-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] partition_balanced return wrong result. #4312
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
zjjMaiMai
requested review from
jeffra,
mrwyattii and
tjruwase
as code owners
September 12, 2023 18:01
@microsoft-github-policy-service agree |
already fixed! cc @tjruwase |
ShadenSmith
approved these changes
Oct 6, 2023
Thanks for this PR, @zjjMaiMai! A note on balance: the objective function in the original code minimizes the maximum load per partition. The maximum load on a pipeline stage determines the pipeline throughput, and so the original result is also balanced. Example paper: Fast Optimal Load Balancing Algorithms for 1D Partitioning |
github-merge-queue
bot
removed this pull request from the merge queue due to failed status checks
Oct 6, 2023
github-merge-queue
bot
removed this pull request from the merge queue due to failed status checks
Dec 4, 2023
mauryaavinash95
pushed a commit
to mauryaavinash95/DeepSpeed
that referenced
this pull request
Feb 17, 2024
# Background In pipeline parallelism, deepspeed uses `ds_utils.partition_balanced` to balance the partitioning of the model according to the number of parameters or class names. https://github.com/microsoft/DeepSpeed/blob/581e44dd1ab3c409a5905335867c761d5cb4db5b/deepspeed/runtime/pipe/module.py#L380-L395 # What wrong? ``` >>> import deepspeed >>> deepspeed.__version__ '0.10.3+542dc0d5' >>> from deepspeed.runtime import utils as ds_utils >>> ds_utils.partition_balanced([1, 1, 1, 1, 1], 4) [0, 2, 4, 5, 5] >>> ``` the result [0, 2, 4, 5, 5] means [2, 2, 1, 0] layers for each part, which is not balanced at all. the last part will throw an exception because there are no parameters to training. i add some unit test for this function, and i will fix it later if anyone need it. --------- Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Background
In pipeline parallelism, deepspeed uses
ds_utils.partition_balanced
to balance the partitioning of the model according to the number of parameters or class names.DeepSpeed/deepspeed/runtime/pipe/module.py
Lines 380 to 395 in 581e44d
What wrong?
the result [0, 2, 4, 5, 5] means [2, 2, 1, 0] layers for each part, which is not balanced at all. the last part will throw an exception because there are no parameters to training.
i add some unit test for this function, and i will fix it later if anyone need it.