From 7c840affaf357b20cc75f21f69cf726d214e39c0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda <masahi129@gmail.com> Date: Wed, 15 Sep 2021 13:41:06 +0900 Subject: [PATCH 1/2] [Torch] Add an option to make imported models compatible with the Relay text parser --- python/tvm/relay/frontend/pytorch.py | 36 +++- tests/python/frontend/pytorch/test_forward.py | 191 +----------------- 2 files changed, 39 insertions(+), 188 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c13d791cf2e2..818a1a4716fb 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3400,8 +3400,8 @@ def _getattr_attr_name(node): return attr_name -def _getattr_full_name(getattrs): - return ".".join([_getattr_attr_name(node) for node in getattrs]) +def _getattr_full_name(getattrs, sep="."): + return sep.join([_getattr_attr_name(node) for node in getattrs]) def _get_pytorch_value_type(typ, default_dtype="float32"): @@ -3657,7 +3657,7 @@ def terminate(users): return get_use_chains(root_getattr_node, terminate) -def convert_params(graph, state_dict): +def convert_params(graph, state_dict, use_parser_friendly_name=False): """ Return Relay vars and TVM NDArrays for input parameters A chain of prim::GetAttr nodes is processed one at a time @@ -3668,6 +3668,7 @@ def convert_params(graph, state_dict): packed_param_map = {} vars_by_name = {} seen = set() + attr_name_sep = "_" if use_parser_friendly_name else "." for node in getattr_nodes: if _get_output_name(node) in seen: @@ -3676,7 +3677,7 @@ def convert_params(graph, state_dict): for getattrs in get_attr_chains(node): seen.update(map(_get_output_name, getattrs)) - full_attr = _getattr_full_name(getattrs) + full_attr = _getattr_full_name(getattrs, attr_name_sep) full_attr_node_name = _get_output_name(getattrs[-1]) if full_attr.endswith("_packed_params"): # for quantized models @@ -3706,7 +3707,13 @@ def get_all_op_names(graph): return set(node.kind() for node in nodes) -def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dtype="float32"): +def from_pytorch( + script_module, + input_infos, + custom_convert_map=None, + default_dtype="float32", + use_parser_friendly_name=False, +): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -3729,6 +3736,15 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt custom_convert_map : Dictionary of str to Relay op A custom op conversion map in the same format as _convert_map above + default_type : str + The default dtype to use when type information is not provided by PyTorch. + + use_parser_friendly_name : bool + When True, replace '.' with `_' in a original parameter name. + The Relay text parser treats a variable name followed by a period as a tuple element access, + so a variable name like "dense.weight" cannot be correctly parsed. + Use this option when you want to run the AnnotateSpans pass on the imported module. + Returns ------- mod : tvm.IRModule @@ -3758,7 +3774,13 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt outputs = _get_relay_input_vars( graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module ) - param_vars, tensors, packed_param_map = convert_params(graph, params) + + if use_parser_friendly_name: + new_names = [key.replace(".", "_") for key in params.keys()] + params = dict(zip(new_names, params.values())) + + param_vars, tensors, packed_param_map = convert_params(graph, params, use_parser_friendly_name) + tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} outputs.update(param_vars) @@ -3778,7 +3800,7 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt # ListConstruct kept original python list. Convert to tuple. ret = _expr.Tuple(ret) - # Separate data inputs and parameters to make sure data inputs are always in the beginning. + # Separate data inputs and parameters to make sure data inputs come first. func_args = [] data_inputs = [] for arg in _analysis.free_vars(ret): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 7b1cd8f53e8b..abae931e3ede 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3937,185 +3937,14 @@ def forward(self, x): verify_model(Flip(axis=-1), input_data=input) +def test_annotate_span(): + model = torchvision.models.resnet18().eval() + inp = torch.randn([1, 3, 224, 224]) + trace = torch.jit.trace(model, inp).eval() + mod, params = relay.frontend.from_pytorch(trace, [('input', inp.shape)], + use_parser_friendly_name=True) + relay.transform.AnnotateSpans()(mod) + + if __name__ == "__main__": - # some structural tests - test_forward_traced_function() - test_forward_dtypes() - test_weight_names() - test_duplicate_weight_use() - - # Single operator tests - test_forward_pixel_shuffle() - test_forward_add() - test_forward_subtract() - test_forward_multiply() - test_forward_matmul() - test_forward_rsub() - test_forward_onehot() - test_forward_embedding() - test_forward_reshape() - test_forward_reciprocal() - test_forward_repeat() - test_forward_repeat_interleave() - test_forward_squeeze() - test_forward_unsqueeze() - test_forward_concatenate() - test_forward_reduce_sum() - test_forward_reduce_prod() - test_forward_argmin() - test_forward_argmax() - test_forward_norm() - test_forward_frobenius_norm() - test_forward_std() - test_forward_variance() - test_forward_relu() - test_forward_prelu() - test_forward_leakyrelu() - test_forward_elu() - test_forward_celu() - test_forward_gelu() - test_forward_selu() - test_forward_log_sigmoid() - test_forward_adaptiveavgpool() - test_forward_maxpool2d() - test_forward_maxpool1d() - test_forward_maxpool3d() - test_forward_hardtanh() - test_forward_conv() - test_forward_conv_transpose() - test_forward_threshold() - test_forward_contiguous() - test_forward_batchnorm() - test_forward_instancenorm() - test_forward_layernorm() - test_forward_groupnorm() - test_forward_transpose() - test_forward_size() - test_forward_view() - test_forward_select() - test_forward_take() - test_forward_topk() - test_forward_where() - test_forward_addcdiv() - test_forward_addcmul() - test_forward_true_divide() - test_forward_is_floating_point() - test_forward_clone() - test_forward_softplus() - test_forward_softsign() - test_forward_logsoftmax() - test_forward_sigmoid() - test_forward_dense() - test_forward_linear() - test_forward_avgpool1d() - test_forward_avgpool2d() - test_forward_avgpool3d() - test_forward_dropout() - test_forward_slice() - test_forward_narrow() - test_forward_mean() - test_forward_expand() - test_forward_pow() - test_forward_unary() - test_forward_clamp() - test_forward_clamp_() - test_forward_logical_not() - test_forward_bitwise_not() - test_forward_bitwise_xor() - test_forward_logical_xor() - test_forward_isfinite() - test_forward_isnan() - test_forward_isinf() - test_forward_ones() - test_forward_ones_like() - test_forward_zeros() - test_forward_zeros_like() - test_forward_full() - test_forward_full_like() - test_forward_linspace() - test_forward_arange() - test_forward_mesh_grid() - test_forward_chunk() - test_forward_split() - test_forward_gather() - test_upsample() - test_forward_upsample3d() - test_forward_nms() - test_forward_roi_align() - test_to() - test_flatten() - test_type_as() - test_forward_functional_pad() - test_forward_zero_pad2d() - test_forward_constant_pad1d() - test_forward_constant_pad2d() - test_forward_constant_pad3d() - test_forward_reflection_pad1d() - test_forward_reflection_pad2d() - test_forward_replication_pad1d() - test_forward_replication_pad2d() - test_forward_replication_pad3d() - test_adaptive_pool3d() - test_conv3d() - test_conv3d_transpose() - test_forward_index() - test_min_max() - test_logsumexp() - test_stack() - test_stack_dynamic() - test_forward_unbind() - test_forward_nonzero() - test_forward_scatter() - test_forward_index_put() - test_numel() - test_bincount() - test_cumsum() - test_masked_fill() - test_transformer() - test_sort() - test_argsort() - test_logical_and() - test_masked_select() - test_unique() - test_hard_swish() - test_hard_sigmoid() - test_forward_nll_loss() - test_forward_flip() - - # Model tests - test_resnet18() - test_squeezenet1_0() - test_squeezenet1_1() - test_densenet121() - # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # test_inception_v3() - test_googlenet() - test_mnasnet0_5() - test_mobilenet_v2() - - test_custom_conversion_map() - - test_segmentation_models() - test_3d_models() - - # Quantization test - from qnn_test import test_quantized_imagenet, test_quantized_modules - - test_quantized_modules() - test_quantized_imagenet() - - # Test simple conditionals and loop - test_control_flow() - test_simple_rnn() - - # More complex recurrent models - from test_lstm import test_custom_lstm - - test_custom_lstm() - - # Test bert model - test_forward_pretrained_bert_base_uncased() - - # Test convert torch script(jit) with specific inputs' types - test_convert_torch_script_with_input_types() + pytest.main([__file__]) From 80cf35543779b7b266f9d7e57d16a20949c160cc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda <masahi129@gmail.com> Date: Wed, 15 Sep 2021 13:48:44 +0900 Subject: [PATCH 2/2] py format --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 818a1a4716fb..39bcfc68e421 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3742,7 +3742,7 @@ def from_pytorch( use_parser_friendly_name : bool When True, replace '.' with `_' in a original parameter name. The Relay text parser treats a variable name followed by a period as a tuple element access, - so a variable name like "dense.weight" cannot be correctly parsed. + so a variable name like "dense.weight" cannot be parsed correctly. Use this option when you want to run the AnnotateSpans pass on the imported module. Returns diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index abae931e3ede..c27469edf1d7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3941,8 +3941,9 @@ def test_annotate_span(): model = torchvision.models.resnet18().eval() inp = torch.randn([1, 3, 224, 224]) trace = torch.jit.trace(model, inp).eval() - mod, params = relay.frontend.from_pytorch(trace, [('input', inp.shape)], - use_parser_friendly_name=True) + mod, params = relay.frontend.from_pytorch( + trace, [("input", inp.shape)], use_parser_friendly_name=True + ) relay.transform.AnnotateSpans()(mod)