Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AutoParallel】Promote fuselinear pass in auto-parallel #59188

Merged
merged 13 commits into from
Dec 6, 2023
6 changes: 6 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)

# #########################################
# # offload configuration
# #########################################
FUSEDLINEARPROMOTION = "fused_linear_promotion"
set_field_default_config(FUSEDLINEARPROMOTION, "enable", False)

#########################################
# fused passes configuration
#########################################
Expand Down
29 changes: 29 additions & 0 deletions python/paddle/distributed/auto_parallel/static/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,35 @@ def _apply_post_optimization(
)
sp_pass.apply([main_program], [startup_program], self._pass_context)

# apply fused linear promotion pass
if (
self.is_train
and self._strategy.fused_linear_promotion.enable
and self._strategy.fused_passes.enable
):
if (
len(self._strategy.fused_passes.fused_passes_list) > 0
and "fuse_gemm_epilogue"
in self._strategy.fused_passes.fused_passes_list
):
amp_config = None
if self._strategy.amp.enable:
amp_config = copy.deepcopy(self._strategy.amp.to_dict())
config = {}
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["enable_sp"] = self._strategy.sp_optimization.enable
config["params_grads"] = params_grads
config["amp_level"] = (
amp_config['level'] if amp_config is not None else "o0"
)
fused_linear_promotion_pass = new_pass(
"auto_parallel_fused_linear_promotion", config
)
fused_linear_promotion_pass.apply(
[main_program], [startup_program], self._pass_context
)

# data parallel optimization
if self._strategy.dp_optimization.enable:
config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
Expand Down
9 changes: 9 additions & 0 deletions python/paddle/distributed/auto_parallel/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def __init__(self, config_dict=None):
super().__init__(category, config_dict)


class FusedPromotionConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.FUSEDPROMOTION
super().__init__(category, config_dict)


class AMPConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.AMP
Expand Down Expand Up @@ -224,6 +230,9 @@ def __init__(self, config=None):
config_dict = self._config_dict.get(constants.FUSED_PASSES, None)
self.fused_passes = FusedPassesConfig(config_dict)

config_dict = self._config_dict.get(constants.FUSEDPROMOTION, None)
self.fused_linear_promotion = FusedPromotionConfig(config_dict)

config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)

Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_fused_linear_promotion import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .auto_parallel_pipeline import * # noqa: F403
from .auto_parallel_sequence_parallel_optimization import * # noqa: F403
Expand Down
Loading