Skip to content

Commit

Permalink
add disable-custom-all-reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxu02 committed Aug 19, 2024
1 parent b997a18 commit eddb973
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce
)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not server_args.disable_custom_all_reduce)
init_distributed_environment(
backend="nccl",
world_size=self.tp_size,
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ServerArgs:
enable_mla: bool = False
attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False
disable_custom_all_reduce: bool = False

# Distributed args
nccl_init_addr: Optional[str] = None
Expand Down Expand Up @@ -428,6 +429,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
parser.add_argument('--disable-custom-all-reduce',
action='store_true',
default=False,
help='Disable the custom all-reduce kernel and fall back to NCCL.')

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down

0 comments on commit eddb973

Please sign in to comment.