From 7fee3cfbf52a8b96bd2c6e527d89a83fc50a19b4 Mon Sep 17 00:00:00 2001 From: Iris Z <31293777+wz337@users.noreply.github.com> Date: Tue, 12 Mar 2024 15:06:24 -0700 Subject: [PATCH] [TorchTrain][Checkpoint] Fix TrainState state_dict to unblock loading (#131) This fix would temporarily unblock loading. So we won't run into the issue of: ``` [rank0]:[rank0]: train_state.losses.append(train_state.current_loss) [rank0]:[rank0]: AttributeError: 'float' object has no attribute 'append' ``` However, current_loss and losses are still not correct, since by current setup, losses and current_losses would be different across different ranks. Also, we don't know the size of losses because this is based on the # of steps. So loading still work but the value of current_loss and losses are not being loaded correctly. I will follow up with further fixes. --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 2f9a130cd..603364d01 100644 --- a/train.py +++ b/train.py @@ -51,7 +51,7 @@ def state_dict(self) -> Dict[str, Any]: return { "step": torch.tensor(self.step, dtype=torch.int32), "current_loss": torch.tensor(self.current_loss, dtype=torch.float32), - "losses": torch.tensor(self.current_loss, dtype=torch.float32), + "losses": torch.tensor(self.losses, dtype=torch.float32), } def load_state_dict(self, state_dict) -> None: