From 4a20df55a141453cca1b502f0a1d8a7f942316a7 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 1 Feb 2021 00:42:29 +0900 Subject: [PATCH] swap pytorch and tvm import order (#7380) --- .../python/frontend/pytorch/test_object_detection.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 6b7f9be06d99..3c94b0b846d8 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -17,8 +17,6 @@ # pylint: disable=import-self, invalid-name, unused-argument """Test torch vision fasterrcnn and maskrcnn models""" import numpy as np -import torch -import torchvision import cv2 import tvm @@ -33,6 +31,8 @@ ) from tvm.contrib.download import download +import torch +import torchvision in_size = 300 @@ -150,10 +150,10 @@ def compile_and_run_vm(mod, params, data_np, target): after = mod["main"] assert not tvm.ir.structural_equal(after, before) - # before = mod["main"] - # mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn - # after = mod["main"] - # assert not tvm.ir.structural_equal(after, before) + before = mod["main"] + mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn + after = mod["main"] + assert not tvm.ir.structural_equal(after, before) tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm")