Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Sharding Pass #38502

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ message ShardingConfig {
optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
optional int32 stage = 14 [ default = 1 ];
}

message HybridConfig {
Expand Down
9 changes: 9 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..dist_attribute import OperatorDistributedAttribute

_g_distributed_operator_impl_registries = {}
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale'}


class DistributedOperatorImplContainer:
Expand Down Expand Up @@ -116,6 +117,14 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
return best_compatible_impl, idx


def is_parameter_related(varname, block):
if ".cast_fp" in varname:
varname = varname[:varname.index(".cast_fp")]
assert block.has_var(varname)
var = block.var(varname)
return var.is_parameter


def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block.var(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
Expand Down Expand Up @@ -183,8 +183,8 @@ def backward(ctx, *args, **kwargs):
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not main_block.var(
varname).is_parameter:
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):

# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.process_mesh
Expand All @@ -210,8 +210,8 @@ def backward(ctx, *args, **kwargs):
allreduce_vars = []
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and main_block.var(
varname).is_parameter:
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
assert len(
backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
Expand Down
61 changes: 31 additions & 30 deletions python/paddle/distributed/auto_parallel/operators/dist_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
Expand All @@ -26,7 +26,7 @@
from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.framework import Program, Parameter, Variable
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
Expand Down Expand Up @@ -283,34 +283,35 @@ def forward(ctx, *args, **kwargs):
allreduce_op_dist_attr)

# param initialization sync
assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping

# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis,
rank_id)
sync_group = new_process_group(group_ranks)

startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block._sync_with_cpp()
if Weight_var.is_parameter:
assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping

# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis,
rank_id)
sync_group = new_process_group(group_ranks)

startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block._sync_with_cpp()

@staticmethod
def backward(ctx, *args, **kwargs):
Expand Down
11 changes: 7 additions & 4 deletions python/paddle/distributed/auto_parallel/operators/dist_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
Expand Down Expand Up @@ -184,7 +184,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0])

assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format(
assert not is_parameter_related(
X_var.name, main_block
), "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name)

Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
Expand All @@ -200,7 +202,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Y_var_partitioned = True
break

if Y_var.is_parameter and Y_var_partitioned:
if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:

if Y_var_dim_mapping[0] >= 0:
# row parallel: c_identity + matmul
Expand Down Expand Up @@ -322,7 +324,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)

if need_gradient_allreduce and Y_var.is_parameter:
if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block):
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
Expand Down Expand Up @@ -444,6 +446,7 @@ def is_auto_compatible(self, dist_op):
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False

if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
Expand Down
42 changes: 24 additions & 18 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from paddle.fluid import program_guard
from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
Expand Down Expand Up @@ -139,30 +140,34 @@ def _generate_backward(self, main_program, startup_program, loss,

def _apply_optimize(self, main_program, startup_program, params_grads):

if self._dist_strategy.sharding:
auto_parallel_sharding_pass = new_pass(
"auto_parallel_sharding_pass", self._dist_strategy)
params_grads = auto_parallel_sharding_pass.apply(
main_program, startup_program, params_grads, self._pass_context)

if self._dist_strategy.gradient_merge:
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass",
self._dist_strategy.gradient_merge_configs)
auto_parallel_gradient_merge_pass.apply(
main_program, startup_program, params_grads, self._pass_context)

else:
with program_guard(main_program, startup_program):
optimizer = copy.deepcopy(self._optimizer)
optimize_ops = optimizer.apply_gradients(params_grads)
with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(self._optimizer).apply_gradients(
params_grads)

# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)

return optimize_ops

def _apply_post_optimization_passed(self, main_program, startup_program,
rank, params_grads):

# apply amp forward pass
Copy link
Collaborator

@sneaxiy sneaxiy Dec 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the comment for? TODO or wrong comment? The following code is sharding but not amp. If it is a TODO, try to add TODO(who is responsible TODO) at the beginning of this comment.

Copy link
Contributor Author

@JZ-LIANG JZ-LIANG Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is kind of a TODO that indicate the place where the amp pass will be in future. our final goal is that all optimization pass will be applied within that function after autoparallel-graph partition. we will have several update to achieve that goal.
the final order will be: graph_partition-amp-recompute-sharding-gradient_merge
but at this moment, we implement it as amp-recompute-graph_partition-sharding-gradient_merge

fixed~

if self._dist_strategy.sharding:
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["global_rank"] = rank
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config)
auto_parallel_sharding_pass.apply(
[main_program], [startup_program], self._pass_context)

# apply recompute forward pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above reply.

if self._dist_strategy.gradient_merge:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this code is not implemented yet, try to remove it first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None
serial_main_program = self._main_program.clone()
Expand Down Expand Up @@ -203,7 +208,8 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)

reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)

self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
g_process_group_map = None
if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map)
Expand Down
56 changes: 30 additions & 26 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op
from .operators.common import BACKWARD_ONLY_DIST_OPS

__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]

Expand Down Expand Up @@ -102,22 +103,17 @@ def partition_startup_program(self, serial_main_program,
partitioned_startup_prog = fluid.Program()
ref_block = serial_main_program.global_block()
target_block = partitioned_startup_prog.global_block()
param2shape = {}
var2shape = {}
temp_varname_map = {}

# tensors
for var in serial_startup_program.list_vars():
if isinstance(var, Parameter):
# TODO if var not belong to this rank, should be filtered
serial_main_var = ref_block.var(var.name)
dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
serial_main_var)
target_shape = _get_dist_shape(serial_main_var, dist_attr)
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
_partition_parameter(self._dist_context, serial_main_var,
target_block, new_name, target_shape)
param2shape[new_name] = target_shape
assert var.persistable
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
target_shape = _partition_var(self._dist_context, ref_block,
target_block, var.name, new_name)
var2shape[new_name] = target_shape

# ops
for op in serial_startup_program.global_block().ops:
Expand All @@ -128,14 +124,14 @@ def partition_startup_program(self, serial_main_program,
) == 1, "initializer should output only ONE variable, but got [{}]".format(
str(op.desc))
assert temp_varname_map[output_vars[
0]] in param2shape, "try to initialize [{}] which is not a Parameter".format(
0]] in var2shape, "try to initialize [{}] which is not a persistable var".format(
output_vars[0])
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op_desc._rename_output(output_vars[0],
temp_varname_map[output_vars[0]])
new_op_desc._set_attr("shape",
param2shape[temp_varname_map[output_vars[0]]])
var2shape[temp_varname_map[output_vars[0]]])
target_block._sync_with_cpp()

# set distribute atrribute
Expand Down Expand Up @@ -211,7 +207,6 @@ def partition_main_program(self, serial_main_program, params_and_grads):
**koutputs)

elif is_backward_op(op):
print(str(op))
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op)
Expand Down Expand Up @@ -351,6 +346,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
name=dst_varname,
persistable=True,
stop_gradient=True)
target_shape = None
else:
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
target_shape = _get_dist_shape(src_var, dist_attr)
Expand All @@ -361,6 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
else:
_partition_intermediate_var(dist_context, src_var, dst_block,
dst_varname, target_shape)
return target_shape


def _get_dist_op_backward_implement(backward_op, dist_context,
Expand All @@ -371,25 +368,32 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type)
dist_op = get_distributed_operator_impl_container(forward_op.type)

# TODO backward should have its own impl_idx
if dist_ops and forward_op_dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \
forward_op_dist_attr.impl_idx)._backward_implemented:
return dist_ops.get_impl(forward_op_dist_attr.impl_idx)
return dist_op.get_impl(forward_op_dist_attr.impl_idx)

dist_ops = get_distributed_operator_impl_container("default")
return dist_ops.get_impl(0)
# NOTE trick for dist ops that only have backward implement
if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0
return get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)

dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)


def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type)
dist_op = get_distributed_operator_impl_container(forward_op.type)

if dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl(
if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl(
dist_attr.impl_idx)._forward_implemented:
return dist_ops.get_impl(dist_attr.impl_idx)
return dist_op.get_impl(dist_attr.impl_idx)

else:
dist_ops = get_distributed_operator_impl_container("default")
return dist_ops.get_impl(0)
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
Loading