Skip to content

Commit

Permalink
Make proxy UCX tests async
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Jan 16, 2023
1 parent 10b73ac commit 4ba0d30
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from dask.sizeof import sizeof
from distributed import Client
from distributed.protocol.serialize import deserialize, serialize
from distributed.utils_test import gen_test

import dask_cuda
from dask_cuda import proxy_object
from dask_cuda import LocalCUDACluster, proxy_object
from dask_cuda.disk_io import SpillToDiskFile
from dask_cuda.proxify_device_objects import proxify_device_objects
from dask_cuda.proxify_host_file import ProxifyHostFile
Expand Down Expand Up @@ -299,8 +300,10 @@ def task(x):
return x

# Notice, setting `device_memory_limit=1B` to trigger spilling
with dask_cuda.LocalCUDACluster(
n_workers=1, device_memory_limit="1B", jit_unspill=jit_unspill
with LocalCUDACluster(
n_workers=1,
device_memory_limit="1B",
jit_unspill=jit_unspill,
) as cluster:
with Client(cluster):
df = cudf.DataFrame({"a": range(10)})
Expand Down Expand Up @@ -395,11 +398,12 @@ def _pxy_deserialize(self):

@pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
def test_communicating_proxy_objects(protocol, send_serializers):
@gen_test(timeout=20)
async def test_communicating_proxy_objects(protocol, send_serializers):
"""Testing serialization of cuDF dataframe when communicating"""
cudf = pytest.importorskip("cudf")

def task(x):
async def task(x):
# Check that the subclass survives the trip from client to worker
assert isinstance(x, _PxyObjTest)
serializers_used = x._pxy_get().serializer
Expand All @@ -413,10 +417,13 @@ def task(x):
else:
assert serializers_used == "dask"

with dask_cuda.LocalCUDACluster(
n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx"
async with dask_cuda.LocalCUDACluster(
n_workers=1,
protocol=protocol,
enable_tcp_over_ucx=protocol == "ucx",
asynchronous=True,
) as cluster:
with Client(cluster) as client:
async with Client(cluster, asynchronous=True) as client:
df = cudf.DataFrame({"a": range(10)})
df = proxy_object.asproxy(
df, serializers=send_serializers, subclass=_PxyObjTest
Expand All @@ -429,19 +436,20 @@ def task(x):
df._pxy_get().assert_on_deserializing = False
else:
df._pxy_get().assert_on_deserializing = True
df = client.scatter(df)
client.submit(task, df).result()
client.shutdown() # Avoids a UCX shutdown error
df = await client.scatter(df)
await client.submit(task, df)
await client.shutdown() # Avoids a UCX shutdown error


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("shared_fs", [True, False])
def test_communicating_disk_objects(protocol, shared_fs):
@gen_test(timeout=20)
async def test_communicating_disk_objects(protocol, shared_fs):
"""Testing disk serialization of cuDF dataframe when communicating"""
cudf = pytest.importorskip("cudf")
ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs

def task(x):
async def task(x):
# Check that the subclass survives the trip from client to worker
assert isinstance(x, _PxyObjTest)
serializer_used = x._pxy_get().serializer
Expand All @@ -450,16 +458,19 @@ def task(x):
else:
assert serializer_used == "dask"

with dask_cuda.LocalCUDACluster(
n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx"
async with dask_cuda.LocalCUDACluster(
n_workers=1,
protocol=protocol,
enable_tcp_over_ucx=protocol == "ucx",
asynchronous=True,
) as cluster:
with Client(cluster) as client:
async with Client(cluster, asynchronous=True) as client:
df = cudf.DataFrame({"a": range(10)})
df = proxy_object.asproxy(df, serializers=("disk",), subclass=_PxyObjTest)
df._pxy_get().assert_on_deserializing = False
df = client.scatter(df)
client.submit(task, df).result()
client.shutdown() # Avoids a UCX shutdown error
df = await client.scatter(df)
await client.submit(task, df)
await client.shutdown() # Avoids a UCX shutdown error


@pytest.mark.parametrize("array_module", ["numpy", "cupy"])
Expand Down

0 comments on commit 4ba0d30

Please sign in to comment.