Skip to content

Commit

Permalink
[MetaSchedule] Tuning Script Upgrade (apache#11797)
Browse files Browse the repository at this point in the history
* Support uint8.

* Modify tuning functions.

* Follow legacy setting, use int32 for uint8.

* Add vm support.

* Fix vm usage.

* Use vm in rpc run module.

* Fix lint & stuff.

* Fix backend.

* Fix ftimer.

* Fix lint.

* Limit backend choice.

* Add try catch.

* Display name in rpc try catch.

* Support ahb from tune_relay.

* Modify scripts.

* Fix typo.

* Minor fix.

* Fix try catch & func name.

* Fix utils.

* Move utils to tune_utils.

* Fix tune_utils.
  • Loading branch information
zxybazh authored and masahi committed Jul 15, 2022
1 parent bafebb8 commit 10b4961
Show file tree
Hide file tree
Showing 11 changed files with 448 additions and 363 deletions.
150 changes: 59 additions & 91 deletions python/tvm/auto_scheduler/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
from distutils.util import strtobool
import argparse
import json
import os

from distutils.util import strtobool
import numpy as np # type: ignore
import onnx # type: ignore

import tvm
from tvm import auto_scheduler
from tvm import meta_schedule as ms
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer
from tvm.meta_schedule.utils import cpu_count
from tvm.relay.frontend import from_onnx
from tvm.support import describe
Expand Down Expand Up @@ -96,17 +96,23 @@ def _parse_args():
default=100,
)
args.add_argument(
"--cpu-flush",
"--adaptive-training",
type=lambda x: bool(strtobool(x)),
required=True,
help="example: True / False",
default=True,
)
args.add_argument(
"--adaptive-training",
"--cpu-flush",
type=lambda x: bool(strtobool(x)),
required=False,
help="example: True / False",
default=True,
required=True,
)
args.add_argument(
"--backend",
type=str,
choices=["graph", "vm"],
help="example: graph / vm",
required=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand Down Expand Up @@ -135,6 +141,7 @@ def main():
repeat=ARGS.repeat,
min_repeat_ms=ARGS.min_repeat_ms,
enable_cpu_cache_flush=ARGS.cpu_flush,
timeout=ARGS.rpc_config.session_timeout_sec,
)

if ARGS.target.kind.name == "llvm":
Expand Down Expand Up @@ -163,102 +170,63 @@ def main():
onnx_model = onnx.load(ARGS.onnx_path)
shape_dict = {}
for item in ARGS.input_shape:
print(f" input_name: {item['name']}")
print(f" input_name : {item['name']}")
print(f" input_shape: {item['shape']}")
print(f" input_dtype: {item['dtype']}")
shape_dict[item["name"]] = item["shape"]
mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True)
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
params,
target=ARGS.target,
hardware_params=hardware_params,
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====")
print(task.compute_dag)

tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=ARGS.num_trials,
runner=runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
),
adaptive_training=ARGS.adaptive_training,
)

with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay.build(
mod,
target=ARGS.target,
params=params,
input_data = {
item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape
}

with ms.Profiler() as profiler:
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
params,
target=ARGS.target,
hardware_params=hardware_params,
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
print(
f"==== Task {idx}: {task.desc} "
f"(weight {task_weight} key: {task.workload_key}) ====="
)
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
input_data = {}
for item in ARGS.input_shape:
input_name, input_shape, input_dtype = item["name"], item["shape"], item["dtype"]
if input_dtype.startswith("float"):
input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
else:
input_data[input_name] = np.random.randint(
low=0, high=10000, size=input_shape, dtype=input_dtype
print(task.compute_dag)

if ARGS.num_trials > 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=ARGS.num_trials,
runner=runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
),
adaptive_training=ARGS.adaptive_training,
)

def f_timer(rt_mod, dev, input_data):
# pylint: disable=import-outside-toplevel
from tvm.contrib.graph_executor import GraphModule

# pylint: enable=import-outside-toplevel

mod = GraphModule(rt_mod["default"](dev))
for input_name, input_value in input_data.items():
mod.set_input(input_name, input_value)
ftimer = mod.module.time_evaluator(
"run",
dev,
min_repeat_ms=500,
repeat=3,
)
results = list(np.array(ftimer().results) * 1000.0) # type: ignore
print("Running time in time_evaluator: ", results)
relay_build = {"graph": relay.build, "vm": relay.vm.compile}[ARGS.backend]
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay_build(
mod,
target=ARGS.target,
params=params,
)
print("Tuning Time:")
print(profiler.table())

run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer,
)

def f_per_layer(rt_mod, dev, input_data):
# pylint: disable=import-outside-toplevel
from tvm.contrib.debugger.debug_executor import create

# pylint: enable=import-outside-toplevel
mod = create(graph, rt_mod, dev)
for input_name, input_value in input_data.items():
mod.set_input(input_name, input_value)
graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
print("|graph_nodes| = ", len(graph_nodes))
print("|graph_time| = ", len(graph_time))
graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)}
for k, v in graph_nodes_time.items():
print(f"{k} : {v:.3f}")

run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=rt_mod,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_per_layer,
continuation=create_timer(ARGS.backend),
backend=ARGS.backend,
)


Expand Down
Loading

0 comments on commit 10b4961

Please sign in to comment.