Skip to content

Commit

Permalink
Allow passing planner to _shard_modules (#2732)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2732

`_shard_modules` function is used in fx_traceability tests for SDD and SemiSync pipeline. It uses a default ShardingPlanner and topology that use hardcoded batch size (512) and HBM memory limit (32Gb), respectively. This change allows specifying the ShardingPlanner and Topology to more accurately reflect the machine capabilities. The change is intentionally limited to `_shard_modules` only and not public `shard_modules` to avoid changing the contract for the latter.

Reviewed By: sarckk

Differential Revision: D69163227

fbshipit-source-id: 42852df294787e4d64ff2ea81bb1d238b5ec16ab
  • Loading branch information
che-sh authored and facebook-github-bot committed Feb 14, 2025
1 parent fd45bdc commit 7d161d9
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def _shard_modules( # noqa: C901
plan: Optional[ShardingPlan] = None,
sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None,
init_params: Optional[bool] = False,
planner: Optional[EmbeddingShardingPlanner] = None,
) -> nn.Module:
"""
See shard_modules
Expand Down Expand Up @@ -238,13 +239,14 @@ def _shard_modules( # noqa: C901
assert isinstance(
env, ShardingEnv
), "Currently hybrid sharding only support use manual sharding plan"
planner = EmbeddingShardingPlanner(
topology=Topology(
local_world_size=get_local_size(env.world_size),
world_size=env.world_size,
compute_device=device.type,
if planner is None:
planner = EmbeddingShardingPlanner(
topology=Topology(
local_world_size=get_local_size(env.world_size),
world_size=env.world_size,
compute_device=device.type,
)
)
)
pg = env.process_group
if pg is not None:
plan = planner.collective_plan(module, sharders, pg)
Expand Down

0 comments on commit 7d161d9

Please sign in to comment.