Skip to content

Commit

Permalink
[ONNX] Support SequenceErase op (#13865)
Browse files Browse the repository at this point in the history
* SequenceErase was implemented in ONNX front-end

* add SequenceErase node to Sequence test

* remark from reviewer. fix negative position recalculation

* add assert

---------

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
vvchernov and Valery Chernov authored Jan 31, 2023
1 parent 4daf38f commit 0d5baac
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
42 changes: 37 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6148,13 +6148,35 @@ def _impl_v11(cls, inputs, attr, params):
return _expr.Tuple(inputs)


class SequenceLength(OnnxOpConverter):
"""Operator converter for sequence length op."""
class SequenceErase(OnnxOpConverter):
"""Operator converter for sequence erase op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
# Get length of input sequence
return _expr.const(len(inputs[0]), dtype="int64")
# Erase tensor from sequence on specified position
input_sequence = inputs[0]

if len(inputs) == 2:
position = inputs[1]
# Non constant position is not supported.
if isinstance(position, _expr.Constant):
position = position.data.numpy()
elif position.name_hint in params:
position = params[position.name_hint].numpy()
else:
raise NotImplementedError("Position must be a constant.")
else:
position = -1

seq_len = len(input_sequence)
assert -seq_len <= position < seq_len, "Position is out of bounds"

if position < 0:
position = seq_len + position
# Convert sequence to a list, insert tensors before erased, and repackage as Tuple.
tensor_list = [input_sequence[i] for i in range(seq_len) if i != position]
# Create new tuple and return.
return _expr.Tuple(tensor_list)


class SequenceInsert(OnnxOpConverter):
Expand Down Expand Up @@ -6188,6 +6210,15 @@ def _impl_v11(cls, inputs, attr, params):
return _expr.Tuple(tensor_list)


class SequenceLength(OnnxOpConverter):
"""Operator converter for sequence length op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
# Get length of input sequence
return _expr.const(len(inputs[0]), dtype="int64")


class ConcatFromSequence(OnnxOpConverter):
"""Operator converter for sequence concatenation op."""

Expand Down Expand Up @@ -6492,8 +6523,9 @@ def _get_convert_map(opset):
"LinearRegressor": LinearRegressor.get_converter(opset),
# Sequence operators
"SequenceConstruct": SequenceConstruct.get_converter(opset),
"SequenceLength": SequenceLength.get_converter(opset),
"SequenceErase": SequenceErase.get_converter(opset),
"SequenceInsert": SequenceInsert.get_converter(opset),
"SequenceLength": SequenceLength.get_converter(opset),
"ConcatFromSequence": ConcatFromSequence.get_converter(opset),
"SplitToSequence": SplitToSequence.get_converter(opset),
"SequenceAt": SequenceAt.get_converter(opset),
Expand Down
10 changes: 9 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7747,10 +7747,17 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
outputs=["inserted_sequence"],
)

# Test sequence erase.
erase_node = helper.make_node(
"SequenceErase",
inputs=["inserted_sequence", "position"],
outputs=["erased_sequence"],
)

# Test sequence concatenation.
concat_node = helper.make_node(
"ConcatFromSequence",
inputs=["inserted_sequence"],
inputs=["erased_sequence"],
outputs=["concat_sequence"],
axis=axis,
)
Expand Down Expand Up @@ -7796,6 +7803,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
position_node,
construct_node,
insert_node,
erase_node,
concat_node,
split_node,
at_node,
Expand Down

0 comments on commit 0d5baac

Please sign in to comment.