Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-node checkpoint benchmark improvements #149

Merged
merged 30 commits into from
Oct 15, 2024
Merged

Multi-node checkpoint benchmark improvements #149

merged 30 commits into from
Oct 15, 2024

Conversation

MattIrv
Copy link
Collaborator

@MattIrv MattIrv commented Oct 15, 2024

Several minor fixes/improvements for the multi-node checkpoint benchmark:

  • Update the README with some missing instructions

  • Run as a module rather than directly with Python so that imports work correctly

  • Update WORLD_SIZE env var to NUM_NODES because PyTorch / PyTorch Lightning appear to use WORLD_SIZE for other purposes

  • Set use_orig_params=False when creating the strategy, since this avoids the time-consuming step of flattening the optimizer state dict

  • Make NUM_DEVICES a configurable environment variable

  • Add the ability to configure the strategy with AdamW or SGD optimizer. AdamW appears to result in larger checkpoint files that are about 20% of the GPU memory utilization

  • Tests pass

  • Appropriate changes to documentation are included in the PR

@MattIrv MattIrv requested a review from awonak October 15, 2024 18:34
@MattIrv MattIrv requested a review from a team as a code owner October 15, 2024 18:34
@MattIrv MattIrv merged commit f12fbec into main Oct 15, 2024
5 checks passed
@MattIrv MattIrv deleted the mirvine/distckpt branch October 15, 2024 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants