Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove OOM-Driven FSDP Deadlocks and Increase Throughput of Automicrobatching #3510

Merged
merged 32 commits into from
Aug 2, 2024

Conversation

JackZ-db
Copy link
Contributor

@JackZ-db JackZ-db commented Jul 31, 2024

Prior to this PR, automicrobatching suffered from both low reliability, in the form of consistent deadlocks when using FSDP, and decreased throughput, compared to if device_train_microbatch_size were manually set to the value auto found. This PR addresses both those issues, removing the sources of OOM-driven FSDP deadlocks and implementing a more intelligent sync hook adding/dropping method that allows auto to perform as well as when dtms is manually set.

Reliability:

Hooks:

With FSDP, there are 4 different sources of deadlocks, which become consistent especially when model size grows to 30b:

  1. Before the forward of an FSDP module, some ranks may OOM and some ranks may not, leading to the OOMing ranks all_reducing when they are trying to see if other ranks are OOMing and the non-OOMing ranks all_gathering when they are unsharding and resharding as they continue into the FSDP modules
    a) Solution: Register module.register_forward_pre_hook(sync_hook, prepend=True) on FSDP modules

  2. Before the backwards of an FSDP module, some ranks may OOM and some ranks may not, leading to the OOMing ranks all_reducing when they are trying to see if other ranks are OOMing and the non-OOMing ranks all_gathering when they are unsharding and resharding as they continue into the FSDP modules
    a) Solution: Register module.register_full_backward_pre_hook(sync_hook, prepend=True) on FSDP modules

  3. In the FSDP post-backward-hook, post_backward_reshard prefetches unshard, moving the unshard from the beginning of the next backward to the end of this backward
    a) Solution: Register module.register_full_backward_hook(sync_hook) on the non-FSDP original modules

  4. Within the FSDP post-backward-hook, the call to unshard reallocates memory, causing some ranks to OOM and others to not OOM. Since there is no syncing within the FSDP native hook, this can lead to deadlock when half the ranks OOM during that realloc and half don’t
    a) Solution: Monkeypatch a sync hook right before the realloc; proof of adaptive patch working: test-patch-mpt-125m-chinchilla-regression-LfuGDp

Thrashing:

Thrashing occurs when we are close to the GPU's maximum memory, and alloc_retries consistently occur, leading to lower throughput and a risk of OOMing. If we detect alloc_retries for two consecutive batches, we consider it thrashing, and search downwards for a smaller microbatch size, treating it as if it were an OOM.

Run with thrashing check: mpt-30b-auto-fix-egrtom

  • Samples/Sec: 7.25
  • dtms: 4

Run without thrashing check: mpt-30b-auto-fix-ZQmoZp

  • Samples/Sec: 6.9
  • dtms: 8

Throughput

We turn all hooks on when we are searching for a microbatch size for the first time, right after we finish an eval in case of a memory spike, when we detect thrashing, and if we hit even a single OOM. We leave these hooks on until we can successfully run batches with the selected microbatch size for 3 consecutive batches. Then, to increase throughput, we drop all hooks until we hit one of the events listed above.

This allows us to achieve the same throughput as if we had manually set the dtms to what auto found.

Experiments

Improvement on throughput depends on whether the bottleneck of training is GPU/CPU bound, which depends on model size/architecture.

MPT-125M, 1 Node:

MPT-1B, 1 Node:

MPT-7B, 1 Node:

MPT-13B, 2 Node:

MPT-30B, 2 Node:

Credits: One unit test, test_automicrobatching_fsdp, borrows classes introduced in an earlier WIP PR by @bigning.

Successful regression test: mcli logs -f llm-foundry-regression-tests-runner-PW7pJe

  • Same loss, higher throughput, lower memory usage (due to removal of hooks)

@JackZ-db JackZ-db marked this pull request as ready for review August 1, 2024 02:33
Copy link
Contributor

@bigning bigning left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great write-up in the PR description!

composer/trainer/_patch_pytorch.py Outdated Show resolved Hide resolved
composer/trainer/trainer.py Outdated Show resolved Hide resolved
composer/trainer/trainer.py Outdated Show resolved Hide resolved
composer/trainer/trainer.py Show resolved Hide resolved
composer/trainer/trainer.py Outdated Show resolved Hide resolved
tests/trainer/test_fsdp.py Show resolved Hide resolved
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first pass. Need to think through this carefully after comments are addressed, but I think its right.

This is a huge PR 🎉

composer/trainer/trainer.py Show resolved Hide resolved
composer/trainer/trainer.py Show resolved Hide resolved
composer/trainer/trainer.py Outdated Show resolved Hide resolved
composer/trainer/trainer.py Outdated Show resolved Hide resolved
composer/trainer/trainer.py Show resolved Hide resolved
@JackZ-db JackZ-db requested review from mvpatel2000 and bigning August 1, 2024 21:16
composer/trainer/trainer.py Outdated Show resolved Hide resolved
composer/distributed/dist_strategy.py Show resolved Hide resolved
composer/trainer/trainer.py Show resolved Hide resolved
@JackZ-db JackZ-db requested a review from bigning August 1, 2024 22:56
Copy link
Contributor

@bigning bigning left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add the before/after training loss comparison. And make sure the loss is on-par before merge.

@j316chuck
Copy link
Contributor

j316chuck commented Aug 2, 2024

Thanks for the clean PR description and nice debugging work! Two qq:

  1. On the thrashing section, why is auto throughput samples/sec lower after fix?

  2. Throughput: iiuc, we can: turn on hooks in train forward -> find micro batch size A -> turn off hooks -> oom in eval forward -> turn on hooks again in eval forward -> find microbatch size B -> turn off hooks in eval forward? If so, then the first train batch will != second train batch size. Is this expected?

One high level comment is it might be good to split up the PR into 4 different PR's each containing the four different sources of deadlock each with a repro example + throughput improvement. Would be good for posterity when we are reviewing/thinking of understanding each of the components of deadlocks.

@JackZ-db
Copy link
Contributor Author

JackZ-db commented Aug 2, 2024

Thanks for the clean PR description and nice debugging work! Two qq:

  1. On the thrashing section, why is auto throughput samples/sec lower after fix?
  2. Throughput: iiuc, we can: turn on hooks in train forward -> find micro batch size A -> turn off hooks -> oom in eval forward -> turn on hooks again in eval forward -> find microbatch size B -> turn off hooks in eval forward? If so, then the first train batch will != second train batch size. Is this expected?

One high level comment is it might be good to split up the PR into 4 different PR's each containing the four different sources of deadlock each with a repro example + throughput improvement. Would be good for posterity when we are reviewing/thinking of understanding each of the components of deadlocks.

1 is a good catch, i just switched the two headings - thrashing decreases dtms but still increases throughput.

  1. The hook readding isn't meant for OOMs in evaluation - it's meant to account for a memory spike that we see directly after the first evaluation (future evaluations don't introduce new memory allocations), which then may cause the first train batch after evaluation to OOM.

Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good as a v1.

I agree with @j316chuck in general we should cut this up into smaller PRs, but I think it's fine at this point to merge given we've reviewed and signed off.

@dakinggg feel free to block if you're concerned about thrashing, but I'm pro merging with the cautious approach and we can revisit in follow-on PR

@dakinggg
Copy link
Contributor

dakinggg commented Aug 2, 2024

@mvpatel2000 yeah i think its ok to merge as is and revisit as needed

@JackZ-db JackZ-db merged commit deb39cf into mosaicml:dev Aug 2, 2024
15 checks passed
@jacobfulano
Copy link
Contributor

Just wanted to check - has this been tested for scenarios where you are starting from "older" composer checkpoints?

@mvpatel2000
Copy link
Contributor

Just wanted to check - has this been tested for scenarios where you are starting from "older" composer checkpoints?

@jacobfulano it should be independent of checkpoint loading

@JackZ-db
Copy link
Contributor Author

JackZ-db commented Aug 5, 2024

Just wanted to check - has this been tested for scenarios where you are starting from "older" composer checkpoints?

afaik, checkpointing saves automicrobatching info, but once its loaded, its never used, so idt this should be a problem? let me know though if you're referring to something else and I'll take a look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants