Skip to content

Commit

Permalink
Merge pull request #1830 from IntelPython/backport-gh-1827
Browse files Browse the repository at this point in the history
Backport gh-1827
  • Loading branch information
oleksandr-pavlyk authored Sep 12, 2024
2 parents 0a0e9ae + 406af46 commit 19331c0
Show file tree
Hide file tree
Showing 16 changed files with 259 additions and 83 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The full list of changes that went into this release are:
* Update version of 'pybind11' used [gh-1758](https://github.com/IntelPython/dpctl/pull/1758), [gh-1812](https://github.com/IntelPython/dpctl/pull/1812)
* Handle possible exceptions by `usm_host_allocator` used with `std::vector` [gh-1791](https://github.com/IntelPython/dpctl/pull/1791)
* Use `dpctl::tensor::offset_utils::sycl_free_noexcept` instead of `sycl::free` in `host_task` tasks associated with life-time management of temporary USM allocations [gh-1797](https://github.com/IntelPython/dpctl/pull/1797)
* Add `"same_kind"`-style casting for in-place mathematical operators of `tensor.usm_ndarray` [gh-1827](https://github.com/IntelPython/dpctl/pull/1827), [gh-1830](https://github.com/IntelPython/dpctl/pull/1830)

### Fixed

Expand Down
144 changes: 141 additions & 3 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
_find_buf_dtype_in_place_op,
_resolve_weak_types,
_to_device_supported_dtype,
)
Expand Down Expand Up @@ -213,8 +214,8 @@ def __call__(self, x, /, *, out=None, order="K"):

if res_dt != out.dtype:
raise ValueError(
f"Output array of type {res_dt} is needed,"
f" got {out.dtype}"
f"Output array of type {res_dt} is needed, "
f"got {out.dtype}"
)

if (
Expand Down Expand Up @@ -650,7 +651,7 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):

if res_dt != out.dtype:
raise ValueError(
f"Output array of type {res_dt} is needed,"
f"Output array of type {res_dt} is needed, "
f"got {out.dtype}"
)

Expand Down Expand Up @@ -927,3 +928,140 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
)
_manager.add_event_pair(ht_, bf_ev)
return out

def _inplace_op(self, o1, o2):
if self.binary_inplace_fn_ is None:
raise ValueError(
"binary function does not have a dedicated in-place "
"implementation"
)
if not isinstance(o1, dpt.usm_ndarray):
raise TypeError(
"Expected first argument to be "
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
)
if not o1.flags.writable:
raise ValueError("provided left-hand side array is read-only")
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
q2, o2_usm_type = _get_queue_usm_type(o2)
if q2 is None:
exec_q = q1
res_usm_type = o1_usm_type
else:
exec_q = dpctl.utils.get_execution_queue((q1, q2))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
res_usm_type = dpctl.utils.get_coerced_usm_type(
(
o1_usm_type,
o2_usm_type,
)
)
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
o1_shape = o1.shape
o2_shape = _get_shape(o2)
if not isinstance(o2_shape, (tuple, list)):
raise TypeError(
"Shape of second argument can not be inferred. "
"Expected list or tuple."
)
try:
res_shape = _broadcast_shape_impl(
[
o1_shape,
o2_shape,
]
)
except ValueError:
raise ValueError(
"operands could not be broadcast together with shapes "
f"{o1_shape} and {o2_shape}"
)

if res_shape != o1_shape:
raise ValueError(
"The shape of the non-broadcastable left-hand "
f"side {o1_shape} is inconsistent with the "
f"broadcast shape {res_shape}."
)

sycl_dev = exec_q.sycl_device
o1_dtype = o1.dtype
o2_dtype = _get_dtype(o2, sycl_dev)
if not _validate_dtype(o2_dtype):
raise ValueError("Operand has an unsupported data type")

o1_dtype, o2_dtype = self.weak_type_resolver_(
o1_dtype, o2_dtype, sycl_dev
)

buf_dt, res_dt = _find_buf_dtype_in_place_op(
o1_dtype,
o2_dtype,
self.result_type_resolver_fn_,
sycl_dev,
)

if res_dt is None:
raise ValueError(
f"function '{self.name_}' does not support input types "
f"({o1_dtype}, {o2_dtype}), "
"and the inputs could not be safely coerced to any "
"supported types according to the casting rule "
"''same_kind''."
)

if res_dt != o1_dtype:
raise ValueError(
f"Output array of type {res_dt} is needed, " f"got {o1_dtype}"
)

_manager = SequentialOrderManager[exec_q]
if isinstance(o2, dpt.usm_ndarray):
src2 = o2
if (
ti._array_overlap(o2, o1)
and not ti._same_logical_tensors(o2, o1)
and buf_dt is None
):
buf_dt = o2_dtype
else:
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
if buf_dt is None:
if src2.shape != res_shape:
src2 = dpt.broadcast_to(src2, res_shape)
dep_evs = _manager.submitted_events
ht_, comp_ev = self.binary_inplace_fn_(
lhs=o1,
rhs=src2,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_, comp_ev)
else:
buf = dpt.empty_like(src2, dtype=buf_dt)
dep_evs = _manager.submitted_events
(
ht_copy_ev,
copy_ev,
) = ti._copy_usm_ndarray_into_usm_ndarray(
src=src2,
dst=buf,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_copy_ev, copy_ev)

buf = dpt.broadcast_to(buf, res_shape)
ht_, bf_ev = self.binary_inplace_fn_(
lhs=o1,
rhs=buf,
sycl_queue=exec_q,
depends=[copy_ev],
)
_manager.add_event_pair(ht_, bf_ev)

return o1
16 changes: 16 additions & 0 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,21 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
return None, None, None


def _find_buf_dtype_in_place_op(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
res_dt = query_fn(arg1_dtype, arg2_dtype)
if res_dt:
return None, res_dt

_fp16 = sycl_dev.has_aspect_fp16
_fp64 = sycl_dev.has_aspect_fp64
if _can_cast(arg2_dtype, arg1_dtype, _fp16, _fp64, casting="same_kind"):
res_dt = query_fn(arg1_dtype, arg1_dtype)
if res_dt:
return arg1_dtype, res_dt

return None, None


class WeakBooleanType:
"Python type representing type of Python boolean objects"

Expand Down Expand Up @@ -959,4 +974,5 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"WeakComplexType",
"_default_accumulation_dtype",
"_default_accumulation_dtype_fp_types",
"_find_buf_dtype_in_place_op",
]
26 changes: 13 additions & 13 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1508,43 +1508,43 @@ cdef class usm_ndarray:
return dpctl.tensor.bitwise_xor(other, self)

def __iadd__(self, other):
return dpctl.tensor.add(self, other, out=self)
return dpctl.tensor.add._inplace_op(self, other)

def __iand__(self, other):
return dpctl.tensor.bitwise_and(self, other, out=self)
return dpctl.tensor.bitwise_and._inplace_op(self, other)

def __ifloordiv__(self, other):
return dpctl.tensor.floor_divide(self, other, out=self)
return dpctl.tensor.floor_divide._inplace_op(self, other)

def __ilshift__(self, other):
return dpctl.tensor.bitwise_left_shift(self, other, out=self)
return dpctl.tensor.bitwise_left_shift._inplace_op(self, other)

def __imatmul__(self, other):
return dpctl.tensor.matmul(self, other, out=self)
return dpctl.tensor.matmul(self, other, out=self, dtype=self.dtype)

def __imod__(self, other):
return dpctl.tensor.remainder(self, other, out=self)
return dpctl.tensor.remainder._inplace_op(self, other)

def __imul__(self, other):
return dpctl.tensor.multiply(self, other, out=self)
return dpctl.tensor.multiply._inplace_op(self, other)

def __ior__(self, other):
return dpctl.tensor.bitwise_or(self, other, out=self)
return dpctl.tensor.bitwise_or._inplace_op(self, other)

def __ipow__(self, other):
return dpctl.tensor.pow(self, other, out=self)
return dpctl.tensor.pow._inplace_op(self, other)

def __irshift__(self, other):
return dpctl.tensor.bitwise_right_shift(self, other, out=self)
return dpctl.tensor.bitwise_right_shift._inplace_op(self, other)

def __isub__(self, other):
return dpctl.tensor.subtract(self, other, out=self)
return dpctl.tensor.subtract._inplace_op(self, other)

def __itruediv__(self, other):
return dpctl.tensor.divide(self, other, out=self)
return dpctl.tensor.divide._inplace_op(self, other)

def __ixor__(self, other):
return dpctl.tensor.bitwise_xor(self, other, out=self)
return dpctl.tensor.bitwise_xor._inplace_op(self, other)

def __str__(self):
return usm_ndarray_str(self)
Expand Down
75 changes: 67 additions & 8 deletions dpctl/tests/elementwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,9 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
# operators use a different Python implementation which permits
# same kind style casting
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
ar1 += ar2
assert (
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
Expand All @@ -373,9 +375,28 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
else:
with pytest.raises(ValueError):
ar1 += ar2

# here, test the special case where out is the first argument
# so an in-place kernel is used for efficiency
# this covers a specific branch in the BinaryElementwiseFunc logic
ar1 = dpt.ones(sz, dtype=op1_dtype)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
dpt.add(ar1, ar2, out=ar1)
assert (
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
).all()

ar3 = dpt.ones(sz, dtype=op1_dtype)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)[::2]
dpt.add(ar3, ar4, out=ar3)
assert (
dpt.asnumpy(ar3) == np.full(ar3.shape, 2, dtype=ar3.dtype)
).all()
else:
with pytest.raises(ValueError):
dpt.add(ar1, ar2, out=ar1)

# out is second arg
ar1 = dpt.ones(sz, dtype=op1_dtype)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
Expand All @@ -401,7 +422,7 @@ def test_add_inplace_broadcasting():
m = dpt.ones((100, 5), dtype="i4")
v = dpt.arange(5, dtype="i4")

m += v
dpt.add(m, v, out=m)
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()

# check case where second arg is out
Expand All @@ -411,6 +432,26 @@ def test_add_inplace_broadcasting():
).all()


def test_add_inplace_operator_broadcasting():
get_queue_or_skip()

m = dpt.ones((100, 5), dtype="i4")
v = dpt.arange(5, dtype="i4")

m += v
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()


def test_add_inplace_operator_mutual_broadcast():
get_queue_or_skip()

x1 = dpt.ones((1, 10), dtype="i4")
x2 = dpt.ones((10, 1), dtype="i4")

with pytest.raises(ValueError):
dpt.add._inplace_op(x1, x2)


def test_add_inplace_errors():
get_queue_or_skip()
try:
Expand All @@ -425,27 +466,45 @@ def test_add_inplace_errors():
ar1 = dpt.ones(2, dtype="float32", sycl_queue=gpu_queue)
ar2 = dpt.ones_like(ar1, sycl_queue=cpu_queue)
with pytest.raises(ExecutionPlacementError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)

ar1 = dpt.ones(2, dtype="float32")
ar2 = dpt.ones(3, dtype="float32")
with pytest.raises(ValueError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)

ar1 = np.ones(2, dtype="float32")
ar2 = dpt.ones(2, dtype="float32")
with pytest.raises(TypeError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)

ar1 = dpt.ones(2, dtype="float32")
ar2 = dict()
with pytest.raises(ValueError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)

ar1 = dpt.ones((2, 1), dtype="float32")
ar2 = dpt.ones((1, 2), dtype="float32")
with pytest.raises(ValueError):
ar1 += ar2
dpt.add(ar1, ar2, out=ar1)


def test_add_inplace_operator_errors():
q1 = get_queue_or_skip()
q2 = get_queue_or_skip()

x = dpt.ones(10, dtype="i4", sycl_queue=q1)
with pytest.raises(TypeError):
dpt.add._inplace_op(dict(), x)

x.flags["W"] = False
with pytest.raises(ValueError):
dpt.add._inplace_op(x, 2)

x_q1 = dpt.ones(10, dtype="i4", sycl_queue=q1)
x_q2 = dpt.ones(10, dtype="i4", sycl_queue=q2)
with pytest.raises(ExecutionPlacementError):
dpt.add._inplace_op(x_q1, x_q2)


def test_add_inplace_same_tensors():
Expand Down
Loading

0 comments on commit 19331c0

Please sign in to comment.