Skip to content

Commit

Permalink
Adopt dpnp to DLPack v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Aug 12, 2024
1 parent 4a23239 commit a0e32c0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 11 deletions.
50 changes: 42 additions & 8 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,27 +184,61 @@ def __copy__(self):
# '__divmod__',
# '__doc__',

def __dlpack__(self, stream=None):
def __dlpack__(
self, *, stream=None, max_version=None, dl_device=None, copy=None
):
"""
Produces DLPack capsule.
Parameters
----------
stream : {:class:`dpctl.SyclQueue`, None}, optional
Execution queue to synchronize with. If ``None``,
synchronization is not performed.
Execution queue to synchronize with. If ``None``, synchronization
is not performed.
Default: ``None``.
max_version {tuple of ints, None}, optional
The maximum DLPack version the consumer (caller of ``__dlpack__``)
supports. As ``__dlpack__`` may not always return a DLPack capsule
with version `max_version`, the consumer must verify the version
even if this argument is passed.
Default: ``None``.
dl_device {tuple, None}, optional:
The device the returned DLPack capsule will be placed on. The
device must be a 2-tuple matching the format of
``__dlpack_device__`` method, an integer enumerator representing
the device type followed by an integer representing the index of
the device.
Default: ``None``.
copy {bool, None}, optional:
Boolean indicating whether or not to copy the input.
* If `copy` is ``True``, the input will always be copied.
* If ``False``, a ``BufferError`` will be raised if a copy is
deemed necessary.
* If ``None``, a copy will be made only if deemed necessary,
otherwise, the existing memory buffer will be reused.
Default: ``None``.
Raises
------
MemoryError
MemoryError:
when host memory can not be allocated.
DLPackCreationError
when array is allocated on a partitioned
SYCL device, or with a non-default context.
DLPackCreationError:
when array is allocated on a partitioned SYCL device, or with
a non-default context.
BufferError:
when a copy is deemed necessary but `copy` is ``False`` or when
the provided `dl_device` cannot be handled.
"""

return self._array_obj.__dlpack__(stream=stream)
return self._array_obj.__dlpack__(
stream=stream,
max_version=max_version,
dl_device=dl_device,
copy=copy,
)

def __dlpack_device__(self):
"""
Expand Down
35 changes: 32 additions & 3 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def default_float_type(device=None, sycl_queue=None):
return map_dtype_to_device(float64, _sycl_queue.sycl_device)


def from_dlpack(obj, /):
def from_dlpack(obj, /, *, device=None, copy=None):
"""
Create a dpnp array from a Python object implementing the ``__dlpack__``
protocol.
Expand All @@ -476,17 +476,46 @@ def from_dlpack(obj, /):
obj : object
A Python object representing an array that implements the ``__dlpack__``
and ``__dlpack_device__`` methods.
device : {:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`,
:class:`dpctl.tensor.Device`, tuple, None}, optional
Array API concept of a device where the output array is to be placed.
``device`` can be ``None``, an oneAPI filter selector string,
an instance of :class:`dpctl.SyclDevice` corresponding to
a non-partitioned SYCL device, an instance of :class:`dpctl.SyclQueue`,
a :class:`dpctl.tensor.Device` object returned by
:attr:`dpctl.tensor.usm_ndarray.device`, or a 2-tuple matching
the format of the output of the ``__dlpack_device__`` method,
an integer enumerator representing the device type followed by
an integer representing the index of the device.
Default: ``None``.
copy {bool, None}, optional
Boolean indicating whether or not to copy the input.
* If `copy``is ``True``, the input will always be copied.
* If ``False``, a ``BufferError`` will be raised if a copy is deemed
necessary.
* If ``None``, a copy will be made only if deemed necessary, otherwise,
the existing memory buffer will be reused.
Default: ``None``.
Returns
-------
out : dpnp_array
Returns a new dpnp array containing the data from another array
(obj) with the ``__dlpack__`` method on the same device as object.
Raises
------
TypeError:
if `obj` does not implement ``__dlpack__`` method
ValueError:
if the input array resides on an unsupported device
"""

usm_ary = dpt.from_dlpack(obj)
return dpnp_array._create_from_usm_ndarray(usm_ary)
usm_res = dpt.from_dlpack(obj, device=device, copy=copy)
return dpnp_array._create_from_usm_ndarray(usm_res)


def get_dpnp_descriptor(
Expand Down

0 comments on commit a0e32c0

Please sign in to comment.