From 4dd79a712f249d147e6a595be9db4df3b5d00959 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 29 Jan 2025 14:47:07 -0800 Subject: [PATCH] ProcessGroupBaby: support full suite of PG tests --- torchft/process_group.py | 140 +++++++++++++++++++++++++++++++--- torchft/process_group_test.py | 84 +++++++++++++++++--- 2 files changed, 206 insertions(+), 18 deletions(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index 0afc7d3..0110c6e 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -19,8 +19,21 @@ import logging import queue import threading +from dataclasses import dataclass from datetime import timedelta -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) import torch import torch.distributed as dist @@ -29,7 +42,6 @@ # pyre-fixme[21]: no attribute ProcessGroupNCCL # pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( - BroadcastOptions, DeviceMesh, PrefixStore, ProcessGroup as BaseProcessGroup, @@ -40,7 +52,14 @@ get_rank, init_device_mesh, ) -from torch.distributed.distributed_c10d import Work, _world +from torch.distributed.distributed_c10d import ( + AllgatherOptions, + AllreduceOptions, + BroadcastOptions, + ReduceOp, + Work, + _world, +) from torch.futures import Future if TYPE_CHECKING: @@ -54,6 +73,9 @@ _FUTURE_EXCEPTION = "fut_exception" +T = TypeVar("T") + + def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object: """ Gets an item from a queue with a timeout. If the timeout is exceeded then @@ -122,7 +144,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: raise NotImplementedError("not implemented") # pyre-fixme[14]: inconsistent override - def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + def allreduce( + self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp] + ) -> Work: raise NotImplementedError("not implemented") # pyre-fixme[14]: inconsistent override @@ -130,7 +154,7 @@ def allgather( self, output_tensors: List[List[torch.Tensor]], input_tensor: List[torch.Tensor], - opts: object, + opts: AllgatherOptions, ) -> Work: """ Gathers tensors from the whole group in a list. @@ -140,7 +164,9 @@ def allgather( raise NotImplementedError("not implemented") # pyre-fixme[14]: inconsistent override - def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: + def broadcast( + self, tensor_list: List[torch.Tensor], opts: BroadcastOptions + ) -> Work: """ Broadcasts the tensor to the whole group. @@ -567,6 +593,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool: def get_future(self) -> Future[object]: return self._pg._get_future(self._op_id) + def __del__(self) -> None: + self._tx.put(("del", self._op_id), timeout=self._timeout) + class _BabyWorkNCCL(_BabyWork): def wait(self, timeout: Optional[timedelta] = None) -> bool: @@ -695,6 +724,7 @@ def _worker( cmd = op[0] if cmd == "func": func_name, args, kwargs = op[1:] + args = _PickleSafeOptions.unsafe_args(args) fn = getattr(pg, func_name) work[next_op_id] = fn(*args, **kwargs) tx.put(next_op_id) @@ -702,8 +732,10 @@ def _worker( elif cmd == "wait": op_id: int = op[1] work[op_id].wait() - del work[op_id] tx.put(op_id) + elif cmd == "del": + op_id: int = op[1] + del work[op_id] elif cmd == "future": op_id: int = op[1] @@ -731,6 +763,8 @@ def callback(fut: Future[object]) -> None: del work[op_id] tx.put((op_id, event)) + elif cmd == "num_active_work": + tx.put(len(work)) else: raise ValueError(f"unknown cmd: {cmd}") @@ -775,7 +809,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: assert rx is not None assert tx is not None - tx.put(("func", func, args, kwargs), timeout=self._timeout) + tx.put( + ("func", func, _PickleSafeOptions.safe_args(args), kwargs), + timeout=self._timeout, + ) op_id = _get(rx, self._timeout) assert isinstance(op_id, int), f"invalid return {op_id}" @@ -784,7 +821,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout ) - def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + def allreduce( + self, + tensors: List[torch.Tensor], + opts: Union[dist.AllreduceOptions, dist.ReduceOp], + ) -> Work: assert isinstance(tensors, list), "input must be list" for tensor in tensors: @@ -793,9 +834,90 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: return self._run_func("allreduce", tensors, opts) + def allgather( + self, + output_tensors: List[List[torch.Tensor]], + input_tensor: List[torch.Tensor], + opts: AllgatherOptions, + ) -> Work: + assert isinstance(output_tensors, list), "input must be list" + assert isinstance(input_tensor, list), "input must be list" + + for tensor_list in output_tensors: + for tensor in tensor_list: + if not tensor.is_shared(): + tensor.share_memory_() + + for tensor in input_tensor: + if not tensor.is_shared(): + tensor.share_memory_() + + return self._run_func("allgather", output_tensors, input_tensor, opts) + + def broadcast( + self, + tensor_list: List[torch.Tensor], + opts: BroadcastOptions, + ) -> Work: + assert isinstance(tensor_list, list), "input must be list" + + for tensor in tensor_list: + if not tensor.is_shared(): + tensor.share_memory_() + + return self._run_func("broadcast", tensor_list, opts) + def size(self) -> int: return self._world_size + def num_active_work(self) -> int: + assert self._tx is not None + self._tx.put(("num_active_work",), timeout=self._timeout) + + assert self._rx is not None + return cast(int, _get(self._rx, self._timeout)) + + +@dataclass +class _PickleSafeOptions: + func: Callable[[], object] + fields: Dict[str, object] + + @classmethod + def safe_args(cls, args: T) -> T: + if isinstance(args, tuple): + return tuple(cls.safe_args(arg) for arg in args) + elif isinstance(args, list): + return [cls.safe_args(arg) for arg in args] + elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)): + return cls.from_torch(args) + else: + return args + + @classmethod + def unsafe_args(cls, args: T) -> T: + if isinstance(args, tuple): + return tuple(cls.unsafe_args(arg) for arg in args) + elif isinstance(args, list): + return [cls.unsafe_args(arg) for arg in args] + elif isinstance(args, cls): + return args.to_torch() + else: + return args + + @classmethod + def from_torch(cls, opts: object) -> "_PickleSafeOptions": + return cls( + func=opts.__class__, + fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")}, + ) + + def to_torch(self) -> object: + opts = self.func() + for k, v in self.fields.items(): + setattr(opts, k, v) + return opts + class ProcessGroupBabyGloo(ProcessGroupBaby): """ diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 75a3e53..6abeb39 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc import io import multiprocessing import os @@ -87,14 +88,15 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] check_tensors(item) # Test collectives - collectives = { - "allreduce": ([input_tensor], AllreduceOptions()), - "allgather": (output_tensors, [input_tensor], AllgatherOptions()), - "broadcast": (tensor_list, BroadcastOptions()), - "broadcast_one": (input_tensor, 0), - } + collectives = [ + ("allreduce", ([input_tensor], AllreduceOptions())), + ("allreduce", ([input_tensor], ReduceOp.SUM)), + ("allgather", (output_tensors, [input_tensor], AllgatherOptions())), + ("broadcast", (tensor_list, BroadcastOptions())), + ("broadcast_one", (input_tensor, 0)), + ] works: Dict[str, dist._Work] = {} - for coll_str, args in collectives.items(): + for coll_str, args in collectives: coll = getattr(pg, coll_str) work = coll(*args) works[coll_str] = work @@ -247,6 +249,23 @@ def test_reconfigure_baby_process_group(self) -> None: assert p_2 is not None self.assertTrue(p_2.is_alive()) + def test_baby_gloo_apis(self) -> None: + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + + store_addr = f"localhost:{store.port}/prefix" + + a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10)) + a.configure(store_addr, 0, 1) + + _test_pg(a) + + # force collection to ensure no BabyWork objects remain + gc.collect() + + self.assertEqual(a.num_active_work(), 0) + def test_dummy(self) -> None: pg = ProcessGroupDummy(0, 1) m = nn.Linear(3, 4) @@ -368,5 +387,52 @@ def test_managed_process_group(self) -> None: self.assertIsInstance(list(works.values())[0], _ManagedWork) self.assertEqual(manager.report_error.call_count, 0) - self.assertEqual(manager.wrap_future.call_count, 1) - self.assertEqual(manager.wait_quorum.call_count, 1) + self.assertEqual(manager.wrap_future.call_count, 2) + self.assertEqual(manager.wait_quorum.call_count, 2) + + +class DeviceMeshTest(TestCase): + @staticmethod + def _test_init_device_mesh(world_size: int, rank: int) -> None: + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(12346) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(4) + + testcase = TestCase() + + manager = Mock(spec=Manager) + # Even though we only have 4 workers, we can still initialize (2, 4) mesh. + # That's because the replicate group is NOT phystically created in the + # real mesh but is virtually added to the mesh via ManagedDeviceMesh. + device_mesh = ft_init_device_mesh( + device_type="cpu", + mesh_shape=(2, world_size), + mesh_dim_names=("dp_replicate", "dp_shard"), + replicate_dim=0, + manager=manager, + ) + + testcase.assertTrue( + isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) + ) + testcase.assertTrue( + not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) + ) + replicate_group = device_mesh.get_group("dp_replicate") + testcase.assertEqual( + cast(ManagedProcessGroup, replicate_group)._manager, manager + ) + replicate_mesh = device_mesh["dp_replicate"] + testcase.assertEqual(replicate_mesh.get_group(), replicate_group) + flatten_mesh = device_mesh._flatten("dp") + manager.num_participants.return_value = 1 + testcase.assertEqual(flatten_mesh.size(), world_size) + testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank()) + + def test_init_device_mesh(self) -> None: + with ProcessPoolExecutor(max_workers=4) as executor: + futures = [] + for i in range(4): + future = executor.submit(self._test_init_device_mesh, 4, i) + futures.append(future)