Skip to content

Commit

Permalink
add custom all reduce support for 2D parallel (#2758)
Browse files Browse the repository at this point in the history
Summary:

Add support for user defined all reduce function for embedding weight and optimizer sync.

Differential Revision: D69990461
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Feb 21, 2025
1 parent e00868c commit 5d25f8e
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 21 deletions.
66 changes: 51 additions & 15 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import copy
import logging as logger
from collections import OrderedDict
from typing import Any, cast, Dict, Iterator, List, Optional, Set, Tuple, Type
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Set, Tuple, Type

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -691,6 +691,7 @@ def __init__(
init_parameters: bool = True,
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
use_inter_host_allreduce: bool = False,
custom_all_reduce: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
) -> None:
assert device.type == "cuda", "DMPCollection only supports CUDA"
self._device = device
Expand All @@ -700,6 +701,9 @@ def __init__(
self._sharding_pg: dist.ProcessGroup = None # pyre-ignore[8]
self._replica_pg: dist.ProcessGroup = None # pyre-ignore[8]
self._global_rank: int = dist.get_rank(global_pg)
self._custom_all_reduce: Optional[Callable[[torch.Tensor], torch.Tensor]] = (
custom_all_reduce
)

self._device_mesh, self._sharding_pg, self._replica_pg = (
self._create_process_groups(
Expand Down Expand Up @@ -748,29 +752,61 @@ def sync(self, include_optimizer_state: bool = True) -> None:
include_optimizer_state (bool): Flag to include optimizer state syncing upon call
"""
assert self._replica_pg is not None, "replica_pg is not initialized!"
opts = dist.AllreduceCoalescedOptions()
opts.reduceOp = dist.ReduceOp.AVG
all_weights = [
all_weights: List[torch.Tensor] = [
w
for emb_kernel in self._modules_to_sync
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for w in emb_kernel.split_embedding_weights()
]
handle = self._replica_pg.allreduce_coalesced(all_weights, opts=opts)
handle.wait()

opts = None
if self._custom_all_reduce is None:
opts = dist.AllreduceCoalescedOptions()
opts.reduceOp = dist.ReduceOp.AVG
self._allreduce_tensors(all_weights, opts)

if include_optimizer_state:
# Sync accumulated square of grad of local optimizer shards
optim_list = []
optimizer_tensors = []
for emb_kernel in self._modules_to_sync:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
all_optimizer_states = emb_kernel.get_optimizer_state()
momentum1 = [optim["sum"] for optim in all_optimizer_states]
optim_list.extend(momentum1)
# Some optimizers do not have states to sync, we check if states exist before collective call
if optim_list:
handle = self._replica_pg.allreduce_coalesced(optim_list, opts=opts)
handle.wait()
optimizer_states = emb_kernel.get_optimizer_state()
optimizer_tensors.extend([state["sum"] for state in optimizer_states])
if optimizer_tensors:
self._allreduce_tensors(optimizer_tensors, opts)

def _allreduce_tensors(
self,
tensors: List[torch.Tensor],
opts: Optional[dist.AllreduceCoalescedOptions] = None,
) -> None:
"""
Helper to perform all reduce on given tensors, uses custom all reduce function if provided
"""
if self._custom_all_reduce is not None:
# pyre-ignore[6]
self._custom_all_reduce(tensors)
else:
handle = self._replica_pg.allreduce_coalesced(tensors, opts=opts)
handle.wait()

def set_all_reduce_hook(
self, reduce_hook: Callable[[torch.Tensor], torch.Tensor]
) -> None:
"""
Allow users to call custom all reduce function instead. Instead of using
this function, users can alternatively pass in the custom all reduce function
through the constructor. The hook expects the user to handle distributed
communication call, associated process group, and additional details.
Args:
reduce_hook (Callable[[torch.Tensor], torch.Tensor]): The custom all reduce function to use for
embedding weights and optimizer states
"""
if self._custom_all_reduce is not None:
logger.warning(
"[TorchRec 2D Parallel] Custom all reduce function already defined, overriding with new callable"
)
self._custom_all_reduce = reduce_hook

def _create_process_groups(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _test_sharding(
data_type: DataType = DataType.FP32,
use_inter_host_allreduce: bool = False,
allow_zero_batch_size: bool = False,
custom_all_reduce: bool = False,
) -> None:
self._build_tables_and_groups(data_type=data_type)
self._run_multi_process_test(
Expand All @@ -174,6 +175,7 @@ def _test_sharding(
global_constant_batch=global_constant_batch,
use_inter_host_allreduce=use_inter_host_allreduce,
allow_zero_batch_size=allow_zero_batch_size,
custom_all_reduce=custom_all_reduce,
)


Expand Down
68 changes: 62 additions & 6 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@

import random
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Protocol, Tuple, Type, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Protocol,
Tuple,
Type,
Union,
)

import torch
import torch.distributed as dist
import torch.nn as nn
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from torch.distributed._tensor import DTensor
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed.optim import (
_apply_optimizer_in_backward as apply_optimizer_in_backward,
)
Expand Down Expand Up @@ -314,11 +325,12 @@ def sharding_single_rank_test(
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
variable_batch_per_feature: bool = False, # VBE
global_constant_batch: bool = False,
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
use_inter_host_allreduce: bool = False,
world_size_2D: Optional[int] = None, # 2D parallel
node_group_size: Optional[int] = None, # 2D parallel
use_inter_host_allreduce: bool = False, # 2D parallel
input_type: str = "kjt", # "kjt" or "td"
allow_zero_batch_size: bool = False,
custom_all_reduce: bool = False, # 2D parallel
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
batch_size = (
Expand Down Expand Up @@ -428,17 +440,36 @@ def sharding_single_rank_test(
)

assert ctx.pg is not None
hook_called: bool = False
if world_size_2D is not None:
all_reduce_func = None
if custom_all_reduce:
all_reduce_pg: dist.ProcessGroup = create_device_mesh_for_2D(
use_inter_host_allreduce,
world_size=ctx.world_size,
local_size=world_size_2D,
).get_group(mesh_dim="replicate")

def _custom_hook(input: List[torch.Tensor]) -> None:
nonlocal hook_called
opts = dist.AllreduceCoalescedOptions()
opts.reduceOp = dist.ReduceOp.AVG
all_reduce_pg.allreduce_coalesced(input, opts=opts)
hook_called = True

all_reduce_func = _custom_hook

local_model = DMPCollection(
module=local_model,
sharding_group_size=world_size_2D,
world_size=ctx.world_size,
global_pg=ctx.pg,
global_pg=ctx.pg, # pyre-ignore[6]
node_group_size=node_group_size,
plan=plan,
sharders=sharders,
device=ctx.device,
use_inter_host_allreduce=use_inter_host_allreduce,
custom_all_reduce=all_reduce_func, # pyre-ignore[6]
)
else:
local_model = DistributedModelParallel(
Expand Down Expand Up @@ -469,6 +500,9 @@ def sharding_single_rank_test(
local_input,
)

if world_size_2D is not None and custom_all_reduce:
assert hook_called, "custom all reduce hook was not called"

# TODO: support non-sharded forward with zero batch size KJT
if not allow_zero_batch_size:
all_local_pred = []
Expand Down Expand Up @@ -501,6 +535,28 @@ def sharding_single_rank_test(
)


def create_device_mesh_for_2D(
use_inter_host_allreduce: bool, world_size: int, local_size: int
) -> DeviceMesh:
if use_inter_host_allreduce:
peer_matrix = [
list(range(i, i + local_size)) for i in range(0, world_size, local_size)
]
else:
peer_matrix = []
step = world_size // local_size
for group_rank in range(world_size // local_size):
peer_matrix.append([step * r + group_rank for r in range(local_size)])

mesh = DeviceMesh(
device_type="cuda",
mesh=peer_matrix,
mesh_dim_names=("replicate", "shard"),
)

return mesh


def gen_full_pred_after_one_step(
model: nn.Module,
opt: torch.optim.Optimizer,
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/tests/test_2d_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def setUp(self, backend: str = "nccl") -> None:
),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
custom_all_reduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_cw_2D(
Expand All @@ -99,6 +100,7 @@ def test_sharding_cw_2D(
],
pooling: PoolingType,
use_inter_host_allreduce: bool,
custom_all_reduce: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -133,6 +135,7 @@ def test_sharding_cw_2D(
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
custom_all_reduce=custom_all_reduce,
)

@unittest.skipIf(
Expand Down Expand Up @@ -176,6 +179,7 @@ def test_sharding_cw_2D(
),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
custom_all_reduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_tw_2D(
Expand All @@ -188,6 +192,7 @@ def test_sharding_tw_2D(
],
pooling: PoolingType,
use_inter_host_allreduce: bool,
custom_all_reduce: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -223,6 +228,7 @@ def test_sharding_tw_2D(
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
custom_all_reduce=custom_all_reduce,
)

@unittest.skipIf(
Expand Down Expand Up @@ -266,6 +272,7 @@ def test_sharding_tw_2D(
),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
custom_all_reduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_grid_2D(
Expand All @@ -278,6 +285,7 @@ def test_sharding_grid_2D(
],
pooling: PoolingType,
use_inter_host_allreduce: bool,
custom_all_reduce: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -335,6 +343,7 @@ def test_sharding_grid_2D(
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
custom_all_reduce=custom_all_reduce,
)

@unittest.skipIf(
Expand Down Expand Up @@ -375,6 +384,7 @@ def test_sharding_grid_2D(
variable_batch_size=st.booleans(),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
custom_all_reduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_rw_2D(
Expand All @@ -388,6 +398,7 @@ def test_sharding_rw_2D(
variable_batch_size: bool,
pooling: PoolingType,
use_inter_host_allreduce: bool,
custom_all_reduce: bool,
) -> None:
if self.backend == "gloo":
self.skipTest(
Expand Down Expand Up @@ -421,6 +432,7 @@ def test_sharding_rw_2D(
variable_batch_size=variable_batch_size,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
custom_all_reduce=custom_all_reduce,
)

@unittest.skipIf(
Expand Down Expand Up @@ -464,6 +476,7 @@ def test_sharding_rw_2D(
),
pooling=st.sampled_from([PoolingType.SUM]),
use_inter_host_allreduce=st.booleans(),
custom_all_reduce=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_twrw_2D(
Expand All @@ -476,6 +489,7 @@ def test_sharding_twrw_2D(
],
pooling: PoolingType,
use_inter_host_allreduce: bool,
custom_all_reduce: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -511,6 +525,7 @@ def test_sharding_twrw_2D(
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
use_inter_host_allreduce=use_inter_host_allreduce,
custom_all_reduce=custom_all_reduce,
)


Expand Down

0 comments on commit 5d25f8e

Please sign in to comment.