Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][ROCm][AMD] Add optional torchrun multi GPU executor #3691

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
worker_use_ray=args.worker_use_ray,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size)
Expand Down Expand Up @@ -190,5 +191,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU '
'unless on ROCm where the default is torchrun')
args = parser.parse_args()
main(args)
80 changes: 30 additions & 50 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,55 +58,38 @@ def sample_requests(


def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
requests: List[Tuple[str, int, int]],
args: argparse.Namespace,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
enable_prefix_caching=args.enable_prefix_caching,
download_dir=args.download_dir,
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_batched_tokens=args.max_num_batched_tokens,
worker_use_ray=args.worker_use_ray,
)

# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
Expand Down Expand Up @@ -219,15 +202,7 @@ def main(args: argparse.Namespace):
args.output_len)

if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
elapsed_time = run_vllm(args)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -354,6 +329,11 @@ def main(args: argparse.Namespace):
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU '
'unless on ROCm where the default is torchrun')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
13 changes: 12 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def __init__(
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.worker_use_torchrun = False
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.tokenizer_pool_config = tokenizer_pool_config
Expand All @@ -513,7 +514,17 @@ def __init__(

self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
if is_hip() and not self.worker_use_ray:
logger.info("Using torchrun for multi-GPU on "
"ROCM platform. Use --worker-use-ray "
"to override")
if not os.environ.get("RANK"):
raise RuntimeError(
"Needs to be run in torchrun: "
"torchrun --standalone --nproc_per_node=<tp> ...")
self.worker_use_torchrun = True
else:
self.worker_use_ray = True
self._verify_args()

def _verify_args(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def add_cli_args(
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
'automatically set when using more than 1 GPU '
'unless on ROCm where the default is torchrun')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def from_engine_args(
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif engine_config.parallel_config.worker_use_torchrun:
from vllm.executor.torchrun_gpu_executor import TorchrunGPUExecutor
executor_class = TorchrunGPUExecutor
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
Expand Down
106 changes: 106 additions & 0 deletions vllm/executor/torchrun_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from typing import Dict, List, Optional

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_object_list)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)

logger = init_logger(__name__)

# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}


class TorchrunGPUExecutor(GPUExecutor):

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.is_driver_worker = self.local_rank == 0
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, lora_config,
vision_language_config)

def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker

assert self.parallel_config.world_size > 1, (
"TorchrunGPUExecutor only supports multiple GPUs.")

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
local_rank=self.local_rank,
rank=self.local_rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=self.is_driver_worker,
)
self.driver_worker.init_device()
self.driver_worker.load_model()

def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
if self.is_driver_worker:
broadcast_object_list([output], src=0)
else:
res = [None]
broadcast_object_list(res, src=0)
output = res[0]
return output


class TorchrunGPUExecutorAsync(TorchrunGPUExecutor, ExecutorAsyncBase):

async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output

async def check_health_async(self) -> None:
# TorchrunGPUExecutor will always be healthy as long as
# it's running.
return
8 changes: 6 additions & 2 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,12 @@ def init_worker_distributed_environment(
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
if not parallel_config.worker_use_torchrun:
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
else:
init_distributed_environment(parallel_config.world_size, -1, "env://",
local_rank)

if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size()
Expand Down
Loading