diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 7b7475a77167c..5e824b0f5a65a 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -4,14 +4,12 @@ @pytest.mark.parametrize( - "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", - [ + "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [ (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"), (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"), (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"), - # TODO: figure out why PP=4 tests are flaky - # (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"), - # (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"), ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): pp_args = [ diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 388f934ef75a6..edff9b6c93e09 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -224,13 +224,27 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # broadcasted to. self.non_driver_workers: List[RayWorkerWrapper] = [] + tp_driver_worker_ranks = [] + non_driver_worker_ranks = [] for idx, rank in enumerate(worker_ranks[1:]): # We need to skip the driver worker, which we # do by skipping worker_ranks[0] which is always 0. if rank % self.parallel_config.tensor_parallel_size == 0: self.tp_driver_workers.append(self.workers[idx]) + tp_driver_worker_ranks.append(rank) else: self.non_driver_workers.append(self.workers[idx]) + non_driver_worker_ranks.append(rank) + + # Enforce rank order for correct rank to return final output. + self.tp_driver_workers = [ + worker for _, worker in sorted( + zip(tp_driver_worker_ranks, self.tp_driver_workers)) + ] + self.non_driver_workers = [ + worker for _, worker in sorted( + zip(non_driver_worker_ranks, self.non_driver_workers)) + ] def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest]