From 704aa6079645ac9a5296566fed2f842d12f19c3d Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Fri, 16 Aug 2024 21:17:32 +0000 Subject: [PATCH 01/18] Support vLLM single and multi-host TPUs on GKE --- requirements-tpu.txt | 2 +- vllm/attention/backends/pallas.py | 5 +++- .../device_communicators/tpu_communicator.py | 27 ++++++++++++++++++- vllm/executor/ray_tpu_executor.py | 12 +++++++++ vllm/executor/ray_utils.py | 4 +++ 5 files changed, 47 insertions(+), 3 deletions(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 5eb27b39eb623..7624010371e81 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -4,4 +4,4 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. -ray +ray[default,serve] diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 4ecf698c8d514..adf5bf812a1d3 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -118,7 +118,10 @@ def __init__( raise NotImplementedError("TPU version must be 4 or higher.") self.megacore_mode = None - tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE") + tpu_type = tpu_type.lower() + if "lite" not in tpu_type: if self.num_kv_heads % 2 == 0: self.megacore_mode = "kv_head" diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 81a141e86206a..7830afb80a910 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,3 +1,4 @@ +import os import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -24,9 +25,33 @@ def __init__(self, group: ProcessGroup): # be simply calculated as follows. global_rank = dist.get_rank(group) global_world_size = dist.get_world_size(group) - num_nodes = len(ray.nodes()) + + # Calculate how many TPU nodes are in the current placement group. + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + + print(f"current pg: {current_pg.id.hex()}") + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg['bundles_to_node_id'].items(): + nodes_in_pg.add(node) + print(f"pg nodes: {nodes_in_pg}") + num_nodes = len(nodes_in_pg) + local_world_size = global_world_size // num_nodes local_rank = global_rank % local_world_size + + print(f"global_rank: {global_rank}") + print(f"global_world_size: {global_world_size}") + print(f"num_nodes: {num_nodes}") + print(f"local_world_size: {local_world_size}") + print(f"local_rank: {local_rank}") + + # Ensure environment variables are set for multihost deployments. + os.environ['CLOUD_TPU_TASK_ID'] = str(global_rank) + os.environ['TPU_VISIBLE_CHIPS'] = str(local_rank) + pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 7048d47980723..4d593f8cb3966 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -70,6 +70,16 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_module_name = "vllm.worker.tpu_worker" worker_class_name = "TPUWorker" + # GKE does not fetch environment information from metadata server + # and instead sets these from within the Ray process. Therefore we + # need to override the Ray environment variables manually. + override_env = {} + if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: + override_env.update( + {"TPU_CHIPS_PER_HOST_BOUNDS": os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]}) + if "TPU_HOST_BOUNDS" in os.environ: + override_env.update({"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) + worker = ray.remote( num_cpus=0, resources={"TPU": 1}, @@ -80,6 +90,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) + if override_env: + worker.override_env_vars.remote(override_env) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index ab283467d4783..b92bb03e2402f 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple, Union from vllm.config import ParallelConfig @@ -63,6 +64,9 @@ def execute_model_spmd( return execute_model_req, output return output + def override_env_vars(self, vars): + os.environ.update(vars) + ray_import_err = None except ImportError as e: From 98bf7238bd84560b7af715387fdaf185303fce20 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Fri, 16 Aug 2024 21:25:45 +0000 Subject: [PATCH 02/18] fix format --- vllm/executor/ray_tpu_executor.py | 152 +++++++++++++++++++----------- 1 file changed, 99 insertions(+), 53 deletions(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 4d593f8cb3966..7fed792b4d2cb 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -2,8 +2,16 @@ import os from collections import defaultdict from itertools import islice, repeat -from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, - Union) +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Tuple, + Union, +) import vllm.envs as envs from vllm.executor.executor_base import ExecutorAsyncBase @@ -11,8 +19,13 @@ from vllm.executor.tpu_executor import TPUExecutor from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) +from vllm.utils import ( + get_distributed_init_method, + get_ip, + get_open_port, + get_vllm_instance_id, + make_async, +) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -24,7 +37,6 @@ class RayTPUExecutor(TPUExecutor): - def __init__(self, *args, **kwargs): # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. @@ -47,8 +59,9 @@ def _init_executor(self) -> None: # Create the parallel TPU workers. self._init_workers_ray(placement_group) - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): + def _init_workers_ray( + self, placement_group: "PlacementGroup", **ray_remote_kwargs + ): # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None @@ -76,9 +89,16 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", override_env = {} if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: override_env.update( - {"TPU_CHIPS_PER_HOST_BOUNDS": os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]}) + { + "TPU_CHIPS_PER_HOST_BOUNDS": os.environ[ + "TPU_CHIPS_PER_HOST_BOUNDS" + ] + } + ) if "TPU_HOST_BOUNDS" in os.environ: - override_env.update({"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) + override_env.update( + {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]} + ) worker = ray.remote( num_cpus=0, @@ -111,11 +131,13 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", raise ValueError( "Ray does not allocate any TPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " - "TPU node.") + "TPU node." + ) # Get the set of TPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = self._run_workers( + "get_node_and_gpu_ids", use_dummy_driver=True + ) node_workers = defaultdict(list) for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): @@ -124,14 +146,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", VLLM_INSTANCE_ID = get_vllm_instance_id() # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, - "VLLM_TRACE_FUNCTION": - str(envs.VLLM_TRACE_FUNCTION), - }, ) for _ in worker_node_and_gpu_ids] - self._run_workers("update_environment_variables", - all_args=all_args_to_update_environment_variables) + all_args_to_update_environment_variables = [ + ( + { + "VLLM_INSTANCE_ID": VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), + }, + ) + for _ in worker_node_and_gpu_ids + ] + self._run_workers( + "update_environment_variables", + all_args=all_args_to_update_environment_variables, + ) if len(node_workers) == 1: # in single node case, we don't need to get the IP address. @@ -144,7 +171,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # the node. driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) + driver_ip, get_open_port() + ) # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ @@ -152,26 +180,28 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank=node_workers[node_id].index(rank), rank=rank, distributed_init_method=distributed_init_method, - ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ) + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, + ) def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None + self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - return self.driver_worker.execute_method("execute_model", - execute_model_req) + return self.driver_worker.execute_method( + "execute_model", execute_model_req + ) def _run_workers( self, @@ -199,19 +229,27 @@ def _run_workers( if max_concurrent_workers: raise NotImplementedError( - "max_concurrent_workers is not supported yet.") + "max_concurrent_workers is not supported yet." + ) count = len(self.workers) - all_worker_args = repeat(args, count) if all_args is None \ + all_worker_args = ( + repeat(args, count) + if all_args is None else islice(all_args, 1, None) - all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + ) + all_worker_kwargs = ( + repeat(kwargs, count) + if all_kwargs is None else islice(all_kwargs, 1, None) + ) # Start the ray workers first. ray_worker_outputs = [ worker.execute_method.remote(method, *worker_args, **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) + for (worker, worker_args, worker_kwargs) in zip( + self.workers, all_worker_args, all_worker_kwargs + ) ] if async_run_remote_workers_only: @@ -224,12 +262,15 @@ def _run_workers( # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) + method, *driver_args, **driver_kwargs + ) else: assert self.driver_dummy_worker is not None driver_worker_output = ray.get( self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) + method, *driver_args, **driver_kwargs + ) + ) # Get the results of the ray workers. if self.workers: ray_worker_outputs = ray.get(ray_worker_outputs) @@ -242,20 +283,26 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: ray.get(parallel_worker_tasks) def determine_num_available_blocks(self) -> Tuple[int, int]: - num_blocks = self._run_workers("determine_num_available_blocks", ) + num_blocks = self._run_workers( + "determine_num_available_blocks", + ) num_tpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) return num_tpu_blocks, num_cpu_blocks - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) + def initialize_cache( + self, num_gpu_blocks: int, num_cpu_blocks: int + ) -> None: + logger.info( + "# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks + ) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) + self._run_workers( + "initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + ) def execute_model( self, @@ -265,7 +312,8 @@ def execute_model( self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", async_run_remote_workers_only=True, - **self.extra_execute_model_run_workers_kwargs) + **self.extra_execute_model_run_workers_kwargs, + ) # Only the driver worker returns the sampling results. return self._driver_execute_model(execute_model_req) @@ -283,18 +331,18 @@ def stop_remote_worker_execution_loop(self) -> None: class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers self.parallel_worker_tasks = asyncio.create_task( - self._start_worker_execution_loop()) + self._start_worker_execution_loop() + ) # Only the driver worker returns the sampling results. return await self._driver_execute_model_async(execute_model_req) @@ -311,11 +359,9 @@ async def stop_remote_worker_execution_loop_async(self) -> None: await parallel_worker_tasks async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None + self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - return await self.driver_exec_method("execute_model", - execute_model_req) + return await self.driver_exec_method("execute_model", execute_model_req) async def _start_worker_execution_loop(self): coros = [ From 44683c9db581e03167958c1ca433ae43910c81fd Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Fri, 16 Aug 2024 21:28:18 +0000 Subject: [PATCH 03/18] fix import order --- vllm/executor/ray_tpu_executor.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 7fed792b4d2cb..68cf830e454d6 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -2,16 +2,8 @@ import os from collections import defaultdict from itertools import islice, repeat -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Dict, - List, - Optional, - Tuple, - Union, -) +from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, + Union) import vllm.envs as envs from vllm.executor.executor_base import ExecutorAsyncBase @@ -19,13 +11,8 @@ from vllm.executor.tpu_executor import TPUExecutor from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import ( - get_distributed_init_method, - get_ip, - get_open_port, - get_vllm_instance_id, - make_async, -) +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + get_vllm_instance_id, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy From 726490ed8500bcd117e936f2aaa1031168782fb1 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Fri, 16 Aug 2024 21:31:04 +0000 Subject: [PATCH 04/18] fix import order --- vllm/distributed/device_communicators/tpu_communicator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 7830afb80a910..4816d3f2fb85a 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,4 +1,5 @@ import os + import torch import torch.distributed as dist from torch.distributed import ProcessGroup From 1930cc9cb71e7245445ce98c880d6f0b18f11156 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Fri, 16 Aug 2024 21:37:59 +0000 Subject: [PATCH 05/18] format --- vllm/executor/ray_tpu_executor.py | 117 ++++++++++++------------------ 1 file changed, 47 insertions(+), 70 deletions(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 68cf830e454d6..93c6bf607cc3c 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -24,6 +24,7 @@ class RayTPUExecutor(TPUExecutor): + def __init__(self, *args, **kwargs): # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. @@ -46,9 +47,8 @@ def _init_executor(self) -> None: # Create the parallel TPU workers. self._init_workers_ray(placement_group) - def _init_workers_ray( - self, placement_group: "PlacementGroup", **ray_remote_kwargs - ): + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None @@ -75,17 +75,13 @@ def _init_workers_ray( # need to override the Ray environment variables manually. override_env = {} if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: - override_env.update( - { - "TPU_CHIPS_PER_HOST_BOUNDS": os.environ[ - "TPU_CHIPS_PER_HOST_BOUNDS" - ] - } - ) + override_env.update({ + "TPU_CHIPS_PER_HOST_BOUNDS": + os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] + }) if "TPU_HOST_BOUNDS" in os.environ: override_env.update( - {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]} - ) + {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) worker = ray.remote( num_cpus=0, @@ -118,13 +114,11 @@ def _init_workers_ray( raise ValueError( "Ray does not allocate any TPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " - "TPU node." - ) + "TPU node.") # Get the set of TPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers( - "get_node_and_gpu_ids", use_dummy_driver=True - ) + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) node_workers = defaultdict(list) for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): @@ -133,15 +127,12 @@ def _init_workers_ray( VLLM_INSTANCE_ID = get_vllm_instance_id() # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [ - ( - { - "VLLM_INSTANCE_ID": VLLM_INSTANCE_ID, - "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), - }, - ) - for _ in worker_node_and_gpu_ids - ] + all_args_to_update_environment_variables = [({ + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + }, ) for _ in worker_node_and_gpu_ids] self._run_workers( "update_environment_variables", all_args=all_args_to_update_environment_variables, @@ -158,8 +149,7 @@ def _init_workers_ray( # the node. driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port() - ) + driver_ip, get_open_port()) # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ @@ -167,28 +157,28 @@ def _init_workers_ray( local_rank=node_workers[node_id].index(rank), rank=rank, distributed_init_method=distributed_init_method, - ) - for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") self._run_workers( "load_model", - max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, ) def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] = None + self, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - return self.driver_worker.execute_method( - "execute_model", execute_model_req - ) + return self.driver_worker.execute_method("execute_model", + execute_model_req) def _run_workers( self, @@ -216,27 +206,19 @@ def _run_workers( if max_concurrent_workers: raise NotImplementedError( - "max_concurrent_workers is not supported yet." - ) + "max_concurrent_workers is not supported yet.") count = len(self.workers) - all_worker_args = ( - repeat(args, count) - if all_args is None - else islice(all_args, 1, None) - ) - all_worker_kwargs = ( - repeat(kwargs, count) - if all_kwargs is None - else islice(all_kwargs, 1, None) - ) + all_worker_args = (repeat(args, count) if all_args is None else islice( + all_args, 1, None)) + all_worker_kwargs = (repeat(kwargs, count) if all_kwargs is None else + islice(all_kwargs, 1, None)) # Start the ray workers first. ray_worker_outputs = [ worker.execute_method.remote(method, *worker_args, **worker_kwargs) - for (worker, worker_args, worker_kwargs) in zip( - self.workers, all_worker_args, all_worker_kwargs - ) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] if async_run_remote_workers_only: @@ -249,15 +231,12 @@ def _run_workers( # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs - ) + method, *driver_args, **driver_kwargs) else: assert self.driver_dummy_worker is not None driver_worker_output = ray.get( self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs - ) - ) + method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: ray_worker_outputs = ray.get(ray_worker_outputs) @@ -270,19 +249,15 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: ray.get(parallel_worker_tasks) def determine_num_available_blocks(self) -> Tuple[int, int]: - num_blocks = self._run_workers( - "determine_num_available_blocks", - ) + num_blocks = self._run_workers("determine_num_available_blocks", ) num_tpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) return num_tpu_blocks, num_cpu_blocks - def initialize_cache( - self, num_gpu_blocks: int, num_cpu_blocks: int - ) -> None: - logger.info( - "# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks - ) + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks self._run_workers( @@ -318,18 +293,18 @@ def stop_remote_worker_execution_loop(self) -> None: class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) async def execute_model_async( - self, execute_model_req: ExecuteModelRequest - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers self.parallel_worker_tasks = asyncio.create_task( - self._start_worker_execution_loop() - ) + self._start_worker_execution_loop()) # Only the driver worker returns the sampling results. return await self._driver_execute_model_async(execute_model_req) @@ -346,9 +321,11 @@ async def stop_remote_worker_execution_loop_async(self) -> None: await parallel_worker_tasks async def _driver_execute_model_async( - self, execute_model_req: Optional[ExecuteModelRequest] = None + self, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - return await self.driver_exec_method("execute_model", execute_model_req) + return await self.driver_exec_method("execute_model", + execute_model_req) async def _start_worker_execution_loop(self): coros = [ From b0225afd3c86bf67096f708ecbefa8d8c4f7485b Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Tue, 20 Aug 2024 16:57:06 +0000 Subject: [PATCH 06/18] remove prints --- vllm/distributed/device_communicators/tpu_communicator.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 4816d3f2fb85a..1558753f38337 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -31,24 +31,16 @@ def __init__(self, group: ProcessGroup): pg_table = ray.util.placement_group_table() current_pg = ray.util.get_current_placement_group() - print(f"current pg: {current_pg.id.hex()}") nodes_in_pg = set() for pg_key, pg in pg_table.items(): if pg_key == current_pg.id.hex(): for _, node in pg['bundles_to_node_id'].items(): nodes_in_pg.add(node) - print(f"pg nodes: {nodes_in_pg}") num_nodes = len(nodes_in_pg) local_world_size = global_world_size // num_nodes local_rank = global_rank % local_world_size - print(f"global_rank: {global_rank}") - print(f"global_world_size: {global_world_size}") - print(f"num_nodes: {num_nodes}") - print(f"local_world_size: {local_world_size}") - print(f"local_rank: {local_rank}") - # Ensure environment variables are set for multihost deployments. os.environ['CLOUD_TPU_TASK_ID'] = str(global_rank) os.environ['TPU_VISIBLE_CHIPS'] = str(local_rank) From 852f9fb84bc0a1b421dc78e701ad789aa970a334 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Wed, 21 Aug 2024 01:19:17 +0000 Subject: [PATCH 07/18] fix bug with no placement group --- .../device_communicators/tpu_communicator.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 1558753f38337..83a3447a63c60 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -27,16 +27,20 @@ def __init__(self, group: ProcessGroup): global_rank = dist.get_rank(group) global_world_size = dist.get_world_size(group) - # Calculate how many TPU nodes are in the current placement group. + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of nodes in the Ray cluster.. + num_nodes = len(ray.nodes()) pg_table = ray.util.placement_group_table() current_pg = ray.util.get_current_placement_group() - nodes_in_pg = set() - for pg_key, pg in pg_table.items(): - if pg_key == current_pg.id.hex(): - for _, node in pg['bundles_to_node_id'].items(): - nodes_in_pg.add(node) - num_nodes = len(nodes_in_pg) + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg['bundles_to_node_id'].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) local_world_size = global_world_size // num_nodes local_rank = global_rank % local_world_size From dcb6095e73fb765ad4e1b034b26804af4207bbdd Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Wed, 21 Aug 2024 04:20:47 +0000 Subject: [PATCH 08/18] calculate tpu resources --- .../device_communicators/tpu_communicator.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 83a3447a63c60..aad308363e833 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -8,6 +8,7 @@ if current_platform.is_tpu(): import ray + from ray._private.accelerators import TPUAcceleratorManager import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt @@ -29,8 +30,15 @@ def __init__(self, group: ProcessGroup): # Calculate how many TPU nodes are in the current deployment. This # is the Ray placement group if it is deployed with Ray. Default - # to the number of nodes in the Ray cluster.. - num_nodes = len(ray.nodes()) + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + num_nodes = total_tpus // tpus_per_node + pg_table = ray.util.placement_group_table() current_pg = ray.util.get_current_placement_group() From 00fb272551611278a1c21be0dbbbd1727d1c8136 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Wed, 21 Aug 2024 04:33:06 +0000 Subject: [PATCH 09/18] ruff --- .../device_communicators/tpu_communicator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index aad308363e833..aeb9f099062c7 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -15,7 +15,6 @@ class TpuCommunicator: - def __init__(self, group: ProcessGroup): if not current_platform.is_tpu(): self.disabled = True @@ -36,7 +35,9 @@ def __init__(self, group: ProcessGroup): # with both CPUs and TPUs. cluster_resources = ray.cluster_resources() total_tpus = int(cluster_resources["TPU"]) - tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + tpus_per_node = ( + TPUAcceleratorManager.get_current_node_num_accelerators() + ) num_nodes = total_tpus // tpus_per_node pg_table = ray.util.placement_group_table() @@ -46,7 +47,7 @@ def __init__(self, group: ProcessGroup): nodes_in_pg = set() for pg_key, pg in pg_table.items(): if pg_key == current_pg.id.hex(): - for _, node in pg['bundles_to_node_id'].items(): + for _, node in pg["bundles_to_node_id"].items(): nodes_in_pg.add(node) num_nodes = len(nodes_in_pg) @@ -54,8 +55,8 @@ def __init__(self, group: ProcessGroup): local_rank = global_rank % local_world_size # Ensure environment variables are set for multihost deployments. - os.environ['CLOUD_TPU_TASK_ID'] = str(global_rank) - os.environ['TPU_VISIBLE_CHIPS'] = str(local_rank) + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() From b7774b3abff956639307b55b1f693d124da1021b Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Wed, 21 Aug 2024 04:36:12 +0000 Subject: [PATCH 10/18] isort --- vllm/distributed/device_communicators/tpu_communicator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index aeb9f099062c7..3328557c948be 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -8,9 +8,9 @@ if current_platform.is_tpu(): import ray - from ray._private.accelerators import TPUAcceleratorManager import torch_xla.core.xla_model as xm import torch_xla.runtime as xr + from ray._private.accelerators import TPUAcceleratorManager from torch_xla._internal import pjrt From da66cf9f7c6168a9bae587a0b3fb875f4adcc4d3 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Wed, 21 Aug 2024 04:40:10 +0000 Subject: [PATCH 11/18] yapf -i --- vllm/distributed/device_communicators/tpu_communicator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 3328557c948be..431b5af870d04 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -15,6 +15,7 @@ class TpuCommunicator: + def __init__(self, group: ProcessGroup): if not current_platform.is_tpu(): self.disabled = True @@ -36,8 +37,7 @@ def __init__(self, group: ProcessGroup): cluster_resources = ray.cluster_resources() total_tpus = int(cluster_resources["TPU"]) tpus_per_node = ( - TPUAcceleratorManager.get_current_node_num_accelerators() - ) + TPUAcceleratorManager.get_current_node_num_accelerators()) num_nodes = total_tpus // tpus_per_node pg_table = ray.util.placement_group_table() From 1629b0b77d2d35fb10a8e3f3f20b85d2ab78034d Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Thu, 29 Aug 2024 23:45:39 +0000 Subject: [PATCH 12/18] Reorganize code --- .../device_communicators/tpu_communicator.py | 23 +++-------- vllm/executor/ray_utils.py | 40 ++++++++++++++++--- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 431b5af870d04..2803673f4454b 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -10,9 +10,10 @@ import ray import torch_xla.core.xla_model as xm import torch_xla.runtime as xr - from ray._private.accelerators import TPUAcceleratorManager from torch_xla._internal import pjrt + from vllm.executor import ray_utils + class TpuCommunicator: @@ -34,22 +35,10 @@ def __init__(self, group: ProcessGroup): # nodes is computed by the total number of TPUs divided by the # number of TPU accelerators per node, to account for clusters # with both CPUs and TPUs. - cluster_resources = ray.cluster_resources() - total_tpus = int(cluster_resources["TPU"]) - tpus_per_node = ( - TPUAcceleratorManager.get_current_node_num_accelerators()) - num_nodes = total_tpus // tpus_per_node - - pg_table = ray.util.placement_group_table() - current_pg = ray.util.get_current_placement_group() - - if current_pg: - nodes_in_pg = set() - for pg_key, pg in pg_table.items(): - if pg_key == current_pg.id.hex(): - for _, node in pg["bundles_to_node_id"].items(): - nodes_in_pg.add(node) - num_nodes = len(nodes_in_pg) + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg: + num_nodes = num_nodes_in_pg local_world_size = global_world_size // num_nodes local_rank = global_rank % local_world_size diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index d34d5871c20ef..3490dd6840d92 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -21,6 +21,7 @@ from ray._private.state import available_resources_per_node from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup + from ray._private.accelerators import TPUAcceleratorManager class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be @@ -47,9 +48,9 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: return node_id, gpu_ids def execute_model_spmd( - self, req_or_tuple: Union[bytes, - Tuple[bytes, - Optional[IntermediateTensors]]] + self, + req_or_tuple: Union[bytes, Tuple[bytes, + Optional[IntermediateTensors]]], ) -> bytes: """Execute model in SPMD fashion: used only when SPMD worker and compiled DAG are both enabled. @@ -71,6 +72,7 @@ def execute_model_spmd( # on a background thread, so we need to reset torch's current # device. import torch + if not self.compiled_dag_cuda_device_set: torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True @@ -227,9 +229,11 @@ def initialize_ray_cluster( # Connect to a ray cluster. if is_hip() or is_xpu(): - ray.init(address=ray_address, - ignore_reinit_error=True, - num_gpus=parallel_config.world_size) + ray.init( + address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size, + ) else: ray.init(address=ray_address, ignore_reinit_error=True) @@ -295,3 +299,27 @@ def initialize_ray_cluster( _verify_bundles(current_placement_group, parallel_config, device_str) # Set the placement group in the parallel config parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes(): + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + + +def get_num_nodes_in_placement_group(): + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes From 5a387ec92c122e1a389b7a322deb7f56c4dbadae Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Fri, 30 Aug 2024 00:04:36 +0000 Subject: [PATCH 13/18] format --- vllm/executor/ray_utils.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 3490dd6840d92..abce047beea02 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -18,10 +18,10 @@ try: import ray + from ray._private.accelerators import TPUAcceleratorManager from ray._private.state import available_resources_per_node from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup - from ray._private.accelerators import TPUAcceleratorManager class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be @@ -110,16 +110,19 @@ def assert_ray_available(): "`pip install ray`.") from ray_import_err -def _verify_bundles(placement_group: "PlacementGroup", - parallel_config: ParallelConfig, device_str: str): +def _verify_bundles( + placement_group: "PlacementGroup", + parallel_config: ParallelConfig, + device_str: str, +): """Verify a given placement group has bundles located in the right place. There are 2 rules. - Warn if all tensor parallel workers cannot fit in a single node. - Fail if driver node is not included in a placement group. """ - assert ray.is_initialized(), ( - "Ray is not initialized although distributed-executor-backend is ray.") + assert (ray.is_initialized( + )), "Ray is not initialized although distributed-executor-backend is ray." pg_data = placement_group_table(placement_group) # bundle_idx -> node_id bundle_to_node_ids = pg_data["bundles_to_node_id"] @@ -151,8 +154,13 @@ def _verify_bundles(placement_group: "PlacementGroup", "unless you have fast interconnect across nodes, like " "Infiniband. To resolve this issue, make sure you have more " "than %d GPUs available at each node.", - parallel_config.tensor_parallel_size, device_str, len(bundles), - device_str, node_id, parallel_config.tensor_parallel_size) + parallel_config.tensor_parallel_size, + device_str, + len(bundles), + device_str, + node_id, + parallel_config.tensor_parallel_size, + ) def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): @@ -181,7 +189,9 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): "Waiting for creating a placement group of specs for " "%d seconds. specs=%s. Check " "`ray status` to see if you have enough resources.", - int(time.time() - s), placement_group_specs) + int(time.time() - s), + placement_group_specs, + ) try: ray.get(pg_ready_ref, timeout=0) @@ -206,7 +216,9 @@ def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): wait_interval *= 2 logger.info( "Waiting for removing a placement group of specs for " - "%d seconds.", int(time.time() - s)) + "%d seconds.", + int(time.time() - s), + ) time.sleep(wait_interval) @@ -270,9 +282,9 @@ def initialize_ray_cluster( f"The number of required {device_str}s exceeds the total " f"number of available {device_str}s in the placement group.") # Create a new placement group - placement_group_specs: List[Dict[str, float]] = ([{ + placement_group_specs: List[Dict[str, float]] = [{ device_str: 1.0 - } for _ in range(parallel_config.world_size)]) + } for _ in range(parallel_config.world_size)] # vLLM engine is also a worker to execute model with an accelerator, # so it requires to have the device in a current node. Check if From fb02fe6e46360211f89989261465f720a1c3063a Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Fri, 30 Aug 2024 00:21:49 +0000 Subject: [PATCH 14/18] remove ray[serve] --- requirements-tpu.txt | 2 +- vllm/distributed/device_communicators/tpu_communicator.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 7624010371e81..4c606cf0a9105 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -4,4 +4,4 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. -ray[default,serve] +ray[default] diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 2803673f4454b..02271fe27af08 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -7,7 +7,6 @@ from vllm.platforms import current_platform if current_platform.is_tpu(): - import ray import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt From 07cee90bb900ec24a2759c92986a6db0a42cabfd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 30 Aug 2024 07:11:06 +0000 Subject: [PATCH 15/18] Add comment --- vllm/distributed/device_communicators/tpu_communicator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 02271fe27af08..283b1491de338 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -43,6 +43,10 @@ def __init__(self, group: ProcessGroup): local_rank = global_rank % local_world_size # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) From 70e6447e981eddfd4dc5a4f28475112375d32622 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 30 Aug 2024 07:14:34 +0000 Subject: [PATCH 16/18] revert --- vllm/executor/ray_tpu_executor.py | 33 ++++++++++++------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 93c6bf607cc3c..a40af1741456f 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -133,10 +133,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for _ in worker_node_and_gpu_ids] - self._run_workers( - "update_environment_variables", - all_args=all_args_to_update_environment_variables, - ) + self._run_workers("update_environment_variables", + all_args=all_args_to_update_environment_variables) if len(node_workers) == 1: # in single node case, we don't need to get the IP address. @@ -162,11 +160,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) def _driver_execute_model( self, @@ -209,10 +205,10 @@ def _run_workers( "max_concurrent_workers is not supported yet.") count = len(self.workers) - all_worker_args = (repeat(args, count) if all_args is None else islice( - all_args, 1, None)) - all_worker_kwargs = (repeat(kwargs, count) if all_kwargs is None else - islice(all_kwargs, 1, None)) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 1, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 1, None) # Start the ray workers first. ray_worker_outputs = [ @@ -260,11 +256,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self._run_workers( - "initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - ) + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) def execute_model( self, @@ -274,8 +268,7 @@ def execute_model( self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", async_run_remote_workers_only=True, - **self.extra_execute_model_run_workers_kwargs, - ) + **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. return self._driver_execute_model(execute_model_req) From 117fc127eed1a2750ca6f6aaa82d7420e8db4ce4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 30 Aug 2024 07:16:19 +0000 Subject: [PATCH 17/18] Minor --- vllm/distributed/device_communicators/tpu_communicator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 283b1491de338..765a0f9cb1c87 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -36,7 +36,7 @@ def __init__(self, group: ProcessGroup): # with both CPUs and TPUs. num_nodes = ray_utils.get_num_tpu_nodes() num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() - if num_nodes_in_pg: + if num_nodes_in_pg > 0: num_nodes = num_nodes_in_pg local_world_size = global_world_size // num_nodes From 1781e11c2b45e7e5bd7a08105b104841dc412b3a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 30 Aug 2024 07:19:57 +0000 Subject: [PATCH 18/18] Update --- vllm/executor/ray_utils.py | 55 ++++++++++++++------------------------ 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index abce047beea02..59e9854393b6b 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -18,7 +18,6 @@ try: import ray - from ray._private.accelerators import TPUAcceleratorManager from ray._private.state import available_resources_per_node from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup @@ -48,9 +47,9 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: return node_id, gpu_ids def execute_model_spmd( - self, - req_or_tuple: Union[bytes, Tuple[bytes, - Optional[IntermediateTensors]]], + self, req_or_tuple: Union[bytes, + Tuple[bytes, + Optional[IntermediateTensors]]] ) -> bytes: """Execute model in SPMD fashion: used only when SPMD worker and compiled DAG are both enabled. @@ -72,7 +71,6 @@ def execute_model_spmd( # on a background thread, so we need to reset torch's current # device. import torch - if not self.compiled_dag_cuda_device_set: torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True @@ -87,7 +85,7 @@ def execute_model_spmd( return output - def override_env_vars(self, vars): + def override_env_vars(self, vars: Dict[str, str]): os.environ.update(vars) ray_import_err = None @@ -110,19 +108,16 @@ def assert_ray_available(): "`pip install ray`.") from ray_import_err -def _verify_bundles( - placement_group: "PlacementGroup", - parallel_config: ParallelConfig, - device_str: str, -): +def _verify_bundles(placement_group: "PlacementGroup", + parallel_config: ParallelConfig, device_str: str): """Verify a given placement group has bundles located in the right place. There are 2 rules. - Warn if all tensor parallel workers cannot fit in a single node. - Fail if driver node is not included in a placement group. """ - assert (ray.is_initialized( - )), "Ray is not initialized although distributed-executor-backend is ray." + assert ray.is_initialized(), ( + "Ray is not initialized although distributed-executor-backend is ray.") pg_data = placement_group_table(placement_group) # bundle_idx -> node_id bundle_to_node_ids = pg_data["bundles_to_node_id"] @@ -154,13 +149,8 @@ def _verify_bundles( "unless you have fast interconnect across nodes, like " "Infiniband. To resolve this issue, make sure you have more " "than %d GPUs available at each node.", - parallel_config.tensor_parallel_size, - device_str, - len(bundles), - device_str, - node_id, - parallel_config.tensor_parallel_size, - ) + parallel_config.tensor_parallel_size, device_str, len(bundles), + device_str, node_id, parallel_config.tensor_parallel_size) def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): @@ -189,9 +179,7 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): "Waiting for creating a placement group of specs for " "%d seconds. specs=%s. Check " "`ray status` to see if you have enough resources.", - int(time.time() - s), - placement_group_specs, - ) + int(time.time() - s), placement_group_specs) try: ray.get(pg_ready_ref, timeout=0) @@ -216,9 +204,7 @@ def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): wait_interval *= 2 logger.info( "Waiting for removing a placement group of specs for " - "%d seconds.", - int(time.time() - s), - ) + "%d seconds.", int(time.time() - s)) time.sleep(wait_interval) @@ -241,11 +227,9 @@ def initialize_ray_cluster( # Connect to a ray cluster. if is_hip() or is_xpu(): - ray.init( - address=ray_address, - ignore_reinit_error=True, - num_gpus=parallel_config.world_size, - ) + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) else: ray.init(address=ray_address, ignore_reinit_error=True) @@ -282,9 +266,9 @@ def initialize_ray_cluster( f"The number of required {device_str}s exceeds the total " f"number of available {device_str}s in the placement group.") # Create a new placement group - placement_group_specs: List[Dict[str, float]] = [{ + placement_group_specs: List[Dict[str, float]] = ([{ device_str: 1.0 - } for _ in range(parallel_config.world_size)] + } for _ in range(parallel_config.world_size)]) # vLLM engine is also a worker to execute model with an accelerator, # so it requires to have the device in a current node. Check if @@ -313,7 +297,8 @@ def initialize_ray_cluster( parallel_config.placement_group = current_placement_group -def get_num_tpu_nodes(): +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager cluster_resources = ray.cluster_resources() total_tpus = int(cluster_resources["TPU"]) tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() @@ -321,7 +306,7 @@ def get_num_tpu_nodes(): return total_tpus // tpus_per_node -def get_num_nodes_in_placement_group(): +def get_num_nodes_in_placement_group() -> int: pg_table = ray.util.placement_group_table() current_pg = ray.util.get_current_placement_group() num_nodes = 0