diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 6b7f9be06d99..3c94b0b846d8 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -17,8 +17,6 @@ # pylint: disable=import-self, invalid-name, unused-argument """Test torch vision fasterrcnn and maskrcnn models""" import numpy as np -import torch -import torchvision import cv2 import tvm @@ -33,6 +31,8 @@ ) from tvm.contrib.download import download +import torch +import torchvision in_size = 300 @@ -150,10 +150,10 @@ def compile_and_run_vm(mod, params, data_np, target): after = mod["main"] assert not tvm.ir.structural_equal(after, before) - # before = mod["main"] - # mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn - # after = mod["main"] - # assert not tvm.ir.structural_equal(after, before) + before = mod["main"] + mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn + after = mod["main"] + assert not tvm.ir.structural_equal(after, before) tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm")