Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change how TorchFT manages user_state_dict #87

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
data = f.read()

reader = io.BytesIO(data)
return torch.load(reader, weights_only=True)
# We have to set weights_only to True as there are some non-tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# We have to set weights_only to True as there are some non-tensor
# We have to set weights_only to False as there are some non-tensor

# states like lr_scheduler.
return torch.load(reader, weights_only=False)

def address(self) -> str:
"""
Expand Down
17 changes: 11 additions & 6 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar

import torch
from torch.distributed import ReduceOp, TCPStore
Expand Down Expand Up @@ -87,8 +87,8 @@ class Manager:
def __init__(
self,
pg: "ProcessGroup",
load_state_dict: Callable[[T], None],
state_dict: Callable[[], T],
load_state_dict: Optional[Callable[[T], None]],
state_dict: Optional[Callable[[], T]],
min_replica_size: int,
use_async_quorum: bool = True,
timeout: timedelta = timedelta(seconds=60),
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
transfering checkpoints to recovering replicas
"""
self._load_state_dict = load_state_dict
self._state_dict = state_dict
self._user_state_dict = state_dict
self._pending_state_dict: Optional[Dict[str, object]] = None
self._use_async_quorum = use_async_quorum
self._timeout = timeout
Expand All @@ -159,8 +159,6 @@ def __init__(
world_size = world_size or int(os.environ["WORLD_SIZE"])
self._min_replica_size = min_replica_size

self._user_state_dict = state_dict

if checkpoint_transport is None:
checkpoint_transport = CheckpointServer[Dict[str, T]](
timeout=timeout,
Expand Down Expand Up @@ -226,6 +224,12 @@ def __init__(
self._participating_rank: Optional[int] = None
self._participating_world_size: int = 0

def set_state_dict_fns(
self, load_state_dict: Callable[T, None], state_dict: Callable[[], T]
) -> None:
self._load_state_dict = load_state_dict
self._user_state_dict = state_dict

def shutdown(self, wait: bool = True) -> None:
"""
Shutdown the manager and checkpoint server.
Expand Down Expand Up @@ -533,6 +537,7 @@ def _apply_pending_state_dict(self) -> None:
assert self._pending_state_dict is not None, "checkpoint was not staged"
self._load_state_dict(self._pending_state_dict["user"])
self._pending_state_dict = None
self._logger.info("Loaded state dict.")

def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
"""
Expand Down
37 changes: 34 additions & 3 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from datetime import timedelta
from typing import Optional
from unittest import TestCase
from unittest.mock import MagicMock, create_autospec, patch
from unittest.mock import create_autospec, MagicMock, patch

import torch
from torch.distributed import TCPStore

from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
from torchft.process_group import ProcessGroup, _DummyWork
from torchft.manager import Manager, MANAGER_ADDR_KEY, REPLICA_ID_KEY, WorldSizeMode
from torchft.process_group import _DummyWork, ProcessGroup
from torchft.torchft import QuorumResult


Expand Down Expand Up @@ -95,6 +95,37 @@ def test_state_dict(self, client_mock: MagicMock) -> None:
self.assertEqual(manager.current_step(), 1234)
self.assertEqual(manager.batches_committed(), 2345)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_user_state_dict(self, client_mock: MagicMock) -> None:
manager = self._create_manager()

self.assertEqual(
manager._manager_state_dict(),
{
"user": {},
"torchft": {
"step": 0,
"batches_committed": 0,
},
},
)

manager.set_state_dict_fns(
self.load_state_dict,
lambda: {"new_state": 1},
)

self.assertEqual(
manager._manager_state_dict(),
{
"user": {"new_state": 1},
"torchft": {
"step": 0,
"batches_committed": 0,
},
},
)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_happy(self, client_mock: MagicMock) -> None:
manager = self._create_manager()
Expand Down
Loading