Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ESOVICH committed Jun 29, 2023
1 parent 770111b commit 44b62fd
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ivy/functional/frontends/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def promote_types_of_numpy_inputs(
_mod,
_modf,
_multiply,
_remainder,
_negative,
_positive,
_power,
Expand Down Expand Up @@ -628,7 +629,8 @@ def promote_types_of_numpy_inputs(
)

from ivy.functional.frontends.numpy.mathematical_functions.floating_point_routines import ( # noqa
_nextafter, _spacing,
_nextafter,
_spacing,
)

_frontend_array = array
Expand All @@ -652,6 +654,7 @@ def promote_types_of_numpy_inputs(
mod = ufunc("_mod")
modf = ufunc("_modf")
multiply = ufunc("_multiply")
remainder = ufunc("_remainder")
negative = ufunc("_negative")
positive = ufunc("_positive")
power = ufunc("_power")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,27 @@ def _divmod(
out=out,
)
return ret


@handle_numpy_out
@handle_numpy_dtype
@to_ivy_arrays_and_back
@handle_numpy_casting
@from_zero_dim_arrays_to_scalar
def _remainder(
x1,
x2,
/,
out=None,
*,
where=True,
casting="same_kind",
order="k",
dtype=None,
subok=True,
):
x1, x2 = promote_types_of_numpy_inputs(x1, x2)
ret = ivy.remainder(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret
Original file line number Diff line number Diff line change
Expand Up @@ -731,3 +731,51 @@ def test_numpy_divmod(
x1=xs[0],
x2=xs[1],
)


# remainder
@handle_frontend_test(
fn_tree="numpy.remainder",
dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype(
arr_func=[
lambda: helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=2,
shared_dtype=True,
)
],
),
where=np_frontend_helpers.where(),
number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc(
fn_name="remainder"
),
)
def test_numpy_remainder(
dtypes_values_casting,
where,
frontend,
test_flags,
fn_tree,
on_device,
):
input_dtypes, xs, casting, dtype = dtypes_values_casting
where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools(
where=where,
input_dtype=input_dtypes,
test_flags=test_flags,
)
np_frontend_helpers.test_frontend_function(
input_dtypes=input_dtypes,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
x1=xs[0],
x2=xs[1],
out=None,
where=where,
casting=casting,
order="K",
dtype=dtype,
subok=True,
)

0 comments on commit 44b62fd

Please sign in to comment.