From fcac430c74edab1aa77d7cde75fc797f5cc5bded Mon Sep 17 00:00:00 2001 From: xc720 Date: Sun, 5 Feb 2023 15:12:55 +0000 Subject: [PATCH 01/14] added jax numpy logic_or --- .idea/ivy.iml | 4 +-- .idea/misc.xml | 4 +-- ivy/functional/frontends/jax/numpy/logic.py | 5 ++++ .../test_jax/test_jax_numpy_logic.py | 28 +++++++++++++++++++ 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/.idea/ivy.iml b/.idea/ivy.iml index f4b5a229e00b1..30b8bd0ac1604 100644 --- a/.idea/ivy.iml +++ b/.idea/ivy.iml @@ -2,7 +2,7 @@ - + @@ -12,4 +12,4 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 8cc1b33864f84..c3409f72a037a 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,8 +3,8 @@ - + - + \ No newline at end of file diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index 5861e7577fda8..83106eb0ccb2b 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -157,3 +157,8 @@ def isinf(x, /): def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): a, b = promote_jax_arrays(a, b) return ivy.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + + +def logical_or(x1, x2, /): + return ivy.logical_or(x1, x2) \ No newline at end of file diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py index cb522926d7d05..b4580bb91e404 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py @@ -661,3 +661,31 @@ def test_jax_numpy_isclose( b=input[1], equal_nan=equal_nan, ) + + + +#logical_or +@handle_frontend_test( + fn_tree="jax.numpy.logical_or", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("bool"), + num_arrays=2, + ), +) +def test_jax_numpy_logical_or( + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, +): + x_dtypes, x = dtypes_values + np_helpers.test_frontend_function( + input_dtypes=x_dtypes, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x1=x[0], + x2=x[1], + ) \ No newline at end of file From ef502c0234c5c025a95422c025b8fb77f8769030 Mon Sep 17 00:00:00 2001 From: xc720 Date: Sun, 5 Feb 2023 15:27:28 +0000 Subject: [PATCH 02/14] added jax numpy logic_or --- .idea/ivy.iml | 2 +- .idea/misc.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.idea/ivy.iml b/.idea/ivy.iml index 30b8bd0ac1604..d3b3d9a297392 100644 --- a/.idea/ivy.iml +++ b/.idea/ivy.iml @@ -2,7 +2,7 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index c3409f72a037a..072be263e40ea 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,7 +3,7 @@ - + From c29f298a397ce67f536fe12dcb3f30d2cd34e2c3 Mon Sep 17 00:00:00 2001 From: xc720 Date: Sun, 5 Feb 2023 16:37:46 +0000 Subject: [PATCH 03/14] fix formatting issues --- .../test_jax/test_jax_numpy_logic.py | 262 +++++++++--------- 1 file changed, 130 insertions(+), 132 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py index b4580bb91e404..4c62dfd4d745f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py @@ -20,13 +20,13 @@ test_with_out=st.just(False), ) def test_jax_numpy_allclose( - *, - dtype_and_input, - equal_nan, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_input, + equal_nan, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, input = dtype_and_input helpers.test_frontend_function( @@ -54,13 +54,13 @@ def test_jax_numpy_allclose( test_with_out=st.just(False), ) def test_jax_numpy_array_equal( - *, - dtype_and_x, - equal_nan, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + equal_nan, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -85,12 +85,12 @@ def test_jax_numpy_array_equal( test_with_out=st.just(False), ) def test_jax_numpy_array_equiv( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -120,12 +120,12 @@ def test_jax_numpy_array_equiv( test_with_out=st.just(False), ) def test_jax_numpy_isneginf( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -153,12 +153,12 @@ def test_jax_numpy_isneginf( ), ) def test_jax_numpy_isposinf( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -182,10 +182,10 @@ def test_jax_numpy_isposinf( test_with_out=st.just(False), ) def test_jax_numpy_less( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -209,10 +209,10 @@ def test_jax_numpy_less( test_with_out=st.just(False), ) def test_jax_numpy_less_equal( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -236,10 +236,10 @@ def test_jax_numpy_less_equal( test_with_out=st.just(False), ) def test_jax_numpy_greater( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -263,10 +263,10 @@ def test_jax_numpy_greater( test_with_out=st.just(False), ) def test_jax_numpy_greater_equal( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -294,12 +294,12 @@ def test_jax_numpy_greater_equal( ), ) def test_jax_numpy_isnan( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -323,10 +323,10 @@ def test_jax_numpy_isnan( test_with_out=st.just(False), ) def test_jax_numpy_equal( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -350,10 +350,10 @@ def test_jax_numpy_equal( test_with_out=st.just(False), ) def test_jax_numpy_not_equal( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -376,12 +376,12 @@ def test_jax_numpy_not_equal( test_with_out=st.just(False), ) def test_jax_numpy_all( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -404,12 +404,12 @@ def test_jax_numpy_all( test_with_out=st.just(False), ) def test_jax_numpy_bitwise_and( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -430,12 +430,12 @@ def test_jax_numpy_bitwise_and( test_with_out=st.just(False), ) def test_jax_numpy_bitwise_not( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -458,12 +458,12 @@ def test_jax_numpy_bitwise_not( test_with_out=st.just(False), ) def test_jax_numpy_bitwise_or( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -487,12 +487,12 @@ def test_jax_numpy_bitwise_or( test_with_out=st.just(False), ) def test_jax_numpy_bitwise_xor( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -521,14 +521,14 @@ def test_jax_numpy_bitwise_xor( test_with_out=st.just(False), ) def test_jax_numpy_any( - *, - dtype_x_axis, - keepdims, - where, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_x_axis, + keepdims, + where, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtypes, x, axis = dtype_x_axis if isinstance(axis, tuple): @@ -561,11 +561,11 @@ def test_jax_numpy_any( ), ) def test_jax_numpy_logical_and( - dtypes_values, - on_device, - fn_tree, - frontend, - test_flags, + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, ): x_dtypes, x = dtypes_values np_helpers.test_frontend_function( @@ -587,11 +587,11 @@ def test_jax_numpy_logical_and( ), ) def test_jax_numpy_invert( - dtypes_values, - on_device, - fn_tree, - frontend, - test_flags, + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, ): x_dtypes, x = dtypes_values np_helpers.test_frontend_function( @@ -613,12 +613,12 @@ def test_jax_numpy_invert( test_with_out=st.just(False), ) def test_jax_numpy_isinf( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -642,13 +642,13 @@ def test_jax_numpy_isinf( test_with_out=st.just(False), ) def test_jax_numpy_isclose( - *, - dtype_and_input, - equal_nan, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_input, + equal_nan, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, input = dtype_and_input helpers.test_frontend_function( @@ -663,8 +663,7 @@ def test_jax_numpy_isclose( ) - -#logical_or +# logical_or @handle_frontend_test( fn_tree="jax.numpy.logical_or", dtypes_values=helpers.dtype_and_values( @@ -673,11 +672,11 @@ def test_jax_numpy_isclose( ), ) def test_jax_numpy_logical_or( - dtypes_values, - on_device, - fn_tree, - frontend, - test_flags, + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, ): x_dtypes, x = dtypes_values np_helpers.test_frontend_function( @@ -687,5 +686,4 @@ def test_jax_numpy_logical_or( fn_tree=fn_tree, on_device=on_device, x1=x[0], - x2=x[1], - ) \ No newline at end of file + x2=x[1], ) From c59754ef376513d423a477ee112ead03819df0a1 Mon Sep 17 00:00:00 2001 From: xc720 Date: Sun, 5 Feb 2023 16:45:30 +0000 Subject: [PATCH 04/14] fix formatting issues --- ivy/functional/frontends/jax/numpy/logic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index 83106eb0ccb2b..a523b7caf917c 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -127,7 +127,6 @@ def any(a, axis=None, out=None, keepdims=False, *, where=None): alltrue = all - sometrue = any @@ -159,6 +158,5 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return ivy.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - def logical_or(x1, x2, /): - return ivy.logical_or(x1, x2) \ No newline at end of file + return ivy.logical_or(x1, x2) From 0a19665ab26b3123214adb36f115219c60a8f1af Mon Sep 17 00:00:00 2001 From: xc720 Date: Tue, 7 Feb 2023 20:40:18 +0000 Subject: [PATCH 05/14] change idea files back --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index eed057456974f..228ff1fd7e6da 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ with_time_logs/ internal_automation_tools/ .vscode/* .idea/* +*.iml From 89663fd639ac9ced12e4f672b8e162636d503b38 Mon Sep 17 00:00:00 2001 From: xc720 Date: Tue, 7 Feb 2023 20:47:40 +0000 Subject: [PATCH 06/14] change idea files back --- .gitignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 228ff1fd7e6da..35a762120aef7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,5 +27,4 @@ with_time_logs/ .array_api_tests_k_flag* internal_automation_tools/ .vscode/* -.idea/* -*.iml +.idea/* \ No newline at end of file From 54c1c2187e3eb41ac06cabd45ee51eb5d4289da6 Mon Sep 17 00:00:00 2001 From: xc720 Date: Tue, 7 Feb 2023 21:01:31 +0000 Subject: [PATCH 07/14] change back --- .gitignore | 2 +- .idea/ivy.iml | 2 +- .idea/misc.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 35a762120aef7..eed057456974f 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,4 @@ with_time_logs/ .array_api_tests_k_flag* internal_automation_tools/ .vscode/* -.idea/* \ No newline at end of file +.idea/* diff --git a/.idea/ivy.iml b/.idea/ivy.iml index d3b3d9a297392..f4b5a229e00b1 100644 --- a/.idea/ivy.iml +++ b/.idea/ivy.iml @@ -12,4 +12,4 @@ - \ No newline at end of file + diff --git a/.idea/misc.xml b/.idea/misc.xml index 072be263e40ea..8cc1b33864f84 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -7,4 +7,4 @@ - \ No newline at end of file + From 58d3531f131da8b32f7bc7668a9b56e79d35e0a0 Mon Sep 17 00:00:00 2001 From: xc720 Date: Tue, 7 Feb 2023 21:19:29 +0000 Subject: [PATCH 08/14] two conflicts resolved --- ivy/functional/frontends/jax/numpy/logic.py | 5 +++-- .../test_jax/test_jax_numpy_logic.py | 22 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index a523b7caf917c..eb5f05933d534 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -158,5 +158,6 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return ivy.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -def logical_or(x1, x2, /): - return ivy.logical_or(x1, x2) +@to_ivy_arrays_and_back +def logical_not(x, /): + return ivy.logical_not(x) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py index 4c62dfd4d745f..65411aeea2bcb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py @@ -663,20 +663,20 @@ def test_jax_numpy_isclose( ) -# logical_or +# logical_not @handle_frontend_test( - fn_tree="jax.numpy.logical_or", + fn_tree="jax.numpy.logical_not", dtypes_values=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("bool"), - num_arrays=2, + num_arrays=1, ), ) -def test_jax_numpy_logical_or( - dtypes_values, - on_device, - fn_tree, - frontend, - test_flags, +def test_jax_numpy_logical_not( + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, ): x_dtypes, x = dtypes_values np_helpers.test_frontend_function( @@ -685,5 +685,5 @@ def test_jax_numpy_logical_or( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], ) + x=x[0], + ) From 28154f48ef2f2196bebfa8818ed256b5006c75f5 Mon Sep 17 00:00:00 2001 From: xc720 Date: Sun, 12 Feb 2023 13:50:45 +0000 Subject: [PATCH 09/14] change back the changes made accidentally --- ivy/functional/frontends/jax/numpy/logic.py | 11 +- .../test_jax/test_jax_numpy_logic.py | 273 ++++++++++-------- 2 files changed, 156 insertions(+), 128 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index eb5f05933d534..db4506b87ef9d 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -127,6 +127,7 @@ def any(a, axis=None, out=None, keepdims=False, *, where=None): alltrue = all + sometrue = any @@ -142,6 +143,11 @@ def logical_and(x1, x2, /): return ivy.logical_and(x1, x2) +@to_ivy_arrays_and_back +def logical_not(x, /): + return ivy.logical_not(x) + + @to_ivy_arrays_and_back def invert(x, /): return ivy.bitwise_invert(x) @@ -156,8 +162,3 @@ def isinf(x, /): def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): a, b = promote_jax_arrays(a, b) return ivy.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - - -@to_ivy_arrays_and_back -def logical_not(x, /): - return ivy.logical_not(x) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py index 65411aeea2bcb..116c81c8738b2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py @@ -20,13 +20,13 @@ test_with_out=st.just(False), ) def test_jax_numpy_allclose( - *, - dtype_and_input, - equal_nan, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_input, + equal_nan, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, input = dtype_and_input helpers.test_frontend_function( @@ -54,13 +54,13 @@ def test_jax_numpy_allclose( test_with_out=st.just(False), ) def test_jax_numpy_array_equal( - *, - dtype_and_x, - equal_nan, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + equal_nan, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -85,12 +85,12 @@ def test_jax_numpy_array_equal( test_with_out=st.just(False), ) def test_jax_numpy_array_equiv( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -120,12 +120,12 @@ def test_jax_numpy_array_equiv( test_with_out=st.just(False), ) def test_jax_numpy_isneginf( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -153,12 +153,12 @@ def test_jax_numpy_isneginf( ), ) def test_jax_numpy_isposinf( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -182,10 +182,10 @@ def test_jax_numpy_isposinf( test_with_out=st.just(False), ) def test_jax_numpy_less( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -209,10 +209,10 @@ def test_jax_numpy_less( test_with_out=st.just(False), ) def test_jax_numpy_less_equal( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -236,10 +236,10 @@ def test_jax_numpy_less_equal( test_with_out=st.just(False), ) def test_jax_numpy_greater( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -263,10 +263,10 @@ def test_jax_numpy_greater( test_with_out=st.just(False), ) def test_jax_numpy_greater_equal( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -294,12 +294,12 @@ def test_jax_numpy_greater_equal( ), ) def test_jax_numpy_isnan( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -323,10 +323,10 @@ def test_jax_numpy_isnan( test_with_out=st.just(False), ) def test_jax_numpy_equal( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -350,10 +350,10 @@ def test_jax_numpy_equal( test_with_out=st.just(False), ) def test_jax_numpy_not_equal( - dtype_and_x, - frontend, - test_flags, - fn_tree, + dtype_and_x, + frontend, + test_flags, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -376,12 +376,12 @@ def test_jax_numpy_not_equal( test_with_out=st.just(False), ) def test_jax_numpy_all( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -404,12 +404,12 @@ def test_jax_numpy_all( test_with_out=st.just(False), ) def test_jax_numpy_bitwise_and( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -430,12 +430,12 @@ def test_jax_numpy_bitwise_and( test_with_out=st.just(False), ) def test_jax_numpy_bitwise_not( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -458,12 +458,12 @@ def test_jax_numpy_bitwise_not( test_with_out=st.just(False), ) def test_jax_numpy_bitwise_or( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -487,12 +487,12 @@ def test_jax_numpy_bitwise_or( test_with_out=st.just(False), ) def test_jax_numpy_bitwise_xor( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -521,14 +521,14 @@ def test_jax_numpy_bitwise_xor( test_with_out=st.just(False), ) def test_jax_numpy_any( - *, - dtype_x_axis, - keepdims, - where, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_x_axis, + keepdims, + where, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtypes, x, axis = dtype_x_axis if isinstance(axis, tuple): @@ -561,11 +561,11 @@ def test_jax_numpy_any( ), ) def test_jax_numpy_logical_and( - dtypes_values, - on_device, - fn_tree, - frontend, - test_flags, + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, ): x_dtypes, x = dtypes_values np_helpers.test_frontend_function( @@ -587,11 +587,11 @@ def test_jax_numpy_logical_and( ), ) def test_jax_numpy_invert( - dtypes_values, - on_device, - fn_tree, - frontend, - test_flags, + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, ): x_dtypes, x = dtypes_values np_helpers.test_frontend_function( @@ -604,6 +604,33 @@ def test_jax_numpy_invert( ) +# isfinite +@handle_frontend_test( + fn_tree="jax.numpy.isfinite", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), allow_nan=True + ), + test_with_out=st.just(False), +) +def test_jax_numpy_isfinite( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + # isinf @handle_frontend_test( fn_tree="jax.numpy.isinf", @@ -613,12 +640,12 @@ def test_jax_numpy_invert( test_with_out=st.just(False), ) def test_jax_numpy_isinf( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -642,13 +669,13 @@ def test_jax_numpy_isinf( test_with_out=st.just(False), ) def test_jax_numpy_isclose( - *, - dtype_and_input, - equal_nan, - on_device, - fn_tree, - frontend, - test_flags, + *, + dtype_and_input, + equal_nan, + on_device, + fn_tree, + frontend, + test_flags, ): input_dtype, input = dtype_and_input helpers.test_frontend_function( From 4ee90d4764ae5dc58139d60480c9ddbf987a67c9 Mon Sep 17 00:00:00 2001 From: xc720 Date: Mon, 13 Feb 2023 12:49:00 +0000 Subject: [PATCH 10/14] commit changes in master from 719 to 742 --- .../test_jax/test_jax_numpy_logic.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py index 116c81c8738b2..1496fc29d843b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py @@ -714,3 +714,29 @@ def test_jax_numpy_logical_not( on_device=on_device, x=x[0], ) + + +# isscalar +@handle_frontend_test( + fn_tree="jax.numpy.isscalar", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") + ), +) +def test_jax_numpy_isscalar( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, +): + x_dtypes, x = dtype_and_x + np_helpers.test_frontend_function( + input_dtypes=x_dtypes, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) \ No newline at end of file From b30b48a1b0316bdc3334792724924b53295c6e53 Mon Sep 17 00:00:00 2001 From: xc720 Date: Tue, 14 Feb 2023 12:22:32 +0000 Subject: [PATCH 11/14] reimplemented logic_or --- ivy/functional/frontends/jax/numpy/logic.py | 5 +++-- .../test_jax/test_jax_numpy_logic.py | 21 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index db4506b87ef9d..2a7b126718425 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -144,8 +144,9 @@ def logical_and(x1, x2, /): @to_ivy_arrays_and_back -def logical_not(x, /): - return ivy.logical_not(x) +def logical_or(x1, x2, /): + x1, x2 = promote_jax_arrays(x1, x2) + return ivy.logical_or(x1, x2) @to_ivy_arrays_and_back diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py index 1496fc29d843b..d4f788e330248 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py @@ -690,20 +690,20 @@ def test_jax_numpy_isclose( ) -# logical_not +# logical_or @handle_frontend_test( - fn_tree="jax.numpy.logical_not", + fn_tree="jax.numpy.logical_or", dtypes_values=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("bool"), - num_arrays=1, + num_arrays=2, ), ) -def test_jax_numpy_logical_not( - dtypes_values, - on_device, - fn_tree, - frontend, - test_flags, +def test_jax_numpy_logical_or( + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, ): x_dtypes, x = dtypes_values np_helpers.test_frontend_function( @@ -712,7 +712,8 @@ def test_jax_numpy_logical_not( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) From 459ed74538832505011a752a93bceb0528f2d6d0 Mon Sep 17 00:00:00 2001 From: xc720 Date: Tue, 14 Feb 2023 12:28:01 +0000 Subject: [PATCH 12/14] reimplemented logic_or --- ivy/functional/frontends/jax/numpy/logic.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index f7fd3d081e920..935d2382e4290 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -143,12 +143,6 @@ def logical_and(x1, x2, /): return ivy.logical_and(x1, x2) -@to_ivy_arrays_and_back -def logical_or(x1, x2, /): - x1, x2 = promote_jax_arrays(x1, x2) - return ivy.logical_or(x1, x2) - - @to_ivy_arrays_and_back def invert(x, /): return ivy.bitwise_invert(x) @@ -175,6 +169,12 @@ def logical_not(x, /): return ivy.logical_not(x) +@to_ivy_arrays_and_back +def logical_or(x1, x2, /): + x1, x2 = promote_jax_arrays(x1, x2) + return ivy.logical_or(x1, x2) + + @to_ivy_arrays_and_back def isscalar(x, /): return ivy.isscalar(x) From 09aba4a69e3dd7e52d3826372c799668316625fb Mon Sep 17 00:00:00 2001 From: xc720 Date: Tue, 14 Feb 2023 12:35:02 +0000 Subject: [PATCH 13/14] add new line at the end of the file --- .../test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py index d4f788e330248..a68ed2914a977 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py @@ -740,4 +740,4 @@ def test_jax_numpy_isscalar( fn_tree=fn_tree, on_device=on_device, x=x[0], - ) \ No newline at end of file + ) From 91afb37ffcc04e28d8847da14c1a3fcf4aee2fa1 Mon Sep 17 00:00:00 2001 From: xc720 Date: Tue, 14 Feb 2023 12:39:41 +0000 Subject: [PATCH 14/14] remain original logic_not --- .../test_jax/test_jax_numpy_logic.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py index a68ed2914a977..dd73fbaae93fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py @@ -690,6 +690,32 @@ def test_jax_numpy_isclose( ) +# logical_not +@handle_frontend_test( + fn_tree="jax.numpy.logical_not", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("bool"), + num_arrays=1, + ), +) +def test_jax_numpy_logical_not( + dtypes_values, + on_device, + fn_tree, + frontend, + test_flags, +): + x_dtypes, x = dtypes_values + np_helpers.test_frontend_function( + input_dtypes=x_dtypes, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + # logical_or @handle_frontend_test( fn_tree="jax.numpy.logical_or",