Skip to content

Commit

Permalink
Merge pull request #1 from jroesch/mx-converter
Browse files Browse the repository at this point in the history
Mx converter
  • Loading branch information
MarisaKirisame authored Apr 4, 2019
2 parents eb1ed11 + 5d4f251 commit 4b8b933
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 10 deletions.
64 changes: 56 additions & 8 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def _mx_fully_connected(inputs, attrs):
# no flatten attribute in old mxnet
has_flatten = False
use_flatten = attrs.get_bool("flatten", True)
assert use_flatten == False
if has_flatten and use_flatten:
inputs[0] = _op.nn.batch_flatten(inputs[0])
res = _op.nn.dense(inputs[0], inputs[1], units=units)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2])
res = _op.nn.bias_add(res, inputs[2], axis=-1)
return res


Expand Down Expand Up @@ -192,6 +193,26 @@ def _mx_batch_norm(inputs, attrs):
new_attrs["scale"] = not attrs.get_bool("fix_gamma", False)
return _op.nn.batch_norm(*inputs, **new_attrs)

def _mx_layer_norm(inputs, attrs):
# TODO: implement layer norm
return inputs[0]

def _mx_div_sqrt_dim(inputs, attrs):
assert len(inputs) == 1
data = inputs[0]
shape = _op.shape_of(data)
last_dim_index = _op.subtract(_op.sum(_op.ones_like(shape)), _expr.const(1))
last_dim = _op.take(_op.shape_of(data), indices=last_dim_index)
return _op.divide(data,
_op.sqrt(last_dim.astype('float32')))

def _mx_erf(inputs, attrs):
# TODO: implement erf
return inputs[0]

def _mx_sequence_mask(inputs, attrs):
# TODO: implement seq mask
return inputs[0]

def _mx_slice(inputs, attrs):
new_attrs = {}
Expand Down Expand Up @@ -413,7 +434,7 @@ def _mx_batch_dot(inputs, attrs):
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.batch_matmul(a, b)
return _op.nn.batch_matmul(a, b)


def _mx_arange(inputs, attrs):
Expand All @@ -422,10 +443,19 @@ def _mx_arange(inputs, attrs):
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0)
new_attrs["stop"] = attrs.get_float("stop")
new_attrs["step"] = attrs.get_float("step", 1)
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
stop = attrs.attrs.get('stop')
# This op has special behavior when only start is passed.
if stop != 'None':
new_attrs["start"] = attrs.get_float("start", 0)
new_attrs["stop"] = attrs.get_float("stop")
new_attrs["step"] = attrs.get_float("step", 1)
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
else:
new_attrs["start"] = 0
new_attrs["stop"] = attrs.get_float("start")
new_attrs["step"] = attrs.get_float("step", 1)
new_attrs["dtype"] = attrs.get_str("dtype", "float32")

return _op.arange(**new_attrs)


Expand Down Expand Up @@ -774,6 +804,10 @@ def _mx_deformable_convolution(inputs, attrs):
# "broadcast_to",
# "gather_nd",
# "Crop" : _crop_like,
"LayerNorm": _mx_layer_norm,
"_contrib_div_sqrt_dim": _mx_div_sqrt_dim,
"erf": _mx_erf,
"SequenceMask": _mx_sequence_mask,
}

# set identity list
Expand Down Expand Up @@ -812,6 +846,9 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
attrs = StrAttrsDict(node.get("attrs", {}))
node_name = node["name"]
op_name = node["op"]



if op_name == "null":
shape = shape_dict[node_name] if node_name in shape_dict else None
if isinstance(dtype_info, dict):
Expand All @@ -828,6 +865,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
else:
raise RuntimeError("unexpected type %s" % type(res))
node_map[nid] = res

# if op_name == 'FullyConnected':
# outputs = res
# break
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend MXNet.'.format(op_name))
Expand Down Expand Up @@ -859,6 +900,7 @@ def _update_shape_dtype(shape, dtype, params):
def from_mxnet(symbol,
shape=None,
dtype="float32",
input_symbols=None,
arg_params=None,
aux_params=None):
"""Convert from MXNet"s model into compatible relay Function.
Expand Down Expand Up @@ -909,8 +951,14 @@ def from_mxnet(symbol,
params = {}
for k, v in symbol.collect_params().items():
params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data")
sym = symbol(data)

if input_symbols is not None:
inputs = input_symbols
else:
inputs = []
inputs.append(mx.sym.Variable("data"))
sym = symbol(*inputs)

if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params)
Expand Down
14 changes: 12 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -710,9 +710,19 @@ bool TakeRel(const Array<Type>& types,
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "must be tensor type or incomplete type";
return false;
}

const auto* indices = types[1].as<TensorTypeNode>();
CHECK(indices != nullptr);
if (indices == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "must be tensor type or incomplete type";
return true;
}

const auto param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);

Expand Down

0 comments on commit 4b8b933

Please sign in to comment.