diff --git a/dpctl/tensor/_clip.py b/dpctl/tensor/_clip.py index 8aaad3544e..5eb0bc1216 100644 --- a/dpctl/tensor/_clip.py +++ b/dpctl/tensor/_clip.py @@ -34,10 +34,9 @@ from dpctl.utils import ExecutionPlacementError from ._type_utils import ( - WeakBooleanType, WeakComplexType, - WeakFloatingType, WeakIntegralType, + _is_weak_dtype, _strong_dtype_num_kind, _weak_type_num_kind, ) @@ -47,29 +46,10 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev): "Resolves weak data types per NEP-0050," "where the second and third arguments are" "permitted to be weak types" - if isinstance( - st_dtype, - ( - WeakBooleanType, - WeakIntegralType, - WeakFloatingType, - WeakComplexType, - ), - ): + if _is_weak_dtype(st_dtype): raise ValueError - if isinstance( - dtype1, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ): - if isinstance( - dtype2, - ( - WeakBooleanType, - WeakIntegralType, - WeakFloatingType, - WeakComplexType, - ), - ): + if _is_weak_dtype(dtype1): + if _is_weak_dtype(dtype2): kind_num1 = _weak_type_num_kind(dtype1) kind_num2 = _weak_type_num_kind(dtype2) st_kind_num = _strong_dtype_num_kind(st_dtype) @@ -120,10 +100,7 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev): return _to_device_supported_dtype(dpt.float64, dev), dtype2 else: return max_dtype, dtype2 - elif isinstance( - dtype2, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ): + elif _is_weak_dtype(dtype2): max_dt_num_kind, max_dtype = max( [ (_strong_dtype_num_kind(st_dtype), st_dtype), @@ -152,15 +129,9 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev): def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev): "Resolves one weak data type with one strong data type per NEP-0050" - if isinstance( - st_dtype, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ): + if _is_weak_dtype(st_dtype): raise ValueError - if isinstance( - dtype, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ): + if _is_weak_dtype(dtype): st_kind_num = _strong_dtype_num_kind(st_dtype) kind_num = _weak_type_num_kind(dtype) if kind_num > st_kind_num: diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index fbfe22410d..75f13942c3 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -416,6 +416,7 @@ def __init__( docs, binary_inplace_fn=None, acceptance_fn=None, + weak_type_resolver=None, ): self.__name__ = "BinaryElementwiseFunc" self.name_ = name @@ -428,6 +429,10 @@ def __init__( self.acceptance_fn_ = acceptance_fn else: self.acceptance_fn_ = _acceptance_fn_default_binary + if callable(weak_type_resolver): + self.weak_type_resolver_ = weak_type_resolver + else: + self.weak_type_resolver_ = _resolve_weak_types def __str__(self): return f"<{self.__name__} '{self.name_}'>" @@ -476,6 +481,26 @@ def get_type_promotion_path_acceptance_function(self): """ return self.acceptance_fn_ + def get_array_dtype_scalar_type_resolver_function(self): + """Returns the function which determines how to treat + Python scalar types for this elementwise binary function. + + Resolver influences what type the scalar will be + treated as prior to type promotion behavior. + The function takes 3 arguments: + + Args: + o1_dtype (object, dtype): + A class representing a Python scalar type or a ``dtype`` + o2_dtype (object, dtype): + A class representing a Python scalar type or a ``dtype`` + sycl_dev (:class:`dpctl.SyclDevice`): + Device on which function evaluation is carried out. + + One of ``o1_dtype`` and ``o2_dtype`` must be a ``dtype`` instance. + """ + return self.weak_type_resolver_ + @property def nin(self): """ @@ -579,7 +604,9 @@ def __call__(self, o1, o2, /, *, out=None, order="K"): if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)): raise ValueError("Operands have unsupported data types") - o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev) + o1_dtype, o2_dtype = self.weak_type_resolver_( + o1_dtype, o2_dtype, sycl_dev + ) buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( o1_dtype, diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 5a47962bbd..fbf2ad2c6b 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -22,6 +22,7 @@ _acceptance_fn_negative, _acceptance_fn_reciprocal, _acceptance_fn_subtract, + _resolve_weak_types_comparisons, ) # U01: ==== ABS (x) @@ -690,7 +691,11 @@ """ equal = BinaryElementwiseFunc( - "equal", ti._equal_result_type, ti._equal, _equal_docstring_ + "equal", + ti._equal_result_type, + ti._equal, + _equal_docstring_, + weak_type_resolver=_resolve_weak_types_comparisons, ) del _equal_docstring_ @@ -845,7 +850,11 @@ """ greater = BinaryElementwiseFunc( - "greater", ti._greater_result_type, ti._greater, _greater_docstring_ + "greater", + ti._greater_result_type, + ti._greater, + _greater_docstring_, + weak_type_resolver=_resolve_weak_types_comparisons, ) del _greater_docstring_ @@ -881,6 +890,7 @@ ti._greater_equal_result_type, ti._greater_equal, _greater_equal_docstring_, + weak_type_resolver=_resolve_weak_types_comparisons, ) del _greater_equal_docstring_ @@ -1027,7 +1037,11 @@ """ less = BinaryElementwiseFunc( - "less", ti._less_result_type, ti._less, _less_docstring_ + "less", + ti._less_result_type, + ti._less, + _less_docstring_, + weak_type_resolver=_resolve_weak_types_comparisons, ) del _less_docstring_ @@ -1063,6 +1077,7 @@ ti._less_equal_result_type, ti._less_equal, _less_equal_docstring_, + weak_type_resolver=_resolve_weak_types_comparisons, ) del _less_equal_docstring_ @@ -1499,7 +1514,11 @@ """ not_equal = BinaryElementwiseFunc( - "not_equal", ti._not_equal_result_type, ti._not_equal, _not_equal_docstring_ + "not_equal", + ti._not_equal_result_type, + ti._not_equal, + _not_equal_docstring_, + weak_type_resolver=_resolve_weak_types_comparisons, ) del _not_equal_docstring_ diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index d8f6b8d28d..691f538336 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -346,21 +346,17 @@ def _strong_dtype_num_kind(o): raise ValueError(f"Unrecognized kind {k} for dtype {o}") +def _is_weak_dtype(dtype): + return isinstance( + dtype, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), + ) + + def _resolve_weak_types(o1_dtype, o2_dtype, dev): "Resolves weak data type per NEP-0050" - if isinstance( - o1_dtype, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ): - if isinstance( - o2_dtype, - ( - WeakBooleanType, - WeakIntegralType, - WeakFloatingType, - WeakComplexType, - ), - ): + if _is_weak_dtype(o1_dtype): + if _is_weak_dtype(o2_dtype): raise ValueError o1_kind_num = _weak_type_num_kind(o1_dtype) o2_kind_num = _strong_dtype_num_kind(o2_dtype) @@ -377,10 +373,54 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): return _to_device_supported_dtype(dpt.float64, dev), o2_dtype else: return o2_dtype, o2_dtype - elif isinstance( - o2_dtype, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ): + elif _is_weak_dtype(o2_dtype): + o1_kind_num = _strong_dtype_num_kind(o1_dtype) + o2_kind_num = _weak_type_num_kind(o2_dtype) + if o2_kind_num > o1_kind_num: + if isinstance(o2_dtype, WeakIntegralType): + return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(o2_dtype, WeakComplexType): + if o1_dtype is dpt.float16 or o1_dtype is dpt.float32: + return o1_dtype, dpt.complex64 + return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) + return ( + o1_dtype, + _to_device_supported_dtype(dpt.float64, dev), + ) + else: + return o1_dtype, o1_dtype + else: + return o1_dtype, o2_dtype + + +def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev): + "Resolves weak data type per NEP-0050 for comparisons," + "where result type is known to be `bool` and special behavior" + "is needed to handle mixed integer kinds" + if _is_weak_dtype(o1_dtype): + if _is_weak_dtype(o2_dtype): + raise ValueError + o1_kind_num = _weak_type_num_kind(o1_dtype) + o2_kind_num = _strong_dtype_num_kind(o2_dtype) + if o1_kind_num > o2_kind_num: + if isinstance(o1_dtype, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + if isinstance(o1_dtype, WeakComplexType): + if o2_dtype is dpt.float16 or o2_dtype is dpt.float32: + return dpt.complex64, o2_dtype + return ( + _to_device_supported_dtype(dpt.complex128, dev), + o2_dtype, + ) + return _to_device_supported_dtype(dpt.float64, dev), o2_dtype + else: + if isinstance(o1_dtype, WeakIntegralType): + if o2_dtype.kind == "u": + # Python scalar may be negative, assumes mixed int loops + # exist + return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + return o2_dtype, o2_dtype + elif _is_weak_dtype(o2_dtype): o1_kind_num = _strong_dtype_num_kind(o1_dtype) o2_kind_num = _weak_type_num_kind(o2_dtype) if o2_kind_num > o1_kind_num: @@ -395,6 +435,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): _to_device_supported_dtype(dpt.float64, dev), ) else: + if isinstance(o2_dtype, WeakIntegralType): + if o1_dtype.kind == "u": + # Python scalar may be negative, assumes mixed int loops + # exist + return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) return o1_dtype, o1_dtype else: return o1_dtype, o2_dtype @@ -789,6 +834,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q): "_acceptance_fn_negative", "_acceptance_fn_subtract", "_resolve_weak_types", + "_resolve_weak_types_comparisons", "_weak_type_num_kind", "_strong_dtype_num_kind", "can_cast", diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index 269cfa5d2e..61ac3ca128 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -76,7 +76,25 @@ template struct EqualFunctor #endif } else { - return (in1 == in2); + if constexpr (std::is_integral_v && + std::is_integral_v && + std::is_signed_v != std::is_signed_v) + { + if constexpr (std::is_signed_v && + !std::is_signed_v) { + return (in1 < 0) ? false : (static_cast(in1) == in2); + } + else { + if constexpr (!std::is_signed_v && + std::is_signed_v) { + return (in2 < 0) ? false + : (in1 == static_cast(in2)); + } + } + } + else { + return (in1 == in2); + } } } @@ -151,6 +169,10 @@ template struct EqualOutputType bool>, td_ns:: BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp index 5a2aa4651a..768a5bb7f8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp @@ -71,7 +71,25 @@ template struct GreaterFunctor return greater_complex(in1, in2); } else { - return (in1 > in2); + if constexpr (std::is_integral_v && + std::is_integral_v && + std::is_signed_v != std::is_signed_v) + { + if constexpr (std::is_signed_v && + !std::is_signed_v) { + return (in1 < 0) ? false : (static_cast(in1) > in2); + } + else { + if constexpr (!std::is_signed_v && + std::is_signed_v) { + return (in2 < 0) ? true + : (in1 > static_cast(in2)); + } + } + } + else { + return (in1 > in2); + } } } @@ -148,6 +166,10 @@ template struct GreaterOutputType bool>, td_ns:: BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp index eb8bd51584..8569eb0216 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp @@ -72,7 +72,25 @@ struct GreaterEqualFunctor return greater_equal_complex(in1, in2); } else { - return (in1 >= in2); + if constexpr (std::is_integral_v && + std::is_integral_v && + std::is_signed_v != std::is_signed_v) + { + if constexpr (std::is_signed_v && + !std::is_signed_v) { + return (in1 < 0) ? false : (static_cast(in1) >= in2); + } + else { + if constexpr (!std::is_signed_v && + std::is_signed_v) { + return (in2 < 0) ? true + : (in1 >= static_cast(in2)); + } + } + } + else { + return (in1 >= in2); + } } } @@ -149,6 +167,10 @@ template struct GreaterEqualOutputType bool>, td_ns:: BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp index 7ecb7a064a..294a78ba2f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -70,7 +70,25 @@ template struct LessFunctor return less_complex(in1, in2); } else { - return (in1 < in2); + if constexpr (std::is_integral_v && + std::is_integral_v && + std::is_signed_v != std::is_signed_v) + { + if constexpr (std::is_signed_v && + !std::is_signed_v) { + return (in1 < 0) ? true : (static_cast(in1) < in2); + } + else { + if constexpr (!std::is_signed_v && + std::is_signed_v) { + return (in2 < 0) ? false + : (in1 < static_cast(in2)); + } + } + } + else { + return (in1 < in2); + } } } @@ -79,7 +97,6 @@ template struct LessFunctor operator()(const sycl::vec &in1, const sycl::vec &in2) const { - auto tmp = (in1 < in2); if constexpr (std::is_same_v struct LessOutputType bool>, td_ns:: BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp index 5c878d559b..7b18a0b045 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp @@ -70,7 +70,25 @@ template struct LessEqualFunctor return less_equal_complex(in1, in2); } else { - return (in1 <= in2); + if constexpr (std::is_integral_v && + std::is_integral_v && + std::is_signed_v != std::is_signed_v) + { + if constexpr (std::is_signed_v && + !std::is_signed_v) { + return (in1 < 0) ? true : (static_cast(in1) <= in2); + } + else { + if constexpr (!std::is_signed_v && + std::is_signed_v) { + return (in2 < 0) ? false + : (in1 <= static_cast(in2)); + } + } + } + else { + return (in1 <= in2); + } } } @@ -147,6 +165,10 @@ template struct LessEqualOutputType bool>, td_ns:: BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp index 73fcf26677..c31a05b266 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp @@ -61,15 +61,18 @@ template struct NotEqualFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - if constexpr (std::is_same_v> && - std::is_same_v) + if constexpr (std::is_integral_v && std::is_integral_v && + std::is_signed_v != std::is_signed_v) { - return (std::real(in1) != in2 || std::imag(in1) != 0.0f); - } - else if constexpr (std::is_same_v && - std::is_same_v>) - { - return (in1 != std::real(in2) || std::imag(in2) != 0.0f); + if constexpr (std::is_signed_v && !std::is_signed_v) { + return (in1 < 0) ? true : (static_cast(in1) != in2); + } + else { + if constexpr (!std::is_signed_v && + std::is_signed_v) { + return (in2 < 0) ? true : (in1 != static_cast(in2)); + } + } } else { return (in1 != in2); @@ -147,6 +150,10 @@ template struct NotEqualOutputType bool>, td_ns:: BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, @@ -160,10 +167,6 @@ template struct NotEqualOutputType T2, std::complex, bool>, - td_ns:: - BinaryTypeMapResultEntry, bool>, - td_ns:: - BinaryTypeMapResultEntry, T2, float, bool>, td_ns::DefaultResultEntry>::result_type; }; diff --git a/dpctl/tests/elementwise/test_greater.py b/dpctl/tests/elementwise/test_greater.py index ec1412f8a2..d9fd852f18 100644 --- a/dpctl/tests/elementwise/test_greater.py +++ b/dpctl/tests/elementwise/test_greater.py @@ -263,3 +263,21 @@ def __sycl_usm_array_interface__(self): c = Canary() with pytest.raises(ValueError): dpt.greater(a, c) + + +def test_greater_mixed_integer_kinds(): + get_queue_or_skip() + + x1 = dpt.flip(dpt.arange(-9, 1, dtype="i8")) + x2 = dpt.arange(10, dtype="u8") + + # u8 - i8 + res = dpt.greater(x2, x1) + assert dpt.all(res[1:]) + assert not res[0] + # i8 - u8 + assert not dpt.any(dpt.greater(x1, x2)) + + # Python scalar + assert dpt.all(dpt.greater(x2, -1)) + assert not dpt.any(dpt.greater(-1, x2)) diff --git a/dpctl/tests/elementwise/test_greater_equal.py b/dpctl/tests/elementwise/test_greater_equal.py index fa8ba17c9f..0f24aaa9b4 100644 --- a/dpctl/tests/elementwise/test_greater_equal.py +++ b/dpctl/tests/elementwise/test_greater_equal.py @@ -261,3 +261,22 @@ def __sycl_usm_array_interface__(self): c = Canary() with pytest.raises(ValueError): dpt.greater_equal(a, c) + + +def test_greater_equal_mixed_integer_kinds(): + get_queue_or_skip() + + x1 = dpt.flip(dpt.arange(-9, 1, dtype="i8")) + x2 = dpt.arange(10, dtype="u8") + + # u8 - i8 + res = dpt.greater_equal(x2, x1) + assert dpt.all(res) + # i8 - u8 + res = dpt.greater_equal(x1, x2) + assert not dpt.any(res[1:]) + assert res[0] + + # Python scalar + assert dpt.all(dpt.greater_equal(x2, -1)) + assert not dpt.any(dpt.greater_equal(-1, x2)) diff --git a/dpctl/tests/elementwise/test_less.py b/dpctl/tests/elementwise/test_less.py index bd7cdd1b5b..b1cb497b04 100644 --- a/dpctl/tests/elementwise/test_less.py +++ b/dpctl/tests/elementwise/test_less.py @@ -263,3 +263,21 @@ def __sycl_usm_array_interface__(self): c = Canary() with pytest.raises(ValueError): dpt.less(a, c) + + +def test_less_mixed_integer_kinds(): + get_queue_or_skip() + + x1 = dpt.flip(dpt.arange(-9, 1, dtype="i8")) + x2 = dpt.arange(10, dtype="u8") + + # u8 - i8 + assert not dpt.any(dpt.less(x2, x1)) + # i8 - u8 + res = dpt.less(x1, x2) + assert not res[0] + assert dpt.all(res[1:]) + + # Python scalar + assert not dpt.any(dpt.less(x2, -1)) + assert dpt.all(dpt.less(-1, x2)) diff --git a/dpctl/tests/elementwise/test_less_equal.py b/dpctl/tests/elementwise/test_less_equal.py index 57e4e14e02..e189d94cdc 100644 --- a/dpctl/tests/elementwise/test_less_equal.py +++ b/dpctl/tests/elementwise/test_less_equal.py @@ -262,3 +262,21 @@ def __sycl_usm_array_interface__(self): c = Canary() with pytest.raises(ValueError): dpt.less_equal(a, c) + + +def test_less_equal_mixed_integer_kinds(): + get_queue_or_skip() + + x1 = dpt.flip(dpt.arange(-9, 1, dtype="i8")) + x2 = dpt.arange(10, dtype="u8") + + # u8 - i8 + res = dpt.less_equal(x2, x1) + assert res[0] + assert not dpt.any(res[1:]) + # i8 - u8 + assert dpt.all(dpt.less_equal(x1, x2)) + + # Python scalar + assert not dpt.any(dpt.less_equal(x2, -1)) + assert dpt.all(dpt.less_equal(-1, x2)) diff --git a/dpctl/tests/test_usm_ndarray_operators.py b/dpctl/tests/test_usm_ndarray_operators.py index a2571d9f2a..2396dc4109 100644 --- a/dpctl/tests/test_usm_ndarray_operators.py +++ b/dpctl/tests/test_usm_ndarray_operators.py @@ -124,3 +124,19 @@ def test_mat_ops(namespace): M.__matmul__(M) M.__imatmul__(M) M.__rmatmul__(M) + + +@pytest.mark.parametrize("namespace", [dpt, Dummy()]) +def test_comp_ops(namespace): + try: + X = dpt.ones(1, dtype="u8") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + X._set_namespace(namespace) + assert X.__array_namespace__() is namespace + assert X.__gt__(-1) + assert X.__ge__(-1) + assert not X.__lt__(-1) + assert not X.__le__(-1) + assert not X.__eq__(-1) + assert X.__ne__(-1)