Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][PaddlePaddle] Add topk op and Fix bug, when the output is a dimension, it … #13701

Merged
merged 11 commits into from
Jan 21, 2023
35 changes: 34 additions & 1 deletion python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,8 @@ def convert_slice(g, op, block):
strides = _op.const([1] * dims, dtype="int64")

out = _op.strided_slice(data, begin=starts, end=ends, strides=strides)
if decrease_axis:
out_shape = infer_shape(out)
if decrease_axis and len(out_shape) > 1:
out = _op.squeeze(out, axis=decrease_axis)
g.add_node(op.output("Out")[0], out)

Expand Down Expand Up @@ -1998,6 +1999,37 @@ def convert_swish(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_topk(g, op, block):
"""Operator converter for topk."""

data = g.get_node(op.input("X")[0])
if op.input("K"):
k = g.get_node(op.input("K")[0])
else:
k = op.attr("k")

largest = op.attr("largest")
is_ascend = not largest
axis = op.attr("axis")

value_names = op.output("Out")
indice_names = op.output("Indices")

out = None
indice = None
if value_names and indice_names:
out, indice = _op.topk(data=data, k=k, axis=axis, ret_type="both", is_ascend=is_ascend)
elif value_names:
out = _op.topk(data=data, k=k, axis=axis, ret_type="values", is_ascend=is_ascend)
elif indice_names:
indice = _op.topk(data=data, k=k, axis=axis, ret_type="indices", is_ascend=is_ascend)

if out is not None:
g.add_node(value_names[0], out)
if indice is not None:
g.add_node(indice_names[0], indice)


def convert_transpose(g, op, block):
"""Operator converter for transpose."""

Expand Down Expand Up @@ -2148,6 +2180,7 @@ def convert_unsqueeze(g, op, block):
"swish": convert_swish,
"tan": convert_unary_op,
"tanh": convert_unary_op,
"top_k_v2": convert_topk,
"transpose2": convert_transpose,
"unsqueeze2": convert_unsqueeze,
}
Expand Down
16 changes: 16 additions & 0 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,11 @@ def slice4(inputs):
x1 = paddle.to_tensor([3]) + paddle.to_tensor([1])
return inputs[:, x0:, 1:x1, :]

@paddle.jit.to_static
def slice5(inputs):
b, c, h, w = paddle.shape(inputs) # add decrease_axis
return h

input_shape = [1, 3, 10, 10]
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(
Expand All @@ -1362,6 +1367,7 @@ def slice4(inputs):
verify_model(slice2, input_data=input_data)
verify_model(slice3, input_data=paddle.randn((4, 4)))
verify_model(slice4, input_data=input_data)
verify_model(slice5, input_data=input_data)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -1681,5 +1687,15 @@ def forward(self, inputs, prev_h):
)


@tvm.testing.uses_gpu
def test_forward_topk():
@paddle.jit.to_static
def topk1(inputs):
return paddle.topk(inputs, k=1)

input_data = paddle.to_tensor([1, 4, 5, 7])
verify_model(topk1, input_data=input_data)
Copy link
Contributor

@heliqi heliqi Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a few different cases, example:

  1. k is a tensor and not just a int
  2. set largest/axis/sort attribute
  3. The output is only the values or the indices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to paddle.topk api, there seems is no a argument to control the output that only have the values or indices

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the paddle.topk api will definitely have two outputs.



if __name__ == "__main__":
pytest.main([__file__])