diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index ddac29c44..653126e9c 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -1445,7 +1445,6 @@ def __init__( fused_params = config.fused_params or {} if "cache_precision" not in fused_params: fused_params["cache_precision"] = weights_precision - self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = ( SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=list( diff --git a/torchrec/distributed/comm.py b/torchrec/distributed/comm.py index e3e50a2d1..066e3b218 100644 --- a/torchrec/distributed/comm.py +++ b/torchrec/distributed/comm.py @@ -151,3 +151,59 @@ def intra_and_cross_node_pg( dist.barrier() return _INTRA_PG, _CROSS_PG + + +def intra_and_cross_node_pg_2D( + local_pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + backend: Optional[str] = None, +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: + """ + Creates sub process groups (intra and cross node) under 2D parallelism scheme + """ + if device is not None and device.type == "meta": + return None, None + + # we keep names same for backwards compat, but they are not "intra" and "cross" in the traditional sense anymore under 2D parallelism + global _INTRA_PG # intra node process group + global _CROSS_PG # cross node process group + + my_rank = dist.get_rank() + + local_size = dist.get_world_size(local_pg) # Local replica group world size + world_size = dist.get_world_size() # Global world size + step = world_size // local_size + devices_per_node = get_local_size(world_size) + + if _INTRA_PG is None: + for group_rank in range(step): + local_pg_peers = [step * r + group_rank for r in range(local_size)] + for group in range(len(local_pg_peers) // devices_per_node): + intra_pg_peers = local_pg_peers[ + group * devices_per_node : (group + 1) * devices_per_node + ] + curr_intra_pg = dist.new_group(backend=backend, ranks=intra_pg_peers) + if my_rank in intra_pg_peers: + logger.warning( + f"[Connection] 2D rank {my_rank} -> intra_pg_peers {intra_pg_peers}" + ) + _INTRA_PG = curr_intra_pg + dist.barrier() + + if _CROSS_PG is None: + for group_rank in range(step): + local_pg_peers = [step * r + group_rank for r in range(local_size)] + for cross_group_rank in range(devices_per_node): + cross_pg_peers = [ + local_pg_peers[cross_group_rank + g * devices_per_node] + for g in range(devices_per_node) + ] + curr_cross_pg = dist.new_group(backend=backend, ranks=cross_pg_peers) + if my_rank in cross_pg_peers: + logger.warning( + f"[Connection] 2D rank {my_rank} -> cross_pg_peers {cross_pg_peers}" + ) + _CROSS_PG = curr_cross_pg + dist.barrier() + + return _INTRA_PG, _CROSS_PG diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index c737df185..95415a0ad 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -65,6 +65,7 @@ QuantizedCommCodecs, ShardedTensor, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -155,6 +156,8 @@ def create_embedding_bag_sharding( EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ]: sharding_type = sharding_infos[0].param_sharding.sharding_type + is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + if device is not None and device.type == "meta": replace_placement_with_meta_device(sharding_infos) if sharding_type == ShardingType.TABLE_WISE.value: @@ -163,6 +166,7 @@ def create_embedding_bag_sharding( env, device, qcomm_codecs_registry=qcomm_codecs_registry, + is_2D_parallel=is_2D_parallel, ) elif sharding_type == ShardingType.ROW_WISE.value: return RwPooledEmbeddingSharding( @@ -170,6 +174,7 @@ def create_embedding_bag_sharding( env, device, qcomm_codecs_registry=qcomm_codecs_registry, + is_2D_parallel=is_2D_parallel, ) elif sharding_type == ShardingType.DATA_PARALLEL.value: return DpPooledEmbeddingSharding(sharding_infos, env, device) @@ -179,6 +184,7 @@ def create_embedding_bag_sharding( env, device, qcomm_codecs_registry=qcomm_codecs_registry, + is_2D_parallel=is_2D_parallel, ) elif sharding_type == ShardingType.COLUMN_WISE.value: return CwPooledEmbeddingSharding( @@ -187,6 +193,7 @@ def create_embedding_bag_sharding( device, permute_embeddings=permute_embeddings, qcomm_codecs_registry=qcomm_codecs_registry, + is_2D_parallel=is_2D_parallel, ) elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value: return TwCwPooledEmbeddingSharding( @@ -202,6 +209,7 @@ def create_embedding_bag_sharding( env, device, qcomm_codecs_registry=qcomm_codecs_registry, + is_2D_parallel=is_2D_parallel, ) else: raise ValueError(f"Sharding type not supported {sharding_type}") @@ -942,9 +950,14 @@ def _initialize_torch_state(self) -> None: # noqa ShardedTensor._init_from_local_shards( local_shards, self._name_to_table_size[table_name], - process_group=self._env.process_group, + process_group=( + self._env.local_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), ) ) + # pass def extract_sharded_kvtensors( module: ShardedEmbeddingBagCollection, @@ -967,6 +980,7 @@ def post_state_dict_hook( prefix: str, _local_metadata: Dict[str, Any], ) -> None: + print(f"in EBC post state dict hook!") # Adjust dense MP for ( table_name, diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 11164e3e0..506ef2c30 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -14,21 +14,29 @@ import torch import torch.distributed as dist +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from torch import nn from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as ddp_default_hooks, ) from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DeviceMesh from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.types import ( + EnumerableShardingSpec, ModuleSharder, ShardedModule, ShardingEnv, + ShardingEnv2D, ShardingPlan, ) from torchrec.distributed.utils import ( @@ -440,6 +448,7 @@ def state_dict( prefix: str = "", keep_vars: bool = False, ) -> Dict[str, Any]: + print(f"DMP state dict called! {destination=}") state_dict = get_module(self).state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars ) @@ -599,3 +608,177 @@ def _reset_parameters(module: nn.Module) -> None: for _, m in module.named_modules(): if hasattr(m, "reset_parameters"): m.reset_parameters() + + +class DMPCollection(DistributedModelParallel): + """ + A wrapper around DistributedModelParallel that allows for multiple DMPs to be created and managed together. + + This class implements a 2D parallelism model where a DMP is sharded over a subset of ranks. + The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. + This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. + + Example Use Case: + Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: + - Group 0, DMP 0: [0, 2, 4, 6] + - Group 1, DMP 1: [1, 3, 5, 7] + + Each group recieves an identical sharding plan for their local world size and ranks. + If we have one table sharded in each DMP, with one shard on each rank in the group, + each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. + The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. + + Notes: + - DTensor must be used for state dict for checkpointing to work correctly. + - The expected sharding plan should be sharded across local_size (replication group world size) + and broadcasted to all ranks (planner.collective_plan(..)). + """ + + def __init__( + self, + module: nn.Module, + device: torch.device, + plan: ShardingPlan, + local_size: int, + world_size: int, + global_pg: dist.ProcessGroup, + sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, + init_data_parallel: bool = True, + init_parameters: bool = True, + data_parallel_wrapper: Optional[DataParallelWrapper] = None, + ) -> None: + assert device.type == "cuda", "DMPCollection only supports CUDA" + self._pg: dist.ProcessGroup = global_pg + self._plan = plan + self._device = device + self._device_mesh: DeviceMesh = None + self._local_pg: dist.ProcessGroup = None # pyre-ignore[8] + self._inter_pg: dist.ProcessGroup = None # pyre-ignore[8] + self._global_rank: int = dist.get_rank(global_pg) + + self._device_mesh, self._local_pg, self._inter_pg = self._create_process_groups( + global_rank=self._global_rank, + world_size=world_size, + local_size=local_size, + ) + self._remap_sharding_plan(plan, self._global_rank, world_size // local_size) + super().__init__( + module, + ShardingEnv2D( + pg=self._pg, + local_pg=self._local_pg, + device_mesh=self._device_mesh, + ), + device, + plan, + sharders, + init_data_parallel, + init_parameters, + data_parallel_wrapper, + ) + # post DMP init, we group sharded modules for parameter sync + self._modules_to_sync: List[nn.Module] = self._group_sharded_modules() + + def sync(self) -> None: + """ + Syncs the DMP weights across the allreduce (inter) process group + """ + assert self._inter_pg is not None, "inter_pg is not initialized!" + opts = dist.AllreduceCoalescedOptions() + # pyre-fixme[8] + opts.reduceOp = dist.ReduceOp.AVG + # pyre-fixme[8] + all_weights = [ + w + for emb_kernel in self._modules_to_sync + for w in emb_kernel.split_embedding_weights() + ] + handle = self._inter_pg.allreduce_coalesced(all_weights, opts=opts) + handle.wait() + + def _create_process_groups( + self, global_rank: int, world_size: int, local_size: int + ) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: + # TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a + peer_matrix = [] + local_pg, inter_pg = None, None + rank = global_rank + step = world_size // local_size + + my_group_rank = rank % step + for group_rank in range(world_size // local_size): + peers = [step * r + group_rank for r in range(local_size)] + backend = dist.get_backend(self._pg) + curr_pg = dist.new_group(backend=backend, ranks=peers) + peer_matrix.append(peers) + if my_group_rank == group_rank: + local_pg = curr_pg + dist.barrier() + + my_inter_rank = rank // step + for inter_rank in range(local_size): + peers = [inter_rank * step + r for r in range(step)] + backend = dist.get_backend(self._pg) + curr_pg = dist.new_group(backend=backend, ranks=peers) + if my_inter_rank == inter_rank: + inter_pg = curr_pg + dist.barrier() + + mesh = DeviceMesh( + device_type=self._device.type, + mesh=peer_matrix, + mesh_dim_names=("replicate", "shard"), + ) + + assert local_pg is not None, "local_pg is not initialized!" + assert inter_pg is not None, "inter_pg is not initialized!" + return mesh, local_pg, inter_pg + + def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None: + """ + Remaps the sharding plan to the local replica process group ranks + ShardingPlan is remapped inplace. + + As example, ShardingPlan for created for ranks [0, 2, 4, 6] is remapped to ranks [1, 3, 5, 7] + """ + + group_start = rank % step + for key in plan.plan: + # pyre-ignore[16] + for _, param_sharding in plan.plan[key].items(): + new_ranks = [] + for shard_rank in param_sharding.ranks: + new_ranks.append(shard_rank * step + group_start) + param_sharding.ranks = new_ranks + if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): + shards = param_sharding.sharding_spec.shards + if shards is not None: + for shard in shards: + shard_rank = shard.placement._rank * step + group_start + shard.placement = _remote_device( + f"rank:{shard_rank}/cuda:{shard_rank}" + ) + return + + def _group_sharded_modules( + self, + ): + # TODO: look into embedding sync into forward call? + # Post init DMP, save the embedding kernels + sharded_modules: List[nn.Module] = [] + + def _find_sharded_modules( + module: nn.Module, + ) -> None: + if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen): + sharded_modules.append(module) + if isinstance(module, ShardedEmbeddingBagCollection): + for lookup in module._lookups: + _find_sharded_modules(lookup) + return + for _, child in module.named_children(): + _find_sharded_modules(child) + + _find_sharded_modules(self) + print(f"Found {sharded_modules=}") + return sharded_modules diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index 0f9a89034..e41df2722 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -14,7 +14,7 @@ from fbgemm_gpu.permute_pooled_embedding_modules_split import ( PermutePooledEmbeddingsSplit, ) -from torch.distributed._tensor import Shard +from torch.distributed._tensor import Replicate, Shard from torchrec.distributed.dist_data import EmbeddingsAllToOne from torchrec.distributed.embedding_lookup import ( GroupedPooledEmbeddingsLookup, @@ -70,12 +70,14 @@ def __init__( device: Optional[torch.device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + is_2D_parallel: bool = False, ) -> None: super().__init__( sharding_infos, env, device, qcomm_codecs_registry=qcomm_codecs_registry, + is_2D_parallel=is_2D_parallel, ) self._permute_embeddings = permute_embeddings if self._permute_embeddings: @@ -145,10 +147,13 @@ def _shard( self, sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: - world_size: int = self._env.world_size + world_size: int = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] + num_2D_groups: int = ( + dist.get_world_size() // world_size if self._is_2D_parallel else world_size + ) for info in sharding_infos: # pyre-fixme [16] shards: List[ShardMetadata] = info.param_sharding.sharding_spec.shards @@ -173,7 +178,7 @@ def _shard( if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] dtensor_metadata = DTensorMetadata( mesh=self._env.device_mesh, - placements=(Shard(1),), + placements=(Replicate(), Shard(1)), size=( ( info.embedding_config.num_embeddings_post_pruning @@ -190,6 +195,8 @@ def _shard( # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): + # Remap rank by number of replica groups if 2D parallelism is enabled + rank = rank // num_2D_groups if self._is_2D_parallel else rank tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py index ef49cbb30..f32dc7262 100644 --- a/torchrec/distributed/sharding/grid_sharding.py +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -14,7 +14,12 @@ from fbgemm_gpu.permute_pooled_embedding_modules_split import ( PermutePooledEmbeddingsSplit, ) -from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg +from torch.distributed._tensor import Replicate, Shard +from torchrec.distributed.comm import ( + get_local_size, + intra_and_cross_node_pg, + intra_and_cross_node_pg_2D, +) from torchrec.distributed.dist_data import ( PooledEmbeddingsAllToAll, PooledEmbeddingsReduceScatter, @@ -33,6 +38,7 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, @@ -68,10 +74,14 @@ def __init__( device: Optional[torch.device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + is_2D_parallel: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._is_2D_parallel = is_2D_parallel + self._pg: Optional[dist.ProcessGroup] = ( + self._env.local_pg if self._is_2D_parallel else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank self._device = device @@ -82,9 +92,18 @@ def __init__( self._combined_embedding_names: List[str] = [] self._combined_embedding_dims: List[int] = [] - intra_pg, cross_pg = intra_and_cross_node_pg( - device, backend=dist.get_backend(self._pg) - ) + + if self._is_2D_parallel: + intra_pg, cross_pg = intra_and_cross_node_pg_2D( + # pyre-ignore [6] + self._env.local_pg, + device=device, + backend=dist.get_backend(self._env.local_pg), + ) + else: + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) self._intra_pg: Optional[dist.ProcessGroup] = intra_pg self._cross_pg: Optional[dist.ProcessGroup] = cross_pg self._local_size: int = ( @@ -193,8 +212,13 @@ def _shard( """ world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] + num_2D_groups: int = ( + dist.get_world_size() // world_size + if self._is_2D_parallel + else world_size # global world size / local shard group world size + ) for info in sharding_infos: # pyre-fixme [16] shards = info.param_sharding.sharding_spec.shards @@ -210,9 +234,28 @@ def _shard( ), ) + dtensor_metadata = None + if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = ( + (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) + ) + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=placements, + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + + # to not pass onto TBE + info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] + # Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): + rank = rank // num_2D_groups if self._is_2D_parallel else rank tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, @@ -231,6 +274,7 @@ def _shard( ), local_metadata=shards[i], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index ccba69a78..40eb7fc33 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -13,7 +13,7 @@ import torch import torch.distributed as dist -from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor.placement_types import Replicate, Shard from torchrec.distributed.dist_data import ( EmbeddingsAllToOneReduce, KJTAllToAll, @@ -117,11 +117,14 @@ def __init__( device: Optional[torch.device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + is_2D_parallel: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._is_2D_parallel = is_2D_parallel + self._pg: Optional[dist.ProcessGroup] = ( + self._env.local_pg if self._is_2D_parallel else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank if device is None: @@ -147,7 +150,7 @@ def _shard( sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(self._world_size) + [] for _ in range(self._world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -171,9 +174,12 @@ def _shard( dtensor_metadata = None if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = ( + (Replicate(), Shard(0)) if self._is_2D_parallel else (Shard(0),) + ) dtensor_metadata = DTensorMetadata( mesh=self._env.device_mesh, - placements=(Shard(0),), + placements=placements, size=( ( info.embedding_config.num_embeddings_post_pruning diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index 056295f65..334ad2848 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist +from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.placement_types import Replicate from torchrec.distributed.dist_data import ( EmbeddingsAllToOne, @@ -71,13 +72,18 @@ def __init__( env: ShardingEnv, device: Optional[torch.device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + is_2D_parallel: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env self._device = device - self._pg: Optional[dist.ProcessGroup] = self._env.process_group - self._world_size: int = self._env.world_size - self._rank: int = self._env.rank + self._is_2D_parallel = is_2D_parallel + self._pg: Optional[dist.ProcessGroup] = ( + self._env.local_pg if self._is_2D_parallel else self._env.process_group + ) + self._world_size = self._env.world_size + self._rank = self._env.rank + sharded_tables_per_rank = self._shard(sharding_infos) self._sharded_tables_per_rank: List[List[ShardedEmbeddingTable]] = ( @@ -98,8 +104,13 @@ def _shard( ) -> List[List[ShardedEmbeddingTable]]: world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] + num_2D_groups: int = ( + dist.get_world_size() // world_size + if self._is_2D_parallel + else world_size # global world size / local shard group world size + ) for info in sharding_infos: # pyre-fixme [16] shards = info.param_sharding.sharding_spec.shards @@ -123,7 +134,11 @@ def _shard( dtensor_metadata = None if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] dtensor_metadata = DTensorMetadata( - mesh=self._env.device_mesh, + mesh=( + self._env.device_mesh["replicate"] + if self._env.device_mesh.ndim == 2 + else self._env.device_mesh + ), placements=(Replicate(),), size=( info.embedding_config.num_embeddings, @@ -134,8 +149,14 @@ def _shard( # to not pass onto TBE info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] + rank = ( + # pyre-ignore [16] + info.param_sharding.ranks[0] // num_2D_groups + if self._is_2D_parallel + else info.param_sharding.ranks[0] + ) # pyre-fixme [16] - tables_per_rank[info.param_sharding.ranks[0]].append( + tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, embedding_dim=info.embedding_config.embedding_dim, diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index 22651f75a..000ea57e8 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -13,7 +13,13 @@ import torch import torch.distributed as dist -from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg +from torch.distributed.distributed_c10d import get_process_group_ranks +from torch.distributed.tensor.placement_types import Replicate, Shard +from torchrec.distributed.comm import ( + get_local_size, + intra_and_cross_node_pg, + intra_and_cross_node_pg_2D, +) from torchrec.distributed.dist_data import ( KJTAllToAll, PooledEmbeddingsAllToAll, @@ -34,6 +40,7 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, @@ -68,17 +75,29 @@ def __init__( device: Optional[torch.device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + is_2D_parallel: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._is_2D_parallel = is_2D_parallel + self._pg: Optional[dist.ProcessGroup] = ( + self._env.local_pg if self._is_2D_parallel else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank self._device = device self._need_pos = need_pos - intra_pg, cross_pg = intra_and_cross_node_pg( - device, backend=dist.get_backend(self._pg) - ) + if self._is_2D_parallel: + intra_pg, cross_pg = intra_and_cross_node_pg_2D( + # pyre-ignore [6] + self._env.local_pg, + device=device, + backend=dist.get_backend(self._env.local_pg), + ) + else: + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) self._intra_pg: Optional[dist.ProcessGroup] = intra_pg self._cross_pg: Optional[dist.ProcessGroup] = cross_pg self._local_size: int = ( @@ -112,11 +131,23 @@ def _shard( world_size = self._world_size local_size = self._local_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] + peer_group = ( + # pyre-ignore [6] + get_process_group_ranks(self._env.local_pg) + if self._is_2D_parallel + else None + ) for info in sharding_infos: - # pyre-ignore [16] - table_node = info.param_sharding.ranks[0] // local_size + # Under 2D parallelism we transform rank to the logical ordering in a regular parallelism scheme + rank = ( + # pyre-ignore [16] + peer_group.index(info.param_sharding.ranks[0]) + if peer_group is not None + else info.param_sharding.ranks[0] + ) + table_node = rank // local_size # pyre-fixme [16] shards = info.param_sharding.sharding_spec.shards @@ -131,6 +162,23 @@ def _shard( ), ) + dtensor_metadata = None + if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = ( + (Replicate(), Shard(0)) if self._is_2D_parallel else (Shard(0),) + ) + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=placements, + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + # to not pass onto TBE + info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] + for rank in range( table_node * local_size, (table_node + 1) * local_size, @@ -154,6 +202,7 @@ def _shard( ), local_metadata=shards[rank_idx], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, diff --git a/torchrec/distributed/shards_wrapper.py b/torchrec/distributed/shards_wrapper.py index 15f0f65be..059cc4d8c 100644 --- a/torchrec/distributed/shards_wrapper.py +++ b/torchrec/distributed/shards_wrapper.py @@ -123,6 +123,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def handle_all_gather_into_tensor(args, kwargs): + print(f"ALL GATHERING INTO TENSOR {args=} {kwargs=}") dim = args[0].local_sizes()[0][1] cat_tensor = torch.cat( [t.view(-1) for t in args[0].local_shards()], dim=0 diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index 372eb6c75..7285e79f7 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -119,6 +119,7 @@ def _test_sharding( backend: str = "gloo", world_size: int = 2, local_size: Optional[int] = None, + world_size_2D: Optional[int] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, model_class: Type[TestSparseNNBase] = TestSparseNN, qcomms_config: Optional[QCommsConfig] = None, @@ -135,6 +136,7 @@ def _test_sharding( callable=sharding_single_rank_test, world_size=world_size, local_size=local_size, + world_size_2D=world_size_2D, model_class=model_class, tables=self.tables if pooling == PoolingType.SUM else self.mean_tables, weighted_tables=self.weighted_tables if has_weighted_tables else None, diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 02fafafeb..8382d98e8 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -24,7 +24,7 @@ get_qcomm_codecs_registry, QCommsConfig, ) -from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.model_parallel import DistributedModelParallel, DMPCollection from torchrec.distributed.planner import ( EmbeddingShardingPlanner, ParameterConstraints, @@ -41,6 +41,7 @@ ) from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, + EnumerableShardingSpec, ModuleSharder, ShardedTensor, ShardingEnv, @@ -242,12 +243,12 @@ def copy_state_dict( raise ValueError("Tensors with ndim > 2 are not supported") local_shard.tensor.copy_(t) elif isinstance(tensor, DTensor): - shard_offsets = tensor.to_local().local_offsets() # pyre-ignore[16] - for i, local_shard in enumerate(tensor.to_local().local_shards()): + for local_shard, global_offset in zip( + tensor.to_local().local_shards(), tensor.to_local().local_offsets() + ): assert global_tensor.ndim == local_shard.ndim t = global_tensor.detach() local_shape = local_shard.shape - global_offset = shard_offsets[i] if t.ndim == 1: t = t[global_offset[0] : global_offset[0] + local_shape[0]] elif t.ndim == 2: @@ -283,6 +284,7 @@ def sharding_single_rank_test( feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, variable_batch_per_feature: bool = False, # VBE global_constant_batch: bool = False, + world_size_2D: Optional[int] = None, ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: @@ -310,6 +312,7 @@ def sharding_single_rank_test( global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) local_input = inputs[0][1][rank].to(ctx.device) + print(f"rank: {rank} local_input: {local_input.idlist_features}") # Shard model. local_model = model_class( @@ -336,15 +339,20 @@ def sharding_single_rank_test( assert name in local_model_named_params_as_dict local_param = local_model_named_params_as_dict[name] apply_optimizer_in_backward( - optimizer_type, [param], optimizer_kwargs + optimizer_type, + [param], + optimizer_kwargs, ) apply_optimizer_in_backward( optimizer_type, [local_param], optimizer_kwargs ) + # For 2D parallelism, we use single group world size and local world size planner = EmbeddingShardingPlanner( topology=Topology( - world_size, ctx.device.type, local_world_size=ctx.local_size + world_size=world_size_2D if world_size_2D else world_size, + compute_device=ctx.device.type, + local_world_size=ctx.local_size, ), constraints=constraints, ) @@ -359,7 +367,6 @@ def sharding_single_rank_test( TODO: may need to add some checks that only does this if we're running on a single GPU (which should be most cases). """ - for group in plan.plan: for _, parameter_sharding in cast( EmbeddingModuleShardingPlan, plan.plan[group] @@ -384,21 +391,25 @@ def sharding_single_rank_test( f"rank:{rank}/cuda:{rank}" ) - local_model = DistributedModelParallel( - local_model, - # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - sharders=sharders, - device=ctx.device, - ) - - dense_optim = KeyedOptimizerWrapper( - dict(in_backward_optimizer_filter(local_model.named_parameters())), - lambda params: torch.optim.SGD(params, lr=0.1), - ) - local_opt = CombinedOptimizer([local_model.fused_optimizer, dense_optim]) + assert ctx.pg is not None + if world_size_2D is not None: + local_model = DMPCollection( + module=local_model, + local_size=world_size_2D, + world_size=ctx.world_size, + global_pg=ctx.pg, + plan=plan, + sharders=sharders, + device=ctx.device, + ) + else: + local_model = DistributedModelParallel( + local_model, + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + sharders=sharders, + device=ctx.device, + ) # Load model state from the global model. copy_state_dict( @@ -407,8 +418,18 @@ def sharding_single_rank_test( exclude_predfix="sparse.pooled_embedding_arch.embedding_modules._itp_iter", ) + dense_optim = KeyedOptimizerWrapper( + dict(in_backward_optimizer_filter(local_model.named_parameters())), + lambda params: torch.optim.SGD(params, lr=0.1), + ) + local_opt = CombinedOptimizer([local_model.fused_optimizer, dense_optim]) + # Run a single training step of the sharded model. - local_pred = gen_full_pred_after_one_step(local_model, local_opt, local_input) + local_pred = gen_full_pred_after_one_step( + local_model, + local_opt, + local_input, + ) all_local_pred = [] for _ in range(world_size): @@ -452,6 +473,10 @@ def gen_full_pred_after_one_step( loss.backward() opt.step() + # Sync embedding weights if 2D paralleism is used. + if isinstance(model, DMPCollection): + model.sync() + # Run a forward pass of the global model. with torch.no_grad(): model.train(False) diff --git a/torchrec/distributed/tests/test_2d_sharding.py b/torchrec/distributed/tests/test_2d_sharding.py new file mode 100644 index 000000000..485770e6c --- /dev/null +++ b/torchrec/distributed/tests/test_2d_sharding.py @@ -0,0 +1,493 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Any, cast, Dict, Optional, Tuple, Type + +import torch +import torch.nn as nn +from hypothesis import assume, given, settings, strategies as st, Verbosity +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.test_utils.test_model_parallel import ModelParallelTestShared +from torchrec.distributed.test_utils.test_sharding import ( + create_test_sharder, + SharderType, +) +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import PoolingType +from torchrec.test_utils import skip_if_asan_class + + +@skip_if_asan_class +class Test2DSharding(ModelParallelTestShared): + """ + Tests for 2D parallelism of embedding tables + """ + + WORLD_SIZE = 8 + WORLD_SIZE_2D = 4 + + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + # QCommsConfig( + # forward_precision=CommType.FP16, backward_precision=CommType.BF16 + # ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + # None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_cw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.COLUMN_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + local_size=self.WORLD_SIZE_2D // 2, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_tw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + local_size=self.WORLD_SIZE_2D // 2, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=2) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_twrw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_ROW_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + local_size=self.WORLD_SIZE_2D // 2, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=2) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + # QCommsConfig( + # forward_precision=CommType.FP16, backward_precision=CommType.BF16 + # ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + # None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.GRID_SHARD.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + local_size=self.WORLD_SIZE // 4, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_2": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_3": ParameterConstraints( + min_partition=10, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_4": ParameterConstraints( + min_partition=4, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_5": ParameterConstraints( + min_partition=6, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_0": ParameterConstraints( + min_partition=2, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_1": ParameterConstraints( + min_partition=3, sharding_types=[ShardingType.GRID_SHARD.value] + ), + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_rw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + pooling: PoolingType, + ) -> None: + if self.backend == "gloo": + self.skipTest( + "Gloo reduce_scatter_base fallback not supported with async_op=True" + ) + + sharding_type = ShardingType.ROW_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + + self._test_sharding( + world_size=self.WORLD_SIZE, + local_size=self.WORLD_SIZE // 4, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + pooling=pooling, + ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 141ae049c..8a3a2bb63 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -20,6 +20,7 @@ Iterator, List, Optional, + Protocol, Tuple, Type, TypeVar, @@ -843,6 +844,33 @@ def from_local(cls, world_size: int, rank: int) -> "ShardingEnv": return cls(world_size, rank, None) +class ShardingEnv2D: + """ + Creates a sharding environment for 2D parallelism. + This environment includes an additional local process group and device mesh. + The attributes of this class are changed to reflect the local process group and device mesh. + + Attributes: + local_pg: The process group containing the ranks of the local sharding group. + device_mesh: A 2D device mesh representing the topology of the global world size + on "replicate" and "shard" dimensions. + rank: The rank of device relative to local process group. + world_size: The world size of the local process group. + """ + + def __init__( + self, + pg: dist.ProcessGroup, + local_pg: dist.ProcessGroup, + device_mesh: DeviceMesh, + ) -> None: + self.world_size = dist.get_world_size(local_pg) + self.rank = dist.get_rank(local_pg) + self.process_group = pg + self.local_pg: dist.ProcessGroup = local_pg + self.device_mesh: dist.ProcessGroup = device_mesh + + class NullShardingContext(Multistreamable): def record_stream(self, stream: torch.Stream) -> None: pass