-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] allow callable in collective_rpc #12151
Changes from 11 commits
a9e2d9a
8eb8597
e3cca77
358d212
d021811
8519d2e
08502d6
03663b9
4aaa4cb
c895505
a60d11a
f98038d
bf870d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
|
||
from vllm import LLM | ||
|
||
from ...utils import fork_new_process_for_each_test | ||
|
||
|
||
def echo_rank(self): | ||
return self.rank | ||
|
||
|
||
@pytest.mark.parametrize("tp_size", [1, 2]) | ||
@pytest.mark.parametrize("backend", ["mp", "ray"]) | ||
@fork_new_process_for_each_test | ||
def test_collective_rpc(tp_size, backend): | ||
if tp_size == 1 and backend == "ray": | ||
pytest.skip("Skip duplicate test case") | ||
if tp_size == 1: | ||
backend = None | ||
|
||
from vllm.worker.worker import Worker | ||
|
||
class MyWorker(Worker): | ||
|
||
def echo_rank(self): | ||
return self.rank | ||
|
||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", | ||
enforce_eager=True, | ||
load_format="dummy", | ||
tensor_parallel_size=tp_size, | ||
distributed_executor_backend=backend, | ||
worker_cls=MyWorker) | ||
for method in ["echo_rank", echo_rank]: | ||
assert llm.collective_rpc(method) == list(range(tp_size)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import asyncio | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union | ||
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, | ||
Union) | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
|
@@ -47,7 +48,7 @@ def _init_executor(self) -> None: | |
|
||
@abstractmethod | ||
def collective_rpc(self, | ||
method: str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can add some type annotations to callable method. Something like from typing import Any, Callable, Mapping, TypeVar
from typing_extensions import Concatenate
_R = TypeVar("_R")
WorkerMethod = Callable[Concatenate[WorkerBase, Any], _R]
def collective_rpc((
self,
method: Union[str, WorkerMethod[_R]],
timeout: Optional[float] = None,
args: tuple[Any, ...] = (),
# Mutable default is fine because we annotate it as immutable
kwargs: Mapping[str, Any] = {},
) -> list[_R]:
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm afraid repeating this type annotation in all places would be less clear. I prefer the current simple type annotation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can just have this at the top level for user convenience. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. personally I don't use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would make it easier for users to go to the worker definition when looking at Now that I think about it, perhaps it would be better to return a wrapper like
This way, we can define There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we can actually just return the executor instead of a wrapper in that case... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. users can print the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, let's keep this for simplicity then. Maybe we should update the example file to illustrate this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll prepare a dedicated doc to show the usage of |
||
method: Union[str, Callable], | ||
timeout: Optional[float] = None, | ||
args: Tuple = (), | ||
kwargs: Optional[Dict] = None) -> List[Any]: | ||
|
@@ -260,7 +261,7 @@ def _driver_execute_model( | |
raise NotImplementedError | ||
|
||
def collective_rpc(self, | ||
method: str, | ||
method: Union[str, Callable], | ||
timeout: Optional[float] = None, | ||
args: Tuple = (), | ||
kwargs: Optional[Dict] = None) -> List[Any]: | ||
|
@@ -269,7 +270,7 @@ def collective_rpc(self, | |
@abstractmethod | ||
def _run_workers( | ||
self, | ||
method: str, | ||
method: Union[str, Callable], | ||
*args, | ||
async_run_tensor_parallel_workers_only: bool = False, | ||
max_concurrent_workers: Optional[int] = None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the
multi_gpu_test
decorator not work here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is a mixture of 1 gpu test and 2 gpu test. does
multi_gpu_test
work?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multi_gpu_test
requires 2 GPUs to run the test. Since we're skipping this file in the single GPU CI anyways, I think it shouldn't make a difference.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then we need to duplicate the test in two sections 🤔
I prefer to directly run the whole test in 2 GPU tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can move the test logic into a separate function, and have two separate
test_*
functions (one with 1 GPU, withoutmulti_gpu_test
; and another with 2 GPUs, withmulti_gpu_test
decorator)` to be called by pytest.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I intentionally merge them into one function to reduce code duplication :(