Skip to content

Commit

Permalink
Extracted dtype validate method into dtypes.py.
Browse files Browse the repository at this point in the history
Cross applied to nan_functions.py.
  • Loading branch information
alxmrs committed Jan 22, 2025
1 parent 0ffcf3c commit 046a766
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 59 deletions.
33 changes: 33 additions & 0 deletions cubed/array_api/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copied from numpy.array_api
from cubed.array_api.inspection import __array_namespace_info__
from cubed.backend_array_api import namespace as nxp

int8 = nxp.int8
Expand Down Expand Up @@ -86,3 +87,35 @@
"complex floating-point": _complex_floating_dtypes,
"floating-point": _floating_dtypes,
}


# A Cubed-specific utility.
def _validate_and_define_dtype(x, dtype=None, *, allowed_dtypes=("numeric",), fname=None, device=None):
"""Ensure the input dtype is allowed. If it's None, provide a good default dtype."""
dtypes = __array_namespace_info__().default_dtypes(device=device)

# Validate.
is_invalid = all(x.dtype not in _dtype_categories[a] for a in allowed_dtypes)
if is_invalid:
errmsg = f"Only {' or '.join(allowed_dtypes)} dtypes are allowed"
if fname:
errmsg += f" in {fname}"
raise TypeError(errmsg)

# Choose a good default dtype, when None
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = dtypes["integral"]
elif x.dtype in _signed_integer_dtypes:
dtype = dtypes["integral"]
elif x.dtype in _unsigned_integer_dtypes:
# Type arithemetic to produce an unsinged integer dtype at the same default precision.
dtype = nxp.dtype(dtypes["integral"].str.replace("i", "u"))
elif x.dtype == _complex_floating_dtypes:
dtype = dtypes["complex floating"]
elif x.dtype == _real_floating_dtypes:
dtype = dtypes["real floating"]
else:
dtype = x.dtype

return dtype
37 changes: 3 additions & 34 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import math

from cubed.array_api import __array_namespace_info__
from cubed.array_api.dtypes import (
_boolean_dtypes,
_numeric_dtypes,
_real_floating_dtypes,
_real_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
_validate_and_define_dtype,
)
from cubed.array_api.elementwise_functions import sqrt
from cubed.backend_array_api import namespace as nxp
Expand Down Expand Up @@ -114,7 +110,7 @@ def min(x, /, *, axis=None, keepdims=False, split_every=None):


def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
dtype = _validate_and_define_numeric_or_bool_dtype(x, dtype, fname="prod", device=device)
dtype = _validate_and_define_dtype(x, dtype, allowed_dtypes=("numeric", "boolean",), fname="prod", device=device)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
Expand All @@ -140,7 +136,7 @@ def std(x, /, *, axis=None, correction=0.0, keepdims=False, split_every=None):


def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
dtype = _validate_and_define_numeric_or_bool_dtype(x, dtype, fname="sum", device=device)
dtype = _validate_and_define_dtype(x, dtype, allowed_dtypes=("numeric", "boolean",), fname="sum", device=device)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
Expand Down Expand Up @@ -221,30 +217,3 @@ def _var_combine(a, axis=None, correction=None, **kwargs):

def _var_aggregate(a, correction=None, **kwargs):
return nxp.divide(a["M2"], a["n"] - correction)


def _validate_and_define_numeric_or_bool_dtype(x, dtype=None, *, fname=None, device=None):
"""Validate the type of the numeric function. If it's None, provide a good default dtype."""
dtypes = __array_namespace_info__().default_dtypes(device=device)

# Validate.
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
errmsg = "Only numeric or boolean dtypes are allowed"
if fname:
errmsg += f" in {fname}"
raise TypeError(errmsg)

# Choose a good default dtype, when None
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = dtypes["integral"]
elif x.dtype in _signed_integer_dtypes:
dtype = dtypes["integral"]
elif x.dtype in _unsigned_integer_dtypes:
# Type arithemetic to produce an unsinged integer dtype at the same default precision.
dtype = nxp.dtype(dtypes["integral"].str.replace("i", "u"))
else:
dtype = x.dtype

return dtype
28 changes: 3 additions & 25 deletions cubed/nan_functions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import numpy as np

from cubed.array_api.dtypes import (
_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
complex64,
complex128,
float32,
float64,
int64,
uint64,
)
from cubed.array_api.dtypes import _validate_and_define_dtype
from cubed.backend_array_api import namespace as nxp
from cubed.core import reduction

Expand Down Expand Up @@ -60,21 +50,9 @@ def _nannumel(x, **kwargs):
return nxp.sum(~(nxp.isnan(x)), **kwargs)


def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
"""Return the sum of array elements over a given axis treating NaNs as zero."""
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in nansum")
if dtype is None:
if x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = float64
elif x.dtype == complex64:
dtype = complex128
else:
dtype = x.dtype
dtype = _validate_and_define_dtype(x, dtype, allowed_dtypes=("numeric",), fname="nansum", device=device)
return reduction(
x,
nxp.nansum,
Expand Down

0 comments on commit 046a766

Please sign in to comment.