Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] add chunk_id attr for dist_op #59719

Merged
merged 8 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_context import _node_id
from .operators.common import (
_gradient_sync_by_partial_ops,
find_compatible_distributed_operator_impls,
find_distributed_operator_impl_container,
)
from .operators.common import _gradient_sync_by_partial_ops
from .process_group import get_world_process_group
from .utils import (
__no_shape_var_type__,
Expand Down Expand Up @@ -152,6 +152,7 @@ def _can_apply_infer_spmd_rule(dist_op):
"transpose2",
"split",
"unsqueeze2",
"silu",
]
parallel_ce = os.getenv("PARALLEL_CROSS_ENTROPY")
if parallel_ce == "true":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def copy_dist_attr_from_graph_to_program(self):
# NOTE(zhaoyingli):
# The order of process_meshes is execution order of the ops,
# which will help pipeline strategy to get pp_rank info.
self.process_meshes = process_meshes
self.process_meshes = copy.deepcopy(process_meshes)
# TODO: the completion algorithm will skipped orphan tensors,
# here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"cast",
# "gather",
# "concat",
"silu",
"fused_softmax_mask_upper_triangle",
]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
Expand Down Expand Up @@ -392,13 +393,15 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):


def set_comm_op_dist_attr_for_program(
new_op, process_mesh, tensor_dist_attr, ctx
new_op, process_mesh, tensor_dist_attr, ctx, **kwargs
):
assert process_mesh is not None
assert tensor_dist_attr is not None

new_op_dist_attr = OperatorDistAttr()
new_op_dist_attr.process_mesh = process_mesh
if "chunk_id" in kwargs:
new_op_dist_attr.chunk_id = kwargs["chunk_id"]
for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
for output_varname in new_op.desc.output_arg_names():
Expand All @@ -410,6 +413,9 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
new_op_dist_attr = OperatorDistAttr()
new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh
new_op_dist_attr.impl_type = ref_dist_attr.impl_type
new_op_dist_attr.impl_idx = ref_dist_attr.impl_idx
new_op_dist_attr.chunk_id = ref_dist_attr.chunk_id

for input_name in ref_op.input_names:
assert input_name in new_op.input_names
Expand Down Expand Up @@ -492,6 +498,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names):

op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
process_mesh = op_dist_attr.process_mesh
chunk_id = op_dist_attr.chunk_id
dist_op_context = dist_ctx.dist_op_context
main_block = dist_op_context.work_block

Expand Down Expand Up @@ -541,6 +548,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names):
for new_op in added_ops:
new_op_attr = OperatorDistAttr()
new_op_attr.process_mesh = process_mesh
new_op_attr.chunk_id = chunk_id
new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr)
Expand Down Expand Up @@ -804,4 +812,5 @@ def copy_op_without_infer_shape(src_op, block, ctx, varname_kwargs):
new_op_desc.set_input(input_name, varname_kwargs[input_name])
for output_name in src_op.desc.output_names():
new_op_desc.set_output(output_name, varname_kwargs[output_name])
# TODO: should we add a new dist attr for the new op here?
return new_op
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License

import copy
import logging

from paddle.base.log_helper import get_logger
from paddle.common_ops_import import check_variable_and_dtype
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole

Expand All @@ -27,24 +25,18 @@
_get_corresponding_rank,
get_dist_tensor_spec,
is_dim_shard,
set_dist_op_desc_original_id,
)
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
ParallelMode,
copy_op_without_infer_shape,
infer_shape,
naive_copy_op_dist_attr_for_program,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
update_op_dims_mapping,
)

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)


class DistributedCrossEntropy(DistributedOperatorImplContainer):
def __init__(self, op_type):
Expand Down Expand Up @@ -226,29 +218,7 @@ def forward(ctx, *args, **kwargs):
['bfloat16', 'float16', 'float32', 'float64'],
'cross_entropy_with_softmax',
)

cross_entropy_op = copy_op_without_infer_shape(
src_op, main_block, ctx, kwargs
)

# set dist op's dist_attr with serial op's dist_attr
copied_op_dist_attr = OperatorDistAttr()
copied_op_dist_attr.process_mesh = op_dist_attr.process_mesh
copied_op_dist_attr.impl_type = op_dist_attr.impl_type
copied_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in cross_entropy_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, f"dist_attr is {op_dist_attr}"
copied_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
for output_varname in cross_entropy_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}"
copied_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(cross_entropy_op, copied_op_dist_attr)
copy_op_without_infer_shape(src_op, main_block, ctx, kwargs)

@staticmethod
def backward(ctx, *args, **kwargs):
Expand Down Expand Up @@ -284,16 +254,8 @@ def backward(ctx, *args, **kwargs):
), "output [Logits@GRAD] take 1 variable but got {}".format(
kwargs['Logits@GRAD']
)

# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
copy_op_without_infer_shape(backward_op, main_block, ctx, kwargs)


class DistributedCrossEntropyImpl1(DistributedOperatorImpl):
Expand Down Expand Up @@ -393,14 +355,6 @@ def forward(ctx, *args, **kwargs):
softmax_var.name
)
assert op_dist_attr_softmax is not None
loss_ref_shape = infer_shape(
main_block, loss_var, loss_dist_attr, op_dist_attr_loss
)
softmax_ref_shape = infer_shape(
main_block, softmax_var, softmax_dist_attr, op_dist_attr_softmax
)
loss_var.desc.set_shape(loss_ref_shape)
softmax_var.desc.set_shape(softmax_ref_shape)

# TODO calculate ring id
softmax_axis = src_op.desc.attr('axis')
Expand Down Expand Up @@ -431,27 +385,7 @@ def forward(ctx, *args, **kwargs):
OP_ROLE_KEY: src_op.attr('op_role'),
},
)

# set dist op's dist_attr with serial op's dist_attr
copied_op_dist_attr = OperatorDistAttr()
copied_op_dist_attr.process_mesh = op_dist_attr.process_mesh
copied_op_dist_attr.impl_type = op_dist_attr.impl_type
copied_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_cross_entropy_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, f"dist_attr is {op_dist_attr}"
copied_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
for output_varname in c_cross_entropy_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}"
copied_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(
c_cross_entropy_op, copied_op_dist_attr
)
naive_copy_op_dist_attr_for_program(c_cross_entropy_op, src_op, ctx)

@staticmethod
def backward(ctx, *args, **kwargs):
Expand Down Expand Up @@ -536,6 +470,7 @@ def backward(ctx, *args, **kwargs):
)
scale_op_attr = OperatorDistAttr()
scale_op_attr.process_mesh = op_dist_attr.process_mesh
scale_op_attr.chunk_id = op_dist_attr.chunk_id
scale_op_attr.set_output_dims_mapping(
loss_grad_var.name, dims_mapping
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@
compute_compatible_dim_mapping,
get_dist_tensor_spec,
is_prim_op,
set_dist_op_desc_original_id,
)
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
copy_op_without_infer_shape,
get_default_distributed_operator_impl,
gradient_synchronization,
is_parameter_related,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
set_comm_op_dist_attr_for_program,
update_op_dims_mapping,
)

Expand Down Expand Up @@ -113,12 +114,13 @@ def update_dims_mapping(dist_op):
op_desc = dist_op.serial_op.desc
input_arg_names = op_desc.input_arg_names()
output_arg_names = op_desc.output_arg_names()
main_block = dist_op.serial_op.block

num_inputs = len(input_arg_names)
input_specs = []
for i in range(num_inputs):
assert not is_parameter_related(
input_arg_names[i]
input_arg_names[i], main_block
), "input {} of op {} is parameter, op should not use default rule.".format(
input_arg_names[i], str(dist_op.serial_op)
)
Expand All @@ -129,7 +131,7 @@ def update_dims_mapping(dist_op):
output_specs = []
for i in range(num_outputs):
assert not is_parameter_related(
output_arg_names[i]
output_arg_names[i], main_block
), "output {} of op {} is parameter, op should not use default rule.".format(
output_arg_names[i], str(dist_op.serial_op)
)
Expand Down Expand Up @@ -530,15 +532,7 @@ def forward(ctx, *args, **kwargs):
), f"number of tensor for input [{output_name}] is not match"

# replicate op in dist program
dist_op = main_block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
# TODO: should we add a new dist attr for the new op here?
dst_op = copy_op_without_infer_shape(src_op, main_block, ctx, kwargs)

if (
src_op.has_attr('shape')
Expand All @@ -558,7 +552,7 @@ def forward(ctx, *args, **kwargs):
shape_list[idx] = (
shape_list[idx] // process_mesh_shape[axis]
)
dist_op_desc._set_attr('shape', shape_list)
dst_op.desc._set_attr('shape', shape_list)

# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
Expand All @@ -572,7 +566,7 @@ def forward(ctx, *args, **kwargs):
if src_op.type in __op_not_need_param_init__:
return

for varname in dist_op_desc.input_arg_names():
for varname in dst_op.desc.input_arg_names():
if (
startup_block.has_var(varname)
and startup_block.var(varname).is_parameter
Expand Down Expand Up @@ -614,15 +608,12 @@ def forward(ctx, *args, **kwargs):
OP_ROLE_KEY: OpRole.Forward,
},
)

# set distributed attribute
op_attr = OperatorDistAttr()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(
param.name, dims_mapping
set_comm_op_dist_attr_for_program(
new_op,
process_mesh,
param_dist_attr,
ctx,
)
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr)

@staticmethod
def backward(ctx, *args, **kwargs):
Expand All @@ -649,14 +640,7 @@ def backward(ctx, *args, **kwargs):
), f"number of tensor for input [{output_name}] is not match"

# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
copy_op_without_infer_shape(backward_op, main_block, ctx, kwargs)

# data parallel gradient synchronization
act_grad_names = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ def forward(ctx, *args, **kwargs):
# set new seed_var's dist_attr
seed_var_dims_mapping = [-1]
seed_var_dist_attr = set_var_dist_attr(
ctx, seed_var, seed_var_dims_mapping, process_mesh
ctx,
seed_var,
seed_var_dims_mapping,
process_mesh,
chunk_id=op_dist_attr.chunk_id,
)

# adopt for recompute
Expand All @@ -205,7 +209,11 @@ def forward(ctx, *args, **kwargs):
seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed')
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, process_mesh, seed_var_dims_mapping, ctx
seed_op,
process_mesh,
seed_var_dims_mapping,
ctx,
chunk_id=op_dist_attr.chunk_id,
)

# modify dropout op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,16 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
intermediate_var_0,
Ids_var_dist_attr.dims_mapping,
Ids_var_dist_attr.process_mesh,
chunk_id=Ids_var_dist_attr.chunk_id,
)
set_var_dist_attr(
ctx,
xshape_var,
[-1] + list(Ids_var_dist_attr.dims_mapping),
Ids_var_dist_attr.process_mesh,
chunk_id=Ids_var_dist_attr.chunk_id,
)
# rename src_op's input
src_op._rename_input(Ids_var.name, intermediate_var_0.name)
op_dist_attr.del_input_dist_attr(Ids_var.name)
op_dist_attr.set_input_dist_attr(
Expand All @@ -198,6 +201,7 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
new_op_dist_attr.impl_type = "default"
new_op_dist_attr.impl_idx = 0
new_op_dist_attr.chunk_id = Ids_var_dist_attr.chunk_id
new_op_dist_attr.set_input_dims_mapping(
Ids_var.name, Ids_var_dist_attr.dims_mapping
)
Expand Down Expand Up @@ -530,6 +534,7 @@ def forward(ctx, *args, **kwargs):
op_dist_attr.process_mesh,
out_var_dist_attr,
ctx,
chunk_id=op_dist_attr.chunk_id,
)

# param initialization sync
Expand Down
Loading