From a51ece3e0700a97835f233216cc39abbdefa8cc2 Mon Sep 17 00:00:00 2001 From: wz337 Date: Tue, 12 Mar 2024 14:18:58 -0700 Subject: [PATCH] unblock loading --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 15f0e493..525c22fa 100644 --- a/train.py +++ b/train.py @@ -51,7 +51,7 @@ def state_dict(self) -> Dict[str, Any]: return { "step": torch.tensor(self.step, dtype=torch.int32), "current_loss": torch.tensor(self.current_loss, dtype=torch.float32), - "losses": torch.tensor(self.current_loss, dtype=torch.float32), + "losses": torch.tensor(self.losses, dtype=torch.float32), } def load_state_dict(self, state_dict) -> None: