From 709b069897069ba1c76e7988d8c5a520889b42ec Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Wed, 1 Sep 2021 15:20:30 -0500 Subject: [PATCH 01/12] [Docker] Re-enabled automatic --tty flag when running bash. (#8861) PR8382 split apart the --interactive and --tty flags, but only --interactive was set if the user opens a bash session. This commit restores the previous behavior of running `docker/bash.sh IMAGE_NAME` of opening a bash session with both --interactive and --tty. --- docker/bash.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docker/bash.sh b/docker/bash.sh index 2a05abf4f2bc..372cfded8f89 100755 --- a/docker/bash.sh +++ b/docker/bash.sh @@ -38,7 +38,7 @@ set -euo pipefail function show_usage() { cat < [--] [COMMAND] @@ -95,7 +95,7 @@ DOCKER_IMAGE_NAME COMMAND The command to be run inside the docker container. If this is set - to "bash", both the --interactive and --net=host flags are set. + to "bash", the --interactive, --tty and --net=host flags are set. If no command is specified, defaults to "bash". If the command contains dash-prefixed arguments, the command should be preceded by -- to indicate arguments that are not intended for bash.sh. @@ -235,6 +235,7 @@ fi if [[ ${COMMAND[@]+"${COMMAND[@]}"} = bash ]]; then INTERACTIVE=true + TTY=true USE_NET_HOST=true fi From 8e27d6c18f3cde541ed065009faf2fc43902cdea Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 2 Sep 2021 05:49:12 +0800 Subject: [PATCH 02/12] fix error report on Store (#8895) --- python/tvm/script/parser.py | 2 +- .../unittest/test_tvmscript_error_report.py | 43 +++++++------------ 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 9acf21b6ba3a..60fc49678866 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -536,7 +536,7 @@ def transform_SubscriptAssign(self, node): if len(indexes) != 1: self.report_error( f"Store is only allowed with one index, but {len(indexes)} were provided.", - tvm.ir.Span.union([x.span for x in indexes]), + node.params[1].span, ) # Store return tvm.tir.Store( diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 7aeceeccfa89..70a2aea11293 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import pytest +import sys import tvm from tvm import tir from tvm.script import ty, from_source @@ -380,6 +383,17 @@ def test_match_buffer_shape_mismatch(): check_error(buffer_shape_mismatch, 7) +def high_dim_store() -> None: + with tir.block([], "root"): + B = tir.allocate([256], "float32", "global") + for i, j in tir.grid(16, 16): + B[i, j] = 1.0 # error: Store is only allowed with one index + + +def test_high_dim_store(): + check_error(high_dim_store, 5) + + def check_error(module, rel_lineno): # Override the default renderer to accumulate errors _, start_line = inspect.getsourcelines(module) @@ -404,31 +418,4 @@ def render(e): if __name__ == "__main__": - test_buffer_bind() - test_range_missing_args() - test_undefined_buffer() - test_unsupported_stmt() - test_unsupported_function_call() - test_missing_type_annotation() - test_invalid_expr_stmt() - test_invalid_for_function() - test_invalid_block_function() - test_return_not_allowed() - test_tir_assert() - test_no_body() - test_allocate_with_buffers() - test_inconsistent_binding() - test_invalid_block_axes() - test_miss_block_bind() - test_invalid_loop_var() - test_inconsistent_grid() - test_invalid_match_buffer_region() - test_duplicate_buffer() - test_duplicate_block_signature() - test_opaque_access_during_complete() - test_convert_slice_to_bufferload() - test_error_index_type() - test_error_index_with_stop_slice() - test_mismatch_args() - test_tvm_exception_catch() - test_match_buffer_shape_mismatch() + sys.exit(pytest.main([__file__] + sys.argv[1:])) From eaf888c56827ebac1a43f01f93e6d4e6f8623a28 Mon Sep 17 00:00:00 2001 From: mvermeulen <5479696+mvermeulen@users.noreply.github.com> Date: Thu, 2 Sep 2021 01:21:21 -0500 Subject: [PATCH 03/12] [ROCm][TVMC] Add ROCm to the TVMC driver (#8896) * Add ROCm to list of RPC clients. * Add ROCm to list of TVMC devices. * Enable ROCm by adding session call. --- python/tvm/driver/tvmc/runner.py | 4 +++- python/tvm/rpc/client.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 489604d79cf4..5a15d228803d 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -50,7 +50,7 @@ def add_run_parser(subparsers): # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "cuda", "cl", "metal", "vulkan"], + choices=["cpu", "cuda", "cl", "metal", "vulkan", "rocm"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -394,6 +394,8 @@ def run_module( dev = session.metal() elif device == "vulkan": dev = session.vulkan() + elif device == "rocm": + dev = session.rocm() else: assert device == "cpu" dev = session.cpu() diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index d8199c4c93a6..a9834391ed88 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -217,6 +217,10 @@ def metal(self, dev_id=0): """Construct Metal device.""" return self.device(8, dev_id) + def rocm(self, dev_id=0): + """Construct ROCm device.""" + return self.device(10, dev_id) + def ext_dev(self, dev_id=0): """Construct extension device.""" return self.device(12, dev_id) From 910b73e1f7abd3d6e95e905e62eaf3a607b4b5b9 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Thu, 2 Sep 2021 00:18:37 -0700 Subject: [PATCH 04/12] [Onnx] Support Negative Log Loss (#8872) * nll loss v1 * add converter * decode strings in byte form * decode variable length inputs * make shapes correct * unsqueeze * proper weight handling * simplify if statement * fix tests * add comment about tests * delete extra file * lint * so cool * jostle ci Co-authored-by: Andrew Zhao Luo --- python/tvm/relay/frontend/onnx.py | 59 ++++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 19 +------ 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f9b49204b85e..b59c924d8ca3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -22,6 +22,7 @@ import numpy as np import tvm +from tvm import relay from tvm.ir import IRModule from tvm.topi.utils import get_const_tuple @@ -3454,6 +3455,62 @@ def _impl_v1(cls, inputs, attr, params): return vals +class NegativeLogLikelihoodLoss(OnnxOpConverter): + """Operator converter for random_uniform""" + + VALID_REDUCTIONS = {"mean", "sum", "none"} + + @classmethod + def _impl_v13(cls, inputs, attr, params): + ignore_index = attr.get("ignore_index", None) + reduction = attr.get("reduction", b"mean").decode("utf-8") + + if reduction not in cls.VALID_REDUCTIONS: + raise ValueError( + f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}" + ) + + input_tensor, target_tensor = inputs[0], inputs[1] + if len(inputs) == 3: + weight_tensor = inputs[2] + else: + channels = infer_shape(input_tensor)[1] + weight_tensor = relay.ones( + [channels], + dtype=input_tensor.type_annotation.dtype, + ) + + loss = -relay.gather(input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1)) + loss = relay.squeeze(loss, axis=[1]) + + expanded_target_tensor = relay.expand_dims(target_tensor, 0) + expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor) + flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor) + select_weights = relay.reshape_like(flattened_weights, loss) + loss *= select_weights + + if ignore_index is not None: + # "Ignore" values whose target is the ignore_index + mask_tensor = relay.equal( + target_tensor, relay.const(ignore_index, dtype=target_tensor.type_annotation.dtype) + ) + mask_tensor = relay.const(1, dtype="int8") - relay.cast(mask_tensor, "int8") + loss *= relay.cast_like(mask_tensor, loss) + + # This is not explained super clearly in the onnx spec, but masked values don't + # contribute toward the final value in reduction + select_weights *= relay.cast_like(mask_tensor, select_weights) + + weight_total = relay.sum(select_weights) + + if reduction == "mean": + return relay.sum(loss) / weight_total + if reduction == "sum": + return relay.sum(loss) + # Case reduction == 'none' + return loss + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3636,6 +3693,8 @@ def _get_convert_map(opset): "ConvInteger": ConvInteger.get_converter(opset), # Random number generation. "RandomUniform": RandomUniform.get_converter(opset), + # Loss functions + "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a1d821686ed5..15214192148b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4735,41 +4735,24 @@ def verify_eyelike(indata): "test_momentum_multiple", "test_mvn", "test_nesterov_momentum", - "test_nllloss_NC", + # When unsqueeze is fully supported, remaining nllloss tests should work: "test_nllloss_NC_expanded", - "test_nllloss_NCd1", "test_nllloss_NCd1_expanded", - "test_nllloss_NCd1_ii", "test_nllloss_NCd1_ii_expanded", - "test_nllloss_NCd1_mean_weight_negative_ii", "test_nllloss_NCd1_mean_weight_negative_ii_expanded", - "test_nllloss_NCd1_weight", "test_nllloss_NCd1_weight_expanded", - "test_nllloss_NCd1_weight_ii", "test_nllloss_NCd1_weight_ii_expanded", - "test_nllloss_NCd1d2", "test_nllloss_NCd1d2_expanded", - "test_nllloss_NCd1d2_no_weight_reduction_mean_ii", "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded", - "test_nllloss_NCd1d2_reduction_mean", "test_nllloss_NCd1d2_reduction_mean_expanded", - "test_nllloss_NCd1d2_reduction_sum", "test_nllloss_NCd1d2_reduction_sum_expanded", - "test_nllloss_NCd1d2_with_weight", "test_nllloss_NCd1d2_with_weight_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_mean", "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_sum", "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded", - "test_nllloss_NCd1d2_with_weight_reduction_sum_ii", "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded", - "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", - "test_nllloss_NCd1d2d3_sum_weight_high_ii", "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", - "test_nllloss_NCd1d2d3d4d5_mean_weight", "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", - "test_nllloss_NCd1d2d3d4d5_none_no_weight", "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", "test_pow_types_float", "test_pow_types_float32_int32", From b5662125a5a1fec8f99cb6d2a8e5f2557f93139a Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Thu, 2 Sep 2021 08:50:24 +0100 Subject: [PATCH 05/12] Move to new style issue template system (#8898) * Move to new style issue template system This lets us have a template for each type of issue, notably this includes a template for requesting a CI image update. * Fix checkboxes * Codify the use of Discourse rather than raising issues * Change CI to CI Image and introduce CI Issue template * Fix poor english * Add more tags where we have them * CI Issue -> CI Problem --- .github/ISSUE_TEMPLATE/bug-report.md | 27 +++++++++++++++++ .github/ISSUE_TEMPLATE/ci-image.md | 29 +++++++++++++++++++ .github/ISSUE_TEMPLATE/ci-problem.md | 22 ++++++++++++++ .github/ISSUE_TEMPLATE/config.yml | 5 ++++ .../feature-tracking.md} | 15 +++++++--- 5 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE/bug-report.md create mode 100644 .github/ISSUE_TEMPLATE/ci-image.md create mode 100644 .github/ISSUE_TEMPLATE/ci-problem.md create mode 100644 .github/ISSUE_TEMPLATE/config.yml rename .github/{ISSUE_TEMPLATE.md => ISSUE_TEMPLATE/feature-tracking.md} (61%) diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 000000000000..532f5f408b35 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,27 @@ +--- +name: "\U0001F41B Bug report" +about: To help the developer act on the issues, please include a description of your environment, preferably a minimum script to reproduce the problem. +title: "[Bug] " +labels: "type: bug" + +--- + +Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first :smile_cat: + +Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed. + +### Expected behavior + +What you were expecting + +### Actual behavior + +What actually happened + +### Environment + +Any environment details, such as: Operating System, TVM version, etc + +### Steps to reproduce + +Preferably a minimal script to cause the issue to occur. diff --git a/.github/ISSUE_TEMPLATE/ci-image.md b/.github/ISSUE_TEMPLATE/ci-image.md new file mode 100644 index 000000000000..d5abd8f20f80 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/ci-image.md @@ -0,0 +1,29 @@ +--- +name: "\U0001F40B Update CI Docker Image" +about: Provide information on CI Docker Images requiring updates +title: "[CI Image] " + +--- + +Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first :smile_cat: + +Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed. + +- [ ] S0. Reason: For example, a blocked PR or a feature issue + +- [ ] S1. Tag of nightly build: TAG. Docker hub: https://hub.docker.com/layers/tlcpackstaging/ci_cpu/... + +- [ ] S2. The nightly is built on TVM commit: TVM_COMMIT. Detailed info can be found here: https://ci.tlcpack.ai/blue/organizations/jenkins/docker-images-ci%2Fdaily-docker-image-rebuild/detail/daily-docker-image-rebuild/.... + +- [ ] S3. Testing the nightly image on ci-docker-staging: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/ci-docker-staging/... + +- [ ] S4. Retag TAG to VERSION: +``` +docker pull tlcpackstaging/IMAGE_NAME:TAG +docker tag tlcpackstaging/IMAGE_NAME:TAG tlcpack/IMAGE_NAME:VERSION +docker push tlcpack/IMAGE_NAME:VERSION +``` + +- [ ] S5. Check if the new tag is really there: https://hub.docker.com/u/tlcpack + +- [ ] S6. Submit a PR updating the IMAGE_NAME version on Jenkins diff --git a/.github/ISSUE_TEMPLATE/ci-problem.md b/.github/ISSUE_TEMPLATE/ci-problem.md new file mode 100644 index 000000000000..f46a42f42cf5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/ci-problem.md @@ -0,0 +1,22 @@ +--- +name: "\U0000274C CI Problem" +about: To help the developers act on these problems, please give us as many details of the CI failure as possible. +title: "[CI Problem] " + +--- + +Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first :smile_cat: + +Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed. + +### Branch/PR Failing + +Please provide a link to the PR that has failed to run CI. + +### Jenkins Link + +Provide a link to the specific run that has failed. + +### Flakiness + +Have you seen this multiple times in this branch or in other branches? \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000000..ef55b6355308 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false # default: true +contact_links: +- name: 💬 Discourse + url: https://discuss.tvm.apache.org/ + about: Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first 😺 diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE/feature-tracking.md similarity index 61% rename from .github/ISSUE_TEMPLATE.md rename to .github/ISSUE_TEMPLATE/feature-tracking.md index 0e2a130d489e..8dd0648f69d4 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE/feature-tracking.md @@ -1,7 +1,14 @@ -Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first :) +--- +name: "\U0001F527 Feature Tracking" +about: List clear, small actionable items so we can track the progress of the change. +title: "[Tracking Issue] " +labels: type:rfc-tracking -Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed. +--- + +Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first :smile_cat: -For bug reports, to help the developer act on the issues, please include a description of your environment, preferably a minimum script to reproduce the problem. +Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed. -For feature proposals, list clear, small actionable items so we can track the progress of the change. +### This issue is to track progress for FEATURE NAME +- [ ] P1. Title of this piece of the feature (PR link if available) From 8fbd21d82e516bfd9de04710fd3c4da2e68f6b1d Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Thu, 2 Sep 2021 03:10:09 -0500 Subject: [PATCH 06/12] [Vulkan][Topi] Parametrizing additional topi tests, marking vulkan failures (#8904) * [Pytest] Fixed TestTargetAutoParametrization in cases where LLVM is disabled. * [UnitTests][Vulkan] Improved robustness of test_tir_intrin::test_clz Previously, would fail during build since support for Int64 primitives wasn't declared in the `"vulkan"` target. Now, uses `"vulkan -from_device=0"` target and marks the test as xfail if the current target doesn't support Int64. * [UnitTest][Topi] Parametrized several unit tests, identify vulkan failures - Parametrized topi modules - test_topi_conv1d_transpose_ncw.py - test_topi_conv2d_nhwc.py - test_topi_correlation.py - test_topi_loss.py - test_topi_math.py - test_topi_reduce.py - test_topi_softmax.py - test_topi_sort.py - test_topi_unique.py - test_topi_vision.py - Unit Tests fixed - `test_topi_loss::test_nll_loss`, failure due to `supports_float64` not being passed from the target to the codegen. - Known Vulkan failures (tracked in https://github.com/apache/tvm/issues/8903) - test_topi_math.py::test_ewise, ["tan", "erf", "isnan", "isfinite", "isinf"] Unimplemented CallNode operations - test_topi_reduce.py::test_reduce_map, ["sum", "any", "all"] Fails during codegen, unexpected size of data type. - test_topi_vision.py::test_proposal Marked test_proposal as xfail on vulkan, currently has a type error between bool/int8. - test_topi_conv1d_transpose_ncw.py::test_conv1d_transpose_ncw Incorrect numeric output, a few elements outside of allowed tolerance, only occurs on vulkan backend. - test_softmax.py::test_softmax Marked float64 operations as xfail in vulkan, because GLSL.std.450 only supports 16/32-bit floats. --- src/target/spirv/spirv_support.cc | 3 + .../python/test_topi_conv1d_transpose_ncw.py | 161 ++--- .../topi/python/test_topi_conv2d_nhwc.py | 92 ++- .../topi/python/test_topi_correlation.py | 171 ++--- tests/python/topi/python/test_topi_loss.py | 26 +- tests/python/topi/python/test_topi_math.py | 413 ++++++------ tests/python/topi/python/test_topi_reduce.py | 229 +++---- tests/python/topi/python/test_topi_softmax.py | 116 ++-- tests/python/topi/python/test_topi_sort.py | 156 ++--- tests/python/topi/python/test_topi_unique.py | 128 ++-- tests/python/topi/python/test_topi_vision.py | 603 +++++++++--------- tests/python/unittest/test_tir_intrin.py | 60 +- .../unittest/test_tvm_testing_features.py | 5 +- 13 files changed, 1035 insertions(+), 1128 deletions(-) diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 0f1207f3e9a8..1ef56198df7f 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -72,6 +72,9 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { if (target->GetAttr("supports_float16")) { supports_float16 = target->GetAttr("supports_float16").value(); } + if (target->GetAttr("supports_float64")) { + supports_float64 = target->GetAttr("supports_float64").value(); + } if (target->GetAttr("supports_int8")) { supports_int8 = target->GetAttr("supports_int8").value(); } diff --git a/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py b/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py index 81d3b3fd7f3f..93cfecf4239d 100644 --- a/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py +++ b/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py @@ -15,15 +15,18 @@ # specific language governing permissions and limitations # under the License. """Test code for transposed convolution.""" -import numpy as np + import itertools +import os + +import numpy as np + import tvm -from tvm import te -from tvm import topi +import tvm.testing import tvm.topi.testing -from tvm.contrib.pickle_memoize import memoize + +from tvm import te, topi from tvm.topi.utils import get_const_tuple -import tvm.testing _conv1d_transpose_ncw_implement = { "generic": (topi.nn.conv1d_transpose_ncw, topi.generic.schedule_conv1d_transpose_ncw), @@ -31,74 +34,88 @@ } -def verify_conv1d_transpose_ncw( - batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding +( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + output_padding, +) = tvm.testing.parameters( + (1, 3, 224, 32, 5, 1, 0, (0,)), + (1, 3, 224, 32, 7, 1, 2, (0,)), + (1, 3, 224, 32, 5, 2, 1, (0,)), + (1, 3, 224, 32, 5, 2, 1, (1,)), + (1, 3, 224, 32, 5, 2, 0, (0,)), + (1, 32, 32, 128, 5, 1, 0, (0,)), + (1, 32, 32, 128, 5, 2, 1, (0,)), + (1, 1, 1024, 1, 512, 1, 256, (0,)), + (1, 1, 1024, 1, 512, 2, 256, (0,)), + (1, 1, 1024, 1, 512, 5, 256, (0,)), + (1, 1, 1024, 1, 512, 5, 256, (3,)), + (1, 2, 1024, 1, 128, 128, 0, (0,)), + (1, 1, 1024, 2, 128, 128, 0, (0,)), + (1, 1, 1024, 2, 2, 2, 0, (0,)), + (1, 1, 10, 1, 5, 1, (0, 3), (0,)), + (1, 1, 10, 1, 5, 1, (1, 3), (0,)), + (1, 1, 10, 1, 5, 1, (2, 3), (0,)), + (1, 257, 128, 1, 512, 128, 256, (0,)), +) + +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding +): + dtype = "float32" + a_shape = (batch, in_channel, in_size) + w_shape = (in_channel, num_filter, kernel) + + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = tvm.topi.testing.conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding) + c_np = np.maximum(b_np, 0) + return a_np, w_np, b_np, c_np + + +@tvm.testing.known_failing_targets("vulkan") +def test_conv1d_transpose_ncw( + target, + dev, + ref_data, + dtype, + stride, + padding, + output_padding, ): - in_width = in_size - A = te.placeholder((batch, in_channel, in_width), name="A") - W = te.placeholder((in_channel, num_filter, kernel), name="W") - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv1d_transpose.verify_conv1d_transpose_ncw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = tvm.topi.testing.conv1d_transpose_ncw_python( - a_np, w_np, stride, padding, output_padding - ) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def check_target(target, dev): - dev = tvm.device(target, 0) - with tvm.target.Target(target): - fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv1d_transpose_ncw_implement) - B = fcompute(A, W, stride, padding, A.dtype, output_padding) - C = topi.nn.relu(B) - s1 = fschedule([B]) - s2 = fschedule([C]) - a = tvm.nd.array(a_np, dev) - w = tvm.nd.array(w_np, dev) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) - - func1 = tvm.build(s1, [A, W, B], target) - func2 = tvm.build(s2, [A, W, C], target) - func1(a, w, b) - func2(a, w, c) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) - - -@tvm.testing.uses_gpu -def test_conv1d_transpose_ncw(): - verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 1, 0, (0,)) - verify_conv1d_transpose_ncw(1, 3, 224, 32, 7, 1, 2, (0,)) - verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 1, (0,)) - verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 1, (1,)) - verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 0, (0,)) - verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 1, 0, (0,)) - verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 2, 1, (0,)) - verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 1, 256, (0,)) - verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256, (0,)) - verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (0,)) - verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (3,)) - verify_conv1d_transpose_ncw(1, 2, 1024, 1, 128, 128, 0, (0,)) - verify_conv1d_transpose_ncw(1, 1, 1024, 2, 128, 128, 0, (0,)) - verify_conv1d_transpose_ncw(1, 1, 1024, 2, 2, 2, 0, (0,)) - verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0, 3), (0,)) - verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1, 3), (0,)) - verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2, 3), (0,)) - verify_conv1d_transpose_ncw(1, 257, 128, 1, 512, 128, 256, (0,)) + + a_np, w_np, b_np, c_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv1d_transpose_ncw_implement) + B = fcompute(A, W, stride, padding, A.dtype, output_padding) + C = topi.nn.relu(B) + s1 = fschedule([B]) + s2 = fschedule([C]) + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + + func1 = tvm.build(s1, [A, W, B], target) + func2 = tvm.build(s2, [A, W, C], target) + func1(a, w, b) + func2(a, w, c) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) if __name__ == "__main__": - test_conv1d_transpose_ncw() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py b/tests/python/topi/python/test_topi_conv2d_nhwc.py index eb4c5a343b58..96359860f569 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py @@ -27,8 +27,8 @@ _conv2d_nhwc_implement = { - "llvm": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc), - "cuda": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc), + "generic": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc), + "gpu": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc), "cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc), "arm_cpu": ( topi.arm_cpu.conv2d_nhwc_spatial_pack, @@ -45,61 +45,55 @@ "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc), } +dtype = tvm.testing.parameter("float32") -def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): - in_height = in_width = in_size - - A = te.placeholder((batch, in_height, in_width, in_channel), name="A") - W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W") +batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( + (1, 256, 32, 256, 3, 1, "SAME", 1), + (4, 128, 16, 128, 5, 2, "SAME", 1), + (4, 128, 16, 256, 5, 2, "SAME", 1), + (1, 256, 32, 256, 3, 1, "VALID", 1), + (1, 256, 32, 256, 3, 1, "VALID", 1), + (4, 128, 16, 128, 5, 2, "VALID", 1), + (4, 128, 16, 256, 5, 2, "VALID", 1), + (1, 128, 16, 256, 3, 2, (0, 0, 1, 1), 1), + (1, 128, 16, 256, 3, 2, (1, 1, 2, 2), 1), + (1, 128, 16, 128, 5, 2, (3, 3, 2, 2), 1), + (1, 128, 16, 256, 3, 2, (0, 1, 2, 3), 1), + (1, 256, 32, 256, 3, 1, "SAME", 2), + (1, 256, 32, 256, 3, 1, (1, 1, 2, 2), 2), +) - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - dtype = A.dtype - @memoize("topi.tests.test_topi_conv2d_nhwc.verify_nhwc.v2") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) - b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) - return a_np, w_np, b_np +@tvm.testing.fixture(cache_return_value=True) +def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation): + in_height = in_width = in_size + a_shape = (batch, in_height, in_width, in_channel) + w_shape = (kernel, kernel, in_channel, num_filter) - a_np, w_np, b_np = get_ref_data() + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np - def check_device(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv2d_nhwc_implement) - B = fcompute(A, W, stride, padding, dilation, dtype) - s = fschedule([B]) - a = tvm.nd.array(a_np, dev) - w = tvm.nd.array(w_np, dev) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) - func = tvm.build(s, [A, W, B], target) - func(a, w, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) - for target, dev in tvm.testing.enabled_targets(): - check_device(target, dev) +def test_conv2d_nhwc(target, dev, ref_data, dtype, stride, padding, dilation): + a_np, w_np, b_np = ref_data + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) -@tvm.testing.uses_gpu -def test_conv2d_nhwc(): - verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME") - verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "SAME") - verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "SAME") - verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID") - verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID") - verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID") - verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID") - verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (0, 0, 1, 1)) - verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (1, 1, 2, 2)) - verify_conv2d_nhwc(1, 128, 16, 128, 5, 2, (3, 3, 2, 2)) - verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (0, 1, 2, 3)) - # dilation = 2 - verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME", dilation=2) - verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, (1, 1, 2, 2), dilation=2) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv2d_nhwc_implement) + B = fcompute(A, W, stride, padding, dilation, dtype) + s = fschedule([B]) + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + func = tvm.build(s, [A, W, B], target) + func(a, w, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) if __name__ == "__main__": - test_conv2d_nhwc() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_correlation.py b/tests/python/topi/python/test_topi_correlation.py index e6323065d9be..3dff54dfa694 100644 --- a/tests/python/topi/python/test_topi_correlation.py +++ b/tests/python/topi/python/test_topi_correlation.py @@ -15,125 +15,82 @@ # specific language governing permissions and limitations # under the License """test of correlation operator in NCHW layout""" +import sys + import numpy as np +import pytest + import tvm -from tvm import te -from tvm import autotvm -from tvm import topi -import tvm.testing import tvm.topi.testing -from tvm.contrib.pickle_memoize import memoize -from tvm.topi.utils import get_const_tuple + +from tvm import autotvm, te, topi _correlation_implement = { "generic": (topi.nn.correlation_nchw, topi.generic.schedule_correlation_nchw), - "cuda": (topi.cuda.correlation_nchw, topi.cuda.schedule_correlation_nchw), + "gpu": (topi.cuda.correlation_nchw, topi.cuda.schedule_correlation_nchw), } +( + data_shape, + kernel_size, + max_displacement, + stride1, + stride2, + pad_size, + is_multiply, +) = tvm.testing.parameters( + ((1, 3, 10, 10), 1, 4, 1, 1, 4, True), + ((1, 3, 10, 10), 1, 5, 1, 1, 5, True), + ((5, 1, 4, 4), 3, 1, 2, 1, 2, True), + ((5, 1, 6, 4), 3, 1, 2, 2, 2, False), + ((5, 1, 11, 11), 5, 1, 1, 1, 2, False), +) + +dtype = tvm.testing.parameter("float32") -def verify_correlation_nchw( - data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + dtype, data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply ): - print( - "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" - % ( - data_shape[0], - data_shape[1], - data_shape[2], - data_shape[3], - kernel_size, - max_displacement, - stride1, - stride2, - pad_size, - is_multiply, - ) + a_np = np.random.uniform(size=data_shape).astype(dtype) + b_np = np.random.uniform(size=data_shape).astype(dtype) + c_np = tvm.topi.testing.correlation_nchw_python( + a_np, b_np, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply ) + return a_np, b_np, c_np - A = te.placeholder(data_shape, name="data1") - B = te.placeholder(data_shape, name="data2") - dtype = A.dtype - - @memoize("topi.tests.test_topi_correlation_nchw.verify_correlation_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=data_shape).astype(dtype) - b_np = np.random.uniform(size=data_shape).astype(dtype) - c_np = tvm.topi.testing.correlation_nchw_python( - a_np, b_np, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply - ) - return a_np, b_np, c_np - - a_np, b_np, c_np = get_ref_data() - - def check_device(target, dev): - print("Running on target: %s" % target) - fcompute, fschedule = tvm.topi.testing.dispatch(target, _correlation_implement) - with tvm.target.Target(target): - C = fcompute( - A, B, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply - ) - s = fschedule([C]) - - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.empty(c_np.shape, dtype=dtype, device=dev) - - func = tvm.build(s, [A, B, C], target) - func(a, b, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) - - for target, dev in tvm.testing.enabled_targets(): - check_device(target, dev) - - -@tvm.testing.uses_gpu -def test_correlation_nchw(): - verify_correlation_nchw( - (1, 3, 10, 10), - kernel_size=1, - max_displacement=4, - stride1=1, - stride2=1, - pad_size=4, - is_multiply=True, - ) - verify_correlation_nchw( - (1, 3, 10, 10), - kernel_size=1, - max_displacement=5, - stride1=1, - stride2=1, - pad_size=5, - is_multiply=True, - ) - verify_correlation_nchw( - (5, 1, 4, 4), - kernel_size=3, - max_displacement=1, - stride1=2, - stride2=1, - pad_size=2, - is_multiply=True, - ) - verify_correlation_nchw( - (5, 1, 6, 4), - kernel_size=3, - max_displacement=1, - stride1=2, - stride2=2, - pad_size=2, - is_multiply=False, - ) - verify_correlation_nchw( - (5, 1, 11, 11), - kernel_size=5, - max_displacement=1, - stride1=1, - stride2=1, - pad_size=2, - is_multiply=False, - ) + +def test_correlation_nchw( + target, + dev, + ref_data, + dtype, + kernel_size, + max_displacement, + stride1, + stride2, + pad_size, + is_multiply, +): + a_np, b_np, c_np = ref_data + + A = te.placeholder(a_np.shape, name="data1", dtype=dtype) + B = te.placeholder(b_np.shape, name="data2", dtype=dtype) + + fcompute, fschedule = tvm.topi.testing.dispatch(target, _correlation_implement) + with tvm.target.Target(target): + C = fcompute(A, B, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply) + s = fschedule([C]) + + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.empty(c_np.shape, dtype=dtype, device=dev) + + func = tvm.build(s, [A, B, C], target) + func(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) if __name__ == "__main__": - test_correlation_nchw() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_loss.py b/tests/python/topi/python/test_topi_loss.py index bb7655b192f5..c1b61e5b49cd 100644 --- a/tests/python/topi/python/test_topi_loss.py +++ b/tests/python/topi/python/test_topi_loss.py @@ -25,9 +25,17 @@ import tvm.testing -def verify_nll_loss( - dev, target, prediction_shape, reduction="mean", ignore_index=-100, dtype="float32" -): +prediction_shape, reduction, ignore_index, dtype = tvm.testing.parameters( + ((10, 5), "mean", -100, "float32"), + ((10, 5, 2, 2), "mean", -100, "float32"), + ((10, 5), "sum", -100, "float32"), + ((10, 5), "none", -100, "float32"), + ((10, 5), "mean", 3, "float32"), + ((10, 5), "mean", -100, "float64"), +) + + +def test_nll_loss(target, dev, prediction_shape, reduction, ignore_index, dtype): C = prediction_shape[1] target_shape = prediction_shape[:1] + prediction_shape[2:] predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype) @@ -56,15 +64,5 @@ def verify_nll_loss( tvm.testing.assert_allclose(out_topi, out_npy, rtol=1e-4, atol=1e-5) -@tvm.testing.parametrize_targets -def test_nll_loss(dev, target): - verify_nll_loss(dev, target, (10, 5)) - verify_nll_loss(dev, target, (10, 5, 2, 2)) - verify_nll_loss(dev, target, (10, 5), reduction="sum") - verify_nll_loss(dev, target, (10, 5), reduction="none") - verify_nll_loss(dev, target, (10, 5), ignore_index=3) - verify_nll_loss(dev, target, (10, 5), dtype="float64") - - if __name__ == "__main__": - test_nll_loss(tvm.device("cpu"), tvm.target.Target("llvm")) + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_math.py b/tests/python/topi/python/test_topi_math.py index c7f80033bdf3..5ee049fa379a 100644 --- a/tests/python/topi/python/test_topi_math.py +++ b/tests/python/topi/python/test_topi_math.py @@ -14,14 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import sys + import numpy as np +import pytest import scipy from scipy import special + import tvm -from tvm import te -from tvm import topi import tvm.testing import tvm.topi.testing + +from tvm import te, topi from tvm.topi import utils @@ -31,211 +36,205 @@ def test_util(): assert utils.get_const_tuple((x, x)) == (100, 100) -@tvm.testing.uses_gpu -def test_ewise(): - def test_apply( - func, - name, - f_numpy, - low, - high, - shape=(20, 3), - dtype="float32", - check_round=False, - skip_name_check=False, - ): - m = te.var("m") - l = te.var("l") - A = te.placeholder((m, l), dtype=dtype, name="A") - - B = func(A) - assert tuple(B.shape) == tuple(A.shape) - if not skip_name_check: - assert B.op.body[0].op.name == "tir." + name - a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 - # avoid round check too close to boundary - if check_round: - a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-4 - b_np = f_numpy(a_np) - - def check_target(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(B) - foo = tvm.build(s, [A, B], target, name=name) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros_like(b_np), dev) - foo(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5, atol=1e-5) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) - - def test_isnan( - low, - high, - shape=(20, 3), - dtype="float32", - check_round=False, - skip_name_check=False, - ): - m = te.var("m") - l = te.var("l") - A = te.placeholder((m, l), dtype=dtype, name="A") - - B = topi.isnan(A) - assert tuple(B.shape) == tuple(A.shape) - if not skip_name_check: - assert B.op.body[0].op.name == "tir.isnan" - a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 - a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan - # avoid round check too close to boundary - if check_round: - a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5 - b_np = np.isnan(a_np) - - def check_target(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(B) - foo = tvm.build(s, [A, B], target, name="isnan") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros_like(b_np), dev) - foo(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5, atol=1e-5) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) - - def test_infiniteness_ops(topi_op, ref_op, name): - for dtype in ["float32", "float64", "int32", "int16"]: - m = te.var("m") - l = te.var("l") - A = te.placeholder((m, l), dtype=dtype, name="A") - B = topi_op(A) - assert tuple(B.shape) == tuple(A.shape) - - a_np = np.random.uniform(size=(8, 8)).astype(A.dtype) * 10 - if dtype.startswith("float"): - a_np.ravel()[ - np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False) - ] = np.infty - a_np.ravel()[ - np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False) - ] = np.nan - b_np = ref_op(a_np) - - def check_target(target, dev): - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(B) - foo = tvm.build(s, [A, B], target, name=name) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros_like(b_np), dev) - foo(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5, atol=1e-5) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) - - test_apply(topi.floor, "floor", np.floor, -100, 100) - test_apply(topi.ceil, "ceil", np.ceil, -100, 100) - test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True) - test_apply(topi.trunc, "trunc", np.trunc, -100, 100) - test_apply(topi.abs, "fabs", np.abs, -100, 100) - test_apply(topi.round, "round", np.round, -100, 100, check_round=True) - test_apply(topi.exp, "exp", np.exp, -1, 1) - test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128)) - test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128), dtype="float64") - test_apply(topi.sigmoid, "sigmoid", lambda x: 1 / (1 + np.exp(-x)), -1, 1) - test_apply(topi.log, "log", np.log, 0, 100) - test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) - test_apply( - topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True - ) - test_apply(topi.cos, "cos", np.cos, -2.0 * np.pi, 2.0 * np.pi) - test_apply(topi.tan, "tan", np.tan, -2.0 * np.pi, 2.0 * np.pi, dtype="float32") - test_apply(topi.tan, "tan", np.tan, -2.0 * np.pi, 2.0 * np.pi, dtype="float64") - test_apply(topi.sin, "sin", np.sin, -2.0 * np.pi, 2.0 * np.pi) - test_apply(topi.erf, "erf", scipy.special.erf, -0.1, 0.1, dtype="float32") - test_isnan(-100, 100) - test_infiniteness_ops(topi.isfinite, np.isfinite, "isifinite") - test_infiniteness_ops(topi.isinf, np.isinf, "isinf") - - -@tvm.testing.uses_gpu -def test_cast(): - def verify(from_dtype, to_dtype, low=-100, high=100): - shape = (5, 4) - A = te.placeholder(shape, dtype=from_dtype, name="A") - B = topi.cast(A, to_dtype) - - if from_dtype == "bool": - a_np = np.random.choice([True, False], size=shape) - else: - a_np = np.random.uniform(low, high, size=shape).astype(from_dtype) - if to_dtype == "bool": - a_np = a_np - a_np[2, 3] - b_np = a_np.astype(to_dtype) - - for target, dev in tvm.testing.enabled_targets(): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(B) - foo = tvm.build(s, [A, B], target) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.empty(shape=shape, dtype=to_dtype, device=dev) - foo(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np) - - verify("int32", "float32") - verify("int32", "float64") - verify("int32", "bool") - verify("float32", "int32") - verify("float32", "float64") - verify("float32", "bool") - verify("bool", "float32") - verify("bool", "int32") - - -def test_fastmath(): - def test_apply(func, name, f_numpy, low, high, step, dtype="float32"): - a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1)) - b_np = f_numpy(a_np) - A = te.placeholder(a_np.shape, dtype=dtype, name="A") - B = func(A) - assert tuple(B.shape) == tuple(A.shape) - - def check_target(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - with tvm.target.Target(target): - s = topi.generic.schedule_injective(B) - func = tvm.build(s, [A, B], target, name=name) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros_like(b_np), dev) - func(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5, atol=1e-5) - - check_target("llvm") - check_target("llvm -device=arm-cpu") - - test_apply(topi.fast_exp, "fast_exp", np.exp, low=-88, high=88, step=0.01) - test_apply(topi.fast_erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01) - test_apply(topi.fast_tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01) - test_apply( - topi.nn.fast_softmax, - "fast_softmax", - tvm.topi.testing.softmax_python, - low=-10, - high=10, - step=0.01, - ) +ewise_operations = { + "floor": {"topi": topi.floor, "ref": np.floor, "input_range": (-100, 100)}, + "ceil": {"topi": topi.ceil, "ref": np.ceil, "input_range": (-100, 100)}, + "sign": { + "topi": topi.sign, + "ref": np.sign, + "input_range": (-100, 100), + "skip_name_check": True, + }, + "trunc": {"topi": topi.trunc, "ref": np.trunc, "input_range": (-100, 100)}, + "fabs": {"topi": topi.abs, "ref": np.fabs, "input_range": (-100, 100)}, + "round": {"topi": topi.round, "ref": np.round, "input_range": (-100, 100), "check_round": True}, + "exp": {"topi": topi.exp, "ref": np.exp, "input_range": (-1, 1)}, + "tanh": { + "topi": topi.tanh, + "ref": np.tanh, + "input_range": (-10, 10), + "shape": (128, 128), + "dtype": ["float32", "float64"], + }, + "sigmoid": { + "topi": topi.sigmoid, + "ref": lambda x: 1 / (1 + np.exp(-x)), + "input_range": (-1, 1), + }, + "log": {"topi": topi.log, "ref": np.log, "input_range": (0, 100)}, + "sqrt": {"topi": topi.sqrt, "ref": np.sqrt, "input_range": (0, 100)}, + "rsqrt": { + "topi": topi.rsqrt, + "ref": lambda x: np.ones_like(x) / np.sqrt(x), + "input_range": (0, 100), + "skip_name_check": True, + }, + "cos": {"topi": topi.cos, "ref": np.cos, "input_range": (-2.0 * np.pi, 2.0 * np.pi)}, + "tan": { + "topi": topi.tan, + "ref": np.tan, + "input_range": (-2.0 * np.pi, 2.0 * np.pi), + "dtypes": ["float32", "float64"], + }, + "sin": {"topi": topi.sin, "ref": np.sin, "input_range": (-2.0 * np.pi, 2.0 * np.pi)}, + "erf": {"topi": topi.erf, "ref": scipy.special.erf, "input_range": (-0.1, 0.1)}, + "isnan": { + "topi": topi.isnan, + "ref": np.isnan, + "input_range": (-1, 1), + "replace_with_nan": True, + }, + "isfinite": { + "topi": topi.isfinite, + "ref": np.isfinite, + "input_range": (0, 1), + "shape": (8, 8), + "skip_name_check": True, + "replace_with_nan": True, + "replace_with_inf": True, + "dtypes": ["float32", "float64", "int32", "int16"], + }, + "isinf": { + "topi": topi.isinf, + "ref": np.isinf, + "input_range": (0, 1), + "shape": (8, 8), + "skip_name_check": True, + "replace_with_nan": True, + "replace_with_inf": True, + "dtypes": ["float32", "float64", "int32", "int16"], + }, + "fast_exp": { + "topi": topi.fast_exp, + "ref": np.exp, + "skip_name_check": True, + "input_range": (-88, 88), + "step": 0.01, + }, + "fast_erf": { + "topi": topi.fast_erf, + "ref": scipy.special.erf, + "skip_name_check": True, + "input_range": (-10, 10), + "step": 0.01, + }, + "fast_erf": { + "topi": topi.fast_tanh, + "ref": np.tanh, + "skip_name_check": True, + "input_range": (-10, 10), + "step": 0.01, + }, +} + +topi_name, dtype = tvm.testing.parameters( + *[ + (name, dtype) + for name, config in ewise_operations.items() + for dtype in config.get("dtypes", ["float32"]) + ] +) + + +@tvm.testing.fixture(cache_return_value=True) +def ewise_ref_data(topi_name, dtype): + config = ewise_operations[topi_name] + + input_range = config["input_range"] + shape = config.get("shape", (20, 3)) + + a_np = np.random.uniform(*input_range, size=shape).astype(dtype) + + if dtype.startswith("float"): + if config.get("replace_with_nan", False): + a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan + if config.get("replace_with_inf", False): + a_np.ravel()[ + np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False) + ] = np.infty + + # avoid round check too close to boundary + if topi_name == "round": + a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-4 + + b_np = config["ref"](a_np) + + return a_np, b_np + + +def test_ewise(target, dev, topi_name, dtype, ewise_ref_data): + target = tvm.target.Target(target) + if target.kind.name == "vulkan" and topi_name in ["tan", "erf", "isnan", "isfinite", "isinf"]: + pytest.xfail(f"Vulkan runtime doesn't support {topi_name} yet") + + topi_op = ewise_operations[topi_name]["topi"] + skip_name_check = ewise_operations[topi_name].get("skip_name_check", False) + + m = te.var("m") + l = te.var("l") + A = te.placeholder((m, l), dtype=dtype, name="A") + + B = topi_op(A) + assert tuple(B.shape) == tuple(A.shape) + if not skip_name_check: + assert B.op.body[0].op.name == "tir." + topi_name + + a_np, b_np = ewise_ref_data + + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(B) + foo = tvm.build(s, [A, B], target, name=topi_name) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros_like(b_np), dev) + foo(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5, atol=1e-5) + + +from_dtype, to_dtype = tvm.testing.parameters( + ("int32", "float32"), + ("int32", "float64"), + ("int32", "bool"), + ("float32", "int32"), + ("float32", "float64"), + ("float32", "bool"), + ("bool", "float32"), + ("bool", "int32"), +) + + +@tvm.testing.fixture(cache_return_value=True) +def cast_ref_data(from_dtype, to_dtype): + shape = (5, 4) + input_range = (-100, 100) + + if from_dtype == "bool": + a_np = np.random.choice([True, False], size=shape) + else: + a_np = np.random.uniform(*input_range, size=shape).astype(from_dtype) + + if to_dtype == "bool": + a_np = a_np - a_np[2, 3] + b_np = a_np.astype(to_dtype) + + return a_np, b_np + + +def test_cast(target, dev, cast_ref_data, from_dtype, to_dtype): + m = te.var("m") + l = te.var("l") + A = te.placeholder((m, l), dtype=from_dtype, name="A") + B = topi.cast(A, to_dtype) + + a_np, b_np = cast_ref_data + + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(B) + foo = tvm.build(s, [A, B], target) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.empty(b_np.shape, dtype=to_dtype, device=dev) + foo(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np) if __name__ == "__main__": - test_util() - test_ewise() - test_cast() - test_fastmath() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_reduce.py b/tests/python/topi/python/test_topi_reduce.py index 07656032f878..23d762c5002a 100644 --- a/tests/python/topi/python/test_topi_reduce.py +++ b/tests/python/topi/python/test_topi_reduce.py @@ -16,23 +16,76 @@ # under the License. """Test code for reduce.""" import os +import sys + import numpy as np +import pytest + import tvm -from tvm import te -from tvm import topi import tvm.testing import tvm.topi.testing +from tvm import te, topi + +in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters( + ((32,), 0, False, "argmax", "float32"), + ((128, 24, 128, 24), (1, 2, 3), True, "sum", "float32"), + ((2, 3), None, True, "all", "bool"), + ((128, 24 * 128 * 24), (1,), False, "max", "float32"), + ((32, 128, 24), None, True, "sum", "float32"), + ((32, 128, 24), None, True, "all", "bool"), + ((128, 24, 128, 24), (0, 2), False, "min", "float32"), + ((32, 128), 1, True, "argmax", "float32"), + ((32, 24, 32, 24), 2, False, "argmin", "float32"), + ((31, 21, 15), None, True, "argmax", "float32"), + ((31, 21, 15), None, False, "sum", "float32"), + ((128, 24, 128, 24), (1, 2, 3), True, "sum", "float64"), + ((2, 3), None, True, "any", "bool"), + ((32, 128, 24), None, True, "any", "bool"), + ((1, 4, 7), 1, True, "any", "bool"), + ((128, 24, 128, 24), 2, False, "any", "bool"), +) + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data(in_shape, axis, keepdims, reduce_type, dtype): + # Test + if dtype == "bool": + in_npy_map = in_npy = np.random.choice([True, False], size=in_shape) + else: + in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) + in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) + + if reduce_type == "sum": + out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) + elif reduce_type == "all" and dtype == "bool": + out_npy = in_npy_map.all(axis=axis, keepdims=keepdims) + elif reduce_type == "any" and dtype == "bool": + out_npy = in_npy_map.any(axis=axis, keepdims=keepdims) + elif reduce_type == "max": + out_npy = in_npy_map.max(axis=axis, keepdims=keepdims) + elif reduce_type == "min": + out_npy = in_npy_map.min(axis=axis, keepdims=keepdims) + elif reduce_type == "argmax": + out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims) + elif reduce_type == "argmin": + out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) + else: + raise NotImplementedError + + return in_npy, in_npy_map, out_npy + def _my_npy_argmax(arr, axis, keepdims): if not keepdims: return arr.argmax(axis=axis) else: - if axis is not None: + if axis is None: + out_shape = [1 for _ in arr.shape] + else: out_shape = list(arr.shape) out_shape[axis] = 1 - else: - out_shape = [1 for _ in range(len(arr.shape))] + return arr.argmax(axis=axis).reshape(out_shape) @@ -40,120 +93,72 @@ def _my_npy_argmin(arr, axis, keepdims): if not keepdims: return arr.argmin(axis=axis) else: - out_shape = list(arr.shape) - out_shape[axis] = 1 + if axis is None: + out_shape = [1 for _ in arr.shape] + else: + out_shape = list(arr.shape) + out_shape[axis] = 1 return arr.argmin(axis=axis).reshape(out_shape) -def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32"): +def test_reduce_map(target, dev, ref_data, in_shape, axis, keepdims, reduce_type, dtype): + target = tvm.target.Target(target) + if target.kind.name == "vulkan" and reduce_type in ["sum", "any", "all"]: + pytest.xfail(f"Vulkan backend has known errors on {reduce_type}") + + in_npy, in_npy_map, out_npy = ref_data + # Build the logic and compile the function A = te.placeholder(shape=in_shape, name="A", dtype=dtype) A1 = topi.sqrt(topi.exp(A)) out_dtype = dtype - if type == "sum": + if reduce_type == "sum": B = topi.sum(A1, axis=axis, keepdims=keepdims) - elif type == "all": + elif reduce_type == "all": B = topi.all(A, axis=axis, keepdims=keepdims) - elif type == "any": + elif reduce_type == "any": B = topi.any(A, axis=axis, keepdims=keepdims) - elif type == "max": + elif reduce_type == "max": B = topi.max(A1, axis=axis, keepdims=keepdims) - elif type == "min": + elif reduce_type == "min": B = topi.min(A1, axis=axis, keepdims=keepdims) - elif type == "argmax": + elif reduce_type == "argmax": B = topi.argmax(A1, axis=axis, keepdims=keepdims) out_dtype = "int32" - elif type == "argmin": + elif reduce_type == "argmin": B = topi.argmin(A1, axis=axis, keepdims=keepdims) out_dtype = "int32" else: raise NotImplementedError - def check_device(device, dev): - print("Running on target: %s" % device) - with tvm.target.Target(device): - s = tvm.topi.testing.get_reduce_schedule(device)(B) + with tvm.target.Target(target): + s = tvm.topi.testing.get_reduce_schedule(target)(B) - foo = tvm.build(s, [A, B], device, name=type) - # Test - if dtype == "bool": - in_npy_map = in_npy = np.random.choice([True, False], size=in_shape) - else: - in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) - in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) - - if type == "sum": - out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) - elif type == "all" and dtype == "bool": - out_npy = in_npy_map.all(axis=axis, keepdims=keepdims) - elif type == "any" and dtype == "bool": - out_npy = in_npy_map.any(axis=axis, keepdims=keepdims) - elif type == "max": - out_npy = in_npy_map.max(axis=axis, keepdims=keepdims) - elif type == "min": - out_npy = in_npy_map.min(axis=axis, keepdims=keepdims) - elif type == "argmax": - out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims) - elif type == "argmin": - out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) - else: - raise NotImplementedError - data_tvm = tvm.nd.array(in_npy, device=dev) - out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=out_dtype) - for _ in range(1): - foo(data_tvm, out_tvm) - if type == "argmax" or type == "argmin": - out_tvm_indices = out_tvm.numpy() - if keepdims: - out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis) - if axis is None: - out_tvm_val = in_npy_map.ravel()[out_tvm_indices] - else: - other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis + 1) :])) - sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:] - out_tvm_val = in_npy_map[sel_indices] - if type == "argmax": - tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1e-3, 1e-3) - elif type == "argmin": - tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1e-3, 1e-3) + foo = tvm.build(s, [A, B], target, name=reduce_type) + + data_tvm = tvm.nd.array(in_npy, device=dev) + out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=out_dtype) + foo(data_tvm, out_tvm) + + if reduce_type == "argmax" or reduce_type == "argmin": + out_tvm_indices = out_tvm.numpy() + if keepdims: + out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis) + if axis is None: + out_tvm_val = in_npy_map.ravel()[out_tvm_indices] else: - tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3) - - for device, dev in tvm.testing.enabled_targets(): - check_device(device, dev) - - -@tvm.testing.uses_gpu -def test_reduce_map(): - - verify_reduce_map_ele(in_shape=(32,), axis=0, keepdims=False, type="argmax") - verify_reduce_map_ele(in_shape=(128, 24, 128, 24), axis=(1, 2, 3), keepdims=True, type="sum") - verify_reduce_map_ele(in_shape=(2, 3), axis=None, keepdims=True, type="all", dtype="bool") - verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24), axis=(1,), keepdims=False, type="max") - verify_reduce_map_ele(in_shape=(32, 128, 24), axis=None, keepdims=True, type="sum") - verify_reduce_map_ele( - in_shape=(32, 128, 24), axis=None, keepdims=True, dtype="bool", type="all" - ) - verify_reduce_map_ele(in_shape=(128, 24, 128, 24), axis=(0, 2), keepdims=False, type="min") - verify_reduce_map_ele(in_shape=(32, 128), axis=1, keepdims=True, type="argmax") - verify_reduce_map_ele(in_shape=(32, 24, 32, 24), axis=2, keepdims=False, type="argmin") - verify_reduce_map_ele(in_shape=(31, 21, 15), axis=None, keepdims=True, type="argmax") - verify_reduce_map_ele(in_shape=(31, 21, 15), axis=None, keepdims=False, type="sum") - verify_reduce_map_ele( - in_shape=(128, 24, 128, 24), axis=(1, 2, 3), keepdims=True, type="sum", dtype="float64" - ) - verify_reduce_map_ele(in_shape=(2, 3), axis=None, keepdims=True, type="any", dtype="bool") - verify_reduce_map_ele( - in_shape=(32, 128, 24), axis=None, keepdims=True, type="any", dtype="bool" - ) - verify_reduce_map_ele(in_shape=(1, 4, 7), axis=1, keepdims=True, type="any", dtype="bool") - verify_reduce_map_ele( - in_shape=(128, 24, 128, 24), axis=2, keepdims=False, type="any", dtype="bool" - ) - - -@tvm.testing.uses_gpu -def test_complex_reduce(): + other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis + 1) :])) + sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:] + out_tvm_val = in_npy_map[sel_indices] + if reduce_type == "argmax": + tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1e-3, 1e-3) + elif reduce_type == "argmin": + tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1e-3, 1e-3) + else: + tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3) + + +def test_complex_reduce(target, dev): in_shape = (2, 3) dtype = "float32" axis = 0 @@ -163,20 +168,20 @@ def test_complex_reduce(): C = topi.add(B, B) D = topi.multiply(B, B) E = topi.add(C, D) - for device, dev in tvm.testing.enabled_targets(): - print("Running on target: %s" % device) - with tvm.target.Target(device): - s = tvm.topi.testing.get_reduce_schedule(device)(E) - foo = tvm.build(s, [A, E], device, name="sum") - in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) - sum_npy = in_npy.sum(axis=axis, keepdims=keepdims) - out_npy = sum_npy * 2 + sum_npy * sum_npy - data_tvm = tvm.nd.array(in_npy, device=dev) - out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=dtype) - foo(data_tvm, out_tvm) - tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3) + + with tvm.target.Target(target): + s = tvm.topi.testing.get_reduce_schedule(target)(E) + foo = tvm.build(s, [A, E], target, name="sum") + + in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) + sum_npy = in_npy.sum(axis=axis, keepdims=keepdims) + out_npy = sum_npy * 2 + sum_npy * sum_npy + + data_tvm = tvm.nd.array(in_npy, device=dev) + out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=dtype) + foo(data_tvm, out_tvm) + tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3) if __name__ == "__main__": - test_reduce_map() - test_complex_reduce() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_softmax.py b/tests/python/topi/python/test_topi_softmax.py index 8af038a1f7ce..97fbedcc288e 100644 --- a/tests/python/topi/python/test_topi_softmax.py +++ b/tests/python/topi/python/test_topi_softmax.py @@ -15,14 +15,17 @@ # specific language governing permissions and limitations # under the License. """Test code for softmax""" +import logging import os +import sys + import numpy as np +import pytest + import tvm -from tvm import te -from tvm import topi import tvm.testing import tvm.topi.testing -import logging +from tvm import te, topi from tvm.topi.utils import get_const_tuple @@ -34,75 +37,72 @@ } -def check_target(A, B, a_np, b_np, target, dev, name): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s_func = tvm.topi.testing.dispatch(target, _softmax_schedule) - s = s_func(B) - - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) - f = tvm.build(s, [A, B], target, name=name) - f(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) +dtype = tvm.testing.parameter("float32", "float64") -def verify_softmax(m, n, dtype="float32"): - A = te.placeholder((m, n), dtype=dtype, name="A") - B = topi.nn.softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = tvm.topi.testing.softmax_python(a_np) +configs = { + "softmax": { + "topi": topi.nn.softmax, + "ref": tvm.topi.testing.softmax_python, + "dimensions": [2, 4], + }, + "log_softmax": { + "topi": topi.nn.log_softmax, + "ref": tvm.topi.testing.log_softmax_python, + "dimensions": [2], + }, +} +shapes = [(32, 10), (3, 4), (1, 16, 256, 256)] +softmax_operation, shape = tvm.testing.parameters( + *[ + (name, shape) + for name, config in configs.items() + for shape in shapes + if len(shape) in config["dimensions"] + ] +) - for target, dev in tvm.testing.enabled_targets(): - check_target(A, B, a_np, b_np, target, dev, "softmax") +@tvm.testing.fixture(cache_return_value=True) +def ref_data(shape, dtype, softmax_operation): + ref_func = configs[softmax_operation]["ref"] -def verify_softmax_4d(shape, dtype="float32"): - A = te.placeholder(shape, dtype=dtype, name="A") - B = topi.nn.softmax(A, axis=1) + a_np = np.random.uniform(size=shape).astype(dtype) - _, c, h, w = shape - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c)) - b_np = b_np.reshape(1, h, w, c).transpose(0, 3, 1, 2) + if len(shape) == 2: + b_np = ref_func(a_np) + elif len(shape) == 4: + _, c, h, w = a_np.shape + a_np_2d = a_np.transpose(0, 2, 3, 1).reshape(h * w, c) + b_np_2d = tvm.topi.testing.softmax_python(a_np_2d) + b_np = b_np_2d.reshape(1, h, w, c).transpose(0, 3, 1, 2) - for target, dev in tvm.testing.enabled_targets(): - check_target(A, B, a_np, b_np, target, dev, "softmax") + return a_np, b_np -@tvm.testing.uses_gpu -def test_softmax(): - verify_softmax(32, 10) - verify_softmax(3, 4) - verify_softmax(32, 10, "float64") - verify_softmax_4d((1, 16, 256, 256)) +def test_softmax(target, dev, shape, dtype, ref_data, softmax_operation): + target = tvm.target.Target(target) + if target.kind.name == "vulkan" and dtype == "float64": + # https://www.khronos.org/registry/SPIR-V/specs/1.0/GLSL.std.450.html + pytest.xfail("Vulkan GLSL.std.450 does not support 64-bit floats") + A = te.placeholder(shape, dtype=dtype, name="A") -def verify_log_softmax(m, n, dtype="float32"): - A = te.placeholder((m, n), dtype=dtype, name="A") - B = topi.nn.log_softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = tvm.topi.testing.log_softmax_python(a_np) + topi_op = configs[softmax_operation]["topi"] + B = topi_op(A, axis=1) - for target, dev in tvm.testing.enabled_targets(): - check_target(A, B, a_np, b_np, target, dev, "log_softmax") + with tvm.target.Target(target): + fschedule = tvm.topi.testing.dispatch(target, _softmax_schedule) + s = fschedule(B) + a_np, b_np = ref_data -@tvm.testing.uses_gpu -def test_log_softmax(): - verify_log_softmax(32, 10) - verify_log_softmax(3, 4) - verify_log_softmax(32, 10, "float64") + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + f = tvm.build(s, [A, B], target) + f(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - test_softmax() - test_log_softmax() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_sort.py b/tests/python/topi/python/test_topi_sort.py index 65b2ae590308..43c6ce88be76 100644 --- a/tests/python/topi/python/test_topi_sort.py +++ b/tests/python/topi/python/test_topi_sort.py @@ -15,13 +15,16 @@ # specific language governing permissions and limitations # under the License. """Test code for vision package""" -from __future__ import print_function +import sys + import numpy as np +import pytest + import tvm -from tvm import te -from tvm import topi -import tvm.topi.testing import tvm.testing +import tvm.topi.testing + +from tvm import te, topi _sort_implement = { "generic": (topi.sort, topi.generic.schedule_sort), @@ -38,8 +41,17 @@ "gpu": (topi.cuda.topk, topi.cuda.schedule_topk), } +axis = tvm.testing.parameter(0, -1, 1) +is_ascend = tvm.testing.parameter(True, False, ids=["is_ascend", "not_ascend"]) +dtype = tvm.testing.parameter("int64", "float32") + +topk = tvm.testing.parameter(0, 1, 5) +topk_ret_type = tvm.testing.parameter("values", "indices", "both") + + +def test_sort(target, dev, axis, is_ascend): + np.random.seed(0) -def verify_sort(axis, is_ascend): dshape = (20, 100) data_dtype = "float32" data = te.placeholder(dshape, name="data", dtype=data_dtype) @@ -58,28 +70,19 @@ def verify_sort(axis, is_ascend): else: np_sort = np_sort[:, : dshape[axis]] - def check_target(target): - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - dev = tvm.device(target, 0) - print("Running on target: %s" % target) - with tvm.target.Target(target): - fcompute, fschedule = tvm.topi.testing.dispatch(target, _sort_implement) - out = fcompute(data, axis=axis, is_ascend=is_ascend) - s = fschedule(out) - - tvm_data = tvm.nd.array(np_data, dev) - tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), dev) - f = tvm.build(s, [data, out], target) - f(tvm_data, tvm_out) - tvm.testing.assert_allclose(tvm_out.numpy(), np_sort, rtol=1e0) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _sort_implement) + out = fcompute(data, axis=axis, is_ascend=is_ascend) + s = fschedule(out) - for target in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: - check_target(target) + tvm_data = tvm.nd.array(np_data, dev) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), dev) + f = tvm.build(s, [data, out], target) + f(tvm_data, tvm_out) + tvm.testing.assert_allclose(tvm_out.numpy(), np_sort, rtol=1e0) -def verify_argsort(axis, is_ascend): +def test_argsort(target, dev, axis, is_ascend): dshape = (20, 100) data_dtype = "float32" data = te.placeholder(dshape, name="data", dtype=data_dtype) @@ -98,28 +101,21 @@ def verify_argsort(axis, is_ascend): else: np_indices = np_indices[:, : dshape[axis]] - def check_target(target): - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - dev = tvm.device(target, 0) - print("Running on target: %s" % target) - with tvm.target.Target(target): - fcompute, fschedule = tvm.topi.testing.dispatch(target, _argsort_implement) - out = fcompute(data, axis=axis, is_ascend=is_ascend) - s = fschedule(out) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _argsort_implement) + out = fcompute(data, axis=axis, is_ascend=is_ascend) + s = fschedule(out) - tvm_data = tvm.nd.array(np_data, dev) - tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), dev) - f = tvm.build(s, [data, out], target) - f(tvm_data, tvm_out) - tvm.testing.assert_allclose(tvm_out.numpy(), np_indices.astype(data_dtype), rtol=1e0) + tvm_data = tvm.nd.array(np_data, dev) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), dev) + f = tvm.build(s, [data, out], target) + f(tvm_data, tvm_out) + tvm.testing.assert_allclose(tvm_out.numpy(), np_indices.astype(data_dtype), rtol=1e0) - for target in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: - check_target(target) +def test_topk(target, dev, topk, axis, topk_ret_type, is_ascend, dtype): + np.random.seed(0) -def verify_topk(k, axis, ret_type, is_ascend, dtype): shape = (20, 100) data_dtype = "float32" data = te.placeholder(shape, name="data", dtype=data_dtype) @@ -129,7 +125,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): np_indices = np.argsort(np_data, axis=axis) else: np_indices = np.argsort(-np_data, axis=axis) - kk = k if k >= 1 else shape[axis] + kk = topk if topk >= 1 else shape[axis] if axis == 0: np_indices = np_indices[:kk, :] np_values = np.zeros(np_indices.shape).astype(data_dtype) @@ -142,61 +138,25 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): np_values[i, :] = np_data[i, np_indices[i, :]] np_indices = np_indices.astype(dtype) - def check_target(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) - with tvm.target.Target(target): - fcompute, fschedule = tvm.topi.testing.dispatch(target, _topk_implement) - outs = fcompute(data, k, axis, ret_type, is_ascend, dtype) - outs = outs if isinstance(outs, list) else [outs] - s = fschedule(outs) - tvm_data = tvm.nd.array(np_data, dev) - tvm_res = [] - for t in outs: - tvm_res.append(tvm.nd.empty(t.shape, dtype=t.dtype, device=dev)) - f = tvm.build(s, [data] + outs, target) - f(tvm_data, *tvm_res) - if ret_type == "both": - tvm.testing.assert_allclose(tvm_res[0].numpy(), np_values) - tvm.testing.assert_allclose(tvm_res[1].numpy(), np_indices) - elif ret_type == "values": - tvm.testing.assert_allclose(tvm_res[0].numpy(), np_values) - else: - tvm.testing.assert_allclose(tvm_res[0].numpy(), np_indices) - - for target in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: - check_target(target) - - -@tvm.testing.uses_gpu -def test_sort(): - np.random.seed(0) - for axis in [0, -1, 1]: - verify_sort(axis, True) - verify_sort(axis, False) - - -@tvm.testing.uses_gpu -def test_argsort(): - np.random.seed(0) - for axis in [0, -1, 1]: - verify_argsort(axis, True) - verify_argsort(axis, False) - - -@tvm.testing.uses_gpu -def test_topk(): - np.random.seed(0) - for k in [0, 1, 5]: - for axis in [0, -1, 1]: - for ret_type in ["both", "values", "indices"]: - verify_topk(k, axis, ret_type, True, "int64") - verify_topk(k, axis, ret_type, False, "float32") + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _topk_implement) + outs = fcompute(data, topk, axis, topk_ret_type, is_ascend, dtype) + outs = outs if isinstance(outs, list) else [outs] + s = fschedule(outs) + tvm_data = tvm.nd.array(np_data, dev) + tvm_res = [] + for t in outs: + tvm_res.append(tvm.nd.empty(t.shape, dtype=t.dtype, device=dev)) + f = tvm.build(s, [data] + outs, target) + f(tvm_data, *tvm_res) + if topk_ret_type == "both": + tvm.testing.assert_allclose(tvm_res[0].numpy(), np_values) + tvm.testing.assert_allclose(tvm_res[1].numpy(), np_indices) + elif topk_ret_type == "values": + tvm.testing.assert_allclose(tvm_res[0].numpy(), np_values) + else: + tvm.testing.assert_allclose(tvm_res[0].numpy(), np_indices) if __name__ == "__main__": - test_argsort() - test_topk() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py index 3e26241cea94..4dd9b193ad57 100644 --- a/tests/python/topi/python/test_topi_unique.py +++ b/tests/python/topi/python/test_topi_unique.py @@ -20,9 +20,14 @@ from tvm import topi import tvm.topi.testing +in_dtype = tvm.testing.parameter("int32", "int64") +is_sorted = tvm.testing.parameter(True, False, ids=["sorted", "unsorted"]) +with_counts = tvm.testing.parameter(True, False, ids=["with_counts", "no_counts"]) +arr_size, maxval = tvm.testing.parameters((1, 100), (10, 10), (10000, 100)) + @tvm.testing.parametrize_targets -def test_unique(dev, target): +def test_unique(dev, target, in_dtype, is_sorted, with_counts, arr_size, maxval): def calc_numpy_unique(data, is_sorted=False): uniq, index, inverse, counts = np.unique( data, return_index=True, return_inverse=True, return_counts=True @@ -43,82 +48,67 @@ def calc_numpy_unique(data, is_sorted=False): num_uniq, ] - def check_unique(data, is_sorted=False, with_counts=False): - # numpy reference - np_unique, np_indices, np_inverse_indices, np_counts, np_num_unique = calc_numpy_unique( - data, is_sorted - ) - num_unique = np_num_unique[0] + data = np.random.randint(0, maxval, size=(arr_size)).astype(in_dtype) - implementations = { - "generic": ( - lambda x, return_counts: topi.unique(x, is_sorted, return_counts), - topi.generic.schedule_unique, - ), - "cuda": ( - lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts), - topi.cuda.schedule_scan, - ), - "nvptx": ( - lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts), - topi.cuda.schedule_scan, - ), - } - fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) - tvm_data = tvm.nd.array(data, device=dev) - tvm_unique = tvm.nd.array(np.zeros(data.shape).astype(data.dtype), device=dev) - tvm_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) - tvm_inverse_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) - tvm_num_unique = tvm.nd.array(np.zeros([1]).astype("int32"), device=dev) + # numpy reference + np_unique, np_indices, np_inverse_indices, np_counts, np_num_unique = calc_numpy_unique( + data, is_sorted + ) + num_unique = np_num_unique[0] - with tvm.target.Target(target): - te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) - outs = fcompute(te_input, with_counts) - s = fschedule(outs) - func = tvm.build(s, [te_input, *outs]) + implementations = { + "generic": ( + lambda x, return_counts: topi.unique(x, is_sorted, return_counts), + topi.generic.schedule_unique, + ), + "gpu": ( + lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts), + topi.cuda.schedule_scan, + ), + "nvptx": ( + lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts), + topi.cuda.schedule_scan, + ), + } + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + tvm_data = tvm.nd.array(data, device=dev) + tvm_unique = tvm.nd.array(np.zeros(data.shape).astype(data.dtype), device=dev) + tvm_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) + tvm_inverse_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) + tvm_num_unique = tvm.nd.array(np.zeros([1]).astype("int32"), device=dev) - if with_counts: - tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) - func( - tvm_data, - tvm_unique, - tvm_indices, - tvm_inverse_indices, - tvm_num_unique, - tvm_counts, - ) - else: - func(tvm_data, tvm_unique, tvm_indices, tvm_inverse_indices, tvm_num_unique) + with tvm.target.Target(target): + te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) + outs = fcompute(te_input, with_counts) + s = fschedule(outs) + func = tvm.build(s, [te_input, *outs]) - num_unique = np_num_unique[0] - assert tvm_num_unique.numpy()[0] == np_num_unique + if with_counts: + tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) + func( + tvm_data, + tvm_unique, + tvm_indices, + tvm_inverse_indices, + tvm_num_unique, + tvm_counts, + ) + else: + func(tvm_data, tvm_unique, tvm_indices, tvm_inverse_indices, tvm_num_unique) - np.testing.assert_allclose(tvm_unique.numpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose( - tvm_indices.numpy()[:num_unique], np_indices, atol=1e-5, rtol=1e-5 - ) + num_unique = np_num_unique[0] + assert tvm_num_unique.numpy()[0] == np_num_unique - np.testing.assert_allclose( - tvm_inverse_indices.numpy(), np_inverse_indices, atol=1e-5, rtol=1e-5 - ) + np.testing.assert_allclose(tvm_unique.numpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose(tvm_indices.numpy()[:num_unique], np_indices, atol=1e-5, rtol=1e-5) - if with_counts: - np.testing.assert_allclose( - tvm_counts.numpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5 - ) + np.testing.assert_allclose( + tvm_inverse_indices.numpy(), np_inverse_indices, atol=1e-5, rtol=1e-5 + ) - for in_dtype in ["int32", "int64"]: - for is_sorted in [True, False]: - for with_counts in [True, False]: - data = np.random.randint(0, 100, size=(1)).astype(in_dtype) - check_unique(data, is_sorted, with_counts) - data = np.random.randint(0, 10, size=(10)).astype(in_dtype) - check_unique(data, is_sorted, with_counts) - data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) - check_unique(data, is_sorted, with_counts) + if with_counts: + np.testing.assert_allclose(tvm_counts.numpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5) if __name__ == "__main__": - test_unique(tvm.device("cpu"), tvm.target.Target("llvm")) - test_unique(tvm.device("cuda"), tvm.target.Target("cuda")) - test_unique(tvm.device("nvptx"), tvm.target.Target("nvptx")) + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 234107d6686e..6ddb86f4027f 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -15,19 +15,18 @@ # specific language governing permissions and limitations # under the License. """Test code for vision package""" -from __future__ import print_function import math +import sys + import numpy as np +import pytest + import tvm -from tvm import te -from tvm import topi +import tvm.testing import tvm.topi.testing -from tvm.contrib.pickle_memoize import memoize -from tvm.topi.utils import get_const_tuple +from tvm import te, topi from tvm.topi.vision import ssd, non_max_suppression, get_valid_counts -import pytest -import tvm.testing _get_valid_counts_implement = { "generic": (topi.vision.get_valid_counts, topi.generic.schedule_get_valid_counts), @@ -71,35 +70,46 @@ } -def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): - dtype = "float32" - batch_size, num_anchor, elem_length = dshape - np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) - np_out1 = np.zeros(shape=(batch_size,)) - np_out2 = np.zeros(shape=dshape).astype(dtype) - np_out3 = np.zeros(shape=(batch_size, num_anchor)) - for i in range(batch_size): - np_out1[i] = 0 - inter_idx = 0 - for j in range(num_anchor): - score = np_data[i, j, score_index] - if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0): - for k in range(elem_length): - np_out2[i, inter_idx, k] = np_data[i, j, k] - np_out1[i] += 1 - np_out3[i, inter_idx] = j - inter_idx += 1 - if j >= np_out1[i]: - for k in range(elem_length): - np_out2[i, j, k] = -1.0 - np_out3[i, j] = -1 - - def check_device(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) +class TestValidCounts: + dshape, score_threshold, id_index, score_index = tvm.testing.parameters( + ((1, 1000, 5), 0.5, -1, 0), + ((1, 2500, 6), 0, 0, 1), + ((1, 2500, 5), -1, -1, 0), + ((3, 1000, 6), 0.55, 1, 0), + ((16, 500, 5), 0.95, -1, 1), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.fixture(cache_return_value=True) + def ref_data(self, dtype, dshape, score_threshold, id_index, score_index): + batch_size, num_anchor, elem_length = dshape + np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) + np_out1 = np.zeros(shape=(batch_size,)) + np_out2 = np.zeros(shape=dshape).astype(dtype) + np_out3 = np.zeros(shape=(batch_size, num_anchor)) + for i in range(batch_size): + np_out1[i] = 0 + inter_idx = 0 + for j in range(num_anchor): + score = np_data[i, j, score_index] + if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0): + for k in range(elem_length): + np_out2[i, inter_idx, k] = np_data[i, j, k] + np_out1[i] += 1 + np_out3[i, inter_idx] = j + inter_idx += 1 + if j >= np_out1[i]: + for k in range(elem_length): + np_out2[i, j, k] = -1.0 + np_out3[i, j] = -1 + + return np_data, np_out1, np_out2, np_out3 + + def test_get_valid_counts( + self, target, dev, ref_data, dtype, dshape, score_threshold, id_index, score_index + ): + np_data, np_out1, np_out2, np_out3 = ref_data + with tvm.target.Target(target): fcompute, fschedule = tvm.topi.testing.dispatch(target, _get_valid_counts_implement) data = te.placeholder(dshape, name="data", dtype=dtype) @@ -117,20 +127,10 @@ def check_device(target): tvm.testing.assert_allclose(tvm_out2.numpy(), np_out2, rtol=1e-3) tvm.testing.assert_allclose(tvm_out3.numpy(), np_out3, rtol=1e-3) - for target in ["llvm", "cuda", "opencl", "vulkan"]: - check_device(target) - - -@tvm.testing.uses_gpu -def test_get_valid_counts(): - verify_get_valid_counts((1, 1000, 5), 0.5, -1, 0) - verify_get_valid_counts((1, 2500, 6), 0, 0, 1) - verify_get_valid_counts((1, 2500, 5), -1, -1, 0) - verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0) - verify_get_valid_counts((16, 500, 5), 0.95, -1, 1) - def verify_non_max_suppression( + target, + dev, np_data, np_valid_count, np_indices, @@ -151,63 +151,53 @@ def verify_non_max_suppression( valid_count = te.placeholder((batch,), dtype="int32", name="valid_count") indices = te.placeholder((batch, num_anchors), dtype="int32", name="indices") - def check_device(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) - with tvm.target.Target(target): - fcompute, fschedule = tvm.topi.testing.dispatch(target, _nms_implement) - out = fcompute( - data, - valid_count, - indices, - max_output_size, - iou_threshold, - force_suppress, - top_k, - coord_start=coord_start, - score_index=score_index, - id_index=id_index, - return_indices=False, - ) - indices_out = fcompute( - data, - valid_count, - indices, - max_output_size, - iou_threshold, - force_suppress, - top_k, - coord_start=coord_start, - score_index=score_index, - id_index=id_index, - return_indices=True, - ) - s = fschedule(out) - indices_s = fschedule(indices_out) - - tvm_data = tvm.nd.array(np_data, dev) - tvm_valid_count = tvm.nd.array(np_valid_count, dev) - tvm_indices = tvm.nd.array(np_indices, dev) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _nms_implement) + out = fcompute( + data, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start=coord_start, + score_index=score_index, + id_index=id_index, + return_indices=False, + ) + indices_out = fcompute( + data, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start=coord_start, + score_index=score_index, + id_index=id_index, + return_indices=True, + ) + s = fschedule(out) + indices_s = fschedule(indices_out) - tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), dev) - f = tvm.build(s, [data, valid_count, indices, out], target) - f(tvm_data, tvm_valid_count, tvm_indices, tvm_out) - tvm.testing.assert_allclose(tvm_out.numpy(), np_result, rtol=1e-4) + tvm_data = tvm.nd.array(np_data, dev) + tvm_valid_count = tvm.nd.array(np_valid_count, dev) + tvm_indices = tvm.nd.array(np_indices, dev) - tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), dev) - f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], target) - f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) - tvm.testing.assert_allclose(tvm_indices_out.numpy(), np_indices_result, rtol=1e-4) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), dev) + f = tvm.build(s, [data, valid_count, indices, out], target) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_out) + tvm.testing.assert_allclose(tvm_out.numpy(), np_result, rtol=1e-4) - for target in ["llvm", "cuda", "opencl", "nvptx"]: - check_device(target) + tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), dev) + f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], target) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) + tvm.testing.assert_allclose(tvm_indices_out.numpy(), np_indices_result, rtol=1e-4) -@tvm.testing.uses_gpu -def test_non_max_suppression(): +def test_non_max_suppression(target, dev): np_data = np.array( [ [ @@ -236,6 +226,8 @@ def test_non_max_suppression(): np_indices_result = np.array([[3, 0, -1, -1, -1]]) verify_non_max_suppression( + target, + dev, np_data, np_valid_count, np_indices, @@ -277,6 +269,8 @@ def test_non_max_suppression(): ) np_indices_result = np.array([[3, 0, -1, -1, -1]]) verify_non_max_suppression( + target, + dev, np_data, np_valid_count, np_indices, @@ -292,91 +286,85 @@ def test_non_max_suppression(): ) -def verify_multibox_prior( - dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False -): - data = te.placeholder(dshape, name="data") +class TestMultiboxPrior: + dshape, sizes, ratios, steps, offsets, clip = tvm.testing.parameters( + ((1, 3, 50, 50), (1,), (1,), (-1, -1), (0.5, 0.5), False), + ((1, 3, 224, 224), (0.5, 0.25, 0.1), (1, 2, 0.5), (-1, -1), (0.5, 0.5), False), + ((1, 32, 32, 32), (0.5, 0.25), (1, 2), (2, 2), (0.5, 0.5), True), + ) - dtype = data.dtype - input_data = np.random.uniform(size=dshape).astype(dtype) - - in_height = data.shape[2].value - in_width = data.shape[3].value - num_sizes = len(sizes) - num_ratios = len(ratios) - size_ratio_concat = sizes + ratios - steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height - steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width - offset_h = offsets[0] - offset_w = offsets[1] - - oshape = (1, in_height * in_width * (num_sizes + num_ratios - 1), 4) - np_out = np.zeros(oshape).astype(dtype) - - for i in range(in_height): - center_h = (i + offset_h) * steps_h - for j in range(in_width): - center_w = (j + offset_w) * steps_w - for k in range(num_sizes + num_ratios - 1): - w = ( - size_ratio_concat[k] * in_height / in_width / 2.0 - if k < num_sizes - else size_ratio_concat[0] - * in_height - / in_width - * math.sqrt(size_ratio_concat[k + 1]) - / 2.0 - ) - h = ( - size_ratio_concat[k] / 2.0 - if k < num_sizes - else size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0 - ) - count = ( - i * in_width * (num_sizes + num_ratios - 1) - + j * (num_sizes + num_ratios - 1) - + k - ) - np_out[0][count][0] = center_w - w - np_out[0][count][1] = center_h - h - np_out[0][count][2] = center_w + w - np_out[0][count][3] = center_h + h - if clip: - np_out = np.clip(np_out, 0, 1) - - def check_device(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.fixture(cache_return_value=True) + def ref_data(self, dtype, dshape, sizes, ratios, offsets, steps, clip): + in_height = dshape[2] + in_width = dshape[3] + num_sizes = len(sizes) + num_ratios = len(ratios) + size_ratio_concat = sizes + ratios + steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height + steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width + offset_h = offsets[0] + offset_w = offsets[1] + + out_shape = (1, in_height * in_width * (num_sizes + num_ratios - 1), 4) + + np_in = np.random.uniform(size=dshape).astype(dtype) + np_out = np.zeros(out_shape).astype(dtype) + + for i in range(in_height): + center_h = (i + offset_h) * steps_h + for j in range(in_width): + center_w = (j + offset_w) * steps_w + for k in range(num_sizes + num_ratios - 1): + w = ( + size_ratio_concat[k] * in_height / in_width / 2.0 + if k < num_sizes + else size_ratio_concat[0] + * in_height + / in_width + * math.sqrt(size_ratio_concat[k + 1]) + / 2.0 + ) + h = ( + size_ratio_concat[k] / 2.0 + if k < num_sizes + else size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0 + ) + count = ( + i * in_width * (num_sizes + num_ratios - 1) + + j * (num_sizes + num_ratios - 1) + + k + ) + np_out[0][count][0] = center_w - w + np_out[0][count][1] = center_h - h + np_out[0][count][2] = center_w + w + np_out[0][count][3] = center_h + h + if clip: + np_out = np.clip(np_out, 0, 1) + + return np_in, np_out + + def test_multibox_prior( + self, target, dev, dtype, dshape, ref_data, sizes, ratios, steps, offsets, clip + ): + np_in, np_out = ref_data + + data = te.placeholder(dshape, name="data", dtype=dtype) fcompute, fschedule = tvm.topi.testing.dispatch(target, _multibox_prior_implement) with tvm.target.Target(target): out = fcompute(data, sizes, ratios, steps, offsets, clip) s = fschedule(out) - tvm_input_data = tvm.nd.array(input_data, dev) - tvm_out = tvm.nd.array(np.zeros(oshape, dtype=dtype), dev) + tvm_input_data = tvm.nd.array(np_in, dev) + tvm_out = tvm.nd.array(np.zeros(np_out.shape, dtype=dtype), dev) f = tvm.build(s, [data, out], target) f(tvm_input_data, tvm_out) tvm.testing.assert_allclose(tvm_out.numpy(), np_out, rtol=1e-3) - for target in ["llvm", "opencl", "cuda"]: - check_device(target) - - -@tvm.testing.uses_gpu -def test_multibox_prior(): - verify_multibox_prior((1, 3, 50, 50)) - verify_multibox_prior((1, 3, 224, 224), sizes=(0.5, 0.25, 0.1), ratios=(1, 2, 0.5)) - verify_multibox_prior( - (1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True - ) - -@tvm.testing.uses_gpu -def test_multibox_detection(): +def test_multibox_detection(target, dev): batch_size = 1 num_anchors = 3 num_classes = 3 @@ -399,41 +387,56 @@ def test_multibox_detection(): ] ) - def check_device(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) - - fcompute, fschedule = tvm.topi.testing.dispatch(target, _multibox_detection_implement) - with tvm.target.Target(target): - out = fcompute(cls_prob, loc_preds, anchors) - s = fschedule(out) - - tvm_cls_prob = tvm.nd.array(np_cls_prob.astype(cls_prob.dtype), dev) - tvm_loc_preds = tvm.nd.array(np_loc_preds.astype(loc_preds.dtype), dev) - tvm_anchors = tvm.nd.array(np_anchors.astype(anchors.dtype), dev) - tvm_out = tvm.nd.array(np.zeros((batch_size, num_anchors, 6)).astype(out.dtype), dev) - f = tvm.build(s, [cls_prob, loc_preds, anchors, out], target) - f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out) - tvm.testing.assert_allclose(tvm_out.numpy(), expected_np_out, rtol=1e-4) - - for target in ["llvm", "opencl", "cuda"]: - check_device(target) - - -def verify_roi_align( - batch, in_channel, in_size, num_roi, pooled_size, spatial_scale, sample_ratio, mode -): # For mode, 0 = avg, 1 = max - a_shape = (batch, in_channel, in_size, in_size) - rois_shape = (num_roi, 5) + fcompute, fschedule = tvm.topi.testing.dispatch(target, _multibox_detection_implement) + with tvm.target.Target(target): + out = fcompute(cls_prob, loc_preds, anchors) + s = fschedule(out) + + tvm_cls_prob = tvm.nd.array(np_cls_prob.astype(cls_prob.dtype), dev) + tvm_loc_preds = tvm.nd.array(np_loc_preds.astype(loc_preds.dtype), dev) + tvm_anchors = tvm.nd.array(np_anchors.astype(anchors.dtype), dev) + tvm_out = tvm.nd.array(np.zeros((batch_size, num_anchors, 6)).astype(out.dtype), dev) + f = tvm.build(s, [cls_prob, loc_preds, anchors, out], target) + f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out) + tvm.testing.assert_allclose(tvm_out.numpy(), expected_np_out, rtol=1e-4) + + +class TestRoiAlign: + ( + batch, + in_channel, + in_size, + num_roi, + pooled_size, + spatial_scale, + sample_ratio, + mode, + ) = tvm.testing.parameters( + (1, 16, 32, 64, 7, 1.0, -1, 0), + (4, 16, 32, 64, 7, 0.5, 2, 0), + (1, 32, 32, 80, 8, 0.0625, 2, 0), + (1, 32, 500, 80, 8, 0.0625, 2, 0), + (1, 16, 32, 64, 7, 1.0, -1, 1), + (4, 16, 32, 64, 7, 0.5, 2, 1), + (1, 32, 32, 80, 8, 0.0625, 2, 1), + (1, 32, 500, 80, 8, 0.0625, 2, 1), + ) - a = te.placeholder(a_shape) - rois = te.placeholder(rois_shape) + @tvm.testing.fixture(cache_return_value=True) + def ref_data( + self, + batch, + in_channel, + in_size, + num_roi, + pooled_size, + spatial_scale, + sample_ratio, + mode, + ): + a_shape = (batch, in_channel, in_size, in_size) + rois_shape = (num_roi, 5) - @memoize("topi.tests.test_topi_vision.verify_roi_align") - def get_ref_data(): a_np = np.random.uniform(-1, 1, size=a_shape).astype("float32") rois_np = np.random.uniform(-1, 1, size=rois_shape).astype("float32") * in_size rois_np[:, 0] = np.random.randint(low=0, high=batch, size=num_roi) @@ -448,13 +451,22 @@ def get_ref_data(): return a_np, rois_np, b_np - a_np, rois_np, b_np = get_ref_data() + def test_roi_align( + self, + target, + dev, + ref_data, + pooled_size, + spatial_scale, + sample_ratio, + mode, + ): + # For mode, 0 = avg, 1 = max + a_np, rois_np, b_np = ref_data + + a = te.placeholder(a_np.shape) + rois = te.placeholder(rois_np.shape) - def check_device(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return with tvm.target.Target(target): fcompute, fschedule = tvm.topi.testing.dispatch(target, _roi_align_implement) b = fcompute( @@ -469,37 +481,24 @@ def check_device(target): tvm_a = tvm.nd.array(a_np, dev) tvm_rois = tvm.nd.array(rois_np, dev) - tvm_b = tvm.nd.array(np.zeros(get_const_tuple(b.shape), dtype=b.dtype), device=dev) + tvm_b = tvm.nd.array(np.zeros(b_np.shape, dtype=b.dtype), device=dev) f = tvm.build(s, [a, rois, b], target) f(tvm_a, tvm_rois, tvm_b) tvm_val = tvm_b.numpy() tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3, atol=1e-4) - for target in ["llvm", "cuda", "opencl"]: - check_device(target) - -@tvm.testing.uses_gpu -def test_roi_align(): - verify_roi_align(1, 16, 32, 64, 7, 1.0, -1, 0) - verify_roi_align(4, 16, 32, 64, 7, 0.5, 2, 0) - verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2, 0) - verify_roi_align(1, 32, 500, 80, 8, 0.0625, 2, 0) - verify_roi_align(1, 16, 32, 64, 7, 1.0, -1, 1) - verify_roi_align(4, 16, 32, 64, 7, 0.5, 2, 1) - verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2, 1) - verify_roi_align(1, 32, 500, 80, 8, 0.0625, 2, 1) - - -def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale): - a_shape = (batch, in_channel, in_size, in_size) - rois_shape = (num_roi, 5) +class TestRoiPool: + batch, in_channel, in_size, num_roi, pooled_size, spatial_scale = tvm.testing.parameters( + (1, 4, 16, 32, 7, 1.0), + (4, 4, 16, 32, 7, 0.5), + ) - a = te.placeholder(a_shape) - rois = te.placeholder(rois_shape) + @tvm.testing.fixture(cache_return_value=True) + def ref_data(self, batch, in_channel, in_size, num_roi, pooled_size, spatial_scale): + a_shape = (batch, in_channel, in_size, in_size) + rois_shape = (num_roi, 5) - @memoize("topi.tests.test_topi_vision.verify_roi_pool") - def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype("float32") rois_np = np.random.uniform(size=rois_shape).astype("float32") * in_size rois_np[:, 0] = np.random.randint(low=0, high=batch, size=num_roi).astype("float32") @@ -509,14 +508,11 @@ def get_ref_data(): ) return a_np, rois_np, b_np - a_np, rois_np, b_np = get_ref_data() + def test_roi_pool(self, target, dev, ref_data, pooled_size, spatial_scale): + a_np, rois_np, b_np = ref_data - def check_device(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) + a = te.placeholder(a_np.shape) + rois = te.placeholder(rois_np.shape) with tvm.target.Target(target): b = topi.vision.rcnn.roi_pool_nchw( @@ -527,50 +523,32 @@ def check_device(target): tvm_a = tvm.nd.array(a_np, dev) tvm_rois = tvm.nd.array(rois_np, dev) - tvm_b = tvm.nd.array(np.zeros(get_const_tuple(b.shape), dtype=b.dtype), device=dev) + tvm_b = tvm.nd.array(np.zeros(b_np.shape, dtype=b.dtype), device=dev) f = tvm.build(s, [a, rois, b], target) f(tvm_a, tvm_rois, tvm_b) tvm.testing.assert_allclose(tvm_b.numpy(), b_np, rtol=1e-4) - for target in ["cuda", "llvm"]: - check_device(target) - -@tvm.testing.uses_gpu -def test_roi_pool(): - verify_roi_pool(1, 4, 16, 32, 7, 1.0) - verify_roi_pool(4, 4, 16, 32, 7, 0.5) - - -def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): +def verify_proposal(target, dev, np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): cls_prob = te.placeholder(np_cls_prob.shape) bbox_pred = te.placeholder(np_bbox_pred.shape) im_info = te.placeholder(np_im_info.shape) - def check_device(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) - with tvm.target.Target(target): - fcompute, fschedule = tvm.topi.testing.dispatch(target, _proposal_implement) - out = fcompute(cls_prob, bbox_pred, im_info, **attrs) - s = fschedule(out) - f = tvm.build(s, [cls_prob, bbox_pred, im_info, out], target) - tvm_cls_prob = tvm.nd.array(np_cls_prob, device=dev) - tvm_bbox_pred = tvm.nd.array(np_bbox_pred, device=dev) - tvm_im_info = tvm.nd.array(np_im_info, device=dev) - tvm_out = tvm.nd.empty(device=dev, shape=out.shape, dtype=out.dtype) - f(tvm_cls_prob, tvm_bbox_pred, tvm_im_info, tvm_out) - tvm.testing.assert_allclose(tvm_out.numpy(), np_out, rtol=1e-4) - - for target in ["llvm", "cuda"]: - check_device(target) - - -@tvm.testing.uses_gpu -def test_proposal(): + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _proposal_implement) + out = fcompute(cls_prob, bbox_pred, im_info, **attrs) + s = fschedule(out) + f = tvm.build(s, [cls_prob, bbox_pred, im_info, out], target) + tvm_cls_prob = tvm.nd.array(np_cls_prob, device=dev) + tvm_bbox_pred = tvm.nd.array(np_bbox_pred, device=dev) + tvm_im_info = tvm.nd.array(np_im_info, device=dev) + tvm_out = tvm.nd.empty(device=dev, shape=out.shape, dtype=out.dtype) + f(tvm_cls_prob, tvm_bbox_pred, tvm_im_info, tvm_out) + tvm.testing.assert_allclose(tvm_out.numpy(), np_out, rtol=1e-4) + + +@tvm.testing.known_failing_targets("vulkan") +def test_proposal(target, dev): attrs = { "scales": (0.5,), "ratios": (0.5,), @@ -612,7 +590,7 @@ def test_proposal(): dtype="float32", ) - verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) + verify_proposal(target, dev, np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) np_out = np.array( [ @@ -625,10 +603,12 @@ def test_proposal(): ) attrs["iou_loss"] = True - verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) + verify_proposal(target, dev, np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) def verify_all_class_non_max_suppression( + target, + dev, boxes_np, scores_np, max_output_boxes_per_class, @@ -642,36 +622,24 @@ def verify_all_class_non_max_suppression( boxes = te.placeholder(dshape, name="boxes") scores = te.placeholder(scores_np.shape, dtype="float32", name="scores") - def check_device(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) - with tvm.target.Target(target): - fcompute, fschedule = tvm.topi.testing.dispatch(target, _all_class_nms_implement) - out = fcompute( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold - ) - s = fschedule(out) - - tvm_boxes = tvm.nd.array(boxes_np, dev) - tvm_scores = tvm.nd.array(scores_np, dev) - selected_indices = tvm.nd.array(np.zeros((batch * num_class * num_boxes, 3), "int64"), dev) - num_detections = tvm.nd.array(np.zeros((1,), "int64"), dev) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _all_class_nms_implement) + out = fcompute(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) + s = fschedule(out) - f = tvm.build(s, [boxes, scores, out[0], out[1]], target) - f(tvm_boxes, tvm_scores, selected_indices, num_detections) + tvm_boxes = tvm.nd.array(boxes_np, dev) + tvm_scores = tvm.nd.array(scores_np, dev) + selected_indices = tvm.nd.array(np.zeros((batch * num_class * num_boxes, 3), "int64"), dev) + num_detections = tvm.nd.array(np.zeros((1,), "int64"), dev) - tvm_res = selected_indices.numpy()[: num_detections.numpy()[0]] - np.testing.assert_equal(tvm_res, expected_indices) + f = tvm.build(s, [boxes, scores, out[0], out[1]], target) + f(tvm_boxes, tvm_scores, selected_indices, num_detections) - for target in ["llvm", "cuda", "opencl", "vulkan"]: - check_device(target) + tvm_res = selected_indices.numpy()[: num_detections.numpy()[0]] + np.testing.assert_equal(tvm_res, expected_indices) -@tvm.testing.uses_gpu -def test_all_class_non_max_suppression(): +def test_all_class_non_max_suppression(target, dev): boxes = np.array( [ [ @@ -707,7 +675,14 @@ def test_all_class_non_max_suppression(): ) verify_all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected + target, + dev, + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected, ) boxes = np.array( @@ -730,16 +705,16 @@ def test_all_class_non_max_suppression(): expected = np.array([[0, 0, 3], [0, 0, 0]]) verify_all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected + target, + dev, + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected, ) if __name__ == "__main__": - test_get_valid_counts() - test_multibox_prior() - test_multibox_detection() - test_roi_align() - test_roi_pool() - test_proposal() - test_non_max_suppression() - test_all_class_non_max_suppression() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index a8d57b3da780..ecc30199c1a7 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -139,7 +139,19 @@ def test_ldexp(): tvm.testing.assert_allclose(c.numpy(), np.ldexp(a.numpy(), b.numpy()), atol=1e-5, rtol=1e-5) -def test_clz(): +dtype = tvm.testing.parameter("int32", "int64") + + +@tvm.testing.parametrize_targets("llvm", "vulkan -from_device=0") +def test_clz(target, dev, dtype): + target = tvm.target.Target(target) + if ( + target.kind.name == "vulkan" + and dtype == "int64" + and not target.attrs.get("supports_int64", False) + ): + pytest.xfail("Vulkan target does not support Int64 types") + def clz_np(x, dtype): ceil_log2 = np.ceil(np.log2(x)).astype(dtype) bits = int(dtype[-2:]) @@ -147,38 +159,32 @@ def clz_np(x, dtype): clz[np.bitwise_and(x, x - 1) == 0] -= 1 return clz - for target in ["llvm", "vulkan"]: - if not tvm.testing.device_enabled("vulkan"): - continue + m = te.var("m") + A = te.placeholder((m,), name="A", dtype=dtype) + B = te.compute((m,), lambda *i: tvm.tir.clz(A(*i)), name="B") + s = te.create_schedule(B.op) - for dtype in ["int32", "int64"]: - m = te.var("m") - A = te.placeholder((m,), name="A", dtype=dtype) - B = te.compute((m,), lambda *i: tvm.tir.clz(A(*i)), name="B") - s = te.create_schedule(B.op) + if target.kind.name == "vulkan": + bx, tx = s[B].split(B.op.axis[0], factor=64) - if target == "vulkan": - bx, tx = s[B].split(B.op.axis[0], factor=64) + s[B].bind(bx, te.thread_axis("blockIdx.x")) + s[B].bind(tx, te.thread_axis("threadIdx.x")) - s[B].bind(bx, te.thread_axis("blockIdx.x")) - s[B].bind(tx, te.thread_axis("threadIdx.x")) - - f = tvm.build(s, [A, B], target) - dev = tvm.device(target, 0) - n = 10 + f = tvm.build(s, [A, B], target) + n = 10 - highs = [10, 100, 1000, 10000, 100000, 1000000] + highs = [10, 100, 1000, 10000, 100000, 1000000] - if dtype == "int64": - highs.append((1 << 63) - 1) + if dtype == "int64": + highs.append((1 << 63) - 1) - for high in highs: - a_np = np.random.randint(1, high=high, size=(n,)).astype(dtype) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros((n,)).astype("int32"), dev) - f(a, b) - ref = clz_np(a_np, dtype) - np.testing.assert_equal(b.numpy(), ref) + for high in highs: + a_np = np.random.randint(1, high=high, size=(n,), dtype=dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros((n,)).astype("int32"), dev) + f(a, b) + ref = clz_np(a_np, dtype) + np.testing.assert_equal(b.numpy(), ref) @tvm.script.tir diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index 4c9c5d91901a..f396eeeee5fb 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -57,7 +57,10 @@ def test_explicit_list(self, target): self.targets_with_explicit_list.append(target) def test_no_repeats_in_explicit_list(self): - assert self.targets_with_explicit_list == ["llvm"] + if tvm.testing.device_enabled("llvm"): + assert self.targets_with_explicit_list == ["llvm"] + else: + assert self.targets_with_explicit_list == [] targets_with_exclusion = [] From b6f985d5ebb35813c7583013966913cf265c206d Mon Sep 17 00:00:00 2001 From: Tom Gall Date: Thu, 2 Sep 2021 04:15:33 -0500 Subject: [PATCH 07/12] Trivial uTVM -> microTVM "spelling" fix to align with branding. (#8905) * Embarrassingly trivial fix remove use of uTVM and replace with the proper microTVM naming convention. --- cmake/config.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index e55f1197d90e..8d8186c1b4f0 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -108,7 +108,7 @@ set(USE_GRAPH_EXECUTOR_CUDA_GRAPH OFF) # Whether to enable the profiler for the graph executor and vm set(USE_PROFILER ON) -# Whether enable uTVM standalone runtime +# Whether enable microTVM standalone runtime set(USE_MICRO_STANDALONE_RUNTIME OFF) # Whether build with LLVM support From 27be4625de60466ff72ec6780f7b6671d76382d8 Mon Sep 17 00:00:00 2001 From: ziheng Date: Thu, 2 Sep 2021 03:35:00 -0700 Subject: [PATCH 08/12] [Community] @Hzfengsy -> Committer (#8908) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index e7f70379e49d..96e1b9e42a7f 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -35,6 +35,7 @@ We do encourage everyone to work anything they are interested in. - [Tianqi Chen](https://github.com/tqchen) (PMC): @tqchen - topi, compiler, relay, docs - [Wei Chen](https://github.com/wweic): @wweic - runtime, relay, vm - [Zhi Chen](https://github.com/zhiics) (PMC): @zhiics - relay, quantization, pass manager +- [Siyuan Feng](https://github.com/Hzfengsy): @Hzfengsy - tir - [Josh Fromm](https://github.com/jwfromm): @jwfromm - frontends, quantization, topi - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - topi, frontends - [Nick Hynes](https://github.com/nhynes): @nhynes: - sgx, rust From 7c9811c23f75df9361fc3bf5e1767c3f9fb88f0f Mon Sep 17 00:00:00 2001 From: Luyao Ren Date: Thu, 2 Sep 2021 23:28:48 +0800 Subject: [PATCH 09/12] Set default value of p in LpPool as 2 (#8866) * Set default value of p in LpPool as 2 * Update test_forward.py Fix bug in test. * Update test_forward.py update with correct shape. * Update onnx.py * Update python/tvm/relay/frontend/onnx.py Co-authored-by: Wuwei Lin Co-authored-by: luyaor Co-authored-by: Wuwei Lin --- python/tvm/relay/frontend/onnx.py | 5 +++-- tests/python/frontend/onnx/test_forward.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b59c924d8ca3..18e07b369eec 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -908,8 +908,9 @@ def _impl_v1(cls, inputs, attr, params): else: attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2), op_name="LpPool") - p = _expr.const(attr["p"], dtype) - reci_p = _expr.const(1.0 / attr["p"], dtype) + p_value = attr.get("p", 2) + p = _expr.const(p_value, dtype) + reci_p = _expr.const(1.0 / p_value, dtype) data = _op.power(data, p) out = AttrCvt( diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 15214192148b..e4451aba1704 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3184,15 +3184,18 @@ def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_sh @tvm.testing.parametrize_targets def test_lppool(target, dev): def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"): + kwargs = {} + if p is not None: + kwargs["p"] = p if pads is None: pool_node = helper.make_node( "LpPool", inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, - p=p, auto_pad=auto_pad, strides=strides, + **kwargs, ) else: pool_node = helper.make_node( @@ -3200,9 +3203,9 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad=" inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, - p=p, pads=pads, strides=strides, + **kwargs, ) graph = helper.make_graph( @@ -3295,6 +3298,15 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad=" out_shape=[1, 1, 16, 16, 16], auto_pad="SAME_UPPER", ) + # Pool2D with empty p + verify_lppool( + x_shape=[1, 1, 32, 32], + kernel_shape=[3, 3], + p=None, + strides=[1, 1], + pads=[1, 1, 1, 1], + out_shape=[1, 1, 32, 32], + ) def verify_rnn( From 707c4e0afc4bab87c5931236ed9ec3f26c8bea61 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Thu, 2 Sep 2021 19:38:34 +0300 Subject: [PATCH 10/12] Enable python debug runtime for exported network libraries (#8793) * Add get_json method to graph_eceutor factory Signed-off-by: Alexander Peskov * Update Debugger runtime documentation for exported libraries * Fix cpplint * Change module get_json to get_graph_json, add test * Fix get_graph_json test * Change verificatino of llvm support in tet to decorator * Fix sphinx warning in debugger.rst Co-authored-by: Alexander Peskov --- docs/dev/debugger.rst | 20 ++++++++- .../graph_executor/graph_executor_factory.cc | 4 ++ .../test_runtime_module_based_interface.py | 45 ++++++++++--------- 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/docs/dev/debugger.rst b/docs/dev/debugger.rst index 38172a2189e0..4a612cac37be 100644 --- a/docs/dev/debugger.rst +++ b/docs/dev/debugger.rst @@ -123,12 +123,12 @@ Example of loading the parameters How to use Debugger? *************************************** -1. In ``config.cmake`` set the ``USE_GRAPH_EXECUTOR_DEBUG`` flag to ``ON`` +1. In ``config.cmake`` set the ``USE_PROFILER`` flag to ``ON`` :: # Whether enable additional graph debug functions - set(USE_GRAPH_EXECUTOR_DEBUG ON) + set(USE_PROFILER ON) 2. Do 'make' tvm, so that it will make the ``libtvm_runtime.so`` @@ -148,6 +148,22 @@ How to use Debugger? m.run() tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).numpy() +4. If network previously was exported to external libray using ``lib.export_library("network.so")`` + like shared object file/dynamic linked library, the initialization + of debug runtime will be slightly different + +:: + + lib = tvm.runtime.load_module("network.so") + m = graph_executor.create(lib["get_graph_json"](), lib, dev, dump_root="/tmp/tvmdbg") + # set inputs + m.set_input('data', tvm.nd.array(data.astype(dtype))) + m.set_input(**params) + # execute + m.run() + tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).numpy() + + The outputs are dumped to a temporary folder in ``/tmp`` folder or the folder specified while creating the runtime. diff --git a/src/runtime/graph_executor/graph_executor_factory.cc b/src/runtime/graph_executor/graph_executor_factory.cc index a13fbd860d43..42fc28b60df2 100644 --- a/src/runtime/graph_executor/graph_executor_factory.cc +++ b/src/runtime/graph_executor/graph_executor_factory.cc @@ -53,6 +53,10 @@ PackedFunc GraphExecutorFactory::GetFunction( } *rv = this->ExecutorCreate(devices); }); + } else if (name == "get_graph_json") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->graph_json_; }); + } else if (name == "debug_create") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_GE(args.size(), 2); diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index e984979ac14f..2abbcef29283 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -46,10 +46,8 @@ def verify(data): return out +@tvm.testing.requires_llvm def test_legacy_compatibility(): - if not tvm.testing.device_enabled("llvm"): - print("Skip because llvm is not enabled") - return mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) @@ -63,10 +61,8 @@ def test_legacy_compatibility(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +@tvm.testing.requires_llvm def test_cpu(): - if not tvm.testing.device_enabled("llvm"): - print("Skip because llvm is not enabled") - return mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -90,6 +86,23 @@ def test_cpu(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +@tvm.testing.requires_llvm +def test_cpu_get_graph_json(): + mod, params = relay.testing.synthetic.get_workload() + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + from tvm.contrib import utils + + temp = utils.tempdir() + file_name = "deploy_lib.so" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + loaded_lib = tvm.runtime.load_module(path_lib) + json = loaded_lib["get_graph_json"]() + assert isinstance(json, str) == True + assert json.find("tvmgen_default_fused_nn_softmax1") > -1 + + @tvm.testing.requires_cuda @tvm.testing.requires_gpu def test_gpu(): @@ -120,9 +133,6 @@ def test_gpu(): @tvm.testing.uses_gpu def test_mod_export(): def verify_cpu_export(obj_format): - if not tvm.testing.device_enabled("llvm"): - print("Skip because llvm is not enabled") - return mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -210,10 +220,8 @@ def setup_gmod(): out = gmod.get_output(0).numpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + @tvm.testing.requires_llvm def verify_rpc_cpu_export(obj_format): - if not tvm.testing.device_enabled("llvm"): - print("Skip because llvm is not enabled") - return mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -308,12 +316,10 @@ def check_remote(server): verify_rpc_gpu_export(obj_format) +@tvm.testing.requires_llvm @tvm.testing.uses_gpu def test_remove_package_params(): def verify_cpu_remove_package_params(obj_format): - if not tvm.testing.device_enabled("llvm"): - print("Skip because llvm is not enabled") - return mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -404,10 +410,8 @@ def verify_gpu_remove_package_params(obj_format): out = gmod.get_output(0).numpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + @tvm.testing.requires_llvm def verify_rpc_cpu_remove_package_params(obj_format): - if not tvm.testing.device_enabled("llvm"): - print("Skip because llvm is not enabled") - return mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -517,10 +521,8 @@ def verify_rpc_gpu_remove_package_params(obj_format): verify_rpc_gpu_remove_package_params(obj_format) +@tvm.testing.requires_llvm def test_debug_graph_executor(): - if not tvm.testing.device_enabled("llvm"): - print("Skip because llvm is not enabled") - return mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -619,3 +621,4 @@ def make_module(mod): test_remove_package_params() test_debug_graph_executor() test_multiple_imported_modules() + test_cpu_get_graph_json() From 7deebc646c7c013f1bb0a71f88df0b70af11ffd0 Mon Sep 17 00:00:00 2001 From: wangxiang2713 <49302617+wangxiang2713@users.noreply.github.com> Date: Fri, 3 Sep 2021 02:04:01 +0800 Subject: [PATCH 11/12] [BUG] DataType Bug In SplitRel (#8899) * [BUG] DataType Bug In SplitRel * Add Test Case --- src/relay/op/tensor/transform.cc | 11 +++++++---- tests/python/relay/test_op_level3.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9f9ed1c075cd..3781107eeee1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2829,15 +2829,18 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, } reporter->Assign(types[1], TupleType(Array(fields))); } else { - auto indices = Downcast>(param->indices_or_sections); + Array indices; + for (auto i : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + } auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; for (unsigned int i = 0; i < indices.size(); ++i) { - ICHECK(reporter->Assert(Downcast(indices[i]) > begin)) + ICHECK(reporter->Assert(indices[i] > begin)) << "indices_or_sections need to be a sorted ascending list"; std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = Downcast(indices[i]) - begin; - begin = Downcast(indices[i]); + oshape[axis] = indices[i] - begin; + begin = indices[i]; auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); } diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 41a866a0a034..e0b95fe7fbf7 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -516,6 +516,21 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None): ), axis=1, ) + verify_split( + (d1, d2, d3, d4), + tuple(np.array([2, 4, 7]).astype(np.int64)), + relay.ty.TupleType( + tvm.runtime.convert( + [ + relay.ty.TensorType((d1, 2, d3, d4), "float32"), + relay.ty.TensorType((d1, 2, d3, d4), "float32"), + relay.ty.TensorType((d1, 3, d3, d4), "float32"), + relay.ty.TensorType((d1, (d2 - 7), d3, d4), "float32"), + ] + ) + ), + axis=1, + ) def test_full_infer_type(): From aac0754ed639358af2ad533159745a7b4ff416d6 Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Thu, 2 Sep 2021 14:10:28 -0500 Subject: [PATCH 12/12] [UnitTests][Contrib] Enable contrib tensorrt/coreml unit tests (#8902) * [UnitTests][CoreML] Marked test_annotate as a known failure. The unit tests in `test_coreml_codegen.py` haven't run in the CI lately, so this test wasn't caught before. (See tracking issue - Added `pytest.mark.xfail` mark to `test_annotate`. - Added `tvm.testing.requires_package` decorator, which can mark tests as requiring a specific python package to be available. Switched from `pytest.importorskip('coremltools')` to `requires_package('coremltools')` in `test_coreml_codegen.py` so that all tests would explicitly show up as skipped in the report. - Added `uses_gpu` tag to all tests in `test_coreml_codegen.py`, since only ci_gpu has coremltools installed. In the future, if the ci_cpu image has coremltools installed, this mark can be removed. * [Pytest][TensorRT] Mark the TensorRT tests with tvm.testing.requires_cuda Previously, the tests had an early bailout if tensorrt was disabled, or if there was no cuda device present. However, the tests were not marked with `pytest.mark.gpu` and so they didn't run during `task_python_integration_gpuonly.sh`. This commit adds the `requires_cuda` mark, and maintains the same behavior of testing the tensorrt compilation steps if compilation is enabled, and running the results if tensorrt is enabled. In addition, some of the tests result in failures when run. These have been marked with `pytest.mark.xfail`, and are being tracked in issue #8901. --- python/tvm/testing/__init__.py | 1 + python/tvm/testing/utils.py | 44 ++ tests/python/contrib/test_coreml_codegen.py | 28 +- tests/python/contrib/test_tensorrt.py | 523 +++++++++++--------- 4 files changed, 357 insertions(+), 239 deletions(-) diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index f610c6ecc0db..56a435ea3887 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -24,6 +24,7 @@ from .utils import known_failing_targets, requires_cuda, requires_cudagraph from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc from .utils import requires_tensorcore, requires_metal, requires_micro, requires_opencl +from .utils import requires_package from .utils import identity_after, terminate_self from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 6f115f8da58c..85a8b7738184 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -774,7 +774,51 @@ def requires_rpc(*args): return _compose(args, _requires_rpc) +def requires_package(*packages): + """Mark a test as requiring python packages to run. + + If the packages listed are not available, tests marked with + `requires_package` will appear in the pytest results as being skipped. + This is equivalent to using ``foo = pytest.importorskip('foo')`` inside + the test body. + + Parameters + ---------- + packages : List[str] + + The python packages that should be available for the test to + run. + + Returns + ------- + mark: pytest mark + + The pytest mark to be applied to unit tests that require this + + """ + + def has_package(package): + try: + __import__(package) + return True + except ImportError: + return False + + marks = [ + pytest.mark.skipif(not has_package(package), reason=f"Cannot import '{package}'") + for package in packages + ] + + def wrapper(func): + for mark in marks: + func = mark(func) + return func + + return wrapper + + def parametrize_targets(*args): + """Parametrize a test over a specific set of targets. Use this decorator when you want your test to be run over a diff --git a/tests/python/contrib/test_coreml_codegen.py b/tests/python/contrib/test_coreml_codegen.py index 7e678a790f4a..2edfafaa0bd8 100644 --- a/tests/python/contrib/test_coreml_codegen.py +++ b/tests/python/contrib/test_coreml_codegen.py @@ -19,11 +19,12 @@ from unittest import mock import tvm +import tvm.testing from tvm import relay from tvm.relay import transform from tvm.contrib.target import coreml as _coreml -pytest.importorskip("coremltools") +requires_coremltools = tvm.testing.requires_package("coremltools") def _has_xcode(): @@ -88,6 +89,11 @@ def _create_graph_annotated(): return mod +@pytest.mark.xfail( + reason="Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901" +) +@tvm.testing.uses_gpu +@requires_coremltools def test_annotate(): mod = _create_graph() mod = transform.AnnotateTarget("coremlcompiler")(mod) @@ -98,6 +104,8 @@ def test_annotate(): @pytest.mark.skipif(not _has_xcode(), reason="Xcode is not available") +@tvm.testing.uses_gpu +@requires_coremltools def test_compile_and_run(): dev = tvm.cpu() target = "llvm" @@ -136,6 +144,8 @@ def _construct_model(func, m1, m2): fcompile(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_add(): shape = (10, 10) x = relay.var("x", shape=shape) @@ -144,6 +154,8 @@ def test_add(): _construct_model(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_multiply(): shape = (10, 10) x = relay.var("x", shape=shape) @@ -152,6 +164,8 @@ def test_multiply(): _construct_model(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_clip(): shape = (10, 10) x = relay.var("x", shape=shape) @@ -160,6 +174,8 @@ def test_clip(): _construct_model(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_batch_flatten(): shape = (10, 10, 10) x = relay.var("x", shape=shape) @@ -168,6 +184,8 @@ def test_batch_flatten(): _construct_model(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_expand_dims(): shape = (10, 10) x = relay.var("x", shape=shape) @@ -180,6 +198,8 @@ def test_expand_dims(): _construct_model(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_relu(): shape = (10, 10) x = relay.var("x", shape=shape) @@ -188,6 +208,8 @@ def test_relu(): _construct_model(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_softmax(): shape = (10, 10) x = relay.var("x", shape=shape) @@ -196,6 +218,8 @@ def test_softmax(): _construct_model(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_conv2d(): x = relay.var("x", shape=(1, 3, 224, 224)) w = relay.const(np.zeros((16, 3, 3, 3), dtype="float32")) @@ -204,6 +228,8 @@ def test_conv2d(): _construct_model(func) +@tvm.testing.uses_gpu +@requires_coremltools def test_global_avg_pool2d(): shape = (10, 10, 10, 10) x = relay.var("x", shape=shape) diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index f40b3368dc85..ec512d7d714f 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -32,26 +32,23 @@ from tvm.contrib.download import download from tvm.relay.op.contrib import tensorrt +import tvm.testing -def skip_codegen_test(): - """Skip test if TensorRT and CUDA codegen are not present""" - if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist: - print("Skip because CUDA is not enabled.") - return True - if not tvm.get_global_func("relay.ext.tensorrt", True): - print("Skip because TensorRT codegen is not available.") - return True - return False +has_tensorrt_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" +) +has_tensorrt_runtime = pytest.mark.skipif( + not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" +) -def skip_runtime_test(): - if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist: - print("Skip because CUDA is not enabled.") - return True - if not tensorrt.is_tensorrt_runtime_enabled(): - print("Skip because TensorRT runtime is not available.") - return True - return False +run_module = tvm.testing.parameter( + pytest.param(False, marks=[has_tensorrt_codegen, *tvm.testing.requires_cuda()]), + pytest.param( + True, marks=[has_tensorrt_runtime, has_tensorrt_codegen, *tvm.testing.requires_cuda()] + ), + ids=["compile", "run"], +) def vmobj_to_list(o): @@ -79,7 +76,7 @@ def set_func_attr(func, compile_name, symbol_name): return func -def run_and_verify_func(config, target="cuda"): +def run_and_verify_func(config, target="cuda", run_module=True): """Test a Relay func by compiling, running, and comparing TVM and TRT outputs. Parameters @@ -87,9 +84,11 @@ def run_and_verify_func(config, target="cuda"): config : Tuple[relay.Function, Dict[str, NDArray], List[str]] A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and 3) A list of which vars should be considered params. + + run_module: bool + + If True, the built module will be run after being compiled. """ - if skip_codegen_test(): - return f, input_shapes, is_param = config params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(np.float32) for x in is_param} input_dict = { @@ -118,17 +117,14 @@ def run_and_verify_func(config, target="cuda"): func = relay.create_executor( mode, mod=mod, device=dev, target=target ).evaluate() - if not skip_runtime_test(): + if run_module: result_dict[result_key] = func(**input_dict, **params) - if not skip_runtime_test(): + if run_module: assert_result_dict_holds(result_dict) -def run_and_verify_model(model): - if skip_codegen_test(): - return - +def run_and_verify_model(model, run_module): import mxnet as mx from mxnet.gluon.model_zoo.vision import get_model @@ -156,7 +152,7 @@ def compile_and_run(mod, params, i_data, mode="vm", use_trt=True): mode, mod=mod, device=tvm.cuda(0), target="cuda" ).evaluate() - res = func(i_data, **params) if not skip_runtime_test() else None + res = func(i_data, **params) if run_module else None return res dtype = "float32" @@ -173,13 +169,11 @@ def compile_and_run(mod, params, i_data, mode="vm", use_trt=True): mod, params, i_data, mode=mode, use_trt=use_trt ) - if not skip_runtime_test(): + if run_module: assert_result_dict_holds(result_dict) -def test_tensorrt_simple(): - if skip_codegen_test(): - return +def test_tensorrt_simple(run_module): dtype = "float32" xshape = (1, 3, 2, 2) yshape = (1, 3, 1, 1) @@ -214,14 +208,14 @@ def test_tensorrt_simple(): func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" ).evaluate() - if not skip_runtime_test(): + if run_module: result_dict[result_key] = func(x_data, y_data, z_data) - if not skip_runtime_test(): + if run_module: assert_result_dict_holds(result_dict) -def test_tensorrt_simple_cpu_io(): +def test_tensorrt_simple_cpu_io(run_module): def get_graph(): dtype = "float32" x_shape = (1, 3, 2, 2) @@ -235,12 +229,10 @@ def get_graph(): f = relay.Function([x, y, z], out) return f, {"x": x_shape, "y": y_shape, "z": z_shape}, ["y"] - run_and_verify_func(get_graph(), target="llvm") + run_and_verify_func(get_graph(), target="llvm", run_module=run_module) -def test_tensorrt_not_compatible(): - if skip_codegen_test(): - return +def test_tensorrt_not_compatible(run_module): dtype = "float32" xshape = (1, 32, 14, 14) x_data = np.random.uniform(-1, 1, xshape).astype(dtype) @@ -258,13 +250,11 @@ def test_tensorrt_not_compatible(): func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" ).evaluate() - if not skip_runtime_test(): + if run_module: results = func(x_data) -def test_tensorrt_serialize_graph_executor(): - if skip_codegen_test(): - return +def test_tensorrt_serialize_graph_executor(run_module): import mxnet as mx from mxnet.gluon.model_zoo.vision import get_model @@ -311,16 +301,14 @@ def load_graph(): save_graph(graph, lib, graph_params) loaded_graph, loaded_lib, loaded_params = load_graph() - if not skip_runtime_test(): + if run_module: result_dict = dict() result_dict["graph"] = run_graph(graph, lib, graph_params) result_dict["graph_ref"] = run_graph(loaded_graph, loaded_lib, loaded_params) assert_result_dict_holds(result_dict) -def test_tensorrt_serialize_vm(): - if skip_codegen_test(): - return +def test_tensorrt_serialize_vm(run_module): import mxnet as mx from mxnet.gluon.model_zoo.vision import get_model @@ -360,14 +348,14 @@ def load_vm(): save_vm(code_vm, lib_vm) loaded_lib_vm, loaded_code_vm = load_vm() - if not skip_runtime_test(): + if run_module: result_dict = dict() result_dict["vm"] = run_vm(code_vm, lib_vm) result_dict["vm_ref"] = run_vm(loaded_code_vm, loaded_lib_vm) assert_result_dict_holds(result_dict) -def test_conv2d(): +def test_conv2d(run_module): def get_graph( x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), @@ -403,15 +391,17 @@ def get_graph( padding=padding, strides=strides, dilation=dilation, - ) + ), + run_module=run_module, ) run_and_verify_func( - get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24) + get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24), + run_module=run_module, ) - run_and_verify_func(get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1)) + run_and_verify_func(get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1), run_module=run_module) -def test_conv2d_nhwc(): +def test_conv2d_nhwc(run_module): def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)): x = relay.var("x", shape=(x_shape), dtype="float32") kernel = relay.var("kernel", shape=(k_shape), dtype="float32") @@ -426,10 +416,10 @@ def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)): f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_conv2d_weights_const(): +def test_conv2d_weights_const(run_module): def get_graph( x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), @@ -453,10 +443,10 @@ def get_graph( f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_conv2d_weights_transposed(): +def test_conv2d_weights_transposed(run_module): def get_graph(x_shape=(1, 32, 9, 9), k_shape=(3, 3, 32, 16), order=(3, 2, 0, 1)): x = relay.var("x", shape=(x_shape), dtype="float32") kernel = relay.var("kernel", shape=(k_shape), dtype="float32") @@ -467,10 +457,10 @@ def get_graph(x_shape=(1, 32, 9, 9), k_shape=(3, 3, 32, 16), order=(3, 2, 0, 1)) f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_dense(): +def test_dense(run_module): def get_graph(x_shape=(1, 16), k_shape=(32, 16)): x = relay.var("x", shape=(x_shape), dtype="float32") kernel = relay.var("kernel", shape=(k_shape), dtype="float32") @@ -479,11 +469,11 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)): f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] - run_and_verify_func(get_graph()) - run_and_verify_func(get_graph(k_shape=(1, 16))) + run_and_verify_func(get_graph(), run_module=run_module) + run_and_verify_func(get_graph(k_shape=(1, 16)), run_module=run_module) -def test_batch_matmul(): +def test_batch_matmul(run_module): def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True): x = relay.var("x", shape=(x_shape), dtype="float32") y = relay.var("y", shape=(y_shape), dtype="float32") @@ -492,20 +482,24 @@ def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb return f, {"x": x_shape, "y": y_shape}, [] run_and_verify_func( - get_graph(x_shape=(12, 64, 128), y_shape=(12, 128, 64), transa=True, transb=True) + get_graph(x_shape=(12, 64, 128), y_shape=(12, 128, 64), transa=True, transb=True), + run_module=run_module, ) run_and_verify_func( - get_graph(x_shape=(12, 64, 128), y_shape=(12, 64, 128), transa=True, transb=False) + get_graph(x_shape=(12, 64, 128), y_shape=(12, 64, 128), transa=True, transb=False), + run_module=run_module, ) run_and_verify_func( - get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True) + get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True), + run_module=run_module, ) run_and_verify_func( - get_graph(x_shape=(12, 128, 64), y_shape=(12, 64, 128), transa=False, transb=False) + get_graph(x_shape=(12, 128, 64), y_shape=(12, 64, 128), transa=False, transb=False), + run_module=run_module, ) -def test_bias_add(): +def test_bias_add(run_module): def get_graph(x_shape=(1, 16), channels=16): x = relay.var("x", shape=(x_shape), dtype="float32") bias = relay.var("bias", shape=(channels,), dtype="float32") @@ -513,11 +507,11 @@ def get_graph(x_shape=(1, 16), channels=16): f = relay.Function([x, bias], out) return f, {"x": x_shape, "bias": (channels,)}, ["bias"] - run_and_verify_func(get_graph()) - run_and_verify_func(get_graph((1, 6, 3, 4), 6)) + run_and_verify_func(get_graph(), run_module=run_module) + run_and_verify_func(get_graph((1, 6, 3, 4), 6), run_module=run_module) -def test_pool2d(): +def test_pool2d(run_module): def get_graph( op, x_shape=(1, 3, 32, 32), @@ -567,7 +561,8 @@ def get_graph( padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad, - ) + ), + run_module=run_module, ) run_and_verify_func( get_graph( @@ -576,53 +571,54 @@ def get_graph( strides=strides, padding=padding, ceil_mode=ceil_mode, - ) + ), + run_module=run_module, ) -def test_global_pool2d(): +def test_global_pool2d(run_module): def get_graph(op, x_shape=(1, 3, 32, 32)): x = relay.var("x", shape=(x_shape), dtype="float32") out = op(x) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph(relay.nn.global_max_pool2d)) - run_and_verify_func(get_graph(relay.nn.global_avg_pool2d)) + run_and_verify_func(get_graph(relay.nn.global_max_pool2d), run_module=run_module) + run_and_verify_func(get_graph(relay.nn.global_avg_pool2d), run_module=run_module) -def test_batch_flatten(): +def test_batch_flatten(run_module): def get_graph(x_shape=(1, 3, 4, 6)): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.nn.batch_flatten(x) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_expand_dims(): +def test_expand_dims(run_module): def get_graph(x_shape=(1, 3), axis=1, num_newaxis=1): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.expand_dims(x, axis, num_newaxis) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_squeeze(): +def test_squeeze(run_module): def get_graph(x_shape, axis): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.squeeze(x, axis=axis) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph((1, 5, 1, 1), (2, 3))) - run_and_verify_func(get_graph((1, 3, 1), (-1,))) + run_and_verify_func(get_graph((1, 5, 1, 1), (2, 3)), run_module=run_module) + run_and_verify_func(get_graph((1, 3, 1), (-1,)), run_module=run_module) -def test_concatenate(): +def test_concatenate(run_module): def get_graph(input_shapes, axis): concat_inputs = [] shapes_dict = {} @@ -634,23 +630,25 @@ def get_graph(input_shapes, axis): f = relay.Function(concat_inputs, out) return f, shapes_dict, [] - run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1)) + run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1), run_module=run_module) -def test_split(): +def test_split(run_module): def get_graph(x_shape, indices_or_sections, axis): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.split(x, indices_or_sections=indices_or_sections, axis=axis) f = relay.Function([x], out.astuple()) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph((1, 16), indices_or_sections=2, axis=1)) - run_and_verify_func(get_graph((1, 16), indices_or_sections=4, axis=1)) - run_and_verify_func(get_graph((1, 16), indices_or_sections=[8], axis=1)) - run_and_verify_func(get_graph((1, 16), indices_or_sections=[2, 3, 6, 10, 14], axis=1)) + run_and_verify_func(get_graph((1, 16), indices_or_sections=2, axis=1), run_module=run_module) + run_and_verify_func(get_graph((1, 16), indices_or_sections=4, axis=1), run_module=run_module) + run_and_verify_func(get_graph((1, 16), indices_or_sections=[8], axis=1), run_module=run_module) + run_and_verify_func( + get_graph((1, 16), indices_or_sections=[2, 3, 6, 10, 14], axis=1), run_module=run_module + ) -def test_conv2d_transpose(): +def test_conv2d_transpose(run_module): def get_graph( x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), @@ -674,19 +672,19 @@ def get_graph( for padding in [(0, 0), (1, 1)]: for strides in [(1, 1), (2, 2)]: - run_and_verify_func(get_graph(padding=padding, strides=strides)) + run_and_verify_func(get_graph(padding=padding, strides=strides), run_module=run_module) -def test_reshape(): +def test_reshape(run_module): def get_graph(x_shape, new_shape): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.reshape(x, new_shape) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph((1, 1, 1, 10), (-1, 10))) - run_and_verify_func(get_graph((1, 10, 2, 3), (1, -1))) - run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6))) + run_and_verify_func(get_graph((1, 1, 1, 10), (-1, 10)), run_module=run_module) + run_and_verify_func(get_graph((1, 10, 2, 3), (1, -1)), run_module=run_module) + run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6)), run_module=run_module) class AreOpsOnGraph(ExprVisitor): @@ -732,10 +730,10 @@ def are_ops_on_trt(mod, op_list): return True -def test_dynamic_reshape(): - if skip_codegen_test(): - return - +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_dynamic_reshape(run_module): def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): result_arr = [{} for _ in range(len(x_data_list))] for use_trt in [True, False]: @@ -749,7 +747,7 @@ def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): mod, params={}, remove_no_mac_subgraphs=False ) assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt - if not skip_runtime_test(): + if run_module: with relay.build_config(opt_level=3): func = relay.create_executor( "vm", mod=mod, device=tvm.cpu(0), target="llvm" @@ -758,7 +756,7 @@ def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): for i, x_data in enumerate(x_data_list): result_arr[i][use_trt] = func(x_data) - if not skip_runtime_test(): + if run_module: for i in range(len(x_data_list)): assert_result_dict_holds(result_arr[i]) @@ -791,18 +789,18 @@ def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): test_run(x_data_list, x_shape, new_shape, should_offload_to_trt) -def test_transpose(): +def test_transpose(run_module): def get_graph(x_shape, order): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.transpose(x, order) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph((1, 16, 7, 7), [0, 2, 3, 1])) - run_and_verify_func(get_graph((1, 7, 7, 16), [0, 3, 1, 2])) + run_and_verify_func(get_graph((1, 16, 7, 7), [0, 2, 3, 1]), run_module=run_module) + run_and_verify_func(get_graph((1, 7, 7, 16), [0, 3, 1, 2]), run_module=run_module) -def test_float_const(): +def test_float_const(run_module): def get_graph(x_shape=(1, 16)): x = relay.var("x", shape=(x_shape), dtype="float32") beta = relay.const(1, dtype="float32") @@ -810,36 +808,45 @@ def get_graph(x_shape=(1, 16)): f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_pad(): +def test_pad(run_module): def get_graph(x_shape, pad_width): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.nn.pad(x, pad_width=pad_width) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0]])) - run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]])) - run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]])) - run_and_verify_func(get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]])) + run_and_verify_func( + get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0]]), run_module=run_module + ) + run_and_verify_func( + get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]]), run_module=run_module + ) + run_and_verify_func( + get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]]), run_module=run_module + ) + run_and_verify_func( + get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]), + run_module=run_module, + ) -def test_softmax(): +def test_softmax(run_module): def get_graph(x_shape, axis): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.nn.softmax(x, axis=axis) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph((1, 1000), axis=1)) - run_and_verify_func(get_graph((1, 1000), axis=-1)) - run_and_verify_func(get_graph((1, 3, 4), axis=-2)) - run_and_verify_func(get_graph((1, 3, 4), axis=1)) + run_and_verify_func(get_graph((1, 1000), axis=1), run_module=run_module) + run_and_verify_func(get_graph((1, 1000), axis=-1), run_module=run_module) + run_and_verify_func(get_graph((1, 3, 4), axis=-2), run_module=run_module) + run_and_verify_func(get_graph((1, 3, 4), axis=1), run_module=run_module) -def test_batch_norm(): +def test_batch_norm(run_module): def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): x = relay.var("x", shape=(x_shape), dtype="float32") beta = relay.var("beta", shape=(param_shape), dtype="float32") @@ -870,17 +877,19 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): ["beta", "gamma", "moving_mean", "moving_var"], ) - run_and_verify_func(get_graph((1, 64, 56, 56), (64,))) - run_and_verify_func(get_graph((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05)) - run_and_verify_func(get_graph((1, 4, 8, 4), (8,), axis=2)) - run_and_verify_func(get_graph((1, 8, 4, 4, 4), (8,), axis=1)) - run_and_verify_func(get_graph((1, 4, 8, 4, 4), (8,), axis=2)) - run_and_verify_func(get_graph((1, 4, 4, 4, 8), (8,), axis=4)) - run_and_verify_func(get_graph((1, 8), (8,), axis=1)) - run_and_verify_func(get_graph((1, 3, 8), (8,), axis=2)) + run_and_verify_func(get_graph((1, 64, 56, 56), (64,)), run_module=run_module) + run_and_verify_func( + get_graph((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05), run_module=run_module + ) + run_and_verify_func(get_graph((1, 4, 8, 4), (8,), axis=2), run_module=run_module) + run_and_verify_func(get_graph((1, 8, 4, 4, 4), (8,), axis=1), run_module=run_module) + run_and_verify_func(get_graph((1, 4, 8, 4, 4), (8,), axis=2), run_module=run_module) + run_and_verify_func(get_graph((1, 4, 4, 4, 8), (8,), axis=4), run_module=run_module) + run_and_verify_func(get_graph((1, 8), (8,), axis=1), run_module=run_module) + run_and_verify_func(get_graph((1, 3, 8), (8,), axis=2), run_module=run_module) -def test_layer_norm(): +def test_layer_norm(run_module): def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): x = relay.var("x", shape=(x_shape), dtype="float32") gamma = relay.var("gamma", shape=(param_shape), dtype="float32") @@ -905,12 +914,14 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): ["beta", "gamma"], ) - run_and_verify_func(get_graph((1, 32, 8, 8), (32,))) - run_and_verify_func(get_graph((1, 8, 8, 32), (32,), axis=3, epsilon=1.001e-05)) - run_and_verify_func(get_graph((1, 8), (8,), axis=1)) + run_and_verify_func(get_graph((1, 32, 8, 8), (32,)), run_module=run_module) + run_and_verify_func( + get_graph((1, 8, 8, 32), (32,), axis=3, epsilon=1.001e-05), run_module=run_module + ) + run_and_verify_func(get_graph((1, 8), (8,), axis=1), run_module=run_module) -def test_unary(): +def test_unary(run_module): def get_graph(op, x_shape=(1, 8, 3, 3)): x = relay.var("x", shape=(x_shape), dtype="float32") out = op(x) @@ -933,30 +944,30 @@ def get_graph(op, x_shape=(1, 8, 3, 3)): relay.floor, relay.erf, ]: - run_and_verify_func(get_graph(op)) + run_and_verify_func(get_graph(op), run_module=run_module) -def test_clip(): +def test_clip(run_module): def get_graph(x_shape=(1, 8, 3, 3)): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.clip(x, a_min=-0.2, a_max=0.4) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_leaky_relu(): +def test_leaky_relu(run_module): def get_graph(x_shape=(1, 8, 3, 3)): x = relay.var("x", shape=(x_shape), dtype="float32") out = relay.nn.leaky_relu(x, alpha=0.1) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_binary(): +def test_binary(run_module): def get_graph(op, x_shape, y_shape, y_is_const=False): x = relay.var("x", shape=(x_shape), dtype="float32") if y_is_const: @@ -971,14 +982,20 @@ def get_graph(op, x_shape, y_shape, y_is_const=False): for op in [relay.add, relay.subtract, relay.multiply, relay.divide, relay.power]: for y_is_const in [True, False]: - run_and_verify_func(get_graph(op, (1, 8, 3, 3), (1, 8, 3, 3), y_is_const)) - run_and_verify_func(get_graph(op, (1, 8, 1, 3), (1, 8, 3, 1), y_is_const)) - run_and_verify_func(get_graph(op, (1, 10), (10,), y_is_const)) - run_and_verify_func(get_graph(op, (1, 1, 1, 10), (10,), y_is_const)) - run_and_verify_func(get_graph(op, (1, 1, 1), (3,), y_is_const)) + run_and_verify_func( + get_graph(op, (1, 8, 3, 3), (1, 8, 3, 3), y_is_const), run_module=run_module + ) + run_and_verify_func( + get_graph(op, (1, 8, 1, 3), (1, 8, 3, 1), y_is_const), run_module=run_module + ) + run_and_verify_func(get_graph(op, (1, 10), (10,), y_is_const), run_module=run_module) + run_and_verify_func( + get_graph(op, (1, 1, 1, 10), (10,), y_is_const), run_module=run_module + ) + run_and_verify_func(get_graph(op, (1, 1, 1), (3,), y_is_const), run_module=run_module) -def test_reduce(): +def test_reduce(run_module): def get_graph(op, x_shape=(1, 2, 3, 4), axis=(2, 3), keepdims=False): x = relay.var("x", shape=(x_shape), dtype="float32") out = op(x, axis=axis, keepdims=keepdims) @@ -987,13 +1004,19 @@ def get_graph(op, x_shape=(1, 2, 3, 4), axis=(2, 3), keepdims=False): for op in [relay.sum, relay.prod, relay.max, relay.min, relay.mean]: for keepdims in [True, False]: - run_and_verify_func(get_graph(op, axis=(1), keepdims=keepdims)) - run_and_verify_func(get_graph(op, axis=(2, 3), keepdims=keepdims)) - run_and_verify_func(get_graph(op, axis=(1, 2), keepdims=keepdims)) - run_and_verify_func(get_graph(op, axis=(1, 2, 3), keepdims=keepdims)) + run_and_verify_func(get_graph(op, axis=(1), keepdims=keepdims), run_module=run_module) + run_and_verify_func( + get_graph(op, axis=(2, 3), keepdims=keepdims), run_module=run_module + ) + run_and_verify_func( + get_graph(op, axis=(1, 2), keepdims=keepdims), run_module=run_module + ) + run_and_verify_func( + get_graph(op, axis=(1, 2, 3), keepdims=keepdims), run_module=run_module + ) -def test_strided_slice(): +def test_strided_slice(run_module): def get_graph(x_shape, begin, end, strides=None, slice_mode="size"): x = relay.var("x", shape=(x_shape), dtype="float32") if strides: @@ -1016,32 +1039,38 @@ def get_graph(x_shape, begin, end, strides=None, slice_mode="size"): for slice_mode in ["size", "end"]: run_and_verify_func( - get_graph((1, 3, 6, 7), (0, 0, 0, 0), (1, 1, 6, 7), slice_mode=slice_mode) + get_graph((1, 3, 6, 7), (0, 0, 0, 0), (1, 1, 6, 7), slice_mode=slice_mode), + run_module=run_module, + ) + run_and_verify_func( + get_graph((1, 3, 6, 7), [0, 1, 0, 0], [1, 2, 6, 6], slice_mode=slice_mode), + run_module=run_module, ) run_and_verify_func( - get_graph((1, 3, 6, 7), [0, 1, 0, 0], [1, 2, 6, 6], slice_mode=slice_mode) + get_graph((2, 3, 6, 7), [0, 0, 0, 0], [-1, -1, -1, -1], slice_mode=slice_mode), + run_module=run_module, ) run_and_verify_func( - get_graph((2, 3, 6, 7), [0, 0, 0, 0], [-1, -1, -1, -1], slice_mode=slice_mode) + get_graph((2, 3, 6, 7), [0, 1, 0, 0], [-1, -1, -1, -1], slice_mode=slice_mode), + run_module=run_module, ) run_and_verify_func( - get_graph((2, 3, 6, 7), [0, 1, 0, 0], [-1, -1, -1, -1], slice_mode=slice_mode) + get_graph((1, 6), [0, 1], [1, 3], slice_mode=slice_mode), run_module=run_module ) - run_and_verify_func(get_graph((1, 6), [0, 1], [1, 3], slice_mode=slice_mode)) -def test_adaptive_pool2d(): +def test_adaptive_pool2d(run_module): def get_graph(op, x_shape=(1, 3, 32, 32), out_size=(1, 1)): x = relay.var("x", shape=(x_shape), dtype="float32") out = op(x, out_size) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph(relay.nn.adaptive_max_pool2d)) - run_and_verify_func(get_graph(relay.nn.adaptive_avg_pool2d)) + run_and_verify_func(get_graph(relay.nn.adaptive_max_pool2d), run_module=run_module) + run_and_verify_func(get_graph(relay.nn.adaptive_avg_pool2d), run_module=run_module) -def test_multiple_outputs(): +def test_multiple_outputs(run_module): def get_graph(): x = relay.var("x", shape=(1, 3), dtype="float32") y = relay.var("y", shape=(1, 3), dtype="float32") @@ -1051,10 +1080,10 @@ def get_graph(): f = relay.Function([x, y], out) return f, {"x": (1, 3), "y": (1, 3)}, [] - run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(), run_module=run_module) -def test_conv3d(): +def test_conv3d(run_module): def get_graph( x_shape=(1, 32, 8, 8, 8), k_shape=(16, 32, 3, 3, 3), @@ -1078,11 +1107,11 @@ def get_graph( f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] - run_and_verify_func(get_graph()) - run_and_verify_func(get_graph(padding=(0, 0, 0, 1, 1, 1))) + run_and_verify_func(get_graph(), run_module=run_module) + run_and_verify_func(get_graph(padding=(0, 0, 0, 1, 1, 1)), run_module=run_module) -def test_pool3d(): +def test_pool3d(run_module): def get_graph( op, x_shape=(1, 3, 8, 32, 32), @@ -1113,13 +1142,15 @@ def get_graph( f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph(relay.nn.avg_pool3d)) - run_and_verify_func(get_graph(relay.nn.max_pool3d)) - run_and_verify_func(get_graph(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1))) - run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1))) + run_and_verify_func(get_graph(relay.nn.avg_pool3d), run_module=run_module) + run_and_verify_func(get_graph(relay.nn.max_pool3d), run_module=run_module) + run_and_verify_func( + get_graph(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1)), run_module=run_module + ) + run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1)), run_module=run_module) -def test_conv3d_transpose(): +def test_conv3d_transpose(run_module): def get_graph( x_shape=(1, 32, 8, 8, 8), k_shape=(32, 16, 3, 3, 3), @@ -1143,43 +1174,74 @@ def get_graph( f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] - run_and_verify_func(get_graph()) - run_and_verify_func(get_graph(strides=(2, 2, 2))) - run_and_verify_func(get_graph(strides=(2, 2, 2), output_padding=(1, 1, 1))) + run_and_verify_func(get_graph(), run_module=run_module) + run_and_verify_func(get_graph(strides=(2, 2, 2)), run_module=run_module) + run_and_verify_func( + get_graph(strides=(2, 2, 2), output_padding=(1, 1, 1)), run_module=run_module + ) -def test_alexnet(): - run_and_verify_model("alexnet") +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_alexnet(run_module): + run_and_verify_model("alexnet", run_module) -def test_resnet18_v1(): - run_and_verify_model("resnet18_v1") +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_resnet18_v1(run_module): + run_and_verify_model("resnet18_v1", run_module) -def test_resnet18_v2(): - run_and_verify_model("resnet18_v2") +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_resnet18_v2(run_module): + run_and_verify_model("resnet18_v2", run_module) -def test_squeezenet(): - run_and_verify_model("squeezenet1.0") +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_squeezenet(run_module): + run_and_verify_model("squeezenet1.0", run_module) -def test_mobilenet(): - run_and_verify_model("mobilenet0.25") +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_mobilenet(run_module): + run_and_verify_model("mobilenet0.25", run_module) -def test_mobilenet_v2(): - run_and_verify_model("mobilenetv2_0.25") +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_mobilenet_v2(run_module): + run_and_verify_model("mobilenetv2_0.25", run_module) -def test_vgg11(): - run_and_verify_model("vgg11") +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_vgg11(run_module): + run_and_verify_model("vgg11", run_module) -def test_densenet121(): - run_and_verify_model("densenet121") +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +def test_densenet121(run_module): + run_and_verify_model("densenet121", run_module) +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) +@has_tensorrt_codegen +@tvm.testing.requires_cuda def test_dynamic_offload(): """ This test checks for proper dynamic offloading of relay graphs. An addition between @@ -1188,9 +1250,6 @@ def test_dynamic_offload(): offload the conv2d with dynamic arg to TVM while running the other in TRT. """ - if skip_codegen_test(): - return - data_shape = (1, 32, 8, 8) k_shape = (1, 32, 3, 3) @@ -1235,10 +1294,7 @@ def get_expected(): tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True) -def test_tensorrt_dynamic_batch(): - if skip_codegen_test(): - return - +def test_tensorrt_dynamic_batch(run_module): batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2] x_shape = (relay.Any(), 1, 8, 8) x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") @@ -1252,7 +1308,7 @@ def test_tensorrt_dynamic_batch(): if use_trt: mod, _ = tensorrt.partition_for_tensorrt(mod) - if not skip_runtime_test(): + if run_module: with relay.build_config(opt_level=3): func = relay.create_executor( "vm", mod=mod, device=tvm.cpu(0), target="llvm" @@ -1260,14 +1316,12 @@ def test_tensorrt_dynamic_batch(): for i, batch_size in enumerate(batches_to_test): result_arr[i][use_trt] = func(x_data[:batch_size, ...]) - if not skip_runtime_test(): + if run_module: for i in range(len(batches_to_test)): assert_result_dict_holds(result_arr[i]) -def test_tensorrt_dynamic_batch_conv(): - if skip_codegen_test(): - return +def test_tensorrt_dynamic_batch_conv(run_module): batches_to_test = [1, 5, 1, 0, 2, 3, 0, 1, 3, 2] x_shape = (relay.Any(), 32, 8, 8) x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") @@ -1286,7 +1340,7 @@ def test_tensorrt_dynamic_batch_conv(): mod, config = tensorrt.partition_for_tensorrt( mod, params, use_implicit_batch=use_implicit_batch ) - if not skip_runtime_test(): + if run_module: for target in ["llvm", "cuda"]: with tvm.transform.PassContext( opt_level=3, config={"relay.ext.tensorrt.options": config} @@ -1296,21 +1350,18 @@ def test_tensorrt_dynamic_batch_conv(): ).evaluate() for i, batch_size in enumerate(batches_to_test): result_arr[i][target][use_trt] = func(x_data[:batch_size, ...], **params) - if not skip_runtime_test(): + if run_module: for i in range(len(batches_to_test)): for target in ["llvm", "cuda"]: assert_result_dict_holds(result_arr[i][target]) -def test_maskrcnn_resnet50() -> None: +def test_maskrcnn_resnet50(run_module) -> None: """ This function tests the working of pytorch maskrcnn with resnet50 as backbone with VM and VM + TRT. Since the order of compiled model outputs is a bit different from original pytorch model, it uses a custom logic for comparison check. """ - if skip_codegen_test(): - return - import torch import torchvision @@ -1387,43 +1438,39 @@ def get_maskrcnn_input(in_size: int) -> np.ndarray: traced_module = get_traced_maskrcnn_model(np_sample_input) vm_trt_exec = convert_traced_model_to_vm_trt(traced_module, np_sample_input, target="llvm") - if skip_runtime_test(): - return - - dev = tvm.cpu() - vm = tvm.runtime.vm.VirtualMachine(vm_trt_exec, dev) - vm.set_input("main", **{"input0": np_sample_input}) - tvm_res = vm.run() - - # Descending sort by scores and get the high confidence indices. In this example 9 is chosen, - # because this image has 9 boxes over 0.9 confidence - num_high_confidence_boxes = 9 - tvm_indices = np.argsort(-1 * tvm_res[1].numpy())[:num_high_confidence_boxes] - - with torch.no_grad(): - out = traced_module(torch.Tensor(np_sample_input)) - # Descending sort by scores and get the high confidence indices - pt_indices = np.argsort(-1 * out[1].numpy())[:num_high_confidence_boxes] - - tol = [1e-1, 5e-3, 1e-5, 4e-1] # [Box Tol, Score Tol, Label Tol, Mask Tol] - # Because of certain ops, there are certain minor differences in TVM outputs and PT outputs, - # This means that the tolerance can't be 1e-4 or 1e-5 throughout. The ideal way to get around - # this is to test it on an entire dataset and compare mAP with the original model. - # However, since that is not practically possible on CI, the following compromise is made. - # These tolerances are chosen based on their impact or lack thereof to the mAP score, e.g: - # 0.1 pixel difference of a box in a 300X300 image wont make any change. - for i, tol_val in zip(range(4), tol): - np.testing.assert_allclose( - tvm_res[i].numpy()[tvm_indices], - out[i].numpy()[pt_indices], - rtol=tol_val, - atol=tol_val, - ) + if run_module: + dev = tvm.cpu() + vm = tvm.runtime.vm.VirtualMachine(vm_trt_exec, dev) + vm.set_input("main", **{"input0": np_sample_input}) + tvm_res = vm.run() + + # Descending sort by scores and get the high confidence indices. In this example 9 is chosen, + # because this image has 9 boxes over 0.9 confidence + num_high_confidence_boxes = 9 + tvm_indices = np.argsort(-1 * tvm_res[1].numpy())[:num_high_confidence_boxes] + + with torch.no_grad(): + out = traced_module(torch.Tensor(np_sample_input)) + # Descending sort by scores and get the high confidence indices + pt_indices = np.argsort(-1 * out[1].numpy())[:num_high_confidence_boxes] + + tol = [1e-1, 5e-3, 1e-5, 4e-1] # [Box Tol, Score Tol, Label Tol, Mask Tol] + # Because of certain ops, there are certain minor differences in TVM outputs and PT outputs, + # This means that the tolerance can't be 1e-4 or 1e-5 throughout. The ideal way to get around + # this is to test it on an entire dataset and compare mAP with the original model. + # However, since that is not practically possible on CI, the following compromise is made. + # These tolerances are chosen based on their impact or lack thereof to the mAP score, e.g: + # 0.1 pixel difference of a box in a 300X300 image wont make any change. + for i, tol_val in zip(range(4), tol): + np.testing.assert_allclose( + tvm_res[i].numpy()[tvm_indices], + out[i].numpy()[pt_indices], + rtol=tol_val, + atol=tol_val, + ) -def test_empty_subgraph(): - if skip_codegen_test(): - return +def test_empty_subgraph(run_module): x_shape = (1, 3, 5) mod = tvm.IRModule() # Empty tensorrt subgraph. @@ -1446,7 +1493,7 @@ def test_empty_subgraph(): func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" ).evaluate() - if not skip_runtime_test(): + if run_module: results = func(x_data)