From 05be002d2e7b8e222420daca7f0f86e9d2f41dc5 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Thu, 21 Nov 2024 10:06:41 -0800 Subject: [PATCH] TorchRec 2D Parallel (#2554) Summary: In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Reviewed By: dstaay-fb Differential Revision: D61643328 --- .../distributed/batched_embedding_kernel.py | 5 +- torchrec/distributed/comm.py | 110 +++++ torchrec/distributed/embeddingbag.py | 14 +- torchrec/distributed/model_parallel.py | 270 ++++++++++++ torchrec/distributed/sharding/cw_sharding.py | 16 +- .../distributed/sharding/grid_sharding.py | 59 ++- torchrec/distributed/sharding/rw_sharding.py | 18 +- torchrec/distributed/sharding/tw_sharding.py | 30 +- .../test_utils/test_model_parallel.py | 4 + .../distributed/test_utils/test_sharding.py | 61 ++- .../distributed/tests/test_2d_sharding.py | 404 ++++++++++++++++++ torchrec/distributed/types.py | 47 ++ 12 files changed, 991 insertions(+), 47 deletions(-) create mode 100644 torchrec/distributed/tests/test_2d_sharding.py diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index ddac29c44..bc9348aad 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -46,7 +46,7 @@ PartiallyMaterializedTensor, ) from torch import nn -from torchrec.distributed.comm import get_local_rank, get_local_size +from torchrec.distributed.comm import get_local_rank, get_node_group_size from torchrec.distributed.composable.table_batched_embedding_slice import ( TableBatchedEmbeddingSlice, ) @@ -303,7 +303,7 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata( ) # for grid sharding, the row dimension is replicated CW shard times grid_shard_nodes = ( - len(table_global_shards_metadata) // get_local_size() + len(table_global_shards_metadata) // get_node_group_size() if is_grid_sharded else 1 ) @@ -1445,7 +1445,6 @@ def __init__( fused_params = config.fused_params or {} if "cache_precision" not in fused_params: fused_params["cache_precision"] = weights_precision - self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = ( SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=list( diff --git a/torchrec/distributed/comm.py b/torchrec/distributed/comm.py index f3edd56e7..2b445fe58 100644 --- a/torchrec/distributed/comm.py +++ b/torchrec/distributed/comm.py @@ -13,6 +13,7 @@ import torch import torch.distributed as dist +from torchrec.distributed.types import ShardingEnv2D logger: logging.Logger = logging.getLogger(__name__) @@ -20,6 +21,11 @@ _INTRA_PG: Optional[dist.ProcessGroup] = None _CROSS_PG: Optional[dist.ProcessGroup] = None +# For 2D parallel +_INTRA_PG_2D: Optional[dist.ProcessGroup] = None +_CROSS_PG_2D: Optional[dist.ProcessGroup] = None +_NODE_GROUP_SIZE_2D: Optional[int] = None + def _env2int(env_list: List[str], default: int = -1) -> int: for e in env_list: @@ -54,6 +60,15 @@ def get_local_size(world_size: Optional[int] = None) -> int: return local_size +def get_node_group_size(world_size: Optional[int] = None) -> int: + """ + Get the local world size accounting for 2D environment, if not set, we fallback to global environment + """ + if _NODE_GROUP_SIZE_2D is None: + return get_local_size(world_size) + return _NODE_GROUP_SIZE_2D + + def get_local_rank(world_size: Optional[int] = None, rank: Optional[int] = None) -> int: """ Gets the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html) @@ -151,3 +166,98 @@ def intra_and_cross_node_pg( dist.barrier() return _INTRA_PG, _CROSS_PG + + +def intra_and_cross_node_pg_2D( + env: ShardingEnv2D, + device: Optional[torch.device] = None, +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: + """ + Creates sub process groups (intra and cross node) under 2D parallelism scheme + The concept of "intra" and "cross" node is lost under a 2D parallelism scheme + due to the ranks that exist under a sharding group do not have gurantee of the typical + node topology. And as such there are no guarantees of "intra" group exploiting intra node bandwidth. + + NOTE: + These process groups are created for sharding schemes (ie: GRID) that were designed to exploit + intra node bandwidth for optimized comms. There will be future work to redesign the comms for GRID + sharding to be optimized under a 2D setup. + + Example:: + Here is what "intra" and "cross" groups look like in a 2D environment, + Sharding Groups: + Group 0: [0, 2, 4, 6] + Group 1: [1, 3, 5, 7] + devices_per_node = 2: + "intra" groups for each sharding group, + Group 0: [0, 2], [4, 6] + Group 1: [1, 3], [5, 7] + "cross" groups for each sharding group, + Group 0: [0, 4], [2, 6] + Group 1: [1, 5], [3, 7] + + We can see as this scales to real world topologies how the "intra" and "cross" node ideas in a traditional + sense are not applicable here. + """ + if device is not None and device.type == "meta": + return None, None + + global _INTRA_PG_2D + global _CROSS_PG_2D + global _NODE_GROUP_SIZE_2D + + backend = dist.get_backend(env.sharding_pg) + my_rank = dist.get_rank() + + sharding_group_size = dist.get_world_size( + env.sharding_pg + ) # Local replica group world size + world_size = dist.get_world_size() # Global world size + step = world_size // sharding_group_size + devices_per_node = ( + env.node_group_size if env.node_group_size else get_local_size(world_size) + ) + _NODE_GROUP_SIZE_2D = devices_per_node + + assert ( + sharding_group_size % devices_per_node == 0 + ), f"node group size is not divisible by sharding group size, {devices_per_node=}, {sharding_group_size=}" + intra_pg_groups: List[List[List[int]]] = [[] for _ in range(step)] + + if _INTRA_PG_2D is None: + for group_rank in range(step): + sharding_pg_peers = [ + step * r + group_rank for r in range(sharding_group_size) + ] + for group in range(len(sharding_pg_peers) // devices_per_node): + intra_pg_peers = sharding_pg_peers[ + group * devices_per_node : (group + 1) * devices_per_node + ] + intra_pg_groups[group_rank].append(intra_pg_peers) + curr_intra_pg = dist.new_group(backend=backend, ranks=intra_pg_peers) + if my_rank in intra_pg_peers: + logger.warning( + f"[Connection] 2D rank {my_rank} -> intra_pg_peers {intra_pg_peers}" + ) + _INTRA_PG_2D = curr_intra_pg + assert _INTRA_PG_2D is not None, "INTRA_PG_2D is not initialized!" + dist.barrier() + + if _CROSS_PG_2D is None: + for group_rank in range(step): + intra_pg_group = intra_pg_groups[group_rank] + for cross_group_rank in range(devices_per_node): + cross_pg_peers = [ + intra_pg_group[j][cross_group_rank] + for j in range(len(intra_pg_group)) + ] + curr_cross_pg = dist.new_group(backend=backend, ranks=cross_pg_peers) + if my_rank in cross_pg_peers: + logger.warning( + f"[Connection] 2D rank {my_rank} -> cross_pg_peers {cross_pg_peers}" + ) + _CROSS_PG_2D = curr_cross_pg + assert _CROSS_PG_2D is not None, "CROSS_PG_2D is not initialized!" + dist.barrier() + + return _INTRA_PG_2D, _CROSS_PG_2D diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 56d0465db..5f98218b5 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -65,6 +65,7 @@ QuantizedCommCodecs, ShardedTensor, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, TensorProperties, @@ -149,6 +150,7 @@ def create_embedding_bag_sharding( EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ]: sharding_type = sharding_infos[0].param_sharding.sharding_type + if device is not None and device.type == "meta": replace_placement_with_meta_device(sharding_infos) if sharding_type == ShardingType.TABLE_WISE.value: @@ -949,10 +951,14 @@ def _initialize_torch_state(self) -> None: # noqa ) self._model_parallel_name_to_sharded_tensor[table_name] = ( - ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=local_shards, - sharded_tensor_metadata=metadata, - process_group=none_throws(self._env.process_group), + ShardedTensor._init_from_local_shards( + local_shards, + self._name_to_table_size[table_name], + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), ) ) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 11164e3e0..5bc162179 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -9,16 +9,22 @@ import abc import copy +import logging as logger from collections import OrderedDict from typing import Any, cast, Dict, Iterator, List, Optional, Set, Tuple, Type import torch import torch.distributed as dist +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from torch import nn from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as ddp_default_hooks, ) from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DeviceMesh from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size @@ -26,9 +32,11 @@ from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.types import ( + EnumerableShardingSpec, ModuleSharder, ShardedModule, ShardingEnv, + ShardingEnv2D, ShardingPlan, ) from torchrec.distributed.utils import ( @@ -599,3 +607,265 @@ def _reset_parameters(module: nn.Module) -> None: for _, m in module.named_modules(): if hasattr(m, "reset_parameters"): m.reset_parameters() + + +class DMPCollection(DistributedModelParallel): + """ + A wrapper around DistributedModelParallel that allows for multiple DMPs to be created and managed together. + + This class implements a 2D parallelism model where a DMP is sharded over a subset of ranks. + The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. + This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. + + Example Use Case: + Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: + - Group 0, DMP 0: [0, 2, 4, 6] + - Group 1, DMP 1: [1, 3, 5, 7] + + Each group receives an identical sharding plan for their local world size and ranks. + If we have one table sharded in each DMP, with one shard on each rank in the group, + each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. + The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. + + Notes: + - DTensor must be used for state dict for checkpointing to work correctly. + - The expected sharding plan should be sharded across sharding_group_size (sharding group world size) + and broadcasted to all ranks (`planner.collective_plan(..)`). + + Args: + module (nn.Module): The module to be sharded. + device (torch.device): The device to use for the sharded module. + plan (ShardingPlan): The sharding plan to use, created for sharding group world size. + sharding_group_size (int): The number of GPUs to model parallel shard the embedding tables over + world_size (int): The total number of GPUs. + global_pg (dist.ProcessGroup): The global process group. + node_group_size (Optional[int]): Specify a logical group size for a node for TWRW/GRID sharding schemes + sharders (Optional[List[ModuleSharder[torch.nn.Module]]]): The sharders to use. + init_data_parallel (bool): Whether to initialize data parallelism. + init_parameters (bool): Whether to initialize parameters. + data_parallel_wrapper (Optional[DataParallelWrapper]): The data parallel wrapper to use. + + Example:: + + @torch.no_grad() + def init_weights(m): + if isinstance(m, nn.Linear): + m.weight.fill_(1.0) + elif isinstance(m, EmbeddingBagCollection): + for param in m.parameters(): + init.kaiming_normal_(param) + + m = MyModel(device='meta') + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size=global_world_size, + local_world_size=sharding_group_size, + ), + constraints=constraints, + ) + plan = planner.collective_plan(m, sharders, global_pg) + m = DMPCollection( + module=m, + sharding_group_size=sharding_group_size, + world_size=global_world_size, + global_pg=global_pg, + plan=plan, + ) + m.apply(init_weights) + """ + + def __init__( + self, + module: nn.Module, + device: torch.device, + plan: ShardingPlan, + world_size: int, + sharding_group_size: int, + global_pg: dist.ProcessGroup, + node_group_size: Optional[int] = None, + sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, + init_data_parallel: bool = True, + init_parameters: bool = True, + data_parallel_wrapper: Optional[DataParallelWrapper] = None, + ) -> None: + assert device.type == "cuda", "DMPCollection only supports CUDA" + self._device = device + self._pg: dist.ProcessGroup = global_pg + self._plan: ShardingPlan = plan + self._device_mesh: DeviceMesh = None # pyre-ignore[8] + self._sharding_pg: dist.ProcessGroup = None # pyre-ignore[8] + self._replica_pg: dist.ProcessGroup = None # pyre-ignore[8] + self._global_rank: int = dist.get_rank(global_pg) + + self._device_mesh, self._sharding_pg, self._replica_pg = ( + self._create_process_groups( + global_rank=self._global_rank, + world_size=world_size, + local_size=sharding_group_size, + ) + ) + + self._remap_sharding_plan( + plan, self._global_rank, world_size // sharding_group_size + ) + super().__init__( + module, + ShardingEnv2D( + global_pg=self._pg, + sharding_pg=self._sharding_pg, + device_mesh=self._device_mesh, + node_group_size=node_group_size, + ), + device, + plan, + sharders, + init_data_parallel, + init_parameters, + data_parallel_wrapper, + ) + # post DMP init, we group sharded modules for parameter sync + self._modules_to_sync: List[nn.Module] = self._group_sharded_modules() + + def sync(self, include_optimizer_state: bool = True) -> None: + """ + Syncs the DMP weights across the allreduce (inter) process group + + This method is called after each forward pass to synchronize the weights of the sharded modules. + It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights, + which averages the weights across all processes in the inter-process group. + + Args: + include_optimizer_state (bool): Flag to include optimizer state syncing upon call + """ + assert self._replica_pg is not None, "replica_pg is not initialized!" + opts = dist.AllreduceCoalescedOptions() + opts.reduceOp = dist.ReduceOp.AVG + all_weights = [ + w + for emb_kernel in self._modules_to_sync + for w in emb_kernel.split_embedding_weights() + ] + handle = self._replica_pg.allreduce_coalesced(all_weights, opts=opts) + handle.wait() + + if include_optimizer_state: + # Sync accumulated square of grad of local optimizer shards + optim_list = [] + for emb_kernel in self._modules_to_sync: + all_optimizer_states = emb_kernel.get_optimizer_state() + momentum1 = [optim["sum"] for optim in all_optimizer_states] + optim_list.extend(momentum1) + # Some optimizers do not have states to sync, we check if states exist before collective call + if optim_list: + handle = self._replica_pg.allreduce_coalesced(optim_list, opts=opts) + handle.wait() + + def _create_process_groups( + self, global_rank: int, world_size: int, local_size: int + ) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: + """ + Creates process groups for sharding and replication, the process groups + are created in the same exact order on all ranks as per `dist.new_group` API. + + Args: + global_rank (int): The global rank of the current process. + world_size (int): The total number of ranks. + local_size (int): The number of ranks per sharding group. + + Returns: + Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh, + replication process group, and allreduce process group. + """ + # TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a + peer_matrix = [] + sharding_pg, replica_pg = None, None + step = world_size // local_size + + my_group_rank = global_rank % step + for group_rank in range(world_size // local_size): + peers = [step * r + group_rank for r in range(local_size)] + backend = dist.get_backend(self._pg) + curr_pg = dist.new_group(backend=backend, ranks=peers) + peer_matrix.append(peers) + if my_group_rank == group_rank: + logger.warning( + f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]" + ) + sharding_pg = curr_pg + assert sharding_pg is not None, "sharding_pg is not initialized!" + dist.barrier() + + my_inter_rank = global_rank // step + for inter_rank in range(local_size): + peers = [inter_rank * step + r for r in range(step)] + backend = dist.get_backend(self._pg) + curr_pg = dist.new_group(backend=backend, ranks=peers) + if my_inter_rank == inter_rank: + logger.warning( + f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]" + ) + replica_pg = curr_pg + assert replica_pg is not None, "replica_pg is not initialized!" + dist.barrier() + + mesh = DeviceMesh( + device_type=self._device.type, + mesh=peer_matrix, + mesh_dim_names=("replicate", "shard"), + ) + logger.warning(f"[Connection] 2D Device Mesh created: {mesh}") + + return mesh, sharding_pg, replica_pg + + def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None: + """ + Remaps the sharding plan to the local replica process group ranks + ShardingPlan is remapped inplace. + + As an example, + ShardingPlan for created for ranks [0, 2, 4, 6] is remapped to ranks [1, 3, 5, 7] + + Args: + plan (ShardingPlan): The original sharding plan. + global_rank (int): The global rank of the current process. + step (int): The number of nodes. + """ + + group_start = rank % step + for key in plan.plan: + # pyre-ignore[16] + for _, param_sharding in plan.plan[key].items(): + new_ranks = [] + for shard_rank in param_sharding.ranks: + new_ranks.append(shard_rank * step + group_start) + param_sharding.ranks = new_ranks + if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): + shards = param_sharding.sharding_spec.shards + if shards is not None: + for shard in shards: + shard_rank = shard.placement._rank * step + group_start + shard.placement = _remote_device( + f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}" + ) + return + + def _group_sharded_modules( + self, + ) -> List[nn.Module]: + # Post init DMP, save the embedding kernels + sharded_modules: List[nn.Module] = [] + + def _find_sharded_modules( + module: nn.Module, + ) -> None: + if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen): + sharded_modules.append(module) + if hasattr(module, "_lookups"): + for lookup in module._lookups: + _find_sharded_modules(lookup) + return + for _, child in module.named_children(): + _find_sharded_modules(child) + + _find_sharded_modules(self._dmp_wrapped_module) + return sharded_modules diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index 0f9a89034..940f1a0ca 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -14,7 +14,7 @@ from fbgemm_gpu.permute_pooled_embedding_modules_split import ( PermutePooledEmbeddingsSplit, ) -from torch.distributed._tensor import Shard +from torch.distributed._tensor import Replicate, Shard from torchrec.distributed.dist_data import EmbeddingsAllToOne from torchrec.distributed.embedding_lookup import ( GroupedPooledEmbeddingsLookup, @@ -145,9 +145,9 @@ def _shard( self, sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: - world_size: int = self._env.world_size + world_size: int = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -173,7 +173,9 @@ def _shard( if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] dtensor_metadata = DTensorMetadata( mesh=self._env.device_mesh, - placements=(Shard(1),), + placements=( + (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) + ), size=( ( info.embedding_config.num_embeddings_post_pruning @@ -190,6 +192,12 @@ def _shard( # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): + # Remap rank by number of replica groups if 2D parallelism is enabled + rank = ( + rank // self._env.num_sharding_groups() # pyre-ignore[16] + if self._is_2D_parallel + else rank + ) tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py index ef49cbb30..a0da146ea 100644 --- a/torchrec/distributed/sharding/grid_sharding.py +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -14,7 +14,12 @@ from fbgemm_gpu.permute_pooled_embedding_modules_split import ( PermutePooledEmbeddingsSplit, ) -from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg +from torch.distributed._tensor import Replicate, Shard +from torchrec.distributed.comm import ( + get_local_size, + intra_and_cross_node_pg, + intra_and_cross_node_pg_2D, +) from torchrec.distributed.dist_data import ( PooledEmbeddingsAllToAll, PooledEmbeddingsReduceScatter, @@ -33,6 +38,7 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, @@ -44,6 +50,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -70,8 +77,14 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._env: ShardingEnv = env + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + # pyre-ignore[16] + self._env.sharding_pg + if self._is_2D_parallel + else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank self._device = device @@ -82,9 +95,17 @@ def __init__( self._combined_embedding_names: List[str] = [] self._combined_embedding_dims: List[int] = [] - intra_pg, cross_pg = intra_and_cross_node_pg( - device, backend=dist.get_backend(self._pg) - ) + + if self._is_2D_parallel: + intra_pg, cross_pg = intra_and_cross_node_pg_2D( + # pyre-fixme[6] + self._env, + device=device, + ) + else: + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) self._intra_pg: Optional[dist.ProcessGroup] = intra_pg self._cross_pg: Optional[dist.ProcessGroup] = cross_pg self._local_size: int = ( @@ -193,7 +214,7 @@ def _shard( """ world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -210,9 +231,32 @@ def _shard( ), ) + dtensor_metadata = None + if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = ( + (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) + ) + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=placements, + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + + # to not pass onto TBE + info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] + # Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): + rank = ( + rank // self._env.num_sharding_groups() # pyre-ignore[16] + if self._is_2D_parallel + else rank + ) tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, @@ -231,6 +275,7 @@ def _shard( ), local_metadata=shards[i], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index ccba69a78..303db3b4c 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -13,7 +13,7 @@ import torch import torch.distributed as dist -from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor.placement_types import Replicate, Shard from torchrec.distributed.dist_data import ( EmbeddingsAllToOneReduce, KJTAllToAll, @@ -51,6 +51,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -119,9 +120,13 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + self._env.sharding_pg # pyre-ignore[16] + if self._is_2D_parallel + else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank if device is None: @@ -147,7 +152,7 @@ def _shard( sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(self._world_size) + [] for _ in range(self._world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -171,9 +176,12 @@ def _shard( dtensor_metadata = None if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = ( + (Replicate(), Shard(0)) if self._is_2D_parallel else (Shard(0),) + ) dtensor_metadata = DTensorMetadata( mesh=self._env.device_mesh, - placements=(Shard(0),), + placements=placements, size=( ( info.embedding_config.num_embeddings_post_pruning diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index 056295f65..2752290de 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -47,6 +47,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, ShardMetadata, ) from torchrec.distributed.utils import none_throws @@ -73,11 +74,17 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._env = env - self._device = device - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._env: ShardingEnv = env + self._device: Optional[torch.device] = device + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + self._env.sharding_pg # pyre-ignore[16] + if self._is_2D_parallel + else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank + sharded_tables_per_rank = self._shard(sharding_infos) self._sharded_tables_per_rank: List[List[ShardedEmbeddingTable]] = ( @@ -98,7 +105,7 @@ def _shard( ) -> List[List[ShardedEmbeddingTable]]: world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -123,7 +130,11 @@ def _shard( dtensor_metadata = None if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] dtensor_metadata = DTensorMetadata( - mesh=self._env.device_mesh, + mesh=( + self._env.device_mesh["replicate"] # pyre-ignore[16] + if self._is_2D_parallel + else self._env.device_mesh + ), placements=(Replicate(),), size=( info.embedding_config.num_embeddings, @@ -134,8 +145,13 @@ def _shard( # to not pass onto TBE info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] - # pyre-fixme [16] - tables_per_rank[info.param_sharding.ranks[0]].append( + rank = ( + # pyre-ignore [16] + info.param_sharding.ranks[0] // self._env.num_sharding_groups() + if self._is_2D_parallel + else info.param_sharding.ranks[0] + ) + tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, embedding_dim=info.embedding_config.embedding_dim, diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index 372eb6c75..1ba371e21 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -119,6 +119,8 @@ def _test_sharding( backend: str = "gloo", world_size: int = 2, local_size: Optional[int] = None, + world_size_2D: Optional[int] = None, + node_group_size: Optional[int] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, model_class: Type[TestSparseNNBase] = TestSparseNN, qcomms_config: Optional[QCommsConfig] = None, @@ -135,6 +137,8 @@ def _test_sharding( callable=sharding_single_rank_test, world_size=world_size, local_size=local_size, + world_size_2D=world_size_2D, + node_group_size=node_group_size, model_class=model_class, tables=self.tables if pooling == PoolingType.SUM else self.mean_tables, weighted_tables=self.weighted_tables if has_weighted_tables else None, diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 02fafafeb..dbd8f1007 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -24,7 +24,7 @@ get_qcomm_codecs_registry, QCommsConfig, ) -from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.model_parallel import DistributedModelParallel, DMPCollection from torchrec.distributed.planner import ( EmbeddingShardingPlanner, ParameterConstraints, @@ -41,6 +41,7 @@ ) from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, + EnumerableShardingSpec, ModuleSharder, ShardedTensor, ShardingEnv, @@ -242,12 +243,13 @@ def copy_state_dict( raise ValueError("Tensors with ndim > 2 are not supported") local_shard.tensor.copy_(t) elif isinstance(tensor, DTensor): - shard_offsets = tensor.to_local().local_offsets() # pyre-ignore[16] - for i, local_shard in enumerate(tensor.to_local().local_shards()): + for local_shard, global_offset in zip( + tensor.to_local().local_shards(), + tensor.to_local().local_offsets(), # pyre-ignore[16] + ): assert global_tensor.ndim == local_shard.ndim t = global_tensor.detach() local_shape = local_shard.shape - global_offset = shard_offsets[i] if t.ndim == 1: t = t[global_offset[0] : global_offset[0] + local_shape[0]] elif t.ndim == 2: @@ -283,6 +285,8 @@ def sharding_single_rank_test( feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, variable_batch_per_feature: bool = False, # VBE global_constant_batch: bool = False, + world_size_2D: Optional[int] = None, + node_group_size: Optional[int] = None, ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: @@ -336,15 +340,20 @@ def sharding_single_rank_test( assert name in local_model_named_params_as_dict local_param = local_model_named_params_as_dict[name] apply_optimizer_in_backward( - optimizer_type, [param], optimizer_kwargs + optimizer_type, + [param], + optimizer_kwargs, ) apply_optimizer_in_backward( optimizer_type, [local_param], optimizer_kwargs ) + # For 2D parallelism, we use single group world size and local world size planner = EmbeddingShardingPlanner( topology=Topology( - world_size, ctx.device.type, local_world_size=ctx.local_size + world_size=world_size_2D if world_size_2D else world_size, + compute_device=ctx.device.type, + local_world_size=node_group_size if node_group_size else ctx.local_size, ), constraints=constraints, ) @@ -359,7 +368,6 @@ def sharding_single_rank_test( TODO: may need to add some checks that only does this if we're running on a single GPU (which should be most cases). """ - for group in plan.plan: for _, parameter_sharding in cast( EmbeddingModuleShardingPlan, plan.plan[group] @@ -384,15 +392,26 @@ def sharding_single_rank_test( f"rank:{rank}/cuda:{rank}" ) - local_model = DistributedModelParallel( - local_model, - # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - sharders=sharders, - device=ctx.device, - ) + assert ctx.pg is not None + if world_size_2D is not None: + local_model = DMPCollection( + module=local_model, + sharding_group_size=world_size_2D, + world_size=ctx.world_size, + global_pg=ctx.pg, + node_group_size=node_group_size, + plan=plan, + sharders=sharders, + device=ctx.device, + ) + else: + local_model = DistributedModelParallel( + local_model, + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + sharders=sharders, + device=ctx.device, + ) dense_optim = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(local_model.named_parameters())), @@ -408,7 +427,11 @@ def sharding_single_rank_test( ) # Run a single training step of the sharded model. - local_pred = gen_full_pred_after_one_step(local_model, local_opt, local_input) + local_pred = gen_full_pred_after_one_step( + local_model, + local_opt, + local_input, + ) all_local_pred = [] for _ in range(world_size): @@ -452,6 +475,10 @@ def gen_full_pred_after_one_step( loss.backward() opt.step() + # Sync embedding weights if 2D paralleism is used. + if isinstance(model, DMPCollection): + model.sync() + # Run a forward pass of the global model. with torch.no_grad(): model.train(False) diff --git a/torchrec/distributed/tests/test_2d_sharding.py b/torchrec/distributed/tests/test_2d_sharding.py new file mode 100644 index 000000000..4d0ce7b41 --- /dev/null +++ b/torchrec/distributed/tests/test_2d_sharding.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Any, cast, Dict, Optional, Tuple, Type + +import torch +import torch.nn as nn +from hypothesis import assume, given, settings, strategies as st, Verbosity +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.test_utils.test_model_parallel import ModelParallelTestShared +from torchrec.distributed.test_utils.test_sharding import ( + create_test_sharder, + SharderType, +) +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import PoolingType +from torchrec.test_utils import skip_if_asan_class + + +@skip_if_asan_class +class Test2DSharding(ModelParallelTestShared): + """ + Tests for 2D parallelism of embedding tables + """ + + WORLD_SIZE = 8 + WORLD_SIZE_2D = 4 + + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + # QCommsConfig( + # forward_precision=CommType.FP16, backward_precision=CommType.BF16 + # ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + # None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_cw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.COLUMN_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_tw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + node_group_size=self.WORLD_SIZE_2D // 2, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=2) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + # QCommsConfig( + # forward_precision=CommType.FP16, backward_precision=CommType.BF16 + # ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + # None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.GRID_SHARD.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + node_group_size=self.WORLD_SIZE // 4, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_2": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_3": ParameterConstraints( + min_partition=10, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_4": ParameterConstraints( + min_partition=4, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_5": ParameterConstraints( + min_partition=6, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_0": ParameterConstraints( + min_partition=2, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_1": ParameterConstraints( + min_partition=3, sharding_types=[ShardingType.GRID_SHARD.value] + ), + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_rw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + pooling: PoolingType, + ) -> None: + if self.backend == "gloo": + self.skipTest( + "Gloo reduce_scatter_base fallback not supported with async_op=True" + ) + + sharding_type = ShardingType.ROW_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + pooling=pooling, + ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 141ae049c..4734f4cd7 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -843,6 +843,53 @@ def from_local(cls, world_size: int, rank: int) -> "ShardingEnv": return cls(world_size, rank, None) +class ShardingEnv2D(ShardingEnv): + """ + Creates a sharding environment for 2D parallelism, enables usage of 2D parallelism in sharding + by seamlessly switching to the sub process group (sharding_pg) for a rank. This class is used + as source of truth for TorchRec to understand if we're in a 2D parallel environment. + + NOTE: + - global pg is part of `process_group` attribute to keep the same API as ShardingEnv, + some parts of TorchRec require the global pg to work appropriately (ie: `DDPWrapper` in `DistributedModelParallel`) + - `world_size` and `rank` attributes return values relative to `sharding_pg`, this is different + from default ShardingEnv returning values relative to `global_pg` + + Attributes: + sharding_pg: The process group containing the ranks to shard on. + global_pg: The process group representing global ranks. + device_mesh: A 2D device mesh representing the topology of the global world size + on "replicate" and "shard" dimensions. + node_group_size (Optional[int]): The size of each node group. If not provided, it will be inferred + from env var `LOCAL_WORLD_SIZE`. + """ + + def __init__( + self, + sharding_pg: dist.ProcessGroup, + global_pg: dist.ProcessGroup, + device_mesh: DeviceMesh, + node_group_size: Optional[int] = None, + ) -> None: + assert device_mesh.ndim == 2, "DeviceMesh must be two dimensional!" + self.world_size: int = dist.get_world_size(sharding_pg) + self.global_world_size: int = dist.get_world_size(global_pg) + self.rank: int = dist.get_rank(sharding_pg) + self.global_rank: int = dist.get_rank(global_pg) + self.process_group: dist.ProcessGroup = ( + global_pg # to keep consistent naming between ShardingEnv and ShardingEnv2D + ) + self.sharding_pg: dist.ProcessGroup = sharding_pg + self.device_mesh: DeviceMesh = device_mesh + self.node_group_size: Optional[int] = node_group_size + + def num_sharding_groups(self) -> int: + """ + Return number of sharding groups, also known as the number of times model parallel is replicated + """ + return self.global_world_size // self.world_size + + class NullShardingContext(Multistreamable): def record_stream(self, stream: torch.Stream) -> None: pass