diff --git a/conda/recipes/dask-cuda/meta.yaml b/conda/recipes/dask-cuda/meta.yaml index ece107eb8..8988de6a9 100644 --- a/conda/recipes/dask-cuda/meta.yaml +++ b/conda/recipes/dask-cuda/meta.yaml @@ -24,7 +24,7 @@ requirements: - setuptools run: - python x.x - - dask-core >=2.4.0 + - dask >=2.4.0 - distributed >=2.18.0 - pynvml >=8.0.3 - numpy >=1.16.0 diff --git a/dask_cuda/cli/dask_cuda_worker.py b/dask_cuda/cli/dask_cuda_worker.py index 65cef580c..cf99e5f09 100755 --- a/dask_cuda/cli/dask_cuda_worker.py +++ b/dask_cuda/cli/dask_cuda_worker.py @@ -191,6 +191,11 @@ "InfiniBand only and will still cause unpredictable errors if not _ALL_ " "interfaces are connected and properly configured.", ) +@click.option( + "--enable-jit-unspill/--disable-jit-unspill", + default=None, # If not specified, use Dask config + help="Enable just-in-time unspilling", +) def main( scheduler, host, @@ -218,6 +223,7 @@ def main( enable_nvlink, enable_rdmacm, net_devices, + enable_jit_unspill, **kwargs, ): if tls_ca_file and tls_cert and tls_key: @@ -252,6 +258,7 @@ def main( enable_nvlink, enable_rdmacm, net_devices, + enable_jit_unspill, **kwargs, ) diff --git a/dask_cuda/cuda_worker.py b/dask_cuda/cuda_worker.py index 39c464355..bd700df20 100644 --- a/dask_cuda/cuda_worker.py +++ b/dask_cuda/cuda_worker.py @@ -71,6 +71,7 @@ def __init__( enable_nvlink=False, enable_rdmacm=False, net_devices=None, + jit_unspill=None, **kwargs, ): # Required by RAPIDS libraries (e.g., cuDF) to ensure no context @@ -177,6 +178,11 @@ def del_pid_file(): cuda_device_index=0, ) + if jit_unspill is None: + self.jit_unspill = dask.config.get("jit-unspill", default=False) + else: + self.jit_unspill = jit_unspill + self.nannies = [ t( scheduler, @@ -216,6 +222,7 @@ def del_pid_file(): ), "memory_limit": memory_limit, "local_directory": local_directory, + "jit_unspill": self.jit_unspill, }, ), **kwargs, diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py index 00c747099..ae4a34886 100644 --- a/dask_cuda/device_host_file.py +++ b/dask_cuda/device_host_file.py @@ -16,15 +16,14 @@ from distributed.utils import nbytes from distributed.worker import weight +from . import proxy_object from .is_device_object import is_device_object from .utils import nvtx_annotate class DeviceSerialized: """ Store device object on the host - This stores a device-side object as - 1. A msgpack encodable header 2. A list of `bytes`-like objects (like NumPy arrays) that are in host memory @@ -66,6 +65,27 @@ def host_to_device(s: DeviceSerialized) -> object: return deserialize(s.header, s.frames) +@nvtx_annotate("SPILL_D2H", color="red", domain="dask_cuda") +def pxy_obj_device_to_host(obj: object) -> proxy_object.ProxyObject: + try: + # Never re-serialize proxy objects. + if obj._obj_pxy["serializers"] is None: + return obj + except (KeyError, AttributeError): + pass + + # Notice, both the "dask" and the "pickle" serializer will + # spill `obj` to main memory. + return proxy_object.asproxy(obj, serializers=["dask", "pickle"]) + + +@nvtx_annotate("SPILL_H2D", color="green", domain="dask_cuda") +def pxy_obj_host_to_device(s: proxy_object.ProxyObject) -> object: + # Notice, we do _not_ deserialize at this point. The proxy + # object automatically deserialize just-in-time. + return s + + class DeviceHostFile(ZictBase): """ Manages serialization/deserialization of objects. @@ -86,10 +106,16 @@ class DeviceHostFile(ZictBase): implies no spilling to disk. local_directory: path Path where to store serialized objects on disk + jit_unspill: bool + If True, enable just-in-time unspilling (see proxy_object.ProxyObject). """ def __init__( - self, device_memory_limit=None, memory_limit=None, local_directory=None, + self, + device_memory_limit=None, + memory_limit=None, + local_directory=None, + jit_unspill=False, ): if local_directory is None: local_directory = dask.config.get("temporary-directory") or os.getcwd() @@ -115,7 +141,14 @@ def __init__( self.device_keys = set() self.device_func = dict() - self.device_host_func = Func(device_to_host, host_to_device, self.host_buffer) + if jit_unspill: + self.device_host_func = Func( + pxy_obj_device_to_host, pxy_obj_host_to_device, self.host_buffer + ) + else: + self.device_host_func = Func( + device_to_host, host_to_device, self.host_buffer + ) self.device_buffer = Buffer( self.device_func, self.device_host_func, device_memory_limit, weight=weight ) diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 8513af730..63cb5cbdb 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -92,6 +92,8 @@ class LocalCUDACluster(LocalCluster): but in that case with default (non-managed) memory type. WARNING: managed memory is currently incompatible with NVLink, trying to enable both will result in an exception. + jit_unspill: bool + If True, enable just-in-time unspilling (see proxy_object.ProxyObject). Examples -------- @@ -133,6 +135,7 @@ def __init__( ucx_net_devices=None, rmm_pool_size=None, rmm_managed_memory=False, + jit_unspill=None, **kwargs, ): # Required by RAPIDS libraries (e.g., cuDF) to ensure no context @@ -182,6 +185,11 @@ def __init__( "Processes are necessary in order to use multiple GPUs with Dask" ) + if jit_unspill is None: + self.jit_unspill = dask.config.get("jit-unspill", default=False) + else: + self.jit_unspill = jit_unspill + if data is None: data = ( DeviceHostFile, @@ -191,6 +199,7 @@ def __init__( "local_directory": local_directory or dask.config.get("temporary-directory") or os.getcwd(), + "jit_unspill": self.jit_unspill, }, ) diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py new file mode 100644 index 000000000..cbb33d80c --- /dev/null +++ b/dask_cuda/proxy_object.py @@ -0,0 +1,548 @@ +import operator +import pickle +import threading + +import dask +import dask.dataframe.methods +import dask.dataframe.utils +import distributed.protocol +import distributed.utils +from dask.sizeof import sizeof + +from .is_device_object import is_device_object + +# List of attributes that should be copied to the proxy at creation, which makes +# them accessible without deserialization of the proxied object +_FIXED_ATTRS = ["name"] + + +def asproxy(obj, serializers=None, subclass=None): + """Wrap `obj` in a ProxyObject object if it isn't already. + + Parameters + ---------- + obj: object + Object to wrap in a ProxyObject object. + serializers: list(str), optional + List of serializers to use to serialize `obj`. If None, + no serialization is done. + subclass: class, optional + Specify a subclass of ProxyObject to create instead of ProxyObject. + `subclass` must be pickable. + + Returns + ------- + The ProxyObject proxying `obj` + """ + + if hasattr(obj, "_obj_pxy"): # Already a proxy object + ret = obj + else: + fixed_attr = {} + for attr in _FIXED_ATTRS: + try: + fixed_attr[attr] = getattr(obj, attr) + except AttributeError: + pass + + if subclass is None: + subclass = ProxyObject + ret = subclass( + obj=obj, + fixed_attr=fixed_attr, + type_serialized=pickle.dumps(type(obj)), + typename=dask.utils.typename(type(obj)), + is_cuda_object=is_device_object(obj), + subclass=pickle.dumps(subclass) if subclass else None, + serializers=None, + ) + if serializers is not None: + ret._obj_pxy_serialize(serializers=serializers) + return ret + + +def unproxy(obj): + """Unwrap ProxyObject objects and pass-through anything else. + + Use this function to retrieve the proxied object. + + Parameters + ---------- + obj: object + Any kind of object + + Returns + ------- + The proxied object or `obj` itself if it isn't a ProxyObject + """ + try: + obj = obj._obj_pxy_deserialize() + except AttributeError: + pass + return obj + + +class ProxyObject: + """Object wrapper/proxy for serializable objects + + This is used by DeviceHostFile to delay deserialization of returned objects. + + Objects proxied by an instance of this class will be JIT-deserialized when + accessed. The instance behaves as the proxied object and can be accessed/used + just like the proxied object. + + ProxyObject has some limitations and doesn't mimic the proxied object perfectly. + Thus, if encountering problems remember that it is always possible to use unproxy() + to access the proxied object directly or disable JIT deserialization completely + with `jit_unspill=False`. + + Type checking using instance() works as expected but direct type checking + doesn't: + >>> import numpy as np + >>> from dask_cuda.proxy_object import asproxy + >>> x = np.arange(3) + >>> isinstance(asproxy(x), type(x)) + True + >>> type(asproxy(x)) is type(x) + False + + Parameters + ---------- + obj: object + Any kind of object to be proxied. + fixed_attr: dict + Dictionary of attributes that are accessible without deserializing + the proxied object. + type_serialized: bytes + Pickled type of `obj`. + typename: str + Name of the type of `obj`. + is_cuda_object: boolean + Whether `obj` is a CUDA object or not. + subclass: bytes + Pickled type to use instead of ProxyObject when deserializing. The type + must inherit from ProxyObject. + serializers: list(str), optional + List of serializers to use to serialize `obj`. If None, `obj` + isn't serialized. + """ + + __slots__ = [ + "_obj_pxy", # A dict that holds the state of the proxy object + "_obj_pxy_lock", # Threading lock for all obj_pxy access + "__obj_pxy_cache", # A dict used for caching attributes + ] + + def __init__( + self, + obj, + fixed_attr, + type_serialized, + typename, + is_cuda_object, + subclass, + serializers, + ): + self._obj_pxy = { + "obj": obj, + "fixed_attr": fixed_attr, + "type_serialized": type_serialized, + "typename": typename, + "is_cuda_object": is_cuda_object, + "subclass": subclass, + "serializers": serializers, + } + self._obj_pxy_lock = threading.RLock() + self.__obj_pxy_cache = {} + + def _obj_pxy_get_meta(self): + """Return the metadata of the proxy object. + + Returns + ------- + Dictionary of metadata + """ + with self._obj_pxy_lock: + return {k: self._obj_pxy[k] for k in self._obj_pxy.keys() if k != "obj"} + + def _obj_pxy_serialize(self, serializers): + """Inplace serialization of the proxied object using the `serializers` + + Parameters + ---------- + serializers: list(str) + List of serializers to use to serialize the proxied object. + + Returns + ------- + header: dict + The header of the serialized frames + frames: list(bytes) + List of frames that make up the serialized object + """ + if not serializers: + raise ValueError("Please specify a list of serializers") + + with self._obj_pxy_lock: + if self._obj_pxy["serializers"] is not None and tuple( + self._obj_pxy["serializers"] + ) != tuple(serializers): + # The proxied object is serialized with other serializers + self._obj_pxy_deserialize() + + if self._obj_pxy["serializers"] is None: + self._obj_pxy["obj"] = distributed.protocol.serialize( + self._obj_pxy["obj"], serializers + ) + self._obj_pxy["serializers"] = serializers + + return self._obj_pxy["obj"] + + def _obj_pxy_deserialize(self): + """Inplace deserialization of the proxied object + + Returns + ------- + object + The proxied object (deserialized) + """ + with self._obj_pxy_lock: + if self._obj_pxy["serializers"] is not None: + header, frames = self._obj_pxy["obj"] + self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames) + self._obj_pxy["serializers"] = None + return self._obj_pxy["obj"] + + def _obj_pxy_is_cuda_object(self): + """Return whether the proxied object is a CUDA or not + + Returns + ------- + ret : boolean + Is the proxied object a CUDA object? + """ + with self._obj_pxy_lock: + return self._obj_pxy["is_cuda_object"] + + def __getattr__(self, name): + with self._obj_pxy_lock: + typename = self._obj_pxy["typename"] + if name in _FIXED_ATTRS: + try: + return self._obj_pxy["fixed_attr"][name] + except KeyError: + raise AttributeError( + f"type object '{typename}' has no attribute '{name}'" + ) + + return getattr(self._obj_pxy_deserialize(), name) + + def __str__(self): + return str(self._obj_pxy_deserialize()) + + def __repr__(self): + with self._obj_pxy_lock: + typename = self._obj_pxy["typename"] + ret = ( + f"<{dask.utils.typename(type(self))} at {hex(id(self))} for {typename}" + ) + if self._obj_pxy["serializers"] is not None: + ret += f" (serialized={repr(self._obj_pxy['serializers'])})>" + else: + ret += f" at {hex(id(self._obj_pxy['obj']))}>" + return ret + + @property + def __class__(self): + with self._obj_pxy_lock: + try: + return self.__obj_pxy_cache["type_serialized"] + except KeyError: + ret = pickle.loads(self._obj_pxy["type_serialized"]) + self.__obj_pxy_cache["type_serialized"] = ret + return ret + + def __sizeof__(self): + with self._obj_pxy_lock: + if self._obj_pxy["serializers"] is not None: + frames = self._obj_pxy["obj"][1] + return sum(map(distributed.utils.nbytes, frames)) + else: + return sizeof(self._obj_pxy_deserialize()) + + def __len__(self): + return len(self._obj_pxy_deserialize()) + + def __contains__(self, value): + return value in self._obj_pxy_deserialize() + + def __getitem__(self, key): + return self._obj_pxy_deserialize()[key] + + def __setitem__(self, key, value): + self._obj_pxy_deserialize()[key] = value + + def __delitem__(self, key): + del self._obj_pxy_deserialize()[key] + + def __getslice__(self, i, j): + return self._obj_pxy_deserialize()[i:j] + + def __setslice__(self, i, j, value): + self._obj_pxy_deserialize()[i:j] = value + + def __delslice__(self, i, j): + del self._obj_pxy_deserialize()[i:j] + + def __iter__(self): + return iter(self._obj_pxy_deserialize()) + + def __array__(self): + return getattr(self._obj_pxy_deserialize(), "__array__")() + + def __add__(self, other): + return self._obj_pxy_deserialize() + other + + def __sub__(self, other): + return self._obj_pxy_deserialize() - other + + def __mul__(self, other): + return self._obj_pxy_deserialize() * other + + def __truediv__(self, other): + return operator.truediv(self._obj_pxy_deserialize(), other) + + def __floordiv__(self, other): + return self._obj_pxy_deserialize() // other + + def __mod__(self, other): + return self._obj_pxy_deserialize() % other + + def __divmod__(self, other): + return divmod(self._obj_pxy_deserialize(), other) + + def __pow__(self, other, *args): + return pow(self._obj_pxy_deserialize(), other, *args) + + def __lshift__(self, other): + return self._obj_pxy_deserialize() << other + + def __rshift__(self, other): + return self._obj_pxy_deserialize() >> other + + def __and__(self, other): + return self._obj_pxy_deserialize() & other + + def __xor__(self, other): + return self._obj_pxy_deserialize() ^ other + + def __or__(self, other): + return self._obj_pxy_deserialize() | other + + def __radd__(self, other): + return other + self._obj_pxy_deserialize() + + def __rsub__(self, other): + return other - self._obj_pxy_deserialize() + + def __rmul__(self, other): + return other * self._obj_pxy_deserialize() + + def __rtruediv__(self, other): + return operator.truediv(other, self._obj_pxy_deserialize()) + + def __rfloordiv__(self, other): + return other // self._obj_pxy_deserialize() + + def __rmod__(self, other): + return other % self._obj_pxy_deserialize() + + def __rdivmod__(self, other): + return divmod(other, self._obj_pxy_deserialize()) + + def __rpow__(self, other, *args): + return pow(other, self._obj_pxy_deserialize(), *args) + + def __rlshift__(self, other): + return other << self._obj_pxy_deserialize() + + def __rrshift__(self, other): + return other >> self._obj_pxy_deserialize() + + def __rand__(self, other): + return other & self._obj_pxy_deserialize() + + def __rxor__(self, other): + return other ^ self._obj_pxy_deserialize() + + def __ror__(self, other): + return other | self._obj_pxy_deserialize() + + def __iadd__(self, other): + proxied = self._obj_pxy_deserialize() + proxied += other + return self + + def __isub__(self, other): + proxied = self._obj_pxy_deserialize() + proxied -= other + return self + + def __imul__(self, other): + proxied = self._obj_pxy_deserialize() + proxied *= other + return self + + def __itruediv__(self, other): + with self._obj_pxy_lock: + proxied = self._obj_pxy_deserialize() + self._obj_pxy["obj"] = operator.itruediv(proxied, other) + return self + + def __ifloordiv__(self, other): + proxied = self._obj_pxy_deserialize() + proxied //= other + return self + + def __imod__(self, other): + proxied = self._obj_pxy_deserialize() + proxied %= other + return self + + def __ipow__(self, other): + proxied = self._obj_pxy_deserialize() + proxied **= other + return self + + def __ilshift__(self, other): + proxied = self._obj_pxy_deserialize() + proxied <<= other + return self + + def __irshift__(self, other): + proxied = self._obj_pxy_deserialize() + proxied >>= other + return self + + def __iand__(self, other): + proxied = self._obj_pxy_deserialize() + proxied &= other + return self + + def __ixor__(self, other): + proxied = self._obj_pxy_deserialize() + proxied ^= other + return self + + def __ior__(self, other): + proxied = self._obj_pxy_deserialize() + proxied |= other + return self + + def __neg__(self): + return -self._obj_pxy_deserialize() + + def __pos__(self): + return +self._obj_pxy_deserialize() + + def __abs__(self): + return abs(self._obj_pxy_deserialize()) + + def __invert__(self): + return ~self._obj_pxy_deserialize() + + def __int__(self): + return int(self._obj_pxy_deserialize()) + + def __float__(self): + return float(self._obj_pxy_deserialize()) + + def __complex__(self): + return complex(self._obj_pxy_deserialize()) + + def __index__(self): + return operator.index(self._obj_pxy_deserialize()) + + +@is_device_object.register(ProxyObject) +def obj_pxy_is_device_object(obj: ProxyObject): + """ + In order to avoid de-serializing the proxied object, we call + `_obj_pxy_is_cuda_object()` instead of the default + `hasattr(o, "__cuda_array_interface__")` check. + """ + return obj._obj_pxy_is_cuda_object() + + +@distributed.protocol.dask_serialize.register(ProxyObject) +def obj_pxy_dask_serialize(obj: ProxyObject): + """ + The generic serialization of ProxyObject used by Dask when communicating + ProxyObject. As serializers, it uses "dask" or "pickle", which means + that proxied CUDA objects are spilled to main memory before communicated. + """ + header, frames = obj._obj_pxy_serialize(serializers=["dask", "pickle"]) + return {"proxied-header": header, "obj-pxy-meta": obj._obj_pxy_get_meta()}, frames + + +@distributed.protocol.cuda.cuda_serialize.register(ProxyObject) +def obj_pxy_cuda_serialize(obj: ProxyObject): + """ + The CUDA serialization of ProxyObject used by Dask when communicating using UCX + or another CUDA friendly communication library. As serializers, it uses "cuda", + "dask" or "pickle", which means that proxied CUDA objects are _not_ spilled to + main memory. + """ + if obj._obj_pxy["serializers"] is not None: # Already serialized + header, frames = obj._obj_pxy["obj"] + else: + header, frames = obj._obj_pxy_serialize(serializers=["cuda", "dask", "pickle"]) + return {"proxied-header": header, "obj-pxy-meta": obj._obj_pxy_get_meta()}, frames + + +@distributed.protocol.dask_deserialize.register(ProxyObject) +@distributed.protocol.cuda.cuda_deserialize.register(ProxyObject) +def obj_pxy_dask_deserialize(header, frames): + """ + The generic deserialization of ProxyObject. Notice, it doesn't deserialize + the proxied object at this time. When accessed, the proxied object are + deserialized using the same serializers that were used when the object was + serialized. + """ + meta = header["obj-pxy-meta"] + if meta["subclass"] is None: + subclass = ProxyObject + else: + subclass = pickle.loads(meta["subclass"]) + return subclass( + obj=(header["proxied-header"], frames), + **header["obj-pxy-meta"], + ) + + +@dask.dataframe.utils.hash_object_dispatch.register(ProxyObject) +def obj_pxy_hash_object(obj: ProxyObject, index=True): + return dask.dataframe.utils.hash_object_dispatch(obj._obj_pxy_deserialize(), index) + + +@dask.dataframe.utils.group_split_dispatch.register(ProxyObject) +def obj_pxy_group_split(obj: ProxyObject, c, k, ignore_index=False): + return dask.dataframe.utils.group_split_dispatch( + obj._obj_pxy_deserialize(), c, k, ignore_index + ) + + +@dask.dataframe.utils.make_scalar.register(ProxyObject) +def obj_pxy_make_scalar(obj: ProxyObject): + return dask.dataframe.utils.make_scalar(obj._obj_pxy_deserialize()) + + +@dask.dataframe.methods.concat_dispatch.register(ProxyObject) +def obj_pxy_concat(objs, *args, **kwargs): + # Deserialize concat inputs (in-place) + for i in range(len(objs)): + try: + objs[i] = objs[i]._obj_pxy_deserialize() + except AttributeError: + pass + return dask.dataframe.methods.concat(objs, *args, **kwargs) diff --git a/dask_cuda/tests/test_device_host_file.py b/dask_cuda/tests/test_device_host_file.py index a3327b9c5..2b26320de 100644 --- a/dask_cuda/tests/test_device_host_file.py +++ b/dask_cuda/tests/test_device_host_file.py @@ -4,8 +4,7 @@ import numpy as np import pytest -import dask -from dask import array as da +import dask.array from distributed.protocol import ( deserialize, deserialize_bytes, @@ -23,6 +22,12 @@ cupy = pytest.importorskip("cupy") +def assert_eq(x, y): + # Explicitly calling "cupy.asnumpy" to support `ProxyObject` because + # "cupy" is hardcoded in `dask.array.normalize_to_array()` + return dask.array.assert_eq(cupy.asnumpy(x), cupy.asnumpy(y)) + + def test_device_host_file_config(tmp_path): dhf_disk_path = str(tmp_path / "dask-worker-space" / "storage") with dask.config.set(temporary_directory=str(tmp_path)): @@ -34,13 +39,17 @@ def test_device_host_file_config(tmp_path): @pytest.mark.parametrize("num_host_arrays", [1, 10, 100]) @pytest.mark.parametrize("num_device_arrays", [1, 10, 100]) @pytest.mark.parametrize("array_size_range", [(1, 1000), (100, 100), (1000, 1000)]) +@pytest.mark.parametrize("jit_unspill", [True, False]) def test_device_host_file_short( - tmp_path, num_device_arrays, num_host_arrays, array_size_range + tmp_path, num_device_arrays, num_host_arrays, array_size_range, jit_unspill ): tmpdir = tmp_path / "storage" tmpdir.mkdir() dhf = DeviceHostFile( - device_memory_limit=1024 * 16, memory_limit=1024 * 16, local_directory=tmpdir + device_memory_limit=1024 * 16, + memory_limit=1024 * 16, + local_directory=tmpdir, + jit_unspill=jit_unspill, ) host = [ @@ -64,7 +73,7 @@ def test_device_host_file_short( for k, original in full: acquired = dhf[k] - da.assert_eq(original, acquired) + assert_eq(original, acquired) del dhf[k] assert set(dhf.device.keys()) == set() @@ -72,11 +81,15 @@ def test_device_host_file_short( assert set(dhf.disk.keys()) == set() -def test_device_host_file_step_by_step(tmp_path): +@pytest.mark.parametrize("jit_unspill", [True, False]) +def test_device_host_file_step_by_step(tmp_path, jit_unspill): tmpdir = tmp_path / "storage" tmpdir.mkdir() dhf = DeviceHostFile( - device_memory_limit=1024 * 16, memory_limit=1024 * 16, local_directory=tmpdir + device_memory_limit=1024 * 16, + memory_limit=1024 * 16, + local_directory=tmpdir, + jit_unspill=jit_unspill, ) a = np.random.random(1000) @@ -119,17 +132,17 @@ def test_device_host_file_step_by_step(tmp_path): assert set(dhf.host.keys()) == set(["a2", "b2"]) assert set(dhf.disk.keys()) == set(["a1", "b1"]) - da.assert_eq(dhf["a1"], a) + assert_eq(dhf["a1"], a) del dhf["a1"] - da.assert_eq(dhf["a2"], a) + assert_eq(dhf["a2"], a) del dhf["a2"] - da.assert_eq(dhf["b1"], b) + assert_eq(dhf["b1"], b) del dhf["b1"] - da.assert_eq(dhf["b2"], b) + assert_eq(dhf["b2"], b) del dhf["b2"] - da.assert_eq(dhf["b3"], b) + assert_eq(dhf["b3"], b) del dhf["b3"] - da.assert_eq(dhf["b4"], b) + assert_eq(dhf["b4"], b) del dhf["b4"] assert set(dhf.device.keys()) == set() @@ -152,7 +165,7 @@ def test_serialize_cupy_collection(collection, length, value): assert_func = dd.assert_eq else: x = cupy.arange(10) - assert_func = da.assert_eq + assert_func = assert_eq if length == 0: obj = device_to_host(x) diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py new file mode 100644 index 000000000..9649a0cfd --- /dev/null +++ b/dask_cuda/tests/test_proxy.py @@ -0,0 +1,276 @@ +import operator + +import pytest +from pandas.testing import assert_frame_equal + +from distributed import Client +from distributed.protocol.serialize import deserialize, serialize + +import dask_cudf + +import dask_cuda +from dask_cuda import proxy_object + + +@pytest.mark.parametrize("serializers", [None, ["dask", "pickle"]]) +def test_proxy_object(serializers): + """Check "transparency" of the proxy object""" + + org = list(range(10)) + pxy = proxy_object.asproxy(org, serializers=serializers) + + assert len(org) == len(pxy) + assert org[0] == pxy[0] + assert 1 in pxy + assert -1 not in pxy + assert str(org) == str(pxy) + assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy) + assert "list at " in repr(pxy) + + pxy._obj_pxy_serialize(serializers=["dask", "pickle"]) + assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy) + assert "list (serialized=['dask', 'pickle'])" in repr(pxy) + + assert org == proxy_object.unproxy(pxy) + assert org == proxy_object.unproxy(org) + + +@pytest.mark.parametrize("serializers_first", [None, ["dask", "pickle"]]) +@pytest.mark.parametrize("serializers_second", [None, ["dask", "pickle"]]) +def test_double_proxy_object(serializers_first, serializers_second): + """Check asproxy() when creating a proxy object of a proxy object""" + org = list(range(10)) + pxy1 = proxy_object.asproxy(org, serializers=serializers_first) + assert pxy1._obj_pxy["serializers"] == serializers_first + pxy2 = proxy_object.asproxy(pxy1, serializers=serializers_second) + if serializers_second is None: + # Check that `serializers=None` doesn't change the initial serializers + assert pxy2._obj_pxy["serializers"] == serializers_first + else: + assert pxy2._obj_pxy["serializers"] == serializers_second + assert pxy1 is pxy2 + + +@pytest.mark.parametrize("serializers", [None, ["dask", "pickle"]]) +def test_proxy_object_of_numpy(serializers): + """Check that a proxied numpy array behaves as a regular dataframe""" + + np = pytest.importorskip("numpy") + + # Make sure that equality works, which we use to test the other operators + org = np.arange(10) + 1 + pxy = proxy_object.asproxy(org.copy(), serializers=serializers) + assert all(org == pxy) + assert all(org + 1 != pxy) + + # Check unary scalar operators + for op in [int, float, complex, operator.index, oct, hex]: + org = np.int64(42) + pxy = proxy_object.asproxy(org.copy(), serializers=serializers) + expect = op(org) + got = op(pxy) + assert type(expect) == type(got) + assert expect == got + + # Check unary operators + for op_str in ["neg", "pos", "abs", "inv"]: + op = getattr(operator, op_str) + org = np.arange(10) + 1 + pxy = proxy_object.asproxy(org.copy(), serializers=serializers) + expect = op(org) + got = op(pxy) + assert type(expect) == type(got) + assert all(expect == got) + + # Check binary operators that takes a scalar as second argument + for op_str in ["rshift", "lshift", "pow"]: + op = getattr(operator, op_str) + org = np.arange(10) + 1 + pxy = proxy_object.asproxy(org.copy(), serializers=serializers) + expect = op(org, 2) + got = op(pxy, 2) + assert type(expect) == type(got) + assert all(expect == got) + + # Check binary operators + for op_str in [ + "add", + "eq", + "floordiv", + "ge", + "gt", + "le", + "lshift", + "lt", + "mod", + "mul", + "ne", + "or_", + "sub", + "truediv", + "xor", + "iadd", + "ior", + "iand", + "ifloordiv", + "ilshift", + "irshift", + "ipow", + "imod", + "imul", + "isub", + "ixor", + ]: + op = getattr(operator, op_str) + org = np.arange(10) + 1 + pxy = proxy_object.asproxy(org.copy(), serializers=serializers) + expect = op(org.copy(), org) + got = op(org.copy(), pxy) + assert isinstance(got, type(expect)) + assert all(expect == got) + + expect = op(org.copy(), org) + got = op(pxy, org) + assert isinstance(got, type(expect)) + assert all(expect == got) + + # Check unary truth operators + for op_str in ["not_", "truth"]: + op = getattr(operator, op_str) + org = np.arange(1) + 1 + pxy = proxy_object.asproxy(org.copy(), serializers=serializers) + expect = op(org) + got = op(pxy) + assert type(expect) == type(got) + assert expect == got + + # Check reflected methods + for op_str in [ + "__radd__", + "__rsub__", + "__rmul__", + "__rtruediv__", + "__rfloordiv__", + "__rmod__", + "__rpow__", + "__rlshift__", + "__rrshift__", + "__rxor__", + "__ror__", + ]: + org = np.arange(10) + 1 + pxy = proxy_object.asproxy(org.copy(), serializers=serializers) + expect = getattr(org, op_str)(org) + got = getattr(org, op_str)(pxy) + assert isinstance(got, type(expect)) + assert all(expect == got) + + +@pytest.mark.parametrize("serializers", [None, ["dask"]]) +def test_proxy_object_of_cudf(serializers): + """Check that a proxied cudf dataframe behaves as a regular dataframe""" + cudf = pytest.importorskip("cudf") + df = cudf.DataFrame({"a": range(10)}) + pxy = proxy_object.asproxy(df, serializers=serializers) + assert_frame_equal(df.to_pandas(), pxy.to_pandas()) + + +@pytest.mark.parametrize("proxy_serializers", [None, ["dask"], ["cuda"]]) +@pytest.mark.parametrize("dask_serializers", [["dask"], ["cuda"]]) +def test_serialize_of_proxied_cudf(proxy_serializers, dask_serializers): + """Check that we can serialize a proxied cudf dataframe, which might + be serialized already. + """ + cudf = pytest.importorskip("cudf") + + df = cudf.DataFrame({"a": range(10)}) + pxy = proxy_object.asproxy(df, serializers=proxy_serializers) + header, frames = serialize(pxy, serializers=dask_serializers) + pxy = deserialize(header, frames) + assert_frame_equal(df.to_pandas(), pxy.to_pandas()) + + +@pytest.mark.parametrize("jit_unspill", [True, False]) +def test_spilling_local_cuda_cluster(jit_unspill): + """Testing spilling of a proxied cudf dataframe in a local cuda cluster""" + cudf = pytest.importorskip("cudf") + + def task(x): + assert isinstance(x, cudf.DataFrame) + if jit_unspill: + # Check that `x` is a proxy object and the proxied DataFrame is serialized + assert type(x) is proxy_object.ProxyObject + assert x._obj_pxy_get_meta()["serializers"] == ["dask", "pickle"] + else: + assert type(x) == cudf.DataFrame + assert len(x) == 10 # Trigger deserialization + return x + + # Notice, setting `device_memory_limit=1B` to trigger spilling + with dask_cuda.LocalCUDACluster( + n_workers=1, device_memory_limit="1B", jit_unspill=jit_unspill + ) as cluster: + with Client(cluster): + df = cudf.DataFrame({"a": range(10)}) + ddf = dask_cudf.from_cudf(df, npartitions=1) + ddf = ddf.map_partitions(task, meta=df.head()) + got = ddf.compute() + assert_frame_equal(got.to_pandas(), df.to_pandas()) + + +class _PxyObjTest(proxy_object.ProxyObject): + """ + A class that: + - defines `__dask_tokenize__` in order to avoid deserialization when + calling `client.scatter()` + - Asserts that no deserialization is performaned when communicating. + """ + + def __dask_tokenize__(self): + return 42 + + def _obj_pxy_deserialize(self): + if self.assert_on_deserializing: + assert self._obj_pxy["serializers"] is None + return super()._obj_pxy_deserialize() + + +@pytest.mark.parametrize("send_serializers", [None, ["dask", "pickle"], ["cuda"]]) +@pytest.mark.parametrize("protocol", ["tcp", "ucx"]) +def test_communicating_proxy_objects(protocol, send_serializers): + """Testing serialization of cuDF dataframe when communicating""" + cudf = pytest.importorskip("cudf") + + def task(x): + # Check that the subclass survives the trip from client to worker + assert isinstance(x, _PxyObjTest) + serializers_used = list(x._obj_pxy_get_meta()["serializers"]) + + # Check that `x` is serialized with the expected serializers + if protocol == "ucx": + if send_serializers is None: + assert serializers_used == ["cuda", "dask", "pickle"] + else: + assert serializers_used == send_serializers + else: + assert serializers_used == ["dask", "pickle"] + + with dask_cuda.LocalCUDACluster( + n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx" + ) as cluster: + with Client(cluster) as client: + df = cudf.DataFrame({"a": range(10)}) + df = proxy_object.asproxy( + df, serializers=send_serializers, subclass=_PxyObjTest + ) + + # Notice, in one case we expect deserialization when communicating. + # Since "tcp" cannot send device memory directly, it will be re-serialized + # using the default dask serializers that spill the data to main memory. + if protocol == "tcp" and send_serializers == ["cuda"]: + df.assert_on_deserializing = False + else: + df.assert_on_deserializing = True + df = client.scatter(df) + client.submit(task, df).result() + client.shutdown() # Avoids a UCX shutdown error