Skip to content

Commit

Permalink
[AutoParallel] add chunk_id attr for dist_op (#59719)
Browse files Browse the repository at this point in the history
* [AutoParallel] add chunk_id attr for dist_op

* update utils funcs

* update dist ops

* fix dist_ctx

* fix dist_default

* add silu as dist_elemwise
  • Loading branch information
zhaoyinglia authored Dec 7, 2023
1 parent a2c8c9a commit ee6d976
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 110 deletions.
3 changes: 2 additions & 1 deletion python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_context import _node_id
from .operators.common import (
_gradient_sync_by_partial_ops,
find_compatible_distributed_operator_impls,
find_distributed_operator_impl_container,
)
from .operators.common import _gradient_sync_by_partial_ops
from .process_group import get_world_process_group
from .utils import (
__no_shape_var_type__,
Expand Down Expand Up @@ -152,6 +152,7 @@ def _can_apply_infer_spmd_rule(dist_op):
"transpose2",
"split",
"unsqueeze2",
"silu",
]
parallel_ce = os.getenv("PARALLEL_CROSS_ENTROPY")
if parallel_ce == "true":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def copy_dist_attr_from_graph_to_program(self):
# NOTE(zhaoyingli):
# The order of process_meshes is execution order of the ops,
# which will help pipeline strategy to get pp_rank info.
self.process_meshes = process_meshes
self.process_meshes = copy.deepcopy(process_meshes)
# TODO: the completion algorithm will skipped orphan tensors,
# here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"cast",
# "gather",
# "concat",
"silu",
"fused_softmax_mask_upper_triangle",
]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
Expand Down Expand Up @@ -392,13 +393,15 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):


def set_comm_op_dist_attr_for_program(
new_op, process_mesh, tensor_dist_attr, ctx
new_op, process_mesh, tensor_dist_attr, ctx, **kwargs
):
assert process_mesh is not None
assert tensor_dist_attr is not None

new_op_dist_attr = OperatorDistAttr()
new_op_dist_attr.process_mesh = process_mesh
if "chunk_id" in kwargs:
new_op_dist_attr.chunk_id = kwargs["chunk_id"]
for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
for output_varname in new_op.desc.output_arg_names():
Expand All @@ -410,6 +413,9 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
new_op_dist_attr = OperatorDistAttr()
new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh
new_op_dist_attr.impl_type = ref_dist_attr.impl_type
new_op_dist_attr.impl_idx = ref_dist_attr.impl_idx
new_op_dist_attr.chunk_id = ref_dist_attr.chunk_id

for input_name in ref_op.input_names:
assert input_name in new_op.input_names
Expand Down Expand Up @@ -492,6 +498,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names):

op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
process_mesh = op_dist_attr.process_mesh
chunk_id = op_dist_attr.chunk_id
dist_op_context = dist_ctx.dist_op_context
main_block = dist_op_context.work_block

Expand Down Expand Up @@ -541,6 +548,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names):
for new_op in added_ops:
new_op_attr = OperatorDistAttr()
new_op_attr.process_mesh = process_mesh
new_op_attr.chunk_id = chunk_id
new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr)
Expand Down Expand Up @@ -804,4 +812,5 @@ def copy_op_without_infer_shape(src_op, block, ctx, varname_kwargs):
new_op_desc.set_input(input_name, varname_kwargs[input_name])
for output_name in src_op.desc.output_names():
new_op_desc.set_output(output_name, varname_kwargs[output_name])
# TODO: should we add a new dist attr for the new op here?
return new_op
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License

import copy
import logging

from paddle.base.log_helper import get_logger
from paddle.common_ops_import import check_variable_and_dtype
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole

Expand All @@ -27,24 +25,18 @@
_get_corresponding_rank,
get_dist_tensor_spec,
is_dim_shard,
set_dist_op_desc_original_id,
)
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
ParallelMode,
copy_op_without_infer_shape,
infer_shape,
naive_copy_op_dist_attr_for_program,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
update_op_dims_mapping,
)

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)


class DistributedCrossEntropy(DistributedOperatorImplContainer):
def __init__(self, op_type):
Expand Down Expand Up @@ -226,29 +218,7 @@ def forward(ctx, *args, **kwargs):
['bfloat16', 'float16', 'float32', 'float64'],
'cross_entropy_with_softmax',
)

cross_entropy_op = copy_op_without_infer_shape(
src_op, main_block, ctx, kwargs
)

# set dist op's dist_attr with serial op's dist_attr
copied_op_dist_attr = OperatorDistAttr()
copied_op_dist_attr.process_mesh = op_dist_attr.process_mesh
copied_op_dist_attr.impl_type = op_dist_attr.impl_type
copied_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in cross_entropy_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, f"dist_attr is {op_dist_attr}"
copied_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
for output_varname in cross_entropy_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}"
copied_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(cross_entropy_op, copied_op_dist_attr)
copy_op_without_infer_shape(src_op, main_block, ctx, kwargs)

@staticmethod
def backward(ctx, *args, **kwargs):
Expand Down Expand Up @@ -284,16 +254,8 @@ def backward(ctx, *args, **kwargs):
), "output [Logits@GRAD] take 1 variable but got {}".format(
kwargs['Logits@GRAD']
)

# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
copy_op_without_infer_shape(backward_op, main_block, ctx, kwargs)


class DistributedCrossEntropyImpl1(DistributedOperatorImpl):
Expand Down Expand Up @@ -393,14 +355,6 @@ def forward(ctx, *args, **kwargs):
softmax_var.name
)
assert op_dist_attr_softmax is not None
loss_ref_shape = infer_shape(
main_block, loss_var, loss_dist_attr, op_dist_attr_loss
)
softmax_ref_shape = infer_shape(
main_block, softmax_var, softmax_dist_attr, op_dist_attr_softmax
)
loss_var.desc.set_shape(loss_ref_shape)
softmax_var.desc.set_shape(softmax_ref_shape)

# TODO calculate ring id
softmax_axis = src_op.desc.attr('axis')
Expand Down Expand Up @@ -431,27 +385,7 @@ def forward(ctx, *args, **kwargs):
OP_ROLE_KEY: src_op.attr('op_role'),
},
)

# set dist op's dist_attr with serial op's dist_attr
copied_op_dist_attr = OperatorDistAttr()
copied_op_dist_attr.process_mesh = op_dist_attr.process_mesh
copied_op_dist_attr.impl_type = op_dist_attr.impl_type
copied_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_cross_entropy_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, f"dist_attr is {op_dist_attr}"
copied_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
for output_varname in c_cross_entropy_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}"
copied_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(
c_cross_entropy_op, copied_op_dist_attr
)
naive_copy_op_dist_attr_for_program(c_cross_entropy_op, src_op, ctx)

@staticmethod
def backward(ctx, *args, **kwargs):
Expand Down Expand Up @@ -536,6 +470,7 @@ def backward(ctx, *args, **kwargs):
)
scale_op_attr = OperatorDistAttr()
scale_op_attr.process_mesh = op_dist_attr.process_mesh
scale_op_attr.chunk_id = op_dist_attr.chunk_id
scale_op_attr.set_output_dims_mapping(
loss_grad_var.name, dims_mapping
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@
compute_compatible_dim_mapping,
get_dist_tensor_spec,
is_prim_op,
set_dist_op_desc_original_id,
)
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
copy_op_without_infer_shape,
get_default_distributed_operator_impl,
gradient_synchronization,
is_parameter_related,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
set_comm_op_dist_attr_for_program,
update_op_dims_mapping,
)

Expand Down Expand Up @@ -113,12 +114,13 @@ def update_dims_mapping(dist_op):
op_desc = dist_op.serial_op.desc
input_arg_names = op_desc.input_arg_names()
output_arg_names = op_desc.output_arg_names()
main_block = dist_op.serial_op.block

num_inputs = len(input_arg_names)
input_specs = []
for i in range(num_inputs):
assert not is_parameter_related(
input_arg_names[i]
input_arg_names[i], main_block
), "input {} of op {} is parameter, op should not use default rule.".format(
input_arg_names[i], str(dist_op.serial_op)
)
Expand All @@ -129,7 +131,7 @@ def update_dims_mapping(dist_op):
output_specs = []
for i in range(num_outputs):
assert not is_parameter_related(
output_arg_names[i]
output_arg_names[i], main_block
), "output {} of op {} is parameter, op should not use default rule.".format(
output_arg_names[i], str(dist_op.serial_op)
)
Expand Down Expand Up @@ -530,15 +532,7 @@ def forward(ctx, *args, **kwargs):
), f"number of tensor for input [{output_name}] is not match"

# replicate op in dist program
dist_op = main_block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
# TODO: should we add a new dist attr for the new op here?
dst_op = copy_op_without_infer_shape(src_op, main_block, ctx, kwargs)

if (
src_op.has_attr('shape')
Expand All @@ -558,7 +552,7 @@ def forward(ctx, *args, **kwargs):
shape_list[idx] = (
shape_list[idx] // process_mesh_shape[axis]
)
dist_op_desc._set_attr('shape', shape_list)
dst_op.desc._set_attr('shape', shape_list)

# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
Expand All @@ -572,7 +566,7 @@ def forward(ctx, *args, **kwargs):
if src_op.type in __op_not_need_param_init__:
return

for varname in dist_op_desc.input_arg_names():
for varname in dst_op.desc.input_arg_names():
if (
startup_block.has_var(varname)
and startup_block.var(varname).is_parameter
Expand Down Expand Up @@ -614,15 +608,12 @@ def forward(ctx, *args, **kwargs):
OP_ROLE_KEY: OpRole.Forward,
},
)

# set distributed attribute
op_attr = OperatorDistAttr()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(
param.name, dims_mapping
set_comm_op_dist_attr_for_program(
new_op,
process_mesh,
param_dist_attr,
ctx,
)
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr)

@staticmethod
def backward(ctx, *args, **kwargs):
Expand All @@ -649,14 +640,7 @@ def backward(ctx, *args, **kwargs):
), f"number of tensor for input [{output_name}] is not match"

# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
copy_op_without_infer_shape(backward_op, main_block, ctx, kwargs)

# data parallel gradient synchronization
act_grad_names = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ def forward(ctx, *args, **kwargs):
# set new seed_var's dist_attr
seed_var_dims_mapping = [-1]
seed_var_dist_attr = set_var_dist_attr(
ctx, seed_var, seed_var_dims_mapping, process_mesh
ctx,
seed_var,
seed_var_dims_mapping,
process_mesh,
chunk_id=op_dist_attr.chunk_id,
)

# adopt for recompute
Expand All @@ -205,7 +209,11 @@ def forward(ctx, *args, **kwargs):
seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed')
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, process_mesh, seed_var_dims_mapping, ctx
seed_op,
process_mesh,
seed_var_dims_mapping,
ctx,
chunk_id=op_dist_attr.chunk_id,
)

# modify dropout op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,16 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
intermediate_var_0,
Ids_var_dist_attr.dims_mapping,
Ids_var_dist_attr.process_mesh,
chunk_id=Ids_var_dist_attr.chunk_id,
)
set_var_dist_attr(
ctx,
xshape_var,
[-1] + list(Ids_var_dist_attr.dims_mapping),
Ids_var_dist_attr.process_mesh,
chunk_id=Ids_var_dist_attr.chunk_id,
)
# rename src_op's input
src_op._rename_input(Ids_var.name, intermediate_var_0.name)
op_dist_attr.del_input_dist_attr(Ids_var.name)
op_dist_attr.set_input_dist_attr(
Expand All @@ -198,6 +201,7 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
new_op_dist_attr.impl_type = "default"
new_op_dist_attr.impl_idx = 0
new_op_dist_attr.chunk_id = Ids_var_dist_attr.chunk_id
new_op_dist_attr.set_input_dims_mapping(
Ids_var.name, Ids_var_dist_attr.dims_mapping
)
Expand Down Expand Up @@ -530,6 +534,7 @@ def forward(ctx, *args, **kwargs):
op_dist_attr.process_mesh,
out_var_dist_attr,
ctx,
chunk_id=op_dist_attr.chunk_id,
)

# param initialization sync
Expand Down
Loading

0 comments on commit ee6d976

Please sign in to comment.