Skip to content

Commit

Permalink
add SequenceLength test
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Jan 29, 2023
1 parent f03a65a commit bf18552
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7760,10 +7760,16 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
"SplitToSequence", inputs=["concat_sequence"], outputs=["split_sequence"], axis=axis
)

# Test tensor extraction from sequence
at_node = helper.make_node(
"SequenceAt", inputs=["split_sequence", "position"], outputs=["output"]
)

# Test sequence length
split_node = helper.make_node(
"SequenceLength", inputs=["concat_sequence"], outputs=["output_2"]
)

if new_axis is not None:
new_axis_attr = helper.make_attribute("new_axis", new_axis)
concat_node.attribute.append(new_axis_attr)
Expand All @@ -7781,7 +7787,8 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
output_shape[axis] = num_tensors + 1
else:
output_shape[axis] = (num_tensors + 1) * output_shape[axis]
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)]
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape),
helper.make_tensor_value_info("output_2", TensorProto.INT, ())]

graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node]

Expand Down

0 comments on commit bf18552

Please sign in to comment.