Skip to content

Commit

Permalink
[Core][Distributed] add same-node detection (#5369)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Jun 11, 2024
1 parent dcbf428 commit c4bd03c
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 1 deletion.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ steps:
working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
Expand Down
11 changes: 11 additions & 0 deletions tests/distributed/test_same_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import os

import torch

from vllm.distributed.parallel_state import is_in_the_same_node

torch.distributed.init_process_group(backend="gloo")
test_result = is_in_the_same_node(torch.distributed.group.WORLD)

expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}"
9 changes: 8 additions & 1 deletion vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import (
get_local_rank, get_tensor_model_parallel_cpu_group)
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
from vllm.logger import init_logger

try:
Expand Down Expand Up @@ -113,6 +113,13 @@ def __init__(self,
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")

if not is_in_the_same_node(group):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
return

rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
Expand Down
67 changes: 67 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
from multiprocessing import resource_tracker, shared_memory
from typing import List, Optional

import torch
Expand Down Expand Up @@ -376,3 +378,68 @@ def destroy_model_parallel():
_PP_DEVICE_GROUP = None
global _PP_GLOBAL_RANKS
_PP_GLOBAL_RANKS = None


def is_in_the_same_node(pg: ProcessGroup):
"""
This is a collective operation that checks if all processes in the group
are in the same node. It tests if all processes are attached to the same
memory system (shared access to shared memory).
"""
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"is_in_the_same_node should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)

# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)

magic_message = b"magic_message"
shm = None

try:
with contextlib.suppress(OSError):
if rank == 0:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message
torch.distributed.broadcast_object_list([shm.name],
src=ranks[0],
group=pg)
is_in_the_same_node[0] = 1
else:
# try to open the shared memory segment
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=ranks[0],
group=pg)
name = recv[0]
shm = shared_memory.SharedMemory(name=name)
if shm.buf[:len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1
except Exception as e:
logger.error("Error ignored in is_in_the_same_node: %s", e)
finally:
if shm:
shm.close()

torch.distributed.barrier(group=pg)

# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == 0:
if shm:
shm.unlink()
else:
if shm:
# fix to https://stackoverflow.com/q/62748654/9191338
resource_tracker.unregister(
shm._name, "shared_memory") # type: ignore[attr-defined]
torch.distributed.all_reduce(is_in_the_same_node, group=pg)

return is_in_the_same_node.sum().item() == world_size

0 comments on commit c4bd03c

Please sign in to comment.