Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes element-wise comparisons of mixed signed-unsigned integer inputs #1650

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 7 additions & 36 deletions dpctl/tensor/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@
from dpctl.utils import ExecutionPlacementError

from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_is_weak_dtype,
_strong_dtype_num_kind,
_weak_type_num_kind,
)
Expand All @@ -47,29 +46,10 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
"Resolves weak data types per NEP-0050,"
"where the second and third arguments are"
"permitted to be weak types"
if isinstance(
st_dtype,
(
WeakBooleanType,
WeakIntegralType,
WeakFloatingType,
WeakComplexType,
),
):
if _is_weak_dtype(st_dtype):
raise ValueError
if isinstance(
dtype1,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
if isinstance(
dtype2,
(
WeakBooleanType,
WeakIntegralType,
WeakFloatingType,
WeakComplexType,
),
):
if _is_weak_dtype(dtype1):
if _is_weak_dtype(dtype2):
kind_num1 = _weak_type_num_kind(dtype1)
kind_num2 = _weak_type_num_kind(dtype2)
st_kind_num = _strong_dtype_num_kind(st_dtype)
Expand Down Expand Up @@ -120,10 +100,7 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
return _to_device_supported_dtype(dpt.float64, dev), dtype2
else:
return max_dtype, dtype2
elif isinstance(
dtype2,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
elif _is_weak_dtype(dtype2):
max_dt_num_kind, max_dtype = max(
[
(_strong_dtype_num_kind(st_dtype), st_dtype),
Expand Down Expand Up @@ -152,15 +129,9 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):

def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
"Resolves one weak data type with one strong data type per NEP-0050"
if isinstance(
st_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
if _is_weak_dtype(st_dtype):
raise ValueError
if isinstance(
dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
if _is_weak_dtype(dtype):
st_kind_num = _strong_dtype_num_kind(st_dtype)
kind_num = _weak_type_num_kind(dtype)
if kind_num > st_kind_num:
Expand Down
29 changes: 28 additions & 1 deletion dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def __init__(
docs,
binary_inplace_fn=None,
acceptance_fn=None,
weak_type_resolver=None,
):
self.__name__ = "BinaryElementwiseFunc"
self.name_ = name
Expand All @@ -428,6 +429,10 @@ def __init__(
self.acceptance_fn_ = acceptance_fn
else:
self.acceptance_fn_ = _acceptance_fn_default_binary
if callable(weak_type_resolver):
self.weak_type_resolver_ = weak_type_resolver
else:
self.weak_type_resolver_ = _resolve_weak_types

def __str__(self):
return f"<{self.__name__} '{self.name_}'>"
Expand Down Expand Up @@ -476,6 +481,26 @@ def get_type_promotion_path_acceptance_function(self):
"""
return self.acceptance_fn_

def get_array_dtype_scalar_type_resolver_function(self):
"""Returns the function which determines how to treat
Python scalar types for this elementwise binary function.

Resolver influences what type the scalar will be
treated as prior to type promotion behavior.
The function takes 3 arguments:

Args:
o1_dtype (object, dtype):
A class representing a Python scalar type or a ``dtype``
o2_dtype (object, dtype):
A class representing a Python scalar type or a ``dtype``
sycl_dev (:class:`dpctl.SyclDevice`):
Device on which function evaluation is carried out.

One of ``o1_dtype`` and ``o2_dtype`` must be a ``dtype`` instance.
"""
return self.weak_type_resolver_

@property
def nin(self):
"""
Expand Down Expand Up @@ -579,7 +604,9 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
raise ValueError("Operands have unsupported data types")

o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)
o1_dtype, o2_dtype = self.weak_type_resolver_(
o1_dtype, o2_dtype, sycl_dev
)

buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
o1_dtype,
Expand Down
27 changes: 23 additions & 4 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_acceptance_fn_negative,
_acceptance_fn_reciprocal,
_acceptance_fn_subtract,
_resolve_weak_types_comparisons,
)

# U01: ==== ABS (x)
Expand Down Expand Up @@ -690,7 +691,11 @@
"""

equal = BinaryElementwiseFunc(
"equal", ti._equal_result_type, ti._equal, _equal_docstring_
"equal",
ti._equal_result_type,
ti._equal,
_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
)
del _equal_docstring_

Expand Down Expand Up @@ -845,7 +850,11 @@
"""

greater = BinaryElementwiseFunc(
"greater", ti._greater_result_type, ti._greater, _greater_docstring_
"greater",
ti._greater_result_type,
ti._greater,
_greater_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
)
del _greater_docstring_

Expand Down Expand Up @@ -881,6 +890,7 @@
ti._greater_equal_result_type,
ti._greater_equal,
_greater_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
)
del _greater_equal_docstring_

Expand Down Expand Up @@ -1027,7 +1037,11 @@
"""

less = BinaryElementwiseFunc(
"less", ti._less_result_type, ti._less, _less_docstring_
"less",
ti._less_result_type,
ti._less,
_less_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
)
del _less_docstring_

Expand Down Expand Up @@ -1063,6 +1077,7 @@
ti._less_equal_result_type,
ti._less_equal,
_less_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
)
del _less_equal_docstring_

Expand Down Expand Up @@ -1499,7 +1514,11 @@
"""

not_equal = BinaryElementwiseFunc(
"not_equal", ti._not_equal_result_type, ti._not_equal, _not_equal_docstring_
"not_equal",
ti._not_equal_result_type,
ti._not_equal,
_not_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
)
del _not_equal_docstring_

Expand Down
80 changes: 63 additions & 17 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,21 +346,17 @@ def _strong_dtype_num_kind(o):
raise ValueError(f"Unrecognized kind {k} for dtype {o}")


def _is_weak_dtype(dtype):
return isinstance(
dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
)


def _resolve_weak_types(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050"
if isinstance(
o1_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
if isinstance(
o2_dtype,
(
WeakBooleanType,
WeakIntegralType,
WeakFloatingType,
WeakComplexType,
),
):
if _is_weak_dtype(o1_dtype):
if _is_weak_dtype(o2_dtype):
raise ValueError
o1_kind_num = _weak_type_num_kind(o1_dtype)
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
Expand All @@ -377,10 +373,54 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
else:
return o2_dtype, o2_dtype
elif isinstance(
o2_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
elif _is_weak_dtype(o2_dtype):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
o2_kind_num = _weak_type_num_kind(o2_dtype)
if o2_kind_num > o1_kind_num:
if isinstance(o2_dtype, WeakIntegralType):
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
if isinstance(o2_dtype, WeakComplexType):
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
return o1_dtype, dpt.complex64
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
return (
o1_dtype,
_to_device_supported_dtype(dpt.float64, dev),
)
else:
return o1_dtype, o1_dtype
else:
return o1_dtype, o2_dtype


def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050 for comparisons,"
"where result type is known to be `bool` and special behavior"
"is needed to handle mixed integer kinds"
if _is_weak_dtype(o1_dtype):
if _is_weak_dtype(o2_dtype):
raise ValueError
o1_kind_num = _weak_type_num_kind(o1_dtype)
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
if o1_kind_num > o2_kind_num:
if isinstance(o1_dtype, WeakIntegralType):
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
if isinstance(o1_dtype, WeakComplexType):
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
return dpt.complex64, o2_dtype
return (
_to_device_supported_dtype(dpt.complex128, dev),
o2_dtype,
)
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
else:
if isinstance(o1_dtype, WeakIntegralType):
if o2_dtype.kind == "u":
# Python scalar may be negative, assumes mixed int loops
# exist
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
return o2_dtype, o2_dtype
elif _is_weak_dtype(o2_dtype):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
o2_kind_num = _weak_type_num_kind(o2_dtype)
if o2_kind_num > o1_kind_num:
Expand All @@ -395,6 +435,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
_to_device_supported_dtype(dpt.float64, dev),
)
else:
if isinstance(o2_dtype, WeakIntegralType):
if o1_dtype.kind == "u":
# Python scalar may be negative, assumes mixed int loops
# exist
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
return o1_dtype, o1_dtype
else:
return o1_dtype, o2_dtype
Expand Down Expand Up @@ -789,6 +834,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"_acceptance_fn_negative",
"_acceptance_fn_subtract",
"_resolve_weak_types",
"_resolve_weak_types_comparisons",
"_weak_type_num_kind",
"_strong_dtype_num_kind",
"can_cast",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,25 @@ template <typename argT1, typename argT2, typename resT> struct EqualFunctor
#endif
}
else {
return (in1 == in2);
if constexpr (std::is_integral_v<argT1> &&
std::is_integral_v<argT2> &&
std::is_signed_v<argT1> != std::is_signed_v<argT2>)
{
if constexpr (std::is_signed_v<argT1> &&
!std::is_signed_v<argT2>) {
return (in1 < 0) ? false : (static_cast<argT2>(in1) == in2);
}
else {
if constexpr (!std::is_signed_v<argT1> &&
std::is_signed_v<argT2>) {
return (in2 < 0) ? false
: (in1 == static_cast<argT1>(in2));
}
}
}
else {
return (in1 == in2);
}
}
}

Expand Down Expand Up @@ -151,6 +169,10 @@ template <typename T1, typename T2> struct EqualOutputType
bool>,
td_ns::
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
td_ns::
BinaryTypeMapResultEntry<T1, std::uint64_t, T2, std::int64_t, bool>,
td_ns::
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::uint64_t, bool>,
oleksandr-pavlyk marked this conversation as resolved.
Show resolved Hide resolved
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,25 @@ template <typename argT1, typename argT2, typename resT> struct GreaterFunctor
return greater_complex<argT1>(in1, in2);
}
else {
return (in1 > in2);
if constexpr (std::is_integral_v<argT1> &&
std::is_integral_v<argT2> &&
std::is_signed_v<argT1> != std::is_signed_v<argT2>)
{
if constexpr (std::is_signed_v<argT1> &&
!std::is_signed_v<argT2>) {
return (in1 < 0) ? false : (static_cast<argT2>(in1) > in2);
}
else {
if constexpr (!std::is_signed_v<argT1> &&
std::is_signed_v<argT2>) {
return (in2 < 0) ? true
: (in1 > static_cast<argT1>(in2));
}
}
}
else {
return (in1 > in2);
}
}
}

Expand Down Expand Up @@ -148,6 +166,10 @@ template <typename T1, typename T2> struct GreaterOutputType
bool>,
td_ns::
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
td_ns::
BinaryTypeMapResultEntry<T1, std::uint64_t, T2, std::int64_t, bool>,
td_ns::
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::uint64_t, bool>,
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
Expand Down
Loading