From 94e3f171c271c95aa88139cbe49a4c2968abd026 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 29 Mar 2019 02:58:27 -0700 Subject: [PATCH 1/2] Port changes --- python/tvm/relay/frontend/mxnet.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 69d779271be73..793a37c2730fa 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -839,6 +839,8 @@ def _update_shape_dtype(shape, dtype, params): def from_mxnet(symbol, shape=None, dtype="float32", + aux_params=None, + input_symbols=None, arg_params=None, aux_params=None): """Convert from MXNet"s model into compatible relay Function. @@ -889,8 +891,14 @@ def from_mxnet(symbol, params = {} for k, v in symbol.collect_params().items(): params[k] = _nd.array(v.data().asnumpy()) - data = mx.sym.Variable("data") + data = mx.sym.Variable("data") sym = symbol(data) + if input_symbols is not None: + inputs = input_symbols + else: + inputs = [] + inputs.append(mx.sym.Variable("data")) + sym = symbol(*inputs)) if isinstance(sym, (list, tuple)): sym = mx.sym.Group(sym) shape, dtype = _update_shape_dtype(shape, dtype, params) From 5d4f25188617197086a390f37a43a3293d50c137 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 29 Mar 2019 16:59:53 -0700 Subject: [PATCH 2/2] More fixes --- python/tvm/relay/frontend/mxnet.py | 60 +++++++++++++++++++++++++----- src/relay/op/tensor/transform.cc | 14 ++++++- 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 793a37c2730fa..6e4a0326f66c1 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -29,12 +29,13 @@ def _mx_fully_connected(inputs, attrs): # no flatten attribute in old mxnet has_flatten = False use_flatten = attrs.get_bool("flatten", True) + assert use_flatten == False if has_flatten and use_flatten: inputs[0] = _op.nn.batch_flatten(inputs[0]) res = _op.nn.dense(inputs[0], inputs[1], units=units) if use_bias: assert len(inputs) == 3 - res = _op.nn.bias_add(res, inputs[2]) + res = _op.nn.bias_add(res, inputs[2], axis=-1) return res @@ -192,6 +193,26 @@ def _mx_batch_norm(inputs, attrs): new_attrs["scale"] = not attrs.get_bool("fix_gamma", False) return _op.nn.batch_norm(*inputs, **new_attrs) +def _mx_layer_norm(inputs, attrs): + # TODO: implement layer norm + return inputs[0] + +def _mx_div_sqrt_dim(inputs, attrs): + assert len(inputs) == 1 + data = inputs[0] + shape = _op.shape_of(data) + last_dim_index = _op.subtract(_op.sum(_op.ones_like(shape)), _expr.const(1)) + last_dim = _op.take(_op.shape_of(data), indices=last_dim_index) + return _op.divide(data, + _op.sqrt(last_dim.astype('float32'))) + +def _mx_erf(inputs, attrs): + # TODO: implement erf + return inputs[0] + +def _mx_sequence_mask(inputs, attrs): + # TODO: implement seq mask + return inputs[0] def _mx_slice(inputs, attrs): new_attrs = {} @@ -413,7 +434,7 @@ def _mx_batch_dot(inputs, attrs): raise tvm.error.OpAttributeInvalid(msg.format(transpose_a)) if transpose_b is False: b = _op.transpose(b, axes=[0, 2, 1]) - return _op.batch_matmul(a, b) + return _op.nn.batch_matmul(a, b) def _mx_arange(inputs, attrs): @@ -422,10 +443,19 @@ def _mx_arange(inputs, attrs): raise tvm.error.OpAttributeUnimplemented( 'Attribute "repeat" is not supported in operator arange.') new_attrs = {} - new_attrs["start"] = attrs.get_float("start", 0) - new_attrs["stop"] = attrs.get_float("stop") - new_attrs["step"] = attrs.get_float("step", 1) - new_attrs["dtype"] = attrs.get_str("dtype", "float32") + stop = attrs.attrs.get('stop') + # This op has special behavior when only start is passed. + if stop != 'None': + new_attrs["start"] = attrs.get_float("start", 0) + new_attrs["stop"] = attrs.get_float("stop") + new_attrs["step"] = attrs.get_float("step", 1) + new_attrs["dtype"] = attrs.get_str("dtype", "float32") + else: + new_attrs["start"] = 0 + new_attrs["stop"] = attrs.get_float("start") + new_attrs["step"] = attrs.get_float("step", 1) + new_attrs["dtype"] = attrs.get_str("dtype", "float32") + return _op.arange(**new_attrs) @@ -754,6 +784,10 @@ def _mx_smooth_l1(inputs, attrs): # "broadcast_to", # "gather_nd", # "Crop" : _crop_like, + "LayerNorm": _mx_layer_norm, + "_contrib_div_sqrt_dim": _mx_div_sqrt_dim, + "erf": _mx_erf, + "SequenceMask": _mx_sequence_mask, } # set identity list @@ -792,6 +826,9 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): attrs = StrAttrsDict(node.get("attrs", {})) node_name = node["name"] op_name = node["op"] + + + if op_name == "null": shape = shape_dict[node_name] if node_name in shape_dict else None if isinstance(dtype_info, dict): @@ -808,6 +845,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): else: raise RuntimeError("unexpected type %s" % type(res)) node_map[nid] = res + + # if op_name == 'FullyConnected': + # outputs = res + # break else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported in frontend MXNet.'.format(op_name)) @@ -839,7 +880,6 @@ def _update_shape_dtype(shape, dtype, params): def from_mxnet(symbol, shape=None, dtype="float32", - aux_params=None, input_symbols=None, arg_params=None, aux_params=None): @@ -891,14 +931,14 @@ def from_mxnet(symbol, params = {} for k, v in symbol.collect_params().items(): params[k] = _nd.array(v.data().asnumpy()) - data = mx.sym.Variable("data") - sym = symbol(data) + if input_symbols is not None: inputs = input_symbols else: inputs = [] inputs.append(mx.sym.Variable("data")) - sym = symbol(*inputs)) + sym = symbol(*inputs) + if isinstance(sym, (list, tuple)): sym = mx.sym.Group(sym) shape, dtype = _update_shape_dtype(shape, dtype, params) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a0ea8f2e60a36..456800255f379 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -710,9 +710,19 @@ bool TakeRel(const Array& types, // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + CHECK(types[0].as()) + << "must be tensor type or incomplete type"; + return false; + } + const auto* indices = types[1].as(); - CHECK(indices != nullptr); + if (indices == nullptr) { + CHECK(types[1].as()) + << "must be tensor type or incomplete type"; + return true; + } + const auto param = attrs.as(); CHECK(param != nullptr);