-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Ray backend] Better error when pg topology is bad. #7584
Changes from 3 commits
16bcb3f
7f75862
3a5ab31
59cdddb
0f8655f
9ba1f73
17d4385
c361edd
30e0df2
c832faf
a85dcbc
87654a6
a8e5aca
a3f9719
5ef0512
53848df
59f8e2b
2c91954
f5772fc
a01a1ff
09f841a
196d857
4802308
8d6b00a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling. | ||
|
||
Run: | ||
```sh | ||
cd $VLLM_PATH/tests | ||
|
||
pytest distributed/test_multi_node.py | ||
``` | ||
""" | ||
import os | ||
|
||
import pytest | ||
|
||
import ray | ||
from vllm.utils import cuda_device_count_stateless | ||
from ray.cluster_utils import Cluster | ||
|
||
|
||
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") | ||
|
||
|
||
@pytest.mark.skipif(cuda_device_count_stateless() < 2, | ||
reason="Need at least 2 GPUs to run the test.") | ||
@pytest.mark.parametrize( | ||
"model, distributed_executor_backend, test_suite", [ | ||
("facebook/opt-125m", "ray", "L4"), | ||
]) | ||
def test_multi_node_bad_topology( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should test it with 2 nodes test (where ray cluster is already created), and try to launch the test with 1 gpu from both the head node and the worker node, and make sure the vllm instance is scheduled with the same node as the process who launched it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm yeah I think we need real 2 nodes if we want to use node ip resources. let me fix it. |
||
vllm_runner, | ||
model: str, | ||
distributed_executor_backend: str, | ||
test_suite: str, | ||
) -> None: | ||
"""Verify ray + multi node's bad topology raises an exception. | ||
|
||
This test simulates multi node ray cluster, so we don't have to start | ||
real 2 multi nodes. | ||
|
||
There are 2 potential bad issues. | ||
- the engine's node doesn't have enough GPUs. | ||
- the tensor parallel size exceeds the available GPUs in a current node. | ||
""" | ||
dtype = "half" | ||
assert test_suite == TARGET_TEST_SUITE | ||
|
||
# Simulate 2 node clusters, 1 GPU each. | ||
cluster = Cluster() | ||
head_node = cluster.add_node(num_cpus=8, num_gpus=1, resources={"head": 1}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why We don't have any guarentee on the number of GPUs we have for this 2 GPUs test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ray doesn't require real hardware cpus when specifying resources like this. 8 is just a random number. (I think technically num_cpus=1 should also work) |
||
ray.init(address=head_node.address) | ||
cluster.add_node(num_cpus=8, num_gpus=1) | ||
|
||
# Creating tp == 2. Since TP workers are supposed to spread to 2 workers | ||
# it should log warning. | ||
with vllm_runner(model, | ||
dtype=dtype, | ||
tensor_parallel_size=2, | ||
distributed_executor_backend=distributed_executor_backend | ||
) as _: | ||
pass | ||
|
||
# Simulate there's no GPU in a current node. | ||
@ray.remote(num_gpus=1, resources={"head": 1}) | ||
class Actor: | ||
pass | ||
|
||
# a is created on a head node. | ||
a = Actor.remote() # type: ignore | ||
ray.get(a.__ray_ready__.remote()) | ||
|
||
# Now vLLM is created on a head node, but there's no GPU. It should raise | ||
# an exception. | ||
with pytest.raises(RuntimeError), vllm_runner( | ||
model, | ||
dtype=dtype, | ||
tensor_parallel_size=1, | ||
distributed_executor_backend=distributed_executor_backend | ||
) as _: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please clean up the tests and remove unnecessary code. this test does not need
TARGET_TEST_SUITE
I think.