Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Sharding Pass #38502

Merged

Conversation

JZ-LIANG
Copy link
Contributor

PR types

New features

PR changes

Others

Describe

sharding optimization pass for auto parallel
base framework for stage 1-2-3, more functions and example wait next pr

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@JZ-LIANG JZ-LIANG changed the title auto parallel sharding base [Auto Parallel] Sharding Pass Base Dec 28, 2021
@JZ-LIANG JZ-LIANG requested review from aoyulong and sneaxiy and removed request for aoyulong December 28, 2021 07:14
aoyulong
aoyulong previously approved these changes Dec 28, 2021
def _apply_post_optimization_passed(self, main_program, startup_program,
rank, params_grads):

# apply amp forward pass
Copy link
Collaborator

@sneaxiy sneaxiy Dec 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the comment for? TODO or wrong comment? The following code is sharding but not amp. If it is a TODO, try to add TODO(who is responsible TODO) at the beginning of this comment.

Copy link
Contributor Author

@JZ-LIANG JZ-LIANG Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is kind of a TODO that indicate the place where the amp pass will be in future. our final goal is that all optimization pass will be applied within that function after autoparallel-graph partition. we will have several update to achieve that goal.
the final order will be: graph_partition-amp-recompute-sharding-gradient_merge
but at this moment, we implement it as amp-recompute-graph_partition-sharding-gradient_merge

fixed~

auto_parallel_sharding_pass.apply(
[main_program], [startup_program], self._pass_context)

# apply recompute forward pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above reply.


# apply recompute forward pass
if self._dist_strategy.gradient_merge:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this code is not implemented yet, try to remove it first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -63,12 +63,12 @@ def apply_no_passes(self):

def check_main(self, gpus=None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know why we really need a class which has so many duplicate codes with DistPassTestBase.

Copy link
Contributor Author

@JZ-LIANG JZ-LIANG Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to rewrite the member functions: "_run_gpu_main": since we want the check how the pass co-operate with other auto_parallel logic(like graph partition), so we need to call this pass from fleet_base where will trigger both auto parallel and this pass. we could not re-use the _run_gpu_main in DistPassTestBase.

I will think a better plan for this problem in next pr

process_mesh = dist_attr.process_mesh
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
mesh_shape = process_mesh.topology
# TODO replace with specific batch size dimension
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest that each TODO comment should add who is responsible TODO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed~

def _check_conflict(self, other_pass):
return True

def _apply_single_impl(self, main_program, startup_program, context):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that our sharding pass does not support multiple blocks. How about add some assertion here?

else:
op._set_attr("ring_id", self.outer_dp_group.id)

main_block._sync_with_cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main_block._sync_with_cpp()

Copy link
Contributor Author

@JZ-LIANG JZ-LIANG Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx~
fixed

@JZ-LIANG JZ-LIANG changed the title [Auto Parallel] Sharding Pass Base [Auto Parallel] Sharding Pass Dec 29, 2021
Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(${TEST_OP} PROPERTIES LABELS "RUN_TYPE=DIST")

@JZ-LIANG JZ-LIANG merged commit e3faf34 into PaddlePaddle:develop Dec 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants