diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 48a5d51..16c62cd 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -172,7 +172,9 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: data = f.read() reader = io.BytesIO(data) - return torch.load(reader, weights_only=True) + # We have to set weights_only to False as there are some non-tensor + # states like lr_scheduler. + return torch.load(reader, weights_only=False) def address(self) -> str: """ diff --git a/torchft/manager.py b/torchft/manager.py index dc5ab30..99ad410 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -87,8 +87,8 @@ class Manager: def __init__( self, pg: "ProcessGroup", - load_state_dict: Callable[[T], None], - state_dict: Callable[[], T], + load_state_dict: Optional[Callable[[T], None]], + state_dict: Optional[Callable[[], T]], min_replica_size: int, use_async_quorum: bool = True, timeout: timedelta = timedelta(seconds=60), @@ -144,7 +144,7 @@ def __init__( transfering checkpoints to recovering replicas """ self._load_state_dict = load_state_dict - self._state_dict = state_dict + self._user_state_dict = state_dict self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum self._timeout = timeout @@ -159,8 +159,6 @@ def __init__( world_size = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size - self._user_state_dict = state_dict - if checkpoint_transport is None: checkpoint_transport = CheckpointServer[Dict[str, T]]( timeout=timeout, @@ -226,6 +224,12 @@ def __init__( self._participating_rank: Optional[int] = None self._participating_world_size: int = 0 + def set_state_dict_fns( + self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T] + ) -> None: + self._load_state_dict = load_state_dict + self._user_state_dict = state_dict + def shutdown(self, wait: bool = True) -> None: """ Shutdown the manager and checkpoint server. @@ -531,8 +535,12 @@ def _apply_pending_state_dict(self) -> None: self._logger.info("applying pending state dict") assert self._pending_state_dict is not None, "checkpoint was not staged" + assert ( + self._load_state_dict is not None + ), "user load_state_dict is not initialized." self._load_state_dict(self._pending_state_dict["user"]) self._pending_state_dict = None + self._logger.info("Loaded state dict.") def should_commit(self, timeout: Optional[timedelta] = None) -> bool: """ @@ -602,6 +610,7 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._batches_committed = state_dict["batches_committed"] def _manager_state_dict(self) -> Dict[str, object]: + assert self._user_state_dict is not None, "user state_dict is not initialized." return { "user": self._user_state_dict(), "torchft": self.state_dict(), diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 01adddf..2e5b293 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -95,6 +95,37 @@ def test_state_dict(self, client_mock: MagicMock) -> None: self.assertEqual(manager.current_step(), 1234) self.assertEqual(manager.batches_committed(), 2345) + @patch("torchft.manager.ManagerClient", autospec=True) + def test_user_state_dict(self, client_mock: MagicMock) -> None: + manager = self._create_manager() + + self.assertEqual( + manager._manager_state_dict(), + { + "user": {}, + "torchft": { + "step": 0, + "batches_committed": 0, + }, + }, + ) + + manager.set_state_dict_fns( + self.load_state_dict, + lambda: {"new_state": 1}, + ) + + self.assertEqual( + manager._manager_state_dict(), + { + "user": {"new_state": 1}, + "torchft": { + "step": 0, + "batches_committed": 0, + }, + }, + ) + @patch("torchft.manager.ManagerClient", autospec=True) def test_quorum_happy(self, client_mock: MagicMock) -> None: manager = self._create_manager()