From 0d64f3a3c2f72543420b6f2967e8e789ad265a27 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 20 Mar 2020 03:10:06 +0100 Subject: [PATCH] Synchronize default CUDA stream before UCX send/recv (#3598) * Synchronize default CUDA stream before UCX send/recv * Add more clarity on UCX.write comment Co-Authored-By: Mark Harris * Add more clarity on UCX.read comment Co-Authored-By: Mark Harris Co-authored-by: Mark Harris --- distributed/comm/ucx.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 7295b11bb48..04eecdf4482 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -35,6 +35,15 @@ cuda_array = None +def synchronize_stream(stream=0): + import numba.cuda + + ctx = numba.cuda.current_context() + cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) + stream = numba.cuda.driver.Stream(ctx, cu_stream, None) + stream.synchronize() + + def init_once(): global ucp, cuda_array if ucp is not None: @@ -160,6 +169,14 @@ async def write( np.array([nbytes(f) for f in frames], dtype=np.uint64) ) # Send frames + + # It is necessary to first synchronize the default stream before start sending + # We synchronize the default stream because UCX is not stream-ordered and + # syncing the default stream will wait for other non-blocking CUDA streams. + # Note this is only sufficient if the memory being sent is not currently in use on + # non-blocking CUDA streams. + synchronize_stream(0) + for frame in frames: if nbytes(frame) > 0: await self.ep.send(frame) @@ -196,13 +213,20 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): frame = cuda_array(size) else: frame = np.empty(size, dtype=np.uint8) - await self.ep.recv(frame) frames.append(frame) else: if is_cuda: frames.append(cuda_array(size)) else: frames.append(b"") + + # It is necessary to first populate `frames` with CUDA arrays and synchronize + # the default stream before starting receiving to ensure buffers have been allocated + synchronize_stream(0) + for i, (is_cuda, size) in enumerate(zip(is_cudas.tolist(), sizes.tolist())): + if size > 0: + await self.ep.recv(frames[i]) + msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers )