Skip to content

Commit

Permalink
remove unnecessary cast to int32 (#4573)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and zhiics committed Dec 23, 2019
1 parent dfc4009 commit 9ec0e5c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
2 changes: 0 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,8 +1414,6 @@ def from_onnx(self, graph, opset):
self._num_param += 1
# We should convert scalar integers to int32, to normalize.
array = self._parse_array(t_proto)
if len(array.shape) == 0 and array.dtype == 'int64':
array = _nd.array(array.asnumpy().astype('int32'))
self._params[node.output[0]] = array
self._nodes[node.output[0]] = new_var(
node.output[0],
Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,24 @@ def test_convtranspose():
verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2])


def test_unsqueeze_constant():
from torch.nn import Linear, Sequential, Module
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)

import tempfile
with tempfile.NamedTemporaryFile() as fp:
file_name = fp.name
input_size = (1, 16, 32, 32)
dummy_input = torch.randn(*input_size)
layer = Sequential(Flatten(), Linear(16 * 32 * 32, 64))
torch.onnx.export(layer, dummy_input, file_name, export_params=True)

onnx_model = onnx.load(file_name)
relay.frontend.from_onnx(onnx_model, {'0': input_size})


if __name__ == '__main__':
test_flatten()
test_reshape()
Expand Down Expand Up @@ -1882,3 +1900,4 @@ def test_convtranspose():
test_space_to_depth()
test_conv()
test_convtranspose()
test_unsqueeze_constant()

0 comments on commit 9ec0e5c

Please sign in to comment.