Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vLLM single and multi-host TPUs on GKE #7613

Merged
merged 20 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
ray
ray[default,serve]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove Ray Serve from the dependency. vLLM does not need to be used with Ray Serve, and we'd like to minimize the dependencies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we still install ray[default]? The reason for this is that the GCS endpoint needs to run through the Ray dashboard which does not get installed if you just do pip install ray. The endpoint is needed in order for other Ray workers to join the cluster for multihost inference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, how is ray[default] different from just ray?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://docs.ray.io/en/latest/ray-overview/installation.html, ray[default] includes the Ray dashboard while ray is just the Ray core libraries.

The reason for including the dashboard (in addition to debuggability) is that the GCS and Ray Job endpoints are exposed through the dashboard. So without it, the other Ray nodes are not able to join the cluster. For example the Kubernetes operator initializes Ray worker nodes by having them ping the GCS endpoint.

5 changes: 4 additions & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def __init__(
raise NotImplementedError("TPU version must be 4 or higher.")

self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
tpu_type = tpu_type.lower()

if "lite" not in tpu_type:
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
Expand Down
33 changes: 32 additions & 1 deletion vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand All @@ -8,6 +10,7 @@
import ray
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from ray._private.accelerators import TPUAcceleratorManager
from torch_xla._internal import pjrt


Expand All @@ -24,9 +27,37 @@ def __init__(self, group: ProcessGroup):
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())

# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
cluster_resources = ray.cluster_resources()
total_tpus = int(cluster_resources["TPU"])
tpus_per_node = (
TPUAcceleratorManager.get_current_node_num_accelerators())
num_nodes = total_tpus // tpus_per_node
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we add assert total_tpus % tpus_per_node == 0?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to ray_utils.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


pg_table = ray.util.placement_group_table()
current_pg = ray.util.get_current_placement_group()

if current_pg:
nodes_in_pg = set()
for pg_key, pg in pg_table.items():
if pg_key == current_pg.id.hex():
for _, node in pg["bundles_to_node_id"].items():
nodes_in_pg.add(node)
num_nodes = len(nodes_in_pg)

local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size

# Ensure environment variables are set for multihost deployments.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for libtpu and TPU driver to know which TPU chip is actually visible. On GKE these need to be set, otherwise the TPU driver will fail to initialize because the number of devices would be different from the number of visible worker addresses.


pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()

Expand Down
48 changes: 35 additions & 13 deletions vllm/executor/ray_tpu_executor.py
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"

# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env = {}
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
override_env.update({
"TPU_CHIPS_PER_HOST_BOUNDS":
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
})
if "TPU_HOST_BOUNDS" in os.environ:
override_env.update(
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})

worker = ray.remote(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: can we use the runtime_env arg in ray.remote instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think runtime_env will be overwritten when Ray starts up and tries to initialize the environment from within the Ray process. We are using this to manually override the environment after the Ray task starts up.

num_cpus=0,
resources={"TPU": 1},
Expand All @@ -80,6 +93,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
if override_env:
worker.override_env_vars.remote(override_env)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
Expand Down Expand Up @@ -118,8 +133,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for _ in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
self._run_workers(
"update_environment_variables",
all_args=all_args_to_update_environment_variables,
)

if len(node_workers) == 1:
# in single node case, we don't need to get the IP address.
Expand All @@ -145,9 +162,11 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)

def _driver_execute_model(
self,
Expand Down Expand Up @@ -190,10 +209,10 @@ def _run_workers(
"max_concurrent_workers is not supported yet.")

count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
all_worker_args = (repeat(args, count) if all_args is None else islice(
all_args, 1, None))
all_worker_kwargs = (repeat(kwargs, count) if all_kwargs is None else
islice(all_kwargs, 1, None))

# Start the ray workers first.
ray_worker_outputs = [
Expand Down Expand Up @@ -241,9 +260,11 @@ def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
self._run_workers(
"initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
)

def execute_model(
self,
Expand All @@ -253,7 +274,8 @@ def execute_model(
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
**self.extra_execute_model_run_workers_kwargs,
)

# Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req)
Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -84,6 +85,9 @@ def execute_model_spmd(

return output

def override_env_vars(self, vars):
os.environ.update(vars)

ray_import_err = None

except ImportError as e:
Expand Down
Loading