You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed that if a job is interrupted twice (say first interruption at step 25, then resume and continue until step 45, then resume from step 45), the second time it resumes from the latest checkpoint will have an unexpected behavior that checkpoint.load() seems to find the latest ckpt correctly (say step 45), but the loaded train_state.step seems to still be that from the first resume (say 26).
An example logging info for a second-time resume is as follows:
[rank0]:2024-09-08 10:58:15,493 - root - INFO - Loading the checkpoint at step 45. [90/1925]
[rank0]:2024-09-08 10:58:16,211 - root - INFO - Training starts at step 26, with local batch size 8, global batch size 8, sequence length 2048, total steps 100 (warmup 2)
Thank you very much for your help!
The text was updated successfully, but these errors were encountered:
Seems like it's because after the first resume, the train_state in checkpoint has not been updated to point to the one being incremented during train loop. The below hotfix seems to solve the problem, but maybe a better solution should be considered that eliminate the root cause of this issue.
Hi, thank you for releasing this great codebase.
I noticed that if a job is interrupted twice (say first interruption at step 25, then resume and continue until step 45, then resume from step 45), the second time it resumes from the latest checkpoint will have an unexpected behavior that
checkpoint.load()
seems to find the latest ckpt correctly (say step 45), but the loadedtrain_state.step
seems to still be that from the first resume (say 26).An example logging info for a second-time resume is as follows:
Thank you very much for your help!
The text was updated successfully, but these errors were encountered: