Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] allow callable in collective_rpc #12151

Merged
merged 13 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ steps:
source_file_dependencies:
- vllm/
commands:
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
Expand Down Expand Up @@ -466,7 +466,9 @@ steps:
- vllm/worker/worker_base.py
- vllm/worker/worker.py
- vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
commands:
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
Expand Down
4 changes: 2 additions & 2 deletions tests/engine/test_custom_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pytest

Expand All @@ -18,7 +18,7 @@ class Mock:
class CustomUniExecutor(UniProcExecutor):

def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
Expand Down
35 changes: 35 additions & 0 deletions tests/entrypoints/llm/test_collective_rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from vllm import LLM

from ...utils import fork_new_process_for_each_test


def echo_rank(self):
return self.rank


@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("backend", ["mp", "ray"])
@fork_new_process_for_each_test
Copy link
Member

@DarkLight1337 DarkLight1337 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the multi_gpu_test decorator not work here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a mixture of 1 gpu test and 2 gpu test. does multi_gpu_test work?

Copy link
Member

@DarkLight1337 DarkLight1337 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multi_gpu_test requires 2 GPUs to run the test. Since we're skipping this file in the single GPU CI anyways, I think it shouldn't make a difference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we need to duplicate the test in two sections 🤔

I prefer to directly run the whole test in 2 GPU tests.

Copy link
Member

@DarkLight1337 DarkLight1337 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move the test logic into a separate function, and have two separate test_* functions (one with 1 GPU, without multi_gpu_test; and another with 2 GPUs, with multi_gpu_test decorator)` to be called by pytest.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally merge them into one function to reduce code duplication :(

def test_collective_rpc(tp_size, backend):
if tp_size == 1 and backend == "ray":
pytest.skip("Skip duplicate test case")
if tp_size == 1:
backend = None

from vllm.worker.worker import Worker

class MyWorker(Worker):

def echo_rank(self):
return self.rank

llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=tp_size,
distributed_executor_backend=backend,
worker_cls=MyWorker)
for method in ["echo_rank", echo_rank]:
assert llm.collective_rpc(method) == list(range(tp_size))
17 changes: 14 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
from typing import Set, Tuple, Type, Union, cast, overload

import torch
from typing_extensions import TypeVar, deprecated
Expand Down Expand Up @@ -1816,6 +1816,17 @@ def start_profile(self) -> None:
def stop_profile(self) -> None:
self.model_executor.stop_profile()

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
See LLM.collective_rpc for more details.
"""
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)

def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
Expand Down
14 changes: 9 additions & 5 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import itertools
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)

import cloudpickle
from tqdm import tqdm
Expand Down Expand Up @@ -464,7 +464,7 @@ def generate(
return self.engine_class.validate_outputs(outputs, RequestOutput)

def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
Expand All @@ -476,9 +476,13 @@ def collective_rpc(self,
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
"""
return self.llm_engine.model_executor.collective_rpc(
method, timeout, args, kwargs)
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)

def beam_search(
self,
Expand Down
9 changes: 5 additions & 4 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)

from vllm.config import VllmConfig
from vllm.logger import init_logger
Expand Down Expand Up @@ -47,7 +48,7 @@ def _init_executor(self) -> None:

@abstractmethod
def collective_rpc(self,
method: str,
Copy link
Member

@DarkLight1337 DarkLight1337 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can add some type annotations to callable method. Something like

from typing import Any, Callable, Mapping, TypeVar
from typing_extensions import Concatenate

_R = TypeVar("_R")
WorkerMethod = Callable[Concatenate[WorkerBase, Any], _R]

def collective_rpc((
    self,
    method: Union[str, WorkerMethod[_R]],
    timeout: Optional[float] = None,
    args: tuple[Any, ...] = (),
    # Mutable default is fine because we annotate it as immutable
    kwargs: Mapping[str, Any] = {},
) -> list[_R]:
    ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid repeating this type annotation in all places would be less clear. I prefer the current simple type annotation.

Copy link
Member

@DarkLight1337 DarkLight1337 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just have this at the top level for user convenience.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally I don't use TypeVar as I find it to be less clear

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make it easier for users to go to the worker definition when looking at LLM.collective_rpc.

Now that I think about it, perhaps it would be better to return a wrapper like RemoteWorkers instead of directly returning the result? Something like:

workers = llm.get_workers()
outputs = workers.collective_rpc(...)

This way, we can define collective_rpc in just one place without repeating the arguments every time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can actually just return the executor instead of a wrapper in that case...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users can print the self to see what is the actual worker.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's keep this for simplicity then. Maybe we should update the example file to illustrate this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll prepare a dedicated doc to show the usage of collective_rpc

method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
Expand Down Expand Up @@ -260,7 +261,7 @@ def _driver_execute_model(
raise NotImplementedError

def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
Expand All @@ -269,7 +270,7 @@ def collective_rpc(self,
@abstractmethod
def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
Expand Down
21 changes: 14 additions & 7 deletions vllm/executor/mp_distributed_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from typing import Any, List, Optional
from typing import Any, Callable, List, Optional, Union

import cloudpickle

from vllm.executor.executor_base import DistributedExecutorBase
from vllm.executor.multiproc_worker_utils import (
Expand All @@ -9,7 +11,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
get_ip, get_open_port, make_async, run_method)
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -107,7 +109,7 @@ def _driver_execute_model(

def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
Expand All @@ -121,6 +123,11 @@ def _run_workers(
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method

if max_concurrent_workers:
raise NotImplementedError(
Expand All @@ -129,18 +136,18 @@ def _run_workers(
if async_run_tensor_parallel_workers_only:
# Run only non-driver workers and just return futures.
return [
worker.execute_method(method, *args, **kwargs)
worker.execute_method(sent_method, *args, **kwargs)
for worker in self.non_driver_workers
]

# Start all remote workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
worker.execute_method(sent_method, *args, **kwargs)
for worker in self.workers
]

driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
driver_worker_output = run_method(self.driver_worker, sent_method,
args, kwargs)

# Get the results of the workers.
return [driver_worker_output
Expand Down
12 changes: 6 additions & 6 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import _check_multiproc_method, get_mp_context
from vllm.utils import _check_multiproc_method, get_mp_context, run_method

if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self, result_handler: ResultHandler,
self.process.start()

def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
method: str, args, kwargs):
method: Union[str, bytes], args, kwargs):
task_id = uuid.uuid4()
self.tasks[task_id] = future
try:
Expand All @@ -180,12 +180,13 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
del self.tasks[task_id]
raise ChildProcessError("worker died") from e

def execute_method(self, method: str, *args, **kwargs):
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
future: ResultFuture = ResultFuture()
self._enqueue_task(future, method, args, kwargs)
return future

async def execute_method_async(self, method: str, *args, **kwargs):
async def execute_method_async(self, method: Union[str, bytes], *args,
**kwargs):
future = asyncio.get_running_loop().create_future()
self._enqueue_task(future, method, args, kwargs)
return await future
Expand Down Expand Up @@ -230,8 +231,7 @@ def _run_worker_process(
exception = None
task_id, method, args, kwargs = items
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
output = run_method(worker, method, args, kwargs)
except SystemExit:
raise
except KeyboardInterrupt:
Expand Down
14 changes: 10 additions & 4 deletions vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import cloudpickle
import msgspec

import vllm.envs as envs
Expand Down Expand Up @@ -410,7 +411,7 @@ def execute_model(

def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
Expand All @@ -426,6 +427,11 @@ def _run_workers(
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
Expand All @@ -440,7 +446,7 @@ def _run_workers(
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
worker.execute_method.remote(sent_method, *args, **kwargs)
for worker in ray_workers
]

Expand All @@ -455,7 +461,7 @@ def _run_workers(
if not self.use_ray_spmd_worker:
# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(method, *args, **kwargs)
self.driver_worker.execute_method(sent_method, *args, **kwargs)
]

# Get the results of the ray workers.
Expand Down
14 changes: 5 additions & 9 deletions vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
run_method)
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -39,18 +40,13 @@ def _init_executor(self) -> None:
self.collective_rpc("load_model")

def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None:
kwargs = {}
try:
func = getattr(self.driver_worker, method)
except AttributeError:
raise NotImplementedError(f"Method {method} is not implemented.") \
from None
answer = func(*args, **kwargs)
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer]

def check_health(self) -> None:
Expand Down
Loading
Loading