From 7c840affaf357b20cc75f21f69cf726d214e39c0 Mon Sep 17 00:00:00 2001
From: Masahiro Masuda <masahi129@gmail.com>
Date: Wed, 15 Sep 2021 13:41:06 +0900
Subject: [PATCH 1/2] [Torch] Add an option to make imported models compatible
 with the Relay text parser

---
 python/tvm/relay/frontend/pytorch.py          |  36 +++-
 tests/python/frontend/pytorch/test_forward.py | 191 +-----------------
 2 files changed, 39 insertions(+), 188 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index c13d791cf2e2..818a1a4716fb 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3400,8 +3400,8 @@ def _getattr_attr_name(node):
     return attr_name
 
 
-def _getattr_full_name(getattrs):
-    return ".".join([_getattr_attr_name(node) for node in getattrs])
+def _getattr_full_name(getattrs, sep="."):
+    return sep.join([_getattr_attr_name(node) for node in getattrs])
 
 
 def _get_pytorch_value_type(typ, default_dtype="float32"):
@@ -3657,7 +3657,7 @@ def terminate(users):
     return get_use_chains(root_getattr_node, terminate)
 
 
-def convert_params(graph, state_dict):
+def convert_params(graph, state_dict, use_parser_friendly_name=False):
     """
     Return Relay vars and TVM NDArrays for input parameters
     A chain of prim::GetAttr nodes is processed one at a time
@@ -3668,6 +3668,7 @@ def convert_params(graph, state_dict):
     packed_param_map = {}
     vars_by_name = {}
     seen = set()
+    attr_name_sep = "_" if use_parser_friendly_name else "."
 
     for node in getattr_nodes:
         if _get_output_name(node) in seen:
@@ -3676,7 +3677,7 @@ def convert_params(graph, state_dict):
         for getattrs in get_attr_chains(node):
             seen.update(map(_get_output_name, getattrs))
 
-            full_attr = _getattr_full_name(getattrs)
+            full_attr = _getattr_full_name(getattrs, attr_name_sep)
             full_attr_node_name = _get_output_name(getattrs[-1])
 
             if full_attr.endswith("_packed_params"):  # for quantized models
@@ -3706,7 +3707,13 @@ def get_all_op_names(graph):
     return set(node.kind() for node in nodes)
 
 
-def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dtype="float32"):
+def from_pytorch(
+    script_module,
+    input_infos,
+    custom_convert_map=None,
+    default_dtype="float32",
+    use_parser_friendly_name=False,
+):
     """Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
     The companion parameters will be handled automatically.
 
@@ -3729,6 +3736,15 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
     custom_convert_map : Dictionary of str to Relay op
         A custom op conversion map in the same format as _convert_map above
 
+    default_type : str
+        The default dtype to use when type information is not provided by PyTorch.
+
+    use_parser_friendly_name : bool
+        When True, replace '.' with `_' in a original parameter name.
+        The Relay text parser treats a variable name followed by a period as a tuple element access,
+        so a variable name like "dense.weight" cannot be correctly parsed.
+        Use this option when you want to run the AnnotateSpans pass on the imported module.
+
     Returns
     -------
     mod : tvm.IRModule
@@ -3758,7 +3774,13 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
     outputs = _get_relay_input_vars(
         graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
     )
-    param_vars, tensors, packed_param_map = convert_params(graph, params)
+
+    if use_parser_friendly_name:
+        new_names = [key.replace(".", "_") for key in params.keys()]
+        params = dict(zip(new_names, params.values()))
+
+    param_vars, tensors, packed_param_map = convert_params(graph, params, use_parser_friendly_name)
+
     tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
     outputs.update(param_vars)
@@ -3778,7 +3800,7 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
         # ListConstruct kept original python list. Convert to tuple.
         ret = _expr.Tuple(ret)
 
-    # Separate data inputs and parameters to make sure data inputs are always in the beginning.
+    # Separate data inputs and parameters to make sure data inputs come first.
     func_args = []
     data_inputs = []
     for arg in _analysis.free_vars(ret):
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 7b1cd8f53e8b..abae931e3ede 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3937,185 +3937,14 @@ def forward(self, x):
     verify_model(Flip(axis=-1), input_data=input)
 
 
+def test_annotate_span():
+    model = torchvision.models.resnet18().eval()
+    inp = torch.randn([1, 3, 224, 224])
+    trace = torch.jit.trace(model, inp).eval()
+    mod, params = relay.frontend.from_pytorch(trace, [('input', inp.shape)],
+                                              use_parser_friendly_name=True)
+    relay.transform.AnnotateSpans()(mod)
+
+
 if __name__ == "__main__":
-    # some structural tests
-    test_forward_traced_function()
-    test_forward_dtypes()
-    test_weight_names()
-    test_duplicate_weight_use()
-
-    # Single operator tests
-    test_forward_pixel_shuffle()
-    test_forward_add()
-    test_forward_subtract()
-    test_forward_multiply()
-    test_forward_matmul()
-    test_forward_rsub()
-    test_forward_onehot()
-    test_forward_embedding()
-    test_forward_reshape()
-    test_forward_reciprocal()
-    test_forward_repeat()
-    test_forward_repeat_interleave()
-    test_forward_squeeze()
-    test_forward_unsqueeze()
-    test_forward_concatenate()
-    test_forward_reduce_sum()
-    test_forward_reduce_prod()
-    test_forward_argmin()
-    test_forward_argmax()
-    test_forward_norm()
-    test_forward_frobenius_norm()
-    test_forward_std()
-    test_forward_variance()
-    test_forward_relu()
-    test_forward_prelu()
-    test_forward_leakyrelu()
-    test_forward_elu()
-    test_forward_celu()
-    test_forward_gelu()
-    test_forward_selu()
-    test_forward_log_sigmoid()
-    test_forward_adaptiveavgpool()
-    test_forward_maxpool2d()
-    test_forward_maxpool1d()
-    test_forward_maxpool3d()
-    test_forward_hardtanh()
-    test_forward_conv()
-    test_forward_conv_transpose()
-    test_forward_threshold()
-    test_forward_contiguous()
-    test_forward_batchnorm()
-    test_forward_instancenorm()
-    test_forward_layernorm()
-    test_forward_groupnorm()
-    test_forward_transpose()
-    test_forward_size()
-    test_forward_view()
-    test_forward_select()
-    test_forward_take()
-    test_forward_topk()
-    test_forward_where()
-    test_forward_addcdiv()
-    test_forward_addcmul()
-    test_forward_true_divide()
-    test_forward_is_floating_point()
-    test_forward_clone()
-    test_forward_softplus()
-    test_forward_softsign()
-    test_forward_logsoftmax()
-    test_forward_sigmoid()
-    test_forward_dense()
-    test_forward_linear()
-    test_forward_avgpool1d()
-    test_forward_avgpool2d()
-    test_forward_avgpool3d()
-    test_forward_dropout()
-    test_forward_slice()
-    test_forward_narrow()
-    test_forward_mean()
-    test_forward_expand()
-    test_forward_pow()
-    test_forward_unary()
-    test_forward_clamp()
-    test_forward_clamp_()
-    test_forward_logical_not()
-    test_forward_bitwise_not()
-    test_forward_bitwise_xor()
-    test_forward_logical_xor()
-    test_forward_isfinite()
-    test_forward_isnan()
-    test_forward_isinf()
-    test_forward_ones()
-    test_forward_ones_like()
-    test_forward_zeros()
-    test_forward_zeros_like()
-    test_forward_full()
-    test_forward_full_like()
-    test_forward_linspace()
-    test_forward_arange()
-    test_forward_mesh_grid()
-    test_forward_chunk()
-    test_forward_split()
-    test_forward_gather()
-    test_upsample()
-    test_forward_upsample3d()
-    test_forward_nms()
-    test_forward_roi_align()
-    test_to()
-    test_flatten()
-    test_type_as()
-    test_forward_functional_pad()
-    test_forward_zero_pad2d()
-    test_forward_constant_pad1d()
-    test_forward_constant_pad2d()
-    test_forward_constant_pad3d()
-    test_forward_reflection_pad1d()
-    test_forward_reflection_pad2d()
-    test_forward_replication_pad1d()
-    test_forward_replication_pad2d()
-    test_forward_replication_pad3d()
-    test_adaptive_pool3d()
-    test_conv3d()
-    test_conv3d_transpose()
-    test_forward_index()
-    test_min_max()
-    test_logsumexp()
-    test_stack()
-    test_stack_dynamic()
-    test_forward_unbind()
-    test_forward_nonzero()
-    test_forward_scatter()
-    test_forward_index_put()
-    test_numel()
-    test_bincount()
-    test_cumsum()
-    test_masked_fill()
-    test_transformer()
-    test_sort()
-    test_argsort()
-    test_logical_and()
-    test_masked_select()
-    test_unique()
-    test_hard_swish()
-    test_hard_sigmoid()
-    test_forward_nll_loss()
-    test_forward_flip()
-
-    # Model tests
-    test_resnet18()
-    test_squeezenet1_0()
-    test_squeezenet1_1()
-    test_densenet121()
-    # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
-    # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
-    # test_inception_v3()
-    test_googlenet()
-    test_mnasnet0_5()
-    test_mobilenet_v2()
-
-    test_custom_conversion_map()
-
-    test_segmentation_models()
-    test_3d_models()
-
-    # Quantization test
-    from qnn_test import test_quantized_imagenet, test_quantized_modules
-
-    test_quantized_modules()
-    test_quantized_imagenet()
-
-    # Test simple conditionals and loop
-    test_control_flow()
-    test_simple_rnn()
-
-    # More complex recurrent models
-    from test_lstm import test_custom_lstm
-
-    test_custom_lstm()
-
-    # Test bert model
-    test_forward_pretrained_bert_base_uncased()
-
-    # Test convert torch script(jit) with specific inputs' types
-    test_convert_torch_script_with_input_types()
+    pytest.main([__file__])

From 80cf35543779b7b266f9d7e57d16a20949c160cc Mon Sep 17 00:00:00 2001
From: Masahiro Masuda <masahi129@gmail.com>
Date: Wed, 15 Sep 2021 13:48:44 +0900
Subject: [PATCH 2/2] py format

---
 python/tvm/relay/frontend/pytorch.py          | 2 +-
 tests/python/frontend/pytorch/test_forward.py | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 818a1a4716fb..39bcfc68e421 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3742,7 +3742,7 @@ def from_pytorch(
     use_parser_friendly_name : bool
         When True, replace '.' with `_' in a original parameter name.
         The Relay text parser treats a variable name followed by a period as a tuple element access,
-        so a variable name like "dense.weight" cannot be correctly parsed.
+        so a variable name like "dense.weight" cannot be parsed correctly.
         Use this option when you want to run the AnnotateSpans pass on the imported module.
 
     Returns
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index abae931e3ede..c27469edf1d7 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3941,8 +3941,9 @@ def test_annotate_span():
     model = torchvision.models.resnet18().eval()
     inp = torch.randn([1, 3, 224, 224])
     trace = torch.jit.trace(model, inp).eval()
-    mod, params = relay.frontend.from_pytorch(trace, [('input', inp.shape)],
-                                              use_parser_friendly_name=True)
+    mod, params = relay.frontend.from_pytorch(
+        trace, [("input", inp.shape)], use_parser_friendly_name=True
+    )
     relay.transform.AnnotateSpans()(mod)