diff --git a/train.py b/train.py index 15f0e493a..525c22fae 100644 --- a/train.py +++ b/train.py @@ -51,7 +51,7 @@ def state_dict(self) -> Dict[str, Any]: return { "step": torch.tensor(self.step, dtype=torch.int32), "current_loss": torch.tensor(self.current_loss, dtype=torch.float32), - "losses": torch.tensor(self.current_loss, dtype=torch.float32), + "losses": torch.tensor(self.losses, dtype=torch.float32), } def load_state_dict(self, state_dict) -> None: