Skip to content

Commit

Permalink
Added jax.numpy outer frontend (#11102)
Browse files Browse the repository at this point in the history
Co-authored-by: nathzi1505 <[email protected]>
  • Loading branch information
lucasalavapena and p3jitnath authored Mar 5, 2023
1 parent 0a2d9b7 commit 59a098b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ivy/functional/frontends/jax/numpy/mathematical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,8 @@ def floor_divide(x1, x2, /, out=None):
def inner(a, b):
a, b = promote_types_of_jax_inputs(a, b)
return ivy.inner(a, b)


@to_ivy_arrays_and_back
def outer(a, b, out=None):
return ivy.outer(a, b, out=out)
32 changes: 32 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,3 +2300,35 @@ def test_jax_numpy_inner(
a=xs[0],
b=xs[1],
)


@handle_frontend_test(
fn_tree="jax.numpy.outer",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=2,
min_value=-10,
max_value=10,
min_num_dims=1,
max_num_dims=1,
shared_dtype=True,
),
)
def test_jax_numpy_outer(
*,
dtype_and_x,
test_flags,
on_device,
fn_tree,
frontend,
):
input_dtypes, xs = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtypes,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
a=xs[0],
b=xs[1],
)

0 comments on commit 59a098b

Please sign in to comment.