Skip to content

Commit

Permalink
revise code according to rewiew
Browse files Browse the repository at this point in the history
  • Loading branch information
JZ-LIANG committed Dec 29, 2021
1 parent 7becc2c commit d8d7c91
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
5 changes: 0 additions & 5 deletions python/paddle/distributed/auto_parallel/parallelizer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def _apply_optimize(self, main_program, startup_program, params_grads):
def _apply_post_optimization_passed(self, main_program, startup_program,
rank, params_grads):

# apply amp forward pass
if self._dist_strategy.sharding:
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
Expand All @@ -164,10 +163,6 @@ def _apply_post_optimization_passed(self, main_program, startup_program,
auto_parallel_sharding_pass.apply(
[main_program], [startup_program], self._pass_context)

# apply recompute forward pass
if self._dist_strategy.gradient_merge:
pass

def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None
serial_main_program = self._main_program.clone()
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/distributed/passes/auto_parallel_sharding.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _collective_data_parallel_groups(self, main_block):
if group is not None:
self.dp_groups.add(group)

# TODO allow more than one dp groups in network, support more general distribution
# TODO(JZ-LIANG) allow more than one dp groups in network, support more general distribution
# genetated by auto search
if len(self.dp_groups) != 1:
raise NotImplementedError(
Expand Down Expand Up @@ -136,7 +136,7 @@ def _build_sharding_infos(self, params_grads):
else:
sharding_group = dp_group

# TODO when support multiple dp groups in future, should group param and bind them to corresponding dp group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(set(
params_in_group)), "found duplicated param in params_grads"
Expand Down Expand Up @@ -192,7 +192,7 @@ def _shard_gradient_clip(self, main_block):
if self.stage < 2:
return

# TODO support calculate global norm with tensor parallelism
# TODO (JZ-LIANG) support calculate global norm with tensor parallelism
is_clip_grad_by_global_norm = False
for idx, op in list(enumerate(main_block.ops)):
if not _is_gradient_clip_op(op):
Expand Down Expand Up @@ -350,7 +350,7 @@ def _shard_gradient_synchronization(self, main_block):
else:
op._set_attr("ring_id", self.outer_dp_group.id)

main_block._sync_with_cpp
main_block._sync_with_cpp()

def _shard_parameter(self, main_block, startup_block):

Expand Down Expand Up @@ -603,7 +603,7 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
process_mesh = dist_attr.process_mesh
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
mesh_shape = process_mesh.topology
# TODO replace with specific batch size dimension
# TODO(JZ-LIANG) replace with specific batch size dimension
batch_size_axis = input_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
group_ranks = _get_comm_group(process_mesh.processes,
Expand Down

0 comments on commit d8d7c91

Please sign in to comment.