Skip to content

Commit

Permalink
Fix lint & stuff.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Jun 23, 2022
1 parent 7de8bce commit 0f273b9
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 161 deletions.
40 changes: 16 additions & 24 deletions python/tvm/auto_scheduler/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import os

from distutils.util import strtobool
import numpy as np # type: ignore
import onnx # type: ignore
import tvm
from tvm import auto_scheduler
from tvm import meta_schedule as ms
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_timer_vm, f_per_layer
from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_per_layer
from tvm.meta_schedule.utils import cpu_count
from tvm.relay.frontend import from_onnx
from tvm.support import describe
Expand Down Expand Up @@ -103,10 +102,10 @@ def _parse_args():
help="example: `True / False",
)
args.add_argument(
"--use-vm",
type=lambda x: bool(strtobool(x)),
"--backend",
type=str,
required=True,
help="example: `True / False",
help="example: graph / vm",
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand Down Expand Up @@ -197,19 +196,21 @@ def main():
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
if ARGS.use_vm:
if ARGS.backend == "vm":
lib = relay.vm.compile(
mod,
target=ARGS.target,
params=params,
)
else:
elif ARGS.backend == "graph":
lib = relay.build(
mod,
target=ARGS.target,
params=params,
)
if not ARGS.use_vm:
else:
raise ValueError(f"Backend {ARGS.backend} not supported!")
if ARGS.backend == "graph":
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params

run_module_via_rpc(
Expand All @@ -220,22 +221,13 @@ def main():
continuation=f_per_layer(graph),
)

run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer,
)
else:
run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer_vm,
use_vm=True,
)
run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer(ARGS.backend),
)


if __name__ == "__main__":
Expand Down
41 changes: 17 additions & 24 deletions python/tvm/auto_scheduler/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import os

from distutils.util import strtobool
import numpy as np # type: ignore
import tvm
from tvm import auto_scheduler
from tvm import meta_schedule as ms
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_timer_vm, f_per_layer
from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_per_layer
from tvm.meta_schedule.utils import cpu_count
from tvm.support import describe

Expand Down Expand Up @@ -101,10 +100,10 @@ def _parse_args():
help="example: `True / False",
)
args.add_argument(
"--use-vm",
type=lambda x: bool(strtobool(x)),
"--backend",
type=str,
required=True,
help="example: `True / False",
help="example: graph / vm",
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand Down Expand Up @@ -197,19 +196,22 @@ def main():
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
if ARGS.use_vm:
if ARGS.backend == "vm":
lib = relay.vm.compile(
mod,
target=ARGS.target,
params=params,
)
else:
elif ARGS.backend == "graph":
lib = relay.build(
mod,
target=ARGS.target,
params=params,
)
if not ARGS.use_vm:
else:
raise ValueError(f"Backend {ARGS.backend} not supported!")

if ARGS.backend == "graph":
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params

run_module_via_rpc(
Expand All @@ -220,22 +222,13 @@ def main():
continuation=f_per_layer(graph),
)

run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer,
)
else:
run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer_vm,
use_vm=True,
)
run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer(ARGS.backend),
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def update(
_ffi_api.CostModelUpdate(self, context, candidates, results) # type: ignore # pylint: disable=no-member

def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
"""Update the cost model given running results.
"""Predict normalized score with the cost model.
Parameters
----------
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def run_module_via_rpc(
dev_type: str,
args: Dict[str, "np.ndarray"],
continuation: Callable,
use_vm: Optional[bool] = False,
backend: Optional[str] = "graph",
):
"""Execute a tvm.runtime.Module on RPC remote"""
# pylint: disable=import-outside-toplevel
Expand All @@ -162,14 +162,14 @@ def run_module_via_rpc(

with tempfile.TemporaryDirectory() as tmp_dir:
filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format)
if use_vm:
if backend == "vm":
code, lib = lib.save()
lib.export_library(filename, tar)
session = rpc_config.connect_server()
session.upload(filename)
_, filename = os.path.split(filename)
rt_mod = session.load_module(filename)
if use_vm:
if backend == "vm":
rt_mod = session.get_function("runtime.Load_Executable")(code, rt_mod)
dev = session.device(dev_type=dev_type, dev_id=0)
nd_args = {k: ndarray.array(v, dev) for k, v in args.items()}
Expand Down
36 changes: 13 additions & 23 deletions python/tvm/meta_schedule/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import logging

from distutils.util import strtobool
import numpy as np # type: ignore
import onnx # type: ignore
import tvm
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.relay.frontend import from_onnx
from tvm.support import describe
from .utils import generate_input_data, f_timer, f_timer_vm, f_per_layer
from .utils import generate_input_data, f_timer, f_per_layer


def _parse_args():
Expand Down Expand Up @@ -100,10 +99,10 @@ def _parse_args():
help="example: `True / False",
)
args.add_argument(
"--use-vm",
type=lambda x: bool(strtobool(x)),
"--backend",
type=str,
required=True,
help="example: `True / False",
help="example: graph / vm",
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand Down Expand Up @@ -161,12 +160,12 @@ def main():
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
params=params,
use_vm=ARGS.use_vm,
backend=ARGS.backend,
)
print("Tuning Time:")
print(profiler.table())

if not ARGS.use_vm:
if ARGS.backend == "graph":
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params

run_module_via_rpc(
Expand All @@ -177,22 +176,13 @@ def main():
continuation=f_per_layer(graph),
)

run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer,
)
else:
run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer_vm,
use_vm=True,
)
run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer(ARGS.backend),
)


if __name__ == "__main__":
Expand Down
37 changes: 14 additions & 23 deletions python/tvm/meta_schedule/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
import logging

from distutils.util import strtobool
import numpy as np # type: ignore
import tvm
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_timer_vm, f_per_layer
from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_per_layer
from tvm.support import describe


Expand Down Expand Up @@ -98,10 +97,11 @@ def _parse_args():
help="example: `True / False",
)
args.add_argument(
"--use-vm",
type=lambda x: bool(strtobool(x)),
"--backend",
type=str,
required=True,
help="example: `True / False",
choices=["graph", "vm"],
help="example: graph / vm",
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand Down Expand Up @@ -161,13 +161,13 @@ def main():
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
params=params,
use_vm=ARGS.use_vm,
backend=ARGS.backend,
)
print("Tuning Time:")
print(profiler.table())
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params

if not ARGS.use_vm:
if ARGS.backend == "graph":
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params

run_module_via_rpc(
Expand All @@ -178,22 +178,13 @@ def main():
continuation=f_per_layer(graph),
)

run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer,
)
else:
run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer_vm,
use_vm=True,
)
run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_timer(ARGS.backend),
)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 0f273b9

Please sign in to comment.