Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray backend] Better error when pg topology is bad. #7584

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ steps:
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
- pytest -v -s distributed/test_multi_node_topology.py

- label: Pipeline Parallelism Test # 23min
working_dir: "/vllm-workspace/tests"
Expand Down
78 changes: 78 additions & 0 deletions tests/distributed/test_multi_node_topology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.

Run:
```sh
cd $VLLM_PATH/tests

pytest distributed/test_multi_node.py
```
"""
import os

import pytest

import ray
from vllm.utils import cuda_device_count_stateless
from ray.cluster_utils import Cluster


TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please clean up the tests and remove unnecessary code. this test does not need TARGET_TEST_SUITE I think.



@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"model, distributed_executor_backend, test_suite", [
("facebook/opt-125m", "ray", "L4"),
])
def test_multi_node_bad_topology(
Copy link
Member

@youkaichao youkaichao Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should test it with 2 nodes test (where ray cluster is already created), and try to launch the test with 1 gpu from both the head node and the worker node, and make sure the vllm instance is scheduled with the same node as the process who launched it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah I think we need real 2 nodes if we want to use node ip resources. let me fix it.

vllm_runner,
model: str,
distributed_executor_backend: str,
test_suite: str,
) -> None:
"""Verify ray + multi node's bad topology raises an exception.

This test simulates multi node ray cluster, so we don't have to start
real 2 multi nodes.

There are 2 potential bad issues.
- the engine's node doesn't have enough GPUs.
- the tensor parallel size exceeds the available GPUs in a current node.
"""
dtype = "half"
assert test_suite == TARGET_TEST_SUITE

# Simulate 2 node clusters, 1 GPU each.
cluster = Cluster()
head_node = cluster.add_node(num_cpus=8, num_gpus=1, resources={"head": 1})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why num_cpus=8 here? is 8 some magic number?

We don't have any guarentee on the number of GPUs we have for this 2 GPUs test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ray doesn't require real hardware cpus when specifying resources like this. 8 is just a random number. (I think technically num_cpus=1 should also work)

ray.init(address=head_node.address)
cluster.add_node(num_cpus=8, num_gpus=1)

# Creating tp == 2. Since TP workers are supposed to spread to 2 workers
# it should log warning.
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend
) as _:
pass

# Simulate there's no GPU in a current node.
@ray.remote(num_gpus=1, resources={"head": 1})
class Actor:
pass

# a is created on a head node.
a = Actor.remote() # type: ignore
ray.get(a.__ray_ready__.remote())

# Now vLLM is created on a head node, but there's no GPU. It should raise
# an exception.
with pytest.raises(RuntimeError), vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=1,
distributed_executor_backend=distributed_executor_backend
) as _:
pass
54 changes: 52 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import List, Optional, Tuple, Union
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup

from vllm.config import ParallelConfig
from vllm.logger import init_logger
Expand Down Expand Up @@ -83,6 +87,49 @@ def assert_ray_available():
"`pip install ray`.") from ray_import_err


def _verify_bundles(placement_group: PlacementGroup,
parallel_config: ParallelConfig):
"""Verify a given placement group has bundles located in the right place.

There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
"""
assert ray.is_initialized(), (
"Ray is not initialized although distributed-executor-backend is ray.")
pg_data = placement_group_table(placement_group)
# bundle_idx -> node_id
bundle_to_node_ids = pg_data["bundles_to_node_id"]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)

for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx])
driver_node_id = ray.get_runtime_context().get_node_id()

if driver_node_id not in node_id_to_bundle:
raise RuntimeError(
f"driver node id {driver_node_id} is not included in a placement "
f"group {placement_group.id}. Node id -> bundles "
f"{node_id_to_bundle}. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f"{driver_node_id} before starting an vLLM engine.")

for node_id, bundles in node_id_to_bundle.items():
if len(bundles) < parallel_config.tensor_parallel_size:
raise RuntimeError(
f"tensor_parallel_size={parallel_config.tensor_parallel_size} "
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
f"is smaller than the reserved number of GPUs ({len(bundles)} "
f"GPUs) in a node {node_id}. Tensor parallel workers can be "
"spread out to 2 nodes which can degrade the performance. "
"To resolve this issue, make sure you have more than "
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
f"{parallel_config.tensor_parallel_size} GPUs available at "
"each node.")


def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
Expand Down Expand Up @@ -144,12 +191,15 @@ def initialize_ray_cluster(
placement_group_specs = ([{
device_str: 1
}] * parallel_config.world_size)
# By default, Ray packs resources as much as possible.
current_placement_group = ray.util.placement_group(
placement_group_specs)
placement_group_specs, strategy="PACK")
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)

assert current_placement_group is not None
_verify_bundles(current_placement_group, parallel_config)
# Set the placement group in the parallel config
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
parallel_config.placement_group = current_placement_group
Loading