Skip to content

Commit

Permalink
Merge pull request #13380 from rajshukla1102/unravelissue
Browse files Browse the repository at this point in the history
numpy's frontend unravel_index function
  • Loading branch information
zhumakhan authored Mar 31, 2023
2 parents 55013a2 + fcdeca7 commit 11422d0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ def indices(dimensions, dtype=int, sparse=False):
else:
res[i] = idx
return res


# unravel_index
@to_ivy_arrays_and_back
def unravel_index(indices, shape, order='C'):
ret = [x.astype("int64") for x in ivy.unravel_index(indices, shape)]
return tuple(ret)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Testing Function
# global
import numpy as np
from hypothesis import strategies as st

# local
Expand Down Expand Up @@ -190,3 +191,50 @@ def test_indices(
dtype=dtype[0],
sparse=sparse,
)


# unravel_index
@st.composite
def max_value_as_shape_prod(draw):
shape = draw(
helpers.get_shape(
min_num_dims=1,
max_num_dims=5,
min_dim_size=1,
max_dim_size=5,
)
)
dtype_and_x = draw(
helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("valid"),
min_value=0,
max_value=np.prod(shape) - 1,
)
)
return dtype_and_x, shape


@handle_frontend_test(
fn_tree="numpy.unravel_index",
dtype_x_shape=max_value_as_shape_prod(),
test_with_out=st.just(False),
)
def test_numpy_unravel_index(
*,
dtype_x_shape,
test_flags,
frontend,
fn_tree,
on_device,
):
dtype_and_x, shape = dtype_x_shape
input_dtype, x = dtype_and_x[0], dtype_and_x[1]
helpers.test_frontend_function(
input_dtypes=input_dtype,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
indices=x[0],
shape=shape,
)

0 comments on commit 11422d0

Please sign in to comment.