Skip to content

Commit

Permalink
Adding numpy dtype routine functions to numpy frontend (ivy-llc#10840)
Browse files Browse the repository at this point in the history
  • Loading branch information
fnhirwa authored Feb 22, 2023
1 parent fcb5558 commit b347b10
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# local
import ivy


def iinfo(dtype):
return ivy.iinfo(dtype)


def finfo(dtype):
return ivy.finfo(dtype)
62 changes: 62 additions & 0 deletions ivy/functional/frontends/numpy/data_type_routines/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,65 @@ def promote_types(type1, type2, /):
if isinstance(type2, np_frontend.dtype):
type2 = type2._ivy_dtype
return np_frontend.dtype(np_frontend.promote_numpy_dtypes(type1, type2))


# dtypes as string
all_int_dtypes = ["int8", "int16", "int32", "int64"]
all_uint_dtypes = ["uint8", "uint16", "uint32", "uint64"]
all_float_dtypes = [
"float16",
"float32",
"float64",
]
all_complex_dtypes = ["complex64", "complex128"]


def min_scalar_type(a, /):
if ivy.is_array(a) and a.shape == ():
a = a.item()
if np_frontend.isscalar(a):
validation_dtype = type(a)
if "int" in validation_dtype.__name__:
for dtype in all_uint_dtypes:
if np_frontend.iinfo(dtype).min <= a <= np_frontend.iinfo(dtype).max:
return np_frontend.dtype(dtype)
for dtype in all_int_dtypes:
if np_frontend.iinfo(dtype).min <= a <= np_frontend.iinfo(dtype).max:
return np_frontend.dtype(dtype)

elif "float" in validation_dtype.__name__:
for dtype in all_float_dtypes:
if np_frontend.finfo(dtype).min <= a <= np_frontend.finfo(dtype).max:
return np_frontend.dtype(dtype)
elif "complex" in validation_dtype.__name__:
for dtype in all_complex_dtypes:
if np_frontend.finfo(dtype).min <= a <= np_frontend.finfo(dtype).max:
return np_frontend.dtype(dtype)
else:
return np_frontend.dtype(validation_dtype)
else:
return np_frontend.dtype(a.dtype)


@to_ivy_arrays_and_back
def result_type(*arrays_and_dtypes):
if len(arrays_and_dtypes) == 0:
raise ivy.utils.exceptions.IvyException(
"At least one array or dtype must be provided"
)
if len(arrays_and_dtypes) == 1:
if isinstance(arrays_and_dtypes[0], np_frontend.dtype):
return arrays_and_dtypes[0]
else:
return np_frontend.dtype(arrays_and_dtypes[0].dtype)
else:
res = (
arrays_and_dtypes[0]
if not ivy.is_array(arrays_and_dtypes[0])
else np_frontend.dtype(arrays_and_dtypes[0].dtype)
)
for elem in arrays_and_dtypes:
if ivy.is_array(elem):
elem = np_frontend.dtype(elem.dtype)
res = promote_types(res, elem)
return res
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def isscalar(element):
isinstance(element, int)
or isinstance(element, bool)
or isinstance(element, float)
or isinstance(element, complex)
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# global
from hypothesis import strategies as st, settings
from hypothesis import strategies as st, settings, assume

# local
import ivy
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers.testing_helpers import handle_frontend_test

Expand Down Expand Up @@ -67,3 +68,34 @@ def test_numpy_promote_types(
test_values=False,
)
assert ret._ivy_dtype == frontend_ret[0].name


@handle_frontend_test(
fn_tree="numpy.min_scalar_type",
x=st.one_of(
helpers.ints(min_value=-256, max_value=256),
st.booleans(),
helpers.floats(min_value=-256, max_value=256),
),
)
@settings(max_examples=200)
def test_numpy_min_scalar_type(
*,
x,
on_device,
fn_tree,
frontend,
test_flags,
): # skip torch backend uint
if ivy.current_backend_str() == "torch":
assume(not isinstance(x, int))
ret, frontend_ret = helpers.test_frontend_function(
input_dtypes=[],
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
a=x,
test_values=False,
)
assert ret._ivy_dtype == frontend_ret[0].name
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_numpy_any(

@handle_frontend_test(
fn_tree="numpy.isscalar",
element=st.booleans() | st.floats() | st.integers(),
element=st.booleans() | st.floats() | st.integers() | st.complex_numbers(),
test_with_out=st.just(False),
)
def test_numpy_isscalar(
Expand Down

0 comments on commit b347b10

Please sign in to comment.