Skip to content

Commit

Permalink
test fixed by setting force_surpress=False
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 22, 2020
1 parent 05cc2a0 commit e2446b0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
14 changes: 11 additions & 3 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
# pylint: disable=import-outside-toplevel, unused-argument
""" Common utilities used by PyTorch frontend """
from .. import op
from ..dataflow_pattern import *
from ..dataflow_pattern import (
is_constant,
is_op,
rewrite,
is_tuple,
is_tuple_get_item,
wildcard,
DFPatternCallback,
)


def is_version_greater_than(ver):
Expand Down Expand Up @@ -79,7 +87,7 @@ class NMSRewrite(DFPatternCallback):

def __init__(self):
super().__init__()
# exprs I want to extract
# exprs to extract
self.boxes = wildcard()
self.scores = wildcard()
self.idxs = wildcard()
Expand All @@ -102,7 +110,7 @@ def convert_batched_nms(self, boxes, scores, idxs, iou_thres):
indices=indices,
max_output_size=max_out_size,
iou_threshold=iou_thres,
force_suppress=True,
force_suppress=False,
top_k=top_k,
coord_start=2,
score_index=1,
Expand Down
36 changes: 21 additions & 15 deletions tests/python/frontend/pytorch/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import cv2

import tvm
import tvm.testing

from tvm import relay
from tvm.runtime.vm import VirtualMachine
Expand Down Expand Up @@ -108,7 +109,7 @@ def compile_and_run_vm(mod, params, data_np):
with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.cpu()
ctx = tvm.context(target, 0)
vm = VirtualMachine(vm_exec, ctx)
vm.set_input("main", **{input_name: data_np})
return vm.run()
Expand All @@ -123,19 +124,17 @@ def compile_and_run_vm(mod, params, data_np):
pt_res = scripted_model(data)
pt_scores = pt_res[1].detach().numpy().tolist()
tvm_scores = tvm_res[1].asnumpy().tolist()
num_pt_valid_scores = num_tvm_valid_scores = 0

for score in pt_scores:
if score >= score_threshold:
num_pt_valid_scores += 1
else:
break
def count_valid_scores(scores):
num_valid_scores = 0
for score in pt_scores:
if score >= score_threshold:
num_valid_scores += 1
else:
return num_valid_scores

for score in tvm_scores:
if score >= score_threshold:
num_tvm_valid_scores += 1
else:
break
num_pt_valid_scores = count_valid_scores(pt_scores)
num_tvm_valid_scores = count_valid_scores(tvm_scores)

assert num_pt_valid_scores == num_tvm_valid_scores, (
"Output mismatch: Under score threshold {}, Pytorch has {} valid "
Expand All @@ -145,9 +144,16 @@ def compile_and_run_vm(mod, params, data_np):
before = mod["main"]
after = rewrite(NMSRewrite(), before)
# TODO(masahi): Is there a better way to test if the desired rewrite has happened?
assert tvm.ir.structural_equal(after, before)
assert not tvm.ir.structural_equal(after, before)

mod["main"] = after
tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np)

for res1, res2 in zip(tvm_res, tvm_res_after_rewrite):
tvm.testing.assert_allclose(res1, res2)
num_tvm_after_rewrite_valid_scores = count_valid_scores(
tvm_res_after_rewrite[1].asnumpy().tolist()
)
assert num_tvm_valid_scores == num_tvm_after_rewrite_valid_scores

# Results should be equivalent after rewriting
for i, (res1, res2) in enumerate(zip(tvm_res, tvm_res_after_rewrite)):
tvm.testing.assert_allclose(res1.asnumpy(), res2.asnumpy())

0 comments on commit e2446b0

Please sign in to comment.