Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Error: shape inconsistent while converting PyTorch model to mxnet model with onnx #13949

Open
wangliye00 opened this issue Jan 21, 2019 · 18 comments · May be fixed by #15996
Open

Error: shape inconsistent while converting PyTorch model to mxnet model with onnx #13949

wangliye00 opened this issue Jan 21, 2019 · 18 comments · May be fixed by #15996
Labels
ONNX pr-awaiting-response PR is reviewed and waiting for contributor to respond

Comments

@wangliye00
Copy link

Description

I tried to convert a pretrained pytorch resnet18 to onnx format, then to mxnet model. But shape inconsistent error occurred when I load the onnx file with mxnet.contrib.onnx module.

Environment info (Required)

  • PyTorch 1.0, mxnet 1.3.0 (for windows)
## Error Message:

File “D:\software\Anaconda3\lib\site-packages\mxnet\contrib\onnx\onnx2mx\import_model.py”, line 53, in import_model
sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)

File “D:\software\Anaconda3\lib\site-packages\mxnet\contrib\onnx\onnx2mx\import_onnx.py”, line 96, in from_onnx
self._params[init_tensor.name] = self._parse_array(init_tensor)

File “D:\software\Anaconda3\lib\site-packages\mxnet\contrib\onnx\onnx2mx\import_onnx.py”, line 200, in _parse_array
return nd.array(np_array)

File “D:\software\Anaconda3\lib\site-packages\mxnet\ndarray\utils.py”, line 146, in array
return _array(source_array, ctx=ctx, dtype=dtype)

File “D:\software\Anaconda3\lib\site-packages\mxnet\ndarray\ndarray.py”, line 2435, in array
arr[:] = source_array

File “D:\software\Anaconda3\lib\site-packages\mxnet\ndarray\ndarray.py”, line 444, in setitem
self._set_nd_basic_indexing(key, value)

File “D:\software\Anaconda3\lib\site-packages\mxnet\ndarray\ndarray.py”, line 710, in _set_nd_basic_indexing
self._sync_copyfrom(value)

File “D:\software\Anaconda3\lib\site-packages\mxnet\ndarray\ndarray.py”, line 872, in _sync_copyfrom
str(self.shape), str(source_array.shape)))

ValueError: Shape inconsistent: expected () vs got (1,)

## Minimum reproducible example
import torch
import torchvision

dummy_input = torch.randn(1, 3, 224, 224)
model = torchvision.models.resnet18(pretrained=True)
input_names = [ “input_1" ]
output_names = [ “output1” ]

torch.onnx.export(model, dummy_input, “resnet.onnx”, verbose=True, input_names=input_names, output_names=output_names)

from mxnet.contrib import onnx as onnx_mxnet
sym, arg, aux = onnx_mxnet.import_model(“resnet.onnx”)
@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: ONNX

@piyushghai
Copy link
Contributor

@wangliye00 Thanks for raising this issue.
I'm tagging MXNet ONNX experts to look at this. @Roshrini @vandanavk Can you have a look at this ?

@mxnet-label-bot Add [ONNX, Bug]

@Con-Mi
Copy link

Con-Mi commented Feb 12, 2019

I phase the same error on version 1.3.1 for Linux.
Any updates?

@vandanavk
Copy link
Contributor

Traceback (most recent call last):
  File "/Users/vandanavk/test.py", line 79, in <module>
    sym, arg, aux = onnx_mxnet.import_model("resnet.onnx")
  File "/Users/vandanavk/Documents/mxnet/incubator-mxnet/python/mxnet/contrib/onnx/onnx2mx/import_model.py", line 59, in import_model
    sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)
  File "/Users/vandanavk/Documents/mxnet/incubator-mxnet/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py", line 116, in from_onnx
    mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs)
  File "/Users/vandanavk/Documents/mxnet/incubator-mxnet/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py", line 62, in _convert_operator
    op_name, new_attrs, inputs = convert_map[op_name](attrs, inputs, self)
  File "/Users/vandanavk/Documents/mxnet/incubator-mxnet/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py", line 460, in reshape
    reshape_shape = list(proto_obj._params[inputs[1].name].asnumpy())
KeyError: 'concat0'

I get this error. (MXNet built from source).
This error is the same as #13395 and similar to #13774. All these issues depends on dynamic shape support.

@vandanavk
Copy link
Contributor

Related: #12732, #10789

@vandanavk
Copy link
Contributor

@wangliye00 @Con-Mi However, I do see the error that you facing with MXNet v1.3.1. For fixing this, could you try pulling in the commit #13413 and checking if you are able to proceed further?

@roywei
Copy link
Member

roywei commented Mar 4, 2019

@mxnet-label-bot remove[Bug]

@marcoabreu marcoabreu removed the Bug label Mar 4, 2019
@roywei
Copy link
Member

roywei commented Mar 4, 2019

@mxnet-label-bot add [pr-awaiting-response]

@marcoabreu marcoabreu added the pr-awaiting-response PR is reviewed and waiting for contributor to respond label Mar 4, 2019
@nswamy
Copy link
Member

nswamy commented Mar 4, 2019

Please reopen a new issue if you still face problems after pulling in the commit that @vandanavk suggested and we'll take investigate further.

@nswamy nswamy closed this as completed Mar 4, 2019
@turtleizzy
Copy link

turtleizzy commented Mar 10, 2019

@wangliye00 @Con-Mi However, I do see the error that you facing with MXNet v1.3.1. For fixing this, could you try pulling in the commit #13413 and checking if you are able to proceed further?

I am facing similar issue when loading pytorch-densenet onnx model into mxnet. The error message reads:

/usr/local/lib/python3.6/site-packages/mxnet/contrib/onnx/onnx2mx/import_onnx.py in _convert_operator(self, node_name, op_name, attrs, inputs)
     59         """
     60         if op_name in convert_map:
---> 61             op_name, new_attrs, inputs = convert_map[op_name](attrs, inputs, self)
     62         else:
     63             raise NotImplementedError("Operator {} not implemented.".format(op_name))

/usr/local/lib/python3.6/site-packages/mxnet/contrib/onnx/onnx2mx/_op_translations.py in reshape(attrs, inputs, proto_obj)
    432     if len(inputs) == 1:
    433         return 'reshape', attrs, inputs[0]
--> 434     reshape_shape = list(proto_obj._params[inputs[1].name].asnumpy())
    435     reshape_shape = [int(i) for i in reshape_shape]
    436     new_attrs = {'shape': reshape_shape}

KeyError: 'concat51'

I tried mxnet 1.3.1 (after patched import_onnx.py following your suggestion) and 1.4.0 with no luck, both raised similar exception.

@lupesko
Copy link
Contributor

lupesko commented Mar 27, 2019

@nswamy I think we should open this issue, since it still persists.

@JohnLee168
Copy link

@nswamy I think we should open this issue, since it still persists.

I have the same problem when load resnet50 onnx model

@BrettLL
Copy link

BrettLL commented Apr 29, 2019

same problem with you @JohnLee168

@weihua04
Copy link

@vandanavk , i have the same problem with you: "KeyError: 'concat0'", and i test many onnx model which export from pytorch, Does the problem have any solution ,

thank u

@Ankit01Mishra
Copy link

I too have the same problem. Is there any solution? Thanks

@Roshrini Roshrini reopened this Jul 2, 2019
@vandanavk vandanavk linked a pull request Aug 24, 2019 that will close this issue
5 tasks
@dusigh
Copy link

dusigh commented Nov 20, 2019

I too have run into the same issue. Hoping the fix can be released ASAP.

@high426
Copy link

high426 commented Jan 19, 2022

Maybe in “import_onnx.py” function “def from_onnx()” loss _nodes messages, you can try adding information.

@pribadihcr
Copy link

any solution?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
ONNX pr-awaiting-response PR is reviewed and waiting for contributor to respond
Projects
None yet
Development

Successfully merging a pull request may close this issue.