Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update doc string and clean up variable naming for 2D parallel #2709

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,9 @@ def __init__(
)

self._remap_sharding_plan(
plan, self._global_rank, world_size // sharding_group_size
plan=plan,
rank=self._global_rank,
num_nodes=world_size // sharding_group_size,
)
super().__init__(
module,
Expand All @@ -733,7 +735,7 @@ def sync(self, include_optimizer_state: bool = True) -> None:
"""
Syncs the DMP weights across the allreduce (inter) process group

This method is called after each forward pass to synchronize the weights of the sharded modules.
This method is called after each train step to synchronize the weights of the sharded modules.
It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights,
which averages the weights across all processes in the inter-process group.

Expand Down Expand Up @@ -782,10 +784,10 @@ def _create_process_groups(
replication process group, and allreduce process group.
"""
peer_matrix = []
step = world_size // local_size
num_nodes = world_size // local_size

for group_rank in range(world_size // local_size):
peers = [step * r + group_rank for r in range(local_size)]
peers = [num_nodes * r + group_rank for r in range(local_size)]
peer_matrix.append(peers)

mesh = DeviceMesh(
Expand All @@ -805,7 +807,9 @@ def _create_process_groups(

return mesh, sharding_pg, replica_pg

def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None:
def _remap_sharding_plan(
self, plan: ShardingPlan, rank: int, num_nodes: int
) -> None:
"""
Remaps the sharding plan to the local replica process group ranks
ShardingPlan is remapped inplace.
Expand All @@ -816,22 +820,22 @@ def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None
Args:
plan (ShardingPlan): The original sharding plan.
global_rank (int): The global rank of the current process.
step (int): The number of nodes.
num_nodes (int): The number of nodes.
"""

group_start = rank % step
group_start = rank % num_nodes
for key in plan.plan:
# pyre-ignore[16]
for _, param_sharding in plan.plan[key].items():
new_ranks = []
for shard_rank in param_sharding.ranks:
new_ranks.append(shard_rank * step + group_start)
new_ranks.append(shard_rank * num_nodes + group_start)
param_sharding.ranks = new_ranks
if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
shards = param_sharding.sharding_spec.shards
if shards is not None:
for shard in shards:
shard_rank = shard.placement._rank * step + group_start
shard_rank = shard.placement._rank * num_nodes + group_start
shard.placement = _remote_device(
f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}"
)
Expand Down
Loading