Skip to content

Commit

Permalink
Fixed array API test failure by adding validation
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-pavlyk committed Aug 6, 2023
1 parent cc08b5d commit 5c1a961
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 2 additions & 0 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,8 @@ def _extract_impl(ary, ary_mask, axis=0):
dst = dpt.empty(
dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device
)
if dst.size == 0:
return dst
hev, _ = ti._extract(
src=ary,
cumsum=cumsum,
Expand Down
19 changes: 18 additions & 1 deletion dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,8 @@ cdef class usm_ndarray:
ind, (<object>self).shape, (<object> self).strides,
self.get_offset())
cdef usm_ndarray res
cdef int i = 0
cdef bint matching = 1

if len(_meta) < 5:
raise RuntimeError
Expand All @@ -787,7 +789,20 @@ cdef class usm_ndarray:

from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
return _extract_impl(res, adv_ind[0], axis=adv_ind_start_p)
key_ = adv_ind[0]
adv_ind_end_p = key_.ndim + adv_ind_start_p
if adv_ind_end_p > res.ndim:
raise IndexError("too many indices for the array")
key_shape = key_.shape
arr_shape = res.shape[adv_ind_start_p:adv_ind_end_p]
for i in range(key_.ndim):
if matching:
if not key_shape[i] == arr_shape[i] and key_shape[i] > 0:
matching = 0
if not matching:
raise IndexError("boolean index did not match indexed array in dimensions")
res = _extract_impl(res, key_, axis=adv_ind_start_p)
return res

if any(ind.dtype == dpt_bool for ind in adv_ind):
adv_ind_int = list()
Expand Down Expand Up @@ -1152,6 +1167,8 @@ cdef class usm_ndarray:
if adv_ind_start_p < 0:
# basic slicing
if isinstance(rhs, usm_ndarray):
if Xv.size == 0:
return
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
else:
if hasattr(rhs, "__sycl_usm_array_interface__"):
Expand Down

0 comments on commit 5c1a961

Please sign in to comment.