Skip to content

Commit

Permalink
[Auto Parallel] Sharding Pass (#38502)
Browse files Browse the repository at this point in the history
* auto parallel sharding base

* chmod

* add unitest

* set unitest cmake dist label

* revise code according to rewiew

* chmod
  • Loading branch information
JZ-LIANG authored Dec 29, 2021
1 parent 9456170 commit e3faf34
Show file tree
Hide file tree
Showing 13 changed files with 931 additions and 92 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ message ShardingConfig {
optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
optional int32 stage = 14 [ default = 1 ];
}

message HybridConfig {
Expand Down
9 changes: 9 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..dist_attribute import OperatorDistributedAttribute

_g_distributed_operator_impl_registries = {}
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale'}


class DistributedOperatorImplContainer:
Expand Down Expand Up @@ -116,6 +117,14 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
return best_compatible_impl, idx


def is_parameter_related(varname, block):
if ".cast_fp" in varname:
varname = varname[:varname.index(".cast_fp")]
assert block.has_var(varname)
var = block.var(varname)
return var.is_parameter


def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block.var(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
Expand Down Expand Up @@ -183,8 +183,8 @@ def backward(ctx, *args, **kwargs):
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not main_block.var(
varname).is_parameter:
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):

# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.process_mesh
Expand All @@ -210,8 +210,8 @@ def backward(ctx, *args, **kwargs):
allreduce_vars = []
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and main_block.var(
varname).is_parameter:
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
assert len(
backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
Expand Down
61 changes: 31 additions & 30 deletions python/paddle/distributed/auto_parallel/operators/dist_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
Expand All @@ -26,7 +26,7 @@
from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.framework import Program, Parameter, Variable
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
Expand Down Expand Up @@ -283,34 +283,35 @@ def forward(ctx, *args, **kwargs):
allreduce_op_dist_attr)

# param initialization sync
assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping

# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis,
rank_id)
sync_group = new_process_group(group_ranks)

startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block._sync_with_cpp()
if Weight_var.is_parameter:
assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping

# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis,
rank_id)
sync_group = new_process_group(group_ranks)

startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block._sync_with_cpp()

@staticmethod
def backward(ctx, *args, **kwargs):
Expand Down
11 changes: 7 additions & 4 deletions python/paddle/distributed/auto_parallel/operators/dist_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
Expand Down Expand Up @@ -184,7 +184,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0])

assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format(
assert not is_parameter_related(
X_var.name, main_block
), "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name)

Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
Expand All @@ -200,7 +202,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Y_var_partitioned = True
break

if Y_var.is_parameter and Y_var_partitioned:
if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:

if Y_var_dim_mapping[0] >= 0:
# row parallel: c_identity + matmul
Expand Down Expand Up @@ -322,7 +324,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)

if need_gradient_allreduce and Y_var.is_parameter:
if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block):
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
Expand Down Expand Up @@ -444,6 +446,7 @@ def is_auto_compatible(self, dist_op):
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False

if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
Expand Down
37 changes: 19 additions & 18 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from paddle.fluid import program_guard
from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
Expand Down Expand Up @@ -139,30 +140,29 @@ def _generate_backward(self, main_program, startup_program, loss,

def _apply_optimize(self, main_program, startup_program, params_grads):

if self._dist_strategy.sharding:
auto_parallel_sharding_pass = new_pass(
"auto_parallel_sharding_pass", self._dist_strategy)
params_grads = auto_parallel_sharding_pass.apply(
main_program, startup_program, params_grads, self._pass_context)

if self._dist_strategy.gradient_merge:
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass",
self._dist_strategy.gradient_merge_configs)
auto_parallel_gradient_merge_pass.apply(
main_program, startup_program, params_grads, self._pass_context)

else:
with program_guard(main_program, startup_program):
optimizer = copy.deepcopy(self._optimizer)
optimize_ops = optimizer.apply_gradients(params_grads)
with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(self._optimizer).apply_gradients(
params_grads)

# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)

return optimize_ops

def _apply_post_optimization_passed(self, main_program, startup_program,
rank, params_grads):

if self._dist_strategy.sharding:
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["global_rank"] = rank
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config)
auto_parallel_sharding_pass.apply(
[main_program], [startup_program], self._pass_context)

def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None
serial_main_program = self._main_program.clone()
Expand Down Expand Up @@ -203,7 +203,8 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)

reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)

self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
g_process_group_map = None
if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map)
Expand Down
56 changes: 30 additions & 26 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op
from .operators.common import BACKWARD_ONLY_DIST_OPS

__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]

Expand Down Expand Up @@ -102,22 +103,17 @@ def partition_startup_program(self, serial_main_program,
partitioned_startup_prog = fluid.Program()
ref_block = serial_main_program.global_block()
target_block = partitioned_startup_prog.global_block()
param2shape = {}
var2shape = {}
temp_varname_map = {}

# tensors
for var in serial_startup_program.list_vars():
if isinstance(var, Parameter):
# TODO if var not belong to this rank, should be filtered
serial_main_var = ref_block.var(var.name)
dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
serial_main_var)
target_shape = _get_dist_shape(serial_main_var, dist_attr)
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
_partition_parameter(self._dist_context, serial_main_var,
target_block, new_name, target_shape)
param2shape[new_name] = target_shape
assert var.persistable
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
target_shape = _partition_var(self._dist_context, ref_block,
target_block, var.name, new_name)
var2shape[new_name] = target_shape

# ops
for op in serial_startup_program.global_block().ops:
Expand All @@ -128,14 +124,14 @@ def partition_startup_program(self, serial_main_program,
) == 1, "initializer should output only ONE variable, but got [{}]".format(
str(op.desc))
assert temp_varname_map[output_vars[
0]] in param2shape, "try to initialize [{}] which is not a Parameter".format(
0]] in var2shape, "try to initialize [{}] which is not a persistable var".format(
output_vars[0])
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op_desc._rename_output(output_vars[0],
temp_varname_map[output_vars[0]])
new_op_desc._set_attr("shape",
param2shape[temp_varname_map[output_vars[0]]])
var2shape[temp_varname_map[output_vars[0]]])
target_block._sync_with_cpp()

# set distribute atrribute
Expand Down Expand Up @@ -211,7 +207,6 @@ def partition_main_program(self, serial_main_program, params_and_grads):
**koutputs)

elif is_backward_op(op):
print(str(op))
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op)
Expand Down Expand Up @@ -351,6 +346,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
name=dst_varname,
persistable=True,
stop_gradient=True)
target_shape = None
else:
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
target_shape = _get_dist_shape(src_var, dist_attr)
Expand All @@ -361,6 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
else:
_partition_intermediate_var(dist_context, src_var, dst_block,
dst_varname, target_shape)
return target_shape


def _get_dist_op_backward_implement(backward_op, dist_context,
Expand All @@ -371,25 +368,32 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type)
dist_op = get_distributed_operator_impl_container(forward_op.type)

# TODO backward should have its own impl_idx
if dist_ops and forward_op_dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \
forward_op_dist_attr.impl_idx)._backward_implemented:
return dist_ops.get_impl(forward_op_dist_attr.impl_idx)
return dist_op.get_impl(forward_op_dist_attr.impl_idx)

dist_ops = get_distributed_operator_impl_container("default")
return dist_ops.get_impl(0)
# NOTE trick for dist ops that only have backward implement
if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0
return get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)

dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)


def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type)
dist_op = get_distributed_operator_impl_container(forward_op.type)

if dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl(
if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl(
dist_attr.impl_idx)._forward_implemented:
return dist_ops.get_impl(dist_attr.impl_idx)
return dist_op.get_impl(dist_attr.impl_idx)

else:
dist_ops = get_distributed_operator_impl_container("default")
return dist_ops.get_impl(0)
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
Loading

0 comments on commit e3faf34

Please sign in to comment.