Skip to content

Commit

Permalink
[core] allow callable in collective_rpc (vllm-project#12151)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
youkaichao authored and Isotr0py committed Feb 2, 2025
1 parent 72c4266 commit d7d1f20
Show file tree
Hide file tree
Showing 13 changed files with 147 additions and 50 deletions.
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ steps:
source_file_dependencies:
- vllm/
commands:
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
Expand Down Expand Up @@ -466,7 +466,9 @@ steps:
- vllm/worker/worker_base.py
- vllm/worker/worker.py
- vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
commands:
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
Expand Down
4 changes: 2 additions & 2 deletions tests/engine/test_custom_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pytest

Expand All @@ -18,7 +18,7 @@ class Mock:
class CustomUniExecutor(UniProcExecutor):

def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
Expand Down
36 changes: 36 additions & 0 deletions tests/entrypoints/llm/test_collective_rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from vllm import LLM

from ...utils import fork_new_process_for_each_test


@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("backend", ["mp", "ray"])
@fork_new_process_for_each_test
def test_collective_rpc(tp_size, backend):
if tp_size == 1 and backend == "ray":
pytest.skip("Skip duplicate test case")
if tp_size == 1:
backend = None

# intentionally define the method and class in the test function,
# to test if they can be serialized and sent to the workers
def echo_rank(self):
return self.rank

from vllm.worker.worker import Worker

class MyWorker(Worker):

def echo_rank(self):
return self.rank

llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=tp_size,
distributed_executor_backend=backend,
worker_cls=MyWorker)
for method in ["echo_rank", echo_rank]:
assert llm.collective_rpc(method) == list(range(tp_size))
17 changes: 14 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
from typing import Set, Tuple, Type, Union, cast, overload

import torch
from typing_extensions import TypeVar, deprecated
Expand Down Expand Up @@ -1816,6 +1816,17 @@ def start_profile(self) -> None:
def stop_profile(self) -> None:
self.model_executor.stop_profile()

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
See LLM.collective_rpc for more details.
"""
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)

def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
Expand Down
14 changes: 9 additions & 5 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import itertools
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)

import cloudpickle
from tqdm import tqdm
Expand Down Expand Up @@ -464,7 +464,7 @@ def generate(
return self.engine_class.validate_outputs(outputs, RequestOutput)

def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
Expand All @@ -476,9 +476,13 @@ def collective_rpc(self,
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
"""
return self.llm_engine.model_executor.collective_rpc(
method, timeout, args, kwargs)
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)

def beam_search(
self,
Expand Down
9 changes: 5 additions & 4 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)

from vllm.config import VllmConfig
from vllm.logger import init_logger
Expand Down Expand Up @@ -47,7 +48,7 @@ def _init_executor(self) -> None:

@abstractmethod
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
Expand Down Expand Up @@ -260,7 +261,7 @@ def _driver_execute_model(
raise NotImplementedError

def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
Expand All @@ -269,7 +270,7 @@ def collective_rpc(self,
@abstractmethod
def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
Expand Down
21 changes: 14 additions & 7 deletions vllm/executor/mp_distributed_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from typing import Any, List, Optional
from typing import Any, Callable, List, Optional, Union

import cloudpickle

from vllm.executor.executor_base import DistributedExecutorBase
from vllm.executor.multiproc_worker_utils import (
Expand All @@ -9,7 +11,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
get_ip, get_open_port, make_async, run_method)
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -107,7 +109,7 @@ def _driver_execute_model(

def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
Expand All @@ -121,6 +123,11 @@ def _run_workers(
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method

if max_concurrent_workers:
raise NotImplementedError(
Expand All @@ -129,18 +136,18 @@ def _run_workers(
if async_run_tensor_parallel_workers_only:
# Run only non-driver workers and just return futures.
return [
worker.execute_method(method, *args, **kwargs)
worker.execute_method(sent_method, *args, **kwargs)
for worker in self.non_driver_workers
]

# Start all remote workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
worker.execute_method(sent_method, *args, **kwargs)
for worker in self.workers
]

driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
driver_worker_output = run_method(self.driver_worker, sent_method,
args, kwargs)

# Get the results of the workers.
return [driver_worker_output
Expand Down
12 changes: 6 additions & 6 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import _check_multiproc_method, get_mp_context
from vllm.utils import _check_multiproc_method, get_mp_context, run_method

if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self, result_handler: ResultHandler,
self.process.start()

def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
method: str, args, kwargs):
method: Union[str, bytes], args, kwargs):
task_id = uuid.uuid4()
self.tasks[task_id] = future
try:
Expand All @@ -180,12 +180,13 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
del self.tasks[task_id]
raise ChildProcessError("worker died") from e

def execute_method(self, method: str, *args, **kwargs):
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
future: ResultFuture = ResultFuture()
self._enqueue_task(future, method, args, kwargs)
return future

async def execute_method_async(self, method: str, *args, **kwargs):
async def execute_method_async(self, method: Union[str, bytes], *args,
**kwargs):
future = asyncio.get_running_loop().create_future()
self._enqueue_task(future, method, args, kwargs)
return await future
Expand Down Expand Up @@ -230,8 +231,7 @@ def _run_worker_process(
exception = None
task_id, method, args, kwargs = items
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
output = run_method(worker, method, args, kwargs)
except SystemExit:
raise
except KeyboardInterrupt:
Expand Down
14 changes: 10 additions & 4 deletions vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import cloudpickle
import msgspec

import vllm.envs as envs
Expand Down Expand Up @@ -410,7 +411,7 @@ def execute_model(

def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
Expand All @@ -426,6 +427,11 @@ def _run_workers(
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
Expand All @@ -440,7 +446,7 @@ def _run_workers(
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
worker.execute_method.remote(sent_method, *args, **kwargs)
for worker in ray_workers
]

Expand All @@ -455,7 +461,7 @@ def _run_workers(
if not self.use_ray_spmd_worker:
# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(method, *args, **kwargs)
self.driver_worker.execute_method(sent_method, *args, **kwargs)
]

# Get the results of the ray workers.
Expand Down
14 changes: 5 additions & 9 deletions vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
run_method)
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -39,18 +40,13 @@ def _init_executor(self) -> None:
self.collective_rpc("load_model")

def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None:
kwargs = {}
try:
func = getattr(self.driver_worker, method)
except AttributeError:
raise NotImplementedError(f"Method {method} is not implemented.") \
from None
answer = func(*args, **kwargs)
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer]

def check_health(self) -> None:
Expand Down
Loading

0 comments on commit d7d1f20

Please sign in to comment.