Skip to content

Commit

Permalink
[TorchTrain][Checkpoint] Fix TrainState state_dict to unblock loading (
Browse files Browse the repository at this point in the history
…pytorch#131)

This fix would temporarily unblock loading. So we won't run into the
issue of:

```
[rank0]:[rank0]:     train_state.losses.append(train_state.current_loss)
[rank0]:[rank0]: AttributeError: 'float' object has no attribute 'append'
```

However, current_loss and losses are still not correct, since by current
setup, losses and current_losses would be different across different
ranks. Also, we don't know the size of losses because this is based on
the # of steps. So loading still work but the value of current_loss and
losses are not being loaded correctly.

I will follow up with further fixes.
  • Loading branch information
wz337 authored Mar 12, 2024
1 parent 10229d6 commit 7fee3cf
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def state_dict(self) -> Dict[str, Any]:
return {
"step": torch.tensor(self.step, dtype=torch.int32),
"current_loss": torch.tensor(self.current_loss, dtype=torch.float32),
"losses": torch.tensor(self.current_loss, dtype=torch.float32),
"losses": torch.tensor(self.losses, dtype=torch.float32),
}

def load_state_dict(self, state_dict) -> None:
Expand Down

0 comments on commit 7fee3cf

Please sign in to comment.