Skip to content

Commit

Permalink
Merge pull request #1756 from IntelPython/propagate-ro-flag-for-suai-…
Browse files Browse the repository at this point in the history
…in-asarray

Propagate read-only flag from sycl_usm_array_interface in asarray
  • Loading branch information
oleksandr-pavlyk authored Jul 26, 2024
2 parents f83f95b + 9a5715f commit 9403f76
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
7 changes: 7 additions & 0 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ def _usm_ndarray_from_suai(obj):
buffer=membuf,
strides=sua_iface.get("strides", None),
)
_data_field = sua_iface["data"]
if isinstance(_data_field, tuple) and len(_data_field) > 1:
ro_field = _data_field[1]
else:
ro_field = False
if ro_field:
ary.flags["W"] = False
return ary


Expand Down
26 changes: 26 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,3 +2368,29 @@ def test_gh_1201():
c = dpt.flip(dpt.empty(a.shape, dtype=a.dtype))
c[:] = a
assert (dpt.asnumpy(c) == a).all()


class ObjWithSyclUsmArrayInterface:
def __init__(self, ary):
self._array_obj = ary

@property
def __sycl_usm_array_interface__(self):
_suai = self._array_obj.__sycl_usm_array_interface__
return _suai


@pytest.mark.parametrize("ro_flag", [True, False])
def test_asarray_writable_flag(ro_flag):
try:
a = dpt.empty(8)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")

a.flags["W"] = not ro_flag
wrapped = ObjWithSyclUsmArrayInterface(a)

b = dpt.asarray(wrapped)

assert b.flags["W"] == (not ro_flag)
assert b._pointer == a._pointer

0 comments on commit 9403f76

Please sign in to comment.