Skip to content

Commit

Permalink
Filter out non-trivial frames to transmit
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Mar 20, 2020
1 parent e5b4d3d commit 2ae3427
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
8 changes: 4 additions & 4 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ async def read(self, deserializers=None):
lengths = struct.unpack("Q" * n_frames, lengths)

frames = [bytearray(each_length) for each_length in lengths]
for each_frame in frames:
recv_frames = [each_frame for each_frame in frames if len(each_frame) > 0]
for each_frame in recv_frames:
each_length = len(each_frame)
if each_length:
n = await stream.read_into(each_frame)
assert n == each_length, (n, each_length)
n = await stream.read_into(each_frame)
assert n == each_length, (n, each_length)
except StreamClosedError as e:
self.stream = None
if not shutting_down():
Expand Down
17 changes: 11 additions & 6 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ async def write(
frames = await to_frames(
msg, serializers=serializers, on_error=on_error
)
send_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]

# Send meta data
await self.ep.send(np.array([len(frames)], dtype=np.uint64))
Expand All @@ -159,10 +162,10 @@ async def write(
await self.ep.send(
np.array([nbytes(f) for f in frames], dtype=np.uint64)
)

# Send frames
for frame in frames:
if nbytes(frame) > 0:
await self.ep.send(frame)
for frame in send_frames:
await self.ep.send(frame)
return sum(map(nbytes, frames))
except (ucp.exceptions.UCXBaseException):
self.abort()
Expand Down Expand Up @@ -195,9 +198,11 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
else np.empty(each_size, dtype=np.uint8)
for is_cuda, each_size in zip(is_cudas.tolist(), sizes.tolist())
]
for each_frame in frames:
if len(each_frame) > 0:
await self.ep.recv(each_frame)
recv_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]
for each_frame in recv_frames:
await self.ep.recv(each_frame)
msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
)
Expand Down

0 comments on commit 2ae3427

Please sign in to comment.