diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 0bf785fb521..ceda7f2ac43 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -35,6 +35,15 @@ cuda_array = None +def synchronize_stream(stream=0): + import numba.cuda + + ctx = numba.cuda.current_context() + cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) + stream = numba.cuda.driver.Stream(ctx, cu_stream, None) + stream.synchronize() + + def init_once(): global ucp, cuda_array if ucp is not None: @@ -164,6 +173,14 @@ async def write( ) # Send frames + + # It is necessary to first synchronize the default stream before start sending + # We synchronize the default stream because UCX is not stream-ordered and + # syncing the default stream will wait for other non-blocking CUDA streams. + # Note this is only sufficient if the memory being sent is not currently in use on + # non-blocking CUDA streams. + synchronize_stream(0) + for frame in send_frames: await self.ep.send(frame) return sum(map(nbytes, frames)) @@ -203,6 +220,14 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): ] for each_frame in recv_frames: await self.ep.recv(each_frame) + + # It is necessary to first populate `frames` with CUDA arrays and synchronize + # the default stream before starting receiving to ensure buffers have been allocated + synchronize_stream(0) + + for each_frame in recv_frames: + await self.ep.recv(each_frame) + msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers )