Skip to content

Commit

Permalink
Fix disagg hang caused by the prefill and decode communication issues
Browse files Browse the repository at this point in the history
  • Loading branch information
houseroad committed Feb 4, 2025
1 parent 5d98d56 commit 8a54ef7
Showing 1 changed file with 32 additions and 37 deletions.
69 changes: 32 additions & 37 deletions vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,

self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock()
self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
Expand Down Expand Up @@ -116,11 +116,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
hidden = hidden.clone()

buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])

with self.buffer_lock:
for data in buffer_item:
self.buffer_size += self._get_element_size(data)
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()

self.buffer_size += data_size
self.buffer.append(buffer_item)
self.buffer_cv.notify()

def _is_end_signal(self, signal):
return signal is None
Expand All @@ -143,35 +151,29 @@ def drop_select_handler(self):
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]

matched_length = 0

# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with self.buffer_lock:

def is_buffer_available(
tokens_roi_recver: List[torch.Tensor],
) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for _ in range(len(self.buffer)):

temp_length = self._matches(self.buffer[0],
tokens_roi_recver)
if temp_length > 0:
matched_length = temp_length
break
if self._matches(self.buffer[0], tokens_roi_recver) > 0:
return True
# rotate the element we just accessed to the end
self.buffer.rotate(-1)

if matched_length > 0:
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)

else:
# no match, just send None
for _ in range(5):
self.data_pipe.send_tensor(None)
return False

with self.buffer_cv:
while not is_buffer_available(tokens_roi_recver):
self.buffer_cv.wait()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()

except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
Expand Down Expand Up @@ -215,13 +217,6 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:

if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()

self._add_to_buffer(input_tokens, roi, key, value, hidden)

# when calling the insert, the current process is a sender
Expand Down

0 comments on commit 8a54ef7

Please sign in to comment.