From aa6b0d9d78ba9614a1ee99ba9372aff20caa976c Mon Sep 17 00:00:00 2001 From: Jacob Date: Thu, 16 Feb 2023 19:33:16 +0100 Subject: [PATCH 1/6] Added `jax.numpy.convolve()` frontend function and test. --- .../jax/numpy/mathematical_functions.py | 30 ++++++++++++++++ .../test_jax/test_jax_numpy_math.py | 34 +++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index f023542991dd8..349b49a613357 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -38,6 +38,36 @@ def arctan2(x1, x2): return ivy.atan2(x1, x2) +@to_ivy_arrays_and_back +def convolve(a, v, mode='full', *, precision=None): + a, v = promote_types_of_jax_inputs(a, v) + if ivy.get_num_dims(a) != 1: + raise ValueError( + "convolve() only support 1-dimensional inputs." + ) + if len(a) == 0 or len(v) == 0: + raise ValueError( + f"convolve: inputs cannot be empty, got shapes {a.shape} and {v.shape}." + ) + if len(a) < len(v): + a, v = v, a + v = ivy.flip(v) + + out_order = slice(None) + + if mode == 'valid': + padding = [(0, 0)] + elif mode == 'same': + padding = [(v.shape[0] // 2, v.shape[0] - v.shape[0] // 2 - 1)] + elif mode == 'full': + padding = [(v.shape[0] - 1, v.shape[0] - 1)] + + result = ivy.conv_general_dilated( + a[None, None, :], v[:, None, None], (1,), padding, dims=1, data_format='channel_first' + ) + return result[0, 0, out_order] + + @to_ivy_arrays_and_back def cos(x): return ivy.cos(x) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py index c4c092395d583..5bc476a1cd37d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py @@ -171,6 +171,40 @@ def test_jax_numpy_arctan2( ) +# convolve +@handle_frontend_test( + fn_tree = "jax.numpy.convolve", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_num_dims=1, + max_num_dims=1, + min_value=-1e04, + max_value=1e04, + ), +) +def test_jax_numpy_convolve( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + v=x[1], + mode='full', + precision=None, + ) + + @handle_frontend_test( fn_tree="jax.numpy.cos", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), From 085279c1562e077551e053a1dd5da36759d07906 Mon Sep 17 00:00:00 2001 From: Jacob Date: Tue, 21 Feb 2023 16:18:16 +0100 Subject: [PATCH 2/6] Fixed linting errors --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 5 +++-- .../test_ivy/test_frontends/test_jax/test_jax_numpy_math.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index eb06412cb52ef..a66cdc7ee8a19 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -39,7 +39,7 @@ def arctan2(x1, x2): @to_ivy_arrays_and_back -def convolve(a, v, mode='full', *, precision=None): +def convolve(a, v, mode='full', *, precision=None): a, v = promote_types_of_jax_inputs(a, v) if ivy.get_num_dims(a) != 1: raise ValueError( @@ -63,7 +63,8 @@ def convolve(a, v, mode='full', *, precision=None): padding = [(v.shape[0] - 1, v.shape[0] - 1)] result = ivy.conv_general_dilated( - a[None, None, :], v[:, None, None], (1,), padding, dims=1, data_format='channel_first' + a[None, None, :], v[:, None, None], (1,), + padding, dims=1, data_format='channel_first', ) return result[0, 0, out_order] diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py index 788e2bd1d9de0..cd0962f9130a8 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py @@ -173,7 +173,7 @@ def test_jax_numpy_arctan2( # convolve @handle_frontend_test( - fn_tree = "jax.numpy.convolve", + fn_tree="jax.numpy.convolve", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, From 236306bb0aab8edc16be0f8addb29b9b087e9f84 Mon Sep 17 00:00:00 2001 From: Jacob Date: Tue, 21 Feb 2023 16:27:42 +0100 Subject: [PATCH 3/6] Fixed linting errors --- ivy/functional/frontends/jax/numpy/logic.py | 1 - .../test_ivy/test_frontends/test_torch/test_reduction_ops.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index c8493767cc40c..2b03f92f404fd 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -183,4 +183,3 @@ def isscalar(x, /): @to_ivy_arrays_and_back def left_shift(x1, x2): return ivy.isscalar(x1, x2) - diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py index 77a8a29988720..5e3c69a5313f2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py @@ -8,9 +8,8 @@ from ivy_tests.test_ivy.test_functional.test_core.test_statistical import ( statistical_dtype_values, ) -from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_statistical import ( - statistical_dtype_values as statistical_dtype_values_experimental, -) +from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_statistical \ + import (statistical_dtype_values as statistical_dtype_values_experimental) @handle_frontend_test( From 6ddc628f4aaffc53780e6377ed0e29b74699b1fe Mon Sep 17 00:00:00 2001 From: Jacob Date: Wed, 22 Feb 2023 00:30:17 +0100 Subject: [PATCH 4/6] Added data type restrictions --- .../test_ivy/test_frontends/test_jax/test_jax_numpy_math.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py index cd0962f9130a8..8e08da9d6aa08 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py @@ -192,6 +192,7 @@ def test_jax_numpy_convolve( test_flags, ): input_dtype, x = dtype_and_x + assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, From 5925c525c749fc0e458ab8b2e76b0b8b4402a4a9 Mon Sep 17 00:00:00 2001 From: Jacob Date: Thu, 23 Feb 2023 16:11:08 +0100 Subject: [PATCH 5/6] Added rtol and atol --- .../test_ivy/test_frontends/test_jax/test_jax_numpy_math.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py index 8e08da9d6aa08..974abb1f6cd49 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py @@ -198,6 +198,8 @@ def test_jax_numpy_convolve( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + rtol=1e-5, + atol=1e-5, on_device=on_device, a=x[0], v=x[1], From 9851219ad079df7daf2566e209da41000cb7ae0d Mon Sep 17 00:00:00 2001 From: Jacob Date: Thu, 23 Feb 2023 18:45:06 +0100 Subject: [PATCH 6/6] Added sampling of pooling mode --- .../test_frontends/test_jax/test_jax_numpy_math.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py index 974abb1f6cd49..23b9dc48366ff 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py @@ -171,6 +171,12 @@ def test_jax_numpy_arctan2( ) +@st.composite +def _get_pooling_mode(draw): + mode = draw(st.sampled_from(['full', 'valid', 'same'])) + return mode + + # convolve @handle_frontend_test( fn_tree="jax.numpy.convolve", @@ -182,6 +188,7 @@ def test_jax_numpy_arctan2( min_value=-1e04, max_value=1e04, ), + mode=_get_pooling_mode(), ) def test_jax_numpy_convolve( *, @@ -190,6 +197,7 @@ def test_jax_numpy_convolve( fn_tree, frontend, test_flags, + mode, ): input_dtype, x = dtype_and_x assume("float16" not in input_dtype) @@ -198,12 +206,12 @@ def test_jax_numpy_convolve( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - rtol=1e-5, - atol=1e-5, + rtol=1e-4, + atol=1e-4, on_device=on_device, a=x[0], v=x[1], - mode='full', + mode=mode, precision=None, )