Skip to content

Commit

Permalink
change sharding topology for inter host all reduce (#2740)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2740

Users can now change network topology such that all reduce is now inter host. This design uses the ShardingEnv to inform the rank placements since the topology now changes how the rank is remapped back.

Future work: make this extensible where user can define a sharding topology and torchrec can MP/DP across it

Reviewed By: kausv

Differential Revision: D68729283

fbshipit-source-id: 696cd190ca0e599cec7519470283011a6b14ce77
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Feb 12, 2025
1 parent 1afbf08 commit a48d0ff
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 21 deletions.
11 changes: 8 additions & 3 deletions torchrec/distributed/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,14 @@ def intra_and_cross_node_pg_2D(

if _INTRA_PG_2D is None:
for group_rank in range(step):
sharding_pg_peers = [
step * r + group_rank for r in range(sharding_group_size)
]
if env.use_inter_host_allreduce:
# for inter host all reduce, we change the sharding group calculation to be continuous
ranks = group_rank * sharding_group_size
sharding_pg_peers = list(range(ranks, ranks + sharding_group_size))
else:
sharding_pg_peers = [
step * r + group_rank for r in range(sharding_group_size)
]
for group in range(len(sharding_pg_peers) // devices_per_node):
intra_pg_peers = sharding_pg_peers[
group * devices_per_node : (group + 1) * devices_per_node
Expand Down
62 changes: 50 additions & 12 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def __init__(
init_data_parallel: bool = True,
init_parameters: bool = True,
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
use_inter_host_allreduce: bool = False,
) -> None:
assert device.type == "cuda", "DMPCollection only supports CUDA"
self._device = device
Expand All @@ -705,13 +706,16 @@ def __init__(
global_rank=self._global_rank,
world_size=world_size,
local_size=sharding_group_size,
use_inter_host_allreduce=use_inter_host_allreduce,
)
)

self._remap_sharding_plan(
plan=plan,
rank=self._global_rank,
num_nodes=world_size // sharding_group_size,
step=world_size // sharding_group_size,
sharding_group_size=sharding_group_size,
use_inter_host_allreduce=use_inter_host_allreduce,
)
super().__init__(
module,
Expand All @@ -720,6 +724,7 @@ def __init__(
sharding_pg=self._sharding_pg,
device_mesh=self._device_mesh,
node_group_size=node_group_size,
use_inter_host_allreduce=use_inter_host_allreduce,
),
device,
plan,
Expand Down Expand Up @@ -768,7 +773,11 @@ def sync(self, include_optimizer_state: bool = True) -> None:
handle.wait()

def _create_process_groups(
self, global_rank: int, world_size: int, local_size: int
self,
global_rank: int,
world_size: int,
local_size: int,
use_inter_host_allreduce: bool = False,
) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]:
"""
Creates process groups for sharding and replication, the process groups
Expand All @@ -784,17 +793,29 @@ def _create_process_groups(
replication process group, and allreduce process group.
"""
peer_matrix = []
num_nodes = world_size // local_size
mesh, sharding_pg, replica_pg = None, None, None

for group_rank in range(world_size // local_size):
peers = [num_nodes * r + group_rank for r in range(local_size)]
peer_matrix.append(peers)
logger.warning(f"[2D] Use inter host all reduce: {use_inter_host_allreduce}")

if use_inter_host_allreduce:
# We shard on continuous set of ranks and nodes. Thereby forcing our all reduce to be inter host.
# Under this scheme sharding types such as TWRW and GRID will now take
# advantage of intra node comms as a result of the continuous set of ranks.
peer_matrix = [
list(range(i, i + local_size)) for i in range(0, world_size, local_size)
]
else:
step = world_size // local_size
for group_rank in range(world_size // local_size):
peers = [step * r + group_rank for r in range(local_size)]
peer_matrix.append(peers)

mesh = DeviceMesh(
device_type=self._device.type,
mesh=peer_matrix,
mesh_dim_names=("replicate", "shard"),
)

logger.warning(f"[Connection] 2D Device Mesh created: {mesh}")
sharding_pg = mesh.get_group(mesh_dim="shard")
logger.warning(
Expand All @@ -808,7 +829,12 @@ def _create_process_groups(
return mesh, sharding_pg, replica_pg

def _remap_sharding_plan(
self, plan: ShardingPlan, rank: int, num_nodes: int
self,
plan: ShardingPlan,
rank: int,
step: int,
sharding_group_size: int,
use_inter_host_allreduce: bool = False,
) -> None:
"""
Remaps the sharding plan to the local replica process group ranks
Expand All @@ -822,20 +848,32 @@ def _remap_sharding_plan(
global_rank (int): The global rank of the current process.
num_nodes (int): The number of nodes.
"""

group_start = rank % num_nodes
group_start = rank % step
for key in plan.plan:
# pyre-ignore[16]
for _, param_sharding in plan.plan[key].items():
new_ranks = []
for shard_rank in param_sharding.ranks:
new_ranks.append(shard_rank * num_nodes + group_start)
if use_inter_host_allreduce:
group = rank // sharding_group_size
new_ranks = [
shard_rank + (group * sharding_group_size)
for shard_rank in param_sharding.ranks
]
else:
for shard_rank in param_sharding.ranks:
new_ranks.append(shard_rank * step + group_start)
param_sharding.ranks = new_ranks

if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
shards = param_sharding.sharding_spec.shards
if shards is not None:
for shard in shards:
shard_rank = shard.placement._rank * num_nodes + group_start
if use_inter_host_allreduce:
shard_rank = shard.placement._rank + (
(rank // sharding_group_size) * sharding_group_size
)
else:
shard_rank = shard.placement._rank * step + group_start
shard.placement = _remote_device(
f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}"
)
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/sharding/cw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
QuantizedCommCodecs,
ShardedTensorMetadata,
ShardingEnv,
ShardingType,
ShardMetadata,
)
from torchrec.distributed.utils import none_throws
Expand Down Expand Up @@ -191,7 +192,7 @@ def _shard(
for i, rank in enumerate(info.param_sharding.ranks):
# Remap rank by number of replica groups if 2D parallelism is enabled
rank = (
rank // self._env.num_sharding_groups() # pyre-ignore[16]
self._env.remap_rank(rank, ShardingType.COLUMN_WISE) # pyre-ignore[16]
if self._is_2D_parallel
else rank
)
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/sharding/grid_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _shard(
# pyre-fixme [6]
for i, rank in enumerate(info.param_sharding.ranks):
rank = (
rank // self._env.num_sharding_groups() # pyre-ignore[16]
self._env.remap_rank(rank, ShardingType.GRID_SHARD) # pyre-ignore[16]
if self._is_2D_parallel
else rank
)
Expand Down
9 changes: 5 additions & 4 deletions torchrec/distributed/sharding/tw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ShardedTensorMetadata,
ShardingEnv,
ShardingEnv2D,
ShardingType,
ShardMetadata,
)
from torchrec.distributed.utils import none_throws
Expand Down Expand Up @@ -128,7 +129,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
dtensor_metadata = DTensorMetadata(
mesh=(
self._env.device_mesh["replicate"] # pyre-ignore[16]
Expand All @@ -142,12 +143,12 @@ def _shard(
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

rank = (
# pyre-ignore [16]
info.param_sharding.ranks[0] // self._env.num_sharding_groups()
self._env.remap_rank(
info.param_sharding.ranks[0], ShardingType.TABLE_WISE # pyre-ignore[16]
)
if self._is_2D_parallel
else info.param_sharding.ranks[0]
)
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _test_sharding(
global_constant_batch: bool = False,
pooling: PoolingType = PoolingType.SUM,
data_type: DataType = DataType.FP32,
use_inter_host_allreduce: bool = False,
) -> None:
self._build_tables_and_groups(data_type=data_type)
self._run_multi_process_test(
Expand All @@ -170,6 +171,7 @@ def _test_sharding(
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=global_constant_batch,
use_inter_host_allreduce=use_inter_host_allreduce,
)


Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def sharding_single_rank_test(
global_constant_batch: bool = False,
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
use_inter_host_allreduce: bool = False,
input_type: str = "kjt", # "kjt" or "td"
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
Expand Down Expand Up @@ -432,6 +433,7 @@ def sharding_single_rank_test(
plan=plan,
sharders=sharders,
device=ctx.device,
use_inter_host_allreduce=use_inter_host_allreduce,
)
else:
local_model = DistributedModelParallel(
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/tests/test_2d_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def setUp(self, backend: str = "nccl") -> None:
]
),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_cw_2D(
Expand All @@ -89,6 +90,7 @@ def test_sharding_cw_2D(
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
pooling: PoolingType,
use_inter_host_allreduce: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -122,6 +124,7 @@ def test_sharding_cw_2D(
backend=self.backend,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
)

@unittest.skipIf(
Expand Down Expand Up @@ -164,6 +167,7 @@ def test_sharding_cw_2D(
]
),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_tw_2D(
Expand All @@ -175,6 +179,7 @@ def test_sharding_tw_2D(
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
pooling: PoolingType,
use_inter_host_allreduce: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -209,6 +214,7 @@ def test_sharding_tw_2D(
backend=self.backend,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
)

@unittest.skipIf(
Expand Down Expand Up @@ -251,6 +257,7 @@ def test_sharding_tw_2D(
]
),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_grid_2D(
Expand All @@ -262,6 +269,7 @@ def test_sharding_grid_2D(
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
pooling: PoolingType,
use_inter_host_allreduce: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -318,6 +326,7 @@ def test_sharding_grid_2D(
backend=self.backend,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
)

@unittest.skipIf(
Expand Down Expand Up @@ -357,6 +366,7 @@ def test_sharding_grid_2D(
),
variable_batch_size=st.booleans(),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_rw_2D(
Expand All @@ -369,6 +379,7 @@ def test_sharding_rw_2D(
],
variable_batch_size: bool,
pooling: PoolingType,
use_inter_host_allreduce: bool,
) -> None:
if self.backend == "gloo":
self.skipTest(
Expand Down Expand Up @@ -401,6 +412,7 @@ def test_sharding_rw_2D(
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
)

@unittest.skipIf(
Expand Down Expand Up @@ -443,6 +455,7 @@ def test_sharding_rw_2D(
]
),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_twrw_2D(
Expand All @@ -454,6 +467,7 @@ def test_sharding_twrw_2D(
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
pooling: PoolingType,
use_inter_host_allreduce: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -488,4 +502,5 @@ def test_sharding_twrw_2D(
backend=self.backend,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
)
Loading

0 comments on commit a48d0ff

Please sign in to comment.