Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit "same_kind" casting for usm_ndarray element-wise in-place operators #1827

Merged
merged 10 commits into from
Sep 11, 2024
Merged
146 changes: 143 additions & 3 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
_find_buf_dtype_in_place_op,
_resolve_weak_types,
_to_device_supported_dtype,
)
Expand Down Expand Up @@ -213,8 +214,8 @@ def __call__(self, x, /, *, out=None, order="K"):

if res_dt != out.dtype:
raise ValueError(
f"Output array of type {res_dt} is needed,"
f" got {out.dtype}"
f"Output array of type {res_dt} is needed, "
ndgrigorian marked this conversation as resolved.
Show resolved Hide resolved
f"got {out.dtype}"
)

if (
Expand Down Expand Up @@ -650,7 +651,7 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):

if res_dt != out.dtype:
raise ValueError(
f"Output array of type {res_dt} is needed,"
f"Output array of type {res_dt} is needed, "
f"got {out.dtype}"
)

Expand Down Expand Up @@ -927,3 +928,142 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
)
_manager.add_event_pair(ht_, bf_ev)
return out

def _inplace_op(self, o1, o2):
if self.binary_inplace_fn_ is not None:
ndgrigorian marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(o1, dpt.usm_ndarray):
raise TypeError(
"Expected first argument to be "
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
)
if not o1.flags.writable:
raise ValueError("provided left-hand side array is read-only")
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
q2, o2_usm_type = _get_queue_usm_type(o2)
if q2 is None:
exec_q = q1
res_usm_type = o1_usm_type
else:
exec_q = dpctl.utils.get_execution_queue((q1, q2))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
res_usm_type = dpctl.utils.get_coerced_usm_type(
(
o1_usm_type,
o2_usm_type,
)
)
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
o1_shape = o1.shape
o2_shape = _get_shape(o2)
if not isinstance(o2_shape, (tuple, list)):
raise TypeError(
"Shape of second argument can not be inferred. "
"Expected list or tuple."
)
try:
res_shape = _broadcast_shape_impl(
[
o1_shape,
o2_shape,
]
)
except ValueError:
raise ValueError(
"operands could not be broadcast together with shapes "
f"{o1_shape} and {o2_shape}"
)

if res_shape != o1_shape:
raise ValueError(
"The shape of the non-broadcastable left-hand "
f"side {o1_shape} is inconsistent with the "
f"broadcast shape {res_shape}."
)

sycl_dev = exec_q.sycl_device
o1_dtype = o1.dtype
o2_dtype = _get_dtype(o2, sycl_dev)
if not _validate_dtype(o2_dtype):
raise ValueError("Operand has an unsupported data type")

o1_dtype, o2_dtype = self.weak_type_resolver_(
o1_dtype, o2_dtype, sycl_dev
)

buf_dt, res_dt = _find_buf_dtype_in_place_op(
o1_dtype,
o2_dtype,
self.result_type_resolver_fn_,
sycl_dev,
)

if res_dt is None:
raise ValueError(
f"function '{self.name_}' does not support input types "
f"({o1_dtype}, {o2_dtype}), "
"and the inputs could not be safely coerced to any "
"supported types according to the casting rule "
"''same_kind''."
)

if res_dt != o1_dtype:
raise ValueError(
f"Output array of type {res_dt} is needed, "
f"got {o1_dtype}"
)

_manager = SequentialOrderManager[exec_q]
if isinstance(o2, dpt.usm_ndarray):
src2 = o2
if (
ti._array_overlap(o2, o1)
and not ti._same_logical_tensors(o2, o1)
and buf_dt is None
):
buf_dt = o2_dtype
else:
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
if buf_dt is None:
if src2.shape != res_shape:
src2 = dpt.broadcast_to(src2, res_shape)
dep_evs = _manager.submitted_events
ht_, comp_ev = self.binary_inplace_fn_(
lhs=o1,
rhs=src2,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_, comp_ev)
else:
buf = dpt.empty_like(src2, dtype=buf_dt)
dep_evs = _manager.submitted_events
(
ht_copy_ev,
copy_ev,
) = ti._copy_usm_ndarray_into_usm_ndarray(
src=src2,
dst=buf,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_copy_ev, copy_ev)

buf = dpt.broadcast_to(buf, res_shape)
ht_, bf_ev = self.binary_inplace_fn_(
lhs=o1,
rhs=buf,
sycl_queue=exec_q,
depends=[copy_ev],
)
_manager.add_event_pair(ht_, bf_ev)

return o1
else:
raise ValueError(
"binary function does not have a dedicated in-place "
"implementation"
)
16 changes: 16 additions & 0 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,21 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
return None, None, None


def _find_buf_dtype_in_place_op(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
res_dt = query_fn(arg1_dtype, arg2_dtype)
if res_dt:
return None, res_dt

_fp16 = sycl_dev.has_aspect_fp16
_fp64 = sycl_dev.has_aspect_fp64
if _can_cast(arg2_dtype, arg1_dtype, _fp16, _fp64, casting="same_kind"):
res_dt = query_fn(arg1_dtype, arg1_dtype)
if res_dt:
return arg1_dtype, res_dt

return None, None


class WeakBooleanType:
"Python type representing type of Python boolean objects"

Expand Down Expand Up @@ -959,4 +974,5 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"WeakComplexType",
"_default_accumulation_dtype",
"_default_accumulation_dtype_fp_types",
"_find_buf_dtype_in_place_op",
]
26 changes: 13 additions & 13 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1508,43 +1508,43 @@ cdef class usm_ndarray:
return dpctl.tensor.bitwise_xor(other, self)

def __iadd__(self, other):
return dpctl.tensor.add(self, other, out=self)
return dpctl.tensor.add._inplace_op(self, other)

def __iand__(self, other):
return dpctl.tensor.bitwise_and(self, other, out=self)
return dpctl.tensor.bitwise_and._inplace_op(self, other)

def __ifloordiv__(self, other):
return dpctl.tensor.floor_divide(self, other, out=self)
return dpctl.tensor.floor_divide._inplace_op(self, other)

def __ilshift__(self, other):
return dpctl.tensor.bitwise_left_shift(self, other, out=self)
return dpctl.tensor.bitwise_left_shift._inplace_op(self, other)

def __imatmul__(self, other):
return dpctl.tensor.matmul(self, other, out=self)
return dpctl.tensor.matmul(self, other, out=self, dtype=self.dtype)

def __imod__(self, other):
return dpctl.tensor.remainder(self, other, out=self)
return dpctl.tensor.remainder._inplace_op(self, other)

def __imul__(self, other):
return dpctl.tensor.multiply(self, other, out=self)
return dpctl.tensor.multiply._inplace_op(self, other)

def __ior__(self, other):
return dpctl.tensor.bitwise_or(self, other, out=self)
return dpctl.tensor.bitwise_or._inplace_op(self, other)

def __ipow__(self, other):
return dpctl.tensor.pow(self, other, out=self)
return dpctl.tensor.pow._inplace_op(self, other)

def __irshift__(self, other):
return dpctl.tensor.bitwise_right_shift(self, other, out=self)
return dpctl.tensor.bitwise_right_shift._inplace_op(self, other)

def __isub__(self, other):
return dpctl.tensor.subtract(self, other, out=self)
return dpctl.tensor.subtract._inplace_op(self, other)

def __itruediv__(self, other):
return dpctl.tensor.divide(self, other, out=self)
return dpctl.tensor.divide._inplace_op(self, other)

def __ixor__(self, other):
return dpctl.tensor.bitwise_xor(self, other, out=self)
return dpctl.tensor.bitwise_xor._inplace_op(self, other)

def __str__(self):
return usm_ndarray_str(self)
Expand Down
70 changes: 62 additions & 8 deletions dpctl/tests/elementwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 += ar2
assert (
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
Expand All @@ -373,9 +373,25 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
else:
with pytest.raises(ValueError):
ar1 += ar2

ar1 = dpt.ones(sz, dtype=op1_dtype)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
dpt.add(ar1, ar2, out=ar1)
assert (
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
).all()

ar3 = dpt.ones(sz, dtype=op1_dtype)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)[::2]
dpt.add(ar3, ar4, out=ar3)
assert (
dpt.asnumpy(ar3) == np.full(ar3.shape, 2, dtype=ar3.dtype)
).all()
else:
with pytest.raises(ValueError):
dpt.add(ar1, ar2, out=ar1)

# out is second arg
ar1 = dpt.ones(sz, dtype=op1_dtype)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
Expand All @@ -401,7 +417,7 @@ def test_add_inplace_broadcasting():
m = dpt.ones((100, 5), dtype="i4")
v = dpt.arange(5, dtype="i4")

m += v
dpt.add(m, v, out=m)
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()

# check case where second arg is out
Expand All @@ -411,6 +427,26 @@ def test_add_inplace_broadcasting():
).all()


def test_add_inplace_operator_broadcasting():
get_queue_or_skip()

m = dpt.ones((100, 5), dtype="i4")
v = dpt.arange(5, dtype="i4")

m += v
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()


def test_add_inplace_operator_mutual_broadcast():
get_queue_or_skip()

x1 = dpt.ones((1, 10), dtype="i4")
x2 = dpt.ones((10, 1), dtype="i4")

with pytest.raises(ValueError):
dpt.add._inplace_op(x1, x2)


def test_add_inplace_errors():
get_queue_or_skip()
try:
Expand All @@ -425,27 +461,45 @@ def test_add_inplace_errors():
ar1 = dpt.ones(2, dtype="float32", sycl_queue=gpu_queue)
ar2 = dpt.ones_like(ar1, sycl_queue=cpu_queue)
with pytest.raises(ExecutionPlacementError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)

ar1 = dpt.ones(2, dtype="float32")
ar2 = dpt.ones(3, dtype="float32")
with pytest.raises(ValueError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)

ar1 = np.ones(2, dtype="float32")
ar2 = dpt.ones(2, dtype="float32")
with pytest.raises(TypeError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)

ar1 = dpt.ones(2, dtype="float32")
ar2 = dict()
with pytest.raises(ValueError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)

ar1 = dpt.ones((2, 1), dtype="float32")
ar2 = dpt.ones((1, 2), dtype="float32")
with pytest.raises(ValueError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)


def test_add_inplace_operator_errors():
q1 = get_queue_or_skip()
q2 = get_queue_or_skip()

x = dpt.ones(10, dtype="i4", sycl_queue=q1)
with pytest.raises(TypeError):
dpt.add._inplace_op(dict(), x)

x.flags["W"] = False
with pytest.raises(ValueError):
dpt.add._inplace_op(x, 2)

x_q1 = dpt.ones(10, dtype="i4", sycl_queue=q1)
x_q2 = dpt.ones(10, dtype="i4", sycl_queue=q2)
with pytest.raises(ExecutionPlacementError):
dpt.add._inplace_op(x_q1, x_q2)


def test_add_inplace_same_tensors():
Expand Down
Loading
Loading