Skip to content

Commit

Permalink
Implements BinaryElementwiseFunc._inplace_op method
Browse files Browse the repository at this point in the history
This method permits casting behavior equivalent to `"same_kind"` when using in-place operators by introducing the `_inplace_op` method

Expands this to `__imatmul__` as well through use of the already-implemented `dtype` keyword
  • Loading branch information
ndgrigorian committed Sep 10, 2024
1 parent cd7d41c commit bab3571
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 15 deletions.
130 changes: 128 additions & 2 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
_find_buf_dtype_in_place_op,
_resolve_weak_types,
_to_device_supported_dtype,
)
Expand Down Expand Up @@ -213,7 +214,7 @@ def __call__(self, x, /, *, out=None, order="K"):

if res_dt != out.dtype:
raise ValueError(
f"Output array of type {res_dt} is needed,"
f"Output array of type {res_dt} is needed, "
f" got {out.dtype}"
)

Expand Down Expand Up @@ -650,7 +651,7 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):

if res_dt != out.dtype:
raise ValueError(
f"Output array of type {res_dt} is needed,"
f"Output array of type {res_dt} is needed, "
f"got {out.dtype}"
)

Expand Down Expand Up @@ -927,3 +928,128 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
)
_manager.add_event_pair(ht_, bf_ev)
return out

def _inplace_op(self, o1, o2):
if not isinstance(o1, dpt.usm_ndarray):
raise TypeError(
"Expected first argument to be "
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
)
if not o1.flags.writable:
raise ValueError("provided left-hand side array is read-only")
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
q2, o2_usm_type = _get_queue_usm_type(o2)
if q2 is None:
exec_q = q1
res_usm_type = o1_usm_type
else:
exec_q = dpctl.utils.get_execution_queue((q1, q2))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
res_usm_type = dpctl.utils.get_coerced_usm_type(
(
o1_usm_type,
o2_usm_type,
)
)
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
o1_shape = o1.shape
o2_shape = _get_shape(o2)
if not isinstance(o2_shape, (tuple, list)):
raise TypeError(
"Shape of second argument can not be inferred. "
"Expected list or tuple."
)
try:
res_shape = _broadcast_shape_impl(
[
o1_shape,
o2_shape,
]
)
except ValueError:
raise ValueError(
"operands could not be broadcast together with shapes "
f"{o1_shape} and {o2_shape}"
)
if res_shape != o1_shape:
raise ValueError("")
sycl_dev = exec_q.sycl_device
o1_dtype = o1.dtype
o2_dtype = _get_dtype(o2, sycl_dev)
if not _validate_dtype(o2_dtype):
raise ValueError("Operand has an unsupported data type")

o1_dtype, o2_dtype = self.weak_type_resolver_(
o1_dtype, o2_dtype, sycl_dev
)

buf_dt, res_dt = _find_buf_dtype_in_place_op(
o1_dtype,
o2_dtype,
self.result_type_resolver_fn_,
sycl_dev,
)

if res_dt is None:
raise ValueError(
f"function '{self.name_}' does not support input types "
f"({o1_dtype}, {o2_dtype}), "
"and the inputs could not be safely coerced to any "
"supported types according to the casting rule ''same_kind''."
)

if res_dt != o1_dtype:
raise ValueError(
f"Output array of type {res_dt} is needed, " f"got {o1_dtype}"
)

_manager = SequentialOrderManager[exec_q]
if isinstance(o2, dpt.usm_ndarray):
src2 = o2
if (
ti._array_overlap(o2, o1)
and not ti._same_logical_tensors(o2, o1)
and buf_dt is None
):
buf_dt = o2_dtype
else:
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
if buf_dt is None:
if src2.shape != res_shape:
src2 = dpt.broadcast_to(src2, res_shape)
dep_evs = _manager.submitted_events
ht_, comp_ev = self.binary_inplace_fn_(
lhs=o1,
rhs=src2,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_, comp_ev)
else:
buf = dpt.empty_like(src2, dtype=buf_dt)
dep_evs = _manager.submitted_events
(
ht_copy_ev,
copy_ev,
) = ti._copy_usm_ndarray_into_usm_ndarray(
src=src2,
dst=buf,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_copy_ev, copy_ev)

buf = dpt.broadcast_to(buf, res_shape)
ht_, bf_ev = self.binary_inplace_fn_(
lhs=o1,
rhs=buf,
sycl_queue=exec_q,
depends=[copy_ev],
)
_manager.add_event_pair(ht_, bf_ev)

return o1
16 changes: 16 additions & 0 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,21 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
return None, None, None


def _find_buf_dtype_in_place_op(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
res_dt = query_fn(arg1_dtype, arg2_dtype)
if res_dt:
return None, res_dt

_fp16 = sycl_dev.has_aspect_fp16
_fp64 = sycl_dev.has_aspect_fp64
if _can_cast(arg2_dtype, arg1_dtype, _fp16, _fp64, casting="same_kind"):
res_dt = query_fn(arg1_dtype, arg1_dtype)
if res_dt:
return arg1_dtype, res_dt

return None, None


class WeakBooleanType:
"Python type representing type of Python boolean objects"

Expand Down Expand Up @@ -959,4 +974,5 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"WeakComplexType",
"_default_accumulation_dtype",
"_default_accumulation_dtype_fp_types",
"_find_buf_dtype_in_place_op",
]
26 changes: 13 additions & 13 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1508,43 +1508,43 @@ cdef class usm_ndarray:
return dpctl.tensor.bitwise_xor(other, self)

def __iadd__(self, other):
return dpctl.tensor.add(self, other, out=self)
return dpctl.tensor.add._inplace_op(self, other)

def __iand__(self, other):
return dpctl.tensor.bitwise_and(self, other, out=self)
return dpctl.tensor.bitwise_and._inplace_op(self, other)

def __ifloordiv__(self, other):
return dpctl.tensor.floor_divide(self, other, out=self)
return dpctl.tensor.floor_divide._inplace_op(self, other)

def __ilshift__(self, other):
return dpctl.tensor.bitwise_left_shift(self, other, out=self)
return dpctl.tensor.bitwise_left_shift._inplace_op(self, other)

def __imatmul__(self, other):
return dpctl.tensor.matmul(self, other, out=self)
return dpctl.tensor.matmul(self, other, out=self, dtype=self.dtype)

def __imod__(self, other):
return dpctl.tensor.remainder(self, other, out=self)
return dpctl.tensor.remainder._inplace_op(self, other)

def __imul__(self, other):
return dpctl.tensor.multiply(self, other, out=self)
return dpctl.tensor.multiply._inplace_op(self, other)

def __ior__(self, other):
return dpctl.tensor.bitwise_or(self, other, out=self)
return dpctl.tensor.bitwise_or._inplace_op(self, other)

def __ipow__(self, other):
return dpctl.tensor.pow(self, other, out=self)
return dpctl.tensor.pow._inplace_op(self, other)

def __irshift__(self, other):
return dpctl.tensor.bitwise_right_shift(self, other, out=self)
return dpctl.tensor.bitwise_right_shift._inplace_op(self, other)

def __isub__(self, other):
return dpctl.tensor.subtract(self, other, out=self)
return dpctl.tensor.subtract._inplace_op(self, other)

def __itruediv__(self, other):
return dpctl.tensor.divide(self, other, out=self)
return dpctl.tensor.divide._inplace_op(self, other)

def __ixor__(self, other):
return dpctl.tensor.bitwise_xor(self, other, out=self)
return dpctl.tensor.bitwise_xor._inplace_op(self, other)

def __str__(self):
return usm_ndarray_str(self)
Expand Down

0 comments on commit bab3571

Please sign in to comment.