From 8c828a59af1de62fac0459cab9981384db5e6a81 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Mon, 20 Mar 2023 17:06:11 +0000 Subject: [PATCH] fix: expose array interface for CUDA (#2327) * fix: expose array interface for CUDA * test: cover other cases --- src/awkward/index.py | 9 +++++++-- tests/test_2327_array_interface.py | 31 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 tests/test_2327_array_interface.py diff --git a/src/awkward/index.py b/src/awkward/index.py index 54f6af7eee..c6ab440e25 100644 --- a/src/awkward/index.py +++ b/src/awkward/index.py @@ -149,8 +149,13 @@ def raw(self, nplike): def __len__(self): return self.length - def __array__(self, dtype=None): - return self._nplike.asarray(self._data, dtype=dtype) + @property + def __cuda_array_interface__(self): + return self._data.__cuda_array_interface__ + + @property + def __array_interface__(self): + return self._data.__array_interface__ def __repr__(self): return self._repr("", "", "") diff --git a/tests/test_2327_array_interface.py b/tests/test_2327_array_interface.py new file mode 100644 index 0000000000..dd14bde600 --- /dev/null +++ b/tests/test_2327_array_interface.py @@ -0,0 +1,31 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE + +import numpy as np +import pytest + +import awkward as ak + + +def test_wrap_index_cupy(): + cp = pytest.importorskip("cupy") + data = cp.arange(10, dtype=cp.int64) + index = ak.index.Index64(data) + other_index = ak.index.Index64(index) + other_data = cp.asarray(other_index) + assert cp.shares_memory(data, other_data) + + +def test_wrap_index_numpy(): + data = np.arange(10, dtype=np.int64) + index = ak.index.Index64(data) + other_index = ak.index.Index64(index) + other_data = np.asarray(other_index) + assert np.shares_memory(data, other_data) + + +def test_wrap_bare_list(): + data = [1, 2, 3, 4, 5] + index = ak.index.Index64(data) + other_index = ak.index.Index64(index) + other_data = np.asarray(other_index) + assert other_data.tolist() == data