-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel] Sharding Pass #38502
[Auto Parallel] Sharding Pass #38502
Conversation
Thanks for your contribution! |
def _apply_post_optimization_passed(self, main_program, startup_program, | ||
rank, params_grads): | ||
|
||
# apply amp forward pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the comment for? TODO
or wrong comment? The following code is sharding but not amp. If it is a TODO
, try to add TODO(who is responsible TODO)
at the beginning of this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is kind of a TODO that indicate the place where the amp pass will be in future. our final goal is that all optimization pass will be applied within that function after autoparallel-graph partition. we will have several update to achieve that goal.
the final order will be: graph_partition-amp-recompute-sharding-gradient_merge
but at this moment, we implement it as amp-recompute-graph_partition-sharding-gradient_merge
fixed~
auto_parallel_sharding_pass.apply( | ||
[main_program], [startup_program], self._pass_context) | ||
|
||
# apply recompute forward pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above reply.
|
||
# apply recompute forward pass | ||
if self._dist_strategy.gradient_merge: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this code is not implemented yet, try to remove it first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -63,12 +63,12 @@ def apply_no_passes(self): | |||
|
|||
def check_main(self, gpus=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not know why we really need a class which has so many duplicate codes with DistPassTestBase
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we want to rewrite the member functions: "_run_gpu_main": since we want the check how the pass co-operate with other auto_parallel logic(like graph partition), so we need to call this pass from fleet_base where will trigger both auto parallel and this pass. we could not re-use the _run_gpu_main in DistPassTestBase.
I will think a better plan for this problem in next pr
process_mesh = dist_attr.process_mesh | ||
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name) | ||
mesh_shape = process_mesh.topology | ||
# TODO replace with specific batch size dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest that each TODO
comment should add who is responsible TODO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed~
def _check_conflict(self, other_pass): | ||
return True | ||
|
||
def _apply_single_impl(self, main_program, startup_program, context): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember that our sharding pass does not support multiple blocks. How about add some assertion here?
else: | ||
op._set_attr("ring_id", self.outer_dp_group.id) | ||
|
||
main_block._sync_with_cpp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
main_block._sync_with_cpp()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx~
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for set_tests_properties(${TEST_OP} PROPERTIES LABELS "RUN_TYPE=DIST")
PR types
New features
PR changes
Others
Describe
sharding optimization pass for auto parallel
base framework for stage 1-2-3, more functions and example wait next pr