Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProcessGroupBaby: support full suite of PG tests #89

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 131 additions & 9 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,21 @@
import logging
import queue
import threading
from dataclasses import dataclass
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)

import torch
import torch.distributed as dist
Expand All @@ -29,7 +42,6 @@
# pyre-fixme[21]: no attribute ProcessGroupNCCL
# pyre-fixme[21]: no attribute ProcessGroupGloo
from torch.distributed import (
BroadcastOptions,
DeviceMesh,
PrefixStore,
ProcessGroup as BaseProcessGroup,
Expand All @@ -40,7 +52,14 @@
get_rank,
init_device_mesh,
)
from torch.distributed.distributed_c10d import Work, _world
from torch.distributed.distributed_c10d import (
AllgatherOptions,
AllreduceOptions,
BroadcastOptions,
ReduceOp,
Work,
_world,
)
from torch.futures import Future

if TYPE_CHECKING:
Expand All @@ -54,6 +73,9 @@
_FUTURE_EXCEPTION = "fut_exception"


T = TypeVar("T")


def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
"""
Gets an item from a queue with a timeout. If the timeout is exceeded then
Expand Down Expand Up @@ -122,15 +144,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
def allreduce(
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
) -> Work:
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def allgather(
self,
output_tensors: List[List[torch.Tensor]],
input_tensor: List[torch.Tensor],
opts: object,
opts: AllgatherOptions,
) -> Work:
"""
Gathers tensors from the whole group in a list.
Expand All @@ -140,7 +164,9 @@ def allgather(
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
def broadcast(
self, tensor_list: List[torch.Tensor], opts: BroadcastOptions
) -> Work:
"""
Broadcasts the tensor to the whole group.

Expand Down Expand Up @@ -567,6 +593,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
def get_future(self) -> Future[object]:
return self._pg._get_future(self._op_id)

def __del__(self) -> None:
self._tx.put(("del", self._op_id), timeout=self._timeout)


class _BabyWorkNCCL(_BabyWork):
def wait(self, timeout: Optional[timedelta] = None) -> bool:
Expand Down Expand Up @@ -695,15 +724,18 @@ def _worker(
cmd = op[0]
if cmd == "func":
func_name, args, kwargs = op[1:]
args = _PickleSafeOptions.unsafe_args(args)
fn = getattr(pg, func_name)
work[next_op_id] = fn(*args, **kwargs)
tx.put(next_op_id)
next_op_id += 1
elif cmd == "wait":
op_id: int = op[1]
work[op_id].wait()
del work[op_id]
tx.put(op_id)
elif cmd == "del":
op_id: int = op[1]
del work[op_id]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we wait first before deleting it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to? In normal situations the user code will have waited for it, or they will have gotten a future. If they have a future we still have a reference and don't need it in this map anymore

The only other common case would be if an error occured and we're cleaning things up/shutting down -- in that case waiting likely won't succeed so no reason to wait

elif cmd == "future":
op_id: int = op[1]

Expand Down Expand Up @@ -731,6 +763,8 @@ def callback(fut: Future[object]) -> None:

del work[op_id]
tx.put((op_id, event))
elif cmd == "num_active_work":
tx.put(len(work))
else:
raise ValueError(f"unknown cmd: {cmd}")

Expand Down Expand Up @@ -775,7 +809,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
assert rx is not None
assert tx is not None

tx.put(("func", func, args, kwargs), timeout=self._timeout)
tx.put(
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
timeout=self._timeout,
)

op_id = _get(rx, self._timeout)
assert isinstance(op_id, int), f"invalid return {op_id}"
Expand All @@ -784,7 +821,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
)

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
def allreduce(
self,
tensors: List[torch.Tensor],
opts: Union[dist.AllreduceOptions, dist.ReduceOp],
) -> Work:
assert isinstance(tensors, list), "input must be list"

for tensor in tensors:
Expand All @@ -793,9 +834,90 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:

return self._run_func("allreduce", tensors, opts)

def allgather(
self,
output_tensors: List[List[torch.Tensor]],
input_tensor: List[torch.Tensor],
opts: AllgatherOptions,
) -> Work:
assert isinstance(output_tensors, list), "input must be list"
assert isinstance(input_tensor, list), "input must be list"

for tensor_list in output_tensors:
for tensor in tensor_list:
if not tensor.is_shared():
tensor.share_memory_()

for tensor in input_tensor:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("allgather", output_tensors, input_tensor, opts)

def broadcast(
self,
tensor_list: List[torch.Tensor],
opts: BroadcastOptions,
) -> Work:
assert isinstance(tensor_list, list), "input must be list"

for tensor in tensor_list:
if not tensor.is_shared():
tensor.share_memory_()

return self._run_func("broadcast", tensor_list, opts)

def size(self) -> int:
return self._world_size

def num_active_work(self) -> int:
assert self._tx is not None
self._tx.put(("num_active_work",), timeout=self._timeout)

assert self._rx is not None
return cast(int, _get(self._rx, self._timeout))


@dataclass
class _PickleSafeOptions:
func: Callable[[], object]
fields: Dict[str, object]

@classmethod
def safe_args(cls, args: T) -> T:
if isinstance(args, tuple):
return tuple(cls.safe_args(arg) for arg in args)
elif isinstance(args, list):
return [cls.safe_args(arg) for arg in args]
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
return cls.from_torch(args)
else:
return args

@classmethod
def unsafe_args(cls, args: T) -> T:
if isinstance(args, tuple):
return tuple(cls.unsafe_args(arg) for arg in args)
elif isinstance(args, list):
return [cls.unsafe_args(arg) for arg in args]
elif isinstance(args, cls):
return args.to_torch()
else:
return args

@classmethod
def from_torch(cls, opts: object) -> "_PickleSafeOptions":
return cls(
func=opts.__class__,
fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")},
)

def to_torch(self) -> object:
opts = self.func()
for k, v in self.fields.items():
setattr(opts, k, v)
return opts


class ProcessGroupBabyGloo(ProcessGroupBaby):
"""
Expand Down
84 changes: 75 additions & 9 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import gc
import io
import multiprocessing
import os
Expand Down Expand Up @@ -87,14 +88,15 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
check_tensors(item)

# Test collectives
collectives = {
"allreduce": ([input_tensor], AllreduceOptions()),
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
"broadcast": (tensor_list, BroadcastOptions()),
"broadcast_one": (input_tensor, 0),
}
collectives = [
("allreduce", ([input_tensor], AllreduceOptions())),
("allreduce", ([input_tensor], ReduceOp.SUM)),
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
("broadcast", (tensor_list, BroadcastOptions())),
("broadcast_one", (input_tensor, 0)),
]
works: Dict[str, dist._Work] = {}
for coll_str, args in collectives.items():
for coll_str, args in collectives:
coll = getattr(pg, coll_str)
work = coll(*args)
works[coll_str] = work
Expand Down Expand Up @@ -247,6 +249,23 @@ def test_reconfigure_baby_process_group(self) -> None:
assert p_2 is not None
self.assertTrue(p_2.is_alive())

def test_baby_gloo_apis(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
a.configure(store_addr, 0, 1)

_test_pg(a)

# force collection to ensure no BabyWork objects remain
gc.collect()

self.assertEqual(a.num_active_work(), 0)

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
m = nn.Linear(3, 4)
Expand Down Expand Up @@ -368,5 +387,52 @@ def test_managed_process_group(self) -> None:
self.assertIsInstance(list(works.values())[0], _ManagedWork)

self.assertEqual(manager.report_error.call_count, 0)
self.assertEqual(manager.wrap_future.call_count, 1)
self.assertEqual(manager.wait_quorum.call_count, 1)
self.assertEqual(manager.wrap_future.call_count, 2)
self.assertEqual(manager.wait_quorum.call_count, 2)


class DeviceMeshTest(TestCase):
@staticmethod
def _test_init_device_mesh(world_size: int, rank: int) -> None:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(12346)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(4)

testcase = TestCase()

manager = Mock(spec=Manager)
# Even though we only have 4 workers, we can still initialize (2, 4) mesh.
# That's because the replicate group is NOT phystically created in the
# real mesh but is virtually added to the mesh via ManagedDeviceMesh.
device_mesh = ft_init_device_mesh(
device_type="cpu",
mesh_shape=(2, world_size),
mesh_dim_names=("dp_replicate", "dp_shard"),
replicate_dim=0,
manager=manager,
)

testcase.assertTrue(
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
)
testcase.assertTrue(
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
)
replicate_group = device_mesh.get_group("dp_replicate")
testcase.assertEqual(
cast(ManagedProcessGroup, replicate_group)._manager, manager
)
replicate_mesh = device_mesh["dp_replicate"]
testcase.assertEqual(replicate_mesh.get_group(), replicate_group)
flatten_mesh = device_mesh._flatten("dp")
manager.num_participants.return_value = 1
testcase.assertEqual(flatten_mesh.size(), world_size)
testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())

def test_init_device_mesh(self) -> None:
with ProcessPoolExecutor(max_workers=4) as executor:
futures = []
for i in range(4):
future = executor.submit(self._test_init_device_mesh, 4, i)
futures.append(future)