From ee6d976b2709e7e4a2ac0037e4549fab09ed6a88 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Thu, 7 Dec 2023 15:36:27 +0800 Subject: [PATCH] [AutoParallel] add chunk_id attr for dist_op (#59719) * [AutoParallel] add chunk_id attr for dist_op * update utils funcs * update dist ops * fix dist_ctx * fix dist_default * add silu as dist_elemwise --- .../auto_parallel/static/completion.py | 3 +- .../auto_parallel/static/dist_context.py | 2 +- .../auto_parallel/static/operators/common.py | 11 ++- .../static/operators/dist_cross_entropy.py | 73 +------------------ .../static/operators/dist_default.py | 44 ++++------- .../static/operators/dist_dropout.py | 12 ++- .../static/operators/dist_embedding.py | 5 ++ .../operators/dist_fused_dropout_add.py | 12 ++- .../static/operators/dist_matmul.py | 6 +- .../static/operators/dist_pnorm.py | 3 + .../distributed/auto_parallel/static/utils.py | 19 ++++- 11 files changed, 80 insertions(+), 110 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/completion.py b/python/paddle/distributed/auto_parallel/static/completion.py index b9d3038ad87ec..912e921691453 100644 --- a/python/paddle/distributed/auto_parallel/static/completion.py +++ b/python/paddle/distributed/auto_parallel/static/completion.py @@ -29,10 +29,10 @@ from .dist_attribute import OperatorDistAttr, TensorDistAttr from .dist_context import _node_id from .operators.common import ( + _gradient_sync_by_partial_ops, find_compatible_distributed_operator_impls, find_distributed_operator_impl_container, ) -from .operators.common import _gradient_sync_by_partial_ops from .process_group import get_world_process_group from .utils import ( __no_shape_var_type__, @@ -152,6 +152,7 @@ def _can_apply_infer_spmd_rule(dist_op): "transpose2", "split", "unsqueeze2", + "silu", ] parallel_ce = os.getenv("PARALLEL_CROSS_ENTROPY") if parallel_ce == "true": diff --git a/python/paddle/distributed/auto_parallel/static/dist_context.py b/python/paddle/distributed/auto_parallel/static/dist_context.py index aa9dfd2128aa7..d52f1604d3a9d 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_context.py +++ b/python/paddle/distributed/auto_parallel/static/dist_context.py @@ -896,7 +896,7 @@ def copy_dist_attr_from_graph_to_program(self): # NOTE(zhaoyingli): # The order of process_meshes is execution order of the ops, # which will help pipeline strategy to get pp_rank info. - self.process_meshes = process_meshes + self.process_meshes = copy.deepcopy(process_meshes) # TODO: the completion algorithm will skipped orphan tensors, # here we just set there process_mesh to the first one. for orphan_node in self._serial_orphan_tensor_nodes: diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index 3fb2f4fe72619..a79a3766c5a19 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -44,6 +44,7 @@ "cast", # "gather", # "concat", + "silu", "fused_softmax_mask_upper_triangle", ] BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} @@ -392,13 +393,15 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): def set_comm_op_dist_attr_for_program( - new_op, process_mesh, tensor_dist_attr, ctx + new_op, process_mesh, tensor_dist_attr, ctx, **kwargs ): assert process_mesh is not None assert tensor_dist_attr is not None new_op_dist_attr = OperatorDistAttr() new_op_dist_attr.process_mesh = process_mesh + if "chunk_id" in kwargs: + new_op_dist_attr.chunk_id = kwargs["chunk_id"] for input_varname in new_op.desc.input_arg_names(): new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr) for output_varname in new_op.desc.output_arg_names(): @@ -410,6 +413,9 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx): ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op) new_op_dist_attr = OperatorDistAttr() new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh + new_op_dist_attr.impl_type = ref_dist_attr.impl_type + new_op_dist_attr.impl_idx = ref_dist_attr.impl_idx + new_op_dist_attr.chunk_id = ref_dist_attr.chunk_id for input_name in ref_op.input_names: assert input_name in new_op.input_names @@ -492,6 +498,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) process_mesh = op_dist_attr.process_mesh + chunk_id = op_dist_attr.chunk_id dist_op_context = dist_ctx.dist_op_context main_block = dist_op_context.work_block @@ -541,6 +548,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): for new_op in added_ops: new_op_attr = OperatorDistAttr() new_op_attr.process_mesh = process_mesh + new_op_attr.chunk_id = chunk_id new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr) @@ -804,4 +812,5 @@ def copy_op_without_infer_shape(src_op, block, ctx, varname_kwargs): new_op_desc.set_input(input_name, varname_kwargs[input_name]) for output_name in src_op.desc.output_names(): new_op_desc.set_output(output_name, varname_kwargs[output_name]) + # TODO: should we add a new dist attr for the new op here? return new_op diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_cross_entropy.py b/python/paddle/distributed/auto_parallel/static/operators/dist_cross_entropy.py index 67d0b52f701f7..b1593a61cd73a 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_cross_entropy.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_cross_entropy.py @@ -13,9 +13,7 @@ # limitations under the License import copy -import logging -from paddle.base.log_helper import get_logger from paddle.common_ops_import import check_variable_and_dtype from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole @@ -27,24 +25,18 @@ _get_corresponding_rank, get_dist_tensor_spec, is_dim_shard, - set_dist_op_desc_original_id, ) from .common import ( DistributedOperatorImpl, DistributedOperatorImplContainer, ParallelMode, copy_op_without_infer_shape, - infer_shape, naive_copy_op_dist_attr_for_program, register_distributed_operator_impl, register_distributed_operator_impl_container, update_op_dims_mapping, ) -_logger = get_logger( - __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' -) - class DistributedCrossEntropy(DistributedOperatorImplContainer): def __init__(self, op_type): @@ -226,29 +218,7 @@ def forward(ctx, *args, **kwargs): ['bfloat16', 'float16', 'float32', 'float64'], 'cross_entropy_with_softmax', ) - - cross_entropy_op = copy_op_without_infer_shape( - src_op, main_block, ctx, kwargs - ) - - # set dist op's dist_attr with serial op's dist_attr - copied_op_dist_attr = OperatorDistAttr() - copied_op_dist_attr.process_mesh = op_dist_attr.process_mesh - copied_op_dist_attr.impl_type = op_dist_attr.impl_type - copied_op_dist_attr.impl_idx = op_dist_attr.impl_idx - for input_varname in cross_entropy_op.desc.input_arg_names(): - input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) - assert input_dist_attr is not None, f"dist_attr is {op_dist_attr}" - copied_op_dist_attr.set_input_dist_attr( - input_varname, input_dist_attr - ) - for output_varname in cross_entropy_op.desc.output_arg_names(): - output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) - assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}" - copied_op_dist_attr.set_output_dist_attr( - output_varname, output_dist_attr - ) - ctx.set_op_dist_attr_for_program(cross_entropy_op, copied_op_dist_attr) + copy_op_without_infer_shape(src_op, main_block, ctx, kwargs) @staticmethod def backward(ctx, *args, **kwargs): @@ -284,16 +254,8 @@ def backward(ctx, *args, **kwargs): ), "output [Logits@GRAD] take 1 variable but got {}".format( kwargs['Logits@GRAD'] ) - # replicate op in dist program - dist_op_desc = main_block.append_op(type='nop').desc - dist_op_desc.copy_from(backward_op.desc) - # Refer to the related dist op - set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) - for input_name in backward_op.desc.input_names(): - dist_op_desc.set_input(input_name, kwargs[input_name]) - for output_name in backward_op.desc.output_names(): - dist_op_desc.set_output(output_name, kwargs[output_name]) + copy_op_without_infer_shape(backward_op, main_block, ctx, kwargs) class DistributedCrossEntropyImpl1(DistributedOperatorImpl): @@ -393,14 +355,6 @@ def forward(ctx, *args, **kwargs): softmax_var.name ) assert op_dist_attr_softmax is not None - loss_ref_shape = infer_shape( - main_block, loss_var, loss_dist_attr, op_dist_attr_loss - ) - softmax_ref_shape = infer_shape( - main_block, softmax_var, softmax_dist_attr, op_dist_attr_softmax - ) - loss_var.desc.set_shape(loss_ref_shape) - softmax_var.desc.set_shape(softmax_ref_shape) # TODO calculate ring id softmax_axis = src_op.desc.attr('axis') @@ -431,27 +385,7 @@ def forward(ctx, *args, **kwargs): OP_ROLE_KEY: src_op.attr('op_role'), }, ) - - # set dist op's dist_attr with serial op's dist_attr - copied_op_dist_attr = OperatorDistAttr() - copied_op_dist_attr.process_mesh = op_dist_attr.process_mesh - copied_op_dist_attr.impl_type = op_dist_attr.impl_type - copied_op_dist_attr.impl_idx = op_dist_attr.impl_idx - for input_varname in c_cross_entropy_op.desc.input_arg_names(): - input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) - assert input_dist_attr is not None, f"dist_attr is {op_dist_attr}" - copied_op_dist_attr.set_input_dist_attr( - input_varname, input_dist_attr - ) - for output_varname in c_cross_entropy_op.desc.output_arg_names(): - output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) - assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}" - copied_op_dist_attr.set_output_dist_attr( - output_varname, output_dist_attr - ) - ctx.set_op_dist_attr_for_program( - c_cross_entropy_op, copied_op_dist_attr - ) + naive_copy_op_dist_attr_for_program(c_cross_entropy_op, src_op, ctx) @staticmethod def backward(ctx, *args, **kwargs): @@ -536,6 +470,7 @@ def backward(ctx, *args, **kwargs): ) scale_op_attr = OperatorDistAttr() scale_op_attr.process_mesh = op_dist_attr.process_mesh + scale_op_attr.chunk_id = op_dist_attr.chunk_id scale_op_attr.set_output_dims_mapping( loss_grad_var.name, dims_mapping ) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_default.py b/python/paddle/distributed/auto_parallel/static/operators/dist_default.py index 04afe623cbafc..b4303e3ab8c72 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_default.py @@ -30,16 +30,17 @@ compute_compatible_dim_mapping, get_dist_tensor_spec, is_prim_op, - set_dist_op_desc_original_id, ) from .common import ( DistributedOperatorImpl, DistributedOperatorImplContainer, + copy_op_without_infer_shape, get_default_distributed_operator_impl, gradient_synchronization, is_parameter_related, register_distributed_operator_impl, register_distributed_operator_impl_container, + set_comm_op_dist_attr_for_program, update_op_dims_mapping, ) @@ -113,12 +114,13 @@ def update_dims_mapping(dist_op): op_desc = dist_op.serial_op.desc input_arg_names = op_desc.input_arg_names() output_arg_names = op_desc.output_arg_names() + main_block = dist_op.serial_op.block num_inputs = len(input_arg_names) input_specs = [] for i in range(num_inputs): assert not is_parameter_related( - input_arg_names[i] + input_arg_names[i], main_block ), "input {} of op {} is parameter, op should not use default rule.".format( input_arg_names[i], str(dist_op.serial_op) ) @@ -129,7 +131,7 @@ def update_dims_mapping(dist_op): output_specs = [] for i in range(num_outputs): assert not is_parameter_related( - output_arg_names[i] + output_arg_names[i], main_block ), "output {} of op {} is parameter, op should not use default rule.".format( output_arg_names[i], str(dist_op.serial_op) ) @@ -530,15 +532,7 @@ def forward(ctx, *args, **kwargs): ), f"number of tensor for input [{output_name}] is not match" # replicate op in dist program - dist_op = main_block.append_op(type='nop') - dist_op_desc = dist_op.desc - dist_op_desc.copy_from(src_op.desc) - set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) - for input_name in src_op.desc.input_names(): - dist_op_desc.set_input(input_name, kwargs[input_name]) - for output_name in src_op.desc.output_names(): - dist_op_desc.set_output(output_name, kwargs[output_name]) - # TODO: should we add a new dist attr for the new op here? + dst_op = copy_op_without_infer_shape(src_op, main_block, ctx, kwargs) if ( src_op.has_attr('shape') @@ -558,7 +552,7 @@ def forward(ctx, *args, **kwargs): shape_list[idx] = ( shape_list[idx] // process_mesh_shape[axis] ) - dist_op_desc._set_attr('shape', shape_list) + dst_op.desc._set_attr('shape', shape_list) # data parallel synchronization for primtive operators from paddle.incubate.autograd import prim_enabled @@ -572,7 +566,7 @@ def forward(ctx, *args, **kwargs): if src_op.type in __op_not_need_param_init__: return - for varname in dist_op_desc.input_arg_names(): + for varname in dst_op.desc.input_arg_names(): if ( startup_block.has_var(varname) and startup_block.var(varname).is_parameter @@ -614,15 +608,12 @@ def forward(ctx, *args, **kwargs): OP_ROLE_KEY: OpRole.Forward, }, ) - - # set distributed attribute - op_attr = OperatorDistAttr() - op_attr.process_mesh = process_mesh - op_attr.set_output_dims_mapping( - param.name, dims_mapping + set_comm_op_dist_attr_for_program( + new_op, + process_mesh, + param_dist_attr, + ctx, ) - op_attr.set_input_dims_mapping(param.name, dims_mapping) - ctx.set_op_dist_attr_for_program(new_op, op_attr) @staticmethod def backward(ctx, *args, **kwargs): @@ -649,14 +640,7 @@ def backward(ctx, *args, **kwargs): ), f"number of tensor for input [{output_name}] is not match" # replicate op in dist program - dist_op_desc = main_block.append_op(type='nop').desc - dist_op_desc.copy_from(backward_op.desc) - # Refer to the related dist op - set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) - for input_name in backward_op.desc.input_names(): - dist_op_desc.set_input(input_name, kwargs[input_name]) - for output_name in backward_op.desc.output_names(): - dist_op_desc.set_output(output_name, kwargs[output_name]) + copy_op_without_infer_shape(backward_op, main_block, ctx, kwargs) # data parallel gradient synchronization act_grad_names = [] diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py index 913e1340b32f0..ccc663f90e297 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py @@ -188,7 +188,11 @@ def forward(ctx, *args, **kwargs): # set new seed_var's dist_attr seed_var_dims_mapping = [-1] seed_var_dist_attr = set_var_dist_attr( - ctx, seed_var, seed_var_dims_mapping, process_mesh + ctx, + seed_var, + seed_var_dims_mapping, + process_mesh, + chunk_id=op_dist_attr.chunk_id, ) # adopt for recompute @@ -205,7 +209,11 @@ def forward(ctx, *args, **kwargs): seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed') # set new seed op's dist_attr naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - seed_op, process_mesh, seed_var_dims_mapping, ctx + seed_op, + process_mesh, + seed_var_dims_mapping, + ctx, + chunk_id=op_dist_attr.chunk_id, ) # modify dropout op diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py index e4f9ff09fca20..bb9dbe8cd4694 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py @@ -181,13 +181,16 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): intermediate_var_0, Ids_var_dist_attr.dims_mapping, Ids_var_dist_attr.process_mesh, + chunk_id=Ids_var_dist_attr.chunk_id, ) set_var_dist_attr( ctx, xshape_var, [-1] + list(Ids_var_dist_attr.dims_mapping), Ids_var_dist_attr.process_mesh, + chunk_id=Ids_var_dist_attr.chunk_id, ) + # rename src_op's input src_op._rename_input(Ids_var.name, intermediate_var_0.name) op_dist_attr.del_input_dist_attr(Ids_var.name) op_dist_attr.set_input_dist_attr( @@ -198,6 +201,7 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh new_op_dist_attr.impl_type = "default" new_op_dist_attr.impl_idx = 0 + new_op_dist_attr.chunk_id = Ids_var_dist_attr.chunk_id new_op_dist_attr.set_input_dims_mapping( Ids_var.name, Ids_var_dist_attr.dims_mapping ) @@ -530,6 +534,7 @@ def forward(ctx, *args, **kwargs): op_dist_attr.process_mesh, out_var_dist_attr, ctx, + chunk_id=op_dist_attr.chunk_id, ) # param initialization sync diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py index 44a99efcebfe1..f715f6f283c4b 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py @@ -145,7 +145,11 @@ def forward(ctx, *args, **kwargs): # set new seed_var's dist_attr seed_var_dims_mapping = [-1] seed_var_dist_attr = set_var_dist_attr( - ctx, seed_var, seed_var_dims_mapping, process_mesh + ctx, + seed_var, + seed_var_dims_mapping, + process_mesh, + chunk_id=op_dist_attr.chunk_id, ) # adopt for recompute @@ -162,7 +166,11 @@ def forward(ctx, *args, **kwargs): seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed') # set new seed op's dist_attr naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - seed_op, process_mesh, seed_var_dims_mapping, ctx + seed_op, + process_mesh, + seed_var_dims_mapping, + ctx, + chunk_id=op_dist_attr.chunk_id, ) # modify dropout op diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py index e35f257a57c6f..a8a0638889a49 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py @@ -392,7 +392,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ctx, main_block, backward_op, **kwargs ) else: - # col parallel: matmul + allreduce + # col parallel: matmul_grad + allreduce col_parallel = True assert Y_var_dim_mapping[0] < 0 parallel_axis = Y_var_dim_mapping[1] @@ -457,6 +457,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): dist_attr.process_mesh, X_grad_dist_attr, ctx, + chunk_id=dist_attr.chunk_id, ) if trans_x: @@ -1115,6 +1116,7 @@ def forward(ctx, *args, **kwargs): op_dist_attr.process_mesh, out_var_dist_attr, ctx, + chunk_id=op_dist_attr.chunk_id, ) # init param sync @@ -1806,6 +1808,7 @@ def forward(ctx, *args, **kwargs): op_dist_attr.process_mesh, out_var_dist_attr, ctx, + chunk_id=op_dist_attr.chunk_id, ) # init param sync @@ -2475,6 +2478,7 @@ def forward(ctx, *args, **kwargs): op_dist_attr.process_mesh, out_var_dist_attr, ctx, + chunk_id=op_dist_attr.chunk_id, ) # init param sync diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/static/operators/dist_pnorm.py index 9f322cb5caf8a..6342056a4af95 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_pnorm.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_pnorm.py @@ -213,6 +213,7 @@ def forward(ctx, *args, **kwargs): # set allgather_out tensor dist_attr allgather_out_dist_attr = TensorDistAttr() allgather_out_dist_attr.process_mesh = op_dist_attr.process_mesh + allgather_out_dist_attr.chunk_id = op_dist_attr.chunk_id allgather_out_dist_attr.dims_mapping = [ -1 for i in range(len(allgather_out.shape)) ] @@ -233,6 +234,7 @@ def forward(ctx, *args, **kwargs): # set c_allgather op dist_attr allgather_op_dist_attr = OperatorDistAttr() allgather_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allgather_op_dist_attr.chunk_id = op_dist_attr.chunk_id allgather_op_dist_attr.set_input_dims_mapping( X_var.name, in_dims_mapping ) @@ -366,6 +368,7 @@ def backward(ctx, *args, **kwargs): ) slice_op_dist_attr = OperatorDistAttr() slice_op_dist_attr.process_mesh = op_dist_attr.process_mesh + slice_op_dist_attr.chunk_id = op_dist_attr.chunk_id slice_op_dist_attr.set_input_dims_mapping( new_X_grad.name, new_X_var_dist_attr.dims_mapping ) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 905d04782a007..00b5d081648a3 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -1258,6 +1258,12 @@ def is_gradient_clip_op(op): ).startswith("/gradient_clip") +def is_reshard_op(op): + return op.desc.has_attr( + "op_namescope" + ) and "/auto_parallel/reshard" in op.desc.attr('op_namescope') + + def is_prim_op(op): return op.type.endswith("_p") @@ -1296,12 +1302,14 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): if "mark_annotated" in kwargs and kwargs["mark_annotated"]: tensor_dist_attr.mark_annotated("dims_mapping") tensor_dist_attr.mark_annotated("process_mesh") + if "chunk_id" in kwargs and kwargs["chunk_id"]: + tensor_dist_attr.chunk_id = kwargs["chunk_id"] dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) return tensor_dist_attr def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - new_op, process_mesh, ref_mapping, ctx + new_op, process_mesh, ref_mapping, ctx, **kwargs ): assert process_mesh is not None assert ref_mapping is not None @@ -1314,11 +1322,13 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op_dist_attr.set_output_dims_mapping(output_varname, ref_mapping) new_op_dist_attr.process_mesh = process_mesh + if "chunk_id" in kwargs and kwargs["chunk_id"]: + new_op_dist_attr.chunk_id = kwargs["chunk_id"] ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) def naive_set_dist_op_attr_for_program_by_mesh( - new_op, process_mesh, ctx, is_recompute=False + new_op, process_mesh, ctx, **kwargs ): assert process_mesh is not None @@ -1334,7 +1344,10 @@ def naive_set_dist_op_attr_for_program_by_mesh( new_op_dist_attr.set_output_dims_mapping(output_varname, mapping) new_op_dist_attr.process_mesh = process_mesh - new_op_dist_attr.is_recompute = is_recompute + if "is_recompute" in kwargs: + new_op_dist_attr.is_recompute = kwargs["is_recompute"] + if "chunk_id" in kwargs: + new_op_dist_attr.chunk_id = kwargs["chunk_id"] ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)