Skip to content

Commit

Permalink
ProcessGroupBaby: support full suite of PG tests
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jan 29, 2025
1 parent 4bdb8a7 commit b3f5b93
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 18 deletions.
130 changes: 121 additions & 9 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,20 @@
import logging
import queue
import threading
from dataclasses import dataclass
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

import torch
import torch.distributed as dist
Expand All @@ -29,7 +41,6 @@
# pyre-fixme[21]: no attribute ProcessGroupNCCL
# pyre-fixme[21]: no attribute ProcessGroupGloo
from torch.distributed import (
BroadcastOptions,
DeviceMesh,
PrefixStore,
ProcessGroup as BaseProcessGroup,
Expand All @@ -40,7 +51,14 @@
get_rank,
init_device_mesh,
)
from torch.distributed.distributed_c10d import Work, _world
from torch.distributed.distributed_c10d import (
AllgatherOptions,
AllreduceOptions,
BroadcastOptions,
ReduceOp,
Work,
_world,
)
from torch.futures import Future

if TYPE_CHECKING:
Expand All @@ -54,6 +72,9 @@
_FUTURE_EXCEPTION = "fut_exception"


T = TypeVar("T")


def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
"""
Gets an item from a queue with a timeout. If the timeout is exceeded then
Expand Down Expand Up @@ -122,15 +143,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
def allreduce(
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
) -> Work:
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allgather(
self,
output_tensors: List[List[torch.Tensor]],
input_tensor: List[torch.Tensor],
opts: object,
opts: AllgatherOptions,
) -> Work:
"""
Gathers tensors from the whole group in a list.
Expand All @@ -140,7 +163,9 @@ def allgather(
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
def broadcast(
self, tensor_list: List[torch.Tensor], opts: BroadcastOptions
) -> Work:
"""
Broadcasts the tensor to the whole group.
Expand Down Expand Up @@ -567,6 +592,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
def get_future(self) -> Future[object]:
return self._pg._get_future(self._op_id)

def __del__(self) -> None:
self._tx.put(("del", self._op_id), timeout=self._timeout)


class _BabyWorkNCCL(_BabyWork):
def wait(self, timeout: Optional[timedelta] = None) -> bool:
Expand Down Expand Up @@ -695,15 +723,18 @@ def _worker(
cmd = op[0]
if cmd == "func":
func_name, args, kwargs = op[1:]
args = _PickleSafeOptions.unsafe_args(args)
fn = getattr(pg, func_name)
work[next_op_id] = fn(*args, **kwargs)
tx.put(next_op_id)
next_op_id += 1
elif cmd == "wait":
op_id: int = op[1]
work[op_id].wait()
del work[op_id]
tx.put(op_id)
elif cmd == "del":
op_id: int = op[1]
del work[op_id]
elif cmd == "future":
op_id: int = op[1]

Expand Down Expand Up @@ -775,7 +806,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
assert rx is not None
assert tx is not None

tx.put(("func", func, args, kwargs), timeout=self._timeout)
tx.put(
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
timeout=self._timeout,
)

op_id = _get(rx, self._timeout)
assert isinstance(op_id, int), f"invalid return {op_id}"
Expand All @@ -784,7 +818,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
)

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
def allreduce(
self,
tensors: List[torch.Tensor],
opts: Union[dist.AllreduceOptions, dist.ReduceOp],
) -> Work:
assert isinstance(tensors, list), "input must be list"

for tensor in tensors:
Expand All @@ -793,10 +831,84 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:

return self._run_func("allreduce", tensors, opts)

def allgather(
self,
output_tensors: List[List[torch.Tensor]],
input_tensor: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
assert isinstance(output_tensors, list), "input must be list"
assert isinstance(input_tensor, list), "input must be list"

for tensor_list in output_tensors:
for tensor in tensor_list:
if not tensor.is_shared():
tensor.share_memory_()

for tensor in input_tensor:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("allgather", output_tensors, input_tensor, opts)

def broadcast(
self,
tensor_list: List[torch.Tensor],
opts: BroadcastOptions,
) -> Work:
assert isinstance(tensor_list, list), "input must be list"

for tensor in tensor_list:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("broadcast", tensor_list, opts)

def size(self) -> int:
return self._world_size


@dataclass
class _PickleSafeOptions:
func: Callable[[], object]
fields: Dict[str, object]

@classmethod
def safe_args(cls, args: T) -> T:
if isinstance(args, tuple):
return tuple(cls.safe_args(arg) for arg in args)
elif isinstance(args, list):
return [cls.safe_args(arg) for arg in args]
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
return cls.from_torch(args)
else:
return args

@classmethod
def unsafe_args(cls, args: T) -> T:
if isinstance(args, tuple):
return tuple(cls.unsafe_args(arg) for arg in args)
elif isinstance(args, list):
return [cls.unsafe_args(arg) for arg in args]
elif isinstance(args, cls):
return args.to_torch()
else:
return args

@classmethod
def from_torch(cls, opts: object) -> "_PickleSafeOptions":
return cls(
func=opts.__class__,
fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")},
)

def to_torch(self) -> object:
opts = self.func()
for k, v in self.fields.items():
setattr(opts, k, v)
return opts


class ProcessGroupBabyGloo(ProcessGroupBaby):
"""
This is a ProcessGroup that runs Gloo in a subprocess.
Expand Down
31 changes: 22 additions & 9 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,15 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
check_tensors(item)

# Test collectives
collectives = {
"allreduce": ([input_tensor], AllreduceOptions()),
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
"broadcast": (tensor_list, BroadcastOptions()),
"broadcast_one": (input_tensor, 0),
}
collectives = [
("allreduce", ([input_tensor], AllreduceOptions())),
("allreduce", ([input_tensor], ReduceOp.SUM)),
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
("broadcast", (tensor_list, BroadcastOptions())),
("broadcast_one", (input_tensor, 0)),
]
works: Dict[str, dist._Work] = {}
for coll_str, args in collectives.items():
for coll_str, args in collectives:
coll = getattr(pg, coll_str)
work = coll(*args)
works[coll_str] = work
Expand Down Expand Up @@ -246,6 +247,18 @@ def test_reconfigure_baby_process_group(self) -> None:
assert p_2 is not None
self.assertTrue(p_2.is_alive())

def test_baby_gloo_opts(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
a.configure(store_addr, 0, 1)

_test_pg(a)

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
m = nn.Linear(3, 4)
Expand Down Expand Up @@ -367,8 +380,8 @@ def test_managed_process_group(self) -> None:
self.assertIsInstance(list(works.values())[0], _ManagedWork)

self.assertEqual(manager.report_error.call_count, 0)
self.assertEqual(manager.wrap_future.call_count, 1)
self.assertEqual(manager.wait_quorum.call_count, 1)
self.assertEqual(manager.wrap_future.call_count, 2)
self.assertEqual(manager.wait_quorum.call_count, 2)


class DeviceMeshTest(TestCase):
Expand Down

0 comments on commit b3f5b93

Please sign in to comment.