Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Gather
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Aug 24, 2019
1 parent 6c325eb commit 6a723e8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
17 changes: 17 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,23 @@ def convert_dynamic_reshape(node, **kwargs):
"""
return create_basic_op_node('Reshape', node, kwargs)

@mx_op.register("take")
def convert_take(node, **kwargs):
"""Map MXNet's Take operator attributes to onnx's Gather operator.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get('axis', 0))

node = onnx.helper.make_node(
"Gather",
input_nodes,
[name],
axis=axis,
name=name,
)
return [node]

# Changing shape and type.
@mx_op.register("Reshape")
def convert_reshape(node, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@
'test_max_',
'test_softplus',
'test_reduce_',
'test_split_equal'
'test_split_equal',
'test_gather'
],
'import': ['test_gather',
'test_softsign',
'import': ['test_softsign',
'test_mean',
'test_averagepool_1d',
'test_averagepool_2d_pads_count_include_pad',
Expand Down

0 comments on commit 6a723e8

Please sign in to comment.