Skip to content

Commit

Permalink
[ONNX] Support SequenceEmpty op (#13866)
Browse files Browse the repository at this point in the history
* add SequenceEmpty

* add SequenceEmpty test

* pylint fix

---------

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
vvchernov and Valery Chernov authored Jan 31, 2023
1 parent 0d5baac commit d8833bd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6148,6 +6148,15 @@ def _impl_v11(cls, inputs, attr, params):
return _expr.Tuple(inputs)


class SequenceEmpty(OnnxOpConverter):
"""Operator converter for sequence empty op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
# Construct an empty tuple.
return _expr.Tuple([])


class SequenceErase(OnnxOpConverter):
"""Operator converter for sequence erase op."""

Expand Down Expand Up @@ -6523,6 +6532,7 @@ def _get_convert_map(opset):
"LinearRegressor": LinearRegressor.get_converter(opset),
# Sequence operators
"SequenceConstruct": SequenceConstruct.get_converter(opset),
"SequenceEmpty": SequenceEmpty.get_converter(opset),
"SequenceErase": SequenceErase.get_converter(opset),
"SequenceInsert": SequenceInsert.get_converter(opset),
"SequenceLength": SequenceLength.get_converter(opset),
Expand Down
32 changes: 32 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7829,6 +7829,38 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
verify_sequence_ops((3, 3, 3, 3), 4, axis=2, new_axis=1)


@tvm.testing.parametrize_targets
def test_empty_sequence(target, dev):
"""test_empty_sequence"""

# Test creating an empty tensor sequence.
empty_node = helper.make_node(
"SequenceEmpty",
inputs=[],
outputs=["empty_sequence"],
)

length_node = helper.make_node("SequenceLength", inputs=["empty_sequence"], outputs=["output"])

graph_outputs = [helper.make_tensor_value_info("output", TensorProto.INT64, [])]

graph_nodes = [empty_node, length_node]

graph = helper.make_graph(
graph_nodes,
"Sequence_empty_test",
inputs=[],
outputs=graph_outputs,
)

model = helper.make_model(
graph,
producer_name="Sequence_empty_test",
)

verify_with_ort_with_inputs(model, [], target=target, dev=dev)


def test_exporting_node_renamed_model():
"""test exproting model when export_node_renamed_model is set"""

Expand Down

0 comments on commit d8833bd

Please sign in to comment.