Skip to content

Commit

Permalink
[Bugfix] Fix for multinode crash on 4 PP (#6495)
Browse files Browse the repository at this point in the history
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
  • Loading branch information
andoorve authored Jul 17, 2024
1 parent 5bf35a9 commit 5fa6e98
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@


@pytest.mark.parametrize(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME",
[
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
# TODO: figure out why PP=4 tests are flaky
# (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
# (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args = [
Expand Down
14 changes: 14 additions & 0 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,27 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []

tp_driver_worker_ranks = []
non_driver_worker_ranks = []
for idx, rank in enumerate(worker_ranks[1:]):
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[idx])
tp_driver_worker_ranks.append(rank)
else:
self.non_driver_workers.append(self.workers[idx])
non_driver_worker_ranks.append(rank)

# Enforce rank order for correct rank to return final output.
self.tp_driver_workers = [
worker for _, worker in sorted(
zip(tp_driver_worker_ranks, self.tp_driver_workers))
]
self.non_driver_workers = [
worker for _, worker in sorted(
zip(non_driver_worker_ranks, self.non_driver_workers))
]

def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
Expand Down

0 comments on commit 5fa6e98

Please sign in to comment.