Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] Fix reshape precompute, and type error (#3230)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and tqchen committed Jun 17, 2019
1 parent a748f5f commit df6957a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 16 deletions.
43 changes: 27 additions & 16 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,21 +409,24 @@ def _impl_v1(cls, inputs, attr, params):
shape = tuple(params[inputs[1].name_hint].asnumpy())
out = _op.reshape(inputs[0], shape)
else:
# Try to infer shape by precompute prune if possible.
# TODO: good to check inputs to be in params.
# to be enhanced when relay support list_input_names API of NNVM
logging.warning("Infering Reshape argument by precompute")
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
data, shape = inputs
logging.warning("Constant evaluating Reshape's shape argument, may reduce performance")
shape_params = ir_pass.free_vars(shape)
func = _expr.Function(shape_params, shape)
func = ir_pass.infer_type(func)
func = ir_pass.fold_constant(func)
shape_params = ir_pass.free_vars(func.body)
func = _expr.Function(shape_params, func.body)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten()))
ex = tvm.relay.create_executor("debug")
inputs = []
for sp in shape_params:
if not sp.name_hint in params:
sh = [int(i) for i in sp.type_annotation.shape]
inputs.append(
tvm.nd.array(np.random.rand(*sh).astype('float32')))
static_shape = ex.evaluate(func)(*inputs, **params)
out = _op.reshape(data, newshape=tuple(static_shape.asnumpy()))

return out

Expand Down Expand Up @@ -568,6 +571,7 @@ class Shape(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
# TODO(@jroesch): use shape_of once it has been fixed
return _op.shape_of(inputs[0])

class Cast(OnnxOpConverter):
Expand Down Expand Up @@ -1058,8 +1062,15 @@ def from_onnx(self, graph, opset):
if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1
self._params[node.output[0]] = self._parse_array(t_proto)
self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims))
# We should convert scalar integers to int32, to normalize.
array = self._parse_array(t_proto)
if len(array.shape) == 0 and array.dtype == 'int64':
array = _nd.array(array.asnumpy().astype('int32'))
self._params[node.output[0]] = array
self._nodes[node.output[0]] = new_var(
node.output[0],
shape=list(t_proto.dims),
dtype=array.dtype)
else:
if op_name == "ConstantFill":
fill_value = attr.get('value', 0.0)
Expand Down
47 changes: 47 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import attr
import numpy as np
import math
import torch
import torchvision
import topi
import topi.testing
import tvm
Expand Down Expand Up @@ -1072,6 +1075,47 @@ def test_LogSoftmax():
'LogSoftmax',
{'axis': 1})

def check_torch_conversion(model, input_size):
dummy_input = torch.randn(*input_size)
file_name = '{}.onnx'.format(model.__name__)
# Set verbose=True for more output
torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
onnx_model = onnx.load(file_name)
shapes = { '0' : input_size }
expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes)

def test_resnet():
check_torch_conversion(torchvision.models.resnet18, (1,3,224,224))
# check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))

# def test_alexnet():
# Torch's ONNX export does not support the adaptive pooling used by AlexNet?
# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))

# Torch's ONNX export does not support the adaptive pooling used by vgg16?
# def test_vgg16():
# check_torch_conversion(torchvision.models.vgg16, (1,3,224,224))

# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_squeezenet():
# # Torch's ONNX export does not support the max pooling used by Squezenet
# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))

def test_densenet():
check_torch_conversion(torchvision.models.densenet161, (1,3,224,224))

def test_inception():
check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224))

# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_googlenet():
# check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))

# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_shufflenetv2():
# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))


if __name__ == '__main__':
test_flatten()
test_reshape()
Expand Down Expand Up @@ -1111,3 +1155,6 @@ def test_LogSoftmax():
test_ParametricSoftplus()
test_Scale()
test_LogSoftmax()
test_resnet()
test_inception()
test_densenet()

0 comments on commit df6957a

Please sign in to comment.