Skip to content

Commit

Permalink
Merge pull request #1331 from IntelPython/to_device-stream-support
Browse files Browse the repository at this point in the history
usm_ndarray.to_device(dev, stream=queue) support
  • Loading branch information
oleksandr-pavlyk authored Aug 11, 2023
2 parents 6f0969c + 706d80f commit 074ec3a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 31 deletions.
45 changes: 41 additions & 4 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,23 @@ def _broadcast_shapes(sh1, sh2):
).shape


def _broadcast_strides(X_shape, X_strides, res_ndim):
"""
Broadcasts strides to match the given dimensions;
returns tuple type strides.
"""
out_strides = [0] * res_ndim
X_shape_len = len(X_shape)
str_dim = -X_shape_len
for i in range(X_shape_len):
shape_value = X_shape[i]
if not shape_value == 1:
out_strides[str_dim] = X_strides[i]
str_dim += 1

return tuple(out_strides)


def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
if any(
not isinstance(arg, dpt.usm_ndarray)
Expand All @@ -268,7 +285,7 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
except ValueError as exc:
raise ValueError("Shapes of two arrays are not compatible") from exc

if dst.size < src.size:
if dst.size < src.size and dst.size < np.prod(common_shape):
raise ValueError("Destination is smaller ")

if len(common_shape) > dst.ndim:
Expand All @@ -279,13 +296,33 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
common_shape = common_shape[ones_count:]

if src.ndim < len(common_shape):
new_src_strides = (0,) * (len(common_shape) - src.ndim) + src.strides
new_src_strides = _broadcast_strides(
src.shape, src.strides, len(common_shape)
)
src_same_shape = dpt.usm_ndarray(
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
)
elif src.ndim == len(common_shape):
new_src_strides = _broadcast_strides(
src.shape, src.strides, len(common_shape)
)
src_same_shape = dpt.usm_ndarray(
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
)
else:
src_same_shape = src
src_same_shape.shape = common_shape
# since broadcasting succeeded, src.ndim is greater because of
# leading sequence of ones, so we trim it
n = len(common_shape)
new_src_strides = _broadcast_strides(
src.shape[-n:], src.strides[-n:], n
)
src_same_shape = dpt.usm_ndarray(
common_shape,
dtype=src.dtype,
buffer=src.usm_data,
strides=new_src_strides,
offset=src._element_offset,
)

_copy_same_shape(dst, src_same_shape)

Expand Down
2 changes: 1 addition & 1 deletion dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,5 +343,5 @@ def nonzero(arr):
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(arr)}"
)
if arr.ndim == 0:
raise ValueError("Array of positive rank is exepcted")
raise ValueError("Array of positive rank is expected")
return _nonzero_impl(arr)
18 changes: 1 addition & 17 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dputils

from ._copy_utils import _broadcast_strides
from ._type_utils import _to_device_supported_dtype

__doc__ = (
Expand Down Expand Up @@ -120,23 +121,6 @@ def __repr__(self):
return self._finfo.__repr__()


def _broadcast_strides(X_shape, X_strides, res_ndim):
"""
Broadcasts strides to match the given dimensions;
returns tuple type strides.
"""
out_strides = [0] * res_ndim
X_shape_len = len(X_shape)
str_dim = -X_shape_len
for i in range(X_shape_len):
shape_value = X_shape[i]
if not shape_value == 1:
out_strides[str_dim] = X_strides[i]
str_dim += 1

return tuple(out_strides)


def _broadcast_shape_impl(shapes):
if len(set(shapes)) == 1:
return shapes[0]
Expand Down
17 changes: 11 additions & 6 deletions dpctl/tensor/_stride_utils.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ cdef int _from_input_shape_strides(
cdef int j
cdef bint all_incr = 1
cdef bint all_decr = 1
cdef bint all_incr_modified = 0
cdef bint all_decr_modified = 0
cdef bint strides_inspected = 0
cdef Py_ssize_t elem_count = 1
cdef Py_ssize_t min_shift = 0
cdef Py_ssize_t max_shift = 0
Expand Down Expand Up @@ -167,27 +166,33 @@ cdef int _from_input_shape_strides(
while (j < nd and shape_arr[j] == 1):
j = j + 1
if j < nd:
strides_inspected = 1
if all_incr:
all_incr_modified = 1
all_incr = (
(strides_arr[i] > 0) and
(strides_arr[j] > 0) and
(strides_arr[i] <= strides_arr[j])
)
if all_decr:
all_decr_modified = 1
all_decr = (
(strides_arr[i] > 0) and
(strides_arr[j] > 0) and
(strides_arr[i] >= strides_arr[j])
)
i = j
else:
if not strides_inspected:
# all dimensions have size 1 except
# dimension 'i'. Array is both C and F
# contiguous
strides_inspected = 1
all_incr = (strides_arr[i] == 1)
all_decr = all_incr
break
# should only set contig flags on actually obtained
# values, rather than default values
all_incr = all_incr and all_incr_modified
all_decr = all_decr and all_decr_modified
all_incr = all_incr and strides_inspected
all_decr = all_decr and strides_inspected
if all_incr and all_decr:
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
elif all_incr:
Expand Down
12 changes: 9 additions & 3 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ cdef class usm_ndarray:
return _take_multi_index(res, adv_ind, adv_ind_start_p)


def to_device(self, target):
def to_device(self, target, stream=None):
""" to_device(target_device)
Transfers this array to specified target device.
Expand Down Expand Up @@ -856,6 +856,14 @@ cdef class usm_ndarray:
cdef c_dpctl.DPCTLSyclQueueRef QRef = NULL
cdef c_dpmem._Memory arr_buf
d = Device.create_device(target)

if (stream is None or type(stream) is not dpctl.SyclQueue or
stream == self.sycl_queue):
pass
else:
ev = self.sycl_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])

if (d.sycl_context == self.sycl_context):
arr_buf = <c_dpmem._Memory> self.usm_data
QRef = (<c_dpctl.SyclQueue> d.sycl_queue).get_queue_ref()
Expand Down Expand Up @@ -1167,8 +1175,6 @@ cdef class usm_ndarray:
if adv_ind_start_p < 0:
# basic slicing
if isinstance(rhs, usm_ndarray):
if Xv.size == 0:
return
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
else:
if hasattr(rhs, "__sycl_usm_array_interface__"):
Expand Down
66 changes: 66 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,25 @@ def test_usm_ndarray_flags():
x.flags["C"] = False


def test_usm_ndarray_flags_bug_gh_1334():
get_queue_or_skip()
a = dpt.ones((2, 3), dtype="u4")
r = dpt.reshape(a, (1, 6, 1))
assert r.flags["C"] and r.flags["F"]

a = dpt.ones((2, 3), dtype="u4", order="F")
r = dpt.reshape(a, (1, 6, 1), order="F")
assert r.flags["C"] and r.flags["F"]

a = dpt.ones((2, 3, 4), dtype="i8")
r = dpt.sum(a, axis=(1, 2), keepdims=True)
assert r.flags["C"] and r.flags["F"]

a = dpt.ones((2, 1), dtype="?")
r = a[:, 1::-1]
assert r.flags["F"] and r.flags["C"]


@pytest.mark.parametrize(
"dtype",
[
Expand Down Expand Up @@ -1012,6 +1031,53 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
Zusm_empty[Ellipsis] = Zusm_3d[0, 0, 0:0]


def test_setitem_broadcasting():
get_queue_or_skip()
dst = dpt.ones((2, 3, 4), dtype="u4")
src = dpt.zeros((3, 1), dtype=dst.dtype)
dst[...] = src
expected = np.zeros(dst.shape, dtype=dst.dtype)
assert np.array_equal(dpt.asnumpy(dst), expected)


def test_setitem_broadcasting_empty_dst_validation():
"Broadcasting rules apply, except exception"
get_queue_or_skip()
dst = dpt.ones((2, 0, 5, 4), dtype="i8")
src = dpt.ones((2, 0, 3, 4), dtype="i8")
with pytest.raises(ValueError):
dst[...] = src


def test_setitem_broadcasting_empty_dst_edge_case():
"""RHS is shunken to empty array by
broadasting rule, hence no exception"""
get_queue_or_skip()
dst = dpt.ones(1, dtype="i8")[0:0]
src = dpt.ones(tuple(), dtype="i8")
dst[...] = src


def test_setitem_broadcasting_src_ndim_equal_dst_ndim():
get_queue_or_skip()
dst = dpt.ones((2, 3, 4), dtype="i4")
src = dpt.zeros((2, 1, 4), dtype="i4")
dst[...] = src

expected = np.zeros(dst.shape, dtype=dst.dtype)
assert np.array_equal(dpt.asnumpy(dst), expected)


def test_setitem_broadcasting_src_ndim_greater_than_dst_ndim():
get_queue_or_skip()
dst = dpt.ones((2, 3, 4), dtype="i4")
src = dpt.zeros((1, 2, 1, 4), dtype="i4")
dst[...] = src

expected = np.zeros(dst.shape, dtype=dst.dtype)
assert np.array_equal(dpt.asnumpy(dst), expected)


@pytest.mark.parametrize(
"dtype",
_all_dtypes,
Expand Down

0 comments on commit 074ec3a

Please sign in to comment.