Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCP simplify receiving frames in comms #3599

Closed
wants to merge 10 commits into from
18 changes: 6 additions & 12 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,12 @@ async def read(self, deserializers=None):
lengths = await stream.read_bytes(8 * n_frames)
lengths = struct.unpack("Q" * n_frames, lengths)

frames = []
for length in lengths:
if length:
if self._iostream_has_read_into:
frame = bytearray(length)
n = await stream.read_into(frame)
assert n == length, (n, length)
else:
frame = await stream.read_bytes(length)
else:
frame = b""
frames.append(frame)
frames = [bytearray(each_length) for each_length in lengths]
recv_frames = [each_frame for each_frame in frames if len(each_frame) > 0]
for each_frame in recv_frames:
each_length = len(each_frame)
n = await stream.read_into(each_frame)
assert n == each_length, (n, each_length)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this change has a semantic difference in that previously we used to include empty frames but now we don't. Is this correct?

Is there a specific reason for this change? Does it affect performance in some way, or is it strictly cosmetic? If it's strictly cosmetic, and if there is a change to semantics (however minor) I'm tempted to avoid the change just because I wouldn't be surprised if it has some unforeseen effect.

Copy link
Member Author

@jakirkham jakirkham Mar 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not true actually. We excluded them before as well. It's just now done outside of the for-loop. 🙂

The idea was to more-or-less pass control to Tornado or UCX-Py completely while receiving frames without all of the boilerplate in Dask between each receive. Tried asyncio.gather to push this even further, but Tornado doesn't seem to like that (seems to work with UCX-Py though 😉).

In any event I thought it would be a nice thing to do the same thing we are doing in UCX for TCP. At the end of the day, I don't have strong feelings about it (especially if there is push back). 🤷‍♂

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any other thoughts here, @mrocklin? 🙂

except StreamClosedError as e:
self.stream = None
if not shutting_down():
Expand Down
40 changes: 18 additions & 22 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ async def write(
frames = await to_frames(
msg, serializers=serializers, on_error=on_error
)
send_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]

# Send meta data
await self.ep.send(np.array([len(frames)], dtype=np.uint64))
Expand All @@ -168,6 +171,7 @@ async def write(
await self.ep.send(
np.array([nbytes(f) for f in frames], dtype=np.uint64)
)

# Send frames

# It is necessary to first synchronize the default stream before start sending
Expand All @@ -177,10 +181,9 @@ async def write(
# non-blocking CUDA streams.
synchronize_stream(0)

for frame in frames:
if nbytes(frame) > 0:
await self.ep.send(frame)
return sum(map(nbytes, frames))
for each_frame in send_frames:
await self.ep.send(each_frame)
return sum(map(nbytes, send_frames))
except (ucp.exceptions.UCXBaseException):
self.abort()
raise CommClosedError("While writing, the connection was closed")
Expand All @@ -206,29 +209,22 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
raise CommClosedError("While reading, the connection was closed")
else:
# Recv frames
frames = []
for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()):
if size > 0:
if is_cuda:
frame = cuda_array(size)
else:
frame = np.empty(size, dtype=np.uint8)
frames.append(frame)
else:
if is_cuda:
frames.append(cuda_array(size))
else:
frames.append(b"")
frames = [
cuda_array(each_size)
if is_cuda
else np.empty(each_size, dtype=np.uint8)
for is_cuda, each_size in zip(is_cudas.tolist(), sizes.tolist())
]
recv_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]

# It is necessary to first populate `frames` with CUDA arrays and synchronize
# the default stream before starting receiving to ensure buffers have been allocated
synchronize_stream(0)
for i, (is_cuda, size) in enumerate(
zip(is_cudas.tolist(), sizes.tolist())
):
if size > 0:
await self.ep.recv(frames[i])

for each_frame in recv_frames:
await self.ep.recv(each_frame)
msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
)
Expand Down