diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index fd52f15eafd2..a3f959a391c1 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -291,11 +291,14 @@ def convert_arg_min(g, op, block): def convert_argsort(g, op, block): """Operator converter for argsort.""" - x = g.get_node(op.inputs("X")[0]) + x = g.get_node(op.input("X")[0]) axis = op.attr("axis") descending = op.attr("descending") - out = _op.argsort(x, axis, not descending, dtype="int64") - g.add_node(op.output("Indices")[0], out) + + out = _op.sort(x, axis, not descending) + out_indice = _op.argsort(x, axis, not descending, dtype="int64") + g.add_node(op.output("Out")[0], out) + g.add_node(op.output("Indices")[0], out_indice) def convert_assign(g, op, block): @@ -1296,6 +1299,18 @@ def convert_mul(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_mv(g, op, block): + """Operator converter for mv.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Vec")[0]) + y = _op.expand_dims(y, axis=-1) + y = _op.transpose(y) + out = _op.nn.dense(x, y) + out = _op.squeeze(out, axis=[-1]) + g.add_node(op.output("Out")[0], out) + + def convert_numel(g, op, block): """Operator converter for numel.""" @@ -1639,9 +1654,9 @@ def make_init_param_inputs(g, node, layer): if is_bidirec: num_directions = 2 - X_shape = infer_shape(input_x) - time_steps = X_shape[0] - X_steps = _op.split(input_x, indices_or_sections=time_steps, axis=0) + x_shape = infer_shape(input_x) + time_steps = x_shape[0] + x_steps = _op.split(input_x, indices_or_sections=time_steps, axis=0) for layer in range(num_layers): input_weights, hidden_weights, input_bias, hidden_bias = make_param_inputs( g, op, layer, hidden_size, num_layers @@ -1661,7 +1676,7 @@ def make_init_param_inputs(g, node, layer): WB = g.get_node(input_bias[i]) RB = g.get_node(hidden_bias[i]) output, H, C = generate_lstm( - X_steps=X_steps, + X_steps=x_steps, H_t=H_t, C_t=C_t, W=W, @@ -1682,8 +1697,7 @@ def make_init_param_inputs(g, node, layer): output = _op.transpose(output, axes=[0, 2, 1, 3]) output = _op.reshape(output, newshape=(0, 0, -1)) - X_steps = output - X_steps = _op.split(X_steps, indices_or_sections=time_steps, axis=0) + x_steps = _op.split(output, indices_or_sections=time_steps, axis=0) g.add_node(op.output("Out")[0], output) @@ -2054,6 +2068,53 @@ def convert_unsqueeze(g, op, block): g.add_node(op.output("Out")[0], x) +def convert_unstack(g, op, block): + """Operator converter for unstack.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + num = op.attr("num") + out = _op.split(x, num, axis=axis) + for i, out_i in enumerate(out): + out_i = _op.squeeze(out_i, axis=[axis]) + g.add_node(op.output("Y")[i], out_i) + + +def convert_unique(g, op, block): + """Operator converter for unique.""" + + x = g.get_node(op.input("X")[0]) + ndim = len(infer_shape(x)) + assert ndim == 1, "Only support 1D Tensor for PaddlePaddle's unique" + is_sorted = op.attr("is_sorted") + return_counts = op.attr("return_counts") + return_index = op.attr("return_index") + return_inverse = op.attr("return_inverse") + if return_counts: + [unique, indices, inverse_indices, num_uniq, counts] = _op.unique( + x, is_sorted=is_sorted, return_counts=True + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") + indices_sliced = _op.strided_slice(indices, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.cast(counts_sliced, "int64") + g.add_node(op.output("Counts")[0], counts_sliced) + else: + [unique, indices, inverse_indices, num_uniq] = _op.unique( + x, is_sorted=is_sorted, return_counts=False + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + indices_sliced = _op.strided_slice(indices, begin=[0], end=num_uniq, slice_mode="size") + + inverse_indices = _op.cast(inverse_indices, "int64") + indices_sliced = _op.cast(indices_sliced, "int64") + g.add_node(op.output("Out")[0], unique_sliced) + if return_index: + g.add_node(op.output("Indices")[0], indices_sliced) + if return_inverse: + g.add_node(op.output("Index")[0], inverse_indices) + + def convert_where(g, op, block): """Operator converter for where.""" @@ -2152,6 +2213,7 @@ def convert_where(g, op, block): "logsumexp": convert_logsumexp, "matmul": convert_matmul, "matmul_v2": convert_matmul, + "mv": convert_mv, "mul": convert_mul, "nearest_interp_v2": convert_interpolate, "not_equal": convert_elementwise_op, @@ -2176,12 +2238,15 @@ def convert_where(g, op, block): "relu6": convert_relu6, "reshape2": convert_reshape, "rnn": convert_rnn, + "round": convert_unary_op, "rsqrt": convert_unary_op, "scale": convert_scale, "selu": convert_selu, "shape": convert_shape, "sigmoid": convert_unary_op, + "sign": convert_unary_op, "sin": convert_unary_op, + "sinh": convert_unary_op, "size": convert_numel, "slice": convert_slice, "softmax": convert_softmax, @@ -2189,6 +2254,7 @@ def convert_where(g, op, block): "softshrink": convert_softshrink, "softsign": convert_softsign, "split": convert_split, + "sqrt": convert_unary_op, "square": convert_square, "squeeze2": convert_squeeze, "stack": convert_stack, @@ -2203,6 +2269,8 @@ def convert_where(g, op, block): "tile": convert_tile, "transpose2": convert_transpose, "unsqueeze2": convert_unsqueeze, + "unstack": convert_unstack, + "unique": convert_unique, "where": convert_where, "where_index": convert_nonzero, } diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index bf731d65a781..7caedeff75f5 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -166,9 +166,16 @@ def forward(self, inputs): "log1p", "numel", "reciprocal", + "relu", + "round", + "rsqrt", + "sigmoid", + "sign", "rsqrt", "sin", + "sinh", "square", + "sqrt", "tan", "tanh", ] @@ -1190,6 +1197,45 @@ def forward(self, input1, input2): verify_model(MatMul1(), input_data=[input_data1, input_data2]) +@tvm.testing.uses_gpu +def test_forward_mm(): + class Mm(nn.Layer): + def forward(self, input1, input2): + return paddle.mm(input1, input2) + + # matrix x vector + input_data1 = paddle.randn((3, 4), dtype="float32") + input_data2 = paddle.randn((4,), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # matrix x matrix + input_data1 = paddle.randn((5, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # batched matrix x batched matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((10, 4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # batched matrix x broadcasted matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_mv(): + class Mv(nn.Layer): + def forward(self, input1, input2): + return paddle.mv(input1, input2) + + # matrix x vector + input_data1 = paddle.randn((3, 4), dtype="float32") + input_data2 = paddle.randn((4,), dtype="float32") + verify_model(Mv(), input_data=[input_data1, input_data2]) + + @tvm.testing.uses_gpu def test_forward_nonzero(): class Nonzero(nn.Layer): @@ -1406,6 +1452,21 @@ def forward(self, x, y): verify_model(Pow2(), input_data=[x_data, y_data]) +@tvm.testing.uses_gpu +def test_forward_rank(): + class Rank(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + rank = paddle.rank(inputs) + rank = paddle.unsqueeze(rank, axis=0) + output = inputs + rank + return output + + input_shape = [1, 2, 1, 3, 1] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Rank(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_reduce_op(): class ReduceOp_Bool(nn.Layer): @@ -1566,6 +1627,23 @@ def slice5(inputs): verify_model(slice5, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_sort(): + @paddle.jit.to_static + def sort(inputs): + return paddle.sort(inputs) + + @paddle.jit.to_static + def sort2(inputs): + return paddle.sort(inputs, axis=0, descending=True) + + input_shape = [2, 3, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(sort, input_data) + input_data2 = np.random.randint(100, size=input_shape) + verify_model(sort2, input_data2) + + @tvm.testing.uses_gpu def test_forward_split(): @paddle.jit.to_static @@ -1588,7 +1666,7 @@ def split3(inputs): @tvm.testing.uses_gpu -def test_forward_squeeze2(): +def test_forward_squeeze(): @paddle.jit.to_static def squeeze(inputs): return paddle.squeeze(inputs) @@ -1608,6 +1686,102 @@ def squeeze3(inputs): verify_model(squeeze3, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_stack(): + @paddle.jit.to_static + def stack(input1, input2, input3): + return paddle.stack([input1, input2, input3]) + + @paddle.jit.to_static + def stack2(input1, input2, input3): + return paddle.stack([input1, input2, input3], axis=-1) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") + input_data2 = paddle.rand(input_shape, dtype="float32") + input_data3 = paddle.rand(input_shape, dtype="float32") + verify_model(stack, input_data=[input_data, input_data2, input_data3]) + verify_model(stack2, input_data=[input_data, input_data2, input_data3]) + + +@tvm.testing.uses_gpu +def test_forward_std(): + class Std1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, 1, unbiased=False) + + class Std2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, axis=-2, keepdim=False, unbiased=False) + + class Std3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, axis=3, keepdim=True, unbiased=False) + + class Std4(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, axis=[2, 3], keepdim=True, unbiased=False) + + class Std5(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, axis=[2, 3], keepdim=False, unbiased=False) + + class Std6(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, unbiased=False) + + class Std7(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, unbiased=True) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Std1(), input_data=input_data) + verify_model(Std2(), input_data=input_data) + verify_model(Std3(), input_data=input_data) + verify_model(Std4(), input_data=input_data) + verify_model(Std5(), input_data=input_data) + verify_model(Std6(), input_data=input_data) + verify_model(Std7(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_subtract(): + class Subtract(nn.Layer): + @paddle.jit.to_static + def forward(self, x, y): + return paddle.subtract(x, y) + + input_data1 = paddle.to_tensor([2, np.nan, 5], dtype='float32') + input_data2 = paddle.to_tensor([1, 4, np.nan], dtype='float32') + verify_model(Subtract(), input_data=[input_data1, input_data2]) + + input_data1 = paddle.randint(0, 10, (3, 4), dtype="int32") + input_data2 = paddle.randint(0, 10, (4,), dtype="int32") + verify_model(Subtract(), input_data=[input_data1, input_data2]) + + input_data1 = paddle.randint(0, 10, (10, 3, 4), dtype="int64") + input_data2 = paddle.randint(0, 10, (3, 4), dtype="int64") + verify_model(Subtract(), input_data=[input_data1, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_t(): + class T(nn.Layer): + def forward(self, x): + return paddle.t(x) + + input_data1 = paddle.randn((3, 4), dtype="float32") + verify_model(T(), input_data=[input_data1]) + + @tvm.testing.uses_gpu def test_forward_topk(): class Topk1(nn.Layer): @@ -1650,24 +1824,6 @@ def forward(self, inputs): verify_model(Topk6(), input_data=input_data) -@tvm.testing.uses_gpu -def test_forward_stack(): - @paddle.jit.to_static - def stack(input1, input2, input3): - return paddle.stack([input1, input2, input3]) - - @paddle.jit.to_static - def stack2(input1, input2, input3): - return paddle.stack([input1, input2, input3], axis=-1) - - input_shape = [2, 3] - input_data = paddle.rand(input_shape, dtype="float32") - input_data2 = paddle.rand(input_shape, dtype="float32") - input_data3 = paddle.rand(input_shape, dtype="float32") - verify_model(stack, input_data=[input_data, input_data2, input_data3]) - verify_model(stack2, input_data=[input_data, input_data2, input_data3]) - - @tvm.testing.uses_gpu def test_forward_tile(): @paddle.jit.to_static @@ -1699,6 +1855,73 @@ def tile3(inputs, inputs2): verify_model(tile3, input_data=[input_data, input_data2]) +@tvm.testing.uses_gpu +def test_forward_unstack(): + @paddle.jit.to_static + def unstack1(x): + return paddle.unstack(x) + + @paddle.jit.to_static + def unstack2(x): + return paddle.unstack(x, axis=-1) + + @paddle.jit.to_static + def unstack3(x): + return paddle.unstack(x, axis=-1, num=3) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(unstack1, input_data=[input_data]) + verify_model(unstack2, input_data=[input_data]) + verify_model(unstack3, input_data=[input_data]) + + +@tvm.testing.uses_gpu +def test_forward_unique(): + @paddle.jit.to_static + def unique1(x): + return paddle.unique(x) + + @paddle.jit.to_static + def unique2(x): + return paddle.unique(x, return_index=True, return_inverse=False, return_counts=False) + + @paddle.jit.to_static + def unique3(x): + return paddle.unique(x, return_index=False, return_inverse=True, return_counts=False) + + @paddle.jit.to_static + def unique4(x): + return paddle.unique(x, return_index=False, return_inverse=False, return_counts=True) + + @paddle.jit.to_static + def unique5(x): + return paddle.unique(x, return_index=True, return_inverse=True, return_counts=False) + + @paddle.jit.to_static + def unique6(x): + return paddle.unique(x, return_index=False, return_inverse=True, return_counts=True) + + @paddle.jit.to_static + def unique7(x): + return paddle.unique(x, return_index=True, return_inverse=False, return_counts=True) + + @paddle.jit.to_static + def unique8(x): + return paddle.unique(x, return_index=True, return_inverse=True, return_counts=True) + + input_data = np.array([2, 3, 3, 1, 5, 3]) + input_data = paddle.to_tensor(input_data) + verify_model(unique1, input_data=[input_data], input_shape=[[6]]) + verify_model(unique2, input_data=[input_data], input_shape=[[6]]) + verify_model(unique3, input_data=[input_data], input_shape=[[6]]) + verify_model(unique4, input_data=[input_data], input_shape=[[6]]) + verify_model(unique5, input_data=[input_data], input_shape=[[6]]) + verify_model(unique6, input_data=[input_data], input_shape=[[6]]) + verify_model(unique7, input_data=[input_data], input_shape=[[6]]) + verify_model(unique8, input_data=[input_data], input_shape=[[6]]) + + @tvm.testing.uses_gpu def test_forward_zeros(): @paddle.jit.to_static @@ -1761,6 +1984,7 @@ def forward(self, x): test_forward_arange() test_forward_argmax() test_forward_argmin() + test_forward_argsort() test_forward_assign() test_forward_batch_norm() test_forward_cast() @@ -1795,6 +2019,8 @@ def forward(self, x): test_forward_look_up() test_forward_lstm() test_forward_matmul() + test_forward_mm() + test_forward_mv() test_forward_multiply() test_forward_nonzero() test_forward_norm() @@ -1803,15 +2029,22 @@ def forward(self, x): test_forward_pixel_shuffle() test_forward_prelu() test_forward_pow() + test_forward_rank() test_forward_reduce_op() test_forward_reshape() test_forward_scale() test_forward_slice() + test_forward_sort() test_forward_split() - test_forward_squeeze2() + test_forward_squeeze() + test_forward_std() + test_forward_subtract() + test_forward_t() test_forward_topk() test_forward_tile() test_forward_conv_transpose() + test_forward_unstack() + test_forward_unique() test_forward_math() test_forward_zeros() test_forward_where()