Skip to content

Commit

Permalink
[Autoscheduler] Configurable workload keys (apache#8862)
Browse files Browse the repository at this point in the history
* change workload keys

* remove binary string comparison

* append the tuple not every integer

* clean up

* lint

* dump workload keys to dags

* fix things

* change some strings

* misc fixes, add tests

* jostle ci
  • Loading branch information
AndrewZhaoLuo authored and ylc committed Jan 13, 2022
1 parent ef90631 commit bcc3469
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 7 deletions.
15 changes: 11 additions & 4 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,20 +222,27 @@ def rewrite_layout_from_state(self, state):

def workload_key(self):
"""Return the workload key of this compute DAG.
The workload key is a JSON string from a tuple of (hash-key, tensor shapes...)
The workload key is a JSON string from a tuple of (hash of DAG, tensor shapes...)
Returns
-------
key: str
The workload key of this compute DAG
"""
str_dag = _ffi_api.ComputeDAGPrintDAG(self, True)
str_dag = str_dag.encode(encoding="utf-8")
hash_key = hashlib.md5(str_dag).hexdigest()
hash_func = tvm._ffi.get_global_func(
"auto_scheduler.compute_dag.hash_func", allow_missing=True
)

if hash_func is None:
str_dag = str_dag.encode("utf-8")
hash_key = hashlib.md5(str_dag).hexdigest()
else:
hash_key = hash_func(str_dag)

io_shapes = []
for tensor in self.tensors:
io_shapes += get_const_tuple(tensor.shape)
io_shapes.append(get_const_tuple(tensor.shape))
return json.dumps([hash_key] + io_shapes)

def __str__(self):
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
2. Provide auto-scheduling for all TOPI compute functions
"""

import json
import logging
import threading
from copy import deepcopy
Expand All @@ -30,11 +31,10 @@
from tvm import autotvm, transform
from tvm.ir.transform import PassContext
from tvm.runtime import convert_to_object

from tvm.target import Target
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import Reduce
from tvm.tir import expr as _expr
from tvm.target import Target

from . import _ffi_api
from .compute_dag import ComputeDAG, LayoutRewriteOption
Expand Down Expand Up @@ -97,6 +97,7 @@ def extract_tasks(
target_host=None,
hardware_params=None,
include_simple_tasks=False,
dump_workload_to_dag_log=None,
opt_level=3,
):
"""Extract tuning tasks from a relay program.
Expand All @@ -115,6 +116,8 @@ def extract_tasks(
Hardware parameters used for the search tasks
include_simple_tasks: bool
Whether to extract simple tasks that do not include complicated ops.
dump_workload_to_dag_log: Optional[str]
A file to dump an association between the workload keys and the actual DAG
opt_level : Optional[int]
The optimization level of the task extractions.
Expand Down Expand Up @@ -170,6 +173,10 @@ def extract_tasks(
)
weights.append(weight)

if dump_workload_to_dag_log is not None:
with open(dump_workload_to_dag_log, "w") as f:
json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f)

return tasks, weights


Expand Down
44 changes: 43 additions & 1 deletion tests/python/relay/test_auto_scheduler_task_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Test task extraction for auto-scheduler"""
import pytest
import json
import tempfile

import pytest
import tvm.relay.testing
import tvm.testing
from tvm import _ffi as _ffi_api
from tvm import auto_scheduler, relay


Expand Down Expand Up @@ -248,5 +251,44 @@ def verify_task_extraction(func_name, expected_task, include_simple_tasks=False)
verify_task_extraction(*params)


def test_dump_workload_to_dag_extract_tasks():
mod, _ = get_network("mobilenet", layout="NHWC")
with tempfile.NamedTemporaryFile() as f:
tasks, _ = auto_scheduler.extract_tasks(
mod["main"], None, "llvm", include_simple_tasks=True, dump_workload_to_dag_log=f.name
)
expected = {task.workload_key: str(task.compute_dag) for task in tasks}
actual = json.load(f)
assert expected == actual


def test_custom_hash_func_extract_tasks():
@_ffi_api.register_func("auto_scheduler.compute_dag.hash_func")
def counting_unique_hash(str_dag):
ret = counting_unique_hash.i
counting_unique_hash.i += 1
return ret

counting_unique_hash.i = 0

mod, _ = get_network("mobilenet", layout="NHWC")
tasks, _ = auto_scheduler.extract_tasks(mod["main"], None, "llvm", include_simple_tasks=True)

hash_values = []
for task in tasks:
# task.workload_key should look like
# [43, [3, 3, 1024, 1], [1024], [3, 3, 1024, 1]] where the first int is the result of the hash
# Extract the hash and keep track of every hash
hash_value = int(task.workload_key[1:].split(",")[0])
hash_values.append(hash_value)

# All values are unique, and we know the min and max
# This is a sufficient condition to know that hashes in hash_values are an increasing list
# of hashes up to counting_unique_hash.i - 1
assert len(hash_values) == len(set(hash_values))
assert min(hash_values) == 0
assert max(hash_values) == counting_unique_hash.i - 1


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit bcc3469

Please sign in to comment.