Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Add an option to make imported models compatible with the Relay text parser #9015

Merged
merged 2 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3400,8 +3400,8 @@ def _getattr_attr_name(node):
return attr_name


def _getattr_full_name(getattrs):
return ".".join([_getattr_attr_name(node) for node in getattrs])
def _getattr_full_name(getattrs, sep="."):
return sep.join([_getattr_attr_name(node) for node in getattrs])


def _get_pytorch_value_type(typ, default_dtype="float32"):
Expand Down Expand Up @@ -3657,7 +3657,7 @@ def terminate(users):
return get_use_chains(root_getattr_node, terminate)


def convert_params(graph, state_dict):
def convert_params(graph, state_dict, use_parser_friendly_name=False):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
Expand All @@ -3668,6 +3668,7 @@ def convert_params(graph, state_dict):
packed_param_map = {}
vars_by_name = {}
seen = set()
attr_name_sep = "_" if use_parser_friendly_name else "."

for node in getattr_nodes:
if _get_output_name(node) in seen:
Expand All @@ -3676,7 +3677,7 @@ def convert_params(graph, state_dict):
for getattrs in get_attr_chains(node):
seen.update(map(_get_output_name, getattrs))

full_attr = _getattr_full_name(getattrs)
full_attr = _getattr_full_name(getattrs, attr_name_sep)
full_attr_node_name = _get_output_name(getattrs[-1])

if full_attr.endswith("_packed_params"): # for quantized models
Expand Down Expand Up @@ -3706,7 +3707,13 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)


def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dtype="float32"):
def from_pytorch(
script_module,
input_infos,
custom_convert_map=None,
default_dtype="float32",
use_parser_friendly_name=False,
):
"""Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.

Expand All @@ -3729,6 +3736,15 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
custom_convert_map : Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above

default_type : str
The default dtype to use when type information is not provided by PyTorch.

use_parser_friendly_name : bool
When True, replace '.' with `_' in a original parameter name.
The Relay text parser treats a variable name followed by a period as a tuple element access,
so a variable name like "dense.weight" cannot be parsed correctly.
Use this option when you want to run the AnnotateSpans pass on the imported module.

Returns
-------
mod : tvm.IRModule
Expand Down Expand Up @@ -3758,7 +3774,13 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
outputs = _get_relay_input_vars(
graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
)
param_vars, tensors, packed_param_map = convert_params(graph, params)

if use_parser_friendly_name:
new_names = [key.replace(".", "_") for key in params.keys()]
params = dict(zip(new_names, params.values()))

param_vars, tensors, packed_param_map = convert_params(graph, params, use_parser_friendly_name)

tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

outputs.update(param_vars)
Expand All @@ -3778,7 +3800,7 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
# ListConstruct kept original python list. Convert to tuple.
ret = _expr.Tuple(ret)

# Separate data inputs and parameters to make sure data inputs are always in the beginning.
# Separate data inputs and parameters to make sure data inputs come first.
func_args = []
data_inputs = []
for arg in _analysis.free_vars(ret):
Expand Down
192 changes: 11 additions & 181 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3937,185 +3937,15 @@ def forward(self, x):
verify_model(Flip(axis=-1), input_data=input)


def test_annotate_span():
model = torchvision.models.resnet18().eval()
inp = torch.randn([1, 3, 224, 224])
trace = torch.jit.trace(model, inp).eval()
mod, params = relay.frontend.from_pytorch(
trace, [("input", inp.shape)], use_parser_friendly_name=True
)
relay.transform.AnnotateSpans()(mod)


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
test_forward_dtypes()
test_weight_names()
test_duplicate_weight_use()

# Single operator tests
test_forward_pixel_shuffle()
test_forward_add()
test_forward_subtract()
test_forward_multiply()
test_forward_matmul()
test_forward_rsub()
test_forward_onehot()
test_forward_embedding()
test_forward_reshape()
test_forward_reciprocal()
test_forward_repeat()
test_forward_repeat_interleave()
test_forward_squeeze()
test_forward_unsqueeze()
test_forward_concatenate()
test_forward_reduce_sum()
test_forward_reduce_prod()
test_forward_argmin()
test_forward_argmax()
test_forward_norm()
test_forward_frobenius_norm()
test_forward_std()
test_forward_variance()
test_forward_relu()
test_forward_prelu()
test_forward_leakyrelu()
test_forward_elu()
test_forward_celu()
test_forward_gelu()
test_forward_selu()
test_forward_log_sigmoid()
test_forward_adaptiveavgpool()
test_forward_maxpool2d()
test_forward_maxpool1d()
test_forward_maxpool3d()
test_forward_hardtanh()
test_forward_conv()
test_forward_conv_transpose()
test_forward_threshold()
test_forward_contiguous()
test_forward_batchnorm()
test_forward_instancenorm()
test_forward_layernorm()
test_forward_groupnorm()
test_forward_transpose()
test_forward_size()
test_forward_view()
test_forward_select()
test_forward_take()
test_forward_topk()
test_forward_where()
test_forward_addcdiv()
test_forward_addcmul()
test_forward_true_divide()
test_forward_is_floating_point()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
test_forward_linear()
test_forward_avgpool1d()
test_forward_avgpool2d()
test_forward_avgpool3d()
test_forward_dropout()
test_forward_slice()
test_forward_narrow()
test_forward_mean()
test_forward_expand()
test_forward_pow()
test_forward_unary()
test_forward_clamp()
test_forward_clamp_()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
test_forward_logical_xor()
test_forward_isfinite()
test_forward_isnan()
test_forward_isinf()
test_forward_ones()
test_forward_ones_like()
test_forward_zeros()
test_forward_zeros_like()
test_forward_full()
test_forward_full_like()
test_forward_linspace()
test_forward_arange()
test_forward_mesh_grid()
test_forward_chunk()
test_forward_split()
test_forward_gather()
test_upsample()
test_forward_upsample3d()
test_forward_nms()
test_forward_roi_align()
test_to()
test_flatten()
test_type_as()
test_forward_functional_pad()
test_forward_zero_pad2d()
test_forward_constant_pad1d()
test_forward_constant_pad2d()
test_forward_constant_pad3d()
test_forward_reflection_pad1d()
test_forward_reflection_pad2d()
test_forward_replication_pad1d()
test_forward_replication_pad2d()
test_forward_replication_pad3d()
test_adaptive_pool3d()
test_conv3d()
test_conv3d_transpose()
test_forward_index()
test_min_max()
test_logsumexp()
test_stack()
test_stack_dynamic()
test_forward_unbind()
test_forward_nonzero()
test_forward_scatter()
test_forward_index_put()
test_numel()
test_bincount()
test_cumsum()
test_masked_fill()
test_transformer()
test_sort()
test_argsort()
test_logical_and()
test_masked_select()
test_unique()
test_hard_swish()
test_hard_sigmoid()
test_forward_nll_loss()
test_forward_flip()

# Model tests
test_resnet18()
test_squeezenet1_0()
test_squeezenet1_1()
test_densenet121()
# disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
# See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
# test_inception_v3()
test_googlenet()
test_mnasnet0_5()
test_mobilenet_v2()

test_custom_conversion_map()

test_segmentation_models()
test_3d_models()

# Quantization test
from qnn_test import test_quantized_imagenet, test_quantized_modules

test_quantized_modules()
test_quantized_imagenet()

# Test simple conditionals and loop
test_control_flow()
test_simple_rnn()

# More complex recurrent models
from test_lstm import test_custom_lstm

test_custom_lstm()

# Test bert model
test_forward_pretrained_bert_base_uncased()

# Test convert torch script(jit) with specific inputs' types
test_convert_torch_script_with_input_types()
pytest.main([__file__])