From 0f04c19378e2099a258d8503fd992354c6309216 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 23 Dec 2019 17:35:57 +0900 Subject: [PATCH] remove unnecessary cast to int32 --- python/tvm/relay/frontend/onnx.py | 2 -- tests/python/frontend/onnx/test_forward.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 52211f89221a..c6f3a2e8714b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1414,8 +1414,6 @@ def from_onnx(self, graph, opset): self._num_param += 1 # We should convert scalar integers to int32, to normalize. array = self._parse_array(t_proto) - if len(array.shape) == 0 and array.dtype == 'int64': - array = _nd.array(array.asnumpy().astype('int32')) self._params[node.output[0]] = array self._nodes[node.output[0]] = new_var( node.output[0], diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 12dee0ff534d..a35ebd23ae0a 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1826,6 +1826,24 @@ def test_convtranspose(): verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2]) +def test_unsqueeze_constant(): + from torch.nn import Linear, Sequential, Module + class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + import tempfile + with tempfile.NamedTemporaryFile() as fp: + file_name = fp.name + input_size = (1, 16, 32, 32) + dummy_input = torch.randn(*input_size) + layer = Sequential(Flatten(), Linear(16 * 32 * 32, 64)) + torch.onnx.export(layer, dummy_input, file_name, export_params=True) + + onnx_model = onnx.load(file_name) + relay.frontend.from_onnx(onnx_model, {'0': input_size}) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1882,3 +1900,4 @@ def test_convtranspose(): test_space_to_depth() test_conv() test_convtranspose() + test_unsqueeze_constant()