diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 58f576bef45ea..51473bbbabef7 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -17,7 +17,15 @@ # pylint: disable=import-outside-toplevel, unused-argument """ Common utilities used by PyTorch frontend """ from .. import op -from ..dataflow_pattern import * +from ..dataflow_pattern import ( + is_constant, + is_op, + rewrite, + is_tuple, + is_tuple_get_item, + wildcard, + DFPatternCallback, +) def is_version_greater_than(ver): @@ -79,7 +87,7 @@ class NMSRewrite(DFPatternCallback): def __init__(self): super().__init__() - # exprs I want to extract + # exprs to extract self.boxes = wildcard() self.scores = wildcard() self.idxs = wildcard() @@ -102,7 +110,7 @@ def convert_batched_nms(self, boxes, scores, idxs, iou_thres): indices=indices, max_output_size=max_out_size, iou_threshold=iou_thres, - force_suppress=True, + force_suppress=False, top_k=top_k, coord_start=2, score_index=1, diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index c35bacd42df9a..2476a3e3583c2 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -22,6 +22,7 @@ import cv2 import tvm +import tvm.testing from tvm import relay from tvm.runtime.vm import VirtualMachine @@ -108,7 +109,7 @@ def compile_and_run_vm(mod, params, data_np): with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): vm_exec = relay.vm.compile(mod, target=target, params=params) - ctx = tvm.cpu() + ctx = tvm.context(target, 0) vm = VirtualMachine(vm_exec, ctx) vm.set_input("main", **{input_name: data_np}) return vm.run() @@ -123,19 +124,17 @@ def compile_and_run_vm(mod, params, data_np): pt_res = scripted_model(data) pt_scores = pt_res[1].detach().numpy().tolist() tvm_scores = tvm_res[1].asnumpy().tolist() - num_pt_valid_scores = num_tvm_valid_scores = 0 - for score in pt_scores: - if score >= score_threshold: - num_pt_valid_scores += 1 - else: - break + def count_valid_scores(scores): + num_valid_scores = 0 + for score in pt_scores: + if score >= score_threshold: + num_valid_scores += 1 + else: + return num_valid_scores - for score in tvm_scores: - if score >= score_threshold: - num_tvm_valid_scores += 1 - else: - break + num_pt_valid_scores = count_valid_scores(pt_scores) + num_tvm_valid_scores = count_valid_scores(tvm_scores) assert num_pt_valid_scores == num_tvm_valid_scores, ( "Output mismatch: Under score threshold {}, Pytorch has {} valid " @@ -145,9 +144,16 @@ def compile_and_run_vm(mod, params, data_np): before = mod["main"] after = rewrite(NMSRewrite(), before) # TODO(masahi): Is there a better way to test if the desired rewrite has happened? - assert tvm.ir.structural_equal(after, before) + assert not tvm.ir.structural_equal(after, before) + mod["main"] = after tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np) - for res1, res2 in zip(tvm_res, tvm_res_after_rewrite): - tvm.testing.assert_allclose(res1, res2) + num_tvm_after_rewrite_valid_scores = count_valid_scores( + tvm_res_after_rewrite[1].asnumpy().tolist() + ) + assert num_tvm_valid_scores == num_tvm_after_rewrite_valid_scores + + # Results should be equivalent after rewriting + for i, (res1, res2) in enumerate(zip(tvm_res, tvm_res_after_rewrite)): + tvm.testing.assert_allclose(res1.asnumpy(), res2.asnumpy())