diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 468a7486ca5c..bb968ec0bea8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -23,6 +23,7 @@ import tvm from ... import nd as _nd from .. import ir_pass +from .. import transform as _transform from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -409,21 +410,27 @@ def _impl_v1(cls, inputs, attr, params): shape = tuple(params[inputs[1].name_hint].asnumpy()) out = _op.reshape(inputs[0], shape) else: - # Try to infer shape by precompute prune if possible. - # TODO: good to check inputs to be in params. - # to be enhanced when relay support list_input_names API of NNVM - logging.warning("Infering Reshape argument by precompute") - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) + data, shape = inputs + logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") + shape_params = ir_pass.free_vars(shape) + func = _expr.Function(shape_params, shape) + mod = _module.Module.from_expr(func) + seq = _transform.Sequential([_transform.InferType(), + _transform.FoldConstant(), + _transform.FuseOps(0), + _transform.InferType()]) + with tvm.relay.PassContext(opt_level=2): + mod = seq(mod) with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) - inputs.pop(1) - out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten())) + ex = tvm.relay.create_executor("debug", mod=mod) + inputs = [] + for sp in shape_params: + if not sp.name_hint in params: + sh = [int(i) for i in sp.type_annotation.shape] + inputs.append( + tvm.nd.array(np.random.rand(*sh).astype('float32'))) + static_shape = ex.evaluate()(*inputs, **params) + out = _op.reshape(data, newshape=tuple(static_shape.asnumpy())) return out @@ -568,6 +575,7 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + # TODO(@jroesch): use shape_of once it has been fixed) return _op.shape_of(inputs[0]) class Cast(OnnxOpConverter): @@ -1058,8 +1066,15 @@ def from_onnx(self, graph, opset): if op_name == "Constant": t_proto = self._parse_attr(node.attribute)["value"] self._num_param += 1 - self._params[node.output[0]] = self._parse_array(t_proto) - self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims)) + # We should convert scalar integers to int32, to normalize. + array = self._parse_array(t_proto) + if len(array.shape) == 0 and array.dtype == 'int64': + array = _nd.array(array.asnumpy().astype('int32')) + self._params[node.output[0]] = array + self._nodes[node.output[0]] = new_var( + node.output[0], + shape=list(t_proto.dims), + dtype=array.dtype) else: if op_name == "ConstantFill": fill_value = attr.get('value', 0.0) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7371a88ca677..a52e3f0cc16e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,8 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import attr import numpy as np import math +import torch +import torchvision import topi import topi.testing import tvm @@ -1072,6 +1075,48 @@ def test_LogSoftmax(): 'LogSoftmax', {'axis': 1}) + +def check_torch_conversion(model, input_size): + dummy_input = torch.randn(*input_size) + file_name = '{}.onnx'.format(model.__name__) + # Set verbose=True for more output + torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) + onnx_model = onnx.load(file_name) + shapes = { '0' : input_size } + expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes) + +def test_resnet(): + check_torch_conversion(torchvision.models.resnet18, (1,3,224,224)) + # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224)) + +# def test_alexnet(): + # Torch's ONNX export does not support the adaptive pooling used by AlexNet? + # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224)) + +# Torch's ONNX export does not support the adaptive pooling used by vgg16? +# def test_vgg16(): +# check_torch_conversion(torchvision.models.vgg16, (1,3,224,224)) + +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_squeezenet(): +# # Torch's ONNX export does not support the max pooling used by Squezenet +# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) + +def test_densenet(): + check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) + +def test_inception(): + check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224)) + +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_googlenet(): +# check_torch_conversion(torchvision.models.googlenet, (1,3,224,224)) + +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_shufflenetv2(): +# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1111,3 +1156,6 @@ def test_LogSoftmax(): test_ParametricSoftplus() test_Scale() test_LogSoftmax() + test_resnet() + test_inception() + test_densenet()