From 9be84c01a402e80419c1a9c1b4337db40979a3c0 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 8 Jan 2025 21:12:29 +0000 Subject: [PATCH 1/2] Revert "[distributed] remove pynccl's redundant change_state (#11749)" This reverts commit 9e764e7b105a483ebc702cad33922ba8d8c210e1. --- tests/distributed/test_pynccl.py | 64 +++++++++++-------- .../device_communicators/pynccl.py | 17 +++++ vllm/distributed/parallel_state.py | 9 ++- 3 files changed, 62 insertions(+), 28 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a8571a1157892..a77b48d5e49f3 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -59,7 +59,8 @@ def worker_fn(): device=get_world_group().device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - tensor = pynccl_comm.all_reduce(tensor) + with pynccl_comm.change_state(enable=True): + tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() assert torch.all(tensor == pynccl_comm.world_size).cpu().item() @@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn(): group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - tensor = pynccl_comm.all_reduce(tensor) - tensor = pynccl_comm.all_reduce(tensor) - torch.cuda.synchronize() - assert torch.all(tensor == 4).cpu().item() - else: - tensor = pynccl_comm.all_reduce(tensor) - torch.cuda.synchronize() - assert torch.all(tensor == 2).cpu().item() + with pynccl_comm.change_state(enable=True): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + tensor = pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() + assert torch.all(tensor == 4).cpu().item() + else: + tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -135,7 +137,8 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph): + with torch.cuda.graph(graph), \ + pynccl_comm.change_state(enable=True): a_out = pynccl_comm.all_reduce(a) torch.cuda.synchronize() graph.replay() @@ -164,7 +167,8 @@ def all_gather_worker_fn(): for r in range(world_size) ]).to(device) - pynccl_comm.all_gather(result, tensor) + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_gather(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -201,7 +205,8 @@ def reduce_scatter_worker_fn(): expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] for tensor in all_tensors).to(device) - pynccl_comm.reduce_scatter(result, tensor) + with pynccl_comm.change_state(enable=True): + pynccl_comm.reduce_scatter(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -228,13 +233,15 @@ def send_recv_worker_fn(): else: tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - - if pynccl_comm.rank == 0: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) - else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + with pynccl_comm.change_state(enable=True): + if pynccl_comm.rank == 0: + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) + else: + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) torch.cuda.synchronize() assert torch.all(tensor == 1).cpu().item() @@ -265,12 +272,15 @@ def multiple_send_recv_worker_fn(): 1024, dtype=torch.float32, device=device) - if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) - else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + with pynccl_comm.change_state(enable=True): + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) + else: + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) torch.cuda.synchronize() if torch.distributed.get_rank() in [0, 2]: assert torch.all(tensor == 1).cpu().item() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index fda4d007ceb5b..93d96fd8f5686 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -212,3 +213,19 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, enable: Optional[bool] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + old_disable = self.disabled + + self.disabled = not enable + yield + + self.disabled = old_disable diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a837c1dc5953b..dccd3addbcb35 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -305,7 +305,14 @@ def graph_capture( stream.wait_stream(curr_stream) with torch.cuda.stream(stream), maybe_ca_context: - yield graph_capture_context + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state() + with maybe_pynccl_context: + yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """ From 177ad8538e18e5a59a219d844b7117a9e225a30c Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 8 Jan 2025 21:12:34 +0000 Subject: [PATCH 2/2] Revert "[distributed] remove pynccl's redundant stream (#11744)" This reverts commit 635b897246da121238454ed4b2bbc87cb4d4166b. --- tests/distributed/test_pynccl.py | 5 ++-- .../device_communicators/pynccl.py | 28 +++++++++++++------ vllm/distributed/parallel_state.py | 3 +- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a77b48d5e49f3..36cfe42251384 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -137,8 +137,9 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph), \ - pynccl_comm.change_state(enable=True): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): a_out = pynccl_comm.all_reduce(a) torch.cuda.synchronize() graph.replay() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 93d96fd8f5686..a6800f93f167b 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -51,6 +51,7 @@ def __init__( if self.world_size == 1: self.available = False self.disabled = True + self.stream = None return try: self.nccl = NCCLLibrary(library_path) @@ -59,6 +60,7 @@ def __init__( # e.g. in a non-GPU environment self.available = False self.disabled = True + self.stream = None return self.available = True @@ -96,12 +98,12 @@ def __init__( with torch.cuda.device(device): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank) + self.stream = torch.cuda.Stream() - stream = torch.cuda.current_stream() # A small all_reduce for warmup. data = torch.zeros(1, device=device) self.all_reduce(data) - stream.synchronize() + self.stream.synchronize() del data def all_reduce(self, @@ -120,7 +122,7 @@ def all_reduce(self, out_tensor = torch.empty_like(in_tensor) if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), buffer_type(out_tensor.data_ptr()), in_tensor.numel(), @@ -142,7 +144,7 @@ def all_gather(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), input_tensor.numel(), @@ -163,7 +165,7 @@ def reduce_scatter(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), output_tensor.numel(), @@ -178,7 +180,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), dst, self.comm, cudaStream_t(stream.cuda_stream)) @@ -190,7 +192,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) @@ -202,7 +204,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream if src == self.rank: sendbuff = buffer_type(tensor.data_ptr()) # NCCL requires the sender also to have a receive buffer @@ -215,7 +217,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.comm, cudaStream_t(stream.cuda_stream)) @contextmanager - def change_state(self, enable: Optional[bool] = None): + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): """ A context manager to change the state of the communicator. """ @@ -223,9 +227,15 @@ def change_state(self, enable: Optional[bool] = None): # guess a default value when not specified enable = self.available + if stream is None: + stream = self.stream + old_disable = self.disabled + old_stream = self.stream + self.stream = stream self.disabled = not enable yield self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index dccd3addbcb35..a0d4235460f3b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -310,7 +310,8 @@ def graph_capture( if not pynccl_comm: maybe_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state() + maybe_pynccl_context = pynccl_comm.change_state( + stream=torch.cuda.current_stream()) with maybe_pynccl_context: yield graph_capture_context