From f53b3c4c06b3c7f8778cbcb5d617dca3c64d96df Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Mon, 20 Nov 2023 17:18:57 +0800 Subject: [PATCH 01/11] add fused_linear_promotion pass --- .../distributed/auto_parallel/constants.py | 6 + .../auto_parallel/static/parallelizer_v2.py | 24 + .../distributed/auto_parallel/strategy.py | 9 + python/paddle/distributed/passes/__init__.py | 1 + .../auto_parallel_fused_linear_promotion.py | 787 ++++++++++++++++++ 5 files changed, 827 insertions(+) create mode 100644 python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index e80e68281c09a..b037965bcd83a 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -144,6 +144,12 @@ def set_field_default_config(category, field, default_value): set_field_default_config(DATASET, "enable", False) set_field_default_config(DATASET, "num_shards", 1) +# ######################################### +# # offload configuration +# ######################################### +FUSEDPROMOTION = "fused_promotion" +set_field_default_config(FUSEDPROMOTION, "enable", True) + ######################################### # fused passes configuration ######################################### diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index b59cfea194551..f1aaa8519ba5c 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -338,6 +338,30 @@ def _apply_post_optimization( if self._strategy is None: return + # apply fused linear promotion pass + if ( + self.is_train + and self._strategy.fused_promotion.fused_promotion + and self._strategy.fused_passes.enable + ): + amp_config = None + if self._strategy.amp.enable: + amp_config = copy.deepcopy(self._strategy.amp.to_dict()) + config = {} + config["dist_context"] = self._dist_context + config["global_rank"] = rank + config["params_grads"] = params_grads + config["enable_sp"] = False + config["amp_level"] = ( + amp_config['level'] if amp_config is not None else "o0" + ) + fused_promotion_pass = new_pass( + "auto_parallel_fused_linear_promotion", config + ) + fused_promotion_pass.apply( + [main_program], [startup_program], self._pass_context + ) + # data parallel optimization if self._strategy.dp_optimization.enable: config = copy.deepcopy(self._strategy.dp_optimization.to_dict()) diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 958d7dc565304..559d33139171e 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -82,6 +82,12 @@ def __init__(self, config_dict=None): super().__init__(category, config_dict) +class FusedPromotion(BaseConfig): + def __init__(self, config_dict=None): + category = constants.FUSEDPROMOTION + super().__init__(category, config_dict) + + class AMPConfig(BaseConfig): def __init__(self, config_dict=None): category = constants.AMP @@ -218,6 +224,9 @@ def __init__(self, config=None): config_dict = self._config_dict.get(constants.FUSED_PASSES, None) self.fused_passes = FusedPassesConfig(config_dict) + config_dict = self._config_dict.get(constants.FUSEDPROMOTION, None) + self.fused_passes = FusedPromotion(config_dict) + config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None) self.dp_optimization = DPOptimizationConfig(config_dict) diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 8c1f4ab6e5350..39e956ba0dddf 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -22,6 +22,7 @@ from .auto_parallel_quantization import * # noqa: F403 from .auto_parallel_data_parallel_optimization import * # noqa: F403 from .auto_parallel_grad_clip import * # noqa: F403 +from .auto_parallel_fused_linear_promotion import * # noqa: F403 from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403 from .auto_parallel_pipeline import * # noqa: F403 from .allreduce_matmul_grad_overlapping import * # noqa: F403 diff --git a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py new file mode 100644 index 0000000000000..a0a777b713815 --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py @@ -0,0 +1,787 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from paddle.distributed.auto_parallel.static.utils import ( + is_optimize_op, + is_recompute_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, +) +from paddle.utils import unique_name + +from ..utils.log_utils import get_logger +from .pass_base import PassBase, register_pass + +logger = get_logger(logging.INFO, "FusedLinearPromotionPass") + +_supported_optimizer_type = [ + "adam", + "adamax", + "adamw", + "decayed_adagrad", + "momentum", + "dgc_momentum", + "lars_momentum", + "merged_momentum", + "lamb", + "sgd", +] + +FUSED_LINEAR_SOURCE_PATTERNS_LIST = [ + { + "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], + "backward": ["elementwise_add_grad", "matmul_v2_grad"], + }, + { + "forward": ["matmul_v2", "c_reduce_scatter", "elementwise_add"], + "backward": ["elementwise_add_grad", "c_allgather", "matmul_v2_grad"], + }, + { + "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], + "backward": [ + "elementwise_add_grad", + "c_allreduce_sum", + "scale", + "matmul_v2_grad", + ], + }, + { + "forward": ["matmul_v2", "c_reduce_scatter", "elementwise_add"], + "backward": [ + "elementwise_add_grad", + "c_allreduce_sum", + "scale", + "c_allgather", + "matmul_v2_grad", + ], + }, + { + "forward": ["matmul_v2", "c_allreduce_sum", "cast", "elementwise_add"], + "backward": ["elementwise_add_grad", "matmul_v2_grad"], + }, + { + "forward": ["matmul_v2", "c_reduce_scatter", "cast", "elementwise_add"], + "backward": ["elementwise_add_grad", "c_allgather", "matmul_v2_grad"], + }, + { + "forward": ["matmul_v2", "c_allreduce_sum", "cast", "elementwise_add"], + "backward": [ + "elementwise_add_grad", + "c_allreduce_sum", + "scale", + "matmul_v2_grad", + ], + }, + { + "forward": ["matmul_v2", "c_reduce_scatter", "cast", "elementwise_add"], + "backward": [ + "elementwise_add_grad", + "c_allreduce_sum", + "scale", + "c_allgather", + "matmul_v2_grad", + ], + }, +] + + +@register_pass("auto_parallel_fused_linear_promotion") +class FusedLinearPromotionPass(PassBase): + """ + Apply pre-promotion that specialized for fused_linear_pass in tensor parallelism or sequence parallelism in Auto Parallel. + """ + + def __init__(self): + super().__init__() + self.set_attr("dist_context", None) + self.set_attr("global_rank", -1) + self.set_attr("enable_sp", False) + self.set_attr("amp_level", "o0") + self.set_attr("params_grads", None) + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr( + "global_rank" + ) < 0: + return False + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + self._dist_context = self.get_attr("dist_context") + self._global_rank = int(self.get_attr("global_rank")) + self._enable_sp = self.get_attr("enable_sp") + self._params_grads = self.get_attr("params_grads") + self._amp_level = self.get_attr("amp_level") + self._is_amp_o1 = self._amp_level == 'o1' + self._source_patterns = {} + self._enable_dp, self._enable_mp = self._is_enable_dp_mp( + self._dist_context + ) + + pattern_offset = 4 if self._is_amp_o1 else 0 + if self._enable_sp: + if self._enable_dp: + self._source_patterns = FUSED_LINEAR_SOURCE_PATTERNS_LIST[ + 3 + pattern_offset + ] + else: + self._source_patterns = FUSED_LINEAR_SOURCE_PATTERNS_LIST[ + 1 + pattern_offset + ] + elif self._enable_mp: + if self._enable_dp: + self._source_patterns = FUSED_LINEAR_SOURCE_PATTERNS_LIST[ + 2 + pattern_offset + ] + else: + self._source_patterns = FUSED_LINEAR_SOURCE_PATTERNS_LIST[ + 0 + pattern_offset + ] + else: + logger.warning("Neither of sp and mp is enabled, skip this pass") + return + # 1. get whether the current rank is first rank in mp + self._is_first_rank = self._is_tp_sp_first_rank( + self._dist_context, self._global_rank + ) + + # 2. get the forward and backward op list idexs in source patterns + ( + forward_segments, + backward_segments, + ) = self._get_forward_backward_op_segments(main_program) + + # 3 transform the forward ops + rename_var_names_map, deleted_bias_names = self._transform_forward( + main_program, + forward_segments, + backward_segments, + self._is_first_rank, + self._enable_sp, + self._is_amp_o1, + ) + # 4 transform the backward ops + self._transform_backward( + main_program, + backward_segments, + rename_var_names_map, + self._is_first_rank, + self._enable_sp, + ) + # 5. transform the optimizer ops + self._transform_opt( + main_program, + deleted_bias_names, + self._params_grads, + self._is_first_rank, + self._is_amp_o1, + ) + logger.info(f"deleted_bias_names: {deleted_bias_names}") + # 6. transform the startup program + self._transform_startup_program( + startup_program, deleted_bias_names, self._is_first_rank + ) + + def _is_tp_sp_first_rank(self, dist_context, rank): + for process_mesh in dist_context.process_meshes: + if len(process_mesh._shape) == 1: + return rank == min(process_mesh.process_ids) + else: + inner_mesh = process_mesh.mesh + inner_mesh_shape = inner_mesh.shape + if len(inner_mesh.shape) == 2: + for id0 in range(inner_mesh_shape[0]): + if rank == min(inner_mesh[id0, :]): + return True + elif len(inner_mesh.shape) == 3: + for id0 in range(inner_mesh_shape[0]): + for id1 in range(inner_mesh_shape[1]): + if rank == min(inner_mesh[id0, id1, :]): + return True + else: + raise ValueError("inner mesh shape is not supported") + return False + + def _is_enable_dp_mp(self, dist_context): + for process_mesh in dist_context.process_meshes: + if len(process_mesh._shape) == 1: + return False, False + else: + inner_mesh = process_mesh.mesh + inner_mesh_shape = inner_mesh.shape + if len(inner_mesh.shape) == 2: # Dp * Mp + return inner_mesh_shape[0] > 1, inner_mesh_shape[1] > 1 + elif len(inner_mesh.shape) == 3: + return inner_mesh_shape[-2] > 1, inner_mesh_shape[-1] > 1 + else: + raise ValueError("inner mesh shape is not supported") + return False, False + + def _reset_op_dist_attr(self, op, new_var_name, is_input=True): + """ + Reset the dist_attr of the input and output of the operator. + """ + op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) + var_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + self._main_block.var(new_var_name) + ) + assert ( + op_dist_attr is not None + ), f"Reset op {op.desc.type()}'s dist_attr, but its dist_attr is None" + if is_input: + op_dist_attr.set_input_dist_attr(new_var_name, var_dist_attr) + if not is_input: + op_dist_attr.set_output_dist_attr(new_var_name, var_dist_attr) + + def _get_forward_backward_op_segments(self, main_program): + """ + Get the operator segments according to the source patterns. + """ + + def can_match_pattern( + ops, start_id, pattern, forward_matmul_inputs, is_backward=False + ): + """ + Check whether the ops in the range [start_id, start_id + len(pattern)] can match the pattern. + If the ops is in forward pass, check it directly. However, when the ops is in backward pass, + we need to additionally check whether the input of the last op in pattern is in forward_matmul_inputs to + deal the case of enabling recompute. + """ + new_id = start_id + if not is_backward: + for op_name in pattern: + if ops[new_id].type != op_name: + return False + new_id += 1 + forward_matmul_inputs.extend(ops[start_id].input_arg_names) + return True + else: + for op_name in pattern: + if ops[new_id].type != op_name: + return False + new_id += 1 + for input_name in (ops[new_id - 1].input_arg_names)[1:]: + if input_name not in forward_matmul_inputs: + return False + return True + + global_block = main_program.global_block() + forward_segments = [] + backward_segmnets = [] + ops_len = len(global_block.ops) + + self._forward_patterns_len = len(self._source_patterns["forward"]) + self._backward_patterns_len = len(self._source_patterns["backward"]) + forward_matmul_inputs = [] + for id, op in enumerate(global_block.ops): + if id > ops_len - self._backward_patterns_len: + break + if int(op.desc.attr('op_role')) == 0 or ( + is_recompute_op(op) and not op.type.endswith("_grad") + ): # forward + if can_match_pattern( + global_block.ops, + id, + self._source_patterns["forward"], + forward_matmul_inputs, + is_backward=False, + ): + forward_segments.append( + [id, id + self._forward_patterns_len] + ) + elif int(op.desc.attr('op_role')) == 1: # backward + if can_match_pattern( + global_block.ops, + id, + self._source_patterns["backward"], + forward_matmul_inputs, + is_backward=True, + ): + backward_segmnets.append( + [id, id + self._backward_patterns_len] + ) + else: + pass + logger.info(f"forward_segments: {forward_segments}") + logger.info(f"backward_segmnets: {backward_segmnets}") + return forward_segments, backward_segmnets + + def _transform_forward( + self, + main_program, + forward_segments, + backward_segments, + is_first_rank, + is_sp, + is_amp_o1, + ): + """ + Transform the forward pass. + """ + + def _transform_forward_segment( + global_block, + forward_segment, + backward_segments, + is_first_rank, + is_sp, + is_amp_o1, + ): + """ + Transform one forward segment. + """ + # 1. prepare the forward_segment + # 1.1 check whether the forward_segment is right + origin_matmul_op = global_block.ops[forward_segment[0]] + origin_comm_op = global_block.ops[forward_segment[0] + 1] + origin_add_op = global_block.ops[forward_segment[1] - 1] + origin_cast_op = global_block.ops[forward_segment[1] - 2] + origin_matmul_output_name = origin_matmul_op.output_arg_names[0] + origin_comm_input_name = origin_comm_op.input_arg_names[0] + assert ( + origin_matmul_output_name == origin_comm_input_name + ), f"The 0th op output name {origin_matmul_output_name} is not equal to the 1st op input name {origin_comm_input_name}" + origin_comm_output_name = origin_comm_op.output_arg_names[0] + origin_add_input_names = origin_add_op.input_arg_names + assert ( + origin_comm_output_name == origin_add_input_names[0] + ), f"The 1st op output name {origin_comm_output_name} is not equal to the 2nd op input name {origin_add_input_names[0]}" + # 1.2 get the origin dist_attr + origin_add_dist_attr = ( + self._dist_context.get_op_dist_attr_for_program(origin_add_op) + ) + assert ( + origin_add_dist_attr is not None + ), f"Origin add op {origin_add_op.type} has no dist attr" + ref_mesh = origin_add_dist_attr.process_mesh + in_var_dist_attr = origin_add_dist_attr.get_input_dist_attr( + origin_add_op.input_arg_names[0] + ) + ref_mapping = in_var_dist_attr.dims_mapping + + # 2. deal matmul_v2 op + origin_matmul_output_new_name = unique_name.generate( + origin_matmul_output_name + "@promote" + ) + global_block.create_var( + name=origin_matmul_output_new_name, + dtype=global_block.var(origin_matmul_output_name).dtype, + shape=global_block.var(origin_matmul_output_name).shape, + persistable=False, + stop_gradient=False, + ) + rename_vars_map[ + origin_matmul_output_name + ] = origin_matmul_output_new_name + origin_matmul_op._rename_output( + origin_matmul_output_name, origin_matmul_output_new_name + ) + + # 3. deal add op and cast op + if is_first_rank: + # insert the "elementwise_add" op before reduce_sum + new_add_op = global_block._insert_op_without_sync( + forward_segment[0] + 1, + type="nop", + ) + new_op_desc = new_add_op.desc + new_op_desc.copy_from(origin_add_op.desc) + # create new var of new_add_op output + origin_add_output_name = origin_add_op.output_arg_names[0] + new_add_op_output_name = unique_name.generate( + origin_add_output_name + "@promote" + ) + new_shape_var_name = ( + origin_add_output_name + if not is_sp + else origin_matmul_output_name + ) + global_block.create_var( + name=new_add_op_output_name, + dtype=global_block.var(origin_add_output_name).dtype, + shape=global_block.var(new_shape_var_name).shape, + persistable=False, + stop_gradient=False, + ) + global_block._remove_var( + origin_matmul_output_name + ) # We can remove the origin_matmul_output now. + global_block._remove_var(origin_add_output_name) + # rename_vars_map[origin_add_output_name] = new_add_op_output_name + new_add_op._rename_output( + origin_add_output_name, new_add_op_output_name + ) + # rename input of new_add_op + rename_vars_map[ + origin_add_op.input_arg_names[0] + ] = origin_matmul_output_new_name + new_add_op._rename_input( + origin_add_op.input_arg_names[0], + origin_matmul_output_new_name, + ) + # deal dist_attr + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_add_op, ref_mesh, ref_mapping, self._dist_context + ) + # 'cast' op also need to adjust + if is_amp_o1: + new_cast_op = global_block._insert_op_without_sync( + forward_segment[0] + 1, + type="nop", + ) + new_op_desc = new_cast_op.desc + new_op_desc.copy_from(origin_cast_op.desc) + if ( + new_cast_op.input_arg_names[0] + not in delete_bias_vars_name + ): # fp16 = cast(fp32) + delete_bias_vars_name.append( + new_cast_op.input_arg_names[0] + ) + else: + if ( + new_add_op.input_arg_names[1] + not in delete_bias_vars_name + ): + delete_bias_vars_name.append( + new_add_op.input_arg_names[1] + ) + else: + origin_add_output_name = origin_add_op.output_arg_names[0] + global_block._remove_var(origin_add_output_name) + # 4. deal comm op + # The input of c_allreduce_sum only be used once, so we don't need add it in the rename_vars_map + if is_first_rank: + origin_comm_op._rename_input( + origin_comm_op.input_arg_names[0], + new_add_op.output_arg_names[0], + ) + else: + origin_comm_op._rename_input( + origin_comm_op.input_arg_names[0], + origin_matmul_output_new_name, + ) + if origin_comm_op.type == "c_allreduce_sum": + new_comm_var_name = origin_comm_op.input_arg_names[0] + else: + new_comm_var_name = unique_name.generate( + origin_comm_output_name + "@promote" + ) + global_block.create_var( + name=new_comm_var_name, + dtype=global_block.var(origin_comm_output_name).dtype, + shape=global_block.var(origin_comm_output_name).shape, + persistable=False, + stop_gradient=False, + ) + rename_vars_map[origin_comm_output_name] = new_comm_var_name + if global_block.has_var(origin_comm_output_name): + global_block._remove_var(origin_comm_output_name) + rename_vars_map[ + origin_add_output_name + ] = new_comm_var_name # the output of comm op inplace the output of add op for next ops + origin_comm_op._rename_output( + origin_comm_output_name, new_comm_var_name + ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + origin_comm_op, ref_mesh, ref_mapping, self._dist_context + ) + + # 5. remove elementwise_add op and cast op + if is_first_rank: + if is_amp_o1: + global_block._remove_op(forward_segment[0] + 5) + global_block._remove_op(forward_segment[0] + 4) + else: + global_block._remove_op(forward_segment[0] + 3) + else: + global_block._remove_op( + forward_segment[1] - 1 + ) # remove elementwise_add op + if is_amp_o1: + if ( + origin_cast_op.input_arg_names[0] + not in delete_bias_vars_name + ): + delete_bias_vars_name.append( + origin_cast_op.input_arg_names[0] + ) + global_block._remove_var(origin_cast_op.output_arg_names[0]) + global_block._remove_op( + forward_segment[1] - 2 + ) # remove cast op + else: + if origin_add_input_names[1] not in delete_bias_vars_name: + delete_bias_vars_name.append(origin_add_input_names[1]) + # update backward forward_segment + for back_seg in reversed(backward_segments): + if is_amp_o1: + if back_seg[0] > forward_segment[0]: + back_seg[0] -= 2 + back_seg[1] -= 2 + else: + break + else: + if back_seg[0] > forward_segment[0]: + back_seg[0] -= 1 + back_seg[1] -= 1 + else: + break + + global_block = main_program.global_block() + rename_vars_map = {} # origin_name -> new_name + delete_bias_vars_name = [] + for segment in reversed(forward_segments): + _transform_forward_segment( + global_block, + segment, + backward_segments, + is_first_rank, + is_sp, + is_amp_o1, + ) + global_block._sync_with_cpp() + return rename_vars_map, delete_bias_vars_name + + def _transform_backward( + self, + main_program, + backward_segments, + rename_var_names_map, + is_first_rank, + is_sp, + ): + global_block = main_program.global_block() + to_delete_grad_of_param = [] + if is_first_rank: + if is_sp: + # place the comm_op(c_allgather) before the elementwise_add_grad + for segment in reversed(backward_segments): + add_grad_op = global_block.ops[segment[0]] + matmul_grad_op = global_block.ops[segment[-1] - 1] + origin_comm_op_id = segment[-2] + orgin_comm_op = global_block.ops[origin_comm_op_id] + new_comm_op = global_block._insert_op( + segment[0], + type="nop", + ) + new_comm_op.desc.copy_from(orgin_comm_op.desc) + # rename input and output + # TODO(lizhiyu): The input and ouput of all_gather may not suitable. + new_comm_op._rename_input( + orgin_comm_op.input_arg_names[0], + add_grad_op.input_arg_names[0], + ) + add_grad_op._rename_input( + add_grad_op.input_arg_names[0], + new_comm_op.output_arg_names[0], + ) + matmul_grad_op._rename_input( + matmul_grad_op.input_arg_names[0], + add_grad_op.output_arg_names[1], + ) + global_block._remove_op(segment[-2] + 1) + global_block._sync_with_cpp() + else: # not is_first_rank_in tp or sp + # need to delete the grad op assosiated with the deleted bias var + if not is_sp: + for segment in reversed(backward_segments): + add_grad_op = global_block.ops[segment[0]] + rename_var_names_map[ + add_grad_op.output_arg_names[0] + ] = add_grad_op.input_arg_names[0] + global_block._remove_var(add_grad_op.output_arg_names[0]) + to_delete_grad_of_param.append( + add_grad_op.output_arg_names[1] + ) + if self._enable_dp: + c_all_reduce_op = global_block.ops[segment[0] + 1] + scale_op = global_block.ops[segment[0] + 2] + global_block._remove_op(segment[0] + 2) + global_block._remove_op(segment[0] + 1) + global_block._remove_op(segment[0]) + global_block._sync_with_cpp() + else: + # place the comm_op(c_allgather) before the elementwise_add_grad + for segment in reversed(backward_segments): + add_grad_op = global_block.ops[segment[0]] + matmul_grad_op = global_block.ops[segment[-1] - 1] + origin_comm_op_id = segment[-2] + orgin_comm_op = global_block.ops[origin_comm_op_id] + new_comm_op = global_block._insert_op( + segment[0], + type="nop", + ) + new_comm_op.desc.copy_from(orgin_comm_op.desc) + new_comm_op._rename_input( + orgin_comm_op.input_arg_names[0], + add_grad_op.input_arg_names[0], + ) + matmul_grad_op._rename_input( + matmul_grad_op.input_arg_names[0], + new_comm_op.output_arg_names[0], + ) + global_block._remove_op(segment[-2] + 1) + global_block._remove_var(add_grad_op.output_arg_names[0]) + global_block._remove_var(add_grad_op.output_arg_names[1]) + # remove vars and op + if self._enable_dp: # DP + c_all_reduce_op = global_block.ops[segment[1]] + scale_op = global_block.ops[segment[2]] + global_block._remove_var( + c_all_reduce_op.input_arg_names[0] + ) + global_block._remove_var(scale_op.outpu_arg_names[0]) + global_block._remove_op(segment[2]) + global_block._remove_op(segment[1]) + global_block._remove_op(segment[0]) + global_block._sync_with_cpp() + + # rename input vars in gloabl_block + for op in global_block.ops: + if is_optimize_op(op): + continue + for var_name in op.input_arg_names: + if var_name in rename_var_names_map: + op._rename_input(var_name, rename_var_names_map[var_name]) + if self._is_amp_o1: + for var_name in to_delete_grad_of_param: + global_block._remove_var(var_name) + global_block._sync_with_cpp() + + def _transform_opt( + self, + main_program, + deleted_bias_names, + params_grads, + is_first_rank, + is_amp_o1, + ): + """ + Only support ClipGradByGlobalNorm and AMP-O2 + """ + if is_first_rank: + return + + deleted_bias_grads_names = [] + to_delete_params_grads = [] + for id, (param, grad) in enumerate(params_grads): + if param.name in deleted_bias_names: + deleted_bias_grads_names.append(grad.name) + to_delete_params_grads.append(id) + + to_delete_op_ids = [] + for id in reversed(range(len(main_program.global_block().ops))): + global_block = main_program.global_block() + op = global_block.ops[id] + op_input_names = op.input_arg_names + for op_input in op_input_names: + if op_input in deleted_bias_grads_names: + if op.type in _supported_optimizer_type: + for output_var in op.output_arg_names: + global_block._remove_var(output_var) + grad_var = op.input('Grad')[0] + global_block._remove_var(grad_var) + to_delete_op_ids.append(id) + if ( + op.type == "squared_l2_norm" + or op.type == "clip_by_norm" + ): + output_var_name = op.output_arg_names[0] + global_block._remove_var(output_var_name) + to_delete_op_ids.append(id) + for intra_id in range(id + 1, len(global_block.ops)): + intra_op = global_block.ops[intra_id] + if ( + output_var_name in intra_op.input_arg_names + and intra_op.type == "stack" + ): + origin_vars = intra_op.input("X") + origin_vars.remove(output_var_name) + intra_op.desc.set_input("X", origin_vars) + break + if op.type == "elementwise_mul": + to_delete_op_ids.append(id) + # check_finite_and_unscale and update_loss_scaling + if ( + op.type == "check_finite_and_unscale" + or op.type == "update_loss_scaling" + ): + origin_vars = op.input("X") + origin_vars.remove(op_input) + op.desc.set_input("X", origin_vars) + origin_vars = op.output("Out") + origin_vars.remove(op_input) + op.desc.set_output("Out", origin_vars) + + if is_amp_o1: + for output_name in op.output_arg_names: + if ( + output_name in deleted_bias_grads_names + and op.type == 'cast' + ): + to_delete_op_ids.append(id) + + for id in to_delete_op_ids: + global_block._remove_op(id) + main_program.global_block()._sync_with_cpp() + if not is_first_rank: + for id in reversed(to_delete_params_grads): + del params_grads[id] + return + + def _transform_startup_program( + self, startup_program, deleted_bias_names, is_first_rank + ): + """ + Delete the vars and ops assosiated with deleted_bias_names in startup program. + TODO(lizhiyu): If the amp-o2, there are extra variables to delete, such as 'opt_linear_1.b_0_fp32_master_0'. + """ + logger.debug(f"Before transform startup_program: {startup_program}") + cur_glock = startup_program.global_block() + to_delete_op_ids = [] + to_delete_extra_vars = [] + for id, op in enumerate(cur_glock.ops): + if not is_first_rank: + output_var = op.output_arg_names[0] + if output_var in deleted_bias_names: + to_delete_op_ids.append(id) + input_vars = op.input_arg_names + if len(input_vars) > 0 and input_vars[0] in deleted_bias_names: + if id not in to_delete_op_ids: + to_delete_op_ids.append(id) + if len(op.output_arg_names) == 1: + to_delete_extra_vars.append(op.output_arg_names[0]) + else: + if op.type == "c_broadcast": + input_vars = op.input_arg_names + if input_vars[0] in deleted_bias_names: + if id not in to_delete_op_ids: + to_delete_op_ids.append(id) + for to_delete_id in reversed(to_delete_op_ids): + cur_glock._remove_op(to_delete_id) + if not is_first_rank: + for var_name in deleted_bias_names: + cur_glock._remove_var(var_name) + for var_name in to_delete_extra_vars: + if cur_glock.has_var(var_name): + cur_glock._remove_var(var_name) + cur_glock._sync_with_cpp() + logger.debug(f"After transform startup_program: {startup_program}") From 6400677d7088640aab21c87a0175b364be3f8988 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Tue, 21 Nov 2023 10:58:06 +0800 Subject: [PATCH 02/11] add promote_fusedlinear pass --- .../auto_parallel/static/parallelizer_v2.py | 24 ++-- .../auto_parallel_fused_linear_promotion.py | 111 +++++++++++------- 2 files changed, 79 insertions(+), 56 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 4bf0f9e682a2f..772a7c86a676f 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -338,10 +338,19 @@ def _apply_post_optimization( if self._strategy is None: return - # apply fused linear promotion pass + # sequence parallel optimization + if self._strategy.sp_optimization.enable: + config = copy.deepcopy(self._strategy.sp_optimization.to_dict()) + config["dist_context"] = self._dist_context + config["global_rank"] = rank + sp_pass = new_pass( + "auto_parallel_sequence_parallel_optimization", config + ) + sp_pass.apply([main_program], [startup_program], self._pass_context) + if ( self.is_train - and self._strategy.fused_promotion.fused_promotion + # and self._strategy.fused_promotion.fused_promotion and self._strategy.fused_passes.enable ): amp_config = None @@ -350,8 +359,8 @@ def _apply_post_optimization( config = {} config["dist_context"] = self._dist_context config["global_rank"] = rank + config["enable_sp"] = self._strategy.sp_optimization.enable config["params_grads"] = params_grads - config["enable_sp"] = False config["amp_level"] = ( amp_config['level'] if amp_config is not None else "o0" ) @@ -361,15 +370,6 @@ def _apply_post_optimization( fused_promotion_pass.apply( [main_program], [startup_program], self._pass_context ) - # sequence parallel optimization - if self._strategy.sp_optimization.enable: - config = copy.deepcopy(self._strategy.sp_optimization.to_dict()) - config["dist_context"] = self._dist_context - config["global_rank"] = rank - sp_pass = new_pass( - "auto_parallel_sequence_parallel_optimization", config - ) - sp_pass.apply([main_program], [startup_program], self._pass_context) # data parallel optimization if self._strategy.dp_optimization.enable: diff --git a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py index a0a777b713815..c029d85b19357 100644 --- a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py +++ b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import logging from paddle.distributed.auto_parallel.static.utils import ( @@ -40,13 +41,20 @@ ] FUSED_LINEAR_SOURCE_PATTERNS_LIST = [ + # amp_level == 'o0' or 'o2' { "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], "backward": ["elementwise_add_grad", "matmul_v2_grad"], }, { - "forward": ["matmul_v2", "c_reduce_scatter", "elementwise_add"], - "backward": ["elementwise_add_grad", "c_allgather", "matmul_v2_grad"], + "forward": ["matmul_v2", "c_reducescatter", "elementwise_add"], + "backward": [ + "elementwise_add_grad", + "c_allreduce_sum", + "scale", + "c_allgather", + "matmul_v2_grad", + ], }, { "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], @@ -58,7 +66,7 @@ ], }, { - "forward": ["matmul_v2", "c_reduce_scatter", "elementwise_add"], + "forward": ["matmul_v2", "c_reducescatter", "elementwise_add"], "backward": [ "elementwise_add_grad", "c_allreduce_sum", @@ -67,13 +75,20 @@ "matmul_v2_grad", ], }, + # amp_level == 'o1' { "forward": ["matmul_v2", "c_allreduce_sum", "cast", "elementwise_add"], "backward": ["elementwise_add_grad", "matmul_v2_grad"], }, { - "forward": ["matmul_v2", "c_reduce_scatter", "cast", "elementwise_add"], - "backward": ["elementwise_add_grad", "c_allgather", "matmul_v2_grad"], + "forward": ["matmul_v2", "c_reducescatter", "cast", "elementwise_add"], + "backward": [ + "elementwise_add_grad", + "c_allreduce_sum", + "scale", + "c_allgather", + "matmul_v2_grad", + ], }, { "forward": ["matmul_v2", "c_allreduce_sum", "cast", "elementwise_add"], @@ -85,7 +100,7 @@ ], }, { - "forward": ["matmul_v2", "c_reduce_scatter", "cast", "elementwise_add"], + "forward": ["matmul_v2", "c_reducescatter", "cast", "elementwise_add"], "backward": [ "elementwise_add_grad", "c_allreduce_sum", @@ -126,9 +141,9 @@ def _check_conflict(self, other_pass): def _apply_single_impl(self, main_program, startup_program, context): self._dist_context = self.get_attr("dist_context") self._global_rank = int(self.get_attr("global_rank")) - self._enable_sp = self.get_attr("enable_sp") self._params_grads = self.get_attr("params_grads") self._amp_level = self.get_attr("amp_level") + self._enable_sp = self.get_attr("enable_sp") self._is_amp_o1 = self._amp_level == 'o1' self._source_patterns = {} self._enable_dp, self._enable_mp = self._is_enable_dp_mp( @@ -169,6 +184,7 @@ def _apply_single_impl(self, main_program, startup_program, context): ) = self._get_forward_backward_op_segments(main_program) # 3 transform the forward ops + logger.info(f"before main_program: {main_program}") rename_var_names_map, deleted_bias_names = self._transform_forward( main_program, forward_segments, @@ -194,6 +210,7 @@ def _apply_single_impl(self, main_program, startup_program, context): self._is_amp_o1, ) logger.info(f"deleted_bias_names: {deleted_bias_names}") + logger.info(f"after main_program: {main_program}") # 6. transform the startup program self._transform_startup_program( startup_program, deleted_bias_names, self._is_first_rank @@ -234,21 +251,22 @@ def _is_enable_dp_mp(self, dist_context): raise ValueError("inner mesh shape is not supported") return False, False - def _reset_op_dist_attr(self, op, new_var_name, is_input=True): - """ - Reset the dist_attr of the input and output of the operator. - """ - op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) - var_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - self._main_block.var(new_var_name) - ) - assert ( - op_dist_attr is not None - ), f"Reset op {op.desc.type()}'s dist_attr, but its dist_attr is None" - if is_input: - op_dist_attr.set_input_dist_attr(new_var_name, var_dist_attr) - if not is_input: - op_dist_attr.set_output_dist_attr(new_var_name, var_dist_attr) + def _is_enable_sp(self, main_program): + for op in main_program.global_block().ops: + forward_has_scatter_reduce_op = False + backward_has_all_gather_op = False + if ( + int(op.desc.attr('op_role')) == 0 + and op.type == 'c_reducescatter' + ): # forward + forward_has_scatter_reduce_op = True + elif ( + int(op.desc.attr('op_role')) == 1 and op.type == 'c_allgather' + ): # backward + backward_has_all_gather_op = True + if forward_has_scatter_reduce_op and backward_has_all_gather_op: + return True + return False def _get_forward_backward_op_segments(self, main_program): """ @@ -277,9 +295,13 @@ def can_match_pattern( if ops[new_id].type != op_name: return False new_id += 1 - for input_name in (ops[new_id - 1].input_arg_names)[1:]: - if input_name not in forward_matmul_inputs: - return False + matmul_grad_input_names = ops[new_id - 1].input_arg_names + # for refined-recompute + if ( + matmul_grad_input_names[1] not in forward_matmul_inputs + and matmul_grad_input_names[2] not in forward_matmul_inputs + ): + return False return True global_block = main_program.global_block() @@ -576,7 +598,7 @@ def _transform_backward( for segment in reversed(backward_segments): add_grad_op = global_block.ops[segment[0]] matmul_grad_op = global_block.ops[segment[-1] - 1] - origin_comm_op_id = segment[-2] + origin_comm_op_id = segment[-1] - 2 orgin_comm_op = global_block.ops[origin_comm_op_id] new_comm_op = global_block._insert_op( segment[0], @@ -597,7 +619,9 @@ def _transform_backward( matmul_grad_op.input_arg_names[0], add_grad_op.output_arg_names[1], ) - global_block._remove_op(segment[-2] + 1) + global_block._remove_op( + segment[-1] + ) # origin idx is segment[-1] - 1 global_block._sync_with_cpp() else: # not is_first_rank_in tp or sp # need to delete the grad op assosiated with the deleted bias var @@ -619,28 +643,27 @@ def _transform_backward( global_block._remove_op(segment[0]) global_block._sync_with_cpp() else: - # place the comm_op(c_allgather) before the elementwise_add_grad for segment in reversed(backward_segments): add_grad_op = global_block.ops[segment[0]] - matmul_grad_op = global_block.ops[segment[-1] - 1] - origin_comm_op_id = segment[-2] - orgin_comm_op = global_block.ops[origin_comm_op_id] - new_comm_op = global_block._insert_op( - segment[0], - type="nop", - ) - new_comm_op.desc.copy_from(orgin_comm_op.desc) - new_comm_op._rename_input( + orgin_comm_op = global_block.ops[segment[-1] - 1] + rename_var_names_map[ + add_grad_op.output_arg_names[0] + ] = add_grad_op.input_arg_names[0] + global_block._remove_var(add_grad_op.output_arg_names[0]) + orgin_comm_op._rename_input( orgin_comm_op.input_arg_names[0], - add_grad_op.input_arg_names[0], + add_grad_op.output_arg_names[0], ) - matmul_grad_op._rename_input( - matmul_grad_op.input_arg_names[0], - new_comm_op.output_arg_names[0], + to_delete_grad_of_param.append( + add_grad_op.output_arg_names[1] ) - global_block._remove_op(segment[-2] + 1) - global_block._remove_var(add_grad_op.output_arg_names[0]) - global_block._remove_var(add_grad_op.output_arg_names[1]) + + # remove 'elementwise_add_grad' 'c_allreduce_sum' 'scale' + global_block._remove_op( + segment[0] + 3 + ) # elementwise_add_grad + global_block._remove_op(segment[0] + 2) # c_allreduce_sum + global_block._remove_op(segment[0] + 1) # scale # remove vars and op if self._enable_dp: # DP c_all_reduce_op = global_block.ops[segment[1]] From 576dde0d3a4f2757ef7734732a354fbe2165728a Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Tue, 21 Nov 2023 20:49:35 +0800 Subject: [PATCH 03/11] support sp without dp --- .../distributed/auto_parallel/constants.py | 2 +- .../auto_parallel/static/parallelizer_v2.py | 3 +- .../distributed/auto_parallel/strategy.py | 4 +- .../auto_parallel_fused_linear_promotion.py | 130 ++++++++++-------- 4 files changed, 78 insertions(+), 61 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index c51201a1b9a29..616c3c7f5b3bf 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -148,7 +148,7 @@ def set_field_default_config(category, field, default_value): # # offload configuration # ######################################### FUSEDPROMOTION = "fused_promotion" -set_field_default_config(FUSEDPROMOTION, "enable", True) +set_field_default_config(FUSEDPROMOTION, "enable", False) ######################################### # fused passes configuration diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 772a7c86a676f..b2e47309e052d 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -348,9 +348,10 @@ def _apply_post_optimization( ) sp_pass.apply([main_program], [startup_program], self._pass_context) + # apply fused linear promotion pass if ( self.is_train - # and self._strategy.fused_promotion.fused_promotion + and self._strategy.fused_promotion.enable and self._strategy.fused_passes.enable ): amp_config = None diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 928fc2dc9e7c9..ce2e34985c824 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -82,7 +82,7 @@ def __init__(self, config_dict=None): super().__init__(category, config_dict) -class FusedPromotion(BaseConfig): +class FusedPromotionConfig(BaseConfig): def __init__(self, config_dict=None): category = constants.FUSEDPROMOTION super().__init__(category, config_dict) @@ -231,7 +231,7 @@ def __init__(self, config=None): self.fused_passes = FusedPassesConfig(config_dict) config_dict = self._config_dict.get(constants.FUSEDPROMOTION, None) - self.fused_passes = FusedPromotion(config_dict) + self.fused_promotion = FusedPromotionConfig(config_dict) config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None) self.dp_optimization = DPOptimizationConfig(config_dict) diff --git a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py index c029d85b19357..eb22020c0a0ff 100644 --- a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py +++ b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py @@ -41,7 +41,7 @@ ] FUSED_LINEAR_SOURCE_PATTERNS_LIST = [ - # amp_level == 'o0' or 'o2' + # amp_level == 'o2' or 'o3' { "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], "backward": ["elementwise_add_grad", "matmul_v2_grad"], @@ -150,6 +150,12 @@ def _apply_single_impl(self, main_program, startup_program, context): self._dist_context ) + # TODO(lizhiyu): We the TP with DP is ready, add this support. + if self._enable_dp and self._enable_sp: + logger.warning( + "Don't support the case of enable_dp and enable_sp because the TP parallelism not ready, skip this pass" + ) + pattern_offset = 4 if self._is_amp_o1 else 0 if self._enable_sp: if self._enable_dp: @@ -182,9 +188,13 @@ def _apply_single_impl(self, main_program, startup_program, context): forward_segments, backward_segments, ) = self._get_forward_backward_op_segments(main_program) - + if len(forward_segments) == 0 or len(backward_segments) == 0: + logger.warning( + "No forward and backward op segments, skip this pass" + ) + return # 3 transform the forward ops - logger.info(f"before main_program: {main_program}") + logger.debug(f"before main_program: {main_program}") rename_var_names_map, deleted_bias_names = self._transform_forward( main_program, forward_segments, @@ -193,6 +203,7 @@ def _apply_single_impl(self, main_program, startup_program, context): self._enable_sp, self._is_amp_o1, ) + # 4 transform the backward ops self._transform_backward( main_program, @@ -201,6 +212,7 @@ def _apply_single_impl(self, main_program, startup_program, context): self._is_first_rank, self._enable_sp, ) + # 5. transform the optimizer ops self._transform_opt( main_program, @@ -210,7 +222,8 @@ def _apply_single_impl(self, main_program, startup_program, context): self._is_amp_o1, ) logger.info(f"deleted_bias_names: {deleted_bias_names}") - logger.info(f"after main_program: {main_program}") + logger.debug(f"after main_program: {main_program}") + # 6. transform the startup program self._transform_startup_program( startup_program, deleted_bias_names, self._is_first_rank @@ -218,37 +231,32 @@ def _apply_single_impl(self, main_program, startup_program, context): def _is_tp_sp_first_rank(self, dist_context, rank): for process_mesh in dist_context.process_meshes: - if len(process_mesh._shape) == 1: + inner_mesh = process_mesh.mesh + inner_mesh_shape = inner_mesh.shape + if len(inner_mesh_shape) == 1: return rank == min(process_mesh.process_ids) - else: - inner_mesh = process_mesh.mesh - inner_mesh_shape = inner_mesh.shape - if len(inner_mesh.shape) == 2: - for id0 in range(inner_mesh_shape[0]): - if rank == min(inner_mesh[id0, :]): + elif len(inner_mesh.shape) == 2: + for id0 in range(inner_mesh_shape[0]): + if rank == min(inner_mesh[id0, :]): + return True + elif len(inner_mesh.shape) == 3: + for id0 in range(inner_mesh_shape[0]): + for id1 in range(inner_mesh_shape[1]): + if rank == min(inner_mesh[id0, id1, :]): return True - elif len(inner_mesh.shape) == 3: - for id0 in range(inner_mesh_shape[0]): - for id1 in range(inner_mesh_shape[1]): - if rank == min(inner_mesh[id0, id1, :]): - return True - else: - raise ValueError("inner mesh shape is not supported") + else: + raise ValueError("inner mesh shape is not supported") return False def _is_enable_dp_mp(self, dist_context): for process_mesh in dist_context.process_meshes: - if len(process_mesh._shape) == 1: - return False, False + inner_mesh = process_mesh.mesh + inner_mesh_shape = inner_mesh.shape + if len(inner_mesh_shape) == 1: + return False, inner_mesh_shape[0] > 1 else: - inner_mesh = process_mesh.mesh - inner_mesh_shape = inner_mesh.shape - if len(inner_mesh.shape) == 2: # Dp * Mp - return inner_mesh_shape[0] > 1, inner_mesh_shape[1] > 1 - elif len(inner_mesh.shape) == 3: - return inner_mesh_shape[-2] > 1, inner_mesh_shape[-1] > 1 - else: - raise ValueError("inner mesh shape is not supported") + # DP * MP + return inner_mesh_shape[-2] > 1, inner_mesh_shape[-1] > 1 return False, False def _is_enable_sp(self, main_program): @@ -341,6 +349,9 @@ def can_match_pattern( ) else: pass + assert len(forward_segments) >= len( + backward_segmnets + ), "The number of forward segments should be not shorter than the number of backward segments." logger.info(f"forward_segments: {forward_segments}") logger.info(f"backward_segmnets: {backward_segmnets}") return forward_segments, backward_segmnets @@ -374,7 +385,9 @@ def _transform_forward_segment( origin_matmul_op = global_block.ops[forward_segment[0]] origin_comm_op = global_block.ops[forward_segment[0] + 1] origin_add_op = global_block.ops[forward_segment[1] - 1] - origin_cast_op = global_block.ops[forward_segment[1] - 2] + origin_cast_op = ( + global_block.ops[forward_segment[1] - 2] if is_amp_o1 else None + ) origin_matmul_output_name = origin_matmul_op.output_arg_names[0] origin_comm_input_name = origin_comm_op.input_arg_names[0] assert ( @@ -446,11 +459,9 @@ def _transform_forward_segment( origin_matmul_output_name ) # We can remove the origin_matmul_output now. global_block._remove_var(origin_add_output_name) - # rename_vars_map[origin_add_output_name] = new_add_op_output_name new_add_op._rename_output( origin_add_output_name, new_add_op_output_name ) - # rename input of new_add_op rename_vars_map[ origin_add_op.input_arg_names[0] ] = origin_matmul_output_new_name @@ -486,8 +497,12 @@ def _transform_forward_segment( new_add_op.input_arg_names[1] ) else: - origin_add_output_name = origin_add_op.output_arg_names[0] + origin_add_output_name = origin_add_op.output_arg_names[ + 0 + ] # We can remove the origin_matmul_output now. global_block._remove_var(origin_add_output_name) + global_block._remove_var(origin_matmul_output_name) + # 4. deal comm op # The input of c_allreduce_sum only be used once, so we don't need add it in the rename_vars_map if is_first_rank: @@ -606,7 +621,6 @@ def _transform_backward( ) new_comm_op.desc.copy_from(orgin_comm_op.desc) # rename input and output - # TODO(lizhiyu): The input and ouput of all_gather may not suitable. new_comm_op._rename_input( orgin_comm_op.input_arg_names[0], add_grad_op.input_arg_names[0], @@ -617,11 +631,14 @@ def _transform_backward( ) matmul_grad_op._rename_input( matmul_grad_op.input_arg_names[0], - add_grad_op.output_arg_names[1], + add_grad_op.output_arg_names[0], ) + global_block._remove_op( - segment[-1] - ) # origin idx is segment[-1] - 1 + segment[-1] - 1 + ) # remove origin comm_op + global_block._remove_op(segment[0] + 3) # scale + global_block._remove_op(segment[0] + 2) # c_allreduce_sum global_block._sync_with_cpp() else: # not is_first_rank_in tp or sp # need to delete the grad op assosiated with the deleted bias var @@ -645,25 +662,22 @@ def _transform_backward( else: for segment in reversed(backward_segments): add_grad_op = global_block.ops[segment[0]] - orgin_comm_op = global_block.ops[segment[-1] - 1] + orgin_comm_op = global_block.ops[segment[-1] - 2] rename_var_names_map[ add_grad_op.output_arg_names[0] ] = add_grad_op.input_arg_names[0] - global_block._remove_var(add_grad_op.output_arg_names[0]) orgin_comm_op._rename_input( orgin_comm_op.input_arg_names[0], - add_grad_op.output_arg_names[0], + add_grad_op.input_arg_names[0], ) + global_block._remove_var(add_grad_op.output_arg_names[0]) + to_delete_grad_of_param.append( add_grad_op.output_arg_names[1] ) - # remove 'elementwise_add_grad' 'c_allreduce_sum' 'scale' - global_block._remove_op( - segment[0] + 3 - ) # elementwise_add_grad - global_block._remove_op(segment[0] + 2) # c_allreduce_sum - global_block._remove_op(segment[0] + 1) # scale + global_block._remove_op(segment[0] + 2) # scale + global_block._remove_op(segment[0] + 1) # c_allreduce_sum # remove vars and op if self._enable_dp: # DP c_all_reduce_op = global_block.ops[segment[1]] @@ -702,7 +716,7 @@ def _transform_opt( """ if is_first_rank: return - + print(f"+++++ len params_grads: {len(params_grads)}") deleted_bias_grads_names = [] to_delete_params_grads = [] for id, (param, grad) in enumerate(params_grads): @@ -765,9 +779,10 @@ def _transform_opt( for id in to_delete_op_ids: global_block._remove_op(id) main_program.global_block()._sync_with_cpp() - if not is_first_rank: - for id in reversed(to_delete_params_grads): - del params_grads[id] + + for id in reversed(to_delete_params_grads): + del params_grads[id] + print(f"+++++ len params_grads: {len(params_grads)}") return def _transform_startup_program( @@ -775,23 +790,24 @@ def _transform_startup_program( ): """ Delete the vars and ops assosiated with deleted_bias_names in startup program. - TODO(lizhiyu): If the amp-o2, there are extra variables to delete, such as 'opt_linear_1.b_0_fp32_master_0'. """ logger.debug(f"Before transform startup_program: {startup_program}") cur_glock = startup_program.global_block() to_delete_op_ids = [] - to_delete_extra_vars = [] + to_delete_extra_vars = ( + [] + ) # for variables assosiated with deleted_bias_names in amp-o2, such as 'opt_linear_1.b_0_fp32_master_0' for id, op in enumerate(cur_glock.ops): if not is_first_rank: output_var = op.output_arg_names[0] if output_var in deleted_bias_names: to_delete_op_ids.append(id) - input_vars = op.input_arg_names - if len(input_vars) > 0 and input_vars[0] in deleted_bias_names: - if id not in to_delete_op_ids: - to_delete_op_ids.append(id) - if len(op.output_arg_names) == 1: - to_delete_extra_vars.append(op.output_arg_names[0]) + else: + for var_name in deleted_bias_names: + if var_name in output_var: + to_delete_op_ids.append(id) + if output_var not in to_delete_extra_vars: + to_delete_extra_vars.append(output_var) else: if op.type == "c_broadcast": input_vars = op.input_arg_names From ef6acc9f714cf8a6f04c9a38691c308526701d64 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Tue, 21 Nov 2023 21:03:00 +0800 Subject: [PATCH 04/11] delete some log --- .../distributed/passes/auto_parallel_fused_linear_promotion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py index eb22020c0a0ff..ac837f76b66d5 100644 --- a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py +++ b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py @@ -716,7 +716,6 @@ def _transform_opt( """ if is_first_rank: return - print(f"+++++ len params_grads: {len(params_grads)}") deleted_bias_grads_names = [] to_delete_params_grads = [] for id, (param, grad) in enumerate(params_grads): @@ -782,7 +781,6 @@ def _transform_opt( for id in reversed(to_delete_params_grads): del params_grads[id] - print(f"+++++ len params_grads: {len(params_grads)}") return def _transform_startup_program( From b9703a4d4f8f00de39f35f2b47f87fc3c19199b2 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Mon, 27 Nov 2023 21:09:39 +0800 Subject: [PATCH 05/11] fix bug in process_mesh --- .../auto_parallel_fused_linear_promotion.py | 53 +++++++++++-------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py index ac837f76b66d5..43c3d9d842628 100644 --- a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py +++ b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py @@ -15,6 +15,8 @@ import logging +import numpy as np + from paddle.distributed.auto_parallel.static.utils import ( is_optimize_op, is_recompute_op, @@ -42,11 +44,11 @@ FUSED_LINEAR_SOURCE_PATTERNS_LIST = [ # amp_level == 'o2' or 'o3' - { + { # only MP "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], "backward": ["elementwise_add_grad", "matmul_v2_grad"], }, - { + { # MP + SP "forward": ["matmul_v2", "c_reducescatter", "elementwise_add"], "backward": [ "elementwise_add_grad", @@ -56,7 +58,7 @@ "matmul_v2_grad", ], }, - { + { # DP + MP "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], "backward": [ "elementwise_add_grad", @@ -65,7 +67,7 @@ "matmul_v2_grad", ], }, - { + { # DP + MP + SP "forward": ["matmul_v2", "c_reducescatter", "elementwise_add"], "backward": [ "elementwise_add_grad", @@ -155,6 +157,7 @@ def _apply_single_impl(self, main_program, startup_program, context): logger.warning( "Don't support the case of enable_dp and enable_sp because the TP parallelism not ready, skip this pass" ) + return pattern_offset = 4 if self._is_amp_o1 else 0 if self._enable_sp: @@ -178,6 +181,7 @@ def _apply_single_impl(self, main_program, startup_program, context): else: logger.warning("Neither of sp and mp is enabled, skip this pass") return + # 1. get whether the current rank is first rank in mp self._is_first_rank = self._is_tp_sp_first_rank( self._dist_context, self._global_rank @@ -231,8 +235,10 @@ def _apply_single_impl(self, main_program, startup_program, context): def _is_tp_sp_first_rank(self, dist_context, rank): for process_mesh in dist_context.process_meshes: - inner_mesh = process_mesh.mesh - inner_mesh_shape = inner_mesh.shape + inner_mesh_shape = process_mesh.shape + inner_mesh = (np.array(process_mesh.process_ids)).reshape( + inner_mesh_shape + ) if len(inner_mesh_shape) == 1: return rank == min(process_mesh.process_ids) elif len(inner_mesh.shape) == 2: @@ -250,8 +256,10 @@ def _is_tp_sp_first_rank(self, dist_context, rank): def _is_enable_dp_mp(self, dist_context): for process_mesh in dist_context.process_meshes: - inner_mesh = process_mesh.mesh - inner_mesh_shape = inner_mesh.shape + inner_mesh_shape = process_mesh.shape + inner_mesh = (np.array(process_mesh.process_ids)).reshape( + inner_mesh_shape + ) if len(inner_mesh_shape) == 1: return False, inner_mesh_shape[0] > 1 else: @@ -652,11 +660,11 @@ def _transform_backward( to_delete_grad_of_param.append( add_grad_op.output_arg_names[1] ) - if self._enable_dp: - c_all_reduce_op = global_block.ops[segment[0] + 1] - scale_op = global_block.ops[segment[0] + 2] - global_block._remove_op(segment[0] + 2) - global_block._remove_op(segment[0] + 1) + # if self._enable_dp: + # c_all_reduce_op = global_block.ops[segment[0] + 1] + # scale_op = global_block.ops[segment[0] + 2] + # global_block._remove_op(segment[0] + 2) + # global_block._remove_op(segment[0] + 1) global_block._remove_op(segment[0]) global_block._sync_with_cpp() else: @@ -679,15 +687,16 @@ def _transform_backward( global_block._remove_op(segment[0] + 2) # scale global_block._remove_op(segment[0] + 1) # c_allreduce_sum # remove vars and op - if self._enable_dp: # DP - c_all_reduce_op = global_block.ops[segment[1]] - scale_op = global_block.ops[segment[2]] - global_block._remove_var( - c_all_reduce_op.input_arg_names[0] - ) - global_block._remove_var(scale_op.outpu_arg_names[0]) - global_block._remove_op(segment[2]) - global_block._remove_op(segment[1]) + # if self._enable_dp: # DP + # c_all_reduce_op = global_block.ops[segment[1]] + # scale_op = global_block.ops[segment[2]] + # global_block._remove_var( + # c_all_reduce_op.input_arg_names[0] + # ) + # global_block._remove_var(scale_op.outpu_arg_names[0]) + # global_block._remove_op(segment[2]) + # global_block._remove_op(segment[1]) + global_block._remove_op(segment[0]) global_block._sync_with_cpp() From d7f98268ed416bdde063550b419223cb2c3b535f Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Tue, 28 Nov 2023 19:17:14 +0800 Subject: [PATCH 06/11] add sp+dp support --- .../auto_parallel_fused_linear_promotion.py | 151 ++++++++++-------- 1 file changed, 84 insertions(+), 67 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py index 43c3d9d842628..5369673825cd6 100644 --- a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py +++ b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py @@ -21,10 +21,17 @@ is_optimize_op, is_recompute_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_var_dist_attr, ) from paddle.utils import unique_name from ..utils.log_utils import get_logger +from .auto_parallel_sharding import ( + _inference_data_parallel_group_for_operator, + _is_reshard_op, + _skip_ops, + is_forward_op, +) from .pass_base import PassBase, register_pass logger = get_logger(logging.INFO, "FusedLinearPromotionPass") @@ -73,6 +80,8 @@ "elementwise_add_grad", "c_allreduce_sum", "scale", + "c_allreduce_sum", + "scale", "c_allgather", "matmul_v2_grad", ], @@ -107,6 +116,8 @@ "elementwise_add_grad", "c_allreduce_sum", "scale", + "c_allreduce_sum", + "scale", "c_allgather", "matmul_v2_grad", ], @@ -152,13 +163,6 @@ def _apply_single_impl(self, main_program, startup_program, context): self._dist_context ) - # TODO(lizhiyu): We the TP with DP is ready, add this support. - if self._enable_dp and self._enable_sp: - logger.warning( - "Don't support the case of enable_dp and enable_sp because the TP parallelism not ready, skip this pass" - ) - return - pattern_offset = 4 if self._is_amp_o1 else 0 if self._enable_sp: if self._enable_dp: @@ -181,12 +185,17 @@ def _apply_single_impl(self, main_program, startup_program, context): else: logger.warning("Neither of sp and mp is enabled, skip this pass") return + dp_group = None + if self._enable_dp: + dp_group = self._collective_data_parallel_groups( + main_program.global_block() + ) # 1. get whether the current rank is first rank in mp self._is_first_rank = self._is_tp_sp_first_rank( self._dist_context, self._global_rank ) - + logger.debug(f"before main_program: {main_program}") # 2. get the forward and backward op list idexs in source patterns ( forward_segments, @@ -198,7 +207,6 @@ def _apply_single_impl(self, main_program, startup_program, context): ) return # 3 transform the forward ops - logger.debug(f"before main_program: {main_program}") rename_var_names_map, deleted_bias_names = self._transform_forward( main_program, forward_segments, @@ -230,7 +238,7 @@ def _apply_single_impl(self, main_program, startup_program, context): # 6. transform the startup program self._transform_startup_program( - startup_program, deleted_bias_names, self._is_first_rank + startup_program, deleted_bias_names, dp_group, self._is_first_rank ) def _is_tp_sp_first_rank(self, dist_context, rank): @@ -267,23 +275,6 @@ def _is_enable_dp_mp(self, dist_context): return inner_mesh_shape[-2] > 1, inner_mesh_shape[-1] > 1 return False, False - def _is_enable_sp(self, main_program): - for op in main_program.global_block().ops: - forward_has_scatter_reduce_op = False - backward_has_all_gather_op = False - if ( - int(op.desc.attr('op_role')) == 0 - and op.type == 'c_reducescatter' - ): # forward - forward_has_scatter_reduce_op = True - elif ( - int(op.desc.attr('op_role')) == 1 and op.type == 'c_allgather' - ): # backward - backward_has_all_gather_op = True - if forward_has_scatter_reduce_op and backward_has_all_gather_op: - return True - return False - def _get_forward_backward_op_segments(self, main_program): """ Get the operator segments according to the source patterns. @@ -364,6 +355,21 @@ def can_match_pattern( logger.info(f"backward_segmnets: {backward_segmnets}") return forward_segments, backward_segmnets + def _collective_data_parallel_groups(self, main_block): + for op in main_block.ops: + if not is_forward_op(op) or op.type in _skip_ops: + continue + # NOTE: there aren't dist_attr in the ops which reshard insert, + # and should be skip in sharding. + if _is_reshard_op(op): + continue + group = _inference_data_parallel_group_for_operator( + self._global_rank, op, self._dist_context + ) + if group is not None: + return group + return None + def _transform_forward( self, main_program, @@ -423,19 +429,28 @@ def _transform_forward_segment( origin_matmul_output_new_name = unique_name.generate( origin_matmul_output_name + "@promote" ) - global_block.create_var( + origin_matmul_output_new_var = global_block.create_var( name=origin_matmul_output_new_name, dtype=global_block.var(origin_matmul_output_name).dtype, shape=global_block.var(origin_matmul_output_name).shape, persistable=False, stop_gradient=False, ) + set_var_dist_attr( + self._dist_context, + origin_matmul_output_new_var, + ref_mapping, + ref_mesh, + ) rename_vars_map[ origin_matmul_output_name ] = origin_matmul_output_new_name origin_matmul_op._rename_output( origin_matmul_output_name, origin_matmul_output_new_name ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + origin_matmul_op, ref_mesh, ref_mapping, self._dist_context + ) # 3. deal add op and cast op if is_first_rank: @@ -505,9 +520,8 @@ def _transform_forward_segment( new_add_op.input_arg_names[1] ) else: - origin_add_output_name = origin_add_op.output_arg_names[ - 0 - ] # We can remove the origin_matmul_output now. + # We can remove the origin_matmul_output now. + origin_add_output_name = origin_add_op.output_arg_names[0] global_block._remove_var(origin_add_output_name) global_block._remove_var(origin_matmul_output_name) @@ -642,11 +656,17 @@ def _transform_backward( add_grad_op.output_arg_names[0], ) - global_block._remove_op( - segment[-1] - 1 - ) # remove origin comm_op - global_block._remove_op(segment[0] + 3) # scale - global_block._remove_op(segment[0] + 2) # c_allreduce_sum + global_block._remove_op(segment[-1] - 1) + if self._enable_dp: + global_block._remove_op(segment[0] + 5) # scale + global_block._remove_op( + segment[0] + 4 + ) # c_allreduce_sum + else: + global_block._remove_op(segment[0] + 3) # scale + global_block._remove_op( + segment[0] + 2 + ) # c_allreduce_sum global_block._sync_with_cpp() else: # not is_first_rank_in tp or sp # need to delete the grad op assosiated with the deleted bias var @@ -660,11 +680,11 @@ def _transform_backward( to_delete_grad_of_param.append( add_grad_op.output_arg_names[1] ) - # if self._enable_dp: - # c_all_reduce_op = global_block.ops[segment[0] + 1] - # scale_op = global_block.ops[segment[0] + 2] - # global_block._remove_op(segment[0] + 2) - # global_block._remove_op(segment[0] + 1) + if self._enable_dp: + global_block._remove_op(segment[0] + 2) # scale op + global_block._remove_op( + segment[0] + 1 + ) # c_allreduce_sum op global_block._remove_op(segment[0]) global_block._sync_with_cpp() else: @@ -679,25 +699,23 @@ def _transform_backward( add_grad_op.input_arg_names[0], ) global_block._remove_var(add_grad_op.output_arg_names[0]) - to_delete_grad_of_param.append( add_grad_op.output_arg_names[1] ) - # remove 'elementwise_add_grad' 'c_allreduce_sum' 'scale' - global_block._remove_op(segment[0] + 2) # scale - global_block._remove_op(segment[0] + 1) # c_allreduce_sum - # remove vars and op - # if self._enable_dp: # DP - # c_all_reduce_op = global_block.ops[segment[1]] - # scale_op = global_block.ops[segment[2]] - # global_block._remove_var( - # c_all_reduce_op.input_arg_names[0] - # ) - # global_block._remove_var(scale_op.outpu_arg_names[0]) - # global_block._remove_op(segment[2]) - # global_block._remove_op(segment[1]) - - global_block._remove_op(segment[0]) + if self._enable_dp: # DP + global_block._remove_op( + segment[0] + 4 + ) # scale op for dp + global_block._remove_op( + segment[0] + 3 + ) # c_allreduce_sum op for dp + global_block._remove_op(segment[0] + 2) # scale op for sp + global_block._remove_op( + segment[0] + 1 + ) # c_allreduce_sum op for sp + global_block._remove_op( + segment[0] + ) # elementwise_add_grad op global_block._sync_with_cpp() # rename input vars in gloabl_block @@ -720,9 +738,6 @@ def _transform_opt( is_first_rank, is_amp_o1, ): - """ - Only support ClipGradByGlobalNorm and AMP-O2 - """ if is_first_rank: return deleted_bias_grads_names = [] @@ -793,7 +808,7 @@ def _transform_opt( return def _transform_startup_program( - self, startup_program, deleted_bias_names, is_first_rank + self, startup_program, deleted_bias_names, dp_group, is_first_rank ): """ Delete the vars and ops assosiated with deleted_bias_names in startup program. @@ -801,9 +816,8 @@ def _transform_startup_program( logger.debug(f"Before transform startup_program: {startup_program}") cur_glock = startup_program.global_block() to_delete_op_ids = [] - to_delete_extra_vars = ( - [] - ) # for variables assosiated with deleted_bias_names in amp-o2, such as 'opt_linear_1.b_0_fp32_master_0' + # for variables assosiated with deleted_bias_names in amp-o2, such as 'opt_linear_1.b_0_fp32_master_0' + to_delete_extra_vars = [] for id, op in enumerate(cur_glock.ops): if not is_first_rank: output_var = op.output_arg_names[0] @@ -818,9 +832,12 @@ def _transform_startup_program( else: if op.type == "c_broadcast": input_vars = op.input_arg_names - if input_vars[0] in deleted_bias_names: - if id not in to_delete_op_ids: - to_delete_op_ids.append(id) + if ( + input_vars[0] in deleted_bias_names + and op.attr("ring_id") != dp_group.id + and id not in to_delete_op_ids + ): + to_delete_op_ids.append(id) for to_delete_id in reversed(to_delete_op_ids): cur_glock._remove_op(to_delete_id) if not is_first_rank: From 531171712590d8065b56958aea139dffae2f4861 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Tue, 28 Nov 2023 19:54:37 +0800 Subject: [PATCH 07/11] fix bug when dp_group is None --- .../passes/auto_parallel_fused_linear_promotion.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py index 5369673825cd6..cce8eaef8a354 100644 --- a/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py +++ b/python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py @@ -834,10 +834,13 @@ def _transform_startup_program( input_vars = op.input_arg_names if ( input_vars[0] in deleted_bias_names - and op.attr("ring_id") != dp_group.id and id not in to_delete_op_ids ): - to_delete_op_ids.append(id) + if dp_group is None or ( + dp_group is not None + and op.attr("ring_id") != dp_group.id + ): + to_delete_op_ids.append(id) for to_delete_id in reversed(to_delete_op_ids): cur_glock._remove_op(to_delete_id) if not is_first_rank: From 0c795615adef50284f63deed481f82ae502de1b5 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Thu, 30 Nov 2023 16:51:03 +0800 Subject: [PATCH 08/11] modify code according to review --- .../distributed/auto_parallel/constants.py | 4 +- .../auto_parallel/static/parallelizer_v2.py | 41 +++++++++++-------- .../distributed/auto_parallel/strategy.py | 2 +- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 43fee53c80674..3ff3d3d6df4cc 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -148,8 +148,8 @@ def set_field_default_config(category, field, default_value): # ######################################### # # offload configuration # ######################################### -FUSEDPROMOTION = "fused_promotion" -set_field_default_config(FUSEDPROMOTION, "enable", False) +FUSEDLINEARPROMOTION = "fused_linear_promotion" +set_field_default_config(FUSEDLINEARPROMOTION, "enable", False) ######################################### # fused passes configuration diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index b2e47309e052d..c23054ab1421d 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -351,26 +351,31 @@ def _apply_post_optimization( # apply fused linear promotion pass if ( self.is_train - and self._strategy.fused_promotion.enable + and self._strategy.fused_linear_promotion.enable and self._strategy.fused_passes.enable ): - amp_config = None - if self._strategy.amp.enable: - amp_config = copy.deepcopy(self._strategy.amp.to_dict()) - config = {} - config["dist_context"] = self._dist_context - config["global_rank"] = rank - config["enable_sp"] = self._strategy.sp_optimization.enable - config["params_grads"] = params_grads - config["amp_level"] = ( - amp_config['level'] if amp_config is not None else "o0" - ) - fused_promotion_pass = new_pass( - "auto_parallel_fused_linear_promotion", config - ) - fused_promotion_pass.apply( - [main_program], [startup_program], self._pass_context - ) + if ( + len(self._strategy.fused_passes.fused_passes_list) > 0 + and "fuse_gemm_epilogue" + in self._strategy.fused_passes.fused_passes_list + ): + amp_config = None + if self._strategy.amp.enable: + amp_config = copy.deepcopy(self._strategy.amp.to_dict()) + config = {} + config["dist_context"] = self._dist_context + config["global_rank"] = rank + config["enable_sp"] = self._strategy.sp_optimization.enable + config["params_grads"] = params_grads + config["amp_level"] = ( + amp_config['level'] if amp_config is not None else "o0" + ) + fused_linear_promotion_pass = new_pass( + "auto_parallel_fused_linear_promotion", config + ) + fused_linear_promotion_pass.apply( + [main_program], [startup_program], self._pass_context + ) # data parallel optimization if self._strategy.dp_optimization.enable: diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index ce2e34985c824..2b0ce8698b173 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -231,7 +231,7 @@ def __init__(self, config=None): self.fused_passes = FusedPassesConfig(config_dict) config_dict = self._config_dict.get(constants.FUSEDPROMOTION, None) - self.fused_promotion = FusedPromotionConfig(config_dict) + self.fused_linear_promotion = FusedPromotionConfig(config_dict) config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None) self.dp_optimization = DPOptimizationConfig(config_dict) From 2c98caf31c882f637726c5447a56cefa5a296a01 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Fri, 1 Dec 2023 16:06:02 +0800 Subject: [PATCH 09/11] add unit_test --- .../distributed/auto_parallel/strategy.py | 6 +- test/distributed_passes/CMakeLists.txt | 1 + ...to_parallel_fused_linear_promotion_pass.py | 204 ++++++++++++++++++ 3 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 test/distributed_passes/test_auto_parallel_fused_linear_promotion_pass.py diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 2b0ce8698b173..7c8dcb3aaa7a5 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -84,7 +84,7 @@ def __init__(self, config_dict=None): class FusedPromotionConfig(BaseConfig): def __init__(self, config_dict=None): - category = constants.FUSEDPROMOTION + category = constants.FUSEDLINEARPROMOTION super().__init__(category, config_dict) @@ -230,7 +230,9 @@ def __init__(self, config=None): config_dict = self._config_dict.get(constants.FUSED_PASSES, None) self.fused_passes = FusedPassesConfig(config_dict) - config_dict = self._config_dict.get(constants.FUSEDPROMOTION, None) + config_dict = self._config_dict.get( + constants.FUSEDLINEARPROMOTION, None + ) self.fused_linear_promotion = FusedPromotionConfig(config_dict) config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None) diff --git a/test/distributed_passes/CMakeLists.txt b/test/distributed_passes/CMakeLists.txt index 12018ff20deee..99cd70bd616e9 100644 --- a/test/distributed_passes/CMakeLists.txt +++ b/test/distributed_passes/CMakeLists.txt @@ -20,6 +20,7 @@ if((NOT WITH_GPU) AND (NOT WITH_XPU)) list(REMOVE_ITEM TEST_OPS "test_auto_parallel_gradient_merge_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_data_parallel_optimization_pass") + list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fused_linear_promotion_pass") endif() if(NOT ((WITH_GPU) AND (CUDA_VERSION GREATER_EQUAL 11.6))) diff --git a/test/distributed_passes/test_auto_parallel_fused_linear_promotion_pass.py b/test/distributed_passes/test_auto_parallel_fused_linear_promotion_pass.py new file mode 100644 index 0000000000000..77b75489f9c51 --- /dev/null +++ b/test/distributed_passes/test_auto_parallel_fused_linear_promotion_pass.py @@ -0,0 +1,204 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import paddle + +sys.path.append("../legacy_test") + +import paddle.nn.functional as F +from paddle import nn, static, utils +from paddle.base import ParamAttr +from paddle.distributed.auto_parallel.static.dist_context import ( + DistributedContext, +) +from paddle.distributed.auto_parallel.static.parallelizer_v2 import Parallelizer +from paddle.distributed.auto_parallel.static.planner_v2 import Planner +from paddle.distributed.auto_parallel.strategy import Strategy +from paddle.distributed.fleet import auto + +paddle.enable_static() +BATCH_SIZE = 4 +SEQ_LEN = 512 +HIDDEN_SIZE = 1024 +MESH_0 = auto.ProcessMesh([0, 1, 2, 3], dim_names=["x"]) + + +class MLPLayer(nn.Layer): + def __init__( + self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02, + enable_sp=False, + ): + super().__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = ParamAttr( + initializer=paddle.nn.initializer.Normal( + mean=0.0, std=initializer_range + ) + ) + self.enable_sp = enable_sp + bias_attr = True + + self.norm0 = paddle.nn.LayerNorm(d_model, epsilon=1e-5) + self.norm0.bias.stop_gradient = True + self.norm1 = paddle.nn.LayerNorm(d_model, epsilon=1e-5) + self.norm1.bias.stop_gradient = True + self.linear0 = paddle.nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr + ) + auto.shard_tensor(self.linear0.weight, MESH_0, [None, "x"]) + self.linear1 = paddle.nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr + ) + auto.shard_tensor(self.linear1.weight, MESH_0, ["x", None]) + self.dropout = paddle.nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + if self.enable_sp: + # sp region + auto.shard_tensor(input, MESH_0, ["x", None, None]) + out = self.norm0(input) + auto.shard_tensor(input, MESH_0, ["x", None, None]) + out = F.gelu(out, approximate=True) + else: + out = self.norm0(input) + out = F.gelu(out, approximate=True) + + # tp region + auto.shard_tensor(out, MESH_0, [None, None, None]) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + auto.shard_tensor(out, MESH_0, [None, None, None]) + + if self.enable_sp: + # sp region + out = self.dropout(out) + auto.shard_tensor(out, MESH_0, ["x", None, None]) + out = F.gelu(out, approximate=True) + out = self.norm1(out) + else: + out = self.dropout(out) + out = F.gelu(out, approximate=True) + out = self.norm1(out) + + return out + + +class HybridParallelNet(nn.Layer): + def __init__( + self, + hidden_size=1024, + enable_sp=False, + ): + super().__init__() + self.mlp0 = MLPLayer(hidden_size, hidden_size * 4, enable_sp=enable_sp) + self.mlp1 = MLPLayer(hidden_size, hidden_size * 4, enable_sp=enable_sp) + + def forward(self, input): + out = self.mlp0(input) + out = self.mlp1(out) + + return out + + +def get_hybrid_parallel_model(train_program, start_program, enable_sp=False): + with static.program_guard( + train_program, start_program + ), utils.unique_name.guard(): + batch_size = BATCH_SIZE + hidden_size = HIDDEN_SIZE + sequence_len = SEQ_LEN + + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32', + ) + network = HybridParallelNet( + hidden_size=HIDDEN_SIZE, enable_sp=enable_sp + ) + + predict = network(input) + error_cost = paddle.sum(predict) + + return error_cost, train_program, start_program + + +def get_dist_prog(rank=0, enable_fused_linear_promotion=False, enable_sp=False): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + + loss, train_program, startup_program = get_hybrid_parallel_model( + train_program, startup_program, enable_sp=enable_sp + ) + opt = paddle.optimizer.AdamW(learning_rate=0.00001) + strategy = Strategy() + strategy.auto_mode = "semi" + strategy.fused_passes.enable = True + strategy.sp_optimization.enable = enable_sp + strategy.fused_linear_promotion.enable = enable_fused_linear_promotion + strategy.fused_passes.fused_passes_list = ["fuse_gemm_epilogue"] + dist_context = DistributedContext( + train_program, startup_program, opt, loss, strategy=strategy + ) + planner = Planner("train", dist_context) + planner.plan() + + parallelizer = Parallelizer( + "train", + planner.completer, + dist_context, + ) + parallelizer.parallel(rank=rank) + return ( + dist_context.dist_main_programs[rank], + dist_context.dist_startup_programs[rank], + ) + + +class TestFusedLinerPromotion(unittest.TestCase): + def test_fused_linear_promotion_mp(self): + dist_main_prog, _ = get_dist_prog( + rank=0, enable_fused_linear_promotion=False, enable_sp=False + ) + ops_without_promotion = dist_main_prog.global_block().ops + oringin_fused_gemm_epilogue_ops = [ + op + for op in ops_without_promotion + if op.type == "fused_gemm_epilogue" + ] + + dist_main_prog_pro, _ = get_dist_prog( + rank=0, enable_fused_linear_promotion=True, enable_sp=False + ) + ops_with_promotion = dist_main_prog_pro.global_block().ops + fused_gemm_epilogue_ops = [ + op for op in ops_with_promotion if op.type == "fused_gemm_epilogue" + ] + self.assertEqual( + len(fused_gemm_epilogue_ops), + len(oringin_fused_gemm_epilogue_ops) + 2, + ) + + +if __name__ == "__main__": + unittest.main() From 146c0a830fe9e31cc5dbe1d755a3e20c7fbf0dd7 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Fri, 1 Dec 2023 16:08:44 +0800 Subject: [PATCH 10/11] add unit_test --- python/paddle/distributed/auto_parallel/strategy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 7c8dcb3aaa7a5..0ee9c29d23348 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -82,7 +82,7 @@ def __init__(self, config_dict=None): super().__init__(category, config_dict) -class FusedPromotionConfig(BaseConfig): +class FusedLinearPromotionConfig(BaseConfig): def __init__(self, config_dict=None): category = constants.FUSEDLINEARPROMOTION super().__init__(category, config_dict) @@ -233,7 +233,7 @@ def __init__(self, config=None): config_dict = self._config_dict.get( constants.FUSEDLINEARPROMOTION, None ) - self.fused_linear_promotion = FusedPromotionConfig(config_dict) + self.fused_linear_promotion = FusedLinearPromotionConfig(config_dict) config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None) self.dp_optimization = DPOptimizationConfig(config_dict) From 7233d7d4883dbca6a8c5f257749a232511b9e623 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Fri, 1 Dec 2023 20:30:53 +0800 Subject: [PATCH 11/11] fix the test --- test/distributed_passes/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed_passes/CMakeLists.txt b/test/distributed_passes/CMakeLists.txt index 99cd70bd616e9..d9ee247cae2ba 100644 --- a/test/distributed_passes/CMakeLists.txt +++ b/test/distributed_passes/CMakeLists.txt @@ -20,11 +20,11 @@ if((NOT WITH_GPU) AND (NOT WITH_XPU)) list(REMOVE_ITEM TEST_OPS "test_auto_parallel_gradient_merge_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_data_parallel_optimization_pass") - list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fused_linear_promotion_pass") endif() if(NOT ((WITH_GPU) AND (CUDA_VERSION GREATER_EQUAL 11.6))) list(REMOVE_ITEM TEST_OPS test_dist_fuse_gemm_epilogue_pass) + list(REMOVE_ITEM TEST_OPS test_auto_parallel_fused_linear_promotion_pass) endif() foreach(TEST_OP ${TEST_OPS})