Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Populate aux_params in dynamic shape forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Aug 24, 2019
1 parent 6a723e8 commit 092b4f7
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,19 @@ def get_outputs(sym, params, in_shape, in_label):

data_forward = []
data_names = []
for k, v in inputs.items():
data_names.append(k)
data_forward.append(mx.nd.array(mx.nd.random_normal(shape=v)))
for arg in sym.list_arguments():
data_names.append(arg)
data_forward.append(mx.nd.array(mx.nd.random_normal(shape=inputs[arg])))

aux_names = []
aux_forward = []
for aux in sym.list_auxiliary_states():
aux_names.append(aux)
aux_forward.append(mx.nd.array(mx.nd.random_normal(shape=inputs[aux])))

args = dict(zip(data_names, data_forward))
exe = sym.bind(mx.cpu(0), args=args, aux_states=None)
auxs = dict(zip(aux_names, aux_forward))
exe = sym.bind(mx.cpu(0), args=args, aux_states=auxs)
exe.forward(is_train=False)
result = []
for output in exe.outputs:
Expand Down

0 comments on commit 092b4f7

Please sign in to comment.