-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Use torch 2.2 distributed checkpoint APIs for FSDP #19497
Conversation
97469de
to
fb62091
Compare
fb62091
to
bf021fe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently blocked by pytorch/pytorch#119800 (comment)
if _TORCH_GREATER_EQUAL_2_2: | ||
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict | ||
|
||
# `cpu_offload` disabled because when used with `full_state_dict` only rank 0 loads the state dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice that the other path sets rank0_only=False
. I asked if this could be configurable in pytorch/pytorch#112837 (comment)
@@ -440,6 +439,7 @@ def save_checkpoint( | |||
) | |||
if filter is not None and self._state_dict_type == "sharded": | |||
# https://github.com/pytorch/pytorch/issues/105379 | |||
# FIXME: revisit support with new APIs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder to myself
for more information, see https://pre-commit.ci
Closing since the new model parallel strategy is most likely the future. Implemented in #19852 |
What does this PR do?
Fixes #19462
Resources
TODO:
_TORCH_GREATER_EQUAL_2_2 = False
since CI only tests 2.2📚 Documentation preview 📚: https://pytorch-lightning--19497.org.readthedocs.build/en/19497/