From a48d0ffa96db80b62bc1f0a8ed02fb098eafba66 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Wed, 12 Feb 2025 06:55:36 -0800 Subject: [PATCH] change sharding topology for inter host all reduce (#2740) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2740 Users can now change network topology such that all reduce is now inter host. This design uses the ShardingEnv to inform the rank placements since the topology now changes how the rank is remapped back. Future work: make this extensible where user can define a sharding topology and torchrec can MP/DP across it Reviewed By: kausv Differential Revision: D68729283 fbshipit-source-id: 696cd190ca0e599cec7519470283011a6b14ce77 --- torchrec/distributed/comm.py | 11 +++- torchrec/distributed/model_parallel.py | 62 +++++++++++++++---- torchrec/distributed/sharding/cw_sharding.py | 3 +- .../distributed/sharding/grid_sharding.py | 2 +- torchrec/distributed/sharding/tw_sharding.py | 9 +-- .../test_utils/test_model_parallel.py | 2 + .../distributed/test_utils/test_sharding.py | 2 + .../distributed/tests/test_2d_sharding.py | 15 +++++ torchrec/distributed/types.py | 28 +++++++++ 9 files changed, 113 insertions(+), 21 deletions(-) diff --git a/torchrec/distributed/comm.py b/torchrec/distributed/comm.py index 2b445fe58..ce41552a2 100644 --- a/torchrec/distributed/comm.py +++ b/torchrec/distributed/comm.py @@ -226,9 +226,14 @@ def intra_and_cross_node_pg_2D( if _INTRA_PG_2D is None: for group_rank in range(step): - sharding_pg_peers = [ - step * r + group_rank for r in range(sharding_group_size) - ] + if env.use_inter_host_allreduce: + # for inter host all reduce, we change the sharding group calculation to be continuous + ranks = group_rank * sharding_group_size + sharding_pg_peers = list(range(ranks, ranks + sharding_group_size)) + else: + sharding_pg_peers = [ + step * r + group_rank for r in range(sharding_group_size) + ] for group in range(len(sharding_pg_peers) // devices_per_node): intra_pg_peers = sharding_pg_peers[ group * devices_per_node : (group + 1) * devices_per_node diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 492b8e348..57e451a06 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -690,6 +690,7 @@ def __init__( init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None, + use_inter_host_allreduce: bool = False, ) -> None: assert device.type == "cuda", "DMPCollection only supports CUDA" self._device = device @@ -705,13 +706,16 @@ def __init__( global_rank=self._global_rank, world_size=world_size, local_size=sharding_group_size, + use_inter_host_allreduce=use_inter_host_allreduce, ) ) self._remap_sharding_plan( plan=plan, rank=self._global_rank, - num_nodes=world_size // sharding_group_size, + step=world_size // sharding_group_size, + sharding_group_size=sharding_group_size, + use_inter_host_allreduce=use_inter_host_allreduce, ) super().__init__( module, @@ -720,6 +724,7 @@ def __init__( sharding_pg=self._sharding_pg, device_mesh=self._device_mesh, node_group_size=node_group_size, + use_inter_host_allreduce=use_inter_host_allreduce, ), device, plan, @@ -768,7 +773,11 @@ def sync(self, include_optimizer_state: bool = True) -> None: handle.wait() def _create_process_groups( - self, global_rank: int, world_size: int, local_size: int + self, + global_rank: int, + world_size: int, + local_size: int, + use_inter_host_allreduce: bool = False, ) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: """ Creates process groups for sharding and replication, the process groups @@ -784,17 +793,29 @@ def _create_process_groups( replication process group, and allreduce process group. """ peer_matrix = [] - num_nodes = world_size // local_size + mesh, sharding_pg, replica_pg = None, None, None - for group_rank in range(world_size // local_size): - peers = [num_nodes * r + group_rank for r in range(local_size)] - peer_matrix.append(peers) + logger.warning(f"[2D] Use inter host all reduce: {use_inter_host_allreduce}") + + if use_inter_host_allreduce: + # We shard on continuous set of ranks and nodes. Thereby forcing our all reduce to be inter host. + # Under this scheme sharding types such as TWRW and GRID will now take + # advantage of intra node comms as a result of the continuous set of ranks. + peer_matrix = [ + list(range(i, i + local_size)) for i in range(0, world_size, local_size) + ] + else: + step = world_size // local_size + for group_rank in range(world_size // local_size): + peers = [step * r + group_rank for r in range(local_size)] + peer_matrix.append(peers) mesh = DeviceMesh( device_type=self._device.type, mesh=peer_matrix, mesh_dim_names=("replicate", "shard"), ) + logger.warning(f"[Connection] 2D Device Mesh created: {mesh}") sharding_pg = mesh.get_group(mesh_dim="shard") logger.warning( @@ -808,7 +829,12 @@ def _create_process_groups( return mesh, sharding_pg, replica_pg def _remap_sharding_plan( - self, plan: ShardingPlan, rank: int, num_nodes: int + self, + plan: ShardingPlan, + rank: int, + step: int, + sharding_group_size: int, + use_inter_host_allreduce: bool = False, ) -> None: """ Remaps the sharding plan to the local replica process group ranks @@ -822,20 +848,32 @@ def _remap_sharding_plan( global_rank (int): The global rank of the current process. num_nodes (int): The number of nodes. """ - - group_start = rank % num_nodes + group_start = rank % step for key in plan.plan: # pyre-ignore[16] for _, param_sharding in plan.plan[key].items(): new_ranks = [] - for shard_rank in param_sharding.ranks: - new_ranks.append(shard_rank * num_nodes + group_start) + if use_inter_host_allreduce: + group = rank // sharding_group_size + new_ranks = [ + shard_rank + (group * sharding_group_size) + for shard_rank in param_sharding.ranks + ] + else: + for shard_rank in param_sharding.ranks: + new_ranks.append(shard_rank * step + group_start) param_sharding.ranks = new_ranks + if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): shards = param_sharding.sharding_spec.shards if shards is not None: for shard in shards: - shard_rank = shard.placement._rank * num_nodes + group_start + if use_inter_host_allreduce: + shard_rank = shard.placement._rank + ( + (rank // sharding_group_size) * sharding_group_size + ) + else: + shard_rank = shard.placement._rank * step + group_start shard.placement = _remote_device( f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}" ) diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index aa4fafa2b..7f8586093 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -45,6 +45,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingType, ShardMetadata, ) from torchrec.distributed.utils import none_throws @@ -191,7 +192,7 @@ def _shard( for i, rank in enumerate(info.param_sharding.ranks): # Remap rank by number of replica groups if 2D parallelism is enabled rank = ( - rank // self._env.num_sharding_groups() # pyre-ignore[16] + self._env.remap_rank(rank, ShardingType.COLUMN_WISE) # pyre-ignore[16] if self._is_2D_parallel else rank ) diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py index c5cc31e87..04328f4fb 100644 --- a/torchrec/distributed/sharding/grid_sharding.py +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -250,7 +250,7 @@ def _shard( # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): rank = ( - rank // self._env.num_sharding_groups() # pyre-ignore[16] + self._env.remap_rank(rank, ShardingType.GRID_SHARD) # pyre-ignore[16] if self._is_2D_parallel else rank ) diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index 2752290de..78421e0e5 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -48,6 +48,7 @@ ShardedTensorMetadata, ShardingEnv, ShardingEnv2D, + ShardingType, ShardMetadata, ) from torchrec.distributed.utils import none_throws @@ -128,7 +129,7 @@ def _shard( ) dtensor_metadata = None - if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + if self._env.output_dtensor: dtensor_metadata = DTensorMetadata( mesh=( self._env.device_mesh["replicate"] # pyre-ignore[16] @@ -142,12 +143,12 @@ def _shard( ), stride=info.param.stride(), ) - # to not pass onto TBE - info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] rank = ( # pyre-ignore [16] - info.param_sharding.ranks[0] // self._env.num_sharding_groups() + self._env.remap_rank( + info.param_sharding.ranks[0], ShardingType.TABLE_WISE # pyre-ignore[16] + ) if self._is_2D_parallel else info.param_sharding.ranks[0] ) diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index 879e3a3c7..b9f03bd5b 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -149,6 +149,7 @@ def _test_sharding( global_constant_batch: bool = False, pooling: PoolingType = PoolingType.SUM, data_type: DataType = DataType.FP32, + use_inter_host_allreduce: bool = False, ) -> None: self._build_tables_and_groups(data_type=data_type) self._run_multi_process_test( @@ -170,6 +171,7 @@ def _test_sharding( apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=global_constant_batch, + use_inter_host_allreduce=use_inter_host_allreduce, ) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 48b9a90ab..13d16de96 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -315,6 +315,7 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, + use_inter_host_allreduce: bool = False, input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: @@ -432,6 +433,7 @@ def sharding_single_rank_test( plan=plan, sharders=sharders, device=ctx.device, + use_inter_host_allreduce=use_inter_host_allreduce, ) else: local_model = DistributedModelParallel( diff --git a/torchrec/distributed/tests/test_2d_sharding.py b/torchrec/distributed/tests/test_2d_sharding.py index 4215d5dbc..a7755695f 100644 --- a/torchrec/distributed/tests/test_2d_sharding.py +++ b/torchrec/distributed/tests/test_2d_sharding.py @@ -78,6 +78,7 @@ def setUp(self, backend: str = "nccl") -> None: ] ), pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) def test_sharding_cw_2D( @@ -89,6 +90,7 @@ def test_sharding_cw_2D( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], pooling: PoolingType, + use_inter_host_allreduce: bool, ) -> None: if ( self.device == torch.device("cpu") @@ -122,6 +124,7 @@ def test_sharding_cw_2D( backend=self.backend, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, ) @unittest.skipIf( @@ -164,6 +167,7 @@ def test_sharding_cw_2D( ] ), pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) def test_sharding_tw_2D( @@ -175,6 +179,7 @@ def test_sharding_tw_2D( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], pooling: PoolingType, + use_inter_host_allreduce: bool, ) -> None: if ( self.device == torch.device("cpu") @@ -209,6 +214,7 @@ def test_sharding_tw_2D( backend=self.backend, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, ) @unittest.skipIf( @@ -251,6 +257,7 @@ def test_sharding_tw_2D( ] ), pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) def test_sharding_grid_2D( @@ -262,6 +269,7 @@ def test_sharding_grid_2D( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], pooling: PoolingType, + use_inter_host_allreduce: bool, ) -> None: if ( self.device == torch.device("cpu") @@ -318,6 +326,7 @@ def test_sharding_grid_2D( backend=self.backend, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, ) @unittest.skipIf( @@ -357,6 +366,7 @@ def test_sharding_grid_2D( ), variable_batch_size=st.booleans(), pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) def test_sharding_rw_2D( @@ -369,6 +379,7 @@ def test_sharding_rw_2D( ], variable_batch_size: bool, pooling: PoolingType, + use_inter_host_allreduce: bool, ) -> None: if self.backend == "gloo": self.skipTest( @@ -401,6 +412,7 @@ def test_sharding_rw_2D( apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, ) @unittest.skipIf( @@ -443,6 +455,7 @@ def test_sharding_rw_2D( ] ), pooling=st.sampled_from([PoolingType.SUM]), + use_inter_host_allreduce=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) def test_sharding_twrw_2D( @@ -454,6 +467,7 @@ def test_sharding_twrw_2D( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], pooling: PoolingType, + use_inter_host_allreduce: bool, ) -> None: if ( self.device == torch.device("cpu") @@ -488,4 +502,5 @@ def test_sharding_twrw_2D( backend=self.backend, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, pooling=pooling, + use_inter_host_allreduce=use_inter_host_allreduce, ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 8085c415f..d1d4aa953 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -879,6 +879,7 @@ def __init__( global_pg: dist.ProcessGroup, device_mesh: DeviceMesh, node_group_size: Optional[int] = None, + use_inter_host_allreduce: bool = False, ) -> None: assert device_mesh.ndim == 2, "DeviceMesh must be two dimensional!" self.world_size: int = dist.get_world_size(sharding_pg) @@ -892,6 +893,7 @@ def __init__( self.device_mesh: DeviceMesh = device_mesh self.node_group_size: Optional[int] = node_group_size self.output_dtensor: bool = True + self.use_inter_host_allreduce: bool = use_inter_host_allreduce def num_sharding_groups(self) -> int: """ @@ -899,6 +901,32 @@ def num_sharding_groups(self) -> int: """ return self.global_world_size // self.world_size + def remap_rank(self, rank: int, sharding_type: ShardingType) -> int: + """ + Remap from current rank to the appropriate rank in a continuous [0, ..., world size] array for the given sharding type. + + Args: + rank (int): rank to remap. + sharding_type (ShardingType): sharding type to remap to. + + Returns: + int: remapped rank. + """ + if sharding_type in ( + ShardingType.COLUMN_WISE, + ShardingType.TABLE_WISE, + ShardingType.GRID_SHARD, + ): + return ( + rank % self.world_size + if self.use_inter_host_allreduce + else rank // self.num_sharding_groups() + ) + else: + raise ValueError( + f"Do not need 2D specific remapping logic for sharding type: {sharding_type}" + ) + class NullShardingContext(Multistreamable): def record_stream(self, stream: torch.Stream) -> None: