Skip to content

Commit

Permalink
Adds support for allgather_into_tensor_coalesced and `reduce_scatte…
Browse files Browse the repository at this point in the history
…r_tensor_coalesced` (#114)

* initial commit to add final collectives

* adds tests, modifies process group creation to register a backend

* slight cleanups

---------

Co-authored-by: Allen Wang <[email protected]>
  • Loading branch information
allenwang28 and Allen Wang authored Feb 21, 2025
1 parent c782f4e commit d427bef
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 7 deletions.
183 changes: 178 additions & 5 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ def allgather(
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allgather_into_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
"""
Performs an allgather operation on coalesced tensors.
See torch.distributed.allgather_coalesced for more details.
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allreduce(
self,
Expand Down Expand Up @@ -212,6 +226,20 @@ def reduce_scatter(
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> Work:
"""
Performs a reduce-scatter operation on coalesced tensors.
See torch.distributed.reduce_scatter_tensor for more details.
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
"""
Expand Down Expand Up @@ -336,10 +364,20 @@ def allgather(
self,
output_tensors: List[List[torch.Tensor]],
input_tensor: List[torch.Tensor],
opts: object,
opts: AllgatherOptions,
) -> Work:
return self.parent.allgather(output_tensors, input_tensor, opts)

def allgather_into_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
return self.parent.allgather_into_tensor_coalesced(
output_tensors, input_tensors, opts
)

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
return self.parent.allreduce(tensors, opts)

Expand Down Expand Up @@ -377,6 +415,16 @@ def reduce_scatter(
) -> Work:
return self.parent.reduce_scatter(output_tensors, input_tensors, opts)

def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> Work:
return self.parent.reduce_scatter_tensor_coalesced(
output_tensors, input_tensors, opts
)

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
return self.parent.send(tensors, dst_rank, tag)

Expand All @@ -402,8 +450,15 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
self._timeout = timeout

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
# pyre-fixme[16]: no attribute ProcessGroupGloo
return BaseProcessGroupGloo(store, rank, world_size, self._timeout)
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
backend_class._set_sequence_number_for_group()
pg._register_backend(
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
)
return pg

def getBackendName(self) -> str:
return "torchft-gloo"
Expand All @@ -427,6 +482,28 @@ def reduce_scatter(
"""
raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.")

# pyre-fixme[15]: inconsistent override
def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> None:
"""
This function is a placeholder for the reduce_scatter_tensor_coalesced
operation in the ProcessGroupGloo class.
However, this operation is not supported by the
Gloo backend, and thus, calling this function will raise a
RuntimeError.
Raises:
RuntimeError: Always raised since reduce_scatter is not
supported by ProcessGroupGloo.
"""
raise RuntimeError(
"ProcessGroupGloo does not support reduce_scatter_tensor_coalesced."
)


class ProcessGroupNCCL(ProcessGroupWrapper):
"""
Expand All @@ -440,8 +517,15 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
"""

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
# pyre-fixme[16]: no attribute ProcessGroupNCCL
return BaseProcessGroupNCCL(store, rank, world_size)
backend_class = BaseProcessGroupNCCL(store, rank, world_size)
backend_class._set_sequence_number_for_group()
pg._register_backend(
torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
)
return pg

def getBackendName(self) -> str:
return "torchft-nccl"
Expand Down Expand Up @@ -499,6 +583,19 @@ def allgather(
self._work.append(res)
return res

def allgather_into_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
for o, i in zip(output_tensors, input_tensors):
o.copy_(i)

res = _DummyWork(output_tensors)
self._work.append(res)
return res

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
res = _DummyWork(tensors)
self._work.append(res)
Expand Down Expand Up @@ -548,6 +645,19 @@ def reduce_scatter(
self._work.append(res)
return res

def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> Work:
for o, i in zip(output_tensors, input_tensors):
o.copy_(i)

res = _DummyWork(output_tensors)
self._work.append(res)
return res

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
return _DummyWork(None)

Expand Down Expand Up @@ -1134,6 +1244,20 @@ def allgather(
_maybe_share_tensors(input_tensor)
return self._run_func("allgather", output_tensors, input_tensor, opts)

def allgather_into_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
_assert_list(output_tensors)
_assert_list(input_tensors)
_maybe_share_tensors(output_tensors)
_maybe_share_tensors(input_tensors)
return self._run_func(
"allgather_into_tensor_coalesced", output_tensors, input_tensors, opts
)

def allreduce(
self,
tensors: List[torch.Tensor],
Expand Down Expand Up @@ -1200,6 +1324,20 @@ def reduce_scatter(
_maybe_share_tensors(input_tensors)
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)

def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> Work:
_assert_list(output_tensors)
_assert_list(input_tensors)
_maybe_share_tensors(output_tensors)
_maybe_share_tensors(input_tensors)
return self._run_func(
"reduce_scatter_tensor_coalesced", output_tensors, input_tensors, opts
)

def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
_assert_list(tensors)
_maybe_share_tensors(tensors)
Expand Down Expand Up @@ -1278,8 +1416,14 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
# pyre-fixme[16]: no attribute ProcessGroupGloo
return BaseProcessGroupGloo(store, rank, world_size)
backend_class = BaseProcessGroupGloo(store, rank, world_size)
pg._register_backend(
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
)
return pg

def getBackendName(self) -> str:
return "torchft-baby-gloo"
Expand All @@ -1303,6 +1447,28 @@ def reduce_scatter(
"""
raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.")

# pyre-fixme[15]: inconsistent override
def reduce_scatter_tensor_coalesced(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[torch.Tensor],
opts: ReduceScatterOptions,
) -> None:
"""
This function is a placeholder for the reduce_scatter_tensor_coalesced
operation in the ProcessGroupBabyGloo class.
However, this operation is not supported by the
Gloo backend, and thus, calling this function will raise a
RuntimeError.
Raises:
RuntimeError: Always raised since reduce_scatter is not
supported by ProcessGroupBabyGloo.
"""
raise RuntimeError(
"ProcessGroupBabyGloo does not support reduce_scatter_tensor_coalesced."
)


class ProcessGroupBabyNCCL(ProcessGroupBaby):
"""
Expand All @@ -1322,8 +1488,15 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
# pyre-fixme[16]: no attribute ProcessGroupNCCL
return BaseProcessGroupNCCL(store, rank, world_size)
backend_class = BaseProcessGroupNCCL(store, rank, world_size)
backend_class._set_sequence_number_for_group()
pg._register_backend(
torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
)
return pg

def getBackendName(self) -> str:
return "torchft-baby-nccl"
Expand Down
Loading

0 comments on commit d427bef

Please sign in to comment.