diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 01d8df47d59..c9ff59568e7 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -35,6 +35,15 @@ cuda_array = None +def synchronize_stream(stream=0): + import numba.cuda + + ctx = numba.cuda.current_context() + cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) + stream = numba.cuda.driver.Stream(ctx, cu_stream, None) + stream.synchronize() + + def init_once(): global ucp, cuda_array if ucp is not None: @@ -164,6 +173,14 @@ async def write( ) # Send frames + + # It is necessary to first synchronize the default stream before start sending + # We synchronize the default stream because UCX is not stream-ordered and + # syncing the default stream will wait for other non-blocking CUDA streams. + # Note this is only sufficient if the memory being sent is not currently in use on + # non-blocking CUDA streams. + synchronize_stream(0) + for each_frame in send_frames: await self.ep.send(each_frame) return sum(map(nbytes, send_frames)) @@ -201,6 +218,11 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): recv_frames = [ each_frame for each_frame in frames if len(each_frame) > 0 ] + + # It is necessary to first populate `frames` with CUDA arrays and synchronize + # the default stream before starting receiving to ensure buffers have been allocated + synchronize_stream(0) + for each_frame in recv_frames: await self.ep.recv(each_frame) msg = await from_frames(