Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable sharded checkpoint save and load (support local, sharded, and full state dicts for FSDP) #1902

Merged
merged 52 commits into from
Feb 17, 2023

Conversation

eracah
Copy link
Contributor

@eracah eracah commented Jan 23, 2023

What does this PR do?

Enables configuring what type of state dict to user for models and optimizers in State
We enable the user to set fsdp_config['state_dict_type'] to configure these types of state dicts. The types are:

  • 'full' (aka torch.distributed.fsdp.StateDictType.FULL_STATE_DICT) - the full unflattened, unsharded state dict materialized only on rank 0 with cpu offloading
  • 'local' (aka torch.distributed.fsdp.StateDictType.LOCAL_STATE_DICT - materializes just the local flattened shard of the state_dict on each rank
  • 'sharded' (aka torch.distributed.fsdp.StateDictType.SHARDED_STATE_DICT) - returns the unflattened, sharded state_dict.

when a user specifies Trainer(model=my_model, optimiziers=my_optimizer, fsdp_config={..., 'state_dict_type': 'local'}, ...), then they enable local sharding for their model parameters and their optimizer states.

In order to enable this functionality this PR:

  • Adds logic to State.state_dict() for model parameters and optim parameters to support the three FSDP state dict types
  • logic for State.load_state_dict() for model and optim that supports state dict types
  • modified save_checkpoint to allow non-zero ranks to save checkpoints
  • modified load_checkpoint to allow non-zero ranks to load checkpoints

Todo:

  • Add some documentation on how to use sharded checkpointing
  • unit tests

Tests

did a bunch of manual tests

Also did manual tests to make sure proper error raised if:

What issue(s) does this change relate to?

fix CO-1433
fix CO-1682

@eracah eracah changed the title Enable local, sharded, and full state dicts for FSDP Enable sharded checkpoint save and load (support local, sharded, and full state dicts for FSDP) Feb 4, 2023
@eracah eracah marked this pull request as ready for review February 4, 2023 01:52
composer/core/state.py Outdated Show resolved Hide resolved
Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Please add unit tests :) For manual test, could you also test a resumption run? i.e. i'd like to see that training for 100 batches is the same as training for 50 batches, checkpointing, and then resuming and training for another 50. Also would be good to check with all three types.

composer/trainer/dist_strategy.py Outdated Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
composer/utils/checkpoint.py Outdated Show resolved Hide resolved
composer/utils/checkpoint.py Show resolved Hide resolved
@eracah
Copy link
Contributor Author

eracah commented Feb 9, 2023

for manual test, could you also test a resumption run? i.e. i'd like to see that training for 100 batches is the same as training for 50 batches, checkpointing, and then resuming and training for another 50. Also would be good to check with all three types.

@dakinggg, already done see this

@eracah eracah requested a review from a team as a code owner February 14, 2023 00:02
@eracah eracah requested review from dakinggg and bcui19 February 14, 2023 02:56
@eracah
Copy link
Contributor Author

eracah commented Feb 14, 2023

Ok, @dakinggg, I added unit tests and docs. Give it another look when you get the chance.

Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

docs/source/notes/distributed_training.rst Outdated Show resolved Hide resolved
docs/source/notes/distributed_training.rst Outdated Show resolved Hide resolved
docs/source/notes/distributed_training.rst Outdated Show resolved Hide resolved
tests/trainer/test_sharded_checkpoint.py Outdated Show resolved Hide resolved
tests/trainer/test_sharded_checkpoint.py Show resolved Hide resolved
tests/trainer/test_sharded_checkpoint.py Show resolved Hide resolved
eracah and others added 8 commits February 15, 2023 13:15
    * check non-rank0 ranks only hold a shard for "full" state dict modified
    * "local" state dicts have flattened shards
    * "sharded" state dicts have unflattened shards
@eracah
Copy link
Contributor Author

eracah commented Feb 16, 2023

@dakinggg, okay I added the tests you asked for. Can you give it one more quick once over and then I'll merge it?

Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding the extra tests!

tests/trainer/test_sharded_checkpoint.py Show resolved Hide resolved
tests/trainer/test_sharded_checkpoint.py Outdated Show resolved Hide resolved
@eracah eracah requested a review from dakinggg February 17, 2023 01:57
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eracah eracah merged commit 6fd5b2d into mosaicml:dev Feb 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants