Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

usm_ndarray.to_device(dev, stream=queue) support #1331

Merged
merged 8 commits into from
Aug 11, 2023
45 changes: 41 additions & 4 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,23 @@ def _broadcast_shapes(sh1, sh2):
).shape


def _broadcast_strides(X_shape, X_strides, res_ndim):
"""
Broadcasts strides to match the given dimensions;
returns tuple type strides.
"""
out_strides = [0] * res_ndim
X_shape_len = len(X_shape)
str_dim = -X_shape_len
for i in range(X_shape_len):
shape_value = X_shape[i]
if not shape_value == 1:
out_strides[str_dim] = X_strides[i]
str_dim += 1

return tuple(out_strides)


def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
if any(
not isinstance(arg, dpt.usm_ndarray)
Expand All @@ -268,7 +285,7 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
except ValueError as exc:
raise ValueError("Shapes of two arrays are not compatible") from exc

if dst.size < src.size:
if dst.size < src.size and dst.size < np.prod(common_shape):
raise ValueError("Destination is smaller ")

if len(common_shape) > dst.ndim:
Expand All @@ -279,13 +296,33 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
common_shape = common_shape[ones_count:]

if src.ndim < len(common_shape):
new_src_strides = (0,) * (len(common_shape) - src.ndim) + src.strides
new_src_strides = _broadcast_strides(
src.shape, src.strides, len(common_shape)
)
src_same_shape = dpt.usm_ndarray(
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
)
elif src.ndim == len(common_shape):
new_src_strides = _broadcast_strides(
src.shape, src.strides, len(common_shape)
)
src_same_shape = dpt.usm_ndarray(
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
)
else:
src_same_shape = src
src_same_shape.shape = common_shape
# since broadcasting succeeded, src.ndim is greater because of
# leading sequence of ones, so we trim it
n = len(common_shape)
new_src_strides = _broadcast_strides(
src.shape[-n:], src.strides[-n:], n
)
src_same_shape = dpt.usm_ndarray(
common_shape,
dtype=src.dtype,
buffer=src.usm_data,
strides=new_src_strides,
offset=src._element_offset,
)

_copy_same_shape(dst, src_same_shape)

Expand Down
2 changes: 1 addition & 1 deletion dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,5 +343,5 @@ def nonzero(arr):
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(arr)}"
)
if arr.ndim == 0:
raise ValueError("Array of positive rank is exepcted")
raise ValueError("Array of positive rank is expected")
return _nonzero_impl(arr)
18 changes: 1 addition & 17 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dputils

from ._copy_utils import _broadcast_strides
from ._type_utils import _to_device_supported_dtype

__doc__ = (
Expand Down Expand Up @@ -120,23 +121,6 @@ def __repr__(self):
return self._finfo.__repr__()


def _broadcast_strides(X_shape, X_strides, res_ndim):
"""
Broadcasts strides to match the given dimensions;
returns tuple type strides.
"""
out_strides = [0] * res_ndim
X_shape_len = len(X_shape)
str_dim = -X_shape_len
for i in range(X_shape_len):
shape_value = X_shape[i]
if not shape_value == 1:
out_strides[str_dim] = X_strides[i]
str_dim += 1

return tuple(out_strides)


def _broadcast_shape_impl(shapes):
if len(set(shapes)) == 1:
return shapes[0]
Expand Down
17 changes: 11 additions & 6 deletions dpctl/tensor/_stride_utils.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ cdef int _from_input_shape_strides(
cdef int j
cdef bint all_incr = 1
cdef bint all_decr = 1
cdef bint all_incr_modified = 0
cdef bint all_decr_modified = 0
cdef bint strides_inspected = 0
cdef Py_ssize_t elem_count = 1
cdef Py_ssize_t min_shift = 0
cdef Py_ssize_t max_shift = 0
Expand Down Expand Up @@ -167,27 +166,33 @@ cdef int _from_input_shape_strides(
while (j < nd and shape_arr[j] == 1):
j = j + 1
if j < nd:
strides_inspected = 1
if all_incr:
all_incr_modified = 1
all_incr = (
(strides_arr[i] > 0) and
(strides_arr[j] > 0) and
(strides_arr[i] <= strides_arr[j])
)
if all_decr:
all_decr_modified = 1
all_decr = (
(strides_arr[i] > 0) and
(strides_arr[j] > 0) and
(strides_arr[i] >= strides_arr[j])
)
i = j
else:
if not strides_inspected:
# all dimensions have size 1 except
# dimension 'i'. Array is both C and F
# contiguous
strides_inspected = 1
all_incr = (strides_arr[i] == 1)
all_decr = all_incr
break
# should only set contig flags on actually obtained
# values, rather than default values
all_incr = all_incr and all_incr_modified
all_decr = all_decr and all_decr_modified
all_incr = all_incr and strides_inspected
all_decr = all_decr and strides_inspected
if all_incr and all_decr:
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
elif all_incr:
Expand Down
12 changes: 9 additions & 3 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ cdef class usm_ndarray:
return _take_multi_index(res, adv_ind, adv_ind_start_p)


def to_device(self, target):
def to_device(self, target, stream=None):
""" to_device(target_device)

Transfers this array to specified target device.
Expand Down Expand Up @@ -856,6 +856,14 @@ cdef class usm_ndarray:
cdef c_dpctl.DPCTLSyclQueueRef QRef = NULL
cdef c_dpmem._Memory arr_buf
d = Device.create_device(target)

if (stream is None or type(stream) is not dpctl.SyclQueue or
stream == self.sycl_queue):
pass
else:
ev = self.sycl_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])

if (d.sycl_context == self.sycl_context):
arr_buf = <c_dpmem._Memory> self.usm_data
QRef = (<c_dpctl.SyclQueue> d.sycl_queue).get_queue_ref()
Expand Down Expand Up @@ -1167,8 +1175,6 @@ cdef class usm_ndarray:
if adv_ind_start_p < 0:
# basic slicing
if isinstance(rhs, usm_ndarray):
if Xv.size == 0:
return
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
else:
if hasattr(rhs, "__sycl_usm_array_interface__"):
Expand Down
66 changes: 66 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,25 @@ def test_usm_ndarray_flags():
x.flags["C"] = False


def test_usm_ndarray_flags_bug_gh_1334():
get_queue_or_skip()
a = dpt.ones((2, 3), dtype="u4")
r = dpt.reshape(a, (1, 6, 1))
assert r.flags["C"] and r.flags["F"]

a = dpt.ones((2, 3), dtype="u4", order="F")
r = dpt.reshape(a, (1, 6, 1), order="F")
assert r.flags["C"] and r.flags["F"]

a = dpt.ones((2, 3, 4), dtype="i8")
r = dpt.sum(a, axis=(1, 2), keepdims=True)
assert r.flags["C"] and r.flags["F"]

a = dpt.ones((2, 1), dtype="?")
r = a[:, 1::-1]
assert r.flags["F"] and r.flags["C"]


@pytest.mark.parametrize(
"dtype",
[
Expand Down Expand Up @@ -1012,6 +1031,53 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
Zusm_empty[Ellipsis] = Zusm_3d[0, 0, 0:0]


def test_setitem_broadcasting():
get_queue_or_skip()
dst = dpt.ones((2, 3, 4), dtype="u4")
src = dpt.zeros((3, 1), dtype=dst.dtype)
dst[...] = src
expected = np.zeros(dst.shape, dtype=dst.dtype)
assert np.array_equal(dpt.asnumpy(dst), expected)


def test_setitem_broadcasting_empty_dst_validation():
"Broadcasting rules apply, except exception"
get_queue_or_skip()
dst = dpt.ones((2, 0, 5, 4), dtype="i8")
src = dpt.ones((2, 0, 3, 4), dtype="i8")
with pytest.raises(ValueError):
dst[...] = src


def test_setitem_broadcasting_empty_dst_edge_case():
"""RHS is shunken to empty array by
broadasting rule, hence no exception"""
get_queue_or_skip()
dst = dpt.ones(1, dtype="i8")[0:0]
src = dpt.ones(tuple(), dtype="i8")
dst[...] = src


def test_setitem_broadcasting_src_ndim_equal_dst_ndim():
get_queue_or_skip()
dst = dpt.ones((2, 3, 4), dtype="i4")
src = dpt.zeros((2, 1, 4), dtype="i4")
dst[...] = src

expected = np.zeros(dst.shape, dtype=dst.dtype)
assert np.array_equal(dpt.asnumpy(dst), expected)


def test_setitem_broadcasting_src_ndim_greater_than_dst_ndim():
get_queue_or_skip()
dst = dpt.ones((2, 3, 4), dtype="i4")
src = dpt.zeros((1, 2, 1, 4), dtype="i4")
dst[...] = src

expected = np.zeros(dst.shape, dtype=dst.dtype)
assert np.array_equal(dpt.asnumpy(dst), expected)


@pytest.mark.parametrize(
"dtype",
_all_dtypes,
Expand Down