From 6d386932442109483307d61fb92953cc9275cca0 Mon Sep 17 00:00:00 2001 From: NripeshN Date: Tue, 11 Jul 2023 17:38:48 +0530 Subject: [PATCH 1/5] Add Gather paddle manipulation frontend --- .../frontends/paddle/tensor/manipulation.py | 9 ++++++ .../test_tensor/test_paddle_manipulation.py | 31 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/manipulation.py b/ivy/functional/frontends/paddle/tensor/manipulation.py index f25088711c44e..19dab9ff12880 100644 --- a/ivy/functional/frontends/paddle/tensor/manipulation.py +++ b/ivy/functional/frontends/paddle/tensor/manipulation.py @@ -96,3 +96,12 @@ def cast(x, dtype): @to_ivy_arrays_and_back def broadcast_to(x, shape, name=None): return ivy.broadcast_to(x, shape) + + +@with_supported_dtypes( + {"2.5.0 and below": ("bool", "float32", "float64", "int32", "int64")}, + "paddle", +) +@to_ivy_arrays_and_back +def gather(x, indices, axis=0, name=None): + return ivy.gather(x, indices, axis=axis) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_paddle_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_paddle_manipulation.py index 97749d7cfe02c..bdf5591eea296 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_paddle_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_paddle_manipulation.py @@ -450,3 +450,34 @@ def test_paddle_broadcast_to( x=x[0], shape=shape, ) + + +# gather +@handle_frontend_test( + fn_tree="paddle.gather", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ), + index=st.integers(min_value=0, max_value=10), +) +def test_paddle_gather( + *, + dtype_and_x, + index, + on_device, + fn_tree, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + index=index, + ) From bc7b2464112383d29ee5dbc8095bd53af10f10d1 Mon Sep 17 00:00:00 2001 From: NripeshN Date: Fri, 14 Jul 2023 21:46:21 +0530 Subject: [PATCH 2/5] test fix --- .../test_paddle/test_tensor/test_manipulation.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py index bdf5591eea296..6a1eec0ca3ead 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py @@ -457,15 +457,13 @@ def test_paddle_broadcast_to( fn_tree="paddle.gather", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, ), - index=st.integers(min_value=0, max_value=10), + dtype=helpers.get_dtypes("valid", full=False), ) def test_paddle_gather( *, dtype_and_x, - index, + dtype, on_device, fn_tree, frontend, @@ -478,6 +476,6 @@ def test_paddle_gather( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - index=index, + x=x[0], + dtype=dtype[0], ) From 6ed361aa48382cb85f1e6bcca0adfc8c835102b6 Mon Sep 17 00:00:00 2001 From: NripeshN Date: Sat, 15 Jul 2023 15:58:11 +0530 Subject: [PATCH 3/5] test fix --- .../test_tensor/test_manipulation.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py index 6a1eec0ca3ead..56c57a3568e12 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py @@ -456,26 +456,33 @@ def test_paddle_broadcast_to( @handle_frontend_test( fn_tree="paddle.gather", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("all"), + min_num_dims=1, + max_num_dims=6, ), - dtype=helpers.get_dtypes("valid", full=False), + index=st.integers(min_value=0, max_value=10), + axis=st.integers(min_value=0, max_value=5), + test_with_out=st.just(False), ) def test_paddle_gather( *, dtype_and_x, - dtype, + index, + axis, on_device, fn_tree, frontend, test_flags, ): input_dtype, x = dtype_and_x + index = index % x[0].shape[axis] helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, x=x[0], - dtype=dtype[0], + index=index, + axis=axis, + on_device=on_device, ) From 0dd1ff364cce88c1d71ac17293228eb34ad498e5 Mon Sep 17 00:00:00 2001 From: NripeshN Date: Sat, 15 Jul 2023 17:45:35 +0530 Subject: [PATCH 4/5] test fix --- .../test_frontends/test_paddle/test_tensor/test_manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py index 56c57a3568e12..638c8c85a3197 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py @@ -456,7 +456,7 @@ def test_paddle_broadcast_to( @handle_frontend_test( fn_tree="paddle.gather", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("all"), + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, max_num_dims=6, ), From 879db81b72cbcca67a04aaeafcf4725aff43fc41 Mon Sep 17 00:00:00 2001 From: NripeshN Date: Sun, 16 Jul 2023 20:58:46 +0530 Subject: [PATCH 5/5] test not working --- .../frontends/paddle/tensor/manipulation.py | 5 ++- .../test_tensor/test_manipulation.py | 44 ++++++++++++------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/ivy/functional/frontends/paddle/tensor/manipulation.py b/ivy/functional/frontends/paddle/tensor/manipulation.py index 19dab9ff12880..338c127af7697 100644 --- a/ivy/functional/frontends/paddle/tensor/manipulation.py +++ b/ivy/functional/frontends/paddle/tensor/manipulation.py @@ -103,5 +103,6 @@ def broadcast_to(x, shape, name=None): "paddle", ) @to_ivy_arrays_and_back -def gather(x, indices, axis=0, name=None): - return ivy.gather(x, indices, axis=axis) +def gather(params, indices, axis=-1, batch_dims=0, name=None): + return ivy.gather(params, indices, axis=axis, batch_dims=batch_dims) + diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py index 638c8c85a3197..b22a6bd73dad6 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py @@ -452,37 +452,47 @@ def test_paddle_broadcast_to( ) -# gather +@st.composite +def _gather_helper(draw): + dtype_and_param = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + + dtype_and_indices = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + dtype, param = dtype_and_param + dtype, indices = dtype_and_indices + return dtype, param, indices + + @handle_frontend_test( fn_tree="paddle.gather", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ), - index=st.integers(min_value=0, max_value=10), - axis=st.integers(min_value=0, max_value=5), - test_with_out=st.just(False), + dtype_param_and_indices=_gather_helper(), ) def test_paddle_gather( *, - dtype_and_x, - index, - axis, + dtype_param_and_indices, on_device, fn_tree, frontend, test_flags, ): - input_dtype, x = dtype_and_x - index = index % x[0].shape[axis] + input_dtype, param, indices = dtype_param_and_indices helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - x=x[0], - index=index, - axis=axis, on_device=on_device, + param=param[0], + indices=indices[0], )