Skip to content

Commit

Permalink
2D Parallelism in TorchRec (pytorch#2554)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2554

In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name.

Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs.

The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth.

Example Use Case:
        Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be:
            - Group 0, DMP 0: [0, 2, 4, 6]
            - Group 1, DMP 1: [1, 3, 5, 7]

        Each group receives an identical sharding plan for their local world size and ranks.
        If we have one table sharded in each DMP, with one shard on each rank in the group,
        each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1.
        The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7].

NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results

Differential Revision: D61643328
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Nov 12, 2024
1 parent be4b9c7 commit 8ed3c32
Show file tree
Hide file tree
Showing 14 changed files with 982 additions and 54 deletions.
1 change: 0 additions & 1 deletion torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,6 @@ def __init__(
fused_params = config.fused_params or {}
if "cache_precision" not in fused_params:
fused_params["cache_precision"] = weights_precision

self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=list(
Expand Down
56 changes: 56 additions & 0 deletions torchrec/distributed/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,59 @@ def intra_and_cross_node_pg(
dist.barrier()

return _INTRA_PG, _CROSS_PG


def intra_and_cross_node_pg_2D(
local_pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
backend: Optional[str] = None,
) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]:
"""
Creates sub process groups (intra and cross node) under 2D parallelism scheme
"""
if device is not None and device.type == "meta":
return None, None

# we keep names same for backwards compat, but they are not "intra" and "cross" in the traditional sense anymore under 2D parallelism
global _INTRA_PG # intra node process group
global _CROSS_PG # cross node process group

my_rank = dist.get_rank()

local_size = dist.get_world_size(local_pg) # Local replica group world size
world_size = dist.get_world_size() # Global world size
step = world_size // local_size
devices_per_node = get_local_size(world_size)

if _INTRA_PG is None:
for group_rank in range(step):
local_pg_peers = [step * r + group_rank for r in range(local_size)]
for group in range(len(local_pg_peers) // devices_per_node):
intra_pg_peers = local_pg_peers[
group * devices_per_node : (group + 1) * devices_per_node
]
curr_intra_pg = dist.new_group(backend=backend, ranks=intra_pg_peers)
if my_rank in intra_pg_peers:
logger.warning(
f"[Connection] 2D rank {my_rank} -> intra_pg_peers {intra_pg_peers}"
)
_INTRA_PG = curr_intra_pg
dist.barrier()

if _CROSS_PG is None:
for group_rank in range(step):
local_pg_peers = [step * r + group_rank for r in range(local_size)]
for cross_group_rank in range(devices_per_node):
cross_pg_peers = [
local_pg_peers[cross_group_rank + g * devices_per_node]
for g in range(devices_per_node)
]
curr_cross_pg = dist.new_group(backend=backend, ranks=cross_pg_peers)
if my_rank in cross_pg_peers:
logger.warning(
f"[Connection] 2D rank {my_rank} -> cross_pg_peers {cross_pg_peers}"
)
_CROSS_PG = curr_cross_pg
dist.barrier()

return _INTRA_PG, _CROSS_PG
16 changes: 15 additions & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
QuantizedCommCodecs,
ShardedTensor,
ShardingEnv,
ShardingEnv2D,
ShardingType,
ShardMetadata,
)
Expand Down Expand Up @@ -155,6 +156,8 @@ def create_embedding_bag_sharding(
EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
]:
sharding_type = sharding_infos[0].param_sharding.sharding_type
is_2D_parallel: bool = isinstance(env, ShardingEnv2D)

if device is not None and device.type == "meta":
replace_placement_with_meta_device(sharding_infos)
if sharding_type == ShardingType.TABLE_WISE.value:
Expand All @@ -163,13 +166,15 @@ def create_embedding_bag_sharding(
env,
device,
qcomm_codecs_registry=qcomm_codecs_registry,
is_2D_parallel=is_2D_parallel,
)
elif sharding_type == ShardingType.ROW_WISE.value:
return RwPooledEmbeddingSharding(
sharding_infos,
env,
device,
qcomm_codecs_registry=qcomm_codecs_registry,
is_2D_parallel=is_2D_parallel,
)
elif sharding_type == ShardingType.DATA_PARALLEL.value:
return DpPooledEmbeddingSharding(sharding_infos, env, device)
Expand All @@ -179,6 +184,7 @@ def create_embedding_bag_sharding(
env,
device,
qcomm_codecs_registry=qcomm_codecs_registry,
is_2D_parallel=is_2D_parallel,
)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return CwPooledEmbeddingSharding(
Expand All @@ -187,6 +193,7 @@ def create_embedding_bag_sharding(
device,
permute_embeddings=permute_embeddings,
qcomm_codecs_registry=qcomm_codecs_registry,
is_2D_parallel=is_2D_parallel,
)
elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value:
return TwCwPooledEmbeddingSharding(
Expand All @@ -202,6 +209,7 @@ def create_embedding_bag_sharding(
env,
device,
qcomm_codecs_registry=qcomm_codecs_registry,
is_2D_parallel=is_2D_parallel,
)
else:
raise ValueError(f"Sharding type not supported {sharding_type}")
Expand Down Expand Up @@ -942,9 +950,14 @@ def _initialize_torch_state(self) -> None: # noqa
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
process_group=(
self._env.local_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)
# pass

def extract_sharded_kvtensors(
module: ShardedEmbeddingBagCollection,
Expand All @@ -967,6 +980,7 @@ def post_state_dict_hook(
prefix: str,
_local_metadata: Dict[str, Any],
) -> None:
print(f"in EBC post state dict hook!")
# Adjust dense MP
for (
table_name,
Expand Down
183 changes: 183 additions & 0 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,29 @@

import torch
import torch.distributed as dist
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
SplitTableBatchedEmbeddingBagsCodegen,
)
from torch import nn
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as ddp_default_hooks,
)
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.remote_device import _remote_device
from torch.distributed.tensor import DeviceMesh
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.sharding_plan import get_default_sharders
from torchrec.distributed.types import (
EnumerableShardingSpec,
ModuleSharder,
ShardedModule,
ShardingEnv,
ShardingEnv2D,
ShardingPlan,
)
from torchrec.distributed.utils import (
Expand Down Expand Up @@ -440,6 +448,7 @@ def state_dict(
prefix: str = "",
keep_vars: bool = False,
) -> Dict[str, Any]:
print(f"DMP state dict called! {destination=}")
state_dict = get_module(self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
Expand Down Expand Up @@ -599,3 +608,177 @@ def _reset_parameters(module: nn.Module) -> None:
for _, m in module.named_modules():
if hasattr(m, "reset_parameters"):
m.reset_parameters()


class DMPCollection(DistributedModelParallel):
"""
A wrapper around DistributedModelParallel that allows for multiple DMPs to be created and managed together.
This class implements a 2D parallelism model where a DMP is sharded over a subset of ranks.
The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node.
This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth.
Example Use Case:
Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be:
- Group 0, DMP 0: [0, 2, 4, 6]
- Group 1, DMP 1: [1, 3, 5, 7]
Each group recieves an identical sharding plan for their local world size and ranks.
If we have one table sharded in each DMP, with one shard on each rank in the group,
each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1.
The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7].
Notes:
- DTensor must be used for state dict for checkpointing to work correctly.
- The expected sharding plan should be sharded across local_size (replication group world size)
and broadcasted to all ranks (planner.collective_plan(..)).
"""

def __init__(
self,
module: nn.Module,
device: torch.device,
plan: ShardingPlan,
local_size: int,
world_size: int,
global_pg: dist.ProcessGroup,
sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None,
init_data_parallel: bool = True,
init_parameters: bool = True,
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
) -> None:
assert device.type == "cuda", "DMPCollection only supports CUDA"
self._pg: dist.ProcessGroup = global_pg
self._plan = plan
self._device = device
self._device_mesh: DeviceMesh = None
self._local_pg: dist.ProcessGroup = None # pyre-ignore[8]
self._inter_pg: dist.ProcessGroup = None # pyre-ignore[8]
self._global_rank: int = dist.get_rank(global_pg)

self._device_mesh, self._local_pg, self._inter_pg = self._create_process_groups(
global_rank=self._global_rank,
world_size=world_size,
local_size=local_size,
)
self._remap_sharding_plan(plan, self._global_rank, world_size // local_size)
super().__init__(
module,
ShardingEnv2D(
pg=self._pg,
local_pg=self._local_pg,
device_mesh=self._device_mesh,
),
device,
plan,
sharders,
init_data_parallel,
init_parameters,
data_parallel_wrapper,
)
# post DMP init, we group sharded modules for parameter sync
self._modules_to_sync: List[nn.Module] = self._group_sharded_modules()

def sync(self) -> None:
"""
Syncs the DMP weights across the allreduce (inter) process group
"""
assert self._inter_pg is not None, "inter_pg is not initialized!"
opts = dist.AllreduceCoalescedOptions()
# pyre-fixme[8]
opts.reduceOp = dist.ReduceOp.AVG
# pyre-fixme[8]
all_weights = [
w
for emb_kernel in self._modules_to_sync
for w in emb_kernel.split_embedding_weights()
]
handle = self._inter_pg.allreduce_coalesced(all_weights, opts=opts)
handle.wait()

def _create_process_groups(
self, global_rank: int, world_size: int, local_size: int
) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]:
# TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a
peer_matrix = []
local_pg, inter_pg = None, None
rank = global_rank
step = world_size // local_size

my_group_rank = rank % step
for group_rank in range(world_size // local_size):
peers = [step * r + group_rank for r in range(local_size)]
backend = dist.get_backend(self._pg)
curr_pg = dist.new_group(backend=backend, ranks=peers)
peer_matrix.append(peers)
if my_group_rank == group_rank:
local_pg = curr_pg
dist.barrier()

my_inter_rank = rank // step
for inter_rank in range(local_size):
peers = [inter_rank * step + r for r in range(step)]
backend = dist.get_backend(self._pg)
curr_pg = dist.new_group(backend=backend, ranks=peers)
if my_inter_rank == inter_rank:
inter_pg = curr_pg
dist.barrier()

mesh = DeviceMesh(
device_type=self._device.type,
mesh=peer_matrix,
mesh_dim_names=("replicate", "shard"),
)

assert local_pg is not None, "local_pg is not initialized!"
assert inter_pg is not None, "inter_pg is not initialized!"
return mesh, local_pg, inter_pg

def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None:
"""
Remaps the sharding plan to the local replica process group ranks
ShardingPlan is remapped inplace.
As example, ShardingPlan for created for ranks [0, 2, 4, 6] is remapped to ranks [1, 3, 5, 7]
"""

group_start = rank % step
for key in plan.plan:
# pyre-ignore[16]
for _, param_sharding in plan.plan[key].items():
new_ranks = []
for shard_rank in param_sharding.ranks:
new_ranks.append(shard_rank * step + group_start)
param_sharding.ranks = new_ranks
if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
shards = param_sharding.sharding_spec.shards
if shards is not None:
for shard in shards:
shard_rank = shard.placement._rank * step + group_start
shard.placement = _remote_device(
f"rank:{shard_rank}/cuda:{shard_rank}"
)
return

def _group_sharded_modules(
self,
):
# TODO: look into embedding sync into forward call?
# Post init DMP, save the embedding kernels
sharded_modules: List[nn.Module] = []

def _find_sharded_modules(
module: nn.Module,
) -> None:
if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen):
sharded_modules.append(module)
if isinstance(module, ShardedEmbeddingBagCollection):
for lookup in module._lookups:
_find_sharded_modules(lookup)
return
for _, child in module.named_children():
_find_sharded_modules(child)

_find_sharded_modules(self)
print(f"Found {sharded_modules=}")
return sharded_modules
Loading

0 comments on commit 8ed3c32

Please sign in to comment.