Skip to content

Commit

Permalink
Adjusts tests for in-place element-wise operations to account for `"s…
Browse files Browse the repository at this point in the history
…ame_kind"` casting
  • Loading branch information
ndgrigorian committed Sep 10, 2024
1 parent bab3571 commit 79208c8
Show file tree
Hide file tree
Showing 11 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 += ar2
assert (
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_bitwise_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_bitwise_and_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 &= ar2
assert dpt.all(ar1 == 1)

Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_bitwise_left_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_bitwise_left_shift_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 <<= ar2
assert dpt.all(ar1 == 2)

Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_bitwise_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_bitwise_or_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 |= ar2
assert dpt.all(ar1 == 1)

Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_bitwise_xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_bitwise_xor_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 ^= ar2
assert dpt.all(ar1 == 0)

Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/elementwise/test_divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
_fp64 = dev.has_aspect_fp64
# out array only valid if it is inexact
if (
_can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64)
_can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind")
and dpt.dtype(op1_dtype).kind in "fc"
):
ar1 /= ar2
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_divide_gh_1711():


# don't test for overflowing double as Python won't cast
# an Python integer of that size to a Python float
# a Python integer of that size to a Python float
@pytest.mark.parametrize("fp_dt", [dpt.float16, dpt.float32])
def test_divide_by_scalar_overflow(fp_dt):
q = get_queue_or_skip()
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_floor_divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_floor_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
# out array only valid if it is inexact
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 //= ar2
assert dpt.all(ar1 == 1)

Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_multiply_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 *= ar2
assert (
dpt.asnumpy(ar1) == np.full(ar1.shape, 1, dtype=ar1.dtype)
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_pow_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 **= ar2
assert (
dpt.asnumpy(ar1) == np.full(ar1.shape, 1, dtype=ar1.dtype)
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_remainder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_remainder_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 %= ar2
assert dpt.all(ar1 == dpt.zeros(ar1.shape, dtype=ar1.dtype))

Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/elementwise/test_subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_subtract_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 -= ar2
assert (dpt.asnumpy(ar1) == np.zeros(ar1.shape, dtype=ar1.dtype)).all()

Expand Down

0 comments on commit 79208c8

Please sign in to comment.