-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable sharded checkpoint save and load (support local, sharded, and full state dicts for FSDP) #1902
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Please add unit tests :) For manual test, could you also test a resumption run? i.e. i'd like to see that training for 100 batches is the same as training for 50 batches, checkpointing, and then resuming and training for another 50. Also would be good to check with all three types.
Ok, @dakinggg, I added unit tests and docs. Give it another look when you get the chance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Daniel King <[email protected]>
* check non-rank0 ranks only hold a shard for "full" state dict modified * "local" state dicts have flattened shards * "sharded" state dicts have unflattened shards
Co-authored-by: Daniel King <[email protected]>
@dakinggg, okay I added the tests you asked for. Can you give it one more quick once over and then I'll merge it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for adding the extra tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
Enables configuring what type of state dict to user for models and optimizers in
State
We enable the user to set fsdp_config['state_dict_type'] to configure these types of state dicts. The types are:
'full'
(akatorch.distributed.fsdp.StateDictType.FULL_STATE_DICT
) - the full unflattened, unsharded state dict materialized only on rank 0 with cpu offloading'local'
(akatorch.distributed.fsdp.StateDictType.LOCAL_STATE_DICT
- materializes just the local flattened shard of the state_dict on each rank'sharded'
(akatorch.distributed.fsdp.StateDictType.SHARDED_STATE_DICT
) - returns the unflattened, sharded state_dict.when a user specifies
Trainer(model=my_model, optimiziers=my_optimizer, fsdp_config={..., 'state_dict_type': 'local'}, ...)
, then they enable local sharding for their model parameters and their optimizer states.In order to enable this functionality this PR:
State.state_dict()
for model parameters and optim parameters to support the three FSDP state dict typesState.load_state_dict()
for model and optim that supports state dict typessave_checkpoint
to allow non-zero ranks to save checkpointsload_checkpoint
to allow non-zero ranks to load checkpointsTodo:
Tests
did a bunch of manual tests
Also did manual tests to make sure proper error raised if:
What issue(s) does this change relate to?
fix CO-1433
fix CO-1682