Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay. #2850

Merged
merged 3 commits into from
Mar 30, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def __call__(self, inputs, attrs, *args):
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)

# ignore 'tvm_custom' always
self._ignores.append('tvm_custom')

# convert attributes
new_attrs = {}
for k in attrs.keys():
Expand All @@ -328,7 +332,8 @@ def __call__(self, inputs, attrs, *args):
elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores:
logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)
if k != 'tvm_custom':
logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name)
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
Expand Down Expand Up @@ -415,4 +420,6 @@ def __init__(self, new_name):
self._new_name = new_name

def __call__(self, inputs, attrs, *args):
if 'tvm_custom' in attrs:
attrs.pop('tvm_custom')
return get_relay_op(self._new_name)(*inputs, **attrs)
58 changes: 47 additions & 11 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _impl_v1(cls, inputs, attr, params):
'pads': ('padding', (0, 0), revert_caffe2_pad)
},
# very weird attributes here in onnx, force check
ignores=['dilations'],
ignores=['dilations', 'auto_pad'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=dimension_constraint())(inputs, attr, params)
Expand Down Expand Up @@ -156,6 +156,7 @@ def _impl_v1(cls, inputs, attr, params):
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)},
ignores=['auto_pad'],
custom_check=dimension_constraint())(inputs[:2], attr, params)
use_bias = len(inputs) == 3
if use_bias:
Expand Down Expand Up @@ -328,7 +329,22 @@ def _impl_v1(cls, inputs, attr, params):
shape = tuple(params[inputs[1].name_hint].asnumpy())
out = _op.reshape(inputs[0], shape)
else:
out = _op.reshape_like(inputs[0], inputs[1])
# Try to infer shape by precompute prune if possible.
# TODO: good to check inputs to be in params.
# to be enhanced when relay support list_input_names API of NNVM
logging.warning("Infering Reshape argument by precompute")
import tvm
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
out = _op.reshape(inputs[0], tuple(params_new.asnumpy().flatten()))

return out

Expand Down Expand Up @@ -467,10 +483,20 @@ class Shape(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
# Result of this operator is prominently used by reshape operator.
# Just pass the input as it is so that reshape_like can be used there.
logging.warning("Shape: Differently implemented in relay as a bypass (dummy operator)")
return inputs[0]
from topi.util import get_const_tuple
try:
out_type = ir_pass.infer_type(inputs[0])
out_shape = get_const_tuple(out_type.checked_type.shape)
except ValueError as e:
raise ImportError(
"Please pass graph level shapes to compute shape node properly {}".format(e))

node_name = attr['tvm_custom']['name']
params[node_name] = _nd.array(np.asarray(out_shape, dtype='int64'))

return _expr.var(node_name,
shape=params[node_name].shape,
dtype=params[node_name].dtype)

class Cast(OnnxOpConverter):
""" Operator converter for Cast.
Expand All @@ -484,7 +510,7 @@ def _impl_v1(cls, inputs, attr, params):
def _impl_v5(cls, inputs, attr, params):
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']])
except ImportError as e:
raise ImportError(
"Unable to import onnx.mapping which is required {}".format(e))
Expand Down Expand Up @@ -664,6 +690,11 @@ class ReduceMean(Reduce):
"""
name = 'mean'

class ReduceProd(Reduce):
""" Operator converter for ArgMax.
"""
name = 'prod'

class ArgMax(OnnxOpConverter):
""" Operator converter for ArgMax.
"""
Expand Down Expand Up @@ -815,6 +846,7 @@ def _get_convert_map(opset):
'ReduceMin': ReduceMin.get_converter(opset),
'ReduceSum': ReduceSum.get_converter(opset),
'ReduceMean': ReduceMean.get_converter(opset),
'ReduceProd': ReduceProd.get_converter(opset),
# 'ReduceProd'
# 'ReduceLogSumExp'
'ArgMax': ArgMax.get_converter(opset),
Expand All @@ -831,8 +863,7 @@ def _get_convert_map(opset):
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset),
# TODO(zhreshold) Shape op is implemented as bypass op in relay
# 'Shape': Shape.get_converter(opset),
'Shape': Shape.get_converter(opset),
}


Expand Down Expand Up @@ -872,6 +903,7 @@ def from_onnx(self, graph, opset):
----------
graph : onnx protobuf object
The loaded onnx graph

opset : opset version

Returns
Expand Down Expand Up @@ -900,12 +932,12 @@ def from_onnx(self, graph, opset):
dtype=self._params[i_name].dtype)
else:
self._num_input += 1
shape = self._shape[i_name] if i_name in self._shape else ()
tshape = self._shape[i_name] if i_name in self._shape else ()
if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else:
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=shape, dtype=dtype)
self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
# construct nodes, nodes are stored as directed acyclic graph
for node in graph.node:
op_name = node.op_type
Expand All @@ -925,6 +957,10 @@ def from_onnx(self, graph, opset):
self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype)
inputs.append(self._nodes[i_name])

i_name = self._parse_value_proto(node)
attr['tvm_custom'] = {}
attr['tvm_custom']['name'] = i_name

op = self._convert_operator(op_name, inputs, attr, opset)
node_output = self._fix_outputs(op_name, node.output)
if not isinstance(op, _expr.TupleWrapper):
Expand Down