-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel] Sharding Pass #38502
Changes from 5 commits
210b790
6f031e8
aad24bc
d693a48
7becc2c
d8d7c91
63323bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
from paddle.distributed.fleet import cloud_utils | ||
import paddle.fluid.core as core | ||
from paddle.fluid import program_guard | ||
from paddle.distributed.passes import new_pass, PassContext | ||
from .dist_context import DistributedContext | ||
from .dist_context import get_default_distributed_context | ||
from .dist_context import set_default_distributed_context | ||
|
@@ -139,30 +140,34 @@ def _generate_backward(self, main_program, startup_program, loss, | |
|
||
def _apply_optimize(self, main_program, startup_program, params_grads): | ||
|
||
if self._dist_strategy.sharding: | ||
auto_parallel_sharding_pass = new_pass( | ||
"auto_parallel_sharding_pass", self._dist_strategy) | ||
params_grads = auto_parallel_sharding_pass.apply( | ||
main_program, startup_program, params_grads, self._pass_context) | ||
|
||
if self._dist_strategy.gradient_merge: | ||
auto_parallel_gradient_merge_pass = new_pass( | ||
"auto_parallel_gradient_merge_pass", | ||
self._dist_strategy.gradient_merge_configs) | ||
auto_parallel_gradient_merge_pass.apply( | ||
main_program, startup_program, params_grads, self._pass_context) | ||
|
||
else: | ||
with program_guard(main_program, startup_program): | ||
optimizer = copy.deepcopy(self._optimizer) | ||
optimize_ops = optimizer.apply_gradients(params_grads) | ||
with program_guard(main_program, startup_program): | ||
optimize_ops = copy.deepcopy(self._optimizer).apply_gradients( | ||
params_grads) | ||
|
||
# update completion | ||
complete_update_annotation( | ||
main_program, dist_context=self._dist_context) | ||
|
||
return optimize_ops | ||
|
||
def _apply_post_optimization_passed(self, main_program, startup_program, | ||
rank, params_grads): | ||
|
||
# apply amp forward pass | ||
if self._dist_strategy.sharding: | ||
config = copy.deepcopy(self._dist_strategy.sharding_configs) | ||
config["dist_context"] = self._dist_context | ||
config["params_grads"] = params_grads | ||
config["global_rank"] = rank | ||
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", | ||
config) | ||
auto_parallel_sharding_pass.apply( | ||
[main_program], [startup_program], self._pass_context) | ||
|
||
# apply recompute forward pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above reply. |
||
if self._dist_strategy.gradient_merge: | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this code is not implemented yet, try to remove it first. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
|
||
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): | ||
completed_main_program = None | ||
serial_main_program = self._main_program.clone() | ||
|
@@ -203,7 +208,8 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): | |
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) | ||
|
||
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) | ||
|
||
self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog, | ||
rank, dist_params_grads) | ||
g_process_group_map = None | ||
if not relaunch_phase: | ||
g_process_group_map = copy.deepcopy(_g_process_group_map) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the comment for?
TODO
or wrong comment? The following code is sharding but not amp. If it is aTODO
, try to addTODO(who is responsible TODO)
at the beginning of this comment.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is kind of a TODO that indicate the place where the amp pass will be in future. our final goal is that all optimization pass will be applied within that function after autoparallel-graph partition. we will have several update to achieve that goal.
the final order will be: graph_partition-amp-recompute-sharding-gradient_merge
but at this moment, we implement it as amp-recompute-graph_partition-sharding-gradient_merge
fixed~