From 369b64ebe9db5aa8e9b20fcf7ea202ee6efc4889 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 29 Jan 2025 11:28:32 -0800 Subject: [PATCH 1/5] Change how TorchFT manages user_state_dict This PR closes some state_dict gaps when integrating with TorchTitan: 1. User state_dict() and load_state_dict() functions can be initialized lazily. 2. Change weights_only to False for torch.load as we may have to load some non-tensor states. --- torchft/checkpointing.py | 4 +++- torchft/manager.py | 17 +++++++++++------ torchft/manager_test.py | 37 ++++++++++++++++++++++++++++++++++--- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 48a5d51..bc3eb7b 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -172,7 +172,9 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: data = f.read() reader = io.BytesIO(data) - return torch.load(reader, weights_only=True) + # We have to set weights_only to True as there are some non-tensor + # states like lr_scheduler. + return torch.load(reader, weights_only=False) def address(self) -> str: """ diff --git a/torchft/manager.py b/torchft/manager.py index dc5ab30..9c4b159 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -33,7 +33,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast +from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar import torch from torch.distributed import ReduceOp, TCPStore @@ -87,8 +87,8 @@ class Manager: def __init__( self, pg: "ProcessGroup", - load_state_dict: Callable[[T], None], - state_dict: Callable[[], T], + load_state_dict: Optional[Callable[[T], None]], + state_dict: Optional[Callable[[], T]], min_replica_size: int, use_async_quorum: bool = True, timeout: timedelta = timedelta(seconds=60), @@ -144,7 +144,7 @@ def __init__( transfering checkpoints to recovering replicas """ self._load_state_dict = load_state_dict - self._state_dict = state_dict + self._user_state_dict = state_dict self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum self._timeout = timeout @@ -159,8 +159,6 @@ def __init__( world_size = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size - self._user_state_dict = state_dict - if checkpoint_transport is None: checkpoint_transport = CheckpointServer[Dict[str, T]]( timeout=timeout, @@ -226,6 +224,12 @@ def __init__( self._participating_rank: Optional[int] = None self._participating_world_size: int = 0 + def set_state_dict_fns( + self, load_state_dict: Callable[T, None], state_dict: Callable[[], T] + ) -> None: + self._load_state_dict = load_state_dict + self._user_state_dict = state_dict + def shutdown(self, wait: bool = True) -> None: """ Shutdown the manager and checkpoint server. @@ -533,6 +537,7 @@ def _apply_pending_state_dict(self) -> None: assert self._pending_state_dict is not None, "checkpoint was not staged" self._load_state_dict(self._pending_state_dict["user"]) self._pending_state_dict = None + self._logger.info("Loaded state dict.") def should_commit(self, timeout: Optional[timedelta] = None) -> bool: """ diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 01adddf..e80d51e 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -7,13 +7,13 @@ from datetime import timedelta from typing import Optional from unittest import TestCase -from unittest.mock import MagicMock, create_autospec, patch +from unittest.mock import create_autospec, MagicMock, patch import torch from torch.distributed import TCPStore -from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode -from torchft.process_group import ProcessGroup, _DummyWork +from torchft.manager import Manager, MANAGER_ADDR_KEY, REPLICA_ID_KEY, WorldSizeMode +from torchft.process_group import _DummyWork, ProcessGroup from torchft.torchft import QuorumResult @@ -95,6 +95,37 @@ def test_state_dict(self, client_mock: MagicMock) -> None: self.assertEqual(manager.current_step(), 1234) self.assertEqual(manager.batches_committed(), 2345) + @patch("torchft.manager.ManagerClient", autospec=True) + def test_user_state_dict(self, client_mock: MagicMock) -> None: + manager = self._create_manager() + + self.assertEqual( + manager._manager_state_dict(), + { + "user": {}, + "torchft": { + "step": 0, + "batches_committed": 0, + }, + }, + ) + + manager.set_state_dict_fns( + self.load_state_dict, + lambda: {"new_state": 1}, + ) + + self.assertEqual( + manager._manager_state_dict(), + { + "user": {"new_state": 1}, + "torchft": { + "step": 0, + "batches_committed": 0, + }, + }, + ) + @patch("torchft.manager.ManagerClient", autospec=True) def test_quorum_happy(self, client_mock: MagicMock) -> None: manager = self._create_manager() From 1115a379104635b0121590db4ad7dd79cbc40bc5 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 29 Jan 2025 14:01:22 -0800 Subject: [PATCH 2/5] lintrunner --- torchft/manager.py | 2 +- torchft/manager_test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index 9c4b159..e31d35c 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -33,7 +33,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from enum import Enum -from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast import torch from torch.distributed import ReduceOp, TCPStore diff --git a/torchft/manager_test.py b/torchft/manager_test.py index e80d51e..2e5b293 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -7,13 +7,13 @@ from datetime import timedelta from typing import Optional from unittest import TestCase -from unittest.mock import create_autospec, MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import torch from torch.distributed import TCPStore -from torchft.manager import Manager, MANAGER_ADDR_KEY, REPLICA_ID_KEY, WorldSizeMode -from torchft.process_group import _DummyWork, ProcessGroup +from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode +from torchft.process_group import ProcessGroup, _DummyWork from torchft.torchft import QuorumResult From ff0f31fb7137cc640cf597bccbd757de95b4e64a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 29 Jan 2025 15:14:22 -0800 Subject: [PATCH 3/5] Typing fix --- torchft/checkpointing.py | 2 +- torchft/manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index bc3eb7b..16c62cd 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -172,7 +172,7 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: data = f.read() reader = io.BytesIO(data) - # We have to set weights_only to True as there are some non-tensor + # We have to set weights_only to False as there are some non-tensor # states like lr_scheduler. return torch.load(reader, weights_only=False) diff --git a/torchft/manager.py b/torchft/manager.py index e31d35c..739bfee 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -225,7 +225,7 @@ def __init__( self._participating_world_size: int = 0 def set_state_dict_fns( - self, load_state_dict: Callable[T, None], state_dict: Callable[[], T] + self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T] ) -> None: self._load_state_dict = load_state_dict self._user_state_dict = state_dict From 58e6028dfecf8363eade4df35c858af87353e94f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 29 Jan 2025 15:24:17 -0800 Subject: [PATCH 4/5] Typing --- torchft/manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchft/manager.py b/torchft/manager.py index 739bfee..fb26110 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -534,8 +534,11 @@ def _apply_pending_state_dict(self) -> None: self._logger.info("applying pending state dict") - assert self._pending_state_dict is not None, "checkpoint was not staged" + assert ( + self._load_state_dict is not None + ), "user load_state_dict is not initialized." self._load_state_dict(self._pending_state_dict["user"]) + assert self._pending_state_dict is not None, "checkpoint was not staged" self._pending_state_dict = None self._logger.info("Loaded state dict.") @@ -607,6 +610,7 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._batches_committed = state_dict["batches_committed"] def _manager_state_dict(self) -> Dict[str, object]: + assert self._user_state_dict is not None, "user state_dict is not initialized." return { "user": self._user_state_dict(), "torchft": self.state_dict(), From af23135b14506e215280d7e680be4f01b8b066c9 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 29 Jan 2025 15:42:35 -0800 Subject: [PATCH 5/5] typing --- torchft/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchft/manager.py b/torchft/manager.py index fb26110..99ad410 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -534,11 +534,11 @@ def _apply_pending_state_dict(self) -> None: self._logger.info("applying pending state dict") + assert self._pending_state_dict is not None, "checkpoint was not staged" assert ( self._load_state_dict is not None ), "user load_state_dict is not initialized." self._load_state_dict(self._pending_state_dict["user"]) - assert self._pending_state_dict is not None, "checkpoint was not staged" self._pending_state_dict = None self._logger.info("Loaded state dict.")