Skip to content

Commit

Permalink
[BUG] fix empty initialization of device_ndarray in pylibraft (#2061)
Browse files Browse the repository at this point in the history
The `device_ndarray.empty()` function can be used to allocate device memory without initialization. Previously, the memory has been allocated (uninitialized) on the host and then been copied to the device.

This PR fixes the behavior for 'empty()' by allowing the `device_ndarray` to be initialized by an `array_interface` instead of an `numpy.ndarray` instance, which conditionally allows to skip the initialization of the `DeviceBuffer`.

CC @tfeher, @cjnolet

Authors:
  - Malte Förster (https://github.com/mfoerste4)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2061
  • Loading branch information
mfoerste4 authored Dec 15, 2023
1 parent 80a48ca commit 1beb556
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions python/pylibraft/pylibraft/common/device_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self, np_ndarray):
Parameters
----------
ndarray : A numpy.ndarray which will be copied and moved to the device
ndarray : Can be numpy.ndarray, array like or even directly
an __array_interface__. Only case it is a numpy.ndarray its
contents will be copied to the device.
Examples
--------
Expand All @@ -58,11 +60,38 @@ def __init__(self, np_ndarray):
raft_array = device_ndarray.empty((100, 50))
torch_tensor = torch.as_tensor(raft_array, device='cuda')
"""
self.ndarray_ = np_ndarray

if type(np_ndarray) is np.ndarray:
# np_ndarray IS an actual numpy.ndarray
self.__array_interface__ = np_ndarray.__array_interface__.copy()
self.ndarray_ = np_ndarray
copy = True
elif hasattr(np_ndarray, "__array_interface__"):
# np_ndarray HAS an __array_interface__
self.__array_interface__ = np_ndarray.__array_interface__.copy()
self.ndarray_ = np_ndarray
copy = False
elif all(
name in np_ndarray for name in {"typestr", "shape", "version"}
):
# np_ndarray IS an __array_interface__
self.__array_interface__ = np_ndarray.copy()
self.ndarray_ = None
copy = False
else:
raise ValueError(
"np_ndarray should be or contain __array_interface__"
)

order = "C" if self.c_contiguous else "F"
self.device_buffer_ = rmm.DeviceBuffer.to_device(
self.ndarray_.tobytes(order=order)
)
if copy:
self.device_buffer_ = rmm.DeviceBuffer.to_device(
self.ndarray_.tobytes(order=order)
)
else:
self.device_buffer_ = rmm.DeviceBuffer(
size=np.prod(self.shape) * self.dtype.itemsize
)

@classmethod
def empty(cls, shape, dtype=np.float32, order="C"):
Expand All @@ -82,7 +111,7 @@ def empty(cls, shape, dtype=np.float32, order="C"):
or column-major (Fortran-style) order in memory
"""
arr = np.empty(shape, dtype=dtype, order=order)
return cls(arr)
return cls(arr.__array_interface__.copy())

@property
def c_contiguous(self):
Expand All @@ -104,23 +133,23 @@ def dtype(self):
"""
Datatype of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return np.dtype(array_interface["typestr"])

@property
def shape(self):
"""
Shape of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return array_interface["shape"]

@property
def strides(self):
"""
Strides of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return array_interface.get("strides")

@property
Expand All @@ -131,7 +160,7 @@ def __cuda_array_interface__(self):
zero-copy semantics.
"""
device_cai = self.device_buffer_.__cuda_array_interface__
host_cai = self.ndarray_.__array_interface__.copy()
host_cai = self.__array_interface__.copy()
host_cai["data"] = (device_cai["data"][0], device_cai["data"][1])

return host_cai
Expand Down

0 comments on commit 1beb556

Please sign in to comment.