Skip to content

Commit

Permalink
[Auto Parallel] Update Gradient Synchronization in Static Mode (#59057)
Browse files Browse the repository at this point in the history
* completion bw partial

* debug

* bugfix

* insert param grad allreduce by partial

* reorder allreduce for opt

* fix typoes

* add grad sync unitest

* sp unitest

* fixed unitest
  • Loading branch information
JZ-LIANG authored Dec 4, 2023
1 parent 05a3536 commit 7e5f101
Show file tree
Hide file tree
Showing 5 changed files with 736 additions and 80 deletions.
86 changes: 86 additions & 0 deletions python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
find_compatible_distributed_operator_impls,
find_distributed_operator_impl_container,
)
from .operators.common import _gradient_sync_by_partial_ops
from .process_group import get_world_process_group
from .utils import (
__no_shape_var_type__,
Expand Down Expand Up @@ -166,6 +167,7 @@ def _update_op_dims_mapping_and_distoperatorimpl(
dist_op.serial_op.type, dist_op_container.type
)
)

updated = dist_op_container.update_dims_mapping(dist_op)
changed = updated or changed
# TODO(ljz) remove the below code once we introduce general reshard to replace specifc distopimpls
Expand Down Expand Up @@ -202,6 +204,7 @@ def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
assert tensor_dist_attr is not None
if tensor_dist_attr.is_annotated("dims_mapping"):
return False

tensor_dims_mapping = tensor_dist_attr.dims_mapping
if fwd:
dims_mapping_list = []
Expand Down Expand Up @@ -472,6 +475,7 @@ def _update_dims_mapping_for_special(self):
def _update_dims_mapping(self):
# Complete dims_mapping for each node
reach_fix_point = False

while not reach_fix_point:
changed = False
for is_fwd in [True, False]:
Expand All @@ -496,6 +500,7 @@ def _update_dims_mapping(self):
graph_changed = self._update_dims_mapping_between_graphs()
if graph_changed:
changed = True

if changed:
reach_fix_point = False
else:
Expand Down Expand Up @@ -1509,6 +1514,87 @@ def _complete_grad_op_with_forward_op(forward_op, grad_op, vars):

grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx

# inference partial backward
def infer_backward_op_partial_status(
vars, grad_op, grad_op_dist_attr
):
# NOTE Since we use composite op in static mode which might have implicit Reduction of broadcast axes for caculating parameter's gradient.
# Those implicit Reduction hinder the Partial inference in a normal way, and we need a special method to handle it.
param_grads = []
activation_grad = None
broadcast_axis_indies = []
if (
grad_op.type == "matmul_v2_grad"
and len(grad_op.output("Y@GRAD")) > 0
):
activation_grad = grad_op.input("Out@GRAD")[0]
param_grads.extend(grad_op.output("Y@GRAD"))
act_ndim = len(vars[activation_grad].shape)
param_ndim = len(vars[grad_op.output("Y@GRAD")[0]].shape)
# TODO handle case where trans_x or trans_y is true
# NOTE we regard axis m as broadcast axis since it is the contracting axis when calculate param grad.
if param_ndim <= 2:
if act_ndim > 1:
broadcast_axis_indies = list(range(act_ndim - 1))
elif act_ndim > param_ndim:
broadcast_axis_indies = list(
range(act_ndim - param_ndim)
)
elif grad_op.type == "elementwise_add_grad":
activation_grad = grad_op.input("Out@GRAD")[0]
param_grads.extend(grad_op.output("Y@GRAD"))
param_var = grad_op.input("Y")[0]
broadcast_axis_indies = list(
range(
len(vars[activation_grad].shape)
- len(vars[param_var].shape)
)
)
elif grad_op.type == "layer_norm_grad":
activation_grad = grad_op.input("Y@GRAD")[0]
param_grads.extend(grad_op.output("Bias@GRAD"))
param_grads.extend(grad_op.output("Scale@GRAD"))
begin_norm_axis = int(grad_op.attr("begin_norm_axis"))
broadcast_axis_indies = list(range(begin_norm_axis))
elif grad_op.type == "lookup_table_v2_grad":
activation_grad = grad_op.input("Out@GRAD")[0]
param_grads.extend(grad_op.output("W@GRAD"))
broadcast_axis_indies = list(
range(len(vars[activation_grad].shape) - 1)
)
else:
raise NotImplementedError(
f"Backward Partial is not adapted for {str(grad_op)}"
)

# resulote partial
# NOTE We set the Partial status in op_dist_attr instead tensor_dist_attr
# since the Partial will be reshard as Replicated immedidately after op output in static mode.
if len(param_grads) > 0:
activation_grad_dims_mapping = (
grad_op_dist_attr.get_input_dims_mapping(
activation_grad
)
)
for axis in broadcast_axis_indies:
if activation_grad_dims_mapping[axis] != -1:
partial_dim = activation_grad_dims_mapping[axis]
for p_grad_name in param_grads:
p_grad_dist_attr = (
grad_op_dist_attr.get_output_dist_attr(
p_grad_name
)
)
p_grad_dist_attr._set_partial_dims(
[partial_dim]
)

if grad_op.type in _gradient_sync_by_partial_ops:
infer_backward_op_partial_status(
vars, grad_op, grad_op_dist_attr
)

self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr
)
Expand Down
170 changes: 122 additions & 48 deletions python/paddle/distributed/auto_parallel/static/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@
]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}

_gradient_sync_by_partial_ops = [
"matmul_v2_grad",
"elementwise_add_grad",
"layer_norm_grad",
"lookup_table_v2_grad",
# "conv",
]


class ParallelMode:
"""
Expand Down Expand Up @@ -438,7 +446,6 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
act_grad_names (list): list of input activation grads variable name to the current operator.
out_grad_names (list): list of the output parameter's grads variable name of the current operator.
rank (int): global ranks index for current process.
"""
dp_group = None
Expand Down Expand Up @@ -466,11 +473,13 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
)
dp_group = new_process_group(group_ranks)
break

return dp_group
if dp_group is not None:
return [dp_group]
else:
return []


def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names):
"""
insert the allreudce and scale ops for gradients of model
parameters for operator in data parallelism.
Expand All @@ -485,55 +494,113 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
process_mesh = op_dist_attr.process_mesh
dist_op_context = dist_ctx.dist_op_context
main_block = dist_op_context.work_block
dp_degree = len(dp_group.ranks)

for var_name in allreduce_var_names:
added_ops = []
grad_var = main_block.var(var_name)
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [grad_var]},
outputs={'Out': [grad_var]},
attrs={
'ring_id': dp_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward,
},
)
allreduce_op._set_attr('op_namescope', '/' + ParallelMode.DataParallel)
added_ops.append(allreduce_op)

if dist_ctx.gradient_scale:
scale_op = main_block.append_op(
type='scale',
inputs={'X': grad_var},
outputs={'Out': grad_var},
attrs={'scale': 1.0 / dp_degree, OP_ROLE_KEY: OpRole.Backward},

for group in groups:
group_size = len(group.ranks)

for var_name in allreduce_var_names:
added_ops = []
grad_var = main_block.var(var_name)
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [grad_var]},
outputs={'Out': [grad_var]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward,
},
)
allreduce_op._set_attr(
'op_namescope', '/' + ParallelMode.DataParallel
)
scale_op._set_attr('op_namescope', '/' + ParallelMode.DataParallel)
added_ops.append(scale_op)
added_ops.append(allreduce_op)

if dist_ctx.gradient_scale:
scale_op = main_block.append_op(
type='scale',
inputs={'X': grad_var},
outputs={'Out': grad_var},
attrs={
'scale': 1.0 / group_size,
OP_ROLE_KEY: OpRole.Backward,
},
)
scale_op._set_attr(
'op_namescope', '/' + ParallelMode.DataParallel
)
added_ops.append(scale_op)

dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
assert (
dims_mapping is not None
), "Unexpected: dims_mapping of output [{}] of op [{}] is None".format(
grad_var.name, op_dist_attr.op_type
)
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
for new_op in added_ops:
new_op_attr = OperatorDistAttr()
new_op_attr.process_mesh = process_mesh
new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr)
dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
assert (
dims_mapping is not None
), "Unexpected: dims_mapping of output [{}] of op [{}] is None".format(
grad_var.name, op_dist_attr.op_type
)
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
for new_op in added_ops:
new_op_attr = OperatorDistAttr()
new_op_attr.process_mesh = process_mesh
new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr)


def get_partial_groups(dist_ctx, op, out_grad_names, rank):
"""
deduce the partial comminication group for current operator output vars.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
out_grad_names (list): list of the output parameter's grads variable name of the current operator.
rank (int): global ranks index for current process.
"""
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
process_mesh = op_dist_attr.process_mesh
mesh_shape = process_mesh.shape

groups = []

partial_dims = None
for var_name in out_grad_names:
var_dist_attr = op_dist_attr.get_output_dist_attr(var_name)
if partial_dims is None:
partial_dims = var_dist_attr._partial_dims()
else:
assert (
partial_dims == var_dist_attr._partial_dims()
), "Partial dims of outputs {} of op [{}] is not consistent".format(
out_grad_names, op.type
)

partial_dims = list(partial_dims)
partial_dims.sort()

# FIXME Hack for Pipeline Parallelism where the current operator
# not belong to the mesh the current rank belong to.
if rank not in process_mesh.process_ids:
rank = _get_corresponding_rank(dist_ctx, process_mesh, rank)

for dim in partial_dims:
if mesh_shape[dim] > 1:
group_ranks = _get_comm_group(
process_mesh.process_ids,
process_mesh.shape,
dim,
rank,
)
groups.append(new_process_group(group_ranks))

return groups


def gradient_synchronization(
dist_ctx, op, act_grad_names, out_grad_names, rank
):
"""
conduct the allreudce and scaling(dp size)for gradients of model
parameters for operator in data parallelism.
conduct the allreudce and scaling for gradients of model
parameters for operator in parallelism train.
Args:
dist_ctx (DistributedContext): dist context.
Expand All @@ -553,12 +620,19 @@ def gradient_synchronization(
):
return

dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank)
if op.type in _gradient_sync_by_partial_ops:
sync_groups = get_partial_groups(dist_ctx, op, out_grad_names, rank)
# NOTE we reverse the following old branch to support operators (e.g. fuse operators) that haven't been adopted for partial inferspmd,
# and remove this branch after all operators are adopted for partial inferspmd.
else:
sync_groups = get_data_parallel_group(
dist_ctx, op, act_grad_names, rank
)

if not dp_group:
if len(sync_groups) < 1:
return

sync_and_scale_gradients(dist_ctx, op, dp_group, out_grad_names)
sync_and_scale_gradients(dist_ctx, op, sync_groups, out_grad_names)


def is_data_parallel_scale_op(op):
Expand Down
Loading

0 comments on commit 7e5f101

Please sign in to comment.