-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove OOM-Driven FSDP Deadlocks and Increase Throughput of Automicrobatching #3510
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great write-up in the PR description!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass. Need to think through this carefully after comments are addressed, but I think its right.
This is a huge PR 🎉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please also add the before/after training loss comparison. And make sure the loss is on-par before merge.
Thanks for the clean PR description and nice debugging work! Two qq:
One high level comment is it might be good to split up the PR into 4 different PR's each containing the four different sources of deadlock each with a repro example + throughput improvement. Would be good for posterity when we are reviewing/thinking of understanding each of the components of deadlocks. |
1 is a good catch, i just switched the two headings - thrashing decreases dtms but still increases throughput.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good as a v1.
I agree with @j316chuck in general we should cut this up into smaller PRs, but I think it's fine at this point to merge given we've reviewed and signed off.
@dakinggg feel free to block if you're concerned about thrashing, but I'm pro merging with the cautious approach and we can revisit in follow-on PR
@mvpatel2000 yeah i think its ok to merge as is and revisit as needed |
Just wanted to check - has this been tested for scenarios where you are starting from "older" composer checkpoints? |
@jacobfulano it should be independent of checkpoint loading |
afaik, checkpointing saves automicrobatching info, but once its loaded, its never used, so idt this should be a problem? let me know though if you're referring to something else and I'll take a look! |
Prior to this PR, automicrobatching suffered from both low reliability, in the form of consistent deadlocks when using FSDP, and decreased throughput, compared to if
device_train_microbatch_size
were manually set to the value auto found. This PR addresses both those issues, removing the sources of OOM-driven FSDP deadlocks and implementing a more intelligent sync hook adding/dropping method that allows auto to perform as well as whendtms
is manually set.Reliability:
Hooks:
With FSDP, there are 4 different sources of deadlocks, which become consistent especially when model size grows to 30b:
Before the forward of an FSDP module, some ranks may OOM and some ranks may not, leading to the OOMing ranks all_reducing when they are trying to see if other ranks are OOMing and the non-OOMing ranks all_gathering when they are unsharding and resharding as they continue into the FSDP modules
a) Solution: Register
module.register_forward_pre_hook(sync_hook, prepend=True)
on FSDP modulesBefore the backwards of an FSDP module, some ranks may OOM and some ranks may not, leading to the OOMing ranks all_reducing when they are trying to see if other ranks are OOMing and the non-OOMing ranks all_gathering when they are unsharding and resharding as they continue into the FSDP modules
a) Solution: Register
module.register_full_backward_pre_hook(sync_hook, prepend=True)
on FSDP modulesIn the FSDP post-backward-hook, post_backward_reshard prefetches unshard, moving the unshard from the beginning of the next backward to the end of this backward
a) Solution: Register
module.register_full_backward_hook(sync_hook)
on the non-FSDP original modulesWithin the FSDP post-backward-hook, the call to unshard reallocates memory, causing some ranks to OOM and others to not OOM. Since there is no syncing within the FSDP native hook, this can lead to deadlock when half the ranks OOM during that realloc and half don’t
a) Solution: Monkeypatch a sync hook right before the
realloc
; proof of adaptive patch working:test-patch-mpt-125m-chinchilla-regression-LfuGDp
Thrashing:
Thrashing occurs when we are close to the GPU's maximum memory, and
alloc_retries
consistently occur, leading to lower throughput and a risk of OOMing. If we detectalloc_retries
for two consecutive batches, we consider it thrashing, and search downwards for a smaller microbatch size, treating it as if it were an OOM.Run with thrashing check:
mpt-30b-auto-fix-egrtom
Run without thrashing check:
mpt-30b-auto-fix-ZQmoZp
Throughput
We turn all hooks on when we are searching for a microbatch size for the first time, right after we finish an eval in case of a memory spike, when we detect thrashing, and if we hit even a single OOM. We leave these hooks on until we can successfully run batches with the selected microbatch size for 3 consecutive batches. Then, to increase throughput, we drop all hooks until we hit one of the events listed above.
This allows us to achieve the same throughput as if we had manually set the
dtms
to what auto found.Experiments
Improvement on throughput depends on whether the bottleneck of training is GPU/CPU bound, which depends on model size/architecture.
MPT-125M, 1 Node:
MPT-1B, 1 Node:
MPT-7B, 1 Node:
MPT-13B, 2 Node:
MPT-30B, 2 Node:
Credits: One unit test,
test_automicrobatching_fsdp
, borrows classes introduced in an earlier WIP PR by @bigning.Successful regression test: mcli logs -f llm-foundry-regression-tests-runner-PW7pJe