From 4a9b5b5cf597418c8bdbf2e0fcb8ac8cf24f0d07 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Wed, 25 Aug 2021 22:58:40 +0200 Subject: [PATCH 01/42] Update CI Lint Image Version (#8841) * Update CI Lint Image Version * trigger --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index b96bdf566f3b..4814dc7bb802 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -44,7 +44,7 @@ // // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> -ci_lint = "tlcpack/ci-lint:v0.66" +ci_lint = "tlcpack/ci-lint:v0.67" ci_gpu = "tlcpack/ci-gpu:v0.76" ci_cpu = "tlcpack/ci-cpu:v0.76" ci_wasm = "tlcpack/ci-wasm:v0.71" From 0648fffc9b6fddd27dc04a91ebac9cccd780b3b3 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Wed, 25 Aug 2021 17:34:10 -0500 Subject: [PATCH 02/42] [BUG] ToBasicBlockNormalForm immutability (#8778) * ToBasicBlockNormalForm immutability * better comment on ToBasicBlock * refine comment of ToBasicBlockForm --- .../transforms/to_basic_block_normal_form.cc | 11 +++++---- .../test_pass_to_basic_block_normal_form.py | 24 ++++++++++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 79157bba1918..d03fc1488aea 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -51,8 +51,11 @@ Expr ToBasicBlockNormalFormAux(const Expr& e) { IRModule ToBasicBlockNormalForm(const IRModule& mod) { DLOG(INFO) << "ToBBlock:" << std::endl << mod; + // Create a new module by shallow copy. + auto mod_ = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); + tvm::Map updates; - auto funcs = mod->functions; + auto funcs = mod_->functions; for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables"; if (const auto* n = it.second.as()) { @@ -63,12 +66,12 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) { } for (auto pair : updates) { - mod->Add(pair.first, pair.second, true); + mod_->Add(pair.first, pair.second, true); } - DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod; + DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod_; - return mod; + return mod_; } bool BasicBlockNormalFormCheck(const Expr& e) { diff --git a/tests/python/relay/test_pass_to_basic_block_normal_form.py b/tests/python/relay/test_pass_to_basic_block_normal_form.py index 642cab751b79..d04afe15b5bb 100644 --- a/tests/python/relay/test_pass_to_basic_block_normal_form.py +++ b/tests/python/relay/test_pass_to_basic_block_normal_form.py @@ -22,7 +22,7 @@ from tvm.relay.analysis import detect_feature from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude -from tvm.relay.testing import count +from tvm.relay.testing import count, create_workload from tvm.relay.analysis import Feature from tvm.relay.analysis import check_basic_block_normal_form @@ -491,5 +491,27 @@ def test_higher_order_nested(): check_basic_block_normal_form(bblock) +def test_immutability(): + simple_net = relay.nn.conv2d( + data=relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")), + weight=relay.var("weight"), + kernel_size=(5, 5), + channels=3, + padding=(1, 1), + ) + simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net) + mod, _ = create_workload(simple_net) + + old_mod = mod + + with tvm.transform.PassContext(opt_level=4): + with tvm.target.Target("llvm"): + seq = tvm.transform.Sequential(passes=[transform.ToBasicBlockNormalForm()], opt_level=4) + new_mod = seq(mod) + + assert old_mod.astext() == mod.astext() + assert old_mod.astext() != new_mod.astext() + + if __name__ == "__main__": pytest.main([__file__]) From f1ca91d4e401096d04e962c982d62b1f2669c9f5 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 25 Aug 2021 18:25:29 -0700 Subject: [PATCH 03/42] [GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and vm (#8807) * [GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and vm This new benchmarking function is just a convenience function for calling time_evaluator on the underlying module. Hopefully this should make it easier for users to get good benchmarks of their code. * formatting * import order * more test, more comments, more precision * fix tests * add seconds descriptions to doc --- python/tvm/contrib/graph_executor.py | 59 +++++++++++++++ python/tvm/driver/tvmc/model.py | 27 ++----- python/tvm/driver/tvmc/runner.py | 8 +- python/tvm/runtime/module.py | 75 +++++++++++++++++-- python/tvm/runtime/vm.py | 64 ++++++++++++++++ src/runtime/rpc/rpc_module.cc | 5 +- tests/python/driver/tvmc/test_model.py | 3 +- tests/python/driver/tvmc/test_runner.py | 5 +- .../relay/test_backend_graph_executor.py | 26 +++++++ tests/python/relay/test_vm.py | 26 +++++++ tests/python/unittest/test_runtime_measure.py | 11 +++ tutorials/auto_scheduler/tune_network_arm.py | 6 +- tutorials/auto_scheduler/tune_network_cuda.py | 4 +- tutorials/auto_scheduler/tune_network_mali.py | 6 +- tutorials/auto_scheduler/tune_network_x86.py | 4 +- tutorials/autotvm/tune_relay_arm.py | 7 +- tutorials/autotvm/tune_relay_cuda.py | 7 +- tutorials/autotvm/tune_relay_mobile_gpu.py | 7 +- tutorials/autotvm/tune_relay_x86.py | 6 +- tutorials/frontend/deploy_model_on_android.py | 4 +- tutorials/frontend/deploy_prequantized.py | 4 +- .../frontend/deploy_prequantized_tflite.py | 4 +- tutorials/frontend/deploy_sparse.py | 7 +- 23 files changed, 283 insertions(+), 92 deletions(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index f9d1b9734d45..2e8ff1d62421 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -320,3 +320,62 @@ def __getitem__(self, key): The key to the module. """ return self.module[key] + + def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=None, **kwargs): + """Calculate runtime of a function by repeatedly calling it. + + Use this function to get an accurate measurement of the runtime of a function. The function + is run multiple times in order to account for variability in measurements, processor speed + or other external factors. Mean, median, standard deviation, min and max runtime are all + reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that + synchonization and data transfer operations are not counted towards the runtime. This allows + for fair comparison of runtimes across different functions and models. + + The benchmarking loop looks approximately like so: + + .. code-block:: python + + for r in range(repeat): + time_start = now() + for n in range(number): + func_name() + time_end = now() + total_times.append((time_end - time_start)/number) + + + Parameters + ---------- + func_name : str + The function to benchmark + + repeat : int + Number of times to run the outer loop of the timing code (see above). The output will + contain `repeat` number of datapoints. + + number : int + Number of times to run the inner loop of the timing code. This inner loop is run in + between the timer starting and stopping. In order to amortize any timing overhead, + `number` should be increased when the runtime of the function is small (less than a 1/10 + of a millisecond). + + min_repeat_ms : Optional[float] + If set, the inner loop will be run until it takes longer than `min_repeat_ms` + milliseconds. This can be used to ensure that the function is run enough to get an + accurate measurement. + + kwargs : Dict[str, Object] + Named arguments to the function. These are cached before running timing code, so that + data transfer costs are not counted in the runtime. + + Returns + ------- + timing_results : BenchmarkResult + Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to + access the individual runtimes (in seconds). + """ + min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms + if kwargs: + self.set_input(**kwargs) + return self.module.time_evaluator( + func_name, device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms + )() diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index a9516e1e2c42..48bb052124ee 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -46,7 +46,7 @@ import os import tarfile import json -from typing import Optional, Union, List, Dict, Callable, TextIO +from typing import Optional, Union, Dict, Callable, TextIO import numpy as np import tvm @@ -54,6 +54,7 @@ from tvm import relay from tvm.contrib import utils from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule +from tvm.runtime.module import BenchmarkResult try: from tvm.micro import export_model_library_format @@ -371,14 +372,14 @@ def import_package(self, package_path: str): class TVMCResult(object): """A class that stores the results of tvmc.run and provides helper utilities.""" - def __init__(self, outputs: Dict[str, np.ndarray], times: List[float]): + def __init__(self, outputs: Dict[str, np.ndarray], times: BenchmarkResult): """Create a convenience wrapper around the output of tvmc.run Parameters ---------- outputs : dict Outputs dictionary mapping the name of the output to its numpy value. - times : list of float + times : BenchmarkResult The execution times measured by the time evaluator in seconds to produce outputs. """ self.outputs = outputs @@ -390,29 +391,15 @@ def format_times(self): This has the effect of producing a small table that looks like: .. code-block:: Execution time summary: - mean (ms) max (ms) min (ms) std (ms) - 0.14310 0.16161 0.12933 0.01004 + mean (ms) median (ms) max (ms) min (ms) std (ms) + 0.14310 0.14310 0.16161 0.12933 0.01004 Returns ------- str A formatted string containing the statistics. """ - - # timestamps - mean_ts = np.mean(self.times) * 1000 - std_ts = np.std(self.times) * 1000 - max_ts = np.max(self.times) * 1000 - min_ts = np.min(self.times) * 1000 - - header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format( - "mean (ms)", "max (ms)", "min (ms)", "std (ms)" - ) - stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format( - mean_ts, max_ts, min_ts, std_ts - ) - - return "%s\n%s\n" % (header, stats) + return str(self.times) def get_output(self, name: str): """A helper function to grab one of the outputs by name. diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 8515bc9b053c..489604d79cf4 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -421,12 +421,8 @@ def run_module( # This print is intentional print(report) - # create the module time evaluator (returns a function) - timer = module.module.time_evaluator("run", dev, number=number, repeat=repeat) - # call the evaluator function to invoke the module and save execution times - prof_result = timer() - # collect a list of execution times from the profiling results - times = prof_result.results + # call the benchmarking function of the executor + times = module.benchmark(dev, number=number, repeat=repeat) logger.debug("Collecting the output tensors.") num_outputs = module.get_num_outputs() diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 8107ab5b87d2..25a57bbb1c36 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -20,7 +20,8 @@ import os import ctypes import struct -from collections import namedtuple +from typing import Sequence +import numpy as np import tvm._ffi from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY @@ -30,8 +31,69 @@ from . import _ffi_api -# profile result of time evaluator -ProfileResult = namedtuple("ProfileResult", ["mean", "results"]) +class BenchmarkResult: + """Runtimes from benchmarking""" + + def __init__(self, results: Sequence[float]): + """Construct a new BenchmarkResult from a sequence of runtimes. + + Parameters + ---------- + results : Sequence[float] + Raw times from benchmarking + + Attributes + ---------- + min : float + Minimum runtime in seconds of all results. + mean : float + Mean runtime in seconds of all results. If py:meth:`Module.time_evaluator` or + `benchmark` is called with `number` > 0, then each result is already the mean of a + `number` of runtimes, so this becomes the mean of means. + median : float + Median runtime in seconds of all results. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the median of means. + max : float + Maximum runtime in seconds of all results. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the maximum of those means. + std : float + Standard deviation in seconds of runtimes. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the standard deviation of means. + results : Sequence[float] + The collected runtimes (in seconds). This may be a series of mean runtimes if + py:meth:`Module.time_evaluator` or `benchmark` was run with `number` > 1. + """ + self.results = results + self.mean = np.mean(self.results) + self.std = np.std(self.results) + self.median = np.median(self.results) + self.min = np.min(self.results) + self.max = np.max(self.results) + + def __repr__(self): + return "BenchmarkResult(min={}, mean={}, median={}, max={}, std={}, results={})".format( + self.min, self.mean, self.median, self.max, self.std, self.results + ) + + def __str__(self): + return """Execution time summary: +{:^12} {:^12} {:^12} {:^12} {:^12} +{:^12.4f} {:^12.4f} {:^12.4f} {:^12.4f} {:^12.4f} + """.format( + "mean (ms)", + "median (ms)", + "max (ms)", + "min (ms)", + "std (ms)", + self.mean * 1000, + self.median * 1000, + self.max * 1000, + self.min * 1000, + self.std * 1000, + ) class Module(object): @@ -209,7 +271,7 @@ def time_evaluator(self, func_name, dev, number=10, repeat=1, min_repeat_ms=0, f Returns ------- ftimer : function - The function that takes same argument as func and returns a ProfileResult. + The function that takes same argument as func and returns a BenchmarkResult. The ProfileResult reports `repeat` time costs in seconds. """ try: @@ -230,12 +292,11 @@ def evaluator(*args): blob = feval(*args) fmt = "@" + ("d" * repeat) results = struct.unpack(fmt, blob) - mean = sum(results) / float(repeat) - return ProfileResult(mean=mean, results=results) + return BenchmarkResult(results) return evaluator except NameError: - raise NameError("time_evaluate is only supported when RPC is enabled") + raise NameError("time_evaluator is only supported when RPC is enabled") def _collect_from_import_tree(self, filter_func): """Helper function to collect modules from the tree matching a filter_func, then return it. diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 2f133e1a422d..aeb651cb5ae4 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -507,3 +507,67 @@ def get_input_index(self, input_name, func_name="main"): The input index. -1 will be returned if the given input name is not found. """ return self._get_input_index(input_name, func_name) + + def benchmark( + self, device, *args, func_name="main", repeat=5, number=5, min_repeat_ms=None, **kwargs + ): + """Calculate runtime of a function by repeatedly calling it. + + Use this function to get an accurate measurement of the runtime of a function. The function + is run multiple times in order to account for variability in measurements, processor speed + or other external factors. Mean, median, standard deviation, min and max runtime are all + reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that + synchonization and data transfer operations are not counted towards the runtime. This allows + for fair comparison of runtimes across different functions and models. + + The benchmarking loop looks approximately like so: + + .. code-block:: python + + for r in range(repeat): + time_start = now() + for n in range(number): + func_name() + time_end = now() + total_times.append((time_end - time_start)/number) + + + Parameters + ---------- + func_name : str + The function to benchmark + + repeat : int + Number of times to run the outer loop of the timing code (see above). The output will + contain `repeat` number of datapoints. + + number : int + Number of times to run the inner loop of the timing code. This inner loop is run in + between the timer starting and stopping. In order to amortize any timing overhead, + `number` should be increased when the runtime of the function is small (less than a 1/10 + of a millisecond). + + min_repeat_ms : Optional[float] + If set, the inner loop will be run until it takes longer than `min_repeat_ms` + milliseconds. This can be used to ensure that the function is run enough to get an + accurate measurement. + + args : Sequence[Object] + Arguments to the function. These are cached before running timing code, so that data + transfer costs are not counted in the runtime. + + kwargs : Dict[str, Object] + Named arguments to the function. These are cached like `args`. + + Returns + ------- + timing_results : BenchmarkResult + Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to + access the individual runtimes (in seconds). + """ + min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms + if args or kwargs: + self.set_input(func_name, *args, **kwargs) + return self.module.time_evaluator( + "invoke", device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms + )(func_name) diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 7272269680c5..b9ed54e73508 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -417,8 +417,9 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - return WrapTimeEvaluator(m.GetFunction(name, false), dev, number, repeat, min_repeat_ms, - f_preproc); + PackedFunc pf = m.GetFunction(name, false); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global registry"; + return WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, f_preproc); } } else { auto* pf = runtime::Registry::Get(name); diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py index f5a28d419cbb..fd2637a85f1f 100644 --- a/tests/python/driver/tvmc/test_model.py +++ b/tests/python/driver/tvmc/test_model.py @@ -21,6 +21,7 @@ from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCModel, TVMCPackage, TVMCResult +from tvm.runtime.module import BenchmarkResult def test_tvmc_workflow(keras_simple): @@ -35,7 +36,7 @@ def test_tvmc_workflow(keras_simple): assert type(result) is TVMCResult assert path.exists(tuning_records) assert type(result.outputs) is dict - assert type(result.times) is tuple + assert type(result.times) is BenchmarkResult assert "output_0" in result.outputs.keys() diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 7acb376baba6..2ce363ab5911 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -20,6 +20,7 @@ from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCResult from tvm.driver.tvmc.result_utils import get_top_results +from tvm.runtime.module import BenchmarkResult def test_generate_tensor_data_zeros(): @@ -52,7 +53,7 @@ def test_generate_tensor_data__type_unknown(): def test_format_times__contains_header(): - fake_result = TVMCResult(outputs=None, times=[0.6, 1.2, 0.12, 0.42]) + fake_result = TVMCResult(outputs=None, times=BenchmarkResult([0.6, 1.2, 0.12, 0.42])) sut = fake_result.format_times() assert "std (ms)" in sut @@ -101,5 +102,5 @@ def test_run_tflite_module__with_profile__valid_input( tiger_cat_mobilenet_id in top_5_ids ), "tiger cat is expected in the top-5 for mobilenet v1" assert type(result.outputs) is dict - assert type(result.times) is tuple + assert type(result.times) is BenchmarkResult assert "output_0" in result.outputs.keys() diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index c6f2748e9ec8..9e212527838e 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -16,6 +16,7 @@ # under the License. import numpy as np import pytest +from unittest.mock import patch import tvm import json @@ -23,6 +24,7 @@ from tvm.contrib import graph_executor from tvm.relay.op import add import tvm.testing +from tvm.relay.testing import mlp # @tq, @jr should we put this in testing ns? def check_rts(expr, args, expected_result, mod=None): @@ -322,5 +324,29 @@ def test_graph_executor_api(): assert mod.get_input_index("Invalid") == -1 +@tvm.testing.requires_llvm +def test_benchmark(): + mod, params = mlp.get_workload(1) + lib = relay.build(mod, target="llvm", params=params) + exe = graph_executor.create(lib.get_graph_json(), lib.lib, tvm.cpu()) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32")) + result = exe.benchmark(tvm.cpu(), data=data, func_name="run", repeat=2, number=1) + assert result.mean == result.median + assert result.mean > 0 + assert len(result.results) == 2 + + with patch.object( + tvm.runtime.module.Module, + "time_evaluator", + return_value=lambda: tvm.runtime.module.BenchmarkResult([1, 2, 2, 5]), + ) as method: + result = exe.benchmark(tvm.cpu(), data=data, func_name="run", repeat=2, number=1) + assert result.mean == 2.5 + assert result.median == 2.0 + assert result.max == 5 + assert result.min == 1 + assert result.std == 1.5 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 7ae7e0eabeee..c7043481ee3d 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -17,6 +17,7 @@ import numpy as np import pytest import time +from unittest.mock import patch import tvm from tvm import runtime @@ -30,6 +31,7 @@ from tvm import rpc import tvm.testing from tvm.relay.transform import InferType +from tvm.relay.testing import mlp def check_result(args, expected_result, mod=None): @@ -955,5 +957,29 @@ def test_get_input_index(): assert vm_factory.get_input_index("invalid") == -1 +@tvm.testing.requires_llvm +def test_benchmark(): + mod, params = mlp.get_workload(1) + lib = vm.compile(mod, target="llvm", params=params) + exe = runtime.vm.VirtualMachine(lib, tvm.cpu()) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32")) + result = exe.benchmark(tvm.cpu(), data, func_name="main", repeat=2, number=1) + assert result.mean == result.median + assert result.mean > 0 + assert len(result.results) == 2 + + with patch.object( + tvm.runtime.module.Module, + "time_evaluator", + return_value=lambda x: tvm.runtime.module.BenchmarkResult([1, 2, 2, 5]), + ) as method: + result = exe.benchmark(tvm.cpu(), data, func_name="main", repeat=2, number=1) + assert result.mean == 2.5 + assert result.median == 2.0 + assert result.max == 5 + assert result.min == 1 + assert result.std == 1.5 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/unittest/test_runtime_measure.py b/tests/python/unittest/test_runtime_measure.py index 0d02f910a44c..8955b03241a2 100644 --- a/tests/python/unittest/test_runtime_measure.py +++ b/tests/python/unittest/test_runtime_measure.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm.contrib.utils import tempdir +from tvm.runtime.module import BenchmarkResult def test_min_repeat_ms(): @@ -56,5 +57,15 @@ def my_debug(filename): assert ct > 10 + 2 +def test_benchmark_result(): + r = BenchmarkResult([1, 2, 2, 5]) + assert r.mean == 2.5 + assert r.median == 2.0 + assert r.min == 1 + assert r.max == 5 + assert r.std == 1.5 + + if __name__ == "__main__": test_min_repeat_ms() + test_benchmark_result() diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 5b0931405212..1619a55dc7e9 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -349,11 +349,7 @@ def tune_and_evaluate(): # Evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) - prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) # We do not run the tuning in our webpage server since the server doesn't have a Raspberry Pi, diff --git a/tutorials/auto_scheduler/tune_network_cuda.py b/tutorials/auto_scheduler/tune_network_cuda.py index 7b5619c671be..08c15264e3c1 100644 --- a/tutorials/auto_scheduler/tune_network_cuda.py +++ b/tutorials/auto_scheduler/tune_network_cuda.py @@ -288,9 +288,7 @@ def run_tuning(): # Evaluate print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) -prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) ################################################################# diff --git a/tutorials/auto_scheduler/tune_network_mali.py b/tutorials/auto_scheduler/tune_network_mali.py index 8275f96806b8..2d1e51520952 100644 --- a/tutorials/auto_scheduler/tune_network_mali.py +++ b/tutorials/auto_scheduler/tune_network_mali.py @@ -264,11 +264,7 @@ def tune_and_evaluate(): # Evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) - prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) # We do not run the tuning in our webpage server since server doesn't have mali gpu. diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 76068fa79605..6cb8d6f14cb9 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -322,9 +322,7 @@ def run_tuning(): # Evaluate print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) -prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) ################################################################# diff --git a/tutorials/autotvm/tune_relay_arm.py b/tutorials/autotvm/tune_relay_arm.py index debf8b8ecf60..f072c5ddac93 100644 --- a/tutorials/autotvm/tune_relay_arm.py +++ b/tutorials/autotvm/tune_relay_arm.py @@ -359,12 +359,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=10) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=10)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_cuda.py b/tutorials/autotvm/tune_relay_cuda.py index 65991cc83454..b2af2e13f4fe 100644 --- a/tutorials/autotvm/tune_relay_cuda.py +++ b/tutorials/autotvm/tune_relay_cuda.py @@ -244,12 +244,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=600) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=600)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_mobile_gpu.py b/tutorials/autotvm/tune_relay_mobile_gpu.py index 790c2ff2c2b9..d3f4ec62fafc 100644 --- a/tutorials/autotvm/tune_relay_mobile_gpu.py +++ b/tutorials/autotvm/tune_relay_mobile_gpu.py @@ -352,12 +352,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=30) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=30)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index 6b497ae9c0bd..771220bb3314 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -203,11 +203,7 @@ def evaluate_performance(lib, data_shape): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=100, repeat=3) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=100, repeat=3)) def tune_and_evaluate(tuning_opt): diff --git a/tutorials/frontend/deploy_model_on_android.py b/tutorials/frontend/deploy_model_on_android.py index f435befb8250..c7b610d5d503 100644 --- a/tutorials/frontend/deploy_model_on_android.py +++ b/tutorials/frontend/deploy_model_on_android.py @@ -332,9 +332,7 @@ def transform_image(image): print("TVM prediction top-1: {}".format(synset[top1])) print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, number=1, repeat=10) -prof_res = np.array(ftimer().results) * 1000 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, number=1, repeat=10)) ###################################################################### # Sample Output diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index a59655222278..11a9e3e3eee8 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -199,9 +199,7 @@ def quantize_model(model, inp): # Here we give an example of how to measure performance of TVM compiled models. n_repeat = 100 # should be bigger to make the measurement more accurate dev = tvm.cpu(0) -ftimer = rt_mod.module.time_evaluator("run", dev, number=1, repeat=n_repeat) -prof_res = np.array(ftimer().results) * 1e3 -print("Elapsed average ms:", np.mean(prof_res)) +print(rt_mod.benchmark(dev, number=1, repeat=n_repeat)) ###################################################################### # .. note:: diff --git a/tutorials/frontend/deploy_prequantized_tflite.py b/tutorials/frontend/deploy_prequantized_tflite.py index e3934e9b250f..7bbb06bdf801 100644 --- a/tutorials/frontend/deploy_prequantized_tflite.py +++ b/tutorials/frontend/deploy_prequantized_tflite.py @@ -232,9 +232,7 @@ def run_tvm(lib): # Here we give an example of how to measure performance of TVM compiled models. n_repeat = 100 # should be bigger to make the measurement more accurate dev = tvm.cpu(0) -ftimer = rt_mod.module.time_evaluator("run", dev, number=1, repeat=n_repeat) -prof_res = np.array(ftimer().results) * 1e3 -print("Elapsed average ms:", np.mean(prof_res)) +print(rt_mod.benchmark(dev, number=1, repeat=n_repeat)) ###################################################################### # .. note:: diff --git a/tutorials/frontend/deploy_sparse.py b/tutorials/frontend/deploy_sparse.py index f0af12b709e2..768a697f45cf 100644 --- a/tutorials/frontend/deploy_sparse.py +++ b/tutorials/frontend/deploy_sparse.py @@ -233,12 +233,7 @@ def run_relay_graph(mod, params, shape_dict, target, dev): m.run() tvm_output = m.get_output(0) - ftimer = m.module.time_evaluator("run", dev, repeat=5, number=5) - prof_res = np.array(ftimer().results) * 1000 - print( - "%-20s %-19s (%s)" - % ("Runtime:", "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)) - ) + print(m.benchmark(dev, repeat=5, number=5)) return tvm_output From d80528db0becfc471acd1e7cda122f8283117627 Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Thu, 26 Aug 2021 09:54:01 +0100 Subject: [PATCH 04/42] Apply CPPLint to CRT Tests (#8844) This one was a bit trickier as there was more usage of dynamic arrays and less safe casts. I've tried to minimise the changes to just those required to passing linting. --- tests/crt/buffer_write_stream.h | 4 +++- tests/crt/framing_test.cc | 18 +++++++++------- tests/crt/func_registry_test.cc | 35 +++++++++++++++++-------------- tests/crt/page_allocator_test.cc | 5 +++-- tests/crt/session_test.cc | 4 ++-- tests/crt/stack_allocator_test.cc | 2 +- tests/lint/cpplint.sh | 2 +- 7 files changed, 39 insertions(+), 31 deletions(-) diff --git a/tests/crt/buffer_write_stream.h b/tests/crt/buffer_write_stream.h index 66ef044e6ba1..48a30ac4b273 100644 --- a/tests/crt/buffer_write_stream.h +++ b/tests/crt/buffer_write_stream.h @@ -24,6 +24,8 @@ #include #include +#include + using ::tvm::runtime::micro_rpc::FrameBuffer; using ::tvm::runtime::micro_rpc::WriteStream; @@ -51,7 +53,7 @@ class BufferWriteStream : public WriteStream { std::string BufferContents() { return std::string((const char*)buffer_data_, buffer_.Size()); } - static constexpr unsigned int capacity() { return N; }; + static constexpr unsigned int capacity() { return N; } private: bool packet_done_{false}; diff --git a/tests/crt/framing_test.cc b/tests/crt/framing_test.cc index c6dd2c098dd0..e257dfc641ab 100644 --- a/tests/crt/framing_test.cc +++ b/tests/crt/framing_test.cc @@ -150,23 +150,25 @@ TEST_F(UnframerTest, PacketTooLong) { unframer_.Write(packet_length_bytes, sizeof(packet_length), &bytes_consumed)); EXPECT_EQ(sizeof(packet_length), bytes_consumed); - uint8_t long_payload[decltype(write_stream_)::capacity() + 1]; - for (size_t i = 0; i < sizeof(long_payload); i++) { + unsigned int long_payload_len = decltype(write_stream_)::capacity() + 1; + auto long_payload = std::make_unique(long_payload_len); + for (size_t i = 0; i < long_payload_len; i++) { long_payload[i] = i & 0xff; if (long_payload[i] == uint8_t(Escape::kEscapeStart)) { long_payload[i] = 0; } } - crc = tvm::runtime::micro_rpc::crc16_compute(long_payload, sizeof(long_payload), &crc); + crc = tvm::runtime::micro_rpc::crc16_compute(long_payload.get(), long_payload_len, &crc); EXPECT_EQ(kTvmErrorWriteStreamShortWrite, - unframer_.Write(long_payload, sizeof(long_payload), &bytes_consumed)); + unframer_.Write(long_payload.get(), long_payload_len, &bytes_consumed)); EXPECT_EQ(write_stream_.capacity(), bytes_consumed); - EXPECT_EQ(kTvmErrorNoError, unframer_.Write((uint8_t*)&crc, sizeof(crc), &bytes_consumed)); + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(reinterpret_cast(&crc), sizeof(crc), &bytes_consumed)); EXPECT_EQ(2UL, bytes_consumed); // 2, because framer is now in kFindPacketStart. EXPECT_FALSE(write_stream_.packet_done()); EXPECT_FALSE(write_stream_.is_valid()); - EXPECT_EQ(std::string((char*)long_payload, write_stream_.capacity()), + EXPECT_EQ(std::string(reinterpret_cast(long_payload.get()), write_stream_.capacity()), write_stream_.BufferContents()); // Writing a smaller packet directly afterward should work. @@ -177,7 +179,7 @@ TEST_F(UnframerTest, PacketTooLong) { EXPECT_TRUE(write_stream_.packet_done()); EXPECT_TRUE(write_stream_.is_valid()); EXPECT_EQ(kPacket1.payload, write_stream_.BufferContents()); -}; +} class UnframerTestParameterized : public UnframerTest, public ::testing::WithParamInterface {}; @@ -297,4 +299,4 @@ TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { #pragma GCC diagnostic ignored "-Wdeprecated-declarations" INSTANTIATE_TEST_CASE_P(UnframerTests, UnframerTestParameterized, ::testing::ValuesIn(TestPacket::instances)); -#pragma GCC diagnostic pop \ No newline at end of file +#pragma GCC diagnostic pop diff --git a/tests/crt/func_registry_test.cc b/tests/crt/func_registry_test.cc index 7a40c27f6765..9f0e7f8d1a5a 100644 --- a/tests/crt/func_registry_test.cc +++ b/tests/crt/func_registry_test.cc @@ -178,22 +178,23 @@ TEST(MutableFuncRegistry, Create) { for (unsigned int rem = 0; rem < kTvmAverageFuncEntrySizeBytes; rem++) { // test_function name will be used to test overfilling. - char test_function_name[kTvmAverageFunctionNameStrlenBytes + 2 + rem]; + auto test_function_name = + std::make_unique(kTvmAverageFunctionNameStrlenBytes + 2 + rem); TVMMutableFuncRegistry reg; memset(mem_buffer, 0, sizeof(mem_buffer)); EXPECT_EQ(kTvmErrorNoError, TVMMutableFuncRegistry_Create( ®, mem_buffer, kTvmAverageFuncEntrySizeBytes * 2 + rem)); - snprintf_truncate(test_function_name, kTvmAverageFunctionNameStrlenBytes + 1, + snprintf_truncate(test_function_name.get(), kTvmAverageFunctionNameStrlenBytes + 1, function_name_chars); // Add function #1, and verify it can be retrieved. - EXPECT_EQ(kTvmErrorNoError, - TVMMutableFuncRegistry_Set(®, test_function_name, TestFunctionHandle(0x01), 0)); + EXPECT_EQ(kTvmErrorNoError, TVMMutableFuncRegistry_Set(®, test_function_name.get(), + TestFunctionHandle(0x01), 0)); tvm_function_index_t func_index = 100; EXPECT_EQ(kTvmErrorNoError, - TVMFuncRegistry_Lookup(®.registry, test_function_name, &func_index)); + TVMFuncRegistry_Lookup(®.registry, test_function_name.get(), &func_index)); EXPECT_EQ(func_index, 0); TVMBackendPackedCFunc func = NULL; @@ -201,22 +202,23 @@ TEST(MutableFuncRegistry, Create) { EXPECT_EQ(func, TestFunctionHandle(0x01)); // Ensure that overfilling `names` by 1 char is not allowed. - snprintf_truncate(test_function_name, kTvmAverageFunctionNameStrlenBytes + rem + 2, + snprintf_truncate(test_function_name.get(), kTvmAverageFunctionNameStrlenBytes + rem + 2, function_name_chars + 1); - EXPECT_EQ(kTvmErrorFunctionRegistryFull, - TVMMutableFuncRegistry_Set(®, test_function_name, TestFunctionHandle(0x02), 0)); + EXPECT_EQ( + kTvmErrorFunctionRegistryFull, + TVMMutableFuncRegistry_Set(®, test_function_name.get(), TestFunctionHandle(0x02), 0)); EXPECT_EQ(kTvmErrorFunctionNameNotFound, - TVMFuncRegistry_Lookup(®.registry, test_function_name, &func_index)); + TVMFuncRegistry_Lookup(®.registry, test_function_name.get(), &func_index)); // Add function #2, with intentionally short (by 2 char) name. Verify it can be retrieved. - snprintf_truncate(test_function_name, kTvmAverageFunctionNameStrlenBytes - 2 + 1, + snprintf_truncate(test_function_name.get(), kTvmAverageFunctionNameStrlenBytes - 2 + 1, function_name_chars + 1); - EXPECT_EQ(kTvmErrorNoError, - TVMMutableFuncRegistry_Set(®, test_function_name, TestFunctionHandle(0x02), 0)); + EXPECT_EQ(kTvmErrorNoError, TVMMutableFuncRegistry_Set(®, test_function_name.get(), + TestFunctionHandle(0x02), 0)); EXPECT_EQ(kTvmErrorNoError, - TVMFuncRegistry_Lookup(®.registry, test_function_name, &func_index)); + TVMFuncRegistry_Lookup(®.registry, test_function_name.get(), &func_index)); EXPECT_EQ(func_index, 1); func = NULL; @@ -226,7 +228,8 @@ TEST(MutableFuncRegistry, Create) { // Try adding another function, which should fail due to lack of function pointers. test_function_name[0] = 'a'; test_function_name[1] = 0; - EXPECT_EQ(kTvmErrorFunctionRegistryFull, - TVMMutableFuncRegistry_Set(®, test_function_name, TestFunctionHandle(0x03), 0)); + EXPECT_EQ( + kTvmErrorFunctionRegistryFull, + TVMMutableFuncRegistry_Set(®, test_function_name.get(), TestFunctionHandle(0x03), 0)); } -} \ No newline at end of file +} diff --git a/tests/crt/page_allocator_test.cc b/tests/crt/page_allocator_test.cc index 3963569c5972..924bf295ffd2 100644 --- a/tests/crt/page_allocator_test.cc +++ b/tests/crt/page_allocator_test.cc @@ -36,9 +36,10 @@ class PageAllocatorTest : public ::testing::Test { protected: void SetUp() override { memset(raw_memory_pool, 0, sizeof(raw_memory_pool)); - memory_pool = (uint8_t*)(ROUND_UP(((uintptr_t)raw_memory_pool), (1 << kPageSizeBytesLog))); + memory_pool = reinterpret_cast( + ROUND_UP(((uintptr_t)raw_memory_pool), (1 << kPageSizeBytesLog))); PageMemoryManagerCreate(&interface, memory_pool, kMemoryPoolSizeBytes, kPageSizeBytesLog); - mgr = (MemoryManager*)interface; + mgr = reinterpret_cast(interface); ASSERT_EQ(kNumUsablePages, mgr->ptable.max_pages); dev_ = {kDLCPU, 0}; } diff --git a/tests/crt/session_test.cc b/tests/crt/session_test.cc index 48d7475334c2..b6b58e819700 100644 --- a/tests/crt/session_test.cc +++ b/tests/crt/session_test.cc @@ -51,7 +51,7 @@ class ReceivedMessage { class TestSession { public: - TestSession(uint8_t initial_nonce) + explicit TestSession(uint8_t initial_nonce) : framer{&framer_write_stream}, receive_buffer{receive_buffer_array, sizeof(receive_buffer_array)}, sess{&framer, &receive_buffer, TestSessionMessageReceivedThunk, this}, @@ -247,4 +247,4 @@ TEST_F(SessionTest, DoubleStart) { bob_.ClearBuffers(); alice_.WriteTo(&bob_); EXPECT_TRUE(bob_.sess.IsEstablished()); -} \ No newline at end of file +} diff --git a/tests/crt/stack_allocator_test.cc b/tests/crt/stack_allocator_test.cc index 0eae62e92aea..cd0c4a8b65e2 100644 --- a/tests/crt/stack_allocator_test.cc +++ b/tests/crt/stack_allocator_test.cc @@ -198,4 +198,4 @@ TEST(StackAllocatorTest, InitialMemoryMisAlignment) { ASSERT_EQ(tvm_runtime_workspace.next_alloc, &model_memory_ptr[alignment_offset]); ASSERT_EQ(tvm_runtime_workspace.workspace_size, sizeof(model_memory) - offset - alignment_offset); -} \ No newline at end of file +} diff --git a/tests/lint/cpplint.sh b/tests/lint/cpplint.sh index fccd0a23bd63..31eb1d94a347 100755 --- a/tests/lint/cpplint.sh +++ b/tests/lint/cpplint.sh @@ -21,4 +21,4 @@ python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp \ include src \ examples/extension/src examples/graph_executor/src \ - tests/cpp + tests/cpp tests/crt From 46f8b61bd3f9f3351104a0bb9934afe3bafa2c28 Mon Sep 17 00:00:00 2001 From: Anastasia Stulova <38433336+AnastasiaStulova@users.noreply.github.com> Date: Thu, 26 Aug 2021 10:03:06 +0100 Subject: [PATCH 05/42] [Relay][TOPI] Support of depthwise conv2d NHWC for Mali/Bifrost. (#8584) * [Relay][TOPI] Support of depthwise conv2d NHWC for Mali/Bifrost. Added initial tunable autotvm templates for depthwise conv2d with NHWC layout for Mali and Bifrost. * [Relay][TOPI] Misc fixes for depthwise conv2d Mali/Bifrost. - Fix assert for Bifrost. - Set reasonable default axis splits to avoid using tophub for NHWC. - Fixed typo: arm cpu -> Mali. * [Relay][TOPI] Fixed formatting in depthwise conv2d Mali/Bifrost. --- python/tvm/relay/op/strategy/bifrost.py | 8 + python/tvm/relay/op/strategy/mali.py | 17 +- python/tvm/topi/mali/depthwise_conv2d.py | 200 ++++++++++++------ .../topi/python/test_topi_depthwise_conv2d.py | 2 + 4 files changed, 156 insertions(+), 71 deletions(-) diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py index 8008391fe86c..ec3edab2c8b1 100644 --- a/python/tvm/relay/op/strategy/bifrost.py +++ b/python/tvm/relay/op/strategy/bifrost.py @@ -83,6 +83,14 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target): wrap_topi_schedule(topi.bifrost.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.bifrost", ) + elif layout == "NHWC": + assert kernel_layout == "HWOI" + # For now just reuse general Mali strategy. + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nchw.bifrost", + ) else: raise RuntimeError( "Unsupported depthwise_conv2d layout {} for Mali(Bifrost)".format(layout) diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py index d38fe0d82758..e5f4b4e58562 100644 --- a/python/tvm/relay/op/strategy/mali.py +++ b/python/tvm/relay/op/strategy/mali.py @@ -120,14 +120,17 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWOI" if not is_auto_scheduler_enabled(): - raise RuntimeError( - "depthwise_conv2d NHWC layout is not enabled for mali without auto_scheduler." + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.mali", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + naive_schedule, + name="depthwise_conv2d_nhwc.mali", ) - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - naive_schedule, - name="depthwise_conv2d_nhwc.mali", - ) else: raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout)) else: # group_conv2d diff --git a/python/tvm/topi/mali/depthwise_conv2d.py b/python/tvm/topi/mali/depthwise_conv2d.py index b292f694b995..98109ab4535f 100644 --- a/python/tvm/topi/mali/depthwise_conv2d.py +++ b/python/tvm/topi/mali/depthwise_conv2d.py @@ -30,7 +30,7 @@ def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dty return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) -# register customized schedule for arm cpu. +# register customized schedule for Mali. @autotvm.register_topi_schedule("depthwise_conv2d_nchw.mali") def schedule_depthwise_conv2d_nchw(cfg, outs): """Schedule depthwise conv2d @@ -51,86 +51,158 @@ def schedule_depthwise_conv2d_nchw(cfg, outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - def _schedule(pad_data, kernel, conv): - """schedule depthwise_conv2d""" - max_unroll = 16 - vec_size = [1, 2, 4, 8, 16] + def _callback(op): + """traverse to find op to schedule""" + # schedule depthwise_conv2d + if op.tag == "depthwise_conv2d_nchw": + pad_data = op.input_tensors[0] + kernel = op.input_tensors[1] + conv = op.output(0) + _schedule(cfg, s, pad_data, kernel, conv, "NCHW") - ##### space definition begin ##### - n, c, y, x = s[conv].op.axis - bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3) - by, ty, yi = cfg.define_split("tile_y", y, num_outputs=3) - bx, tx, xi = cfg.define_split("tile_x", x, num_outputs=3) - cfg.define_annotate("ann_spatial", [ci, yi, xi], policy="try_unroll_vec") + traverse_inline(s, outs[0].op, _callback) + return s - # fallback support - if cfg.is_fallback: - ref_log = autotvm.tophub.load_reference_log( - "mali", "rk3399", "depthwise_conv2d_nchw.mali" - ) - cfg.fallback_with_reference_log(ref_log) - ###### space definition end ###### - # schedule padding - n, c, y, x = s[pad_data].op.axis - tile_and_bind3d(s, pad_data, c, y, x, cfg["tile_c"].size[1], 1, 1) +# register original implementation of depthwise_conv2d_nhwc since we don't need to change this part +@autotvm.register_topi_compute("depthwise_conv2d_nhwc.mali") +def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): + return nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) - # schedule dilation - if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() - # schedule conv - if conv.op not in s.outputs: - s[conv].set_scope("local") - OL = conv - output = s.outputs[0].output(0) - else: - OL = s.cache_write(conv, "local") - output = conv - - n, c, y, x = s[output].op.axis - bc, tc, ci = cfg["tile_c"].apply(s, output, c) - by, ty, yi = cfg["tile_y"].apply(s, output, y) - bx, tx, xi = cfg["tile_x"].apply(s, output, x) - - bc = s[output].fuse(n, bc) - s[output].bind(bc, te.thread_axis("blockIdx.z")) - s[output].bind(tc, te.thread_axis("threadIdx.z")) - s[output].bind(by, te.thread_axis("blockIdx.y")) - s[output].bind(ty, te.thread_axis("threadIdx.y")) - s[output].bind(bx, te.thread_axis("blockIdx.x")) - s[output].bind(tx, te.thread_axis("threadIdx.x")) - - di, dj = s[OL].op.reduce_axis - s[OL].unroll(di) - s[OL].unroll(dj) - - s[OL].compute_at(s[output], tx) - n, ci, yi, xi = s[OL].op.axis - - cfg["ann_spatial"].apply( - s, - OL, - [ci, yi, xi], - axis_lens=[cfg["tile_c"].size[2], cfg["tile_y"].size[2], cfg["tile_x"].size[2]], - max_unroll=max_unroll, - vec_size=vec_size, - cfg=cfg, - ) +# register customized schedule for Mali. +@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.mali") +def schedule_depthwise_conv2d_nhwc(cfg, outs): + """Schedule depthwise conv2d + + Parameters + ---------- + cfg: ConfigEntity + The configuration of this template + outs: Array of Tensor + The computation graph description of depthwise convolution2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d nchw. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) def _callback(op): """traverse to find op to schedule""" # schedule depthwise_conv2d - if op.tag == "depthwise_conv2d_nchw": + if op.tag == "depthwise_conv2d_nhwc": pad_data = op.input_tensors[0] kernel = op.input_tensors[1] conv = op.output(0) - _schedule(pad_data, kernel, conv) + _schedule(cfg, s, pad_data, kernel, conv, "NHWC") traverse_inline(s, outs[0].op, _callback) return s +def _schedule(cfg, s, pad_data, kernel, conv, layout): + """schedule depthwise_conv2d""" + assert layout in ("NCHW", "NHWC") + + max_unroll = 16 + vec_size = [1, 2, 4, 8, 16] + + ##### space definition begin ##### + if layout == "NCHW": + n, c, h, w = s[conv].op.axis + else: + n, h, w, c = s[conv].op.axis + + bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3) + bh, th, hi = cfg.define_split("tile_y", h, num_outputs=3) + bw, tw, wi = cfg.define_split("tile_x", w, num_outputs=3) + cfg.define_annotate("ann_spatial", [ci, hi, wi], policy="try_unroll_vec") + + # fallback support + if cfg.is_fallback: + if layout == "NCHW": + ref_log = autotvm.tophub.load_reference_log( + "mali", "rk3399", "depthwise_conv2d_nchw.mali" + ) + cfg.fallback_with_reference_log(ref_log) + else: + cfg.fallback_split("tile_c", [-1, 4, 2]) + cfg.fallback_split("tile_y", [-1, 4, 2]) + cfg.fallback_split("tile_x", [-1, 4, 2]) + ###### space definition end ###### + + # schedule padding + if layout == "NCHW": + n, c, h, w = s[pad_data].op.axis + z, y, x = c, h, w + z_factor, y_factor, x_factor = cfg["tile_c"].size[1], 1, 1 + else: + n, h, w, c = s[pad_data].op.axis + z, y, x = h, w, c + z_factor, y_factor, x_factor = 1, 1, cfg["tile_c"].size[1] + tile_and_bind3d(s, pad_data, z, y, x, z_factor, y_factor, x_factor) + + # schedule dilation + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + # schedule conv + if conv.op not in s.outputs: + s[conv].set_scope("local") + OL = conv + output = s.outputs[0].output(0) + else: + OL = s.cache_write(conv, "local") + output = conv + + if layout == "NCHW": + n, c, h, w = s[output].op.axis + else: + n, h, w, c = s[output].op.axis + + bc, tc, ci = cfg["tile_c"].apply(s, output, c) + bh, th, hi = cfg["tile_y"].apply(s, output, h) + bw, tw, wi = cfg["tile_x"].apply(s, output, w) + + if layout == "NCHW": + bz, tz, by, ty, bx, tx = bc, tc, bh, th, bw, tw + else: + bz, tz, by, ty, bx, tx = bh, th, bw, tw, bc, tc + + bz = s[output].fuse(n, bz) + s[output].bind(bz, te.thread_axis("blockIdx.z")) + s[output].bind(tz, te.thread_axis("threadIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + + di, dj = s[OL].op.reduce_axis + s[OL].unroll(di) + s[OL].unroll(dj) + + s[OL].compute_at(s[output], tx) + + if layout == "NCHW": + n, ci, hi, wi = s[OL].op.axis + else: + n, hi, wi, ci = s[OL].op.axis + + cfg["ann_spatial"].apply( + s, + OL, + [ci, hi, wi], + axis_lens=[cfg["tile_c"].size[2], cfg["tile_y"].size[2], cfg["tile_x"].size[2]], + max_unroll=max_unroll, + vec_size=vec_size, + cfg=cfg, + ) + + def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): """tile and bind 3d""" y_factor = y_factor or z_factor diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 5952e624b708..27601cd32b89 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -61,6 +61,8 @@ ) ], "gpu": [(topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc)], + "mali": [(topi.mali.depthwise_conv2d_nhwc, topi.mali.schedule_depthwise_conv2d_nhwc)], + "bifrost": [(topi.mali.depthwise_conv2d_nhwc, topi.mali.schedule_depthwise_conv2d_nhwc)], }, "NCHWc": { "generic": [(topi.x86.depthwise_conv2d_NCHWc, topi.x86.schedule_depthwise_conv2d_NCHWc)], From 349157641b17882fcf944409fba79c7300978a77 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Thu, 26 Aug 2021 11:16:45 +0100 Subject: [PATCH 06/42] Support for CMSIS-NN in Corstone300 Makefile (#8831) Change-Id: Ifc2305db4e11d1d15d45407287f8f0bea469100a --- tests/python/relay/aot/corstone300.mk | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/aot/corstone300.mk b/tests/python/relay/aot/corstone300.mk index bca5dd266491..3a946f2cd876 100644 --- a/tests/python/relay/aot/corstone300.mk +++ b/tests/python/relay/aot/corstone300.mk @@ -42,6 +42,8 @@ PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ -I${PLATFORM_PATH} \ -I${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Include/ \ -I${CMSIS_PATH}/CMSIS/Core/Include \ + -I${CMSIS_PATH}/CMSIS/NN/Include \ + -I${CMSIS_PATH}/CMSIS/DSP/Include \ -isystem$(STANDALONE_CRT_DIR)/include \ PKG_LDFLAGS = -lm -specs=nosys.specs -static -T ${AOT_TEST_ROOT}/corstone300.ld @@ -56,6 +58,7 @@ CRT_SRCS = $(shell find $(CRT_ROOT)) CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c)) CODEGEN_OBJS = $(subst .c,.o,$(CODEGEN_SRCS)) CMSIS_STARTUP_SRCS = $(shell find ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c) +CMSIS_NN_SRCS = $(shell find ${CMSIS_PATH}/CMSIS/NN/Source/*/*.c) UART_SRCS = $(shell find ${PLATFORM_PATH}/*.c) aot_test_runner: $(build_dir)/aot_test_runner @@ -79,13 +82,19 @@ ${build_dir}/libcmsis_startup.a: $(CMSIS_STARTUP_SRCS) $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcmsis_startup.a) $(abspath $(build_dir))/libcmsis_startup/*.o $(QUIET)$(RANLIB) $(abspath $(build_dir)/libcmsis_startup.a) +${build_dir}/libcmsis_nn.a: $(CMSIS_NN_SRCS) + $(QUIET)mkdir -p $(abspath $(build_dir)/libcmsis_nn) + $(QUIET)cd $(abspath $(build_dir)/libcmsis_nn) && $(CC) -c $(PKG_CFLAGS) -D${ARM_CPU} $^ + $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcmsis_nn.a) $(abspath $(build_dir))/libcmsis_nn/*.o + $(QUIET)$(RANLIB) $(abspath $(build_dir)/libcmsis_nn.a) + ${build_dir}/libuart.a: $(UART_SRCS) $(QUIET)mkdir -p $(abspath $(build_dir)/libuart) $(QUIET)cd $(abspath $(build_dir)/libuart) && $(CC) -c $(PKG_CFLAGS) $^ $(QUIET)$(AR) -cr $(abspath $(build_dir)/libuart.a) $(abspath $(build_dir))/libuart/*.o $(QUIET)$(RANLIB) $(abspath $(build_dir)/libuart.a) -$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libcmsis_nn.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(PKG_CFLAGS) -o $@ -Wl,--whole-archive $^ -Wl,--no-whole-archive $(PKG_LDFLAGS) From 98a3476bfc7428f592ad0fd6b8c863b5fd5ec1f9 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 26 Aug 2021 13:22:43 +0200 Subject: [PATCH 07/42] [microtvm][Zephyr] Increase timeout to fix flaky tests (#8846) * increase timeout * trigger --- apps/microtvm/zephyr/template_project/microtvm_api_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index c51389acd90c..f267648a83f9 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -633,8 +633,8 @@ def open(self): return server.TransportTimeouts( session_start_retry_timeout_sec=2.0, - session_start_timeout_sec=5.0, - session_established_timeout_sec=5.0, + session_start_timeout_sec=10.0, + session_established_timeout_sec=10.0, ) def close(self): From bca57cb1e74fe946c2db3d24fe5042b74da9fea7 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Thu, 26 Aug 2021 04:23:28 -0700 Subject: [PATCH 08/42] [AMP] Bump up tolerance on flaky test (#8850) * bumpy up tol * bumped tolerance up even more * jostle ci --- tests/python/relay/test_to_mixed_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 5ab2eb346d8b..99078b7371ba 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -373,7 +373,7 @@ def test_let_statement_simple(): "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.05, rtol=0.15) # Construct expected structure var1 = relay.var("var1", shape=[1, 20], dtype="float16") From 3f777d555f1b1a125b0f7f83291d1d8693ffa6be Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 26 Aug 2021 09:08:17 -0500 Subject: [PATCH 09/42] [Hexagon] Rework tvm.target.hexagon() interface (#8823) * [Hexagon] Rework tvm.target.hexagon() interface Make the tvm.target.hexagon() function take most options as keyword parameters. This will allow adding additional parameters without changing the interface. No changes are required to existing code, except for changing positional parameters following the CPU version to keyword parameters, and updating the names of the keyword parameters: sim_args -> sim_options, llvm_args -> llvm_options, although the old names will be accepted for the time being. * formatting * change ' to " * Rename 'args' to 'config' for clarity * Use 'strip' instad of 'replace' * Restart build --- python/tvm/target/target.py | 137 +++++++++++++++++++++++------------- 1 file changed, 90 insertions(+), 47 deletions(-) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 559e7b3b28d3..aa9226101b52 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -413,79 +413,118 @@ def bifrost(model="unknown", options=None): return Target(" ".join(["opencl"] + opts)) -def hexagon(cpu_ver="v66", sim_args=None, llvm_args=None, hvx=128): +def hexagon(cpu_ver="v66", **kwargs): """Returns a Hexagon target. Parameters ---------- - cpu_ver : str + cpu_ver : str (default: "v66") CPU version used for code generation. Not all allowed cpu str will be valid, LLVM will throw an error. - sim_args : str or list of str + + Recognized keyword parameters + ----------------------------- + hvx : int (default: 128) + Size of HVX vector in bytes. Value of 0 disables HVX codegen. + sim_options : str or list of str (default: None) User defined sim arguments. CPU version defaults to cpu_ver. Otherwise, separate versions are used for codegen and sim. Not all allowed cpu strings will be valid, simulator will throw an error if invalid. Does not affect codegen. - llvm_args : str or list of str + llvm_options : str or list of str (default: None) User defined compiler arguments. - hvx : int - Size of hvx register. Value of 0 indicates disabled hvx. """ + + # Some of the target parameters correspond to target kind attributes + # listed in src/target/target_kind.cc. For those parameters, their + # names follow the attribute names with the exception of '_' being used + # in place of '-'. + # Example compiler arguments # llvm -mtriple=hexagon -mcpu=hexagonv66 -mattr=+hvxv66,+hvx-length128b # Check for valid codegen cpu - valid_hex = ["v60", "v62", "v65", "v66", "v67", "v67t"] + valid_hex = ["v60", "v62", "v65", "v66", "v67", "v67t", "v68"] try: cpu_ver = cpu_ver[cpu_ver.index("v") :].lower() - assert 3 <= len(cpu_ver) <= 4 + assert cpu_ver in valid_hex except: msg = "{} is not a valid Hexagon version\nvalid versions include {}" raise ValueError(msg.format(cpu_ver, valid_hex)) from None - assert hvx in [0, 64, 128] + # Target configuration: + config = { + "hvx": 128, + "sim_options": None, + "llvm_options": None, + } + config.update(kwargs) + + # Warn about obsolete parameter names. + if config.get("sim_args"): + msg = "The keyword parameter 'sim_args' is deprecated, use 'sim_options' instead" + warnings.warn(msg, stacklevel=2) + config.update({"sim_options": config["sim_args"]}) + if config.get("llvm_args"): + msg = "The keyword parameter 'llvm_args' is deprecated, use 'llvm_options' instead" + warnings.warn(msg, stacklevel=2) + config.update({"llvm_options": config["llvm_args"]}) + + # LLVM target string + def create_llvm_target(cpu_ver, config): + """ Create LLVM target string. """ - # Target string - def create_target(cpu_ver): target = " -mtriple=hexagon" mcpu = " -mcpu=hexagon" + cpu_ver - mattr = "" - # HVX enable - if hvx: - mattr = " -mattr=+hvx" + cpu_ver + ",+hvx-length" + str(hvx) + "b" - return target + mcpu + mattr - - # Simulator string - def create_sim(cpu_ver, sim_args): - def validate_hvx_length(codegen_hvx, sim_args): - if sim_args and "--hvx_length" in sim_args: + + # Process the options that affect target features and return the + # target feature string. + def create_target_features(config): + tfs = [] + if config["hvx"] > 0: + valid_hvx = [0, 64, 128] + if not config["hvx"] in valid_hvx: + raise ValueError("Invalid hvx value, should be one of " + str(valid_hvx)) + tfs += ["+hvx" + cpu_ver, "+hvx-length" + str(config["hvx"]) + "b"] + else: + tfs += ["-hvx"] + return "-mattr=" + ",".join(tfs) if tfs else "" + + return target + mcpu + " " + create_target_features(config) + + # Simulator options string + def create_sim_options(cpu_ver, config): + """ Create simulator option string. """ + + def validate_hvx_length(codegen_hvx, sim_options): + if sim_options and "--hvx_length" in sim_options: # If --hvx_length was specified, check HVX length of sim # vs codegen - i = sim_args.index("hvx_length") + len("hvx_length") + 1 - sim_hvx = sim_args[i : i + 3] + i = sim_options.index("hvx_length") + len("hvx_length") + 1 + sim_hvx = sim_options[i : i + 3] if sim_hvx != str(codegen_hvx): - print( - "WARNING: sim hvx {} and codegen hvx {} mismatch!".format( - sim_hvx, codegen_hvx - ) - ) + msg = "sim hvx {} and codegen hvx {} mismatch!".format(sim_hvx, codegen_hvx) + # Set the stacklevel to the tvm.target.hexagon() call. + warnings.warn(msg, stacklevel=4) elif codegen_hvx != 0: # If --hvx_length was not given, add it if HVX is enabled - sim_args = sim_args + " " if isinstance(sim_args, str) else "" - sim_args += "--hvx_length " + str(codegen_hvx) - return sim_args or "" + sim_options = sim_options + " " if isinstance(sim_options, str) else "" + sim_options += "--hvx_length " + str(codegen_hvx) + return sim_options or "" - if not sim_args: - return cpu_ver + " " + validate_hvx_length(hvx, sim_args) + hvx = config["hvx"] + sim_options = config["sim_options"] + if not sim_options: + return cpu_ver + " " + validate_hvx_length(hvx, sim_options) sim_cpu = cpu_ver + " " # Add user defined args - if isinstance(sim_args, list): - sim_args = " ".join(sim_args) + if isinstance(sim_options, list): + sim_options = " ".join(sim_options) # Check for supplied sim cpu version - if "v6" in sim_args: + if "v6" in sim_options: sim_cpu = "" # Regex match for allowed cpus @@ -494,13 +533,13 @@ def validate_hvx_length(codegen_hvx, sim_args): + r"(?Pv6[25678])(?P[a-z])?" + r"(?P_[0-9]+)?(?P_rev[0-9])?\s?(?P--.*)?" ) - m = re.match(valid_cpu_str_regex, sim_args.lower()) + m = re.match(valid_cpu_str_regex, sim_options.lower()) if not m: - raise ValueError('Invalid simulator argument string "{}"'.format(sim_args)) + raise ValueError('Invalid simulator argument string "{}"'.format(sim_options)) # Parse options into correct order cpu_attr = {x: str(m.groupdict()[x] or "") for x in m.groupdict()} - sim_args = ( + sim_options = ( cpu_attr["base_version"] + cpu_attr["sub_version"] + cpu_attr["l2_size"] @@ -510,23 +549,27 @@ def validate_hvx_length(codegen_hvx, sim_args): + cpu_attr["post"] ) - return sim_cpu + " " + validate_hvx_length(hvx, sim_args) + return sim_cpu + " " + validate_hvx_length(hvx, sim_options) + + # LLVM options string + def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument + """ Create LLVM options string. """ + + llvm_options = config["llvm_options"] - # LLVM string - def create_llvm(llvm_args): # TVM's option parser doesn't allow '=' in values, but '=' can # appear in LLVM flags. Replace it with '@', since it's unlikely # that '@' will be used in another context. - if llvm_args is None or len(llvm_args.replace(" ", "")) == 0: + if llvm_options is None or len(llvm_options.strip()) == 0: return "" - args = [s.replace("=", "@") for s in llvm_args.split()] + args = [s.replace("=", "@") for s in llvm_options.split()] return "--llvm-options=" + ",".join(args) # Sim args - os.environ["HEXAGON_SIM_ARGS"] = create_sim(cpu_ver, sim_args) + os.environ["HEXAGON_SIM_ARGS"] = create_sim_options(cpu_ver, config) - target_str = create_target(cpu_ver) - llvm_str = create_llvm(llvm_args) + target_str = create_llvm_target(cpu_ver, config) + llvm_str = create_llvm_options(cpu_ver, config) args_list = target_str.split() + llvm_str.split() return Target(" ".join(["hexagon"] + args_list)) From d263c6d4300170cc6cf7f58b923edcb23b5a7791 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Thu, 26 Aug 2021 18:06:23 +0100 Subject: [PATCH 10/42] [Pattern matching] Add an option to rewrite the graph only once (#8843) * [Pattern matching] Add an option to rewrite the graph only once If the graph returned from the callback consists of the original pattern, the rewriter will run in the loop, which is not always desired. So this patch proposes an option to run the rewriter only once. Change-Id: I85cf0a055b8961d52394f21c1e4d7aad0a7e1d06 * Make rewrite_once default to false Change-Id: Idf6f01f254c403158883681e75c2a5978efbd2d0 --- include/tvm/relay/dataflow_matcher.h | 6 +- python/tvm/relay/dataflow_pattern/__init__.py | 17 +++- src/relay/ir/dataflow_matcher.cc | 11 ++- tests/python/relay/test_dataflow_pattern.py | 98 +++++++------------ 4 files changed, 58 insertions(+), 74 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 12e4e3f45fef..10e461645c8b 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -47,10 +47,13 @@ class DFPatternCallbackNode : public Object { PackedFunc function; /*! \brief Require InferType to be run before the callback */ bool require_type; + /*! \brief Run the callback only once */ + bool rewrite_once; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pattern", &pattern); v->Visit("require_type", &require_type); + v->Visit("rewrite_once", &rewrite_once); } static constexpr const char* _type_key = "DFPatternCallbackNode"; @@ -63,7 +66,8 @@ class DFPatternCallbackNode : public Object { */ class DFPatternCallback : public ObjectRef { public: - TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type); + TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type, + bool rewrite_once = false); TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); }; diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 320a599d5d91..1f6d8bb9ab0b 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -796,11 +796,14 @@ class DFPatternCallback: ---------- require_type: bool Whether InferType is required to be run before the callback. + rewrite_once: bool + If True, run the callback only once. """ - def __init__(self, require_type=False): + def __init__(self, require_type=False, rewrite_once=False): self.pattern = None self.require_type = require_type + self.rewrite_once = rewrite_once def rewrite(self, expr: Expr) -> Expr: """ @@ -842,8 +845,10 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp class _DFPatternCallback(Object): """C++ implemenation""" - def __init__(self, pattern, callback, require_type): - self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type) + def __init__(self, pattern, callback, require_type, rewrite_once): + self.__init_handle_by_constructor__( + ffi.DFPatternCallback, pattern, callback, require_type, rewrite_once + ) def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: @@ -870,7 +875,11 @@ def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: tmp = [] for callback in callbacks: assert callback.pattern is not None - tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type)) + tmp.append( + _DFPatternCallback( + callback.pattern, callback.callback, callback.require_type, callback.rewrite_once + ) + ) return ffi.rewrite(tmp, expr, mod) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index d7f130f2796d..851a498377b2 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -752,19 +752,22 @@ bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) { // Rewrite -DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type) { +DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type, + bool rewrite_once) { ObjectPtr n = make_object(); n->pattern = std::move(pattern); n->function = std::move(function); n->require_type = require_type; + n->rewrite_once = rewrite_once; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") - .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type) { - return DFPatternCallback(pattern, function, require_type); + .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type, + bool rewrite_once) { + return DFPatternCallback(pattern, function, require_type, rewrite_once); }); Expr PatternRewriter::Rewrite(const Array& callbacks, const Expr& pre) { @@ -790,7 +793,7 @@ Expr PatternRewriter::Rewrite(const Array& callbacks, const E count++; } equal = (*structural_equal)(last, post, false, true); - } while (!equal && count < 100); + } while (!equal && count < 100 && !callback_->rewrite_once); if (count >= 100) { LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; } diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 1c721f40d129..74e03f6a9755 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1727,69 +1727,37 @@ def test_partition_constant_embedding(): assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) +def test_rewrite_once(): + # This class recursively removes the arguments to concat until there is nothing left to concatenate. + class ConcatRewriter(DFPatternCallback): + def __init__(self, rewrite_once): + super().__init__(rewrite_once=rewrite_once) + self.pattern = is_op("concatenate")(None) + + def callback(self, pre, post, node_map): + concat_args = post.args[0] + # Remove the last argument + new_args = [concat_args[i] for i in range(len(concat_args) - 1)] + if new_args: + return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0) + else: + return concat_args + + x = relay.var("x") + y = relay.var("y") + z = relay.var("z") + concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0) + + # Let the rewriter run recursively + out = rewrite(ConcatRewriter(False), concat) + expected = relay.expr.Tuple([x]) + assert tvm.ir.structural_equal(out, expected) + + # Run the rewriter once + out = rewrite(ConcatRewriter(True), concat) + expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0) + assert tvm.ir.structural_equal(out, expected) + + if __name__ == "__main__": - test_expr_pattern() - test_var_pattern() - test_constant_pattern() - test_wildcard_pattern() - test_CallPattern() - test_TuplePattern() - test_TupleGetItemPattern() - test_AltPattern() - test_TypePattern() - test_DataTypePattern() - test_ShapePattern() - test_AttrPattern() - test_match_op() - test_no_match_op() - test_match_op_or() - test_match_call_commutive() - test_no_match_call_commutive() - test_match_call() - test_no_match_call() - test_match_option() - test_no_match_option() - test_match_const() - test_match_tuple() - test_no_match_tuple() - test_match_type() - test_no_match_type() - test_match_dtype() - test_no_match_dtype() - test_match_shape() - test_no_match_shape() - test_match_op_attr() - test_no_match_op_attr() - test_match_func_attr() - test_no_match_func_attr() - test_match_call_attr() - test_no_match_call_attr() - test_match_diamond() - test_no_match_diamond() - test_match_fake_diamond() - test_match_dominator() - test_not_match_dominator() - test_rewrite() - test_rewrite_func() - test_nested_rewrite() - test_not_fuse_multi_diamond() - test_fuse_batchnorm() - test_no_fuse_batchnorm() - test_fuse_double_batchnorm() - test_partial_fuse_double_batchnorm() - test_fuse_batchnorm_commutation() - test_quadruple_rewrite_dominator() - test_algebraic_simplify() - test_double_partition() - test_partition_dominator() - test_quadruple_partition_dominator() - test_partition_batchnorm() - test_partition_double_batchnorm() - test_partition_check() - test_partition_check_types() - test_partition_option() - test_match_match() - test_partition_constant_embedding() - test_IfPattern() - test_match_if() - test_no_match_if() + pytest.main([__file__]) From 4fd1bf4e512aafc0bea0b809789cd27f8dd944d4 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 26 Aug 2021 19:08:15 +0200 Subject: [PATCH 11/42] update gpu and cpu (#8853) --- Jenkinsfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 4814dc7bb802..9eafb449c5d3 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,8 +45,8 @@ // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = "tlcpack/ci-lint:v0.67" -ci_gpu = "tlcpack/ci-gpu:v0.76" -ci_cpu = "tlcpack/ci-cpu:v0.76" +ci_gpu = "tlcpack/ci-gpu:v0.77" +ci_cpu = "tlcpack/ci-cpu:v0.77" ci_wasm = "tlcpack/ci-wasm:v0.71" ci_i386 = "tlcpack/ci-i386:v0.73" ci_qemu = "tlcpack/ci-qemu:v0.08" From 04bdd32281c4ae50d086e4469fd6a9ee6f0c93b6 Mon Sep 17 00:00:00 2001 From: Anton Sorokin Date: Thu, 26 Aug 2021 10:21:18 -0700 Subject: [PATCH 12/42] VTA cmake change to include Verilator header for building tsim library (#8797) * VTA cmake file require Verilator include for tsim target. VTA module.cc uses svOpenArrayHandle to send wide data through DPI * Refactor Verialtor check conditions * Build TSIM only for CPU target. CPU target don't use -Werror to compile with Verilator. Jenkinsfile to have tvm_multilib_tsim defined for CPU build target. * remove build/libvta_tsim.so from non tsim targeting builds * Revert to enable TSIM build i386. Revert to -Werror in CPU config. Remove verilator CPP objects from cmake config for tsim and put them as include into vta module.cc to avoid Verilator compilation warnings --- Jenkinsfile | 8 +++++--- cmake/modules/VTA.cmake | 13 ++++++++++++- tests/scripts/task_config_build_arm.sh | 1 - tests/scripts/task_config_build_gpu.sh | 1 - tests/scripts/task_config_build_i386.sh | 2 +- tests/scripts/task_config_build_wasm.sh | 1 - 6 files changed, 18 insertions(+), 8 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 9eafb449c5d3..fa1629205080 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -73,10 +73,12 @@ tvm_runtime = "build/libtvm_runtime.so, build/config.cmake" tvm_lib = "build/libtvm.so, " + tvm_runtime // LLVM upstream lib tvm_multilib = "build/libtvm.so, " + - "build/libvta_tsim.so, " + "build/libvta_fsim.so, " + tvm_runtime +tvm_multilib_tsim = "build/libvta_tsim.so, " + + tvm_multilib + // command to start a docker container docker_run = 'docker/bash.sh' // timeout in minutes @@ -218,7 +220,7 @@ stage('Build') { init_git() sh "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh" make(ci_cpu, 'build', '-j2') - pack_lib('cpu', tvm_multilib) + pack_lib('cpu', tvm_multilib_tsim) timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_cpu} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_unittest.sh" @@ -252,7 +254,7 @@ stage('Build') { init_git() sh "${docker_run} ${ci_i386} ./tests/scripts/task_config_build_i386.sh" make(ci_i386, 'build', '-j2') - pack_lib('i386', tvm_multilib) + pack_lib('i386', tvm_multilib_tsim) } } }, diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index e520e62711f3..1f9d08b50a10 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -73,6 +73,17 @@ elseif(PYTHON) # Cycle accurate simulator driver build if(USE_VTA_TSIM) + if(DEFINED ENV{VERILATOR_INC_DIR}) + set(VERILATOR_INC_DIR $ENV{VERILATOR_INC_DIR}) + elseif (EXISTS /usr/local/share/verilator/include) + set(VERILATOR_INC_DIR /usr/local/share/verilator/include) + elseif (EXISTS /usr/share/verilator/include) + set(VERILATOR_INC_DIR /usr/share/verilator/include) + else() + message(STATUS "Verilator not found in /usr/local/share/verilator/include") + message(STATUS "Verilator not found in /usr/share/verilator/include") + message(FATAL_ERROR "Cannot find Verilator, VERILATOR_INC_DIR is not defined") + endif() # Add tsim driver sources file(GLOB TSIM_RUNTIME_SRCS ${VTA_HW_PATH}/src/*.cc) file(GLOB TSIM_RUNTIME_SRCS vta/runtime/*.cc) @@ -81,7 +92,7 @@ elseif(PYTHON) list(APPEND TSIM_RUNTIME_SRCS ${VTA_HW_PATH}/src/vmem/virtual_memory.cc) # Target lib: vta_tsim add_library(vta_tsim SHARED ${TSIM_RUNTIME_SRCS}) - target_include_directories(vta_tsim SYSTEM PUBLIC ${VTA_HW_PATH}/include) + target_include_directories(vta_tsim SYSTEM PUBLIC ${VTA_HW_PATH}/include ${VERILATOR_INC_DIR} ${VERILATOR_INC_DIR}/vltstd) target_compile_definitions(vta_tsim PUBLIC DMLC_USE_LOGGING_LIBRARY=) foreach(__def ${VTA_DEFINITIONS}) string(SUBSTRING ${__def} 3 -1 __strip_def) diff --git a/tests/scripts/task_config_build_arm.sh b/tests/scripts/task_config_build_arm.sh index cb42b9a71d59..47fa243e8d38 100755 --- a/tests/scripts/task_config_build_arm.sh +++ b/tests/scripts/task_config_build_arm.sh @@ -31,7 +31,6 @@ echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_LLVM llvm-config-8\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "/opt/acl"\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 6e20087df34a..5f86476c64c7 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -38,7 +38,6 @@ echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake echo set\(USE_STACKVM_RUNTIME ON\) >> config.cmake echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_ANTLR ON\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_BLAS openblas\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake diff --git a/tests/scripts/task_config_build_i386.sh b/tests/scripts/task_config_build_i386.sh index ce244fa59276..298259682972 100755 --- a/tests/scripts/task_config_build_i386.sh +++ b/tests/scripts/task_config_build_i386.sh @@ -31,8 +31,8 @@ echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_LLVM llvm-config-4.0\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VERILATOR ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_wasm.sh b/tests/scripts/task_config_build_wasm.sh index 490e9446007e..9a1edbccc1fc 100755 --- a/tests/scripts/task_config_build_wasm.sh +++ b/tests/scripts/task_config_build_wasm.sh @@ -32,6 +32,5 @@ echo set\(USE_ANTLR ON\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake From 423958fd8fdf1a2bd8d45d604135054953c5c73b Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 27 Aug 2021 03:33:37 +0800 Subject: [PATCH 13/42] [FIX] Bug fix for a floormod rewrite simplify rule (#8852) * Update rewrite_simplify.cc * Update test_arith_rewrite_simplify.py * Update test_arith_rewrite_simplify.py * Update test_arith_rewrite_simplify.py --- src/arith/rewrite_simplify.cc | 16 +++++++++----- .../unittest/test_arith_rewrite_simplify.py | 22 ++++++++++++------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ff6536ab066b..1d3475b13dad 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -858,14 +858,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); int64_t ramp_min = floordiv(bmod->base, c2val); int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (bmod->coeff % c2val == 0) { - if (ramp_min == ramp_max) { + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); - } else { - return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } - } else if (c2val % bmod->coeff == 0 && ramp_min == ramp_max) { - return ramp(floormod(b1, c2), c1, lanes).Eval(); + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { + return ramp(floormod(b1, c2), c1, lanes).Eval(); + } + } + if (bmod->coeff % c2val == 0) { + return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 231c376c50ca..641eed51d5cf 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -101,15 +101,16 @@ def test_vector_simplify(): ck.verify( fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), - ) + ) # Example negative case: x = 15; [60, 61, 62, 63, 64] / 64 = [0, 0, 0, 0, 1] ck.verify( fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [0, 1, 1, 1] ck.verify( fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1] + # floor mod ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")) ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4)) @@ -136,16 +137,21 @@ def test_vector_simplify(): flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4) ) ck.verify( - flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), tvm.tir.Ramp(flm(x * 4, 64), 1, 5) - ) + flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + ) # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0] ck.verify( flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4), - ) + flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), + ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2] + ck.verify( + flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), + flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), + ) # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5] ck.verify( flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 13, 20] # Min/Max rules vx = te.var("vx", dtype="int32x2") From 3d81489a2656214e93c6ea983e82c55b310cd28b Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 26 Aug 2021 23:24:51 +0200 Subject: [PATCH 14/42] move rust lint script (#8726) --- tests/lint/rust_format.sh | 35 +++++++++++++++++++++++++++++++++++ tests/scripts/task_lint.sh | 3 +++ tests/scripts/task_rust.sh | 3 --- 3 files changed, 38 insertions(+), 3 deletions(-) create mode 100755 tests/lint/rust_format.sh diff --git a/tests/lint/rust_format.sh b/tests/lint/rust_format.sh new file mode 100755 index 000000000000..10c8feec1fcf --- /dev/null +++ b/tests/lint/rust_format.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_HOME="$(git rev-parse --show-toplevel)" +RUST_DIR="$TVM_HOME/rust" + +if [[ "$1" == "-i" ]]; then + INPLACE_FORMAT=1 + shift 1 +else + INPLACE_FORMAT=0 +fi + +cd $RUST_DIR + +if [[ ${INPLACE_FORMAT} -eq 1 ]]; then + cargo fmt +else + cargo fmt -- --check +fi diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 12b32709a392..2889c3a94f11 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -39,6 +39,9 @@ tests/lint/cpplint.sh echo "clang-format check..." tests/lint/clang_format.sh +echo "Rust check..." +tests/lint/rust_format.sh + echo "black check..." tests/lint/python_format.sh diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index 4b34b6cf8db4..5cc1dc0503f7 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -43,9 +43,6 @@ echo "Using TVM_CARGO_VERSION=$TVM_CARGO_VERSION" export TVM_BIND_THREADS=0 export OMP_NUM_THREADS=1 -cd $RUST_DIR -cargo fmt -- --check - # First we test tvm-sys the core Rust bindings. cd $RUST_DIR/tvm-sys # First we test w/o the bindings feature on. From f4f525dab86af653636bce95ce3609288fbaa587 Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 27 Aug 2021 07:16:54 +0900 Subject: [PATCH 15/42] [AMP] Disallow fp16 conversion for summation-like ops (#8810) * [AMP] Disallow fp16 conversion for summation-like ops * test only structural equality --- python/tvm/relay/transform/mixed_precision.py | 15 +++++---- tests/python/relay/test_to_mixed_precision.py | 31 +++++++++++++------ 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 1657f895dcd7..fb4d3fa208a8 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -81,8 +81,6 @@ "divide", "nn.bias_add", "nn.batch_norm", - "sum", - "mean", "sqrt", "shape_of", # Simple activations @@ -107,15 +105,9 @@ # "nn.global_max_pool1d", # does not exist yet "nn.global_max_pool2d", # "nn.global_max_pool3d", # does not exist yet - # "nn.global_avg_pool1d", # does not exist yet - "nn.global_avg_pool2d", - # "nn.global_avg_pool3d", # does not exist yet "nn.adaptive_max_pool1d", "nn.adaptive_max_pool2d", "nn.adaptive_max_pool3d", - "nn.adaptive_avg_pool1d", - "nn.adaptive_avg_pool2d", - "nn.adaptive_avg_pool3d", ] DEFAULT_NEVER_LIST = [ # In general if |f(x)| >> |x| for expected inputs then put the op here. @@ -131,6 +123,13 @@ # Do not allow arange arguments (begin/end) to be fp16. "end" can be a big fp32 number # not representable in fp16. "arange", + # Ops that could involve a large summation are not allowed in fp16. + "nn.global_avg_pool2d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", + "sum", + "mean", ] diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 99078b7371ba..472f98715ec5 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -221,12 +221,9 @@ def test_do_not_convert_softmax(): b = relay.nn.softmax(a) mod = tvm.IRModule.from_expr(b) mod = tvm.relay.transform.InferType()(mod) - - mod_params = { - "a": np.random.uniform(-1, 1, size=shape).astype("float32"), - } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0) - assert tvm.ir.structural_equal(mod, output_mod) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) def test_do_not_convert_arange(): @@ -234,10 +231,26 @@ def test_do_not_convert_arange(): dtype = "float32" arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype)) mod = tvm.IRModule.from_expr(arange) - mod = tvm.relay.transform.InferType()(mod) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) - output_mod = verify_mixed_precision_output_close(mod, {}, atol=0.0, rtol=0) - assert tvm.ir.structural_equal(mod, output_mod) + +def test_do_not_convert_summation(): + """Ops that could involve a large summation are not allowed in fp16.""" + shape = [1, 3, 16, 16] + a = relay.var("a", shape=shape) + ops = [ + relay.sum, + relay.mean, + relay.nn.global_avg_pool2d, + lambda inp: relay.nn.adaptive_avg_pool2d(inp, (1, 1)), + ] + for op in ops: + mod = tvm.IRModule.from_expr(op(a)) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) def test_green_gray_propagates_simple(): From 227bf7ffafec7a2cff543a8a22f3741f45970b8d Mon Sep 17 00:00:00 2001 From: Tantalus13A98B5F Date: Thu, 26 Aug 2021 21:18:34 -0400 Subject: [PATCH 16/42] [TOPI] [Relay] Sparse Conv2d Implementation for 3x3 kernels (#8605) * [topi] add spconv2d_3x3 nhwc * [relay] sparse_conv2d: add kernel_size attr * [relay] add strategy for spconv2d_3x3 nhwc * [relay] pass to convert spconv2d with const args * [relay] convert sparse conv2d pass fixes * use array for sparse conv2d attr * fixup 1x1 tests; new 3x3 tests --- include/tvm/relay/attrs/nn.h | 4 + python/tvm/autotvm/measure/measure_methods.py | 15 +- python/tvm/relay/analysis/sparse_conv2d.py | 58 ++++-- .../relay/data_dep_optimization/bsr_conv2d.py | 44 ++++- python/tvm/relay/op/nn/_nn.py | 6 +- python/tvm/relay/op/strategy/x86.py | 25 +++ python/tvm/relay/transform/transform.py | 24 ++- python/tvm/topi/nn/sparse.py | 21 ++- python/tvm/topi/x86/sparse.py | 162 +++++++++++++++- src/relay/op/nn/sparse.cc | 3 +- src/relay/transforms/convert_sparse_conv2d.cc | 173 +++++++++++++++++- .../relay/test_sparse_conv2d_convert.py | 63 +++++++ 12 files changed, 548 insertions(+), 50 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 77cba5fa2ff1..d28044c3845d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1066,12 +1066,16 @@ struct SparseTransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes for sparse_dense operator */ struct SparseConv2DAttrs : public tvm::AttrsNode { std::string layout; + Array kernel_size; TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NHWC").describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC'" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively."); + TVM_ATTR_FIELD(kernel_size) + .set_default(Array{1, 1}) + .describe("Kernel size for SparseConv2D, 1x1 or 3x3. "); } }; diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index db4ff26857bd..eab6822b63b8 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -254,13 +254,14 @@ def ref_input(self): @ref_input.setter def ref_input(self, val): - warnings.warn( - "You are specifying fixed input for tuning the operator. " - "Be sure your input always fits the operator. Some " - "operators may conduct layout transformation during tuning, " - "thus can lead to unexpected behaviors. ", - RuntimeWarning, - ) + if val is not None: + warnings.warn( + "You are specifying fixed input for tuning the operator. " + "Be sure your input always fits the operator. Some " + "operators may conduct layout transformation during tuning, " + "thus can lead to unexpected behaviors. ", + RuntimeWarning, + ) self._ref_input = val def set_task(self, task): diff --git a/python/tvm/relay/analysis/sparse_conv2d.py b/python/tvm/relay/analysis/sparse_conv2d.py index 11278bddca33..1862ded831f6 100644 --- a/python/tvm/relay/analysis/sparse_conv2d.py +++ b/python/tvm/relay/analysis/sparse_conv2d.py @@ -54,7 +54,9 @@ def _search_conv2d_op_weight(expr): return _ffi_api.search_conv2d_op_weight(expr) -def process_params(expr, params, block_size, sparsity_threshold, layout): +def process_params( + expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True +): """Process parameters of conv2d from dense to sparse. Parameters @@ -86,14 +88,18 @@ def process_params(expr, params, block_size, sparsity_threshold, layout): for name in weight_names: name = str(name) w_np = params[name].numpy() - # currently only support conv2d_1*1 - if not ( - (w_np.shape[0] == 1 and w_np.shape[1] == 1) - or (w_np.shape[2] == 1 and w_np.shape[3] == 1) - ): + + if layout == "NHWC": # HWIO + weight_kernel = (w_np.shape[0], w_np.shape[1]) + elif layout == "NCHW": # OIHW + weight_kernel = (w_np.shape[2], w_np.shape[3]) + if weight_kernel[0] != weight_kernel[1]: continue - sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) - if sparsity >= sparsity_threshold: + + if weight_kernel[0] == kernel_size == 1: + sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) + if sparsity < sparsity_threshold: + continue if layout == "NHWC": w_np = w_np.squeeze().T elif layout == "NCHW": @@ -108,19 +114,31 @@ def process_params(expr, params, block_size, sparsity_threshold, layout): ) else: sparse_weight_data = sparse_weight.data + elif weight_kernel[0] == kernel_size == 3: + if layout == "NHWC": # HWIO + w_np = w_np.reshape((-1, w_np.shape[-1])).T + elif layout == "NCHW": # OIHW + w_np = w_np.reshape((w_np.shape[0], -1)) + sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size) + if 1 - (sparse_weight.nnz / w_np.size) < sparsity_threshold: + continue + sparse_weight_data = sparse_weight.data + else: + continue - # remove dense weight - del params[name] - memo.weight_name.append(name) - memo.weight_shape.append( - list(sparse_weight_data.shape) - + list(sparse_weight.indices.shape) - + list(sparse_weight.indptr.shape) - ) - params[name + ".data"] = tvm.nd.array(sparse_weight_data) - params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) - params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) - + # remove dense weight + del params[name] + memo.weight_name.append(name) + memo.weight_shape.append( + list(sparse_weight_data.shape) + + list(sparse_weight.indices.shape) + + list(sparse_weight.indptr.shape) + ) + params[name + ".data"] = tvm.nd.array(sparse_weight_data) + params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) + params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) + + if reg_task_input: prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % ( w_np.shape[0], w_np.shape[1], diff --git a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py index 6913a428b2ac..20e01da1493e 100644 --- a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py +++ b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py @@ -23,8 +23,8 @@ from .utils import _run_opt_pass -def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): - """Convert a dense func and according parameters to block sparse +def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1): + """Convert a conv2d func and according parameters to block sparse Parameters ---------- @@ -49,10 +49,46 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): params: Dict[Srting, tvm.nd.array] New params with BSR matrix for mutated Expr """ - weight_info = process_params(func, params, blocksize, sparsity_threshold, layout) + weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size) new_func = _run_opt_pass( func, - relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout), + relay.transform.Conv2dToSparse( + weight_info.weight_name, weight_info.weight_shape, layout, kernel_size + ), ) return new_func, params + + +def convert2(func, params, blocksize, sparsity_threshold, layout, kernel_size): + """Convert a freezed conv2d func to block sparse + + Parameters + ---------- + func : relay.Expr + Expr will be optimized to sparse operation, with params freezed + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr (not used in this pass) + blocksize : Tuple(int, int) + Blocksize for BSR matrix + sparsity_threshold : float + Minimal sparsity requirement for converting. + If weight sparsity is lower than this threshold, + the dense operation will be kept. + layout : str + layout of network + kernel_size : int + kernel size of the conv2d, for filtering + + Returns + ------- + new_func: relay.Expr + Mutated Expr with sparse operations + + params: Dict[Srting, tvm.nd.array] + New params with BSR matrix for mutated Expr (not modified) + """ + new_func = _run_opt_pass( + func, relay.transform.Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold) + ) + return new_func, params diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index a9ccc5aa2d24..a9e485866381 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -198,7 +198,11 @@ def compute_sparse_transpose(attrs, inputs, out_type): @reg.register_compute("nn.sparse_conv2d") def compute_sparse_conv2d(attrs, inputs, out_type): """Compute definition of sparse_conv2d""" - return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])] + return [ + topi.nn.sparse_conv2d( + inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"] + ) + ] reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index a6e141f2753b..1c8d1b478cb1 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -565,6 +565,31 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): return strategy +@sparse_conv2d_strategy.register("cpu") +def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target): + """sparse conv2d x86 strategy""" + strategy = _op.OpStrategy() + if attrs["kernel_size"][0] == 1: + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d), + wrap_topi_schedule(topi.generic.schedule_sparse_conv2d), + name="sparse_conv2d.generic", + ) + elif attrs["kernel_size"][0] == 3: + if attrs["layout"] == "NHWC": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc), + name="conv3x3_spNHWC.x86", + ) + elif attrs["layout"] == "NCHW": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nchw), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nchw), + ) + return strategy + + @roi_align_strategy.register("cpu") def roi_align_strategy_cpu(attrs, inputs, out_type, target): """roi_align x86 strategy""" diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 6294e7acea15..9a7857a01fe6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1093,7 +1093,7 @@ def DenseToSparse(weight_name, weight_shape): return _ffi_api.DenseToSparse(weight_name, weight_shape) -def Conv2dToSparse(weight_name, weight_shape, layout): +def Conv2dToSparse(weight_name, weight_shape, layout, kernel_size): """ Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d``` @@ -1113,7 +1113,27 @@ def Conv2dToSparse(weight_name, weight_shape, layout): ret : tvm.transform.Pass The registered DenseToSparse pass. """ - return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout) + return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size) + + +def Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold): + """ + Rewrite freezed ```nn.conv2d``` operation to ```nn.sparse_conv2d``` + + Parameters + ---------- + layout : str + layout of data + + kernel_size : int + kernel size of conv2d + + Returns + ------- + ret : tvm.transform.Pass + The registered DenseToSparse pass. + """ + return _ffi_api.Conv2dToSparse2(layout, kernel_size, *blocksize, sparsity_threshold) def SimplifyFCTranspose(target_weight_name): diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 948847e60d92..e577104c3ddc 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -566,7 +566,9 @@ def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103 ) -def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC"): +def sparse_conv2d( + dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1 +): """ Computes sparse-conv2d(1*1) of ``data`` and ``(weight_data, weight_indices, weight_indptr)`` @@ -598,14 +600,15 @@ def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout 4-D with shape [M, H, W, N] (layout=NHWC) 4-D with shape [M, N, H ,W] (layout=NCHW) """ - if layout == "NHWC": - return _sparse_conv2d_bsr_compute_nhwc( - dense_data, sparse_data, sparse_indices, sparse_indptr - ) - elif layout == "NCHW": - return _sparse_conv2d_bsr_compute_nchw( - dense_data, sparse_data, sparse_indices, sparse_indptr - ) + if kernel_size == 1: + if layout == "NHWC": + return _sparse_conv2d_bsr_compute_nhwc( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) + elif layout == "NCHW": + return _sparse_conv2d_bsr_compute_nchw( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) else: raise ValueError("Unsupport Layout %s" % layout) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index c6300f6701e0..48ec233fa4bb 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -16,8 +16,10 @@ # under the License. """sparse_dense schedule on x86""" -from tvm import te +from functools import partial, reduce +from tvm import te, tir, autotvm +from ..transform import reshape from ..utils import traverse_inline, get_const_int from .utils import get_fp32_len @@ -60,3 +62,161 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +@autotvm.register_topi_compute("conv3x3_spNHWC.x86") +def spconv2d_3x3_nhwc(cfg, data, wdat, wind, wptr, layout="NHWC"): + """Sparse Conv2d 3x3 compute (NHWC).""" + assert layout == "NHWC" + nsamples, imh, imw, chanin = [i.value for i in data.shape] + nelems, bsrr, bsrc = [i.value for i in wdat.shape] + chanout = (wptr.shape[0].value - 1) * bsrr + + imglen, chanlen = nsamples * imh * imw, 9 * chanin + cfg.define_split("tile_y", imglen, num_outputs=3) + cfg.define_split("tile_x", chanout // bsrr, num_outputs=2) + cfg.add_flop(imglen * (nelems * bsrc * bsrr * 2 - chanout)) + if cfg.is_fallback: + cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 160, 8]) + cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 4]) + + idxsplit = lambda x, y: reduce(lambda a, b: a[:-1] + [a[-1] % b, a[-1] // b], y, [x]) + + @partial(te.compute, (imglen, chanlen), name="Im2Col") + def im2col(row, col): + j_w, j_h, j_n = idxsplit(row, [imw, imh]) + j_c, k_w, k_h = idxsplit(col, [chanin, 3]) + i_h, i_w = j_h + k_h - 1, j_w + k_w - 1 + return tir.if_then_else( + tir.all(i_h >= 0, i_h < imh, i_w >= 0, i_w < imw), data[j_n, i_h, i_w, j_c], 0 + ) + + @partial(te.compute, (imglen, chanout // bsrr, bsrr, bsrc), name="CC") + def matmul(drow, wrow, brow, bcol): + row_start, row_end = wptr[wrow], wptr[wrow + 1] + elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") + elem = row_start + elem_idx + return te.sum( + im2col[drow, wind[elem] * bsrc + bcol] * wdat[elem, brow, bcol], axis=elem_idx + ) + + sum_bsrc = te.reduce_axis((0, bsrc), name="k") + ret = te.compute( + (imglen, chanout), + lambda y, x: te.sum(matmul[y, x // bsrr, x % bsrr, sum_bsrc], axis=sum_bsrc), + name="C", + tag="conv3x3_spNHWC", + ) + return reshape(ret, (nsamples, imh, imw, chanout)) + + +@autotvm.register_topi_schedule("conv3x3_spNHWC.x86") +def schedule_spconv2d_3x3_nhwc(cfg, outs): + """Sparse Conv2d 3x3 schedule (NHWC).""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv3x3_spNHWC": + (matmul,) = op.input_tensors + # wptr, wind, im2col, wdat + _, _, im2col, _ = matmul.op.input_tensors + (data,) = im2col.op.input_tensors + bsrr = matmul.shape[-2].value + chanin = data.shape[-1].value + + mm_y, mm_x = s[op].op.axis + y_t, y_o, y_i = cfg["tile_y"].apply(s, op, mm_y) + x_o, x_i = s[op].split(mm_x, factor=bsrr) + x_t, x_o = cfg["tile_x"].apply(s, op, x_o) + (sum_ax,) = s[op].op.reduce_axis + s[op].reorder(y_t, x_t, y_o, x_o, y_i, x_i, sum_ax) + s[op].unroll(sum_ax) + s[op].vectorize(x_i) + s[op].unroll(y_i) + + s[matmul].compute_at(s[op], x_o) + y_i, x_i, bsrr, bsrc = s[matmul].op.axis + (sum_ax,) = s[matmul].op.reduce_axis + s[matmul].reorder(x_i, sum_ax, y_i, bsrr, bsrc) + s[matmul].unroll(bsrc) + s[matmul].vectorize(bsrr) + s[matmul].unroll(y_i) + + s[im2col].compute_at(s[op], y_o) + y_i, sum_ax = s[im2col].op.axis + _, k_i = s[im2col].split(sum_ax, factor=chanin) + s[im2col].vectorize(k_i) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv3x3_spNCHW.x86") +def spconv2d_3x3_nchw(cfg, data, wdat, wind, wptr, layout="NCHW"): + """Sparse Conv2d 3x3 compute (NCHW).""" + nsamples, chanin, imgh, imgw = [i.value for i in data.shape] + nelems, veclen, bsrc = [i.value for i in wdat.shape] + chanout = (wptr.shape[0].value - 1) * veclen + assert bsrc == 1 and layout == "NCHW" + + cfg.add_flop(nsamples * imgh * imgw * (nelems * veclen * bsrc * 2 - chanout)) + cfg.define_split("tile_hw", imgh * imgw, num_outputs=3) + cfg.define_split("tile_ckk", chanin * 9, num_outputs=3) + + @partial(te.compute, (nsamples, chanin * 3 * 3, imgh * imgw), name="im2col") + def im2col(nsamples, ckk, imglen): + j_h, j_w = imglen // imgw, imglen % imgw + i_c, k_h, k_w = ckk // 9, ckk // 3 % 3, ckk % 3 + i_h, i_w = j_h + k_h - 1, j_w + k_w - 1 + return tir.if_then_else( + tir.all(i_h >= 0, i_h < imgh, i_w >= 0, i_w < imgw), data[nsamples, i_c, i_h, i_w], 0 + ) + + @partial( + te.compute, + (nsamples, chanout // veclen, veclen, bsrc, imgh * imgw), + name="CC", + tag="conv3x3_spNCHW", + ) + def matmul(nsamples, f_o, f_i, bsrk, imglen): + row_start, row_end = wptr[f_o], wptr[f_o + 1] + elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") + elem = row_start + elem_idx + return te.sum( + im2col[nsamples, wind[elem] * bsrc + bsrk, imglen] * wdat[elem, f_i, bsrk], + axis=elem_idx, + ) + + return reshape(matmul, [nsamples, chanout, imgh, imgw]) + + +@autotvm.register_topi_schedule("conv3x3_spNCHW.x86") +def schedule_spconv2d_3x3_nchw(cfg, outs): + """Sparse Conv2d 3x3 schedule (NCHW).""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv3x3_spNCHW": + # wptr, wind, im2col, wdat + _, _, im2col, _ = op.input_tensors + + n_samples, f_o, f_i, b_c, imglen = s[op].op.axis + (sum_ax,) = s[op].op.reduce_axis + hw1, hw2, hw3 = cfg["tile_hw"].apply(s, op, imglen) + s[op].reorder(n_samples, hw1, f_o, hw2, sum_ax, f_i, b_c, hw3) + s[op].unroll(f_i) + s[op].unroll(b_c) + s[op].vectorize(hw3) + + s[im2col].compute_at(s[op], hw1) + n_samples, ckk, imglen = s[im2col].op.axis + ckk1, ckk2, ckk3 = cfg["tile_ckk"].apply(s, im2col, ckk) + hw2, hw3 = s[im2col].split(imglen, factor=cfg["tile_hw"].size[-1]) + s[im2col].reorder(n_samples, ckk1, ckk2, hw2, ckk3, hw3) + s[im2col].unroll(ckk3) + s[im2col].vectorize(hw3) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 32b0811b48ac..7d21005cb4db 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -274,10 +274,11 @@ bool SparseConv2dRel(const Array& types, int num_inputs, const Attrs& attr } Expr MakeSparseConv2d(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr, - std::string layout) { + std::string layout, Array kernel_size) { static const Op& op = Op::Get("nn.sparse_conv2d"); auto attrs = make_object(); attrs->layout = std::move(layout); + attrs->kernel_size = std::move(kernel_size); return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 6e4c03b0fcbc..3f2c25e988f9 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -73,10 +73,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(Sea class Conv2dToSparseConv2dMutator : public ExprRewriter { public: Conv2dToSparseConv2dMutator(const Array& weight_name, - const Array>& weight_shape, const String& layout) + const Array>& weight_shape, const String& layout, + int kernel_size) : conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) { ICHECK_EQ(weight_name.size(), weight_shape.size()); layout_ = layout; + kernel_size_ = kernel_size; for (size_t i = 0; i < weight_name.size(); ++i) { ICHECK(weight_name[i]->IsInstance()); std::string k = weight_name[i].as()->data; @@ -112,6 +114,7 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { Var weight_indptr(prefix + ".indptr", ws_indptr_type); auto attrs = make_object(); attrs->layout = std::move(layout_); + attrs->kernel_size = Array{kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs)); } @@ -126,22 +129,168 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { const Op& sparse_conv2d_op_; std::unordered_map> target_weights_; String layout_; + int kernel_size_; }; // class Conv2dToSparseConv2dAlter Expr Conv2dToSparse(const Expr& e, const Array& weight_name, - const Array>& weight_shape, const String& layout) { - auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout); + const Array>& weight_shape, const String& layout, + int kernel_size) { + auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size); + return PostOrderRewrite(e, &rewriter); +} + +template +auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence) { + return std::make_tuple(arr[Is]...); +} + +template +auto unpack_to_tuple(elemTy* arr) { + return unpack_to_tuple_internal(arr, std::make_index_sequence{}); +} + +struct Range { + size_t dim; + explicit Range(size_t d) : dim(d) {} + + struct iterpoint { + size_t val, lim; + iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {} + + size_t operator*() const { return val; } + + iterpoint operator/(const iterpoint& rhs) const { + return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim); + } + }; + + struct iterator { + size_t val, lim; + iterator(size_t v1, size_t v2) : val(v1), lim(v2) {} + + bool operator!=(const iterator& rhs) const { return val != rhs.val; } + + void operator++() { ++val; } + + iterpoint operator*() const { return iterpoint(val, lim); } + }; + + iterator begin() { return iterator(0, dim); } + + iterator end() { return iterator(dim, dim); } +}; + +// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` +class Conv2dToSparseConv2dMutator2 : public ExprRewriter { + public: + Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW, + double sparse_thresh) + : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")), + dev_cpu0_{DLDeviceType::kDLCPU, 0}, + layout_(layout), + kernel_size_(kernel_size), + blockH_(blockH), + blockW_(blockW), + sparse_thresh_(sparse_thresh) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + // check op type & attrs + const auto pre_attrs = pre->attrs.as(); + if (!pre_attrs || pre_attrs->data_layout != layout_ || + pre_attrs->strides[0].as()->value != 1 || + pre_attrs->kernel_size[0].as()->value != kernel_size_) + return post; + // check constant weight + const auto pre_weight_node = pre->args[1].as(); + if (!pre_weight_node) return post; + + // check weight dtype & shape + auto&& pre_weight = pre_weight_node->data; + auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32); + ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only + auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data()); + int O, I, H, W; + if (layout_ == "NCHW") { + std::tie(O, I, H, W) = pre_weight_shape; + } else { // NHWC + std::tie(H, W, I, O) = pre_weight_shape; + } + int CO = O, CI = H * W * I; + + // copy to vector + std::vector pre_weight_data(CO * CI); + pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float)); + if (layout_ == "NHWC") { + std::vector tmp(pre_weight_data.size()); + for (auto i : Range(CO)) + for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)]; + std::swap(tmp, pre_weight_data); + } + // convert to BSR + std::vector wdata, block(blockH_ * blockW_); + std::vector windices, windptr; + for (auto bh : Range(CO / blockH_)) { + windptr.push_back(windices.size()); + for (auto bw : Range(CI / blockW_)) { + int cntnnz = 0; + for (auto i : Range(blockH_)) + for (auto j : Range(blockW_)) { + auto tmp = pre_weight_data[*(bh / i / bw / j)]; + if (tmp) cntnnz++; + block[*(i / j)] = tmp; + } + if (cntnnz) { + wdata.insert(wdata.end(), block.begin(), block.end()); + windices.push_back(*bw); + } + } + } + windptr.push_back(windices.size()); + double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size(); + if (sprate < sparse_thresh_) return post; + + // constrct return data + int nnz = windices.size(); + auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_); + auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_); + auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_); + weight_data.CopyFromBytes(wdata.data(), wdata.size() * sizeof(float)); + weight_indices.CopyFromBytes(windices.data(), windices.size() * sizeof(int32_t)); + weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t)); + + // construct return call + auto args = runtime::Array{post.as()->args[0], Constant(weight_data), + Constant(weight_indices), Constant(weight_indptr)}; + auto attrs = make_object(); + attrs->layout = layout_; + attrs->kernel_size = Array{kernel_size_, kernel_size_}; + return Call(sparse_conv2d_op_, args, Attrs(attrs)); + } + + private: + const Op& sparse_conv2d_op_; + DLDevice dev_cpu0_; + String layout_; + int kernel_size_, blockH_, blockW_; + double sparse_thresh_; +}; // class Conv2dToSparseConv2dMutator2 + +Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW, + double sparse_thresh) { + auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh); return PostOrderRewrite(e, &rewriter); } namespace transform { +// Convert a model with seperate weight info (already sparsified). Pass Conv2dToSparse(const Array& weight_name, const Array>& weight_shape, - const String& layout) { + const String& layout, int kernel_size) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { // Remove FreeVar warnings - auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout)); + auto f0 = + Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); @@ -155,6 +304,20 @@ Pass Conv2dToSparse(const Array& weight_name, const Array pass_func = + [=](Function f, IRModule m, PassContext pc) { + auto f0 = Downcast( + Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); + return f0; + }; + return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2").set_body_typed(Conv2dToSparse2); + } // namespace transform } // namespace relay diff --git a/tests/python/relay/test_sparse_conv2d_convert.py b/tests/python/relay/test_sparse_conv2d_convert.py index 0af78fc033ac..045462475ee1 100644 --- a/tests/python/relay/test_sparse_conv2d_convert.py +++ b/tests/python/relay/test_sparse_conv2d_convert.py @@ -25,6 +25,7 @@ from tvm.ir import IRModule from tvm import relay from tvm.topi.sparse.utils import random_bsr_matrix +from tvm.relay.build_module import bind_params_by_name def run_func(func, params, x): @@ -100,6 +101,68 @@ def test_bsr_sparse_conv2d_nhwc(): np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) +def test_bsr_sparse_conv2d_3x3_nchw(): + data = relay.var("data", shape=(1, 64, 32, 32), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(128, 64, 3, 3), dtype="float32") + y = relay.nn.conv2d( + x, w, channels=128, kernel_size=3, padding=1, data_layout="NCHW", kernel_layout="OIHW" + ) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).reshape( + 128, 64, 3, 3 + ) + ) + } + + x_np = np.random.randn(1, 64, 32, 32).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + func = bind_params_by_name(func, params) + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2( + func, {}, (16, 1), 0.2, "NCHW", 3 + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + +def test_bsr_sparse_conv2d_3x3_nhwc(): + data = relay.var("data", shape=(1, 32, 32, 64), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(3, 3, 64, 128), dtype="float32") + y = relay.nn.conv2d( + x, w, channels=128, kernel_size=3, padding=1, data_layout="NHWC", kernel_layout="HWIO" + ) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).T.reshape( + 3, 3, 64, 128 + ) + ) + } + + x_np = np.random.randn(1, 32, 32, 64).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + func = bind_params_by_name(func, params) + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2( + func, {}, (16, 1), 0.2, "NHWC", 3 + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_bsr_sparse_conv2d_nhwc() test_bsr_sparse_conv2d_nchw() + test_bsr_sparse_conv2d_3x3_nhwc() + test_bsr_sparse_conv2d_3x3_nchw() From b4b194dbb0db1f152740bbb84cab96721482e2cf Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 27 Aug 2021 05:48:25 +0300 Subject: [PATCH 17/42] extend repeat_interleave op for relay.Expr (#8839) Co-authored-by: Valery Chernov --- python/tvm/relay/frontend/pytorch.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 613643f091d7..c13d791cf2e2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -573,6 +573,12 @@ def repeat_interleave(self, inputs, input_types): if isinstance(inputs[1], int): repeats = inputs[1] axis = inputs[2] + elif isinstance(inputs[1], _expr.Expr): + if isinstance(inputs[1], _expr.Constant): + repeats = int(inputs[1].data.numpy()) + else: + repeats, _ = try_infer_value(inputs[1], lambda ret: ret.tolist()) + axis = inputs[2] else: msg = "Only repeat with one value as repeat is currently supported." raise AssertionError(msg) From 9d168822f2950083a59be243cb35ad51888dbc5d Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Fri, 27 Aug 2021 06:04:09 +0100 Subject: [PATCH 18/42] Change AOT from ExprVisitor to MixedModeVisitor (#8856) This should allow better scale-ability for AOT when targeting larger networks. --- src/relay/backend/aot_executor_codegen.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 942bc0d1d44a..2fb35f3a2e27 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -53,7 +53,7 @@ using StorageMap = * This is an on demand allocator for AOT. A new temporary * (storage allocator identifier) is allocated for each operation. */ -class AOTOnDemandAllocator : public ExprVisitor { +class AOTOnDemandAllocator : public MixedModeVisitor { public: // run the visitor on a function. void Run(const Function& func) { @@ -84,10 +84,7 @@ class AOTOnDemandAllocator : public ExprVisitor { AssignReturnSid(GetRef(op)); } - void VisitExpr_(const VarNode* op) final { - ExprVisitor::VisitExpr_(op); - AssignReturnSid(GetRef(op)); - } + void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } void VisitExpr_(const FunctionNode* op) final { // do not recurse into sub function. @@ -218,7 +215,7 @@ class AOTOnDemandAllocator : public ExprVisitor { }; /*! \brief Code generator for AOT executor */ -class AOTExecutorCodegen : public ExprVisitor { +class AOTExecutorCodegen : public MixedModeVisitor { protected: /*! * \brief Utility function to allocate a DLTensor or TVMValue @@ -437,7 +434,6 @@ class AOTExecutorCodegen : public ExprVisitor { void VisitExpr_(const OpNode* op) override { throw std::runtime_error("can not compile op in non-eta expanded form"); } - void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); } void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } void VisitExpr_(const FunctionNode* op) override { ICHECK(op->GetAttr(attr::kCompiler).defined()) From e774fed67c2d12e6cfc29a013f029d4b55c28e2a Mon Sep 17 00:00:00 2001 From: Jason <928090362@qq.com> Date: Fri, 27 Aug 2021 15:34:05 +0800 Subject: [PATCH 19/42] Add a PaddlePaddle Frontend (#8645) * fix some problems for matmul * fix some problems for matmul * add alpha parameter for matmul * remove unnecessary condition * add TranslatedLayer which support model loaded by jit.load * add mul operator support * Add padding mode support for conv/pool2d * support 4 two-tuples * add paddle test case * add paddle conv2d case * update test_forward.py * fix paddle convert_matmul * add paddle multiply and matmul op test case * add test case and fix bug * delete import pandas * add paddlepaddle tests * modify the variable name of convert_reshape * formatting * formatting * use black to format python code * pylint check * Remove fluid api * black format Co-authored-by: root Co-authored-by: wjj19950828 Co-authored-by: heliqi <1101791222@qq.com> Co-authored-by: Junru Shao --- python/tvm/relay/frontend/__init__.py | 1 + python/tvm/relay/frontend/paddlepaddle.py | 918 ++++++++++++++++++ .../frontend/paddlepaddle/test_forward.py | 661 +++++++++++++ tests/scripts/task_python_frontend.sh | 3 + 4 files changed, 1583 insertions(+) create mode 100644 python/tvm/relay/frontend/paddlepaddle.py create mode 100644 tests/python/frontend/paddlepaddle/test_forward.py diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index aa8ac4fc7434..aa49b63203f2 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -31,4 +31,5 @@ from .darknet import from_darknet from .pytorch import from_pytorch from .caffe import from_caffe +from .paddlepaddle import from_paddle from .change_datatype import ChangeDatatype diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py new file mode 100644 index 000000000000..76a12691d2bf --- /dev/null +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -0,0 +1,918 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines +# pylint: disable=import-outside-toplevel +"""Paddle: PArallel Distributed Deep LEarning.""" +import warnings + +import numpy as np + +import tvm +from tvm.ir import IRModule + +from .. import analysis +from .. import expr as _expr +from .. import function as _function +from .. import ty as _ty +from .. import op as _op +from .common import ( + fold_constant, + infer_shape, + infer_type, + infer_value, + new_var, +) + +__all__ = ["from_paddle"] + + +def shape_of(x, dtype="int32"): + """Get shape of a tensor""" + + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + return _op.shape_of(x, dtype) + + +def _get_pad_size(in_size, dilated_kernel_size, stride_size): + """calculate the paddings size""" + + if stride_size == 1 or in_size % stride_size == 0: + pad = max(dilated_kernel_size - stride_size, 0) + else: + pad = max(dilated_kernel_size - (in_size % stride_size), 0) + + pad_before = pad // 2 + pad_after = pad - pad_before + + return [pad_before, pad_after] + + +def convert_arg_max(g, op, block): + """Operator converter for arg_max.""" + + axis = op.attr("axis") + keepdims = op.attr("keepdims") + flatten = op.attr("flatten") + + x = g.get_node(op.input("X")[0]) + if axis is None or flatten: + x = _op.reshape(x, [-1]) + out = _op.argmax(x, axis=None, keepdims=True) + else: + out = _op.argmax(x, axis=axis, keepdims=keepdims) + g.add_node(op.output("Out")[0], out) + + +def convert_assign(g, op, block): + """Operator converter for assign.""" + + out = _op.copy(g.get_node(op.input("X")[0])) + g.add_node(op.output("Out")[0], out) + + +def convert_batch_norm(g, op, block): + """Operator converter for batch_norm.""" + + ipt_name = op.input("X")[0] + scale_name = op.input("Scale")[0] + bias_name = op.input("Bias")[0] + mean_name = op.input("Mean")[0] + variance_name = op.input("Variance")[0] + epsilon = op.attr("epsilon") + out = _op.nn.batch_norm( + g.get_node(ipt_name), + g.get_node(scale_name), + g.get_node(bias_name), + g.get_node(mean_name), + g.get_node(variance_name), + epsilon=epsilon, + ) + g.add_node(op.output("Y")[0], out[0]) + + +def convert_cast(g, op, block): + """Operator converter for cast.""" + + dtype = block.var(op.output("Out")[0]).dtype + dtype = str(dtype).strip().split(".")[1] + x = g.get_node(op.input("X")[0]) + out = _op.cast(x, dtype=dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_concat(g, op, block): + """Operator converter for concat.""" + + inputs = [g.get_node(op.input("X")[i]) for i in range(len(op.input("X")))] + axis = op.attr("axis") + out = _op.concatenate(inputs, axis=axis) + g.add_node(op.output("Out")[0], out) + + +def convert_conv2d(g, op, block): + """Operator converter for conv2d.""" + + dilations = op.attr("dilations") + groups = op.attr("groups") + paddings = op.attr("paddings") + padding_algorithm = op.attr("padding_algorithm") + strides = op.attr("strides") + + kernel = g.get_node(op.input("Filter")[0]) + input_x = g.get_node(op.input("Input")[0]) + out_channels, _, k_h, k_w = infer_shape(kernel) + in_h, in_w = infer_shape(input_x)[2:] + if padding_algorithm == "VALID": + paddings = [0, 0] + elif padding_algorithm == "SAME": + pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) + paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + elif padding_algorithm == "EXPLICIT": + if len(paddings) == 2: + paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] + if len(paddings) == 4: + paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] + else: + msg = 'Value {} in attribute "padding" of operator Conv is not "valid."' + raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + + out = _op.nn.conv2d( + input_x, + kernel, + strides=strides, + padding=paddings, + dilation=dilations, + groups=groups, + channels=out_channels, + kernel_size=[k_h, k_w], + ) + g.add_node(op.output("Output")[0], out) + + +def convert_cumsum(g, op, block): + """Operator converter for cumsum.""" + + axis = op.attr("axis") + exclusive = op.attr("exclusive") + flatten = op.attr("flatten") + reverse = op.attr("reverse") + + x = g.get_node(op.input("X")[0]) + if axis is None or flatten: + x = _op.reshape(x, [-1]) + if reverse: + x = _op.reverse(x, axis=axis) + out = _op.cumsum(x, axis=axis, exclusive=exclusive) + out = _op.reverse(out, axis=axis) + else: + out = _op.cumsum(x, axis=axis, exclusive=exclusive) + g.add_node(op.output("Out")[0], out) + + +def convert_dropout(g, op, block): + """Operator converter for dropout.""" + + x = g.get_node(op.input("X")[0]) + out = _op.copy(x) + g.add_node(op.output("Out")[0], out) + + +def convert_elementwise_op(g, op, block): + """Operator converter for all the elementwise operators.""" + + op_map = { + "elementwise_div": lambda x, y: x / y, + "elementwise_add": lambda x, y: x + y, + "elementwise_mul": lambda x, y: x * y, + "elementwise_sub": lambda x, y: x - y, + "elementwise_mod": lambda x, y: x % y, + } + op_func = op_map[op.type] + ipt0 = g.get_node(op.input("X")[0]) + ipt1 = g.get_node(op.input("Y")[0]) + ipt0_shape = block.var(op.input("X")[0]).shape + ipt1_shape = block.var(op.input("Y")[0]).shape + axis = op.attr("axis") + if len(ipt0_shape) != len(ipt1_shape): + if axis < 0: + axis = axis + len(ipt0_shape) + if axis != len(ipt0_shape) - 1: + ipt1 = _op.expand_dims(ipt1, axis=axis, num_newaxis=(len(ipt0_shape) - axis - 1)) + out = op_func(ipt0, ipt1) + g.add_node(op.output("Out")[0], out) + + +def convert_equal(g, op, block): + """Operator converter for equal.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + out = _op.equal(x, y) + g.add_node(op.output("Out")[0], out) + + +def convert_activation(g, op, block): + """Operator converter for all the activation.""" + + op_map = { + "exp": _op.exp, + "relu": _op.nn.relu, + "tanh": _op.tanh, + "sqrt": _op.sqrt, + "erf": _op.erf, + "abs": _op.abs, + } + act_func = op_map[op.type] + out = act_func(g.get_node(op.input("X")[0])) + g.add_node(op.output("Out")[0], out) + + +def convert_feed(g, op, block): + """Converter for model input node.""" + + if block is not None: + ipt_name = op.output("Out")[0] + ipt_shape = block.var(ipt_name).shape + ipt_dtype = block.var(ipt_name).dtype + ipt_dtype = str(ipt_dtype).strip().split(".")[1] + else: + ipt_shape = op.shape + ipt_dtype = str(op.dtype).strip().split(".")[1] + ipt_name = op.name + if g.shape_dict is not None: + ipt_shape = g.shape_dict[ipt_name] + out = new_var(ipt_name, shape=ipt_shape, dtype=ipt_dtype) + g.add_node(ipt_name, out) + + +def convert_fill_any_like(g, op, block): + """Operator converter for fill_any_like.""" + + out_name = op.output("Out")[0] + out_dtype = block.var(out_name).dtype + out_dtype = str(out_dtype).strip().split(".")[1] + x = g.get_node(op.input("X")[0]) + ipt_type = infer_type(x).checked_type + value = op.attr("value") + if not _ty.is_dynamic(ipt_type): + shape = infer_shape(x) + const = np.ones(shape) * value + out = _expr.const(const.astype(out_dtype)) + else: + out = _op.transform.full_like(x, value).astype(out_dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_fill_constant(g, op, block): + """Operator converter for fill_constant.""" + + value = op.attr("value") + shape = block.var(op.output("Out")[0]).shape + dtype = block.var(op.output("Out")[0]).dtype + dtype = str(dtype).strip().split(".")[1] + if op.input("ValueTensor"): + shape = g.get_node(op.input("ValueTensor")[0]) + shape = infer_value(shape, g.get_params()).numpy() + if op.input("ShapeTensor"): + shape = g.get_node(op.input("ShapeTensor")[0]) + shape = infer_value(shape, g.get_params()).numpy() + value = np.full(shape, value, dtype) + out = _expr.const(value.astype(dtype)).astype(dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_gelu(g, op, block): + """Operator converter for gelu.""" + + x = g.get_node(op.input("X")[0]) + out = x * ( + _expr.const(0.5, dtype="float32") + + _op.erf(x * _expr.const(0.5 ** 0.5, dtype="float32")) * _expr.const(0.5, dtype="float32") + ) + g.add_node(op.output("Out")[0], out) + + +def convert_hard_sigmoid(g, op, block): + """Operator converter for hard_sigmoid.""" + + slope = op.attr("slope") + x = g.get_node(op.input("X")[0]) + out = x * _expr.const(slope) + _expr.const(0.5) + out = _op.clip(out, 0, 1) + g.add_node(op.output("Out")[0], out) + + +def convert_hard_swish(g, op, block): + """Operator converter for hard_swish.""" + + offset = op.attr("offset") + scale = op.attr("scale") + threshold = op.attr("threshold") + assert np.isclose(offset, 3.0), "Only support offset==3.0 for PaddlePaddle's hard_swish" + assert np.isclose(scale, 6.0), "Only support scale==6.0 for PaddlePaddle's hard_swish" + assert np.isclose(threshold, 6.0), "Only support threshold==6.0 for PaddlePaddle's hard_swish" + x = g.get_node(op.input("X")[0]) + out = _op.clip(x, -1 * offset, offset) + out = out / _expr.const(threshold) + _expr.const(0.5) + out = x * out + g.add_node(op.output("Out")[0], out) + + +def convert_layer_norm(g, op, block): + """Operator converter for layer_norm.""" + + begin_norm_axis = op.attr("begin_norm_axis") + epsilon = op.attr("epsilon") + x = g.get_node(op.input("X")[0]) + bias_input = op.input("Bias") + scale_input = op.input("Scale") + + x_shape = infer_shape(x) + assert begin_norm_axis in ( + len(x_shape) - 1, + -1, + ), "Support only normalization over last one dimension." + + if bias_input: + bias = g.get_node(bias_input[0]) + else: + bias = _expr.const(np.zeros(x_shape[begin_norm_axis])) + + if scale_input: + scale = g.get_node(scale_input[0]) + else: + scale = _expr.const(np.ones(x_shape[begin_norm_axis])) + + out = _op.nn.layer_norm( + x, gamma=scale, beta=bias, axis=begin_norm_axis, epsilon=epsilon, center=True, scale=True + ) + g.add_node(op.output("Y")[0], out) + + +def convert_leaky_relu(g, op, block): + """Operator converter for leaky_relu.""" + + alpha = op.attr("alpha") + x = g.get_node(op.input("X")[0]) + out = _op.nn.leaky_relu(x, alpha=alpha) + g.add_node(op.output("Out")[0], out) + + +def convert_lookup_table(g, op, block): + """Operator converter for lookup_table_v2.""" + + indices = g.get_node(op.input("Ids")[0]) + padding_idx = op.attr("padding_idx") + if padding_idx != -1: + g.get_params[op.input("W")[0]][padding_idx] = 0.0 + g.add_node(op.input("W")[0], _expr.const(g.params[op.input("W")[0]])) + weights = g.get_node(op.input("W")[0]) + out = _op.take(weights, indices.astype("int32"), axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_matmul(g, op, block): + """Operator converter for matmul.""" + + inputs = [g.get_node(op.input("X")[0]), g.get_node(op.input("Y")[0])] + a_shape = infer_shape(inputs[0]) + b_shape = infer_shape(inputs[1]) + if op.has_attr("trans_x"): + # for matmul_v2 + trans_x = op.attr("trans_x") + trans_y = op.attr("trans_y") + else: + # for matmul + trans_x = op.attr("transpose_X") + trans_y = op.attr("transpose_Y") + if trans_x: + perm = list(range(len(a_shape))) + perm[-2] = len(a_shape) - 1 + perm[-1] = len(a_shape) - 2 + inputs[0] = _op.transpose(inputs[0], axes=perm) + if trans_y: + perm = list(range(len(b_shape))) + perm[-2] = len(b_shape) - 1 + perm[-1] = len(b_shape) - 2 + inputs[1] = _op.transpose(inputs[1], axes=perm) + + # This implemention almost keeps same with ONNX + # Need to check input shape as batch matmul must be supported. + a_shape = shape_of(inputs[0]) + a_rank = infer_shape(a_shape)[0] + b_shape = shape_of(inputs[1]) + b_rank = infer_shape(b_shape)[0] + # When performing a batch matmul, we need to properly handle N-dim shapes. + if a_rank > 2 or b_rank > 2: + + def flatten_to_nd(x, x_shape, nd=3): + ndims = infer_shape(x_shape)[0] + if ndims == nd: + return x + newshape = _op.concatenate( + [ + _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), + ], + 0, + ) + out = _op.reshape(x, fold_constant(newshape)) + return out + + b_type = infer_type(inputs[1]) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(inputs[0], a_shape, 2) + b = _op.transpose(inputs[1]) + output = _op.nn.dense(a, b) + else: + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(inputs[0], a_shape, 3) + b = flatten_to_nd(inputs[1], b_shape, 3) + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a batch matmul. + output = _op.nn.batch_matmul(a, b) + # Determine the output batch dimension. + if a_rank > b_rank: + out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) + elif a_rank < b_rank: + out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2]) + # If its unclear how broadcasting should be applied, the output + # shape is determined by choosing the maximum value from each input. + else: + out_batch = _op.concatenate( + [ + _op.maximum( + _op.strided_slice(a_shape, [i], [i + 1]), + _op.strided_slice(b_shape, [i], [i + 1]), + ) + for i in range(a_rank - 2) + ], + 0, + ) + # Reshape output to original dimensions. + final_shape = _op.concatenate( + [ + out_batch, + _op.strided_slice( + a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1] + ), + _op.strided_slice( + b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] + ), + ], + 0, + ) + out = _op.reshape(output, fold_constant(final_shape)) + else: + if b_rank == 1: + inputs[1] = _op.expand_dims(inputs[1], 1, 1) + # Otherwise a simple dense op will get the job done. + input_1_t = _op.transpose(inputs[1], axes=(1, 0)) + out = _op.nn.dense(inputs[0], input_1_t) + if b_rank == 1: + out = _op.squeeze(out, axis=[-1]) + if op.has_attr("alpha"): + alpha = op.attr("alpha") + if not np.isclose(alpha, 1.0): + out = out * _expr.const(alpha).astype("float32") + g.add_node(op.output("Out")[0], out) + + +def convert_mul(g, op, block): + """Operator converter for mul.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + x_num_col_dims = op.attr("x_num_col_dims") + y_num_col_dims = op.attr("y_num_col_dims") + x_shape = shape_of(x) + y_shape = shape_of(y) + x_dim = infer_shape(x_shape)[0] + y_dim = infer_shape(y_shape)[0] + if x_num_col_dims < 0: + x_num_col_dims += x_dim + if y_num_col_dims < 0: + y_num_col_dims += y_dim + if x_num_col_dims == 1: + x = _op.nn.batch_flatten(x) + else: + pre_shape = _op.prod(_op.strided_slice(x_shape, [0], [x_num_col_dims], [1]), keepdims=True) + post_shape = _op.prod( + _op.strided_slice(x_shape, [x_num_col_dims], [x_dim], [1]), keepdims=True + ) + new_shape = _op.concatenate([pre_shape, post_shape], axis=0) + new_shape = fold_constant(new_shape) + x = _op.reshape(x, new_shape) + if y_num_col_dims == 1: + y = _op.nn.batch_flatten(y) + else: + pre_shape = _op.prod(_op.strided_slice(y_shape, [0], [y_num_col_dims], [1]), keepdims=True) + post_shape = _op.prod( + _op.strided_slice(y_shape, [y_num_col_dims], [y_dim], [1]), keepdims=True + ) + new_shape = _op.concatenate([pre_shape, post_shape], axis=0) + new_shape = fold_constant(new_shape) + y = _op.reshape(y, new_shape) + y = _op.transpose(y) + out = _op.nn.dense(x, y) + out_pre_shape = _op.strided_slice(x_shape, [0], [x_num_col_dims], [1]) + out_post_shape = _op.strided_slice(y_shape, [y_num_col_dims], [y_dim], [1]) + out_shape = _op.concatenate([out_pre_shape, out_post_shape], axis=0) + out_shape = fold_constant(out_shape) + out = _op.reshape(out, out_shape) + g.add_node(op.output("Out")[0], out) + + +def convert_pool2d(g, op, block): + """Operator converter for pool2d.""" + + adaptive = op.attr("adaptive") + ceil_mode = op.attr("ceil_mode") + global_pooling = op.attr("global_pooling") + ksize = op.attr("ksize") + paddings = op.attr("paddings") + padding_algorithm = op.attr("padding_algorithm") + pooling_type = op.attr("pooling_type") + if global_pooling: + adaptive = True + ksize = [1, 1] + + input_x = g.get_node(op.input("X")[0]) + in_h, in_w = infer_shape(input_x)[2:] + + op_map = { + "avg": "avg_pool2d", + "max": "max_pool2d", + } + strides = op.attr("strides") + if isinstance(strides, int): + strides = [strides, strides] + if isinstance(ksize, int): + ksize = [ksize, ksize] + if isinstance(paddings, int): + paddings = [paddings] * 2 + + if padding_algorithm == "VALID": + paddings = [0, 0] + elif padding_algorithm == "SAME": + pad_h = _get_pad_size(in_h, ksize[0], strides[0]) + pad_w = _get_pad_size(in_w, ksize[1], strides[1]) + paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + elif padding_algorithm == "EXPLICIT": + if len(paddings) == 2: + paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] + if len(paddings) == 4: + paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] + else: + msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' + raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + + if not adaptive: + out = getattr(_op.nn, op_map[pooling_type])( + input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode + ) + else: + out = getattr(_op.nn, "adaptive_" + op_map[pooling_type])(input_x, output_size=ksize) + g.add_node(op.output("Out")[0], out) + + +def convert_reshape(g, op, block): + """Operator converter for reshape.""" + + input_shape = op.input("Shape") + input_shape_tensor = op.input("ShapeTensor") + data = g.get_node(op.input("X")[0]) + if input_shape: + new_shape = g.get_node(input_shape[0]) + elif input_shape_tensor: + tmp_shape = [] + for shape_name in input_shape_tensor: + shape = g.get_node(shape_name) + if len(infer_shape(shape)) == 0: + shape = _op.reshape(shape, [-1]) + if isinstance(shape, _expr.Constant): + tmp_shape.append(shape) + elif isinstance(shape, _expr.Expr): + tmp_shape.append(shape) + else: + tmp_shape.append(_expr.const(np.array(shape).astype("int64"))) + new_shape = _op.concatenate(tmp_shape, axis=0) + else: + new_shape = op.attr("shape") + out = _op.reshape(data, new_shape) + g.add_node(op.output("Out")[0], out) + + +def convert_scale(g, op, block): + """Operator converter for scale.""" + + scale = op.attr("scale") + bias = op.attr("bias") + bias_after_scale = op.attr("bias_after_scale") + x = g.get_node(op.input("X")[0]) + if np.isclose(scale, 1.0) and np.isclose(bias, 0.0): + out = _op.copy(x) + else: + if np.isclose(bias, 0.0): + out = x * _expr.const(np.array(scale).astype("float32")) + elif np.isclose(scale, 1.0): + out = x + _expr.const(np.array(bias).astype("float32")) + else: + if bias_after_scale: + out = x * _expr.const(np.array(scale).astype("float32")) + _expr.const( + np.array(bias).astype("float32") + ) + else: + out = (x + _expr.const(np.array(bias).astype("float32"))) * _expr.const( + np.array(scale).astype("float32") + ) + g.add_node(op.output("Out")[0], out) + + +def convert_shape(g, op, block): + """Operator converter for shape.""" + + x = g.get_node(op.input("Input")[0]) + out = shape_of(x) + g.add_node(op.output("Out")[0], out) + + +def convert_slice(g, op, block): + """Operator converter for slice.""" + + def parameter_process(starts, ends, axes, dshape): + new_axes = [] + new_starts = [] + new_ends = [] + pop_index = 0 + for i in range(max(axes) + 1): + new_axes.append(i) + if i in axes: + new_starts.append(starts[pop_index]) + new_ends.append(ends[pop_index]) + pop_index += 1 + else: + new_starts.append(0) + new_ends.append(dshape[i]) + return new_starts, new_ends, new_axes + + data = g.get_node(op.input("Input")[0]) + dshape = infer_shape(data) + starts = op.attr("starts") + ends = op.attr("ends") + axes = op.attr("axes") + decrease_axis = op.attr("decrease_axis") + if isinstance(starts, int): + starts = [starts] + if isinstance(ends, int): + ends = [ends] + if isinstance(axes, int): + axes = [axes] + if isinstance(decrease_axis, int): + decrease_axis = [decrease_axis] + starts, ends, axes = parameter_process(starts, ends, axes, dshape) + out = _op.strided_slice(data, begin=starts, end=ends) + if decrease_axis: + out = _op.squeeze(out, axis=decrease_axis) + g.add_node(op.output("Out")[0], out) + + +def convert_softmax(g, op, block): + """Operator converter for softmax.""" + + axis = op.attr("axis") + input_shape = block.var(op.input("X")[0]).shape + if axis < 0: + axis = len(input_shape) + axis + x = g.get_node(op.input("X")[0]) + m = _op.max(x, axis, keepdims=True) + e = _op.exp(x - m) + out = e / _op.sum(e, axis, keepdims=True) + g.add_node(op.output("Out")[0], out) + + +def convert_unsqueeze(g, op, block): + """Operator converter for unsqueeze.""" + + x = g.get_node(op.input("X")[0]) + axes = sorted(op.attr("axes")) + for axis in axes: + x = _op.expand_dims(x, axis=axis, num_newaxis=1) + g.add_node(op.output("Out")[0], x) + + +_convert_map = { + "arg_max": convert_arg_max, + "assign": convert_assign, + "batch_norm": convert_batch_norm, + "cast": convert_cast, + "concat": convert_concat, + "conv2d": convert_conv2d, + "cumsum": convert_cumsum, + "depthwise_conv2d": convert_conv2d, + "dropout": convert_dropout, + "elementwise_add": convert_elementwise_op, + "elementwise_div": convert_elementwise_op, + "elementwise_mul": convert_elementwise_op, + "elementwise_sub": convert_elementwise_op, + "equal": convert_equal, + "exp": convert_activation, + "feed": convert_feed, + "fill_any_like": convert_fill_any_like, + "fill_constant": convert_fill_constant, + "gelu": convert_gelu, + "hard_sigmoid": convert_hard_sigmoid, + "hard_swish": convert_hard_swish, + "layer_norm": convert_layer_norm, + "leaky_relu": convert_leaky_relu, + "lookup_table_v2": convert_lookup_table, + "matmul": convert_matmul, + "matmul_v2": convert_matmul, + "mul": convert_mul, + "pool2d": convert_pool2d, + "relu": convert_activation, + "reshape2": convert_reshape, + "scale": convert_scale, + "shape": convert_shape, + "slice": convert_slice, + "softmax": convert_softmax, + "tanh": convert_activation, + "unsqueeze2": convert_unsqueeze, +} + + +class GraphProto: + """A helper class for handling relay functions from PaddlePaddle model.""" + + def __init__(self): + self.nodes = {} + self.params = {} + self.shape_dict = None + + def get_node(self, name): + """get node from graph""" + + assert name in self.nodes + return self.nodes[name] + + def add_node(self, name, node): + """add a node to graph""" + + self.nodes[name] = fold_constant(node) + + def get_params(self, name=None): + """get params from graph""" + + if name is None: + return self.params + assert name in self.params + return self.params[name] + + def extract_parameters(self, program, scope=None): + """Extract all the weights from PaddlePaddle program.""" + + self.params = {} + variables = program.global_block().vars + for name in variables: + var = program.global_block().var(name) + if name.endswith("feed") or name.endswith("fetch"): + continue + if not var.persistable: + continue + if isinstance(scope, dict): + self.params[name] = scope[name] + else: + self.params[name] = np.array(scope.var(name).get_tensor()) + self.nodes[name] = _expr.const(self.params[name]) + + def check_input_shape(self, op, block): + """Check the shape information of model's inputs, fixed shape is recommended.""" + + ipt_name = op.input(op.input_names[0]) + ipt_shape = block.var(ipt_name).shape + for i in ipt_shape: + if i < 0: + warning_msg = "Input {}(shape={}) has unkown dimension shapes. \ + Specifying static values may improve performance".format( + ipt_name, ipt_shape + ) + warnings.warn(warning_msg) + + def check_unsupported_ops(self, program): + """Check whether all the operators are supported.""" + + unsupported_ops = set() + for block in program.blocks: + for op in block.ops: + if op.type == "fetch": + continue + if op.type not in _convert_map: + unsupported_ops.add(op.type) + if len(unsupported_ops) > 0: + msg = "The following operators are not supported for frontend Paddle: " + msg += ", ".join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) + + def ops_to_relay(self, program, input_specs=None): + """Convert PaddlePaddle operators to TVM relay functions.""" + + if input_specs is not None: + for input_spec in input_specs: + convert_feed(self, input_spec, None) + for block in program.blocks: + for op in block.ops: + if op.type == "fetch": + continue + convert_func = _convert_map[op.type] + convert_func(self, op, block) + + def from_program(self, program, shape_dict, scope): + """Construct the TVM relay expression from PaddlePaddle program.""" + + self.shape_dict = shape_dict + if scope is None: + import paddle + + scope = paddle.fluid.global_scope() + self.check_unsupported_ops(program) + self.extract_parameters(program, scope) + self.ops_to_relay(program) + + output_names = list() + for block in program.blocks: + for op in block.ops: + if op.type == "fetch": + output_names.append(op.input("X")[0]) + + outputs = [self.nodes[name] for name in output_names] + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + + free_vars = analysis.free_vars(outputs) + func = _function.Function(free_vars, outputs) + mod = IRModule.from_expr(func) + return mod, self.params + + def from_translated_layer(self, layer, shape_dict): + """Construct the TVM relay expression from PaddlePaddle TranslatedLayer.""" + + self.shape_dict = shape_dict + program = layer.program() + parameters = dict() + for param in layer.parameters(): + parameters[param.name] = np.array(param.value().get_tensor()) + self.check_unsupported_ops(program) + self.extract_parameters(program, parameters) + + input_specs = layer._input_spec() + self.ops_to_relay(program, input_specs) + + output_names = [x.name for x in layer._output_spec()] + + outputs = [self.nodes[name] for name in output_names] + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + + free_vars = analysis.free_vars(outputs) + func = _function.Function(free_vars, outputs) + mod = IRModule.from_expr(func) + return mod, self.params + + +def from_paddle(program_or_layer, shape_dict=None, scope=None): + """Convert a PaddlePaddle model into an equivalent Relay Function. + + PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, + and PaddlePaddle scope stores all the weights of PaddlePaddle model. + """ + + import paddle + + g = GraphProto() + if isinstance(program_or_layer, paddle.jit.TranslatedLayer): + # model is loaded by `paddle.jit.load` + mod, params = g.from_translated_layer(program_or_layer, shape_dict) + elif isinstance(program_or_layer, paddle.static.Program): + # model is loaded by `paddle.static.load_inference_model` + mod, params = g.from_program(program_or_layer, shape_dict, scope) + else: + raise Exception("Only PaddlePaddle's Program and TranslatedLayer are supported.") + return mod, params diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py new file mode 100644 index 000000000000..db07e07f9d83 --- /dev/null +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -0,0 +1,661 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from pathlib import Path +import shutil + +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay +from tvm.contrib import graph_executor + +import paddle +import paddle.nn as nn + +PADDLE_TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data", "paddle") +PADDLE_TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True) + + +def assert_shapes_match(tru, est): + if tru.shape != est.shape: + msg = "Output shapes {} and {} don't match" + raise AssertionError(msg.format(tru.shape, est.shape)) + + +def get_paddle_model(func, input_spec): + global PADDLE_TEST_DATA_ROOT_PATH + model_path = Path(PADDLE_TEST_DATA_ROOT_PATH, "model") + + paddle.jit.save(func, str(model_path), input_spec=input_spec) + baseline_model = paddle.jit.load(str(model_path)) + + shutil.rmtree(str(PADDLE_TEST_DATA_ROOT_PATH)) + return baseline_model + + +def verify_model(func, input_data, rtol=1e-5, atol=1e-5): + if not (isinstance(input_data, (tuple, list))): + input_data = [input_data] + + input_spec = [] + input_names = [] + input_shape_dict = {} + compiled_input = {} + for idx, data in enumerate(input_data): + input_name = "input{}".format(idx) + input_spec.append( + paddle.static.InputSpec(dtype=data.dtype, shape=data.shape, name=input_name) + ) + input_names.append(input_name) + input_shape_dict[input_name] = data.shape + if isinstance(data, np.ndarray): + compiled_input[input_name] = data + else: + compiled_input[input_name] = data.numpy() + + baseline_model = get_paddle_model(func, input_spec) + baseline_outputs = baseline_model(*[input[:] for input in input_data]) + + # get paddle outputs + if isinstance(baseline_outputs, (tuple, list)): + baseline_outputs = tuple(out.numpy() for out in baseline_outputs) + else: + baseline_outputs = (baseline_outputs.numpy(),) + + mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) + parms_num = min(len(input_names), len(mod["main"].params)) + compiled_names = [] + for arg in mod["main"].params[:parms_num]: + assert arg.name_hint in input_names + compiled_names.append(arg.name_hint) + + with tvm.transform.PassContext(opt_level=3): + for target, dev in tvm.testing.enabled_targets(): + lib = relay.build(mod, target=target, params=params) + gmod = graph_executor.GraphModule(lib["default"](dev)) + for name in compiled_names: + gmod.set_input(name, compiled_input[name]) + gmod.run() + + for i, baseline_output in enumerate(baseline_outputs): + compiled_output = gmod.get_output(i).numpy() + + assert_shapes_match(baseline_output, compiled_output) + tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) + + +@tvm.testing.uses_gpu +def test_forward_add_subtract(): + input_shape = [10] + + @paddle.jit.to_static + def add_subtract(inputs): + return paddle.subtract(paddle.add(inputs, inputs), inputs) + + @paddle.jit.to_static + def add_subtract2(inputs): + return inputs + 1 - 2 + + @paddle.jit.to_static + def add_subtract3(inputs1, inputs2): + ones = paddle.ones([10], dtype="float32") + return inputs1 + ones - inputs2 + + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(add_subtract, input_data) + verify_model(add_subtract2, input_data) + input_data2 = paddle.rand(input_shape, dtype="float32") + verify_model(add_subtract3, [input_data, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_argmax(): + input_shape = [1, 3, 10, 10] + + class ArgMax(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argmax(inputs) + + class ArgMax1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=1) + + class ArgMax2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=1, keepdim=False) + + class ArgMax3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=2, keepdim=True) + + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMax(), input_data=input_data) + verify_model(ArgMax1(), input_data=input_data) + verify_model(ArgMax2(), input_data=input_data) + verify_model(ArgMax3(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_assign(): + @paddle.jit.to_static + def assign(inputs): + return paddle.assign(inputs) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model( + assign, + [ + input_data, + ], + ) + input_data2 = np.random.randint(100, size=input_shape) + verify_model( + assign, + [ + input_data2, + ], + ) + + +@tvm.testing.uses_gpu +def test_forward_batch_norm(): + class BatchNorm1D(nn.Layer): + def __init__(self): + super(BatchNorm1D, self).__init__() + self.batch_norm = nn.BatchNorm1D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + class BatchNorm2D(nn.Layer): + def __init__(self): + super(BatchNorm2D, self).__init__() + self.batch_norm = nn.BatchNorm2D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + class BatchNorm3D(nn.Layer): + def __init__(self): + super(BatchNorm3D, self).__init__() + self.batch_norm = nn.BatchNorm3D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + input_data = paddle.rand((2, 2, 3), dtype="float32") + verify_model(BatchNorm1D(), input_data=input_data) + input_data = paddle.rand((2, 2, 2, 3), dtype="float32") + verify_model(BatchNorm2D(), input_data=input_data) + input_data = paddle.rand((2, 2, 2, 2, 3), dtype="float32") + verify_model(BatchNorm3D(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_cast(): + @paddle.jit.to_static + def cast1(inputs, dtype="uint8"): + return paddle.cast(inputs, dtype) + + @paddle.jit.to_static + def cast2(inputs, dtype="int64"): + return inputs.cast(dtype) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") * 100 + verify_model( + cast1, + [ + input_data, + ], + ) + verify_model( + cast2, + [ + input_data, + ], + ) + + +@tvm.testing.uses_gpu +def test_forward_concat_unsqueeze(): + @paddle.jit.to_static + def concat_unsqueeze1(inputs): + return paddle.concat([inputs[:, 0].unsqueeze(1), inputs[:, 1].unsqueeze(1)], axis=1) + + @paddle.jit.to_static + def concat_unsqueeze2(inputs): + a = (inputs[:, :, 0] + 2) * 7 + b = (inputs[:, :, 1] + 3) * 11 + c = (inputs[:, :, 2] + 5) * 13 + return paddle.concat([paddle.unsqueeze(t, axis=2) for t in [a, b, c]], axis=2) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(concat_unsqueeze1, input_data=input_data) + verify_model(concat_unsqueeze2, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_cumsum(): + @paddle.jit.to_static + def cusum1(inputs): + return paddle.cumsum(inputs) + + @paddle.jit.to_static + def cusum2(inputs): + return paddle.cumsum(inputs, axis=0) + + @paddle.jit.to_static + def cusum3(inputs): + return paddle.cumsum(inputs, axis=1) + + input_data = paddle.randint(0, 100, (10, 10), dtype=paddle.int32) + verify_model(cusum1, [input_data]) + verify_model(cusum1, [input_data.astype(paddle.int64)]) + verify_model( + cusum2, + [ + input_data, + ], + ) + verify_model( + cusum3, + [ + input_data, + ], + ) + + +@tvm.testing.uses_gpu +def test_forward_conv(): + conv2d_input_shape = [1, 3, 10, 10] + + class Conv2D1(nn.Layer): + def __init__(self): + super(Conv2D1, self).__init__() + self.conv = nn.Conv2D(3, 6, 7, bias_attr=True) + self.softmax = nn.Softmax() + + @paddle.jit.to_static + def forward(self, inputs): + return self.softmax(self.conv(inputs)) + + class Conv2D2(nn.Layer): + def __init__(self): + super(Conv2D2, self).__init__() + self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False) + self.softmax = nn.Softmax() + + @paddle.jit.to_static + def forward(self, inputs): + return self.softmax(self.conv(inputs)) + + conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") + verify_model(Conv2D1(), input_data=conv2d_input_data) + verify_model(Conv2D2(), input_data=conv2d_input_data) + + +@tvm.testing.uses_gpu +def test_forward_dropout(): + @paddle.jit.to_static + def dropout(inputs): + return nn.functional.dropout(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(dropout, input_data=input_data[0, 0]) + verify_model(dropout, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_shape_full(): + @paddle.jit.to_static + def full1(inputs): + return paddle.full(paddle.shape(inputs), 3.14) + + @paddle.jit.to_static + def full2(inputs): + return paddle.full(paddle.shape(inputs), 1.0, dtype=inputs.dtype) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(full1, input_data=[input_data]) + verify_model(full2, input_data=[input_data]) + + +@tvm.testing.uses_gpu +def test_forward_ones_like(): + @paddle.jit.to_static + def ones_like1(inputs): + return paddle.ones_like(inputs) + + @paddle.jit.to_static + def ones_like2(inputs): + return paddle.ones_like(inputs, dtype="int32") + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ones_like1, input_data=input_data) + verify_model(ones_like2, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_gelu(): + @paddle.jit.to_static + def gelu(inputs): + return nn.functional.gelu(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(gelu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_hard_sigmoid(): + @paddle.jit.to_static + def hard_sigmoid(inputs): + return nn.functional.hardsigmoid(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(hard_sigmoid, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_hard_swish(): + @paddle.jit.to_static + def hard_swish(inputs): + return nn.functional.hardswish(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(hard_swish, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_layer_norm(): + @paddle.jit.to_static + def layer_norm(inputs, weight, bias): + return nn.functional.layer_norm(inputs, inputs.shape[-1], weight=weight, bias=bias) + + class LayerNorm(nn.Layer): + def __init__(self): + super(LayerNorm, self).__init__() + data_shape = [10] + self.layer_norm = nn.LayerNorm(data_shape) + + @paddle.jit.to_static + def forward(self, inputs): + return self.layer_norm(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + weight = paddle.rand([10], dtype="float32") + bias = paddle.rand([10], dtype="float32") + verify_model(layer_norm, input_data=[input_data, weight, bias]) + verify_model(LayerNorm(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_leaky_relu(): + @paddle.jit.to_static + def leaky_relu(inputs): + return nn.functional.leaky_relu(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(leaky_relu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_look_up(): + @paddle.jit.to_static + def look_up(inputs, weight): + return nn.functional.embedding(inputs, weight) + + class LookUp(nn.Layer): + def __init__(self): + super(LookUp, self).__init__() + self.embedding = paddle.nn.Embedding(10, 4, sparse=True) + + @paddle.jit.to_static + def forward(self, inputs): + return self.embedding(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.randint(0, 10, input_shape, dtype="int32") + weight = paddle.rand([10, 4], dtype="float32") + verify_model(look_up, input_data=[input_data, weight]) + verify_model(LookUp(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_multiply(): + @paddle.jit.to_static + def multiply1(inputs): + return inputs * inputs + + @paddle.jit.to_static + def multiply2(inputs): + return inputs * 1.0 / 2.0 + + @paddle.jit.to_static + def multiply3(inputs, inputs2): + ones = paddle.ones([10], dtype="float32") + return inputs * ones / inputs2 + + input_shape = [10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(multiply1, input_data=input_data) + verify_model(multiply2, input_data=input_data) + input_data2 = paddle.rand(input_shape, dtype="float32") + verify_model(multiply3, input_data=[input_data, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_matmul(): + class MatMul1(nn.Layer): + def forward(self, input1, input2): + return paddle.matmul(input1, input2) + + # matrix x vector + input_data1 = paddle.randn((3, 4), dtype="float32") + input_data2 = paddle.randn((4,), dtype="float32") + verify_model(MatMul1(), input_data=[input_data1, input_data2]) + + # matrix x matrix + input_data1 = paddle.randn((5, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(MatMul1(), input_data=[input_data1, input_data2]) + + # batched matrix x batched matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((10, 4, 5), dtype="float32") + verify_model(MatMul1(), input_data=[input_data1, input_data2]) + + # batched matrix x broadcasted matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(MatMul1(), input_data=[input_data1, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_pool2d(): + @paddle.jit.to_static + def pool2d1(inputs): + return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) + + @paddle.jit.to_static + def pool2d2(inputs): + return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) + + @paddle.jit.to_static + def pool2d3(inputs): + return nn.functional.max_pool2d( + inputs, kernel_size=2, stride=2, padding=0, return_mask=True + ) + + input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) + verify_model(pool2d1, input_data=input_data) + verify_model(pool2d2, input_data=input_data) + # verify_model(pool2d3, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_relu(): + @paddle.jit.to_static + def relu(inputs): + return nn.functional.relu(inputs) + + input_shape = [10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(relu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_reshape(): + @paddle.jit.to_static + def reshape1(inputs, x): + new_shape = paddle.shape(x) + return paddle.reshape(inputs, new_shape) + + @paddle.jit.to_static + def reshape2(inputs): + return inputs.reshape([-1]) + + @paddle.jit.to_static + def reshape3(inputs): + data_shape = inputs.shape + return inputs.reshape([data_shape[0] * data_shape[1], data_shape[2]]) + + @paddle.jit.to_static + def reshape4(inputs, x): + new_shape = paddle.shape(x) + return paddle.reshape(inputs, [new_shape[2], 2, -1]) + + input_shape = [2, 1, 10, 1, 10] + input_data = paddle.rand(input_shape, dtype="float32") + input_data2 = paddle.randn([2, 1, 10, 10]) + verify_model(reshape1, input_data=[input_data, input_data2]) + verify_model(reshape2, input_data=input_data) + verify_model(reshape3, input_data=paddle.randn((2, 3, 4))) + verify_model(reshape4, input_data=[input_data, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_scale(): + @paddle.jit.to_static + def scale1(inputs): + return paddle.scale(inputs, scale=2.0, bias=1.0) + + @paddle.jit.to_static + def scale2(inputs): + return paddle.scale(inputs, scale=3, bias=2.1, act="gelu") + + input_data = paddle.randn(shape=[2, 3], dtype="float32") + verify_model( + scale1, + input_data=[ + input_data, + ], + ) + verify_model(scale2, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_slice(): + @paddle.jit.to_static + def slice1(inputs): + return inputs[:, :, :, :3] + + @paddle.jit.to_static + def slice2(inputs): + return inputs[0, :, :-3, :] + + @paddle.jit.to_static + def slice3(inputs): + return inputs[0::2, 0::2] + inputs[1::2, 1::2] + + @paddle.jit.to_static + def slice4(inputs): + x0 = paddle.to_tensor([2]) - paddle.to_tensor([1]) + x1 = paddle.to_tensor([3]) + paddle.to_tensor([1]) + return inputs[:, x0:, 1:x1, :] + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model( + slice1, + input_data=[ + input_data, + ], + ) + verify_model(slice2, input_data=input_data) + # need op "strided_slice" + # verify_model(slice3, input_data=paddle.randn((4, 4))) + # need op "assign_value" + # verify_model(slice4, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_tanh(): + @paddle.jit.to_static + def tanh(inputs): + return paddle.tanh(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(tanh, input_data=input_data) + + +if __name__ == "__main__": + test_forward_add_subtract() + test_forward_argmax() + test_forward_assign() + test_forward_batch_norm() + test_forward_cast() + test_forward_concat_unsqueeze() + test_forward_cumsum() + test_forward_conv() + test_forward_dropout() + test_forward_shape_full() + test_forward_ones_like() + test_forward_gelu() + test_forward_hard_sigmoid() + test_forward_hard_swish() + test_forward_layer_norm() + test_forward_leaky_relu() + test_forward_look_up() + test_forward_multiply() + test_forward_matmul() + test_forward_pool2d() + test_forward_relu() + test_forward_reshape() + test_forward_scale() + test_forward_slice() + test_forward_tanh() diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 62a0fa1e7fc8..a2f6d706a163 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -51,3 +51,6 @@ run_pytest cython python-frontend-darknet tests/python/frontend/darknet echo "Running relay PyTorch frontend test..." run_pytest cython python-frontend-pytorch tests/python/frontend/pytorch + +echo "Running relay PaddlePaddle frontend test..." +run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddle From 3306857d80bfc76cdd10d7a40028f52b7ca696aa Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 27 Aug 2021 17:28:50 +0800 Subject: [PATCH 20/42] [Runtime] add set_output_zero_copy (#8497) * Update graph_executor.h * Update graph_executor.cc * modify zero copy UT add set input zero copy * modify C style * add runtime test * realy build generatr the json Co-authored-by: hwstaff --- src/runtime/graph_executor/graph_executor.cc | 106 +++++++++++-- src/runtime/graph_executor/graph_executor.h | 28 ++++ tests/cpp/runtime_test.cc | 154 +++++++++++++++++++ 3 files changed, 274 insertions(+), 14 deletions(-) create mode 100644 tests/cpp/runtime_test.cc diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index bc73a5988377..dbd072a68fb5 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -91,6 +91,11 @@ void GraphExecutor::Init(const std::string& graph_json, tvm::runtime::Module mod std::string& name = nodes_[nid].name; input_map_[name] = i; } + for (size_t i = 0; i < outputs_.size(); i++) { + const uint32_t nid = outputs_[i].node_id; + std::string& name = nodes_[nid].name; + output_map_[name] = i; + } } /*! * \brief Get the input index given the name of input. @@ -104,6 +109,18 @@ int GraphExecutor::GetInputIndex(const std::string& name) { } return -1; } +/*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ +int GraphExecutor::GetOutputIndex(const std::string& name) { + auto it = output_map_.find(name); + if (it != output_map_.end()) { + return it->second; + } + return -1; +} /*! * \brief set index-th input to the graph. * \param index The input index. @@ -114,6 +131,23 @@ void GraphExecutor::SetInput(int index, DLTensor* data_in) { uint32_t eid = this->entry_id(input_nodes_[index], 0); data_entry_[eid].CopyFrom(data_in); } +/*! + * \brief Check the legality of external DLTensor*. + * \param external The external DLTensor*. + * \param eid The data_enrty_ index. + */ +void GraphExecutor::CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const { + const DLTensor* internal = data_entry_[eid].operator->(); + + ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*external)); + ICHECK_EQ(reinterpret_cast(external->data) % kAllocAlignment, 0); + ICHECK_EQ(internal->ndim, static_cast(external->ndim)); + ICHECK_EQ(internal->device.device_type, external->device.device_type); + ICHECK_EQ(internal->device.device_id, external->device.device_id); + for (auto i = 0; i < external->ndim; ++i) { + ICHECK_EQ(internal->shape[i], external->shape[i]); + } +} /*! * \brief set index-th input to the graph without copying the data. * \param index The input index. @@ -122,23 +156,37 @@ void GraphExecutor::SetInput(int index, DLTensor* data_in) { void GraphExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { ICHECK_LT(static_cast(index), input_nodes_.size()); uint32_t eid = this->entry_id(input_nodes_[index], 0); - const DLTensor* old_t = data_entry_[eid].operator->(); - // check the consistency of input - ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref)); - ICHECK_EQ(reinterpret_cast(data_ref->data) % kAllocAlignment, 0); - ICHECK_EQ(old_t->ndim, static_cast(data_ref->ndim)); - ICHECK_EQ(old_t->device.device_type, data_ref->device.device_type); - ICHECK_EQ(old_t->device.device_id, data_ref->device.device_id); - for (auto i = 0; i < data_ref->ndim; ++i) { - ICHECK_EQ(old_t->shape[i], data_ref->shape[i]); - } - + CheckExternalDLTensor(data_ref, eid); // Update the data pointer for each argument of each op for (DLTensor* t : input_dltensors_[eid]) { t->data = data_ref->data; } } +/*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ +void GraphExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { + ICHECK_LT(static_cast(index), outputs_.size()); + ICHECK_LT(static_cast(index), output_dltensors_.size()); + const NodeEntry& output_node = outputs_[index]; + uint32_t output_node_eid = this->entry_id(output_node); + + // check the consistency of output + CheckExternalDLTensor(data_ref, output_node_eid); + + // Update the data pointer for output op + for (DLTensor* t : output_dltensors_[output_node_eid]) { + t->data = data_ref->data; + } + + // Update the input of the op connected to the output + for (DLTensor* t : both_output_opinput_dltensors_[output_node_eid]) { + t->data = data_ref->data; + } +} /*! * \brief Get the number of outputs * @@ -358,11 +406,17 @@ void GraphExecutor::SetupStorage() { void GraphExecutor::SetupOpExecs() { op_execs_.resize(this->GetNumOfNodes()); input_dltensors_.resize(num_node_entries()); + output_dltensors_.resize(num_node_entries()); + both_output_opinput_dltensors_.resize(num_node_entries()); std::unordered_set input_node_eids; for (size_t i = 0; i < input_nodes_.size(); i++) { uint32_t nid = input_nodes_[i]; input_node_eids.insert(entry_id(nid, 0)); } + std::unordered_set output_node_eids; + for (size_t i = 0; i < outputs_.size(); i++) { + output_node_eids.insert(entry_id(outputs_[i])); + } // setup the array and requirements. for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { @@ -383,10 +437,25 @@ void GraphExecutor::SetupOpExecs() { std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args); for (size_t i = 0; i < inode.inputs.size(); i++) { - uint32_t eid = this->entry_id(inode.inputs[i]); + uint32_t input_eid = this->entry_id(inode.inputs[i]); // check if op input is model input - if (input_node_eids.count(eid) > 0) { - input_dltensors_[eid].push_back(static_cast(op_args->arg_values[i].v_handle)); + if (input_node_eids.count(input_eid) > 0) { + input_dltensors_[input_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); + } + // check if any model output is the input of the op + if (output_node_eids.count(input_eid) > 0) { + both_output_opinput_dltensors_[input_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); + } + } + + for (uint32_t i = inode.inputs.size(); i < inode.inputs.size() + inode.param.num_outputs; ++i) { + uint32_t output_eid = this->entry_id(nid, i - inode.inputs.size()); + // check if op output is model output + if (output_node_eids.count(output_eid) > 0) { + output_dltensors_[output_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); } } } @@ -462,6 +531,15 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, this->SetInputZeroCopy(args[0], args[1]); } }); + } else if (name == "set_output_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int out_idx = this->GetOutputIndex(args[0].operator String()); + if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); + } else { + this->SetOutputZeroCopy(args[0], args[1]); + } + }); } else if (name == "get_output") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (args.num_args == 2) { diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 42b5c405b406..87e8aa3cee34 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -107,6 +107,13 @@ class TVM_DLL GraphExecutor : public ModuleNode { */ int GetInputIndex(const std::string& name); + /*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ + int GetOutputIndex(const std::string& name); + /*! * \brief set index-th input to the graph. * \param index The input index. @@ -119,6 +126,12 @@ class TVM_DLL GraphExecutor : public ModuleNode { * \param data_ref The input data that is referred. */ void SetInputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ + void SetOutputZeroCopy(int index, DLTensor* data_ref); /*! * \brief Get the number of outputs * @@ -193,6 +206,9 @@ class TVM_DLL GraphExecutor : public ModuleNode { uint32_t node_id; uint32_t index; uint32_t version; + inline bool operator==(const NodeEntry& other) const { + return node_id == other.node_id && index == other.index && version == other.version; + } // JSON Loader void Load(dmlc::JSONReader* reader) { reader->BeginArray(); @@ -377,6 +393,12 @@ class TVM_DLL GraphExecutor : public ModuleNode { void SetupStorage(); /*! \brief Setup the executors. */ void SetupOpExecs(); + /*! + * \brief Check the legality of external DLTensor*. + * \param external The external DLTensor*. + * \param eid The data_enrty_ index. + */ + void CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const; /*! * \brief Create an execution function given input. * \param attrs The node attributes. @@ -397,8 +419,14 @@ class TVM_DLL GraphExecutor : public ModuleNode { std::vector input_nodes_; /*! \brief Map of input names to input indices. */ std::unordered_map input_map_; + /*! \brief Map of output names to output indices. */ + std::unordered_map output_map_; /*! \brief Used for quick node input DLTensor* lookup given an input eid. */ std::vector> input_dltensors_; + /*! \brief Used for quick node output DLTensor* lookup given an output eid. */ + std::vector> output_dltensors_; + /*! \brief Used for quick node(both model output and op input) DLTensor* lookup given an eid. */ + std::vector> both_output_opinput_dltensors_; /*! \brief Used for quick entry indexing. */ std::vector node_row_ptr_; /*! \brief Output entries. */ diff --git a/tests/cpp/runtime_test.cc b/tests/cpp/runtime_test.cc new file mode 100644 index 000000000000..6dbcd61b8c37 --- /dev/null +++ b/tests/cpp/runtime_test.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::relay; + +TVM_REGISTER_GLOBAL("runtime_test.strategy") + .set_body_typed([](const Attrs& attrs, const Array& inputs, const Type& out_type, + const Target& target) { + FTVMCompute fcompute = [](const Attrs& attrs, const Array& inputs, + const Type& out_type) -> Array { + ICHECK_EQ(inputs.size(), 2U); + return {topi::add(inputs[0], inputs[1])}; + }; + FTVMSchedule fschedule = [](const Attrs& attrs, const Array& outs, + const Target& target) { + With target_scope(target); + return topi::generic::schedule_injective(target, outs); + }; + + auto n = make_object(); + auto strategy = tvm::relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "runtime_test.strategy", 10); + return strategy; + }); + +TEST(Runtime, ZeroCopy) { + auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); + auto a = relay::Var("a", tensor_type); + auto b = relay::Var("b", tensor_type); + auto add_op = relay::Op::Get("add"); + auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {}); + auto c = relay::Var("c", tensor_type); + auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {}); + auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); + auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + auto pA = static_cast(A->data); + auto pB = static_cast(B->data); + auto pC = static_cast(C->data); + auto pY = static_cast(Y->data); + + for (int i = 0; i < 6; ++i) { + pA[i] = i; + pB[i] = i + 1; + pC[i] = i + 2; + } + // get schedule + auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr"); + if (!reg) { + LOG(FATAL) << "no _Register"; + } + auto fs = tvm::runtime::Registry::Get("runtime_test.strategy"); + if (!fs) { + LOG(FATAL) << "No test_strategy registered."; + } + auto fgeneric = GenericFunc::Get("runtime_test.strategy_generic").set_default(*fs); + (*reg)("add", "FTVMStrategy", fgeneric, 10); + Array dep; + dep.push_back(0); + (*reg)("add", "TShapeDataDependent", dep, 10); + // build + auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); + tvm::runtime::Module build_mod = (*pfb)(); + auto build_f = build_mod.GetFunction("build", false); + auto json_f = build_mod.GetFunction("get_graph_json", false); + auto mod_f = build_mod.GetFunction("get_module", false); + Map targets; + Target llvm_tgt = Target("llvm"); + targets.Set(0, llvm_tgt); + auto relay_mod = tvm::IRModule::FromExpr(func); + ICHECK(relay_mod.defined()) << "Module must be defined"; + build_f(relay_mod, targets, llvm_tgt, runtime::kTvmExecutorGraph, ""); + // create graph executor + std::string json = json_f(); + tvm::runtime::Module mod = mod_f(); + auto dev = A->device; + auto pfr = tvm::runtime::Registry::Get("tvm.graph_executor.create"); + ICHECK(mod.defined()) << "Module must be defined"; + tvm::runtime::Module run_mod = + (*pfr)(json, mod, static_cast(dev.device_type), dev.device_id); + // get function + auto set_input_f = run_mod.GetFunction("set_input_zero_copy", false); + auto set_output_f = run_mod.GetFunction("set_output_zero_copy", false); + auto run_f = run_mod.GetFunction("run", false); + // set input zero copy + set_input_f("a", const_cast(A.operator->())); + set_input_f("b", const_cast(B.operator->())); + set_input_f("c", const_cast(C.operator->())); + // set output zero copy + set_output_f(0, const_cast(Y.operator->())); + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); + } + // mutate the input a bit and run it again + for (int i = 0; i < 6; ++i) { + pB[i] = i + 3; + } + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 3) + (i + 2))), 1e-4); + } + // attach a different input and run it again + auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto pC2 = static_cast(C2->data); + for (int i = 0; i < 6; ++i) { + pC2[i] = i + 4; + } + set_input_f("c", const_cast(C2.operator->())); + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 3) + (i + 4))), 1e-4); + } +} From cf19c889214ca9a1b8c420baff35aa10986b3d9c Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 27 Aug 2021 05:22:47 -0500 Subject: [PATCH 21/42] [Hexagon] Change declaration order of unique_ptr objects to fix crash (#8859) A crash occurs when automatically deleting an instance of CodeGenHexagon because the LLVMContext object has already been freed. Objects of both types are created using unique_ptr, but the object managed by the LLVMContext unique_ptr is passed to CodeGenHexagon object (not as a unique_ptr). This crash is fixed by moving the declaration of the LLVMContext object before the CodeGenHexagon object. I'm not sure if this is the best way to fix this, but it does fix the crash. Also, in other files, the LLVMContext object is always created first. Co-authored-by: Cahoon, Brendon --- src/target/llvm/codegen_hexagon.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 9d324d56887f..26356a547990 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -704,8 +704,8 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { (void)CallOnce; std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr cg(new CodeGenHexagon()); std::unique_ptr ctx(new llvm::LLVMContext()); + std::unique_ptr cg(new CodeGenHexagon()); cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; From 55bb8b60b707d5fc25c3828adf6086aa01bcc039 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 27 Aug 2021 14:39:03 -0700 Subject: [PATCH 22/42] [Graph Executor, VM] Add end to end benchmarking of models (#8858) Add benchmarking that includes ovearhead of transfering inputs and outputs to and from the device. This should give an accurate measurement of the runtime a user would see when using the model. This is accomplished by adding functions that run from inputs to return values into the graph executor and the VM. --- include/tvm/runtime/vm/vm.h | 10 ++ python/tvm/contrib/graph_executor.py | 34 ++++++- python/tvm/runtime/vm.py | 37 +++++++- src/runtime/graph_executor/graph_executor.cc | 28 ++++++ src/runtime/vm/vm.cc | 93 ++++++++++++------- .../relay/test_backend_graph_executor.py | 36 +++++++ tests/python/relay/test_vm.py | 32 +++++++ 7 files changed, 229 insertions(+), 41 deletions(-) diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 58c6ee037fb5..2fdfec9452af 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -258,6 +258,16 @@ class VirtualMachine : public runtime::ModuleNode { */ void InvokeGlobal(const VMFunction& func, const std::vector& args); + /*! + * \brief Set inputs to a function. + * \param name The function name + * \param args args[offset:] are arguments to the + * function. If the arguments are not of the correct device for the function, + * they will be copied to the device. + * \param offset Starting offset of the arguments in `args`. + */ + void SetInput(std::string name, TVMArgs args, int offset); + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs_; diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 2e8ff1d62421..f064f8dbee69 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -321,7 +321,16 @@ def __getitem__(self, key): """ return self.module[key] - def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=None, **kwargs): + def benchmark( + self, + device, + func_name="run", + repeat=5, + number=5, + min_repeat_ms=None, + end_to_end=False, + **kwargs, + ): """Calculate runtime of a function by repeatedly calling it. Use this function to get an accurate measurement of the runtime of a function. The function @@ -329,7 +338,8 @@ def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=N or other external factors. Mean, median, standard deviation, min and max runtime are all reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that synchonization and data transfer operations are not counted towards the runtime. This allows - for fair comparison of runtimes across different functions and models. + for fair comparison of runtimes across different functions and models. The `end_to_end` flag + switches this behavior to include data transfer operations in the runtime. The benchmarking loop looks approximately like so: @@ -346,7 +356,7 @@ def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=N Parameters ---------- func_name : str - The function to benchmark + The function to benchmark. This is ignored if `end_to_end` is true. repeat : int Number of times to run the outer loop of the timing code (see above). The output will @@ -363,6 +373,11 @@ def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=N milliseconds. This can be used to ensure that the function is run enough to get an accurate measurement. + end_to_end : bool + If set, include time to transfer input tensors to the device and time to transfer + returned tensors in the total runtime. This will give accurate timings for end to end + workloads. + kwargs : Dict[str, Object] Named arguments to the function. These are cached before running timing code, so that data transfer costs are not counted in the runtime. @@ -374,6 +389,19 @@ def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=N access the individual runtimes (in seconds). """ min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms + if end_to_end: + # Have to unpack kwargs into a single list + args = [] + for k, v in kwargs.items(): + args.append(k) + args.append(v) + return self.module.time_evaluator( + "run_from_inputs", + device, + repeat=repeat, + number=number, + min_repeat_ms=min_repeat_ms, + )(device.device_type, device.device_id, *args) if kwargs: self.set_input(**kwargs) return self.module.time_evaluator( diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index aeb651cb5ae4..6416ad7814e1 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -509,16 +509,25 @@ def get_input_index(self, input_name, func_name="main"): return self._get_input_index(input_name, func_name) def benchmark( - self, device, *args, func_name="main", repeat=5, number=5, min_repeat_ms=None, **kwargs + self, + device, + *args, + func_name="main", + repeat=5, + number=5, + min_repeat_ms=None, + end_to_end=False, + **kwargs, ): """Calculate runtime of a function by repeatedly calling it. Use this function to get an accurate measurement of the runtime of a function. The function is run multiple times in order to account for variability in measurements, processor speed or other external factors. Mean, median, standard deviation, min and max runtime are all - reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that + reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that synchonization and data transfer operations are not counted towards the runtime. This allows - for fair comparison of runtimes across different functions and models. + for fair comparison of runtimes across different functions and models. The `end_to_end` flag + switches this behavior to include data transfer operations in the runtime. The benchmarking loop looks approximately like so: @@ -552,6 +561,11 @@ def benchmark( milliseconds. This can be used to ensure that the function is run enough to get an accurate measurement. + end_to_end : bool + If set, include time to transfer input tensors to the device and time to transfer + returned tensors in the total runtime. This will give accurate timings for end to end + workloads. + args : Sequence[Object] Arguments to the function. These are cached before running timing code, so that data transfer costs are not counted in the runtime. @@ -566,6 +580,23 @@ def benchmark( access the individual runtimes (in seconds). """ min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms + if end_to_end: + # We need to unpack keyword arguments into positional arguments + packed_args = list(args) + for k, v in kwargs.items(): + i = self.get_input_index(k, func_name) + if i < 0: + raise TypeError(f"{func_name}() got an unexpected keyword argument '{k}'") + while i >= len(packed_args): + packed_args.append(None) + packed_args[i] = v + return self.module.time_evaluator( + "invoke_return_to_device", + device, + repeat=repeat, + number=number, + min_repeat_ms=min_repeat_ms, + )(func_name, device.device_type, device.device_id, *packed_args) if args or kwargs: self.set_input(func_name, *args, **kwargs) return self.module.time_evaluator( diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index dbd072a68fb5..6fe640e87404 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -568,6 +568,34 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); } else if (name == "run") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); + } else if (name == "run_from_inputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() % 2 == 0) + << "Number of arguments to run_from_inputs must be an even number of key-value pairs"; + Device host{static_cast(args[0].operator int()), args[1].operator int()}; + for (int i = 2; i < args.size(); i += 2) { + if (String::CanConvertFrom(args[i])) { + int in_idx = this->GetInputIndex(args[i].operator String()); + if (in_idx >= 0) { + this->SetInput(in_idx, args[i + 1]); + } else { + LOG(FATAL) << args[i].operator String() << " is not a valid input name"; + } + } else { + this->SetInput(args[i], args[i + 1]); + } + } + this->Run(); + Array outputs; + for (int i = 0; i < this->NumOutputs(); i++) { + NDArray out = this->GetOutput(i); + NDArray a = NDArray::Empty(out.Shape(), out.DataType(), host); + a.CopyFrom(out); + outputs.push_back(a); + } + *rv = outputs; + }); } else if (name == "load_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->LoadParams(args[0].operator std::string()); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 925e867f2e1b..4df013baa2fb 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -118,6 +118,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK(exec_) << "The executable is not created yet."; + std::string func_name = args[0]; auto git = exec_->global_map.find(func_name); ICHECK(git != exec_->global_map.end()) @@ -140,6 +141,26 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, TVMRetValue rv_; invoke.CallPacked(args, &rv_); }); + } else if (name == "invoke_return_to_device") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Device host{static_cast(args[1].operator int()), args[2].operator int()}; + + SetInput(args[0].operator std::string(), args, 3); + PackedFunc invoke = GetFunction("invoke", sptr_to_self); + TVMRetValue rv_; + invoke.CallPacked(args, &rv_); // Invoke only uses the first arg, so the rest of the args + // should not cause an issue + if (rv_.type_code() == kTVMObjectHandle) { + ADT adt = Downcast(rv_.operator ObjectRef()); + std::vector transfered; + for (size_t i = 0; i < adt.size(); i++) { + transfered.push_back(CopyTo(adt[i], host)); + } + *rv = ADT(adt.tag(), transfered); + } else { + *rv = CopyTo(rv_, host); + } + }); } else if (name == "get_output") { return TypedPackedFunc([this](int64_t index) { if (this->return_register_.as()) { @@ -191,47 +212,49 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, this->Init(devices, alloc_types); }); } else if (name == "set_input") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK(exec_) << "The executable is not created yet."; - std::string func_name = args[0]; - auto gvit = exec_->global_map.find(func_name); - ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; - auto func_index = gvit->second; - const auto& vm_func = exec_->functions[func_index]; - const auto& param_names = vm_func.params; - ICHECK_EQ(args.size() - 1, param_names.size()) - << "The number of provided parameters doesn't match the number of arguments"; - ICHECK_EQ(param_names.size(), vm_func.params_device_type.size()) - << "The number of provided parameters doesn't match the number of assigned devices"; - std::vector func_args(param_names.size()); - for (int i = 1; i < args.size(); ++i) { - Index device_type = vm_func.params_device_type[i - 1]; - Device dev = GetDevice(device_type); - - if (args[i].type_code() == kTVMDLTensorHandle) { - // Automatically convert input DLTensors to NDArray - DLTensor* tensor = args[i]; - std::vector shape; - for (int64_t i = 0; i < tensor->ndim; i++) { - shape.push_back(tensor->shape[i]); - } - NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); - ary.CopyFrom(tensor); - func_args[i - 1] = ary; - } else { - ObjectRef obj = CopyTo(args[i], dev); - func_args[i - 1] = obj; - } - } - inputs_.erase(func_name); - inputs_.emplace(func_name, func_args); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); } } +void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { + ICHECK(exec_) << "The executable is not created yet."; + auto gvit = exec_->global_map.find(func_name); + ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; + auto func_index = gvit->second; + const auto& vm_func = exec_->functions[func_index]; + const auto& param_names = vm_func.params; + ICHECK_EQ(args.size() - offset, param_names.size()) + << "The number of provided parameters doesn't match the number of arguments"; + ICHECK_EQ(param_names.size(), vm_func.params_device_type.size()) + << "The number of provided parameters doesn't match the number of assigned devices"; + std::vector func_args(param_names.size()); + for (int i = offset; i < args.size(); ++i) { + Index device_type = vm_func.params_device_type[i - offset]; + Device dev = GetDevice(device_type); + + if (args[i].type_code() == kTVMDLTensorHandle) { + // Automatically convert input DLTensors to NDArray + DLTensor* tensor = args[i]; + std::vector shape; + for (int64_t i = 0; i < tensor->ndim; i++) { + shape.push_back(tensor->shape[i]); + } + NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); + ary.CopyFrom(tensor); + func_args[i - offset] = ary; + } else { + ObjectRef obj = CopyTo(args[i], dev); + func_args[i - offset] = obj; + } + } + inputs_.erase(func_name); + inputs_.emplace(func_name, func_args); +} + inline Device VirtualMachine::GetDevice(Index device_type) const { ICHECK_GE(devices_.size(), device_type) << "devices_ doesn't contain device:" << device_type; diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 9e212527838e..f1ab58e7bf07 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -25,6 +25,8 @@ from tvm.relay.op import add import tvm.testing from tvm.relay.testing import mlp +from tvm import rpc +from tvm.contrib import utils # @tq, @jr should we put this in testing ns? def check_rts(expr, args, expected_result, mod=None): @@ -348,5 +350,39 @@ def test_benchmark(): assert result.std == 1.5 +@tvm.testing.parametrize_targets("cuda", "llvm") +def test_benchmark_end_to_end(dev, target): + mod, params = mlp.get_workload(1) + lib = relay.build(mod, target=target, params=params) + exe = graph_executor.create(lib.get_graph_json(), lib.lib, dev) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32")) + result = exe.benchmark(dev, data=data, func_name="run", repeat=2, number=1, end_to_end=True) + assert result.mean > 0 + assert len(result.results) == 2 + + +@tvm.testing.requires_llvm +def test_benchmark_end_to_end_rpc(): + server = rpc.Server("127.0.0.1") + remote = rpc.connect(server.host, server.port) + + mod, params = mlp.get_workload(1) + lib = relay.build(mod, target="llvm", params=params) + + temp = utils.tempdir() + path = temp.relpath("library.so") + lib.export_library(path) + remote.upload(path) + rlib = remote.load_module("library.so") + + dev = remote.cpu() + exe = graph_executor.create(lib.get_graph_json(), rlib, dev) + + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev) + result = exe.benchmark(dev, data=data, func_name="run", repeat=2, number=1, end_to_end=True) + assert result.mean > 0 + assert len(result.results) == 2 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index c7043481ee3d..4c5b98514724 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -981,5 +981,37 @@ def test_benchmark(): assert result.std == 1.5 +@tvm.testing.parametrize_targets("cuda", "llvm") +def test_benchmark_end_to_end(dev, target): + mod, params = mlp.get_workload(1) + lib = vm.compile(mod, target=target, params=params) + exe = runtime.vm.VirtualMachine(lib, dev) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev) + result = exe.benchmark(dev, data, func_name="main", repeat=2, number=1, end_to_end=True) + assert result.mean > 0 + + +@tvm.testing.requires_llvm +def test_benchmark_end_to_end_rpc(): + server = rpc.Server("127.0.0.1") + remote = rpc.connect(server.host, server.port) + + mod, params = mlp.get_workload(1) + lib = vm.compile(mod, target="llvm", params=params) + + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + lib.mod.export_library(path) + remote.upload(path) + rlib = remote.load_module("vm_library.so") + + exe = runtime.vm.VirtualMachine(rlib, remote.cpu()) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=remote.cpu()) + result = exe.benchmark( + remote.cpu(), data=data, func_name="main", repeat=2, number=1, end_to_end=True + ) + assert result.mean > 0 + + if __name__ == "__main__": pytest.main([__file__]) From 3c86eec10ff8ced914db2af5873dfa91b76e5523 Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Fri, 27 Aug 2021 16:43:49 -0500 Subject: [PATCH 23/42] [UnitTests] Expose TVM pytest helpers as plugin (#8532) * [UnitTests] Expose TVM pytest helpers as plugin Previously, pytest helper utilities such as automatic parametrization of `target`/`dev`, or `tvm.testing.parameter` were only available for tests within the `${TVM_HOME}/tests` directory. This PR extracts the helper utilities into an importable plugin, which can be used in external tests (e.g. one-off debugging). * [UnitTests] Refactor the plugin-specific logic out into plugin.py. * [UnitTests] Moved marker definition out to global variable. --- conftest.py | 33 +- pytest.ini | 26 -- python/tvm/testing/__init__.py | 5 +- python/tvm/testing/plugin.py | 294 ++++++++++++++++++ python/tvm/testing/utils.py | 238 ++------------ .../unittest/test_tvm_testing_features.py | 7 +- 6 files changed, 328 insertions(+), 275 deletions(-) delete mode 100644 pytest.ini create mode 100644 python/tvm/testing/plugin.py diff --git a/conftest.py b/conftest.py index f591fe970de8..28859fd4a17b 100644 --- a/conftest.py +++ b/conftest.py @@ -14,36 +14,5 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest -from pytest import ExitCode -import tvm -import tvm.testing - - -def pytest_configure(config): - print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets()))) - print("pytest marker:", config.option.markexpr) - - -@pytest.fixture -def dev(target): - return tvm.device(target) - - -def pytest_generate_tests(metafunc): - tvm.testing._auto_parametrize_target(metafunc) - tvm.testing._parametrize_correlated_parameters(metafunc) - - -def pytest_collection_modifyitems(config, items): - tvm.testing._count_num_fixture_uses(items) - tvm.testing._remove_global_fixture_definitions(items) - - -def pytest_sessionfinish(session, exitstatus): - # Don't exit with an error if we select a subset of tests that doesn't - # include anything - if session.config.option.markexpr != "": - if exitstatus == ExitCode.NO_TESTS_COLLECTED: - session.exitstatus = ExitCode.OK +pytest_plugins = ["tvm.testing.plugin"] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 675f8fe9b5a0..000000000000 --- a/pytest.ini +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -[pytest] -markers = - gpu: mark a test as requiring a gpu - tensorcore: mark a test as requiring a tensorcore - cuda: mark a test as requiring cuda - opencl: mark a test as requiring opencl - rocm: mark a test as requiring rocm - vulkan: mark a test as requiring vulkan - metal: mark a test as requiring metal - llvm: mark a test as requiring llvm diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index 54004365ca7a..f610c6ecc0db 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + # pylint: disable=redefined-builtin, wildcard-import """Utility Python functions for TVM testing""" from .utils import assert_allclose, assert_prim_expr_equal, check_bool_expr_is_true @@ -23,9 +24,7 @@ from .utils import known_failing_targets, requires_cuda, requires_cudagraph from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc from .utils import requires_tensorcore, requires_metal, requires_micro, requires_opencl -from .utils import _auto_parametrize_target, _count_num_fixture_uses -from .utils import _remove_global_fixture_definitions, _parametrize_correlated_parameters -from .utils import _pytest_target_params, identity_after, terminate_self +from .utils import identity_after, terminate_self from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count from ._ffi_api import test_wrap_callback, test_raise_error_callback, test_check_eq_callback diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py new file mode 100644 index 000000000000..06b4fa4f65eb --- /dev/null +++ b/python/tvm/testing/plugin.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pytest plugin for using tvm testing extensions. + +TVM provides utilities for testing across all supported targets, and +to more easily parametrize across many inputs. For more information +on usage of these features, see documentation in the tvm.testing +module. + +These are enabled by default in all pytests provided by tvm, but may +be useful externally for one-off testing. To enable, add the +following line to the test script, or to the conftest.py in the same +directory as the test scripts. + + pytest_plugins = ['tvm.testing.plugin'] + +""" + +import collections + +import pytest +import _pytest + +import tvm +from tvm.testing import utils + + +MARKERS = { + "gpu": "mark a test as requiring a gpu", + "tensorcore": "mark a test as requiring a tensorcore", + "cuda": "mark a test as requiring cuda", + "opencl": "mark a test as requiring opencl", + "rocm": "mark a test as requiring rocm", + "vulkan": "mark a test as requiring vulkan", + "metal": "mark a test as requiring metal", + "llvm": "mark a test as requiring llvm", +} + + +def pytest_configure(config): + """Runs at pytest configure time, defines marks to be used later.""" + + for markername, desc in MARKERS.items(): + config.addinivalue_line("markers", "{}: {}".format(markername, desc)) + + print("enabled targets:", "; ".join(map(lambda x: x[0], utils.enabled_targets()))) + print("pytest marker:", config.option.markexpr) + + +def pytest_generate_tests(metafunc): + """Called once per unit test, modifies/parametrizes it as needed.""" + _parametrize_correlated_parameters(metafunc) + _auto_parametrize_target(metafunc) + + +def pytest_collection_modifyitems(config, items): + """Called after all tests are chosen, currently used for bookkeeping.""" + # pylint: disable=unused-argument + _count_num_fixture_uses(items) + _remove_global_fixture_definitions(items) + + +@pytest.fixture +def dev(target): + """Give access to the device to tests that need it.""" + return tvm.device(target) + + +def pytest_sessionfinish(session, exitstatus): + # Don't exit with an error if we select a subset of tests that doesn't + # include anything + if session.config.option.markexpr != "": + if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED: + session.exitstatus = pytest.ExitCode.OK + + +def _auto_parametrize_target(metafunc): + """Automatically applies parametrize_targets + + Used if a test function uses the "target" fixture, but isn't + already marked with @tvm.testing.parametrize_targets. Intended + for use in the pytest_generate_tests() handler of a conftest.py + file. + + """ + + def update_parametrize_target_arg( + argnames, + argvalues, + *args, + **kwargs, + ): + args = [arg.strip() for arg in argnames.split(",") if arg.strip()] + if "target" in args: + target_i = args.index("target") + + new_argvalues = [] + for argvalue in argvalues: + + if isinstance(argvalue, _pytest.mark.structures.ParameterSet): + # The parametrized value is already a + # pytest.param, so track any marks already + # defined. + param_set = argvalue.values + target = param_set[target_i] + additional_marks = argvalue.marks + elif len(args) == 1: + # Single value parametrization, argvalue is a list of values. + target = argvalue + param_set = (target,) + additional_marks = [] + else: + # Multiple correlated parameters, argvalue is a list of tuple of values. + param_set = argvalue + target = param_set[target_i] + additional_marks = [] + + new_argvalues.append( + pytest.param( + *param_set, marks=_target_to_requirement(target) + additional_marks + ) + ) + + try: + argvalues[:] = new_argvalues + except TypeError as err: + pyfunc = metafunc.definition.function + filename = pyfunc.__code__.co_filename + line_number = pyfunc.__code__.co_firstlineno + msg = ( + f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) " + "is parametrized using a tuple of parameters instead of a list " + "of parameters." + ) + raise TypeError(msg) from err + + if "target" in metafunc.fixturenames: + # Update any explicit use of @pytest.mark.parmaetrize to + # parametrize over targets. This adds the appropriate + # @tvm.testing.requires_* markers for each target. + for mark in metafunc.definition.iter_markers("parametrize"): + update_parametrize_target_arg(*mark.args, **mark.kwargs) + + # Check if any explicit parametrizations exist, and apply one + # if they do not. If the function is marked with either + # excluded or known failing targets, use these to determine + # the targets to be used. + parametrized_args = [ + arg.strip() + for mark in metafunc.definition.iter_markers("parametrize") + for arg in mark.args[0].split(",") + ] + if "target" not in parametrized_args: + excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", []) + xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", []) + metafunc.parametrize( + "target", + _pytest_target_params(None, excluded_targets, xfail_targets), + scope="session", + ) + + +def _count_num_fixture_uses(items): + # Helper function, counts the number of tests that use each cached + # fixture. Should be called from pytest_collection_modifyitems(). + for item in items: + is_skipped = item.get_closest_marker("skip") or any( + mark.args[0] for mark in item.iter_markers("skipif") + ) + if is_skipped: + continue + + for fixturedefs in item._fixtureinfo.name2fixturedefs.values(): + # Only increment the active fixturedef, in a name has been overridden. + fixturedef = fixturedefs[-1] + if hasattr(fixturedef.func, "num_tests_use_this_fixture"): + fixturedef.func.num_tests_use_this_fixture[0] += 1 + + +def _remove_global_fixture_definitions(items): + # Helper function, removes fixture definitions from the global + # variables of the modules they were defined in. This is intended + # to improve readability of error messages by giving a NameError + # if a test function accesses a pytest fixture but doesn't include + # it as an argument. Should be called from + # pytest_collection_modifyitems(). + + modules = set(item.module for item in items) + + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if hasattr(obj, "_pytestfixturefunction") and isinstance( + obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker + ): + delattr(module, name) + + +def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None): + # Include unrunnable targets here. They get skipped by the + # pytest.mark.skipif in _target_to_requirement(), showing up as + # skipped tests instead of being hidden entirely. + if targets is None: + if excluded_targets is None: + excluded_targets = set() + + if xfail_targets is None: + xfail_targets = set() + + target_marks = [] + for t in utils._get_targets(): + # Excluded targets aren't included in the params at all. + if t["target_kind"] not in excluded_targets: + + # Known failing targets are included, but are marked + # as expected to fail. + extra_marks = [] + if t["target_kind"] in xfail_targets: + extra_marks.append( + pytest.mark.xfail( + reason='Known failing test for target "{}"'.format(t["target_kind"]) + ) + ) + + target_marks.append((t["target"], extra_marks)) + + else: + target_marks = [(target, []) for target in targets] + + return [ + pytest.param(target, marks=_target_to_requirement(target) + extra_marks) + for target, extra_marks in target_marks + ] + + +def _target_to_requirement(target): + if isinstance(target, str): + target = tvm.target.Target(target) + + # mapping from target to decorator + if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []): + return utils.requires_cudnn() + if target.kind.name == "cuda": + return utils.requires_cuda() + if target.kind.name == "rocm": + return utils.requires_rocm() + if target.kind.name == "vulkan": + return utils.requires_vulkan() + if target.kind.name == "nvptx": + return utils.requires_nvptx() + if target.kind.name == "metal": + return utils.requires_metal() + if target.kind.name == "opencl": + return utils.requires_opencl() + if target.kind.name == "llvm": + return utils.requires_llvm() + return [] + + +def _parametrize_correlated_parameters(metafunc): + parametrize_needed = collections.defaultdict(list) + + for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items(): + fixturedef = fixturedefs[-1] + if hasattr(fixturedef.func, "parametrize_group") and hasattr( + fixturedef.func, "parametrize_values" + ): + group = fixturedef.func.parametrize_group + values = fixturedef.func.parametrize_values + parametrize_needed[group].append((name, values)) + + for parametrize_group in parametrize_needed.values(): + if len(parametrize_group) == 1: + name, values = parametrize_group[0] + metafunc.parametrize(name, values, indirect=True) + else: + names = ",".join(name for name, values in parametrize_group) + value_sets = zip(*[values for name, values in parametrize_group]) + metafunc.parametrize(names, value_sets, indirect=True) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 04a235b64fdf..6f115f8da58c 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -16,7 +16,14 @@ # under the License. # pylint: disable=invalid-name,unnecessary-comprehension -""" TVM testing utilities +"""TVM testing utilities + +Organization +************ + +This file contains functions expected to be called directly by a user +while writing unit tests. Integrations with the pytest framework +are in plugin.py. Testing Markers *************** @@ -53,8 +60,8 @@ def test_something(): fpgas), we need to add a new marker in `tests/python/pytest.ini` and a new function in this module. Then targets using this node should be added to the `TVM_TEST_TARGETS` environment variable in the CI. + """ -import collections import copy import copyreg import ctypes @@ -65,7 +72,6 @@ def test_something(): import time import pickle import pytest -import _pytest import numpy as np import tvm import tvm.arith @@ -768,153 +774,6 @@ def requires_rpc(*args): return _compose(args, _requires_rpc) -def _target_to_requirement(target): - if isinstance(target, str): - target = tvm.target.Target(target) - - # mapping from target to decorator - if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []): - return requires_cudnn() - if target.kind.name == "cuda": - return requires_cuda() - if target.kind.name == "rocm": - return requires_rocm() - if target.kind.name == "vulkan": - return requires_vulkan() - if target.kind.name == "nvptx": - return requires_nvptx() - if target.kind.name == "metal": - return requires_metal() - if target.kind.name == "opencl": - return requires_opencl() - if target.kind.name == "llvm": - return requires_llvm() - return [] - - -def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None): - # Include unrunnable targets here. They get skipped by the - # pytest.mark.skipif in _target_to_requirement(), showing up as - # skipped tests instead of being hidden entirely. - if targets is None: - if excluded_targets is None: - excluded_targets = set() - - if xfail_targets is None: - xfail_targets = set() - - target_marks = [] - for t in _get_targets(): - # Excluded targets aren't included in the params at all. - if t["target_kind"] not in excluded_targets: - - # Known failing targets are included, but are marked - # as expected to fail. - extra_marks = [] - if t["target_kind"] in xfail_targets: - extra_marks.append( - pytest.mark.xfail( - reason='Known failing test for target "{}"'.format(t["target_kind"]) - ) - ) - - target_marks.append((t["target"], extra_marks)) - - else: - target_marks = [(target, []) for target in targets] - - return [ - pytest.param(target, marks=_target_to_requirement(target) + extra_marks) - for target, extra_marks in target_marks - ] - - -def _auto_parametrize_target(metafunc): - """Automatically applies parametrize_targets - - Used if a test function uses the "target" fixture, but isn't - already marked with @tvm.testing.parametrize_targets. Intended - for use in the pytest_generate_tests() handler of a conftest.py - file. - - """ - - def update_parametrize_target_arg( - argnames, - argvalues, - *args, - **kwargs, - ): - args = [arg.strip() for arg in argnames.split(",") if arg.strip()] - if "target" in args: - target_i = args.index("target") - - new_argvalues = [] - for argvalue in argvalues: - - if isinstance(argvalue, _pytest.mark.structures.ParameterSet): - # The parametrized value is already a - # pytest.param, so track any marks already - # defined. - param_set = argvalue.values - target = param_set[target_i] - additional_marks = argvalue.marks - elif len(args) == 1: - # Single value parametrization, argvalue is a list of values. - target = argvalue - param_set = (target,) - additional_marks = [] - else: - # Multiple correlated parameters, argvalue is a list of tuple of values. - param_set = argvalue - target = param_set[target_i] - additional_marks = [] - - new_argvalues.append( - pytest.param( - *param_set, marks=_target_to_requirement(target) + additional_marks - ) - ) - - try: - argvalues[:] = new_argvalues - except TypeError as e: - pyfunc = metafunc.definition.function - filename = pyfunc.__code__.co_filename - line_number = pyfunc.__code__.co_firstlineno - msg = ( - f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) " - "is parametrized using a tuple of parameters instead of a list " - "of parameters." - ) - raise TypeError(msg) from e - - if "target" in metafunc.fixturenames: - # Update any explicit use of @pytest.mark.parmaetrize to - # parametrize over targets. This adds the appropriate - # @tvm.testing.requires_* markers for each target. - for mark in metafunc.definition.iter_markers("parametrize"): - update_parametrize_target_arg(*mark.args, **mark.kwargs) - - # Check if any explicit parametrizations exist, and apply one - # if they do not. If the function is marked with either - # excluded or known failing targets, use these to determine - # the targets to be used. - parametrized_args = [ - arg.strip() - for mark in metafunc.definition.iter_markers("parametrize") - for arg in mark.args[0].split(",") - ] - if "target" not in parametrized_args: - excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", []) - xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", []) - metafunc.parametrize( - "target", - _pytest_target_params(None, excluded_targets, xfail_targets), - scope="session", - ) - - def parametrize_targets(*args): """Parametrize a test over a specific set of targets. @@ -1164,28 +1023,6 @@ def fixture_func(*_cls, request): return outputs -def _parametrize_correlated_parameters(metafunc): - parametrize_needed = collections.defaultdict(list) - - for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items(): - fixturedef = fixturedefs[-1] - if hasattr(fixturedef.func, "parametrize_group") and hasattr( - fixturedef.func, "parametrize_values" - ): - group = fixturedef.func.parametrize_group - values = fixturedef.func.parametrize_values - parametrize_needed[group].append((name, values)) - - for parametrize_group in parametrize_needed.values(): - if len(parametrize_group) == 1: - name, values = parametrize_group[0] - metafunc.parametrize(name, values, indirect=True) - else: - names = ",".join(name for name, values in parametrize_group) - value_sets = zip(*[values for name, values in parametrize_group]) - metafunc.parametrize(names, value_sets, indirect=True) - - def fixture(func=None, *, cache_return_value=False): """Convenience function to define pytest fixtures. @@ -1319,7 +1156,9 @@ def _fixture_cache(func): # Can't use += on a bound method's property. Therefore, this is a # list rather than a variable so that it can be accessed from the # pytest_collection_modifyitems(). - num_uses_remaining = [0] + num_tests_use_this_fixture = [0] + + num_times_fixture_used = 0 # Using functools.lru_cache would require the function arguments # to be hashable, which wouldn't allow caching fixtures that @@ -1344,6 +1183,14 @@ def get_cache_key(*args, **kwargs): @functools.wraps(func) def wrapper(*args, **kwargs): + if num_tests_use_this_fixture[0] == 0: + raise RuntimeError( + "Fixture use count is 0. " + "This can occur if tvm.testing.plugin isn't registered. " + "If using outside of the TVM test directory, " + "please add `pytest_plugins = ['tvm.testing.plugin']` to your conftest.py" + ) + try: cache_key = get_cache_key(*args, **kwargs) @@ -1364,52 +1211,17 @@ def wrapper(*args, **kwargs): finally: # Clear the cache once all tests that use a particular fixture # have completed. - num_uses_remaining[0] -= 1 - if not num_uses_remaining[0]: + nonlocal num_times_fixture_used + num_times_fixture_used += 1 + if num_times_fixture_used >= num_tests_use_this_fixture[0]: cache.clear() - # Set in the pytest_collection_modifyitems() - wrapper.num_uses_remaining = num_uses_remaining + # Set in the pytest_collection_modifyitems(), by _count_num_fixture_uses + wrapper.num_tests_use_this_fixture = num_tests_use_this_fixture return wrapper -def _count_num_fixture_uses(items): - # Helper function, counts the number of tests that use each cached - # fixture. Should be called from pytest_collection_modifyitems(). - for item in items: - is_skipped = item.get_closest_marker("skip") or any( - mark.args[0] for mark in item.iter_markers("skipif") - ) - if is_skipped: - continue - - for fixturedefs in item._fixtureinfo.name2fixturedefs.values(): - # Only increment the active fixturedef, in a name has been overridden. - fixturedef = fixturedefs[-1] - if hasattr(fixturedef.func, "num_uses_remaining"): - fixturedef.func.num_uses_remaining[0] += 1 - - -def _remove_global_fixture_definitions(items): - # Helper function, removes fixture definitions from the global - # variables of the modules they were defined in. This is intended - # to improve readability of error messages by giving a NameError - # if a test function accesses a pytest fixture but doesn't include - # it as an argument. Should be called from - # pytest_collection_modifyitems(). - - modules = set(item.module for item in items) - - for module in modules: - for name in dir(module): - obj = getattr(module, name) - if hasattr(obj, "_pytestfixturefunction") and isinstance( - obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker - ): - delattr(module, name) - - def identity_after(x, sleep): """Testing function to return identity after sleep diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index 8885f55bbf4b..4c9c5d91901a 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -199,7 +199,7 @@ def test_num_uses_cached(self): class TestAutomaticMarks: @staticmethod def check_marks(request, target): - parameter = tvm.testing._pytest_target_params([target])[0] + parameter = tvm.testing.plugin._pytest_target_params([target])[0] required_marks = [decorator.mark for decorator in parameter.marks] applied_marks = list(request.node.iter_markers()) @@ -239,6 +239,11 @@ def uncacheable_fixture(self): return self.EmptyClass() def test_uses_uncacheable(self, request): + # Normally the num_tests_use_this_fixture would be set before + # anything runs. For this test case only, because we are + # delaying the use of the fixture, we need to manually + # increment it. + self.uncacheable_fixture.num_tests_use_this_fixture[0] += 1 with pytest.raises(TypeError): request.getfixturevalue("uncacheable_fixture") From f188a4fb11971c9bfce9b059fd2b9dacdbe1a0d1 Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Fri, 27 Aug 2021 23:29:42 +0100 Subject: [PATCH 24/42] Remove AOT Executor header from Arduino project (#8857) --- apps/microtvm/arduino/example_project/src/model.c | 1 - apps/microtvm/arduino/host_driven/src/model_support.c | 1 - 2 files changed, 2 deletions(-) diff --git a/apps/microtvm/arduino/example_project/src/model.c b/apps/microtvm/arduino/example_project/src/model.c index 77566ffc6a8f..9e7c47f75160 100644 --- a/apps/microtvm/arduino/example_project/src/model.c +++ b/apps/microtvm/arduino/example_project/src/model.c @@ -20,7 +20,6 @@ #include "model.h" #include "Arduino.h" -#include "standalone_crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h" #include "standalone_crt/include/tvm/runtime/crt/stack_allocator.h" // AOT memory array diff --git a/apps/microtvm/arduino/host_driven/src/model_support.c b/apps/microtvm/arduino/host_driven/src/model_support.c index ae467441fede..dfcb031136c5 100644 --- a/apps/microtvm/arduino/host_driven/src/model_support.c +++ b/apps/microtvm/arduino/host_driven/src/model_support.c @@ -17,7 +17,6 @@ * under the License. */ -#include "standalone_crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h" #include "stdarg.h" // Blink code for debugging purposes From 1df6c273f0fb1242d0b399614616635cef38bc15 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 27 Aug 2021 19:33:58 -0700 Subject: [PATCH 25/42] [Community] @mdw-octoml -> Reviewer (#8868) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 7b2d5dc29dad..8398bdd5e0a2 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -139,6 +139,7 @@ We do encourage everyone to work anything they are interested in. - [Leyuan Wang](https://github.com/Laurawly): @Laurawly - [Alex Weaver](https://github.com/alex-weaver): @alex-weaver - [Logan Weber](https://github.com/weberlo): @weberlo +- [Matt Welsh](https://github.com/mdw-octoml): @mdw-octoml - [Jian Weng](https://github.com/were): @were - [Yong Wu](https://github.com/yongwww): @yongwww - [Zhao Wu](https://github.com/FrozenGene): @FrozenGene From 7214f5239dbb8da4585d4d10fbc8c65c8f155b12 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 28 Aug 2021 17:23:43 +0800 Subject: [PATCH 26/42] [TIR] Fix opaque access in buffer locator pass and match_buffer in region detector (#8855) * init * fix * Update src/tir/transforms/plan_update_buffer_allocation_location.cc Co-authored-by: Ruihang Lai * Update src/tir/transforms/plan_update_buffer_allocation_location.cc Co-authored-by: Ruihang Lai * address Co-authored-by: Junru Shao Co-authored-by: Ruihang Lai --- .../analysis/block_access_region_detector.cc | 7 ++- .../plan_update_buffer_allocation_location.cc | 39 ++++++++---- ...st_tir_analysis_get_block_access_region.py | 21 +++++-- ..._plan_update_buffer_allocation_location.py | 62 +++++++++++++++++++ 4 files changed, 109 insertions(+), 20 deletions(-) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 8f87ef920784..dd01aed61c52 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -110,8 +110,11 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) { ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); for (const MatchBufferRegion& match_buffer : block->match_buffers) { const Var& target_var = match_buffer->buffer->data; - match_buffers_[target_var.get()] = match_buffer; - buffer_var_map_.Set(target_var, match_buffer->buffer); + const Var& source_var = match_buffer->source->buffer->data; + if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) { + match_buffers_[target_var.get()] = match_buffer; + buffer_var_map_.Set(target_var, match_buffer->buffer); + } } StmtExprVisitor::operator()(stmt); } diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index bee11ad72280..59f9170786b6 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -75,8 +75,6 @@ class BufferAllocationLocator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { ICHECK(!op->init.defined()); - bool is_root = is_root_; - is_root_ = false; Array alloc_buffers; auto it = alloc_buffers_.find(op); if (it != alloc_buffers_.end()) { @@ -85,11 +83,23 @@ class BufferAllocationLocator : public StmtExprMutator { buffer_data_to_buffer_.Set(buf->data, buf); } } + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + const Var& source_var = match_buffer->source->buffer->data; + ICHECK(buffer_data_to_buffer_.count(source_var)); + buffer_data_to_buffer_.Set(target_var, match_buffer->buffer); + } Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - // Ignore buffer allocated inside the block when getting access region. + // No longer consider buffers created by match_buffer inside the block when updating access + // region. + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + buffer_data_to_buffer_.erase(target_var); + } + // No longer consider buffers allocated inside the block when updating access region. if (it != alloc_buffers_.end()) { for (const Buffer& buf : it->second) { buffer_data_to_buffer_.erase(buf->data); @@ -98,12 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator { ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); - // The read/write regions of root block are always empty. - if (!is_root) { - // Recalculate block access region - CollectReadWrite(GetRef(op), &n->reads, &n->writes); - } - + // Erase buffer allocated inside the block from access region. + n->reads = RemoveRedundantBufferRegion(n->reads); + n->writes = RemoveRedundantBufferRegion(n->writes); return Stmt(n); } @@ -127,8 +134,18 @@ class BufferAllocationLocator : public StmtExprMutator { return std::move(realize); } + Array RemoveRedundantBufferRegion(const Array& region) const { + Array result; + for (const BufferRegion& buffer_region : region) { + if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) { + result.push_back(buffer_region); + } + } + return result; + } + void CollectReadWrite(const Block& block, Array* reads, - Array* writes) { + Array* writes) const { Array> access = GetBlockAccessRegion(block, buffer_data_to_buffer_); *reads = access[0]; *writes = access[1]; @@ -142,8 +159,6 @@ class BufferAllocationLocator : public StmtExprMutator { std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ Map buffer_data_to_buffer_; - /*! \brief indicate the whether the block is root. */ - bool is_root_{true}; }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 7641f0ac46cb..9c95b9819e6f 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -114,20 +114,29 @@ def test_match_buffer(): root_block = match_buffer_func.body.block block = root_block.body.body.body.block block_inner = block.body[0].body.body.block - alloc_buffers = func.body.block.alloc_buffers + alloc_buffers = match_buffer_func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - # Check inner block AAA - ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) - tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) - tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) - # Check block ret = tir.analysis.get_block_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.writes, ret[1]) # B is opaque access tvm.ir.assert_structural_equal(block.reads, ret[2]) + # Check inner block AAA without updating buffer_var_map + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + # Since AA is not in the buffer_var_map, region of AA will not be collected. + tvm.ir.assert_structural_equal([], ret[1]) + + # Check inner block AAA + for match_buffer in block.match_buffers: + target_buffer = match_buffer.buffer + buffer_var_map[target_buffer.data] = target_buffer + + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) + tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) + if __name__ == "__main__": test_block_access_region_detector() diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 8418e192d060..07140ab458e6 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -137,6 +137,63 @@ def transformed_match_buffer_func() -> None: C1[()] = 0 +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [1024]) + B = tir.match_buffer(b, [1024]) + A_cache = tir.alloc_buffer([1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [vi]: + with tir.block([8]) as [v]: + tir.bind(v, vi) + tir.reads([A[(v * 128) : ((v * 128) + 128)]]) + tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) + tir.evaluate( + tir.call_extern( + "test", + A_cache.data, + (v * 128), + 128, + A.data, + (v * 128), + 128, + dtype="float32", + ) + ) + for j in tir.serial(0, 128): + with tir.block([1024]) as [v]: + tir.bind(v, ((vi * 128) + j)) + tir.reads([A_cache[v]]) + tir.writes([B[v]]) + B[v] = A_cache[v] + + +@tvm.script.tir +def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [1024]) + B = tir.match_buffer(b, [1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [vi]: + tir.reads(A[vi * 128 : vi * 128 + 128]) + tir.writes(B[vi * 128 : vi * 128 + 128]) + A_cache = tir.alloc_buffer([1024]) + with tir.block([8]) as [v]: + tir.bind(v, vi) + tir.reads([A[v * 128 : v * 128 + 128]]) + tir.writes([A_cache[v * 128 : v * 128 + 128]]) + tir.evaluate( + tir.call_extern( + "test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32" + ) + ) + for j in tir.serial(0, 128): + with tir.block([1024]) as [v]: + tir.bind(v, ((vi * 128) + j)) + tir.reads([A_cache[v]]) + tir.writes([B[v]]) + B[v] = A_cache[v] + + def test_elementwise(): _check(element_func, transformed_element_func) @@ -149,6 +206,10 @@ def test_match_buffer_allocation(): _check(match_buffer_func, transformed_match_buffer_func) +def test_opaque_access(): + _check(opaque_access, transformed_opaque_access) + + def test_lower_te(): x = te.placeholder((1,)) y = te.compute((1,), lambda i: x[i] + 2) @@ -164,4 +225,5 @@ def test_lower_te(): test_elementwise() test_locate_buffer_allocation() test_match_buffer_allocation() + test_opaque_access() test_lower_te() From 5ab527a71f7eb1d352db1408b225c79a21945c94 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Sat, 28 Aug 2021 02:24:16 -0700 Subject: [PATCH 27/42] [Autoscheduler] Configurable workload keys (#8862) * change workload keys * remove binary string comparison * append the tuple not every integer * clean up * lint * dump workload keys to dags * fix things * change some strings * misc fixes, add tests * jostle ci --- python/tvm/auto_scheduler/compute_dag.py | 15 +++++-- .../tvm/auto_scheduler/relay_integration.py | 11 ++++- .../test_auto_scheduler_task_extraction.py | 44 ++++++++++++++++++- 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index f7a5f39c829a..c212d143f987 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -222,7 +222,7 @@ def rewrite_layout_from_state(self, state): def workload_key(self): """Return the workload key of this compute DAG. - The workload key is a JSON string from a tuple of (hash-key, tensor shapes...) + The workload key is a JSON string from a tuple of (hash of DAG, tensor shapes...) Returns ------- @@ -230,12 +230,19 @@ def workload_key(self): The workload key of this compute DAG """ str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) - str_dag = str_dag.encode(encoding="utf-8") - hash_key = hashlib.md5(str_dag).hexdigest() + hash_func = tvm._ffi.get_global_func( + "auto_scheduler.compute_dag.hash_func", allow_missing=True + ) + + if hash_func is None: + str_dag = str_dag.encode("utf-8") + hash_key = hashlib.md5(str_dag).hexdigest() + else: + hash_key = hash_func(str_dag) io_shapes = [] for tensor in self.tensors: - io_shapes += get_const_tuple(tensor.shape) + io_shapes.append(get_const_tuple(tensor.shape)) return json.dumps([hash_key] + io_shapes) def __str__(self): diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 850e50004337..8b68f4e9002a 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -22,6 +22,7 @@ 2. Provide auto-scheduling for all TOPI compute functions """ +import json import logging import threading from copy import deepcopy @@ -30,11 +31,10 @@ from tvm import autotvm, transform from tvm.ir.transform import PassContext from tvm.runtime import convert_to_object - +from tvm.target import Target from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import Reduce from tvm.tir import expr as _expr -from tvm.target import Target from . import _ffi_api from .compute_dag import ComputeDAG, LayoutRewriteOption @@ -97,6 +97,7 @@ def extract_tasks( target_host=None, hardware_params=None, include_simple_tasks=False, + dump_workload_to_dag_log=None, opt_level=3, ): """Extract tuning tasks from a relay program. @@ -115,6 +116,8 @@ def extract_tasks( Hardware parameters used for the search tasks include_simple_tasks: bool Whether to extract simple tasks that do not include complicated ops. + dump_workload_to_dag_log: Optional[str] + A file to dump an association between the workload keys and the actual DAG opt_level : Optional[int] The optimization level of the task extractions. @@ -170,6 +173,10 @@ def extract_tasks( ) weights.append(weight) + if dump_workload_to_dag_log is not None: + with open(dump_workload_to_dag_log, "w") as f: + json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f) + return tasks, weights diff --git a/tests/python/relay/test_auto_scheduler_task_extraction.py b/tests/python/relay/test_auto_scheduler_task_extraction.py index 39596186d211..a53b68cca885 100644 --- a/tests/python/relay/test_auto_scheduler_task_extraction.py +++ b/tests/python/relay/test_auto_scheduler_task_extraction.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. """Test task extraction for auto-scheduler""" -import pytest +import json +import tempfile +import pytest import tvm.relay.testing import tvm.testing +from tvm import _ffi as _ffi_api from tvm import auto_scheduler, relay @@ -248,5 +251,44 @@ def verify_task_extraction(func_name, expected_task, include_simple_tasks=False) verify_task_extraction(*params) +def test_dump_workload_to_dag_extract_tasks(): + mod, _ = get_network("mobilenet", layout="NHWC") + with tempfile.NamedTemporaryFile() as f: + tasks, _ = auto_scheduler.extract_tasks( + mod["main"], None, "llvm", include_simple_tasks=True, dump_workload_to_dag_log=f.name + ) + expected = {task.workload_key: str(task.compute_dag) for task in tasks} + actual = json.load(f) + assert expected == actual + + +def test_custom_hash_func_extract_tasks(): + @_ffi_api.register_func("auto_scheduler.compute_dag.hash_func") + def counting_unique_hash(str_dag): + ret = counting_unique_hash.i + counting_unique_hash.i += 1 + return ret + + counting_unique_hash.i = 0 + + mod, _ = get_network("mobilenet", layout="NHWC") + tasks, _ = auto_scheduler.extract_tasks(mod["main"], None, "llvm", include_simple_tasks=True) + + hash_values = [] + for task in tasks: + # task.workload_key should look like + # [43, [3, 3, 1024, 1], [1024], [3, 3, 1024, 1]] where the first int is the result of the hash + # Extract the hash and keep track of every hash + hash_value = int(task.workload_key[1:].split(",")[0]) + hash_values.append(hash_value) + + # All values are unique, and we know the min and max + # This is a sufficient condition to know that hashes in hash_values are an increasing list + # of hashes up to counting_unique_hash.i - 1 + assert len(hash_values) == len(set(hash_values)) + assert min(hash_values) == 0 + assert max(hash_values) == counting_unique_hash.i - 1 + + if __name__ == "__main__": pytest.main([__file__]) From 0961b65cbf0d6e1c5f51e0e88dd17886d6111522 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Sat, 28 Aug 2021 04:28:07 -0500 Subject: [PATCH 28/42] [Tutorial][Executor] Fix the usage of executors in tutorials (#8586) * fix: executor usage for keras tutorial * fix: executor usage for onnx tutorial * [Tutorial][Executor] Fix executors in tutorials --- tutorials/dev/bring_your_own_datatypes.py | 3 ++- tutorials/frontend/from_keras.py | 4 ++-- tutorials/frontend/from_onnx.py | 6 ++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tutorials/dev/bring_your_own_datatypes.py b/tutorials/dev/bring_your_own_datatypes.py index a5e8e2898d39..1cf556ddd056 100644 --- a/tutorials/dev/bring_your_own_datatypes.py +++ b/tutorials/dev/bring_your_own_datatypes.py @@ -257,8 +257,9 @@ def get_cat_image(): ###################################################################### # It's easy to execute MobileNet with native TVM: +ex = tvm.relay.create_executor("graph", mod=module, params=params) input = get_cat_image() -result = tvm.relay.create_executor("graph", mod=module).evaluate()(input, **params).numpy() +result = ex.evaluate()(input).numpy() # print first 10 elements print(result.flatten()[:10]) diff --git a/tutorials/frontend/from_keras.py b/tutorials/frontend/from_keras.py index e62836d2ccfe..182e769b35b1 100644 --- a/tutorials/frontend/from_keras.py +++ b/tutorials/frontend/from_keras.py @@ -103,14 +103,14 @@ # due to a latent bug. Note that the pass context only has an effect within # evaluate() and is not captured by create_executor(). with tvm.transform.PassContext(opt_level=0): - model = relay.build_module.create_executor("graph", mod, dev, target).evaluate() + model = relay.build_module.create_executor("graph", mod, dev, target, params).evaluate() ###################################################################### # Execute on TVM # --------------- dtype = "float32" -tvm_out = model(tvm.nd.array(data.astype(dtype)), **params) +tvm_out = model(tvm.nd.array(data.astype(dtype))) top1_tvm = np.argmax(tvm_out.numpy()[0]) ##################################################################### diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py index 890bfbac4d8a..fd51d7a76992 100644 --- a/tutorials/frontend/from_onnx.py +++ b/tutorials/frontend/from_onnx.py @@ -92,13 +92,15 @@ mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) with tvm.transform.PassContext(opt_level=1): - compiled = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target).evaluate() + executor = relay.build_module.create_executor( + "graph", mod, tvm.cpu(0), target, params + ).evaluate() ###################################################################### # Execute on TVM # --------------------------------------------- dtype = "float32" -tvm_output = compiled(tvm.nd.array(x.astype(dtype)), **params).numpy() +tvm_output = executor(tvm.nd.array(x.astype(dtype))).numpy() ###################################################################### # Display results From 2545e9caecadd66c72fbb6734c30d100e823b0fb Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sat, 28 Aug 2021 12:59:20 -0700 Subject: [PATCH 29/42] [Frontend][Onnx] Simplify onnx input since name accesses are not reliable. (#8867) * Simplify onnx input since name accesses are no longer supported. * move Celu importer. --- python/tvm/relay/frontend/onnx.py | 82 +++++++++---------------------- 1 file changed, 22 insertions(+), 60 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5471f67ea106..9144d3e145c8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -63,54 +63,16 @@ } -class onnx_input: - """Dual purpose list or dictionary access object.""" - - def __init__(self): - self.input_keys = [] - self.input_dict = {} +class onnx_input(list): + """A helper extension to list that returns None for out of bound indices.""" def __getitem__(self, item): - if isinstance(item, int): - if item > (len(self.input_keys) - 1): - return None - return self.input_dict[self.input_keys[item]] - if isinstance(item, str): - if item not in self.input_keys: - return None - return self.input_dict[item] if isinstance(item, slice): - keys = self.input_keys[item] - return [self.input_dict[key] for key in keys] - - raise ValueError("Only integer, string, and slice accesses allowed.") - - def __setitem__(self, item, value): + indices = list(range(item.stop)[item]) + return [self[i] for i in indices] if isinstance(item, int): - self.input_dict[self.input_keys[item]] = value - elif isinstance(item, str): - self.input_keys.append(item) - self.input_dict[item] = value - else: - raise ValueError("Only integer and string indexed writes allowed.") - - def keys(self): - return self.input_keys - - def __len__(self): - return len(self.input_keys) - - def __iter__(self): - self.n = 0 - return self - - def __next__(self): - if self.n < len(self.input_keys): - output = self.input_dict[self.input_keys[self.n]] - self.n += 1 - return output - - raise StopIteration + return list(self)[item] if item < len(self) else None + raise TypeError("list indices must be integers or slices, not %s" % type(item).__name__) def get_numpy(tensor_proto): @@ -2672,6 +2634,19 @@ def _impl_v10(cls, inputs, attr, params): return isinf +class Celu(OnnxOpConverter): + """Operator convereter for celu""" + + @classmethod + def _impl_v12(cls, inputs, attr, params): + x = inputs[0] + dtype = infer_type(x).checked_type.dtype + alpha = _op.const(attr.get("alpha", 1.0), dtype) + zero = _op.const(0, dtype) + one = _op.const(1, dtype) + return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) + + class MaxRoiPool(OnnxOpConverter): """Operator converter for MaxRoiPool.""" @@ -3822,13 +3797,13 @@ def from_onnx(self, graph, opset, get_output_expr=False): for node in graph.node: op_name = node.op_type attr = self._parse_attr(node.attribute) - # Create and populate onnx input object. + # Create and populate input list. inputs = onnx_input() for i in node.input: if i != "": - inputs[i] = self._nodes[self._renames.get(i, i)] + inputs.append(self._nodes[self._renames.get(i, i)]) else: - inputs[i] = None + inputs.append(None) i_name = self._parse_value_proto(node) node_output = self._fix_outputs(op_name, node.output) attr["tvm_custom"] = {} @@ -3981,19 +3956,6 @@ def _fix_outputs(self, op_name, outputs): return outputs -class Celu(OnnxOpConverter): - """Operator convereter for celu""" - - @classmethod - def _impl_v12(cls, inputs, attr, params): - x = inputs[0] - dtype = infer_type(x).checked_type.dtype - alpha = _op.const(attr.get("alpha", 1.0), dtype) - zero = _op.const(0, dtype) - one = _op.const(1, dtype) - return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) - - def from_onnx( model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None ): From 27d3d605f340889732501ec0bcb0a34c2a49c11b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 30 Aug 2021 00:30:16 +0800 Subject: [PATCH 30/42] [TIR] GetBlockReadWriteRegion (#8875) * [TIR] GetBlockReadWriteRegion * Fix black issue * Use constant reference for the interface * Fix lint issue --- include/tvm/tir/analysis.h | 19 ++++++++--- python/tvm/tir/analysis/analysis.py | 24 ++++++++++++- .../analysis/block_access_region_detector.cc | 34 ++++++++++++++++++- src/tir/schedule/primitive/compute_inline.cc | 2 +- .../plan_update_buffer_allocation_location.cc | 16 +++------ ...st_tir_analysis_get_block_access_region.py | 29 ++++++++++++++++ .../unittest/test_tir_schedule_reduction.py | 1 - 7 files changed, 105 insertions(+), 20 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index dce9736adec7..51bdb18d2217 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -156,8 +156,8 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); /*! - * \brief Auto detect the block read/write region according to body stmt - * It will detect the read/write region as an array in order of appearance in AST + * \brief Auto detect the block access region according to its body stmt + * It will detect the access region as an array in order of appearance in AST * \param block The block to be detected * \param buffer_var_map The outside buffers which may be accessed the block. * It is a map from buffer var to the buffer. @@ -167,8 +167,19 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constrain * - second: write regions * - third: opaque regions */ -Array> GetBlockAccessRegion(const Block& block, - const Map& buffer_var_map); +TVM_DLL Array> GetBlockAccessRegion(const Block& block, + const Map& buffer_var_map); + +/*! + * \brief Auto detect the block read/write region according to its body stmt. An opaque access will + * be counted as both a read and a write access + * \param block The block to be detected + * \param buffer_var_map The outside buffers which may be accessed the block. + * It is a map from buffer var to the buffer + * \return An array only consisting of the read regions and write regions of the input block + */ +TVM_DLL Array> GetBlockReadWriteRegion(const Block& block, + const Map& buffer_var_map); /*! * \brief Calculate the expresion complexity based on number of symbols it contains. diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 500195ac9a13..d1aaa61c3aae 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -136,7 +136,29 @@ def get_block_access_region( - second: write regions - third: opaque regions """ - return _ffi_api.get_block_access_region(block, buffer_var_map) # type: ignore + return _ffi_api.GetBlockAccessRegion(block, buffer_var_map) # type: ignore + + +def get_block_read_write_region( + block: Block, buffer_var_map: Dict[Var, Buffer] +) -> List[List[BufferRegion]]: + """Auto detect the block read/write region according to its body stmt. + An opaque access will be counted as both a read and a write access + + Parameters + ---------- + block: tvm.tir.Block + The block in which we are detecting read/write regions. + + buffer_var_map : Dict[Var, Buffer] + The outside buffers which may access the block. Mapping from buffer var to the buffer + + Returns + ------- + result : List[List[BufferRegion]] + An array only consisting of the read regions and write regions of the input block + """ + return _ffi_api.GetBlockReadWriteRegion(block, buffer_var_map) # type: ignore def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> int: diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index dd01aed61c52..90aaa35d60d8 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -285,7 +285,39 @@ Array> GetBlockAccessRegion(const Block& block, return {detector.CollectReads(), detector.CollectWrites(), detector.CollectOpaques()}; } -TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion); +Array> GetBlockReadWriteRegion(const Block& block, + const Map& buffer_var_map) { + // Step 1. Get all the read/write/opaque accesses in the input block. + Array> access_regions = GetBlockAccessRegion(block, buffer_var_map); + // Step 2. Collect all the buffers that are opaquely accessed. + std::unordered_set opaque_accessed_buffers; + for (const BufferRegion& opaque_access : access_regions[2]) { + opaque_accessed_buffers.insert(opaque_access->buffer.get()); + } + // Step 3. Create new arrays of read/write regions. + Array new_read_regions; + Array new_write_regions; + new_read_regions.reserve(access_regions[0].size() + access_regions[2].size()); + new_write_regions.reserve(access_regions[1].size() + access_regions[2].size()); + for (const BufferRegion& read_access : access_regions[0]) { + if (!opaque_accessed_buffers.count(read_access->buffer.get())) { + new_read_regions.push_back(read_access); + } + } + for (const BufferRegion& write_access : access_regions[1]) { + if (!opaque_accessed_buffers.count(write_access->buffer.get())) { + new_write_regions.push_back(write_access); + } + } + for (const BufferRegion& opaque_access : access_regions[2]) { + new_read_regions.push_back(opaque_access); + new_write_regions.push_back(opaque_access); + } + return {new_read_regions, new_write_regions}; +} + +TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); +TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 2583b21227e4..9c88cc1e787a 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -409,7 +409,7 @@ class BaseInliner : public StmtExprMutator { Array reads = std::move(block->reads); Array writes = std::move(block->writes); if (!is_scope_root) { - Array> inspected = GetBlockAccessRegion(block, buffer_var_map_); + Array> inspected = GetBlockReadWriteRegion(block, buffer_var_map_); reads = std::move(inspected[0]); writes = std::move(inspected[1]); } diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 59f9170786b6..97153aedc6a3 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -129,7 +129,10 @@ class BufferAllocationLocator : public StmtExprMutator { /*init=*/NullOpt, /*alloc_buffers=*/alloc_buffers); ObjectPtr n = CopyOnWrite(opaque_block.get()); - CollectReadWrite(opaque_block, &n->reads, &n->writes); + Array> access = + GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); + n->reads = access[0]; + n->writes = access[1]; BlockRealize realize({}, Bool(true), Block(n)); return std::move(realize); } @@ -144,17 +147,6 @@ class BufferAllocationLocator : public StmtExprMutator { return result; } - void CollectReadWrite(const Block& block, Array* reads, - Array* writes) const { - Array> access = GetBlockAccessRegion(block, buffer_data_to_buffer_); - *reads = access[0]; - *writes = access[1]; - for (const auto& opaque_access : access[2]) { - reads->push_back(opaque_access); - writes->push_back(opaque_access); - } - } - /*! \brief The map from stmt to the buffers to be allocated under it. */ std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 9c95b9819e6f..bc421aa4d19b 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import tir, script from tvm.ir import Range @@ -81,6 +82,20 @@ def opaque_block_func() -> None: B[i, j] = A[i, j] + 1.0 +@tvm.script.tir +def opaque_access_func() -> None: + A = tir.alloc_buffer([1024]) + B = tir.alloc_buffer([1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [v]: + tir.bind(v, i) + tir.reads([A[v * 128 : v * 128 + 128]]) + tir.writes([B[v * 128 : v * 128 + 128]]) + tir.evaluate( + tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") + ) + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -110,6 +125,19 @@ def test_opaque_block(): tvm.ir.assert_structural_equal(block1.writes, ret[1]) +def test_opaque_access(): + block = opaque_access_func.body.block.body.body.block + alloc_buffers = opaque_access_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + def test_match_buffer(): root_block = match_buffer_func.body.block block = root_block.body.body.body.block @@ -141,4 +169,5 @@ def test_match_buffer(): if __name__ == "__main__": test_block_access_region_detector() test_opaque_block() + test_opaque_access() test_match_buffer() diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 067952899c0a..bc054938d282 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -17,7 +17,6 @@ # pylint: disable=missing-function-docstring,missing-module-docstring import sys -import numpy as np import pytest import tvm import tvm.testing From 06fc788fc15ffbeedc25f73c47595878bc9f8263 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Sun, 29 Aug 2021 18:02:54 -0700 Subject: [PATCH 31/42] [RISCV] Add support for llvm parameter -mabi (-target-abi) (#8860) --- python/tvm/target/__init__.py | 17 +++++++++++- python/tvm/target/target.py | 50 ++++++++++++++++++++++++++++++++++ src/target/llvm/llvm_common.cc | 6 ++++ src/target/llvm/llvm_module.cc | 14 ++++++++-- src/target/target_kind.cc | 1 + 5 files changed, 85 insertions(+), 3 deletions(-) diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 92d72b25b44d..1e906cb381d8 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -43,6 +43,10 @@ such as whether SIMD operations are enabled or not. The default set of attributes is set by the current CPU. +- **-mabi=** + + Generate code for the specified ABI, for example "lp64d". + - **-system-lib** Build TVM system library module. System lib is a global module that contains @@ -55,7 +59,18 @@ We can also use other specific function in this module to create specific targets. """ from .target import Target, create -from .target import cuda, rocm, mali, intel_graphics, arm_cpu, rasp, vta, bifrost, hexagon +from .target import ( + cuda, + rocm, + mali, + intel_graphics, + arm_cpu, + rasp, + vta, + bifrost, + riscv_cpu, + hexagon, +) from .tag import list_tags from .generic_func import GenericFunc from .generic_func import generic_func, get_native_generic_func, override_native_generic_func diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index aa9226101b52..d4b538a4bef0 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -87,6 +87,8 @@ def __init__(self, target, host=None): mfloat-abi : str (optional) An llvm setting that is one of 'hard' or 'soft' indicating whether to use hardware or software floating-point operations. + mabi : str (optional) + An llvm setting. Generate code for the specified ABI, for example "lp64d". host : Union[str, Dict[str, Any]] (optional) Description for target host. Can be recursive. Similar to target. host : Optional[Union[str, Dict[str, Any]]] @@ -413,6 +415,54 @@ def bifrost(model="unknown", options=None): return Target(" ".join(["opencl"] + opts)) +def riscv_cpu(model="sifive-u54", options=None): + """Returns a RISC-V CPU target. + Default: sifive-u54 rv64gc + + Parameters + ---------- + model: str + CPU name. + options : str or list of str + Additional options + """ + trans_table = { + "sifive-e31": [ + "-model=sifive-e31", + "-mtriple=riscv32-unknown-linux-gnu", + "-mcpu=sifive-e31", + "-mabi=ilp32", + # cc: riscv64-unknown-linux-gnu-g++ -march=rv32imac -mabi=ilp32 -mcpu=sifive-e31 + ], + "sifive-e76": [ + "-model=sifive-e76", + "-mtriple=riscv32-unknown-linux-gnu", + "-mcpu=sifive-e76", + "-mabi=ilp32", + # cc: riscv64-unknown-linux-gnu-g++ -march=rv32imafc -mabi=ilp32 -mcpu=sifive-e76 + ], + "sifive-u54": [ + "-model=sifive-u54", + "-mtriple=riscv64-unknown-linux-gnu", + "-mcpu=sifive-u54", + "-mabi=lp64d", + # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u54 + ], + "sifive-u74": [ + "-model=sifive-u74", + "-mtriple=riscv64-unknown-linux-gnu", + "-mcpu=sifive-u74", + "-mabi=lp64d", + # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u74 + ], + } + pre_defined_opt = trans_table.get(model, ["-model=%s" % model]) + + opts = ["-device=arm_cpu"] + pre_defined_opt + opts = _merge_opts(opts, options) + return Target(" ".join(["llvm"] + opts)) + + def hexagon(cpu_ver="v66", **kwargs): """Returns a Hexagon target. diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 61dd7024ff05..be80a8bc767e 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -115,6 +115,9 @@ void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::stri } else { opt.FloatABIType = llvm::FloatABI::Hard; } + if (const Optional& v = target->GetAttr("mabi")) { + opt.MCOptions.ABIName = v.value(); + } } std::unique_ptr GetLLVMTargetMachine(const Target& target, bool allow_null) { @@ -164,6 +167,9 @@ std::string LLVMTargetToString(const Target& target) { if (Optional mfloat_abo = target->GetAttr("mfloat-abi")) { os << " -mfloat-abi=" << mfloat_abo.value(); } + if (Optional mabi = target->GetAttr("mabi")) { + os << " -mabi=" << mabi.value(); + } return os.str(); } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 12c7a3132947..8bdf6d1b0422 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -69,12 +69,22 @@ class LLVMModuleNode final : public runtime::ModuleNode { } else if (name == "get_const_vars") { return PackedFunc(nullptr); } else if (name == "_get_target_triple") { - std::string target_triple = tm_->getTargetTriple().str(); + std::ostringstream target_triple_ss; + target_triple_ss << tm_->getTargetTriple().str(); // getTargetTriple() doesn't include other flags besides the triple. Add back flags which are // important for ModulePackImportsToLLVM. if (tm_->Options.FloatABIType == llvm::FloatABI::ABIType::Soft) { - target_triple += " -mfloat-abi=soft"; + target_triple_ss << " -mfloat-abi=soft"; } + std::string mabi = tm_->Options.MCOptions.ABIName; + if (!mabi.empty()) { + target_triple_ss << " -mabi=" << mabi; + } + llvm::StringRef mcpu = tm_->getTargetCPU(); + if (!mcpu.empty() && mcpu != "generic") { + target_triple_ss << " -mcpu=" << mcpu.str(); + } + std::string target_triple = target_triple_ss.str(); return PackedFunc([target_triple](TVMArgs args, TVMRetValue* rv) { *rv = target_triple; }); } if (ee_ == nullptr) LazyInitJIT(); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 97317b5c4800..d536b2e7b4b4 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -216,6 +216,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mcpu") .add_attr_option("mtriple") .add_attr_option("mfloat-abi") + .add_attr_option("mabi") .add_attr_option("system-lib") .add_attr_option("runtime") .add_attr_option("link-params", Bool(false)) From 421dbf14f44c390bda56ba82e0d992b9ece14bf4 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Sun, 29 Aug 2021 22:32:54 -0700 Subject: [PATCH 32/42] [Community] @manupa-arm -> Committer (#8870) * adding Manupa to the contributors list * re-trigger CI --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8398bdd5e0a2..e7f70379e49d 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,6 +41,7 @@ We do encourage everyone to work anything they are interested in. - [Animesh Jain](https://github.com/anijain2305): @anijain2305 - quantization, relay - [Chenfan Jia](https://github.com/jcf94): @jcf94 - auto_scheduler - [Ziheng Jiang](https://github.com/ZihengJiang) (PMC): @ZihengJiang - relay, compiler +- [Manupa Karunaratne](https://github.com/manupa-arm): @manupa-arm - ethos-u, memory planner - [Marisa Kirisame](https://github.com/MarisaKirisame): @MarisaKirisame - relay - [Wuwei Lin](https://github.com/vinx13): @vinx13 - relay, topi - [Yizhi Liu](https://github.com/yzhliu) (PMC): @yzhliu - jvm, topi, relay From b774d7f1ecde6260696f667567f5a1177772483e Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Tue, 31 Aug 2021 03:23:41 +0300 Subject: [PATCH 33/42] [RPC] Fix ios_rpc build (#8864) --- apps/ios_rpc/tvmrpc/TVMRuntime.mm | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index e8788cb78e88..ccd372b1adf4 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -34,6 +34,7 @@ #include "../../../src/runtime/object.cc" #include "../../../src/runtime/profiling.cc" #include "../../../src/runtime/registry.cc" +#include "../../../src/runtime/source_utils.cc" #include "../../../src/runtime/system_library.cc" #include "../../../src/runtime/thread_pool.cc" #include "../../../src/runtime/threading_backend.cc" From 9a9cd7022e93fbbbff146d18f23d61189c70be1d Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Mon, 30 Aug 2021 20:21:50 -0500 Subject: [PATCH 34/42] [Vulkan][Target] Added the driver name to the vulkan target string. (#8882) Driver name (e.g. "NVIDIA", "radv", "AMD open-source driver") is read from the `driverName` property in [VkPhysicalDeviceDriverProperties](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDeviceDriverProperties.html), or is left as `"unknown_driver_name"` if the driver does not support querying the driver name. --- src/runtime/vulkan/vulkan_device.cc | 4 ++++ src/runtime/vulkan/vulkan_device.h | 1 + src/runtime/vulkan/vulkan_device_api.cc | 3 +++ src/target/target_kind.cc | 1 + 4 files changed, 9 insertions(+) diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index 156f86dbb03e..f5d26ace50d9 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -156,6 +156,10 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, device_name = properties.properties.deviceName; driver_version = properties.properties.driverVersion; + if (device.HasExtension("VK_KHR_driver_properties")) { + driver_name = driver.driverName; + } + switch (properties.properties.deviceType) { case VK_PHYSICAL_DEVICE_TYPE_OTHER: device_type = "other"; diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 412542029209..d7788ef5df29 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -94,6 +94,7 @@ struct VulkanDeviceProperties { uint32_t max_shared_memory_per_block{16384}; std::string device_type{"unknown_device_type"}; std::string device_name{"unknown_device_name"}; + std::string driver_name{"unknown_driver_name"}; uint32_t driver_version{0}; uint32_t vulkan_api_version{VK_API_VERSION_1_0}; uint32_t max_spirv_version{0x10000}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 3d27e1651852..cf0b16c6c471 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -242,6 +242,9 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, if (property == "device_type") { *rv = prop.device_type; } + if (property == "driver_name") { + *rv = prop.driver_name; + } if (property == "driver_version") { *rv = int64_t(prop.driver_version); } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d536b2e7b4b4..ab8e6eaad157 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -312,6 +312,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) // Other device properties .add_attr_option("device_type") .add_attr_option("device_name") + .add_attr_option("driver_name") .add_attr_option("driver_version") .add_attr_option("vulkan_api_version") .add_attr_option("max_spirv_version") From 6df070aac6d0e26d1e127095a323c61c2287eb9d Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Tue, 31 Aug 2021 02:45:47 -0700 Subject: [PATCH 35/42] [ONNX][TOPI] Support select_last_index for argmin/max (#8816) * support select_last_index for argmin/max * reverse conditions which made on accident * forward args in reduce.py * make proper nodes for reduction ops * remove complicated nested lambdas * fix lambda capture for conversion * forward more arguments * forward more args * enable onnx tests * wrapping casts to remove ambiguity * revert changes extraneous * correct incorrect attrs being used for ops * change attributes * remove old impl * register new attribute node * clean up test * reformat * reformat * coolio * stable comparison * casts to avoid ambiguity * casting more * correct arg passing * support select_last_index for argmin/max * reverse conditions which made on accident * forward args in reduce.py * make proper nodes for reduction ops * remove complicated nested lambdas * fix lambda capture for conversion * forward more arguments * forward more args * enable onnx tests * wrapping casts to remove ambiguity * revert changes extraneous * correct incorrect attrs being used for ops * change attributes * remove old impl * register new attribute node * clean up test * reformat * reformat * coolio * stable comparison * casts to avoid ambiguity * casting more * correct arg passing * fix broken input * OneElementReduceAttrs-->ArgReduceAttrs" * reduce boilerplate * change names * remove log statement * jostle ci Co-authored-by: Andrew Zhao Luo --- include/tvm/relay/attrs/reduce.h | 36 ++++++++ include/tvm/topi/reduction.h | 99 +++++++++++++++++----- python/tvm/relay/frontend/onnx.py | 20 ++--- python/tvm/relay/op/reduce.py | 20 +++-- python/tvm/topi/reduction.py | 16 +++- src/relay/op/tensor/reduce.cc | 80 +++++++++++++---- src/topi/reduction.cc | 4 +- tests/python/frontend/onnx/test_forward.py | 19 +---- tests/python/relay/test_op_level4.py | 38 +++++++-- 9 files changed, 245 insertions(+), 87 deletions(-) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 14b75ff1c0a8..d91b3594b5a3 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */ +struct ArgReduceAttrs : public tvm::AttrsNode { + Array axis; + bool keepdims; + bool select_last_index; + bool exclude; + + TVM_DECLARE_ATTRS(ArgReduceAttrs, "relay.attrs.ArgReduceAttrs") { + TVM_ATTR_FIELD(axis) + .set_default(NullValue>()) + .describe(R"code(The axis or axes along which to perform the reduction. + + The default, `axis=()`, will compute over all elements into a + scalar array with shape `(1,)`. + + If `axis` is int, a reduction is performed on a particular axis. + + If `axis` is a tuple of ints, a reduction is performed on all the axes + specified in the tuple. + + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead.)code"); + + TVM_ATTR_FIELD(keepdims).set_default(false).describe( + "If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(select_last_index) + .set_default(false) + .describe( + "Whether to select the last index if the target element appears multiple times, else " + "select the first index which the target element appears"); + TVM_ATTR_FIELD(exclude).set_default(false).describe( + "Whether to perform reduction on axis that are NOT in axis instead."); + } +}; + struct VarianceAttrs : public tvm::AttrsNode { Array axis; bool keepdims; diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 15d1455bb267..d4e420d80b02 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -431,6 +431,45 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } +inline FCommReduce MakeArgminReducer(bool select_last_index = false) { + // Create a Commutative Reducer with a comparison operation, and method to get the initial value. + auto fcombine = [=](Array lhs, Array rhs) { + Array result; + + // Casting to avoid operator ambiguity + PrimExpr lhs_idx = static_cast(lhs[0]); + PrimExpr rhs_idx = static_cast(rhs[0]); + PrimExpr lhs_val = static_cast(lhs[1]); + PrimExpr rhs_val = static_cast(rhs[1]); + + // These variables compare the actual values of the array + auto is_smaller = lhs_val < rhs_val; + auto is_same = lhs_val == rhs_val; + + // This checks if the indices are correct for the reduction. E.g. for select_last_index + // it gives precedence for later indices of the same element and precedence for sooner + // indices if not select_last_index; + PrimExpr proper_index; + if (select_last_index) { + proper_index = lhs_idx > rhs_idx; + } else { + proper_index = lhs_idx < rhs_idx; + } + + PrimExpr update_index = is_smaller || (is_same && proper_index); + result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val + return result; + }; + auto fidentity = [&](std::vector types) { + Array result; + result.push_back(tvm::tir::make_const(types[0], -1)); // idx + result.push_back(tvm::max_value(types[1])); // val + return result; + }; + return MakeCommReducer(fcombine, fidentity, "argmin"); +} + /*! * \brief Creates an operation that finds the indices of the minimum * values over a given axis. @@ -442,35 +481,48 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims * left in the result as dimensions with size one. This enables the result * to broadcast correctly against the input array. * \param atleast1d Whether the output need to be atleast1d. + * \param select_last_index Whether to select the last index if the minimum element + * appears multiple times, else select the first index. * * \return A Tensor whose op member is the argmin operation */ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdims = false, - bool atleast1d = false) { - auto fcombine = [](Array lhs, Array rhs) { - Array result; - result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val - return result; - }; - auto fidentity = [](std::vector types) { - Array result; - result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val - return result; - }; - auto func = MakeCommReducer(fcombine, fidentity, "argmin"); - return CommReduceIdx(data, axis, func, keepdims, atleast1d); + bool atleast1d = false, bool select_last_index = false) { + auto reducer = MakeArgminReducer(select_last_index); + return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } -inline FCommReduce MakeArgmaxReducer() { - auto fcombine = [](Array lhs, Array rhs) { +inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { + // Create a Commutative Reducer with a comparison operation, and method to get the initial value. + auto fcombine = [=](Array lhs, Array rhs) { Array result; - result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val + + // Casting to avoid operator ambiguity + PrimExpr lhs_idx = static_cast(lhs[0]); + PrimExpr rhs_idx = static_cast(rhs[0]); + PrimExpr lhs_val = static_cast(lhs[1]); + PrimExpr rhs_val = static_cast(rhs[1]); + + // These variables compare the actual values of the array + auto is_bigger = lhs_val > rhs_val; + auto is_same = lhs_val == rhs_val; + + // This checks if the indices are correct for the reduction. E.g. for select_last_index + // it gives precedence for later indices of the same element and precedence for sooner + // indices if not select_last_index; + PrimExpr proper_index; + if (select_last_index) { + proper_index = lhs_idx > rhs_idx; + } else { + proper_index = lhs_idx < rhs_idx; + } + + PrimExpr update_index = is_bigger || (is_same && proper_index); + result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val return result; }; - auto fidentity = [](std::vector types) { + auto fidentity = [&](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx result.push_back(tvm::min_value(types[1])); // val @@ -490,12 +542,13 @@ inline FCommReduce MakeArgmaxReducer() { * left in the result as dimensions with size one. This enables the result * to broadcast correctly against the input array. * \param atleast1d Whether the output need to be atleast1d. - * + * \param select_last_index Whether to select the last index if the maximum element + * appears multiple times, else select the first index. * \return A Tensor whose op member is the argmax operation */ inline Tensor argmax(const Tensor& data, const Array& axis, bool keepdims = false, - bool atleast1d = false) { - auto reducer = MakeArgmaxReducer(); + bool atleast1d = false, bool select_last_index = false) { + auto reducer = MakeArgmaxReducer(select_last_index); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9144d3e145c8..f9b49204b85e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -32,23 +32,23 @@ from .. import loops as _loops from .. import op as _op from .. import qnn as _qnn +from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .. import random as _random from .common import ( AttrCvt, Renamer, fold_constant, get_name, get_relay_op, + gru_cell, infer_channels, infer_shape, infer_type, infer_value, + lstm_cell, new_var, unbind, - gru_cell, - lstm_cell, ) __all__ = ["from_onnx"] @@ -1786,12 +1786,11 @@ class ArgMax(OnnxOpConverter): """Operator converter for ArgMax.""" @classmethod - def _impl_v1(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMax") + def _impl_v13(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} + select_last_index = attr.get("select_last_index", False) + attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") @@ -1799,12 +1798,11 @@ class ArgMin(OnnxOpConverter): """Operator converter for ArgMin.""" @classmethod - def _impl_v1(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMin") + def _impl_v13(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} + select_last_index = attr.get("select_last_index", False) + attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 368ffb5ab0ca..23accebfd0ec 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -17,13 +17,13 @@ """Reduce operators.""" # pylint: disable=redefined-builtin +from ..expr import Tuple, TupleWrapper from . import _make -from .tensor import sqrt, log, exp +from .tensor import exp, log, sqrt from .transform import squeeze -from ..expr import Tuple, TupleWrapper -def argmax(data, axis=None, keepdims=False, exclude=False): +def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -45,16 +45,20 @@ def argmax(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + select_last_index : bool + Whether to select the last index or the first index if the max element appears in + multiple indices, default is False (first index). + Returns ------- result : relay.Expr The computed result. """ axis = [axis] if isinstance(axis, int) else axis - return _make.argmax(data, axis, keepdims, exclude) + return _make.argmax(data, axis, keepdims, exclude, select_last_index) -def argmin(data, axis=None, keepdims=False, exclude=False): +def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -76,13 +80,17 @@ def argmin(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + select_last_index : bool + Whether to select the last index or the first index if the min element appears in + multiple indices, default is False (first index). + Returns ------- result : relay.Expr The computed result. """ axis = [axis] if isinstance(axis, int) else axis - return _make.argmin(data, axis, keepdims, exclude) + return _make.argmin(data, axis, keepdims, exclude, select_last_index) def sum(data, axis=None, keepdims=False, exclude=False): diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index 77f9ad447ed1..45d07af577a3 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False): return cpp.min(data, axis, keepdims) -def argmax(data, axis=None, keepdims=False): +def argmax(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False): with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the maximum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmax(data, axis, keepdims) + return cpp.argmax(data, axis, keepdims, select_last_index) -def argmin(data, axis=None, keepdims=False): +def argmin(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False): with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the minimum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmin(data, axis, keepdims) + return cpp.argmin(data, axis, keepdims, select_last_index) def prod(data, axis=None, keepdims=False): diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index f08af1e7e4ad..693589fecfb4 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -38,6 +38,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); +TVM_REGISTER_NODE_TYPE(ArgReduceAttrs); TVM_REGISTER_NODE_TYPE(VarianceAttrs); /*! @@ -207,9 +208,29 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp return {topi::identity(inputs[0])}; } } + return {f(inputs[0], axes, param->keepdims, false)}; } +template +Array ArgReduceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, F f) { + const ArgReduceAttrs* param = attrs.as(); + ICHECK(param != nullptr); + if (inputs[0]->shape.size() == 0) { + return {topi::identity(inputs[0])}; + } + auto axes = param->axis; + if (param->exclude) { + axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } + } + + return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)}; +} + /*! * \brief ReduceShapeImpl get the outshape for the reduction operator * \param in_shape Shape of input data. @@ -269,22 +290,16 @@ inline std::vector ReduceShapeImpl(const std::vector& in_s } } -/*! - * \brief ArgReduceRel Output type and shape relation evaluation function. - * \param num_inputs Number of input types in the args. - * \param attrs The additional attributes of the operator. - * \param reporter The reporter to report solution to. - * \return false if This relation cannot be resolved. true if this relation has been resolved. - */ -bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +template +bool GenericReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; ICHECK(static_cast(data->shape.size()) != 0); std::vector in_shape(data->shape.begin(), data->shape.end()); - const ReduceAttrs* param = attrs.as(); + const T* param = attrs.as(); ICHECK(param != nullptr); // assign output type and shape @@ -292,6 +307,17 @@ bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TensorType(oshape, DataType::Int(32))); return true; } +/*! + * \brief ArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + return GenericReduceRel(types, num_inputs, attrs, reporter); +} /*! * \brief ReduceRel Output type and shape relation evaluation function. @@ -324,6 +350,16 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); } +Expr MakeOneElementReduce(Expr data, Array axis, bool keepdims, bool exclude, + bool select_last_index, String op_name) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + attrs->exclude = exclude; + attrs->select_last_index = select_last_index; + return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); +} + #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude) { \ @@ -331,35 +367,43 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str }); \ RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") +#define RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ + .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude, \ + bool select_last_index) { \ + return MakeOneElementReduce(data, axis, keepdims, exclude, select_last_index, OpName); \ + }); \ + RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") + Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return ReduceCompute(attrs, inputs, out_type, topi::argmax); + return ArgReduceCompute(attrs, inputs, out_type, topi::argmax); } -RELAY_REGISTER_REDUCE_OP("argmax") +RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmax") .describe(R"code(Creates an operation that finds the indices of the maximum values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", ArgReduceRel) + .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) .set_attr("TOpPattern", kCommReduce); Array ArgMinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return ReduceCompute(attrs, inputs, out_type, topi::argmin); + return ArgReduceCompute(attrs, inputs, out_type, topi::argmin); } -RELAY_REGISTER_REDUCE_OP("argmin") +RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmin") .describe(R"code(Creates an operation that finds the indices of the minimum values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", ArgReduceRel) + .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMinCompute) .set_attr("TOpPattern", kCommReduce); diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 55c59162e68c..3d1c6f9f7d5b 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -45,11 +45,11 @@ TVM_REGISTER_GLOBAL("topi.max").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.argmin").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); + *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2], false, args[3]); }); TVM_REGISTER_GLOBAL("topi.argmax").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); + *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2], false, args[3]); }); TVM_REGISTER_GLOBAL("topi.prod").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9e0eb1f75217..a1d821686ed5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,7 +17,6 @@ import glob import os import re -import glob import numpy as np import pytest @@ -236,7 +235,7 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] @@ -4680,22 +4679,6 @@ def verify_eyelike(indata): "test_adagrad_multiple", "test_adam", "test_adam_multiple", - "test_argmax_default_axis_example_select_last_index", - "test_argmax_default_axis_random_select_last_index", - "test_argmax_keepdims_example_select_last_index", - "test_argmax_keepdims_random_select_last_index", - "test_argmax_negative_axis_keepdims_example_select_last_index", - "test_argmax_negative_axis_keepdims_random_select_last_index", - "test_argmax_no_keepdims_example_select_last_index", - "test_argmax_no_keepdims_random_select_last_index", - "test_argmin_default_axis_example_select_last_index", - "test_argmin_default_axis_random_select_last_index", - "test_argmin_keepdims_example_select_last_index", - "test_argmin_keepdims_random_select_last_index", - "test_argmin_negative_axis_keepdims_example_select_last_index", - "test_argmin_negative_axis_keepdims_random_select_last_index", - "test_argmin_no_keepdims_example_select_last_index", - "test_argmin_no_keepdims_random_select_last_index", "test_cast_BFLOAT16_to_FLOAT", "test_cast_DOUBLE_to_FLOAT16", "test_cast_FLOAT_to_BFLOAT16", diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index df77c33658de..6415976bfd59 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np -from tvm import relay +import numpy.random +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay, te from tvm.relay import transform from tvm.relay.testing import run_infer_type -import tvm.topi.testing -import tvm.testing @tvm.testing.uses_gpu @@ -342,6 +342,34 @@ def _unbiased_func(a, axis=None, dtype=None, keepdims=None): verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) +@tvm.testing.uses_gpu +def test_argmin_argmax_get_last_elements(): + def get_test_case(shape, gt_func, test_argmin=False): + total_ele = np.product(shape) + arr = np.zeros(total_ele) + target_value = -1 if test_argmin else 1 + arr[: total_ele // 3] = target_value + np.random.shuffle(arr) + arr = arr.reshape(shape) + ans = gt_func(np.flip(arr)) + return arr, len(arr) - ans - 1 + + funcs_and_gt_funcs = [(relay.argmax, np.argmax), (relay.argmin, np.argmin)] + lengths = [5, 10, 15] + for func, gt_func in funcs_and_gt_funcs: + for shape in lengths: + x_in = relay.var("x_in", shape=[shape]) + output = func(x_in, select_last_index=True) + arr, ans = get_test_case(shape, gt_func, test_argmin=func == relay.argmin) + + mod = tvm.IRModule.from_expr(output) + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor( + "graph", mod=mod, device=dev, target=target + ).evaluate()(arr) + assert op_res.numpy().item() == ans + + def verify_mean_var_std(funcs, shape, axis, keepdims): test_func = funcs[0] ref_func = funcs[1] From 400baf2b32d460497c3c65ebb50666536783d49e Mon Sep 17 00:00:00 2001 From: Adam Straw Date: Tue, 31 Aug 2021 08:57:30 -0700 Subject: [PATCH 36/42] refactor optimize GEMM on CPU tutorial (#8825) * refactor optimize GEMM on CPU tutorial * fix lint errors * fix more lint errors * fix typo * fix problem with redefinition of `k` add TODO and comments around loop unrolling clarify note on the array packing figure * reword general description of array packing * grap kaxis from compute definition * remove duplicate comments on unrolling --- tutorials/optimize/opt_gemm.py | 133 ++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 61 deletions(-) diff --git a/tutorials/optimize/opt_gemm.py b/tutorials/optimize/opt_gemm.py index 7af772784cd6..5d698c612ee8 100644 --- a/tutorials/optimize/opt_gemm.py +++ b/tutorials/optimize/opt_gemm.py @@ -101,7 +101,7 @@ k = te.reduce_axis((0, K), "k") A = te.placeholder((M, K), name="A") B = te.placeholder((K, N), name="B") -C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C") +C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C") # Default schedule s = te.create_schedule(C.op) @@ -130,15 +130,16 @@ # fill 32 * 32 * sizeof(float) which is 4KB in the cache whose total size is 32KB (L1 data cache) bn = 32 +kfactor = 4 s = te.create_schedule(C.op) # Blocking by loop tiling -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -(k,) = s[C].op.reduce_axis -ko, ki = s[C].split(k, factor=4) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +(kaxis,) = s[C].op.reduce_axis +ko, ki = s[C].split(kaxis, factor=kfactor) # Hoist reduction domain outside the blocking loop -s[C].reorder(xo, yo, ko, ki, xi, yi) +s[C].reorder(mo, no, ko, ki, mi, ni) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -162,19 +163,20 @@ # ------------- # Another important trick is vectorization. When the memory access pattern is uniform, # the compiler can detect this pattern and pass the continuous memory to vector processor. In TVM, -# we can use `vectorize` interface to hint the compiler this pattern, so that we can accelerate it vastly. +# we can use `vectorize` interface to hint the compiler this pattern, so that we can accelerate it +# vastly. # # In this tutorial, we chose to vectorize the inner loop row data since it is cache friendly. s = te.create_schedule(C.op) -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -(k,) = s[C].op.reduce_axis -ko, ki = s[C].split(k, factor=4) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +(kaxis,) = s[C].op.reduce_axis +ko, ki = s[C].split(kaxis, factor=kfactor) -s[C].reorder(xo, yo, ko, ki, xi, yi) +s[C].reorder(mo, no, ko, ki, mi, ni) # Vectorization -s[C].vectorize(yi) +s[C].vectorize(ni) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -194,20 +196,19 @@ ################################################################################################### # Loop Permutation # ---------------- -# If we look at the above IR, we can see the inner loop row data is vectorized and -# B is transformed into PackedB. The traversal of PackedB is sequential now. -# So we will look at the access pattern of A. In current schedule, A is accessed column by column -# which is not cache friendly. If we change the nested loop order of ki and inner axes xi, +# If we look at the above IR, we can see the inner loop row data is vectorized for both B and C. +# Next we will look at the access pattern of A. In current schedule, A is accessed column by column +# which is not cache friendly. If we change the nested loop order of ki and inner axes mi, # the access pattern for A matrix is more cache friendly. s = te.create_schedule(C.op) -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -(k,) = s[C].op.reduce_axis -ko, ki = s[C].split(k, factor=4) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +(kaxis,) = s[C].op.reduce_axis +ko, ki = s[C].split(kaxis, factor=kfactor) # re-ordering -s[C].reorder(xo, yo, ko, xi, ki, yi) -s[C].vectorize(yi) +s[C].reorder(mo, no, ko, mi, ki, ni) +s[C].vectorize(ni) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -227,43 +228,48 @@ ################################################################################################### # Array Packing # ------------- -# Another important trick is array packing. This trick is to reorder the storage dimension of the -# array to convert the continuous access pattern on certain dimension to a sequential pattern after -# flattening. +# Another important trick is array packing. The trick is to reorder the storage of a multi- +# dimensional array so that it is accessed sequentially after it is flattened and stored in one- +# dimensional memory. # # .. image:: https://github.com/dmlc/web-data/raw/main/tvm/tutorial/array-packing.png # :align: center # +# NOTE: This figure is a general illustration of how array packing works. ################################################################################################### -# Just as it is shown in the figure above, after blocking the computations, we can observe the array -# access pattern of B (after flattening), which is regular but discontinuous. We expect that after -# some transformation we can get continuous access pattern. We can reorder a [16][16] array to -# a [16/4][16][4] array, so that the access pattern of B will be sequential when grabing -# the corresponding value from the packed array. -# +# We can use array packing to address the access pattern for B. Observe the array access pattern of +# B after flattening which is not sequential as we iterate over the K dimension. We can reorder B +# with dimensions [K][N] so that it has dimensions [N/bn][K][bn] where bn is the blocking factor and +# also the vector size for B in the inner loop. This reorder splits N into two dimensions --- +# bigN (N/bn) and littleN (bn) --- and the new dimensions [N/bn][K][bn] match the indexing of B +# from outer to inner loops (no, ko, ki, ni) resulting in a sequential access pattern for B after +# flattening. + # We have to re-write the algorithm slightly. -packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name="packedB") +packedB = te.compute( + (N / bn, K, bn), lambda bigN, k, littleN: B[k, bigN * bn + littleN], name="packedB" +) C = te.compute( (M, N), - lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k), + lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k), name="C", ) s = te.create_schedule(C.op) -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -(k,) = s[C].op.reduce_axis -ko, ki = s[C].split(k, factor=4) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +(kaxis,) = s[C].op.reduce_axis +ko, ki = s[C].split(kaxis, factor=kfactor) -s[C].reorder(xo, yo, ko, xi, ki, yi) -s[C].vectorize(yi) +s[C].reorder(mo, no, ko, mi, ki, ni) +s[C].vectorize(ni) -x, y, z = s[packedB].op.axis -s[packedB].vectorize(z) -s[packedB].parallel(x) +bigN, _, littleN = s[packedB].op.axis +s[packedB].vectorize(littleN) +s[packedB].parallel(bigN) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -293,23 +299,28 @@ # Allocate write cache CC = s.cache_write(C, "global") -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -# Write cache is computed at yo -s[CC].compute_at(s[C], yo) +# Write cache is computed at no +s[CC].compute_at(s[C], no) # New inner axes -xc, yc = s[CC].op.axis +mc, nc = s[CC].op.axis + +(kaxis,) = s[CC].op.reduce_axis +ko, ki = s[CC].split(kaxis, factor=kfactor) +s[CC].reorder(ko, mc, ki, nc) +s[CC].vectorize(nc) -(k,) = s[CC].op.reduce_axis -ko, ki = s[CC].split(k, factor=4) -s[CC].reorder(ko, xc, ki, yc) +# TODO: Add separate optimization step to discuss loop unrolloing +# unrolling is a loop optimization strategy which can reduce branch +# prediction failures and increases the chance of concurrent execution +# unroll kfactor loops s[CC].unroll(ki) -s[CC].vectorize(yc) -x, y, z = s[packedB].op.axis -s[packedB].vectorize(z) -s[packedB].parallel(x) +bigN, _, littleN = s[packedB].op.axis +s[packedB].vectorize(littleN) +s[packedB].parallel(bigN) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -335,24 +346,24 @@ CC = s.cache_write(C, "global") -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -s[CC].compute_at(s[C], yo) +s[CC].compute_at(s[C], no) -xc, yc = s[CC].op.axis +mc, nc = s[CC].op.axis -(k,) = s[CC].op.reduce_axis -ko, ki = s[CC].split(k, factor=4) -s[CC].reorder(ko, xc, ki, yc) +(kaxis,) = s[CC].op.reduce_axis +ko, ki = s[CC].split(kaxis, factor=kfactor) +s[CC].reorder(ko, mc, ki, nc) +s[CC].vectorize(nc) s[CC].unroll(ki) -s[CC].vectorize(yc) # parallel -s[C].parallel(xo) +s[C].parallel(mo) -x, y, z = s[packedB].op.axis -s[packedB].vectorize(z) -s[packedB].parallel(x) +bigN, _, littleN = s[packedB].op.axis +s[packedB].vectorize(littleN) +s[packedB].parallel(bigN) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func From 7b91e62669a526641785fd573118b62b865ba381 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 31 Aug 2021 10:29:06 -0700 Subject: [PATCH 37/42] Change target string to Target object in the TE compiler and interpreter (#8835) * # This is a combination of 2 commits. # This is the 1st commit message: Initial changes # This is the commit message #2: Ftarget string -> Target object works! * Fix remaining target strings * fix bad rebase * Fix typo * 1 more bad rebase fix * Lint * typo * Forgot to commit this * Add TargetStrHash and Map --- include/tvm/target/target.h | 2 + src/relay/backend/aot_executor_codegen.cc | 9 ++-- src/relay/backend/build_module.cc | 17 +++---- src/relay/backend/interpreter.cc | 30 +++++++----- src/relay/backend/te_compiler.cc | 25 +++++----- src/relay/backend/te_compiler.h | 4 +- src/relay/backend/utils.cc | 18 ++++++++ src/relay/backend/utils.h | 56 ++++++++++++++++++++++- 8 files changed, 121 insertions(+), 40 deletions(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 9c1fe55749e4..deec662e74ad 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -31,6 +31,7 @@ #include #include +#include #include #include @@ -203,5 +204,6 @@ void CheckAndUpdateHostConsistency(Map* target, Target* host); * \param host The Target typed object for target host to be updated */ void CheckAndUpdateHostConsistency(Map* target, Target* host); + } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 2fb35f3a2e27..b2e862b22b48 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -665,11 +665,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.lowered_funcs = lowered_module.per_target_module; ret.external_mods = lowered_module.external_mods; - auto target_host_str = target_host_->str(); - if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { - ret.lowered_funcs[target_host_str]->Update(mod_run); + if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { + ret.lowered_funcs[target_host_]->Update(mod_run); } else { - ret.lowered_funcs.Set(target_host_str, mod_run); + ret.lowered_funcs.Set(target_host_, mod_run); } std::vector input_var_names(input_vars_.size()); @@ -774,7 +773,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { return (*it).second.first; } - Map get_irmodule() { return this->output_.lowered_funcs; } + Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; LoweredOutput output_; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b2b73e9bad02..69dced36295e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -92,8 +92,8 @@ struct ExecutorCodegen { return CallFunc>("get_external_modules", nullptr); } - Map GetIRModule() { - return CallFunc>("get_irmodule", nullptr); + Map GetIRModule() { + return CallFunc>("get_irmodule", nullptr); } runtime::Metadata GetMetadata() { return CallFunc("get_metadata"); } @@ -491,8 +491,9 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = executor_codegen_->GetIRModule(); // No need to build for external functions. - if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) { - lowered_funcs.Set("ext_dev", IRModule()); + Target ext_dev("ext_dev"); + if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) { + lowered_funcs.Set(ext_dev, IRModule()); } // Generate a placeholder function that attaches linked params as its arguments. @@ -510,11 +511,11 @@ class RelayBuildModule : public runtime::ModuleNode { DictAttrs attrs{dict}; auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), Map(), attrs); - if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) { - lowered_funcs.Set(target_host->str(), IRModule(Map({}))); + if (lowered_funcs.find(target_host) == lowered_funcs.end()) { + lowered_funcs.Set(target_host, IRModule(Map({}))); } - lowered_funcs[target_host->str()]->Add( - GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim); + lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), + prim); } // When there is no lowered_funcs due to reasons such as optimization. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index af2cbae1f72d..76b6f9186eb5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -53,7 +53,11 @@ namespace { struct PairHash { template std::size_t operator()(const std::pair& k) const { - return std::hash()(k.first) ^ std::hash()(k.second); + return dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } + template + std::size_t operator()(const std::pair& k) const { + return dmlc::HashCombine(ObjectHash()(k.first), std::hash()(k.second)); } }; @@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. - Interpreter(IRModule mod, Map per_target_module, Device device, Target target) + Interpreter(IRModule mod, Map per_target_module, Device device, Target target) : mod_(mod), per_target_module_(per_target_module), device_(device), @@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor, */ PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array& all_tir_fn_vars, Target target) { - std::pair packed_func_key(target->str(), tir_fn_var->name_hint); + std::pair packed_func_key(target, tir_fn_var->name_hint); auto packed_itr = compiled_packed_funcs_.find(packed_func_key); if (packed_itr != compiled_packed_funcs_.end()) { // Already compiled. @@ -382,8 +386,11 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - auto mod_itr = per_target_module_.find(target->str()); - ICHECK(mod_itr != per_target_module_.end()) + std::unordered_map + per_target_module_std_map = + backend::TargetModuleMapToTargetStrModuleMap(per_target_module_); + auto mod_itr = per_target_module_std_map.find(target); + ICHECK(mod_itr != per_target_module_std_map.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; for (const auto& var : all_tir_fn_vars) { @@ -407,7 +414,7 @@ class Interpreter : public ExprFunctor, PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint << "' in compiled module for target '" << target->str() << "'"; - compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); + compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); } // Return just what we need for this call. @@ -874,11 +881,10 @@ class Interpreter : public ExprFunctor, // Map from target key to lowered TIR functions derived from mod_. // Note that primitives are implicitly executed on target_, while shape functions are implicitly // executed on the default 'cpu' host. Thus this map has at most two entries. - Map per_target_module_; + Map per_target_module_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. - std::unordered_map, PackedFunc, PairHash> - compiled_packed_funcs_; + std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) Device device_; @@ -895,7 +901,7 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -std::pair> Prepare(IRModule mod, Device device, Target target) { +std::pair> Prepare(IRModule mod, Device device, Target target) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq({transform::SimplifyInference(), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' @@ -1014,7 +1020,7 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_with_expr, device, target); std::shared_ptr intrp = std::make_shared( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, @@ -1057,7 +1063,7 @@ ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_and_global.first, device, target); Interpreter intrp( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 71ac752ec680..06d862b781e1 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -85,18 +85,19 @@ class TECompilerImpl : public TECompilerNode { return LowerShapeFuncInternal(key)->cached_func; } - Map GetLoweredFunctions() { - Map lowered_functions; + Map GetLoweredFunctions() { + std::unordered_map + lowered_functions; for (const auto& it : cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions[target] = IRModule(Map({})); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } for (const auto& it : shape_func_cache_) { @@ -104,13 +105,13 @@ class TECompilerImpl : public TECompilerNode { auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions[target] = IRModule(Map({})); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } - return lowered_functions; + return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions); } Array LowerExternalFunctions() { @@ -884,7 +885,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { // Annotate the per-target functions with their target and add them to the unified module for (const auto& kv : mod.per_target_module) { - const String target = kv.first; + const Target target = kv.first; const IRModule target_module = kv.second; // Right now, per-target functions are TIR functions, which don't have type definitions, so @@ -926,7 +927,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->AddTypeDef(kv.first, kv.second); } - Map per_target_modules; + Map per_target_modules; for (const auto& kv : mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; @@ -934,7 +935,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->Add(var, func); } else if (func->IsInstance()) { // Extract target - Optional target = func->GetAttr(tvm::attr::kTarget); + Optional target = func->GetAttr(tvm::attr::kTarget); ICHECK(target) << "Target should be set at this point"; // Put the function in per_target_modules diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e9cfb0d62e66..65ba67ac7e1b 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -97,7 +97,7 @@ class TECompilerNode : public Object { virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; /* Return all functions which have been lowered by the compiler, keyed by target. */ - virtual Map GetLoweredFunctions() = 0; + virtual Map GetLoweredFunctions() = 0; /*! * \brief Just in time compile to get a PackedFunc. @@ -144,7 +144,7 @@ struct LoweredModule { /*! \brief The module which contains the Relay code. */ IRModule main_module; /*! \brief The module which contains per target code. */ - Map per_target_module; + Map per_target_module; /*! \brief The external runtime modules which must be combined with the lowered code. */ Array external_mods; // TODO(@electriclilies): THis might need to become a map diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 4b4844599e29..ea0ab093aa1d 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -187,6 +187,24 @@ Array GetPassPrefix(const Map& targets, bool is return pass_seqs; } +std::unordered_map +TargetModuleMapToTargetStrModuleMap(Map input_map) { + std::unordered_map std_map; + for (auto kv : input_map) { + std_map[kv.first] = kv.second; + } + return std_map; +} + +Map TargetStrModuleMapToTargetModuleMap( + std::unordered_map input_map) { + Map tvm_map; + for (auto kv : input_map) { + tvm_map.Set(kv.first, kv.second); + } + return tvm_map; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a0c7a5aad26d..cf8a2dd4b8e0 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); */ struct LoweredOutput { std::string graph_json; - Map lowered_funcs; + Map lowered_funcs; Array external_mods; Map function_metadata; std::unordered_map> params; @@ -427,6 +427,60 @@ inline bool IsCompileEngineCacheDisabled() { */ Array GetPassPrefix(const Map& targets, bool is_vm); +/*! \brief Target hash function */ +struct TargetStrHash { + /*! + * \brief Calculate the hash code of a Target based on the string value of the Target. + Note that this hash should NOT be used in new usecases, equality of targets based on their + value is not well-defined. + This will be removed when maps from Targets to IRModules are removed from the codebase. + * \param target The Target to hash + * \return String hash of the target + */ + size_t operator()(const Target& target) const { + return String::HashBytes(target->str().c_str(), target->str().size()); + } +}; + +/*! \brief Target equality function based on the string value of Target +Note that this equality function should NOT be used in new usecases, equality of targets based on +their value is not well-defined. This will be removed when maps from Targets to IRModules are +removed from the codebase.*/ +struct TargetStrEqual { + /*! + * \brief Check if the two Targets are equal + * \param target One Target + * \param other_target The other Target + * \return String equality of the targets + */ + const bool operator()(const Target& target, const Target& other_target) const { + TargetStrHash target_hash = TargetStrHash(); + return target_hash(target) == target_hash(other_target); + } +}; + +/*! + * \brief Convert a Map to std::unordered_map Target equality is currently based on pointer equality, which is a problem since + * we have a lot of Map in the codebase. This function converts the map to a + * version that is keyed based on string value of the Target instead. Note that once we remove + * Map, this function will be removed. + * \param input_map The map to convert + * \return The converted map + */ +std::unordered_map +TargetModuleMapToTargetStrModuleMap(Map input_map); + +/*! + * \brief Convert a std::unordered_map to + * Map This function is a helper that undoes TargetModuleMapToTargetStr. Note that + * once we remove Map, this function will be removed. + * \param input_map The map to convert + * \return The converted map + */ +Map TargetStrModuleMapToTargetModuleMap( + std::unordered_map input_map); + } // namespace backend } // namespace relay } // namespace tvm From b01ab9e81e6a9605e6d2dce5b0c81ce551c1839b Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 1 Sep 2021 03:59:53 +0800 Subject: [PATCH 38/42] [TensorIR][M2a] CacheRead/Write (#8863) Co-authored-by: Junru Shao Co-authored-by: Wuwei Lin Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> --- include/tvm/tir/schedule/schedule.h | 22 + include/tvm/tir/schedule/state.h | 5 + python/tvm/tir/schedule/schedule.py | 135 +++ src/tir/schedule/analysis.h | 21 +- src/tir/schedule/analysis/analysis.cc | 50 +- src/tir/schedule/concrete_schedule.cc | 21 + src/tir/schedule/concrete_schedule.h | 4 + src/tir/schedule/primitive.h | 24 + src/tir/schedule/primitive/block_annotate.cc | 4 +- .../schedule/primitive/cache_read_write.cc | 781 ++++++++++++++++++ src/tir/schedule/schedule.cc | 4 + src/tir/schedule/state.cc | 18 + src/tir/schedule/traced_schedule.cc | 23 + src/tir/schedule/traced_schedule.h | 4 + src/tir/schedule/transform.cc | 40 + src/tir/schedule/transform.h | 29 + src/tir/schedule/utils.h | 1 + .../test_tir_schedule_cache_read_write.py | 677 +++++++++++++++ 18 files changed, 1840 insertions(+), 23 deletions(-) create mode 100644 src/tir/schedule/primitive/cache_read_write.cc create mode 100644 tests/python/unittest/test_tir_schedule_cache_read_write.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 79fed09c3e36..33776cbe1985 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -282,6 +282,28 @@ class ScheduleNode : public runtime::Object { */ virtual void Unroll(const LoopRV& loop_rv) = 0; /******** Schedule: Insert cache stages ********/ + /*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * \param block_rv The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \return The cache stage block. + */ + virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + /*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block who writes the target buffer. + * 2) The scope block have stage-pipeline property. + * \param block_rv The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \return The cache stage block. + */ + virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 7cd1b00c15ef..35299a3fa84b 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -128,6 +128,11 @@ class ScheduleStateNode : public Object { */ TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, const Map& block_sref_reuse); + /*! + * \brief Recalculate the `affine_binding` flag of the scope block info. + * \param scope_sref The sref to the interested scope block. + */ + TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref); /*! * \brief Trigger the verification according to the `debug_mask` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 9433d019f9a5..ac09bdbb264d 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -790,6 +790,141 @@ def after_unroll(a: ty.handle, b: ty.handle) -> None: ########## Schedule: Insert cache stages ########## + def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV: + """Create a block that reads a buffer region into a read cache. It requires: + + 1) There is at most one block who write the buffer in the scope. + + 2) The scope block have stage-pipeline property. + + Parameters + ---------- + block : BlockRV + The consumer block of the target buffer. + + read_buffer_index: int + The index of the buffer in block's read region. + + storage_scope: str + The target storage scope. + + Returns + ------- + cached_block : BlockRV + The block of the cache stage + + Examples + -------- + Before cache_read, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_cache_read(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and cache_read: + + .. code-block:: python + + sch = tir.Schedule(before_cache_read) + block_b = sch.get_block("B") + sch.cache_read(block_b, 0, "local") + print(tvm.script.asscript(sch.mod["main"])) + + After applying cache_read, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_cache_read(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + A_local = tir.alloc_buffer((128, 128), scope="local") + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "A_local") as [vi, vj]: + A_local[vi, vj] = A[vi, vj] + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A_local[vi, vj] * 2.0 + + """ + return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member + self, block, read_buffer_index, storage_scope + ) + + def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV: + """Create a block that reads a buffer region into a write cache. It requires: + + 1) There is only one block who write the buffer in the scope. + + 2) The scope block have stage-pipeline property. + + Parameters + ---------- + block : BlockRV + The producer block of the target buffer. + + write_buffer_index: int + The index of the buffer in block's write region. + + storage_scope: str + The target storage scope. + + + Returns + ------- + cached_block : BlockRV + The block of the cache stage + + Examples + -------- + Before cache_write, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_cache_write(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and cache_write: + + .. code-block:: python + + sch = tir.Schedule(before_cache_write) + block_b = sch.get_block("B") + sch.cache_write(block_b, 0, "local") + print(tvm.script.asscript(sch.mod["main"])) + + After applying cache_write, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_cache_write(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + B_local = tir.alloc_buffer((128, 128), scope="local") + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "A_local") as [vi, vj]: + B_local[vi, vj] = A[vi, vj] * 2.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = B_local[vi, vj] + + """ + return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member + self, block, write_buffer_index, storage_scope + ) + ########## Schedule: Compute location ########## def compute_inline(self, block: BlockRV) -> None: diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 3fa0c63b2e2f..d4e4728abfe0 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -56,6 +56,13 @@ void VerifyCachedFlags(const ScheduleState& self); const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, GlobalVar* result_g_var); +/*! + * \brief Get the root node of the sref tree, which is the root block of the PrimFunc. + * \param sref The given sref. + * \return The root node of the sref tree which contains the given node. + */ +StmtSRef GetSRefTreeRoot(const StmtSRef& sref); + /******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it @@ -228,15 +235,15 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr /******** Block-buffer relation ********/ /*! - * \brief Get the BlockRealize of the single child block of the block or loop specified by - * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks - * \param self The schedule state - * \param block The queried block - * \param n The index of the queried buffer - * \return The buffer of the n-th write region of the block. + * \brief Get the n-th read or write buffer of the given block. + * \param self The schedule state. + * \param block The queried block. + * \param n The index of the queried buffer. + * \param is_write A boolean flag to indicate querying write buffer or read buffer. + * \return The buffer of the n-th read/write region of the block. * \throw ScheduleError If the buffer index is out of bound. */ -Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n); +Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write); /******** Commutative Reducer ********/ diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c9f8ff4c7e75..3865781c5870 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -588,25 +588,37 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr /******** Block-buffer relation ********/ -Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { - class WriteBufferIndexOutOfRangeError : public ScheduleError { +Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write) { + class BufferIndexOutOfRangeError : public ScheduleError { public: - explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index) - : mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {} + explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, bool is_write) + : mod_(std::move(mod)), + block_(std::move(block)), + buffer_index_(buffer_index), + is_write_(is_write) {} String FastErrorString() const final { - return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " - "range [0, num_write_regions) where `num_write_regions` is the number of buffer " - "regions written by the block."; + if (is_write_) { + return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " + "range " + "[0, num_write_regions) where `num_write_regions` is the number of buffer regions " + "written by the block."; + } else { + return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " + "range " + "[0, num_read_regions) where `num_read_regions` is the number of buffer regions " + "read by the block."; + } } String DetailRenderTemplate() const final { std::ostringstream os; - size_t num_writes = block_->writes.size(); - os << "The block {0} has " << num_writes - << " write regions, so `buffer_index` is required to be in [0, " << num_writes + size_t num = is_write_ ? block_->writes.size() : block_->reads.size(); + std::string access_type = is_write_ ? "write" : "read"; + os << "The block {0} has " << num << " " << access_type + << " regions, so `buffer_index` is required to be in [0, " << num << "). However, the input `buffer_index` is " << buffer_index_ - << ", which is out of the expected range"; + << ", which is out of the expected range."; return os.str(); } @@ -617,12 +629,15 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { IRModule mod_; Block block_; int buffer_index_; + bool is_write_; }; - if (n < 0 || static_cast(n) >= block->writes.size()) { - throw WriteBufferIndexOutOfRangeError(self->mod, block, n); + const Array& access_region = is_write ? block->writes : block->reads; + + if (n < 0 || static_cast(access_region.size()) <= n) { + throw BufferIndexOutOfRangeError(self->mod, block, n, is_write); } - return block->writes[n]->buffer; + return access_region[n]->buffer; } /******** Pattern Matcher ********/ @@ -941,5 +956,12 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, return false; } +/******** SRef Tree Related ********/ +StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { + const StmtSRefNode* p = sref.get(); + for (; p->parent != nullptr; p = p->parent) { + } + return GetRef(p); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index cd9aad8ae512..86223e11c196 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -416,6 +416,27 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { } /******** Schedule: Insert cache stages ********/ + +BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope); + TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); + TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 0bd902d183bf..e756f9da41b2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -103,6 +103,10 @@ class ConcreteScheduleNode : public ScheduleNode { void Bind(const LoopRV& loop_rv, const String& thread_axis) override; void Unroll(const LoopRV& loop_rv) override; /******** Schedule: Insert cache stages ********/ + BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index be33c2acca10..412611adf76d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -135,6 +135,30 @@ TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& */ TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref); /******** Schedule: Insert cache stages ********/ +/*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * \param self The state of the schedule + * \param block_sref The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \return The cache stage block. + */ +TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope); +/*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block that writes the target buffer. + * 2) The scope block have stage-pipeline property. + * \param self The state of the schedule + * \param block_sref The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \return The cache stage block. + */ +TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope); /******** Schedule: Compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 937bc7c3802f..06f7ac3c1bc2 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include "../transform.h" #include "../utils.h" namespace tvm { @@ -237,7 +236,8 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); - Buffer buffer = GetNthWriteBuffer(self, GetRef(block_ptr), buffer_index); + Buffer buffer = + GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, /*is_write=*/true); StorageAlignInvalidFactorError::Check(self->mod, factor); axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc new file mode 100644 index 000000000000..df54c9652ece --- /dev/null +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -0,0 +1,781 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Error Classes ********/ + +class NotSingleWriteBlock : public ScheduleError { + public: + explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, Array write_blocks) + : mod_(std::move(mod)), buffer_(std::move(buffer)) { + ICHECK_GT(write_blocks.size(), 1); + write_blocks_.reserve(write_blocks.size()); + for (const StmtSRef& block_sref : write_blocks) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + write_blocks_.push_back(GetRef(block)); + } + } + + String FastErrorString() const final { + return "ScheduleError: The buffer is allowed to be written by single block."; + } + + String DetailRenderTemplate() const final { + size_t k = write_blocks_.size(); + return "The buffer " + buffer_->name + " is expected to be written by single block, but got " + + std::to_string(k) + " blocks who write it."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { + return {write_blocks_.begin(), write_blocks_.end()}; + } + + private: + IRModule mod_; + Buffer buffer_; + Array write_blocks_; +}; + +/******** Helper Functions/Classes ********/ + +/*! \brief The auxiliary info used for the insertion point and content of the cache stage. */ +struct CacheStageInfo { + /*! \brief The buffer to be read. */ + Buffer read_buffer; + /*! \brief The buffer to be written. */ + Buffer write_buffer; + /*! \brief The buffer allocation to be inserted into the block signature. */ + Buffer alloc; + /*! \brief The AST node whose body is where the cache stage should be inserted. */ + StmtSRef loc_sref; + /*! \brief The index to insert the cache_read/cache_write stage. */ + size_t loc_pos; + /*! \brief The cache_read/cache_write stage to be inserted. */ + Stmt cache_stage; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map block_reuse; +}; + +/*! \brief Return the buffer region realted with the buffer */ +Optional GetBufferRegionFromBuffer(const Array& buffer_regions, + const Buffer& buffer) { + Optional res = NullOpt; + for (const auto& region : buffer_regions) { + if (region->buffer.same_as(buffer)) { + ICHECK(!res.defined()); + res = region; + } + } + return res; +} + +/*! + * \brief Create a loop nest that represents cache copy (cache_read / cache_write) from read buffer + * to write buffer. + * \note This function will store the stmt with loop nesting to the CacheStageInfo, but only return + * the inside block. + * \param cache_region The cached copy region. + * \param info The cache stage information, which will be updated in the function. + * \param storage_scope The storage scope of the cached buffer (only used in naming here) + * \returns A block indicating the body of the loop nesting. + */ +Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, + const String& storage_scope) { + // loop variables + std::vector loop_vars; + // bindings in block realize + std::vector iter_values; + // Create loop vars and block vars' binding_value + for (const Range& axis_range : cache_region->region) { + Var loop_var("ax" + std::to_string(loop_vars.size())); + loop_vars.push_back(loop_var); + iter_values.push_back(axis_range->min + loop_var); + } + // block variables + Array block_vars; + // block access region for read/write buffers + Region access_region; + // indices used in block body + Array access_indices; + // Create block vars, block's accessed region and accessing indices + for (const PrimExpr& dim : cache_region->buffer->shape) { + Var var("v" + std::to_string(access_indices.size())); + block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, dim), + /*var=*/var, + /*IterVarType=*/kDataPar)); + access_indices.push_back(var); + access_region.push_back(Range::FromMinExtent(var, 1)); + } + + // Create the body block: + // reads = [read_buffer[access_region]] + // writes = [write_buffer[access_region]] + // write_buffer[access_indices] = read_buffer[access_indices] + Block block( + /*iter_vars=*/std::move(block_vars), + /*reads=*/{BufferRegion(info->read_buffer, access_region)}, + /*writes=*/{BufferRegion(info->write_buffer, access_region)}, + /*name_hint=*/cache_region->buffer->name + "_" + storage_scope, + /*body=*/ + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, access_indices), + access_indices), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/{}); + // Create the block realize node + Stmt body = BlockRealize(/*values=*/iter_values, + /*predicate=*/Bool(true), + /*block=*/block); + // Create surrounding loops + for (size_t i = loop_vars.size(); i >= 1; --i) { + body = For(/*loop_var=*/loop_vars[i - 1], + /*min=*/0, + /*extent=*/cache_region->region[i - 1]->extent, + /*kind=*/ForKind::kSerial, + /*body=*/body); + } + info->cache_stage = std::move(body); + return block; +} + +/*! + * \brief Insert the cache_read/cache_write stage into the specific position + * \param stmt A sequence of statements or a single statement that the new stage is inserted in + * \param pos The position where the cache stage is inserted + * \param stage The stage to be inserted + * \return A SeqStmt, the result after insertion + */ +SeqStmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { + if (const auto* seq_stmt = stmt.as()) { + ObjectPtr result = make_object(*seq_stmt); + result->seq.insert(result->seq.begin() + pos, stage); + return SeqStmt(result); + } + if (pos == 0) { + return SeqStmt({stage, stmt}); + } + ICHECK_EQ(pos, 1); + return SeqStmt({stmt, stage}); +} + +/*! + * \brief Get the only writer block of the input buffer in a given scope block. + * \param self The state of the schedule + * \param scope_sref The scope block where the write is considered + * \param buffer The queried buffer + * \return The sref of the only writer of the input buffer in the given scope, + * or `NullOpt` if no block writes it in the scope. + * \throw NotSingleWriteBlock if there are more than one intrested block. + */ +Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, + const Buffer& buffer) { + BlockScope scope = self->GetBlockScope(scope_sref); + auto it = scope->buffer_writers.find(buffer); + if (it == scope->buffer_writers.end()) { + return NullOpt; + } else { + const Array& block_srefs = it->second; + ICHECK(!block_srefs.empty()); + if (block_srefs.size() > 1) { + throw NotSingleWriteBlock(self->mod, buffer, block_srefs); + } + return block_srefs[0]; + } +} + +/*! + * \brief Get the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) + * \param self The state of the schedule. + * \param buffer_region The buffer region to be analyzed. + * \param block_sref The sref of the block related to the region. + * \param dom_low_inclusive The lowest node in the sref tree path. + * \param dom_high_exclusive The highest node in the sref tree path. + * \return The relaxed buffer region. + */ +BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_region, + const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive, + const StmtSRef& dom_high_exclusive) { + BlockRealize realize = GetBlockRealize(self, block_sref); + Map binding = GetBindings(realize); + const Buffer& buffer = buffer_region->buffer; + Array int_sets = + arith::EvalSet(Substitute(buffer_region->region, binding), + AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/dom_low_inclusive, + /*high_exclusive=*/dom_high_exclusive, + /*extra_relax_scope=*/runtime::StorageScope::Create(buffer.scope())))); + ICHECK_EQ(buffer_region->region.size(), int_sets.size()); + + Region region; + region.reserve(int_sets.size()); + for (size_t i = 0; i < int_sets.size(); ++i) { + region.push_back(int_sets[i].CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); + } + return BufferRegion(buffer, region); +} + +/*! \brief Detect the insertion position of the new cache stage */ +class CacheLocDetector : public StmtVisitor { + public: + /*! + * \brief Detect the insertion position of the cache stage, and write the position into the + * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique + * writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref + * of the scope block of the cached block \param info The cache stage info. + */ + static void Detect(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_sref, CacheStageInfo* info) { + std::vector related_blocks; + for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { + if (def->kind == DepKind::kRAW) { + related_blocks.push_back(def->dst); + } + } + if (!related_blocks.empty()) { + CacheLocDetector detector(self, block_sref, scope_sref, related_blocks); + detector(GetRef(scope_sref->stmt)); + info->loc_sref = detector.loc_sref_; + info->loc_pos = detector.loc_pos_; + } else { + info->loc_sref = scope_sref; + const auto* body = scope_sref->StmtAs()->body.as(); + info->loc_pos = body == nullptr ? 1 : body->size(); + } + } + + private: + /*! + * \brief Constructor + * \param self The state of the schedule + * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or + * cache_write \param scope_sref The sref of the scope block of the cached block \param + * related_blocks Producer blocks for cache_write, or consumer blocks for cache_read + */ + CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const StmtSRef& scope_sref, + const std::vector& related_blocks) + : self_(self), + block_sref_(block_sref), + scope_sref_(scope_sref), + related_blocks_(related_blocks) {} + + void VisitStmt_(const SeqStmtNode* seq_stmt) final { + bool previous_visited_block = visited_block_; + bool previous_visited_related = visited_related_; + visited_block_ = visited_related_ = false; + + int pos = -1; + for (size_t i = 0; i < seq_stmt->size(); ++i) { + if (loc_pos_ != -1) { + break; + } + VisitStmt(seq_stmt->seq[i]); + // `pos` can be assigned only once when we visited `block_sref` + if (visited_block_ && visited_related_ && pos == -1) { + // The offset of insert position from the block + pos = i; + } + } + visited_block_ = visited_block_ || previous_visited_block; + visited_related_ = visited_related_ || previous_visited_related; + // Only we visited the writing block and any one of the related blocks + // That means that we have found the lowest ancestor + // of the block and any one of the related ones + if (visited_block_ && visited_related_ && loc_pos_ == -1) { + loc_pos_ = pos; + } + } + + void VisitStmt_(const BlockNode* block) final { + // Only visit the current scope under buffer writer's parent block + if (block == scope_sref_->stmt) { + // The block vistied is the current parent scope + StmtVisitor::VisitStmt_(block); + // Handling cache_read for input buffer + if (visited_block_ && visited_related_ && !loc_sref_.defined()) { + loc_sref_ = self_->stmt2ref.at(block); + if (loc_pos_ == -1) { + loc_pos_ = 1; + } + } + return; + } + // Update `visited_block` + if (block_sref_->stmt == block) { + visited_block_ = true; + return; + } + // Update `visited_related` + for (const StmtSRef& related_block : related_blocks_) { + if (related_block->stmt == block) { + visited_related_ = true; + return; + } + } + } + + void VisitStmt_(const ForNode* loop) final { + StmtVisitor::VisitStmt_(loop); + if (visited_block_ && visited_related_ && !loc_sref_.defined() && loc_pos_ != -1) { + loc_sref_ = self_->stmt2ref.at(loop); + } + } + + private: + /*! \brief The schedule class */ + const ScheduleState self_; + /*! \brief The dominate block which write the buffer */ + const StmtSRef& block_sref_; + /*! \brief The parent scope of the dominate block */ + const StmtSRef& scope_sref_; + /*! \brief Producer blocks for cache_write and consumer blocks for cache_read */ + const std::vector& related_blocks_; + /*! \brief The flag whether we have visited the dominate block */ + bool visited_block_{false}; + /*! \brief The flag whether we have visited at least one related blocks */ + bool visited_related_{false}; + /*! \brief The AST node whose body is where the cache stage should be inserted */ + StmtSRef loc_sref_{nullptr}; + /*! \brief The index to insert the cache_read/cache_write stage */ + int loc_pos_{-1}; +}; + +/*! \brief Mutator for CacheRead. */ +class CacheReadRewriter : public StmtExprMutator { + public: + /*! + * \brief Rewrite the AST and add a cache_read stage with the information provided + * \param scope_sref The parent scope of this mutation + * \param info The cache stage information + * \return The new AST rooting at the original parent scope + */ + static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info) { + CacheReadRewriter rewriter(scope_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt stmt = StmtMutator::VisitStmt_(loop); + // Check the insertion point + if (loop == info_->loc_sref->stmt) { + // Insert cache stage into the loop if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Stmt(n); + } + return stmt; + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + // We don't mutate the block which generates info->read_buffer + if (block != scope_sref_->stmt && + GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) { + return std::move(old_stmt); + } + // Mutate the body + Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + // Check the insertion point + if (block == info_->loc_sref->stmt) { + // Insert cache stage into the block if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Block(n); + } + // Check if it is the block corresponding to the parent scope + if (block == scope_sref_->stmt) { + // If so, put buffer allocation on the parent scope + ObjectPtr n = make_object(*stmt.as()); + n->alloc_buffers.push_back(info_->alloc); + stmt = Block(n); + } else { + // Otherwise, update read regions and match_buffers + Array reads = + ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer); + Array match_buffers = + ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer); + if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { + ObjectPtr n = make_object(*stmt.as()); + n->reads = std::move(reads); + n->match_buffers = std::move(match_buffers); + stmt = Block(n); + } + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->read_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->write_buffer; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + PrimExpr VisitExpr_(const LoadNode* load) final { + if (load->buffer_var.same_as(info_->read_buffer->data)) { + ObjectPtr n = make_object(*load); + n->buffer_var = info_->write_buffer->data; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (op == info_->read_buffer->data.get()) { + return info_->write_buffer->data; + } + return GetRef(op); + } + + private: + /*! \brief The parent scope of the insertion */ + const StmtSRef& scope_sref_; + /*! \brief The info for inserting cache stage */ + CacheStageInfo* info_; +}; + +/*! \brief Mutator for CacheWrite */ +class CacheWriteRewriter : public StmtExprMutator { + public: + /*! + * \brief Rewrite the AST and add a cache_write stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param writer_block_sref The only writer block in the scope. + * \param info The cache stage information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + CacheStageInfo* info) { + CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + CacheStageInfo* info) + : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) {} + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt stmt = StmtMutator::VisitStmt_(loop); + // Check the insertion point + if (loop == info_->loc_sref->stmt) { + // Insert cache stage into the loop if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Stmt(n); + } + return stmt; + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + // We only mutate the block which generates info->write_buffer + if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) { + return std::move(old_stmt); + } + + // Mutate the body + bool under_scope = under_writer_block_ || block == writer_block_sref_->stmt; + std::swap(under_scope, under_writer_block_); + Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + std::swap(under_scope, under_writer_block_); + + // Find the insertion point + if (block == info_->loc_sref->stmt) { + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Block(n); + } + // Put buffer allocation on the parent scope + if (block == scope_sref_->stmt) { + ObjectPtr n = make_object(*stmt.as()); + n->alloc_buffers.push_back(info_->alloc); + stmt = Block(n); + } else { + // Since cache_write changes the block, we need to update the buffer it writes + auto writes = ReplaceBuffer(block->writes, info_->write_buffer, info_->read_buffer); + auto reads = ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); + auto match_buffers = + ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); + if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || + !match_buffers.same_as(block->match_buffers)) { + ObjectPtr n = make_object(*stmt.as()); + n->writes = std::move(writes); + n->reads = std::move(reads); + n->match_buffers = std::move(match_buffers); + stmt = Block(n); + } + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); + if (stmt->buffer.same_as(info_->write_buffer)) { + auto n = CopyOnWrite(stmt.get()); + n->buffer = info_->read_buffer; + return Stmt(n); + } else { + return std::move(stmt); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->write_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->read_buffer; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + PrimExpr VisitExpr_(const LoadNode* load) final { + if (load->buffer_var.same_as(info_->write_buffer->data)) { + ObjectPtr n = make_object(*load); + n->buffer_var = info_->read_buffer->data; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + Stmt VisitStmt_(const StoreNode* store) final { + if (store->buffer_var.same_as(info_->write_buffer->data)) { + ObjectPtr n = make_object(*store); + n->buffer_var = info_->read_buffer->data; + return Stmt(n); + } + return StmtMutator::VisitStmt_(store); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (op == info_->write_buffer->data.get()) { + return info_->read_buffer->data; + } + return GetRef(op); + } + + private: + /*! \brief The parent scope of the insertion. */ + const StmtSRef& scope_sref_; + /*! \brief The parent scope of the insertion. */ + const StmtSRef& writer_block_sref_; + /*! \brief The info for inserting cache stage. */ + CacheStageInfo* info_; + /*! \brief Whether the current node is under the given block. */ + bool under_writer_block_{false}; +}; + +/******** Implementation ********/ + +StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is at most one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the consumers blocks. + * - Copy the buffer with the consumed region. + */ + + // Step 1. Check index, getting the target buffer and the parent scope + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer read_buffer = + GetNthAccessBuffer(self, GetRef(block), read_buffer_index, /*is_write=*/false); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + + // Step 2. Creat CacheStageInfo + CacheStageInfo info; + info.read_buffer = read_buffer; + // Create the corresponding buffer to be written, i.e. result of cache_read + info.write_buffer = WithScope(read_buffer, storage_scope); + // Create the corresponding buffer allocation + info.alloc = info.write_buffer; + + // Step 3. Update cache stage info. + BufferRegion cache_region{nullptr}; + if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + // Case 1. The buffer is written inside the block. + StmtSRef write_block_sref = _write_block_sref.value(); + const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block, write_block_sref); + // Find the producing region + BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value(); + StmtSRef parent_sref = GetRef(write_block_sref->parent); + + // Detect insert position + CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); + cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref); + } else { + // Case 2. The buffer is the input block for the scope. + info.loc_sref = scope_sref; + info.loc_pos = 0; + if (Optional region = + GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) { + cache_region = region.value(); + } else { + cache_region = BufferRegion::FullRegion(read_buffer); + } + } + + // Step 4. Making new cache stage block and rewrite readers. + Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); + + // Step 5. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); + self->UpdateAffineFlag(result_block_sref); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + +StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is only one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the producer blocks. + * - Copy the buffer with the consumed region. + */ + // Step 1. Checking index, getting the target buffer and the parent scope + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer write_buffer = + GetNthAccessBuffer(self, GetRef(block), write_buffer_index, /*is_write=*/true); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Creating CacheStageInfo + CacheStageInfo info; + info.read_buffer = WithScope(write_buffer, storage_scope); + // Create the corresponding buffer to be written, i.e. result of cache_write + info.write_buffer = write_buffer; + // Create the corresponding buffer allocation + info.alloc = info.read_buffer; + + // Step 3. Check the only writer block. + ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); + + // Step 4. Find the producing region and insert position + BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value(); + StmtSRef parent_sref = GetRef(block_sref->parent); + // Detect insert position + CacheLocDetector::Detect(self, block_sref, scope_sref, &info); + BufferRegion cache_region = + RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); + + // Step 5. Making new cache stage block and rewrite readers. + Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, + /*writer_block_sref=*/block_sref, /*info=*/&info); + + // Step 6. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); + self->UpdateAffineFlag(result_block_sref); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + +/******** Instruction Registration ********/ + +struct CacheReadTraits : public UnpackedInstTraits { + static constexpr const char* kName = "CacheRead"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer read_buffer_index, + String storage_scope) { + return sch->CacheRead(block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, + String storage_scope) { + PythonAPICall py("cache_read"); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct CacheWriteTraits : public UnpackedInstTraits { + static constexpr const char* kName = "CacheWrite"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index, + String storage_scope) { + return sch->CacheWrite(block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index, + String storage_scope) { + PythonAPICall py("cache_write"); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); +TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index d24cdc625912..fd30b02fc9dd 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -141,6 +141,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); /******** (FFI) Insert cache stages ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") + .set_body_method(&ScheduleNode::CacheRead); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") + .set_body_method(&ScheduleNode::CacheWrite); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 9a9b97497e04..799806bef7b5 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -1029,6 +1029,24 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl Bool(info.scope->stage_pipeline)}; } +TVM_DLL void ScheduleStateNode::UpdateAffineFlag(const StmtSRef& scope_sref) { + auto it = this->block_info.find(scope_sref); + ICHECK(it != this->block_info.end()) << "Cannot find the block info of the given block."; + BlockInfo& info = it->second; + + bool is_root_block = scope_sref->parent == nullptr; + if (is_root_block) { + info.affine_binding = true; + } else { + BlockRealize realize = GetBlockRealize(GetRef(this), scope_sref); + arith::Analyzer analyzer; + StmtSRef parent_sref = GetRef(scope_sref->parent); + info.affine_binding = IsAffineBinding(/*realize=*/realize, + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), + /*analyzer=*/&analyzer); + } +} + /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index af4a6588f064..f429a917858b 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -166,6 +166,29 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { } /******** Schedule: Insert cache stages ********/ +BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) { + BlockRV result = ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("CacheRead"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) { + BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("CacheWrite"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} /******** Schedule: Compute location ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 48dadbc03b3b..a6b5251a96a3 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -71,6 +71,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { void Bind(const LoopRV& loop_rv, const String& thread_axis) final; void Unroll(const LoopRV& loop_rv) final; /******** Schedule: Insert cache stages ********/ + BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block_rv) final; void ReverseComputeInline(const BlockRV& block_rv) final; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index f27e0f6d62eb..da376fdde90f 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -19,6 +19,8 @@ #include "./transform.h" +#include "./utils.h" + namespace tvm { namespace tir { @@ -31,5 +33,43 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec return Block(new_block); } +/******** Buffer Related ********/ +Buffer WithScope(const Buffer& buffer, const String& scope) { + ObjectPtr new_buffer = make_object(*buffer.get()); + ObjectPtr new_var = make_object(*buffer->data.get()); + const auto* ptr_type = TVM_TYPE_AS(ptr_type, buffer->data->type_annotation, PointerTypeNode); + new_var->type_annotation = PointerType(ptr_type->element_type, scope); + new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation); + new_buffer->name = buffer->name + "_" + scope; + return Buffer(new_buffer); +} + +Array ReplaceBuffer(Array regions, const Buffer& source, + const Buffer& target) { + regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion { + if (region->buffer.same_as(source)) { + ObjectPtr n = make_object(*region.get()); + n->buffer = target; + return BufferRegion(n); + } + return region; + }); + return regions; +} + +Array ReplaceBuffer(Array match_buffers, const Buffer& source, + const Buffer& target) { + match_buffers.MutateByApply([&source, + &target](MatchBufferRegion match_buffer) -> MatchBufferRegion { + if (match_buffer->source->buffer.same_as(source)) { + ObjectPtr n = make_object(*match_buffer.get()); + n->source = BufferRegion(target, n->source->region); + return MatchBufferRegion(n); + } + return match_buffer; + }); + return match_buffers; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 53483829a303..85cce9da216e 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -35,6 +35,35 @@ namespace tir { */ Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); +/******** Buffer Related ********/ + +/*! + * \brief Create a new buffer by changing the storage scope. + * \param buffer The given buffer. + * \param scope The target storage scope. + * \return The new buffer with target storage scope. + */ +Buffer WithScope(const Buffer& buffer, const String& scope); + +/*! + * \brief Replaces the buffer within the specific sequence of regions + * \param regions The regions whose buffers are to be replaced + * \param source The buffer to be replaced + * \param target The buffer to be replaced to + * \return The new sequence of regions after replacement + */ +Array ReplaceBuffer(Array regions, const Buffer& source, + const Buffer& target); + +/*! + * \brief Replaces the buffer within the specific sequence of match_buffers + * \param match_buffers The match_buffers whose buffers are to be replaced + * \param source The buffer to be replaced + * \param target The buffer to be replaced to + * \return The new sequence of match_buffers after replacement + */ +Array ReplaceBuffer(Array match_buffers, const Buffer& source, + const Buffer& target); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 8ccf8da731b5..c2f430181664 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -42,6 +42,7 @@ #include "./error.h" #include "./instruction_traits.h" #include "./primitive.h" +#include "./transform.h" namespace tvm { namespace tir { diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py new file mode 100644 index 000000000000..d7eb8d864135 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -0,0 +1,677 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + +########## Function before schedule ########## + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def access_under_scope(b: ty.handle, c: ty.handle) -> None: + A = tir.alloc_buffer((128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([8, 8], "scope") as [i, j]: + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + B[vi, vj] = A[vi, vj] + 1.0 + + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), dtype="float16") + B = tir.match_buffer(b, (128, 128), dtype="float16") + C = tir.match_buffer(c, (128, 128), dtype="float16") + D = tir.match_buffer(d, (128, 128), dtype="float16") + + with tir.block([128, 128], "load_store") as [vi, vj]: + tir.reads(A[vi, vj]) + tir.writes(D[vi, vj]) + D.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj) + with tir.block([8, 8], "opaque") as [vi, vj]: + tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.evaluate( + tir.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + with tir.block([8, 8], "match_buffer") as [vi, vj]: + tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = tir.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = tir.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + + +@tvm.script.tir +def func_multi_consumer() -> None: + A = tir.alloc_buffer((128)) + B = tir.alloc_buffer((128)) + C = tir.alloc_buffer((128)) + for i in tir.grid(8): + for j in tir.grid(16): + with tir.block([128], "A") as [vi]: + tir.bind(vi, i * 16 + j) + A[vi] = 1.0 + for j in tir.grid(16): + with tir.block([128], "B") as [vi]: + tir.bind(vi, i * 16 + j) + B[vi] = A[vi] + 1.0 + for i in tir.grid(128): + with tir.block([128], "C") as [vi]: + C[vi] = A[vi] + + +@tvm.script.tir +def func_multi_producer() -> None: + A = tir.alloc_buffer((128)) + B = tir.alloc_buffer((128)) + with tir.block([128], "A0") as [vi]: + A[vi] = 1.0 + with tir.block([128], "A1") as [vi]: + A[vi] = 2.0 + with tir.block([128], "B") as [vi]: + B[vi] = A[vi] + + +########## Expected function after cache_read ########## + + +@tvm.script.tir +def cache_read_elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + A_global = tir.alloc_buffer((128, 128)) + B_local = tir.alloc_buffer((128, 128), scope="local") + with tir.block([128, 128], "A_global") as [vi, vj]: + A_global[vi, vj] = A[vi, vj] + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A_global[vi, vj] * 2.0 + with tir.block([128, 128], "B_local") as [vi, vj]: + B_local[vi, vj] = B[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B_local[vi, vj] + 1.0 + + +@tvm.script.tir +def cache_read_under_scope(b: ty.handle, c: ty.handle) -> None: + A = tir.alloc_buffer((128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + A_global = tir.alloc_buffer((128, 128)) + + with tir.block([8, 8], "scope") as [i, j]: + A_local = tir.alloc_buffer((128, 128), scope="local") + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A_local") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A_local[vi, vj] = A[vi, vj] + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + B[vi, vj] = A_local[vi, vj] + 1.0 + with tir.block([128, 128], "A_global") as [vi, vj]: + A_global[vi, vj] = A[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A_global[vi, vj] * 2.0 + + +@tvm.script.tir +def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), dtype="float16") + B = tir.match_buffer(b, (128, 128), dtype="float16") + C = tir.match_buffer(c, (128, 128), dtype="float16") + D = tir.match_buffer(d, (128, 128), dtype="float16") + A_global = tir.alloc_buffer((128, 128), dtype="float16") + + with tir.block([128, 128], "A_global") as [vi, vj]: + A_global[vi, vj] = A[vi, vj] + with tir.block([128, 128], "load_store") as [vi, vj]: + tir.reads(A_global[vi, vj]) + tir.writes(D[vi, vj]) + D.data[vi * 128 + vj] = tir.load("float16", A_global.data, vi * 128 + vj) + with tir.block([8, 8], "opaque") as [vi, vj]: + tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.evaluate( + tir.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A_global.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + with tir.block([8, 8], "match_buffer") as [vi, vj]: + tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = tir.match_buffer( + A_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = tir.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + + +@tvm.script.tir +def cache_read_multi_consumer() -> None: + A = tir.alloc_buffer((128)) + B = tir.alloc_buffer((128)) + C = tir.alloc_buffer((128)) + A_global = tir.alloc_buffer((128)) + for i in tir.grid(8): + for j in tir.grid(16): + with tir.block([128], "A") as [vi]: + tir.bind(vi, i * 16 + j) + A[vi] = 1.0 + for j in tir.grid(16): + with tir.block([128], "A") as [vi]: + tir.bind(vi, i * 16 + j) + A_global[vi] = A[vi] + for j in tir.grid(16): + with tir.block([128], "B") as [vi]: + tir.bind(vi, i * 16 + j) + B[vi] = A_global[vi] + 1.0 + + for i in tir.grid(128): + with tir.block([128], "C") as [vi]: + C[vi] = A_global[vi] + + +@tvm.script.tir +def continuous_cache_read(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + B_shared = tir.alloc_buffer((128, 128), scope="shared") + B_local = tir.alloc_buffer((128, 128), scope="local") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "B_shared") as [vi, vj]: + B_shared[vi, vj] = B[vi, vj] + with tir.block([128, 128], "B_local") as [vi, vj]: + B_local[vi, vj] = B_shared[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B_local[vi, vj] + 1.0 + + +########## Expected function after cache_write ########## + + +@tvm.script.tir +def cache_write_elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + B_global = tir.alloc_buffer((128, 128), scope="local") + C_local = tir.alloc_buffer((128, 128)) + with tir.block([128, 128], "B_global") as [vi, vj]: + B_global[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = B_global[vi, vj] + with tir.block([128, 128], "C_local") as [vi, vj]: + C_local[vi, vj] = B[vi, vj] + 1.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = C_local[vi, vj] + + +@tvm.script.tir +def cache_write_under_scope(b: ty.handle, c: ty.handle) -> None: + A = tir.alloc_buffer((128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + A_global = tir.alloc_buffer((128, 128)) + + with tir.block([8, 8], "scope") as [i, j]: + A_local = tir.alloc_buffer((128, 128), scope="local") + B_global = tir.alloc_buffer((128, 128)) + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A_local") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A_local[vi, vj] = 1.0 + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A_global[vi, vj] = A_local[vi, vj] + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "B_global") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + B_global[vi, vj] = A_global[vi, vj] + 1.0 + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "B_global") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + B[vi, vj] = B_global[vi, vj] + with tir.block([128, 128], "A_global") as [vi, vj]: + A[vi, vj] = A_global[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), dtype="float16") + B = tir.match_buffer(b, (128, 128), dtype="float16") + C = tir.match_buffer(c, (128, 128), dtype="float16") + D = tir.match_buffer(d, (128, 128), dtype="float16") + D_global = tir.alloc_buffer((128, 128), dtype="float16") + B_global = tir.alloc_buffer((128, 128), dtype="float16") + C_global = tir.alloc_buffer((128, 128), dtype="float16") + + with tir.block([128, 128], "load_store") as [vi, vj]: + tir.reads(A[vi, vj]) + tir.writes(D_global[vi, vj]) + D_global.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj) + with tir.block([8, 8], "opaque") as [vi, vj]: + tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.evaluate( + tir.tvm_load_matrix_sync( + B_global.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + with tir.block([8, 8], "match_buffer") as [vi, vj]: + tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = tir.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = tir.match_buffer( + C_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + + with tir.block([128, 128], "D") as [vi, vj]: + D[vi, vj] = D_global[vi, vj] + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = B_global[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = C_global[vi, vj] + + +@tvm.script.tir +def cache_write_multi_consumer() -> None: + A = tir.alloc_buffer((128)) + B = tir.alloc_buffer((128)) + C = tir.alloc_buffer((128)) + A_global = tir.alloc_buffer((128)) + for i in tir.grid(8): + for j in tir.grid(16): + with tir.block([128], "A_global") as [vi]: + tir.bind(vi, i * 16 + j) + A_global[vi] = 1.0 + for j in tir.grid(16): + with tir.block([128], "A") as [vi]: + tir.bind(vi, i * 16 + j) + A[vi] = A_global[vi] + for j in tir.grid(16): + with tir.block([128], "B") as [vi]: + tir.bind(vi, i * 16 + j) + B[vi] = A[vi] + 1.0 + + for i in tir.grid(128): + with tir.block([128], "C") as [vi]: + C[vi] = A[vi] + + +@tvm.script.tir +def continuous_cache_write(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + B_shared = tir.alloc_buffer((128, 128), scope="shared") + B_local = tir.alloc_buffer((128, 128), scope="local") + with tir.block([128, 128], "B") as [vi, vj]: + B_local[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "B") as [vi, vj]: + B_shared[vi, vj] = B_local[vi, vj] + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = B_shared[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +########## Testcases for cache_read ########## + + +def test_cache_read_elementwise(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + cached_a = sch.cache_read(block_b, 0, "global") + cached_b = sch.cache_read(block_c, 0, "local") + assert sch.get(cached_a) == sch.get(sch.get_block("A_global")) + assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) + assert sch.get(block_b) == sch.get(sch.get_block("B")) + assert sch.get(block_c) == sch.get(sch.get_block("C")) + tvm.ir.assert_structural_equal(cache_read_elementwise, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_cache_read_under_scope(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.cache_read(block_b, 0, "local") + sch.cache_read(block_c, 0, "global") + tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=access_under_scope) + + +def test_cache_read_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block = sch.get_block("load_store") + sch.cache_read(block, 0, "global") + tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_cache_read_location(): + sch = tir.Schedule(func_multi_consumer, debug_mask="all") + block_b = sch.get_block("B") + sch.cache_read(block_b, 0, "global") + tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) + + +def test_continuous_cache_read(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_c = sch.get_block("C") + sch.cache_read(block_c, 0, "shared") + sch.cache_read(block_c, 0, "local") + tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_cache_read_fail_multi_producer(): + sch = tir.Schedule(func_multi_producer, debug_mask="all") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_read(block_b, 0, "global") + + +def test_cache_read_fail_index_out_of_bound(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_read(block_b, 1, "global") + + +########## Testcases for cache_write ########## + + +def test_cache_write_elementwise(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + cached_b = sch.cache_write(block_b, 0, "local") + cached_c = sch.cache_write(block_c, 0, "global") + assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) + assert sch.get(cached_c) == sch.get(sch.get_block("C_global")) + assert sch.get(block_b) == sch.get(sch.get_block("B")) + assert sch.get(block_c) == sch.get(sch.get_block("C")) + tvm.ir.assert_structural_equal(cache_write_elementwise, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_cache_write_under_scope(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + block_a = sch.get_block("A") + block_b = sch.get_block("B") + block_scope = sch.get_block("scope") + sch.cache_write(block_a, 0, "local") + sch.cache_write(block_b, 0, "global") + sch.cache_write(block_scope, 0, "global") + tvm.ir.assert_structural_equal(cache_write_under_scope, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=access_under_scope) + + +def test_cache_write_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block_store = sch.get_block("load_store") + block_opaque = sch.get_block("opaque") + block_match_buffer = sch.get_block("match_buffer") + sch.cache_write(block_store, 0, "global") + sch.cache_write(block_opaque, 0, "global") + sch.cache_write(block_match_buffer, 0, "global") + tvm.ir.assert_structural_equal(cache_write_opaque_access, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_cache_write_location(): + sch = tir.Schedule(func_multi_consumer, debug_mask="all") + block_a = sch.get_block("A") + sch.cache_write(block_a, 0, "global") + tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) + + +def test_continuous_cache_write(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + sch.cache_write(block_b, 0, "shared") + sch.cache_write(block_b, 0, "local") + tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_cache_write_fail_multi_producer(): + sch = tir.Schedule(func_multi_producer, debug_mask="all") + block_a0 = sch.get_block("A0") + block_a1 = sch.get_block("A1") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_write(block_a0, 0, "global") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_write(block_a1, 0, "global") + + +def test_cache_write_fail_index_out_of_bound(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_write(block_b, 1, "global") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 2f1c8454c832f6bc25e9da356ccd847d1668c9a4 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Wed, 1 Sep 2021 01:42:19 +0300 Subject: [PATCH 39/42] [CI] make pre-commit hooks to run on every push instead of every commit (#8888) --- .pre-commit-config.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a2c07de458a..982b78180f2a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ # Requirements: # - How to configure: # - $ pip install pre-commit -# - $ pre-commit install +# - $ pre-commit install --hook-type pre-push # - How to prevent running it: # - git options: --no-verify or -n # - $ git commit -n -m "YOUR COMMIT MESSAGE" @@ -32,8 +32,9 @@ # default_language_version: - python: python3.8 + python: python3.6 fail_fast: True +default_stages: [push] repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 @@ -42,7 +43,9 @@ repos: - id: check-merge-conflict - id: check-yaml - id: end-of-file-fixer + stages: [push] - id: trailing-whitespace + stages: [push] - repo: local hooks: - id: run-black From fa3b34c03673d7581f91c429b601124a2943d9dc Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 31 Aug 2021 23:15:45 -0400 Subject: [PATCH 40/42] [TVMScript] Fix printing ForNode annotations (#8891) --- src/printer/tvmscript_printer.cc | 2 +- tests/python/unittest/test_tvmscript_roundtrip.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index cc7536b48cfd..df02a6906a0c 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1069,7 +1069,7 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) { res << Print(loop->thread_binding.value()->thread_tag); } if (!loop->annotations.empty()) { - res << ", annotation = {"; + res << ", annotations = {"; res << PrintAnnotations(loop->annotations); res << "}"; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 0566ff5044d9..f9aee67f1d71 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2803,7 +2803,9 @@ def for_thread_binding(a: ty.handle, b: ty.handle) -> None: B = tir.match_buffer(b, (16, 16), "float32") for i in tir.thread_binding(0, 16, thread="threadIdx.x"): - for j in tir.thread_binding(0, 16, thread="threadIdx.y"): + for j in tir.thread_binding( + 0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"} + ): A[i, j] = B[i, j] + tir.float32(1) @@ -2818,6 +2820,7 @@ def test_for_thread_binding(): assert isinstance(rt_func.body.body, tir.stmt.For) assert rt_func.body.body.kind == 4 assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y" + assert rt_func.body.body.annotations["attr_key"] == "attr_value" @tvm.script.tir From 3f3c067fec57f33199d7c846a1bb1d60a339ea1a Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Wed, 1 Sep 2021 11:49:49 +0100 Subject: [PATCH 41/42] [1/10] CMSIS-NN graph partitioner for softmax (#8653) * cmsis graph partitioner for softmax Change-Id: I80ecd7bc5351f241b4674ef53b36e4398c8adb83 * Updated docstring in the partioning function Change-Id: Ieb4b623e5929cfdb6aa0235db64c825fac8d7055 --- python/tvm/relay/op/contrib/cmsisnn.py | 80 +++++++++++++ .../contrib/test_cmsisnn/test_softmax.py | 107 ++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 python/tvm/relay/op/contrib/cmsisnn.py create mode 100644 tests/python/contrib/test_cmsisnn/test_softmax.py diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py new file mode 100644 index 000000000000..daf1e098d7f1 --- /dev/null +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Arm(R) CMSIS-NN supported operators for Cortex-M.""" +import tvm.ir +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name + +from ...dataflow_pattern import is_constant, is_op, wildcard +from .register import register_pattern_table + + +def partition_for_cmsisnn(mod, params=None, **opts): + """Partition the graph greedily offloading supported + operators on Cortex-M using CMSIS-NN + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + + Returns + ------- + ret : Module + annotated and partitioned module. + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.MergeComposite(pattern_table()), + transform.AnnotateTarget("cmsisnn"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + + return seq(mod) + + +@register_pattern_table("cmsisnn") +def pattern_table(): + """Get the cmsisnn compiler pattern table.""" + + def softmax_pattern(): + pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + pattern = is_op("nn.softmax")(pattern) + pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant()) + return pattern + + def check_quantized_softmax(extract): + """Check if softmax is supported by CMSIS-NN.""" + + # check for dtypes of quantize and dequantize + return ( + extract.attrs.out_dtype == "int8" + and extract.args[0].args[0].args[0].checked_type.dtype == "int8" + ) + + return [ + ("cmsisnn.qnn_softmax", softmax_pattern(), check_quantized_softmax), + ] diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py new file mode 100644 index 000000000000..afbc302af66f --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: softmax""" + +import pytest +import sys + +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn +import numpy as np + + +def count_num_calls(mod): + class CallCounter(relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + counter = CallCounter() + for var in mod.get_global_vars(): + counter.visit(mod[var.name_hint]) + return counter.count + + +def make_module(func): + func = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(func) + return relay.transform.InferType()(mod) + + +def make_model(shape, zero_point, scale, in_dtype, out_dtype): + a = relay.var("a", shape=shape, dtype=in_dtype) + dequantize = relay.qnn.op.dequantize( + a, + input_scale=relay.const(scale, "float32"), + input_zero_point=relay.const(zero_point, "int32"), + ) + softmax = relay.nn.softmax(dequantize) + model = relay.qnn.op.quantize( + softmax, + output_scale=relay.const(scale, "float32"), + output_zero_point=relay.const(zero_point, "int32"), + out_dtype=out_dtype, + ) + return model + + +def test_softmax_int8(): + model = make_model([1, 16, 16, 3], 64, 0.02, "int8", "int8") + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsisnn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + +@pytest.mark.parametrize("in_dtype,out_dtype", [["uint8", "int8"], ["int8", "uint8"]]) +def test_softmax_not_int8(in_dtype, out_dtype): + model = make_model([1, 16, 16, 3], 64, 0.02, in_dtype, out_dtype) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert not any(attrs), "No function should have an external attribute." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 544724439efb9a795c92bd7ec9f7929e41c843c6 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Wed, 1 Sep 2021 09:51:49 -0400 Subject: [PATCH 42/42] [microTVM][RVM] Add Arduino RVM (#8748) * Functioning Arduino Vagrant VM Begin building Arduino Vagrant VM Mostly working Vagrant VM Changes for debugging Add ignored json file Fix venv path * Generalize parts of RVM for multiple platforms cwd hack Add unit tests from apps directory to task_python_microtvm.sh Generalize parts of RVM for multiple platforms * Add Vagrantfile lint exceptions * Address PR comments Address Mehrdad's PR comments More PR comments Documentation tweaks Add dialout group to user * Rerun tests * Spresense fix * Rerun CI tests * Rerun tests --- .../template_project/microtvm_api_server.py | 5 + apps/microtvm/reference-vm/arduino/.gitignore | 1 + apps/microtvm/reference-vm/arduino/README.md | 44 +++++++ .../microtvm/reference-vm/arduino/Vagrantfile | 66 +++++++++++ .../reference-vm/arduino/base-box/.gitignore | 4 + .../base-box/Vagrantfile.packer-template | 47 ++++++++ .../arduino/base-box/base_box_provision.sh | 77 +++++++++++++ .../arduino/base-box/base_box_setup.sh | 75 ++++++++++++ .../arduino/base-box/base_box_test.sh | 40 +++++++ .../arduino/base-box/test-config.json | 30 +++++ .../reference-vm/arduino/provision_setup.sh | 48 ++++++++ apps/microtvm/reference-vm/base-box-tool.py | 109 ++++++++++++------ apps/microtvm/reference-vm/rebuild-tvm.sh | 43 +++++++ .../reference-vm/zephyr/provision_setup.sh | 2 +- docker/install/ubuntu_install_arduino.sh | 2 +- tests/lint/check_file_type.py | 2 + tests/micro/arduino/conftest.py | 1 + 17 files changed, 561 insertions(+), 35 deletions(-) create mode 100644 apps/microtvm/reference-vm/arduino/.gitignore create mode 100644 apps/microtvm/reference-vm/arduino/README.md create mode 100644 apps/microtvm/reference-vm/arduino/Vagrantfile create mode 100644 apps/microtvm/reference-vm/arduino/base-box/.gitignore create mode 100644 apps/microtvm/reference-vm/arduino/base-box/Vagrantfile.packer-template create mode 100644 apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh create mode 100644 apps/microtvm/reference-vm/arduino/base-box/base_box_setup.sh create mode 100755 apps/microtvm/reference-vm/arduino/base-box/base_box_test.sh create mode 100644 apps/microtvm/reference-vm/arduino/base-box/test-config.json create mode 100644 apps/microtvm/reference-vm/arduino/provision_setup.sh create mode 100755 apps/microtvm/reference-vm/rebuild-tvm.sh diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 91beaf558249..57177179bcd0 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -65,6 +65,11 @@ class BoardAutodetectFailed(Exception): "architecture": "esp32", "board": "feathers2", }, + "metrom4": { + "package": "adafruit", + "architecture": "samd", + "board": "adafruit_metro_m4", + }, # Spresense only works as of its v2.3.0 sdk "spresense": { "package": "SPRESENSE", diff --git a/apps/microtvm/reference-vm/arduino/.gitignore b/apps/microtvm/reference-vm/arduino/.gitignore new file mode 100644 index 000000000000..dace7081e3f2 --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/.gitignore @@ -0,0 +1 @@ +/.vagrant diff --git a/apps/microtvm/reference-vm/arduino/README.md b/apps/microtvm/reference-vm/arduino/README.md new file mode 100644 index 000000000000..3fa1d8bfb4e3 --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/README.md @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + +# microTVM Arduino Reference Virtual Machine + +This directory contains setup files for Arduino virtual machine used for testing +microTVM platforms that are supported by [Arduino](https://www.arduino.cc/). + +## VM Information for Developers +Arduino VM is published under [tlcpack](https://app.vagrantup.com/tlcpack). +Here is a list of different release versions and their tools. + +(none currently) + +## Supported Arduino Boards +This RVM has been tested and is known to work with these boards: +- Adafruit Metro M4 +- Adafruit Pybadge +- Arduino Due +- Arduino Nano 33 BLE +- Feather S2 +- Sony Spresense +- Wio Terminal + +However, the RVM *should* work with any Arduino with sufficient memory, provided +its core is installed in `base-box/base_box_provision.sh`. + +Note that this RVM does not work with the Teensy boards, even though they are +supported by microTVM. This is because arduino-cli does not support Teensy +boards (https://github.com/arduino/arduino-cli/issues/700)/). diff --git a/apps/microtvm/reference-vm/arduino/Vagrantfile b/apps/microtvm/reference-vm/arduino/Vagrantfile new file mode 100644 index 000000000000..2511a6ae296e --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/Vagrantfile @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +Vagrant.configure("2") do |config| + config.vm.box = "tlcpack/microtvm-arduino-0.18.3" + + if ENV.has_key?("TVM_RVM_NUM_CORES") + num_cores = ENV["TVM_RVM_NUM_CORES"] + else + num_cores = 2 + end + + if ENV.has_key?("TVM_RVM_RAM_BYTES") + ram_bytes = ENV["TVM_RVM_RAM_BYTES"] + else + ram_bytes = 2048 + end + + tvm_home = "../../../.." + dirs_to_mount = [Pathname.new(Pathname.new(tvm_home).expand_path())] + if ENV.has_key?("TVM_PROJECT_DIR") then + dirs_to_mount.append(ENV["TVM_PROJECT_DIR"]) + puts "NOTE: also configuring project dir: %s" % [dirs_to_mount[-1]] + end + + git_file = Pathname.new(tvm_home + "/.git") + if git_file.ftype() == "file" then + gitdir_match = Regexp.new('^gitdir: (?.*/.git).*\n$', Regexp::MULTILINE).match(git_file.read()) + if !gitdir_match.nil? then + dirs_to_mount.append(Pathname.new(tvm_home).realpath.join(gitdir_match.named_captures["gitdir"])) + puts "NOTE: also configuring git-worktree gitdir: %s" % [dirs_to_mount[-1]] + end + end + + config.vm.provision "shell", path: "provision_setup.sh", env: {"TVM_HOME": dirs_to_mount[0]}, privileged: false + + # Enable USB Controller on VirtualBox + vm_name = "microtvm-arduino-#{Time.now.tv_sec}" + config.vm.provider "virtualbox" do |vb, overrides| + vb.name = vm_name + vb.cpus = num_cores + vb.memory = ram_bytes + vb.customize ["modifyvm", :id, "--usb", "on"] + vb.customize ["modifyvm", :id, "--usbehci", "on"] + vb.customize ["modifyvm", :id, "--usbxhci", "on"] + vb.customize [ "guestproperty", "set", :id, "/VirtualBox/GuestAdd/VBoxService/--timesync-set-threshold", 10000] + dirs_to_mount.each do |d| + overrides.vm.synced_folder d.to_s, d.to_s + end + end + +end diff --git a/apps/microtvm/reference-vm/arduino/base-box/.gitignore b/apps/microtvm/reference-vm/arduino/base-box/.gitignore new file mode 100644 index 000000000000..e4406c4f61e2 --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/base-box/.gitignore @@ -0,0 +1,4 @@ +*.box +.vagrant +/output-packer-* +/packer.json diff --git a/apps/microtvm/reference-vm/arduino/base-box/Vagrantfile.packer-template b/apps/microtvm/reference-vm/arduino/base-box/Vagrantfile.packer-template new file mode 100644 index 000000000000..b43596bb83c1 --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/base-box/Vagrantfile.packer-template @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +Vagrant.configure("2") do |config| + # From hashicorp default template: + # https://github.com/hashicorp/packer/blob/master/builder/vagrant/step_create_vagrantfile.go#L23-L37 + + config.vm.define "source" do |source| + source.vm.box = "{{.SourceBox}}" + config.ssh.insert_key = {{.InsertKey}} + end + + config.vm.define "output" do |output| + output.vm.box = "{{.BoxName}}" + output.vm.box_url = "file://package.box" + config.ssh.insert_key = {{.InsertKey}} + end + + {{ if ne .SyncedFolder "" -}} + config.vm.synced_folder "{{.SyncedFolder}}", "/vagrant" + {{- else -}} + config.vm.synced_folder ".", "/vagrant", disabled: true + {{- end}} + + + {{ if eq .BoxName "microtvm-base-vmware_desktop" -}} + config.vm.provision "shell", inline: "touch ~/skip_zeroing_disk", privileged: false + {{- end}} + + # NOTE: base_box_setup.sh resides in the parent directory (../) because this template is expanded into a + # sub-directory of base-box (output-packer-*). + config.vm.provision "shell", path: "../base_box_setup.sh", privileged: false +end diff --git a/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh b/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh new file mode 100644 index 000000000000..996d303d48fb --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh @@ -0,0 +1,77 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Using this script we can reuse docker/install scripts to configure the reference +# virtual machine similar to CI QEMU setup. +# + +set -e +set -x + +source ~/.profile + +# Init Arduino +cd ~ + +sudo apt-get install -y ca-certificates + +# Install Arduino-CLI (latest version) +export PATH="/home/vagrant/bin:$PATH" +wget -O - https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh -s + +# Arduino (the CLI and GUI) require the dialout permission for uploading +sudo usermod -a -G dialout $USER + +# ubuntu_init_arduino.sh only installs a few officially +# supported architectures, so we don't use it here + +# 3rd party board URLs +ADAFRUIT_BOARDS_URL="https://adafruit.github.io/arduino-board-index/package_adafruit_index.json" +ESP32_BOARDS_URL="https://raw.githubusercontent.com/espressif/arduino-esp32/gh-pages/package_esp32_dev_index.json" +SPARKFUN_BOARDS_URL="https://raw.githubusercontent.com/sparkfun/Arduino_Boards/master/IDE_Board_Manager/package_sparkfun_index.json" +SEEED_BOARDS_URL="https://files.seeedstudio.com/arduino/package_seeeduino_boards_index.json" +SPRESENSE_BOARDS_URL="https://github.com/sonydevworld/spresense-arduino-compatible/releases/download/generic/package_spresense_index.json" +arduino-cli core update-index --additional-urls $ADAFRUIT_BOARDS_URL,$ESP32_BOARDS_URL,$SPARKFUN_BOARDS_URL,$SEEED_BOARDS_URL,$SPRESENSE_BOARDS_URL + +# Install supported cores from those URLS +arduino-cli core install arduino:mbed_nano +arduino-cli core install arduino:sam +arduino-cli core install adafruit:samd --additional-urls $ADAFRUIT_BOARDS_URL +arduino-cli core install esp32:esp32 --additional-urls $ESP32_BOARDS_URL +arduino-cli core install Seeeduino:samd --additional-urls $SEEED_BOARDS_URL +arduino-cli core install SPRESENSE:spresense --additional-urls $SPRESENSE_BOARDS_URL + +# The Sony Spresense SDK has a major bug that breaks TVM. It's scheduled to be fixed in +# release 2.3.0, but until that's published we need to use the below hack. This ONLY +# fixes the bug in the main core release SDK - the subcore release SDK and both +# the main and subcore debug SDKs will continue to fail until an official fix is made. +# https://github.com/sonydevworld/spresense/issues/200 +SPRESENSE_NUTTX_BUGFIX_PATH=~/.arduino15/packages/SPRESENSE/tools/spresense-sdk/2.2.1/spresense/release/nuttx/include/sys/types.h +sed -i 's/#ifndef CONFIG_WCHAR_BUILTIN/#if !defined(__cplusplus)/g' $SPRESENSE_NUTTX_BUGFIX_PATH + +# There's also a bug in arduino-cli where {runtime.os} is not properly templated in +# platform.txt. This bug only seems to appear with the SPRESENSE SDK. A fix has been +# merged and will be part of arduino-cli 0.18.4, but that has yet to be published. +# This change is only needed to upload code (not compile) for the Spresense. +# https://github.com/arduino/arduino-cli/issues/1198 +SPRESENSE_FLASH_WRITER_BUGFIX_PATH=~/.arduino15/packages/SPRESENSE/hardware/spresense/2.2.1/platform.txt +sed -i 's/tools.spresense-tools.cmd.path={path}\/flash_writer\/{runtime.os}\/flash_writer/tools.spresense-tools.cmd.path={path}\/flash_writer\/linux\/flash_writer/g' $SPRESENSE_FLASH_WRITER_BUGFIX_PATH +sed -i 's/tools.spresense-tools.cmd.path.linux={path}\/flash_writer\/{runtime.os}\/flash_writer/tools.spresense-tools.cmd.path.linux={path}\/flash_writer\/linux\/flash_writer/g' $SPRESENSE_FLASH_WRITER_BUGFIX_PATH + +# Cleanup +rm -f *.sh diff --git a/apps/microtvm/reference-vm/arduino/base-box/base_box_setup.sh b/apps/microtvm/reference-vm/arduino/base-box/base_box_setup.sh new file mode 100644 index 000000000000..d02518c538b4 --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/base-box/base_box_setup.sh @@ -0,0 +1,75 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -x + +skip_zeroing_disk=0 +if [ -e "$HOME/skip_zeroing_disk" ]; then + echo "NOTE: will not zero disk at the end due to VMWare Fusion bug" + echo "See: https://communities.vmware.com/t5/VMware-Fusion-Discussions/VMWare-Fusion-Pro-11-15-6-16696540-causes-macOS-crash-during/m-p/2284011#M139190" + skip_zeroing_disk=1 +fi + +sudo apt update +sudo apt install -y build-essential +sudo apt-get --purge remove modemmanager # required to access serial ports. + +sudo apt install -y --no-install-recommends git \ + cmake cmake-data \ + ninja-build gperf ccache dfu-util device-tree-compiler wget \ + python3-dev python3-pip python3-setuptools python3-tk python3-wheel xz-utils file \ + make gcc gcc-multilib g++-multilib libsdl2-dev + +OLD_HOSTNAME=$(hostname) +sudo hostnamectl set-hostname microtvm +sudo sed -i.bak "s/${OLD_HOSTNAME}/microtvm.localdomain/g" /etc/hosts + +# Poetry deps +sudo apt install -y python3-venv + +# TVM deps +sudo apt install -y llvm + +# ONNX deps +sudo apt install -y protobuf-compiler libprotoc-dev + +# TODO do we need this? +echo 'export PATH=$HOME/vagrant/bin:"$PATH"' >> ~/.profile +source ~/.profile +echo PATH=$PATH + +# Poetry +curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python3 +sed -i "/^# If not running interactively,/ i source \$HOME/.poetry/env" ~/.bashrc +sed -i "/^# If not running interactively,/ i\\ " ~/.bashrc + +# Clean box for packaging as a base box +sudo apt-get clean +if [ $skip_zeroing_disk -eq 0 ]; then + echo "Zeroing disk..." + EMPTY_FILE="$HOME/EMPTY" + dd if=/dev/zero "of=${EMPTY_FILE}" bs=1M || /bin/true + if [ ! -e "${EMPTY_FILE}" ]; then + echo "failed to zero empty sectors on disk" + exit 2 + fi + rm -f "${EMPTY_FILE}" +else + echo "NOTE: skipping zeroing disk due to command-line argument." +fi diff --git a/apps/microtvm/reference-vm/arduino/base-box/base_box_test.sh b/apps/microtvm/reference-vm/arduino/base-box/base_box_test.sh new file mode 100755 index 000000000000..3d8597f19b64 --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/base-box/base_box_test.sh @@ -0,0 +1,40 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Usage: base_box_test.sh +# Execute microTVM Arduino tests. +# + +set -e +set -x + +if [ "$#" -lt 1 ]; then + echo "Usage: base_box_test.sh " + exit -1 +fi + +microtvm_platform=$1 + +pytest tests/micro/arduino/test_arduino_workflow.py --microtvm-platforms=${microtvm_platform} + +if [ $microtvm_platform == "nano33ble" ]; then + # https://github.com/apache/tvm/issues/8730 + echo "NOTE: skipped test_arduino_rpc_server.py on $microtvm_platform -- known failure" +else + pytest tests/micro/arduino/test_arduino_rpc_server.py --microtvm-platforms=${microtvm_platform} +fi diff --git a/apps/microtvm/reference-vm/arduino/base-box/test-config.json b/apps/microtvm/reference-vm/arduino/base-box/test-config.json new file mode 100644 index 000000000000..80cc17f56847 --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/base-box/test-config.json @@ -0,0 +1,30 @@ +{ + "due": { + "vid_hex": "2341", + "pid_hex": "003d" + }, + "feathers2": { + "vid_hex": "303a", + "pid_hex": "0002" + }, + "nano33ble": { + "vid_hex": "2341", + "pid_hex": "805a" + }, + "spresense": { + "vid_hex": "10c4", + "pid_hex": "ea60" + }, + "teensy40": { + "vid_hex": "16c0", + "pid_hex": "0478" + }, + "teensy41": { + "vid_hex": "16c0", + "pid_hex": "0478" + }, + "wioterminal": { + "vid_hex": "2886", + "pid_hex": "802d" + } +} diff --git a/apps/microtvm/reference-vm/arduino/provision_setup.sh b/apps/microtvm/reference-vm/arduino/provision_setup.sh new file mode 100644 index 000000000000..aeb46a8f7649 --- /dev/null +++ b/apps/microtvm/reference-vm/arduino/provision_setup.sh @@ -0,0 +1,48 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex + +# TVM +# NOTE: TVM is presumed to be mounted already by Vagrantfile. +cd "${TVM_HOME}" + +apps/microtvm/reference-vm/rebuild-tvm.sh + +# Build poetry +cd apps/microtvm/reference-vm/arduino + +poetry env use 3.6 +# NOTE: due to https://github.com/python-poetry/poetry/issues/2247, download torch here. +poetry run pip3 install torch==1.4.0 torchvision==0.5.0 + +# importers +poetry install -E importer-onnx +poetry install -E importer-tflite + +echo "------------------------------[ TVM Message ]------------------------------" +echo "WARNING: running 'poetry lock', which could take several minutes (depending" +echo "on your network connection and the state of PyPI) as dependencies are" +echo "downloaded and cached for future use." +echo "------------------------------[ TVM Message ]------------------------------" +poetry lock -vvv +poetry install + +echo "export TVM_LIBRARY_PATH=\"$TVM_HOME\"/build-microtvm" >>~/.profile +echo "VENV_PATH=\$((cd \"$TVM_HOME\"/apps/microtvm/reference-vm/arduino && poetry env list --full-path) | sed -E 's/^(.*)[[:space:]]\(Activated\)\$/\1/g')" >>~/.profile +echo "source \$VENV_PATH/bin/activate" >>~/.profile diff --git a/apps/microtvm/reference-vm/base-box-tool.py b/apps/microtvm/reference-vm/base-box-tool.py index be9c5173de73..f32885433c2b 100755 --- a/apps/microtvm/reference-vm/base-box-tool.py +++ b/apps/microtvm/reference-vm/base-box-tool.py @@ -41,14 +41,43 @@ "vmware_desktop", ) -# List of microTVM platforms for testing. -ALL_MICROTVM_PLATFORMS = ( - "stm32f746xx_nucleo", - "stm32f746xx_disco", - "nrf5340dk", - "mps2_an521", +# List of supported electronics platforms. Each must correspond +# to a sub-directory of this directory. +ALL_PLATFORMS = ( + "arduino", + "zephyr", ) +# List of identifying strings for microTVM platforms for testing. +# Must match PLATFORMS as defined in tvm/tests/micro/[platform]/conftest.py +# TODO add a way to declare supported platforms to ProjectAPI +ALL_MICROTVM_PLATFORMS = { + "arduino": ( + "due", + "feathers2", + "metrom4", + "nano33ble", + "pybadge", + "spresense", + "teensy40", + "teensy41", + "wioterminal", + ), + "zephyr": ( + "stm32f746xx_nucleo", + "stm32f746xx_disco", + "nrf5340dk", + "mps2_an521", + ), +} + +# Extra scripts required to execute on provisioning +# in [platform]/base-box/base_box_provision.sh +EXTRA_SCRIPTS = { + "arduino": (), + "zephyr": ("docker/install/ubuntu_init_zephyr_project.sh",), +} + PACKER_FILE_NAME = "packer.json" @@ -176,12 +205,8 @@ def attach_vmware(uuid, vid_hex=None, pid_hex=None, serial=None): "vmware_desktop": attach_vmware, } -# Extra scripts required to execute on provisioning -# in zephyr/base-box/base_box_provision.sh -EXTRA_SCRIPTS = ("docker/install/ubuntu_init_zephyr_project.sh",) - -def generate_packer_config(file_path, providers): +def generate_packer_config(platform, file_path, providers): builders = [] provisioners = [] for provider_name in providers: @@ -199,9 +224,9 @@ def generate_packer_config(file_path, providers): ) repo_root = subprocess.check_output( - ["git", "rev-parse", "--show-toplevel"], cwd=os.path.dirname(__file__), encoding="utf-8" + ["git", "rev-parse", "--show-toplevel"], encoding="utf-8" ).strip() - for script in EXTRA_SCRIPTS: + for script in EXTRA_SCRIPTS[platform]: script_path = os.path.join(repo_root, script) filename = os.path.basename(script_path) provisioners.append({"type": "file", "source": script_path, "destination": f"~/{filename}"}) @@ -227,6 +252,7 @@ def generate_packer_config(file_path, providers): def build_command(args): generate_packer_config( + args.platform, os.path.join(THIS_DIR, args.platform, "base-box", PACKER_FILE_NAME), args.provider or ALL_PROVIDERS, ) @@ -311,7 +337,7 @@ def do_build_release_test_vm(release_test_dir, user_box_dir, base_box_dir, provi return True -def do_run_release_test(release_test_dir, provider_name, test_config, test_device_serial): +def do_run_release_test(release_test_dir, platform, provider_name, test_config, test_device_serial): with open( os.path.join(release_test_dir, ".vagrant", "machines", "default", provider_name, "id") ) as f: @@ -335,7 +361,7 @@ def _quote_cmd(cmd): + " && " + _quote_cmd( [ - "apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh", + f"apps/microtvm/reference-vm/{platform}/base-box/base_box_test.sh", test_config["microtvm_platform"], ] ) @@ -377,7 +403,11 @@ def test_command(args): release_test_dir, user_box_dir, base_box_dir, provider_name ) do_run_release_test( - release_test_dir, provider_name, microtvm_test_platform, args.test_device_serial + release_test_dir, + args.platform, + provider_name, + microtvm_test_platform, + args.test_device_serial, ) provider_passed[provider_name] = True @@ -439,26 +469,27 @@ def parse_args(): "--provider", choices=ALL_PROVIDERS, action="append", - help="Name of the provider or providers to act on; if not specified, act on all.", + required=True, + help="Name of the provider or providers to act on", ) - parser.add_argument( - "platform", - help="Name of the platform VM to act on. Must be a sub-directory of this directory.", - ) + # "test" has special options for different platforms, and "build", "release" might + # in the future, so we'll add the platform argument to each one individually. + platform_help_str = "Platform to use (e.g. Arduino, Zephyr)" + # Options for build subcommand parser_build = subparsers.add_parser("build", help="Build a base box.") parser_build.set_defaults(func=build_command) - parser_test = subparsers.add_parser("test", help="Test a base box before release.") - parser_test.set_defaults(func=test_command) - parser_release = subparsers.add_parser("release", help="Release base box to cloud.") - parser_release.set_defaults(func=release_command) - + parser_build.add_argument("platform", help=platform_help_str, choices=ALL_PLATFORMS) parser_build.add_argument( "--debug-packer", action="store_true", help=("Run packer in debug mode, and write log to the base-box directory."), ) + + # Options for test subcommand + parser_test = subparsers.add_parser("test", help="Test a base box before release.") + parser_test.set_defaults(func=test_command) parser_test.add_argument( "--skip-build", action="store_true", @@ -475,12 +506,21 @@ def parse_args(): "iSerial field from `lsusb -v` output." ), ) - parser_test.add_argument( - "--microtvm-platform", - choices=ALL_MICROTVM_PLATFORMS, - required=True, - help="MicroTVM platfrom used for testing.", - ) + parser_test_platform_subparsers = parser_test.add_subparsers(help=platform_help_str) + for platform in ALL_PLATFORMS: + platform_specific_parser = parser_test_platform_subparsers.add_parser(platform) + platform_specific_parser.set_defaults(platform=platform) + platform_specific_parser.add_argument( + "--microtvm-platform", + choices=ALL_MICROTVM_PLATFORMS[platform], + required=True, + help="MicroTVM platfrom used for testing.", + ) + + # Options for release subcommand + parser_release = subparsers.add_parser("release", help="Release base box to cloud.") + parser_release.set_defaults(func=release_command) + parser_release.add_argument("platform", help=platform_help_str, choices=ALL_PLATFORMS) parser_release.add_argument( "--release-version", required=True, @@ -494,7 +534,10 @@ def parse_args(): parser_release.add_argument( "--platform-version", required=True, - help="Platform version to release, in the form 'x.y'.", + help=( + "For Zephyr, the platform version to release, in the form 'x.y'. " + "For Arduino, the version of arduino-cli that's being used, in the form 'x.y.z'." + ), ) return parser.parse_args() diff --git a/apps/microtvm/reference-vm/rebuild-tvm.sh b/apps/microtvm/reference-vm/rebuild-tvm.sh new file mode 100755 index 000000000000..1cebcf7166af --- /dev/null +++ b/apps/microtvm/reference-vm/rebuild-tvm.sh @@ -0,0 +1,43 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +# Get number of cores for build +if [ -n "${TVM_CI_NUM_CORES}" ]; then + num_cores=${TVM_CI_NUM_CORES} +else + # default setup for Vagrantfile + num_cores=2 +fi + +cd "$(dirname $0)" +cd "$(git rev-parse --show-toplevel)" +BUILD_DIR=build-microtvm + +if [ ! -e "${BUILD_DIR}" ]; then + mkdir "${BUILD_DIR}" +fi +cp cmake/config.cmake "${BUILD_DIR}" +cd "${BUILD_DIR}" +sed -i 's/USE_MICRO OFF/USE_MICRO ON/' config.cmake +sed -i 's/USE_GRAPH_EXECUTOR_DEBUG OFF/USE_GRAPH_EXECUTOR_DEBUG ON/' config.cmake +sed -i 's/USE_LLVM OFF/USE_LLVM ON/' config.cmake +cmake .. +rm -rf standalone_crt host_standalone_crt # remove stale generated files +make -j${num_cores} diff --git a/apps/microtvm/reference-vm/zephyr/provision_setup.sh b/apps/microtvm/reference-vm/zephyr/provision_setup.sh index fcefc1176821..2ee2350b377a 100644 --- a/apps/microtvm/reference-vm/zephyr/provision_setup.sh +++ b/apps/microtvm/reference-vm/zephyr/provision_setup.sh @@ -22,7 +22,7 @@ set -ex # NOTE: TVM is presumed to be mounted already by Vagrantfile. cd "${TVM_HOME}" -apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh +apps/microtvm/reference-vm/rebuild-tvm.sh # Build poetry cd apps/microtvm/reference-vm/zephyr diff --git a/docker/install/ubuntu_install_arduino.sh b/docker/install/ubuntu_install_arduino.sh index d5c4303f211b..c374850aa1df 100644 --- a/docker/install/ubuntu_install_arduino.sh +++ b/docker/install/ubuntu_install_arduino.sh @@ -26,7 +26,7 @@ apt-get install -y ca-certificates # Install arduino-cli latest version wget -O - https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh -s -# Install supported cores from those URLS +# Install the cores we want to test on arduino-cli core install arduino:mbed_nano arduino-cli core install arduino:sam diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index ed7288ef00d4..0677292371ae 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -145,6 +145,8 @@ "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32", "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64", # microTVM Virtual Machines + "apps/microtvm/reference-vm/arduino/Vagrantfile", + "apps/microtvm/reference-vm/arduino/base-box/Vagrantfile.packer-template", "apps/microtvm/reference-vm/zephyr/Vagrantfile", "apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template", } diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index bcb2bddf2cab..aea1381a43f8 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -26,6 +26,7 @@ PLATFORMS = { "due": ("sam3x8e", "due"), "feathers2": ("esp32", "feathers2"), + "metrom4": ("atsamd51", "metrom4"), "nano33ble": ("nrf52840", "nano33ble"), "pybadge": ("atsamd51", "pybadge"), "spresense": ("cxd5602gg", "spresense"),