diff --git a/torchft/process_group.py b/torchft/process_group.py index 6e368f8..3a49e3c 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -124,6 +124,20 @@ def allgather( """ raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override + def allgather_into_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: AllgatherOptions, + ) -> Work: + """ + Performs an allgather operation on coalesced tensors. + + See torch.distributed.allgather_coalesced for more details. + """ + raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def allreduce( self, @@ -212,6 +226,20 @@ def reduce_scatter( """ raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override + def reduce_scatter_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: ReduceScatterOptions, + ) -> Work: + """ + Performs a reduce-scatter operation on coalesced tensors. + + See torch.distributed.reduce_scatter_tensor for more details. + """ + raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: """ @@ -336,10 +364,20 @@ def allgather( self, output_tensors: List[List[torch.Tensor]], input_tensor: List[torch.Tensor], - opts: object, + opts: AllgatherOptions, ) -> Work: return self.parent.allgather(output_tensors, input_tensor, opts) + def allgather_into_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: AllgatherOptions, + ) -> Work: + return self.parent.allgather_into_tensor_coalesced( + output_tensors, input_tensors, opts + ) + def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: return self.parent.allreduce(tensors, opts) @@ -377,6 +415,16 @@ def reduce_scatter( ) -> Work: return self.parent.reduce_scatter(output_tensors, input_tensors, opts) + def reduce_scatter_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: ReduceScatterOptions, + ) -> Work: + return self.parent.reduce_scatter_tensor_coalesced( + output_tensors, input_tensors, opts + ) + def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: return self.parent.send(tensors, dst_rank, tag) @@ -402,8 +450,15 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None: self._timeout = timeout def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup: + pg = BaseProcessGroup(store, rank, world_size) + pg._set_default_backend(ProcessGroup.BackendType.GLOO) # pyre-fixme[16]: no attribute ProcessGroupGloo - return BaseProcessGroupGloo(store, rank, world_size, self._timeout) + backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout) + backend_class._set_sequence_number_for_group() + pg._register_backend( + torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class + ) + return pg def getBackendName(self) -> str: return "torchft-gloo" @@ -427,6 +482,28 @@ def reduce_scatter( """ raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.") + # pyre-fixme[15]: inconsistent override + def reduce_scatter_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: ReduceScatterOptions, + ) -> None: + """ + This function is a placeholder for the reduce_scatter_tensor_coalesced + operation in the ProcessGroupGloo class. + However, this operation is not supported by the + Gloo backend, and thus, calling this function will raise a + RuntimeError. + + Raises: + RuntimeError: Always raised since reduce_scatter is not + supported by ProcessGroupGloo. + """ + raise RuntimeError( + "ProcessGroupGloo does not support reduce_scatter_tensor_coalesced." + ) + class ProcessGroupNCCL(ProcessGroupWrapper): """ @@ -440,8 +517,15 @@ class ProcessGroupNCCL(ProcessGroupWrapper): """ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup: + pg = BaseProcessGroup(store, rank, world_size) + pg._set_default_backend(ProcessGroup.BackendType.NCCL) # pyre-fixme[16]: no attribute ProcessGroupNCCL - return BaseProcessGroupNCCL(store, rank, world_size) + backend_class = BaseProcessGroupNCCL(store, rank, world_size) + backend_class._set_sequence_number_for_group() + pg._register_backend( + torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class + ) + return pg def getBackendName(self) -> str: return "torchft-nccl" @@ -499,6 +583,19 @@ def allgather( self._work.append(res) return res + def allgather_into_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: AllgatherOptions, + ) -> Work: + for o, i in zip(output_tensors, input_tensors): + o.copy_(i) + + res = _DummyWork(output_tensors) + self._work.append(res) + return res + def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: res = _DummyWork(tensors) self._work.append(res) @@ -548,6 +645,19 @@ def reduce_scatter( self._work.append(res) return res + def reduce_scatter_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: ReduceScatterOptions, + ) -> Work: + for o, i in zip(output_tensors, input_tensors): + o.copy_(i) + + res = _DummyWork(output_tensors) + self._work.append(res) + return res + def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: return _DummyWork(None) @@ -1134,6 +1244,20 @@ def allgather( _maybe_share_tensors(input_tensor) return self._run_func("allgather", output_tensors, input_tensor, opts) + def allgather_into_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: AllgatherOptions, + ) -> Work: + _assert_list(output_tensors) + _assert_list(input_tensors) + _maybe_share_tensors(output_tensors) + _maybe_share_tensors(input_tensors) + return self._run_func( + "allgather_into_tensor_coalesced", output_tensors, input_tensors, opts + ) + def allreduce( self, tensors: List[torch.Tensor], @@ -1200,6 +1324,20 @@ def reduce_scatter( _maybe_share_tensors(input_tensors) return self._run_func("reduce_scatter", output_tensors, input_tensors, opts) + def reduce_scatter_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: ReduceScatterOptions, + ) -> Work: + _assert_list(output_tensors) + _assert_list(input_tensors) + _maybe_share_tensors(output_tensors) + _maybe_share_tensors(input_tensors) + return self._run_func( + "reduce_scatter_tensor_coalesced", output_tensors, input_tensors, opts + ) + def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: _assert_list(tensors) _maybe_share_tensors(tensors) @@ -1278,8 +1416,14 @@ class ProcessGroupBabyGloo(ProcessGroupBaby): @classmethod def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup: + pg = BaseProcessGroup(store, rank, world_size) + pg._set_default_backend(ProcessGroup.BackendType.GLOO) # pyre-fixme[16]: no attribute ProcessGroupGloo - return BaseProcessGroupGloo(store, rank, world_size) + backend_class = BaseProcessGroupGloo(store, rank, world_size) + pg._register_backend( + torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class + ) + return pg def getBackendName(self) -> str: return "torchft-baby-gloo" @@ -1303,6 +1447,28 @@ def reduce_scatter( """ raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.") + # pyre-fixme[15]: inconsistent override + def reduce_scatter_tensor_coalesced( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: ReduceScatterOptions, + ) -> None: + """ + This function is a placeholder for the reduce_scatter_tensor_coalesced + operation in the ProcessGroupBabyGloo class. + However, this operation is not supported by the + Gloo backend, and thus, calling this function will raise a + RuntimeError. + + Raises: + RuntimeError: Always raised since reduce_scatter is not + supported by ProcessGroupBabyGloo. + """ + raise RuntimeError( + "ProcessGroupBabyGloo does not support reduce_scatter_tensor_coalesced." + ) + class ProcessGroupBabyNCCL(ProcessGroupBaby): """ @@ -1322,8 +1488,15 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): @classmethod def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup: + pg = BaseProcessGroup(store, rank, world_size) + pg._set_default_backend(ProcessGroup.BackendType.NCCL) # pyre-fixme[16]: no attribute ProcessGroupNCCL - return BaseProcessGroupNCCL(store, rank, world_size) + backend_class = BaseProcessGroupNCCL(store, rank, world_size) + backend_class._set_sequence_number_for_group() + pg._register_backend( + torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class + ) + return pg def getBackendName(self) -> str: return "torchft-baby-nccl" diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index b3ca93e..236f773 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -98,6 +98,10 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] ("allreduce", ([input_tensor], ReduceOp.SUM)), ("allreduce_coalesced", ([input_tensor], AllreduceCoalescedOptions())), ("allgather", (output_tensors, [input_tensor], AllgatherOptions())), + ( + "allgather_into_tensor_coalesced", + (output_tensors[0], [input_tensor], AllgatherOptions()), + ), ( "alltoall_base", ( @@ -115,6 +119,10 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] "reduce_scatter", (output_tensors[0], [[input_tensor]], ReduceScatterOptions()), ), + ( + "reduce_scatter_tensor_coalesced", + (output_tensors[0], [input_tensor], ReduceScatterOptions()), + ), ] works: Dict[str, dist._Work] = {} @@ -166,6 +174,49 @@ def run_allgather_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> Non torch.testing.assert_close(output_list[r], expected) +def run_allgather_into_tensor_coalesced_test( + pg: ProcessGroup, rank: int, tensor: torch.Tensor +) -> None: + """Test allgather tensor coalesced collective operation. + + This example gathers two local tensors, T0 and T1, from each rank into corresponding + output tensors. + + For world_sz = n, each rank r has: + T0 = [r+1], + T1 = [r+10] + + After allgather_into_tensor_coalesced, we result in two tensors: out0, out1, + both length n. + + out0 gathers T0 from all ranks, out1 gathers T1 from all ranks. + + We verify that out0[k] == [k+1] and out1[k] == [k+10] for all k. + + """ + world_sz = pg.size() + + if world_sz < 2: + return + + t0 = torch.tensor([rank + 1], device=tensor.device, dtype=tensor.dtype) + t1 = torch.tensor([rank + 10], device=tensor.device, dtype=tensor.dtype) + + out0 = torch.zeros(world_sz, device=tensor.device, dtype=tensor.dtype) + out1 = torch.zeros(world_sz, device=tensor.device, dtype=tensor.dtype) + + work = pg.allgather_into_tensor_coalesced( + [out0, out1], [t0, t1], AllgatherOptions() + ) + work.wait() + + for r in range(world_sz): + expected0 = torch.tensor([r + 1], device=t0.device, dtype=t0.dtype) + torch.testing.assert_close(out0[r], expected0[0]) + expected1 = torch.tensor([r + 10], device=t1.device, dtype=t1.dtype) + torch.testing.assert_close(out1[r], expected1[0]) + + def run_allreduce_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: """Test allreduce collective operation. @@ -351,8 +402,87 @@ def run_reduce_scatter_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) - torch.testing.assert_close(out, expected_sum) +def run_reduce_scatter_tensor_coalesced_test( + pg: ProcessGroup, rank: int, tensor: torch.Tensor +) -> None: + """Test reduce_scatter tensor coalesced collective operation. + + We define two 2D tensors, each shaped [world_sz, world_sz] which is replicated on each rank. + + reduce_scatter coalesced will reduce each row of each tensor, then scatter the results to each rank. + Because these are replicated on all ranks, the reduced sum for each row is: + [r*world_sz + 1, ..., r*world_sz + world_sz] * world_sz + + For example, with 2 ranks: + rank 0 gets: [1, 2] * 2 = [2, 4] (first row) + rank 1 gets: [3, 4] * 2 = [6, 8] (second row) + For example, with 2 ranks: + rank 0 gets: [1, 2] * 2 = [2, 4] (first row) + rank 1 gets: [3, 4] * 2 = [6, 8] (second row) + + """ + world_sz = pg.size() + if world_sz < 2: + return # skip trivial + + # Build m0, m1 (each is a list of n rows) fully replicated on all ranks + m0 = [] + m1 = [] + for r in range(world_sz): + row0 = torch.arange( + start=r * world_sz + 1, + end=r * world_sz + world_sz + 1, + device=tensor.device, + dtype=torch.float32, + ) + row1 = torch.arange( + start=r * world_sz + 100, + end=r * world_sz + 100 + world_sz, + device=tensor.device, + dtype=torch.float32, + ) + m0.append(row0) + m1.append(row1) + + # Each rank receives one "row" for m0, one row for m1, after reduce_scatter_coalesced + out0 = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32) + out1 = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32) + + opts = ReduceScatterOptions() + opts.reduceOp = ReduceOp.SUM + + m0 = torch.stack(m0) + m1 = torch.stack(m1) + + work = pg.reduce_scatter_tensor_coalesced([out0, out1], [m0, m1], opts) + work.wait() + + base0 = ( + torch.arange( + start=rank * world_sz + 1, + end=rank * world_sz + world_sz + 1, + device=tensor.device, + dtype=torch.float32, + ) + * world_sz + ) + base1 = ( + torch.arange( + start=rank * world_sz + 100, + end=rank * world_sz + 100 + world_sz, + device=tensor.device, + dtype=torch.float32, + ) + * world_sz + ) + + torch.testing.assert_close(out0, base0) + torch.testing.assert_close(out1, base1) + + _COLLECTIVE_TO_FUNC: Dict[str, Callable[[ProcessGroup, int, torch.Tensor], None]] = { "allgather": run_allgather_test, + "allgather_into_tensor_coalesced": run_allgather_into_tensor_coalesced_test, "allreduce": run_allreduce_test, "allreduce_coalesced": run_allreduce_coalesced_test, "alltoall_base": run_alltoall_test, @@ -360,13 +490,14 @@ def run_reduce_scatter_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) - "broadcast": run_broadcast_test, "broadcast_one": run_broadcast_one_test, "reduce_scatter": run_reduce_scatter_test, + "reduce_scatter_tensor_coalesced": run_reduce_scatter_tensor_coalesced_test, "send/recv": run_send_recv_test, } _ALL_COLLECTIVES: List[str] = list(_COLLECTIVE_TO_FUNC.keys()) class ProcessGroupTest(TestCase): - def test_gloo(self) -> None: + def test_gloo_apis(self) -> None: store = TCPStore( host_name="localhost", port=0, is_master=True, wait_for_workers=False ) @@ -397,7 +528,7 @@ def test_gloo_timeout(self) -> None: # pyre-fixme[56]: Pyre was not able to infer the type of argument @skipUnless(torch.cuda.is_available(), "needs CUDA") - def test_nccl(self) -> None: + def test_nccl_apis(self) -> None: store = TCPStore( host_name="localhost", port=0, is_master=True, wait_for_workers=False ) @@ -790,6 +921,7 @@ class GlooMultiPgTest(MultiPgBaseTest): SKIP = [ "alltoall_base", "reduce_scatter", + "reduce_scatter_tensor_coalesced", ] COLLECTIVES: List[str] = list(set(_ALL_COLLECTIVES) - set(SKIP)) @@ -808,6 +940,7 @@ class BabyGlooMultiPgTest(MultiPgBaseTest): SKIP = [ "alltoall_base", "reduce_scatter", + "reduce_scatter_tensor_coalesced", ] COLLECTIVES: List[str] = list(set(_ALL_COLLECTIVES) - set(SKIP))