diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 14c674f56f2d..84db5decd503 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -134,9 +134,10 @@ def get_outputs(sym, params, in_shape, in_label): # remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided # by user. Also remove in_label, which is the name of the label symbol that may have been used # as the label for loss during training. - inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_shape)} + inputs = {n: tuple(s) for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label], + in_shape)} # Add params and their shape to list of inputs - inputs.update({n: v.shape for n, v in params.items()}) + inputs.update({n: v.shape for n, v in params.items() if n in sym.list_inputs()}) # Provide input data as well as input params to infer_shape() _, out_shapes, _ = sym.infer_shape(**inputs) diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index bbff7833fe20..a08a184f3c92 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -260,18 +260,19 @@ def _optional_group(symbols, group=False): return symbols -def _check_onnx_export(net, group_outputs=False): +def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params={}): net.initialize() data = nd.random.uniform(0, 1, (1, 1024)) output = _force_list(net(data)) # initialize weights net_sym = _optional_group(net(sym.Variable('data')), group_outputs) net_params = {name:param._reduce() for name, param in net.collect_params().items()} + net_params.update(extra_params) with tempfile.TemporaryDirectory() as tmpdirname: onnx_file_path = os.path.join(tmpdirname, 'net.onnx') export_path = onnx_mxnet.export_model( sym=net_sym, params=net_params, - input_shape=[data.shape], + input_shape=[shape_type(data.shape)], onnx_file_path=onnx_file_path) assert export_path == onnx_file_path # Try importing the model to symbol @@ -314,6 +315,22 @@ def hybrid_forward(self, F, x): _check_onnx_export(net, group_outputs=True) +@with_seed() +def test_onnx_export_list_shape(): + net = nn.HybridSequential(prefix='list_shape_net') + with net.name_scope(): + net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) + _check_onnx_export(net, shape_type=list) + + +@with_seed() +def test_onnx_export_extra_params(): + net = nn.HybridSequential(prefix='extra_params_net') + with net.name_scope(): + net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) + _check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])}) + + if __name__ == '__main__': test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000)) test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))