Skip to content

Commit

Permalink
Add methods to detach model from sparse data dist staged pipeline (#2049
Browse files Browse the repository at this point in the history
)

Summary:

Sparse data dist pipelining causes sharded trec module forward to be replaced with `PipelinedForward` variants that uses context to fetch data for current rank. However, there are use cases where we want to perform a simple forward on the trec sharded modules without using a pipeline (e.g. for simple local debug evals during training). In such cases, it is useful to have a way to detach and re-attach the model from SDD pipelining.

Reviewed By: zzzwen

Differential Revision: D57688338
  • Loading branch information
sarckk authored and facebook-github-bot committed May 29, 2024
1 parent 808c243 commit 7b64dea
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 20 deletions.
124 changes: 124 additions & 0 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,3 +983,127 @@ def on_flush_end() -> None:

# Flush end not called this time
self.assertEqual(flush_end_called, 1)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_model_detach(self) -> None:
model = self._setup_model()

sharding_type = ShardingType.TABLE_WISE.value
fused_params = {}
kernel_type = EmbeddingComputeKernel.FUSED.value

sharded_model, optim = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)

(
sharded_model_pipelined,
optim_pipelined,
) = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type
)

copy_state_dict(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
apply_jit=False,
)

pipeline_stages = [
PipelineStage(
name="data_copy",
runnable=partial(get_h2d_func, device=self.device),
stream=torch.cuda.Stream(),
),
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]

pipeline = StagedTrainPipeline(
pipeline_stages=pipeline_stages,
compute_stream=torch.cuda.current_stream(),
)

data = self._generate_data(
num_batches=12,
batch_size=32,
)
dataloader = iter(data)

for i in range(5):
batch = data[i]
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

model_in = pipeline.progress(dataloader)
optim_pipelined.zero_grad()
loss_pred, pred_pipelined = sharded_model_pipelined(model_in)
loss_pred.backward()
optim_pipelined.step()

self.assertTrue(torch.equal(pred, pred_pipelined))

# Check internal states
ebcs = [
sharded_model_pipelined.module.sparse.ebc,
sharded_model_pipelined.module.sparse.weighted_ebc,
]
for ebc in ebcs:
self.assertIsInstance(ebc.forward, PipelinedForward)
self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1)

detached_model = sdd.detach()

# Check internal states
for ebc in ebcs:
self.assertNotIsInstance(ebc.forward, PipelinedForward)
self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 0)

# Check fwd of detached model is same as non-pipelined model
with torch.no_grad():
batch = data[5].to(self.device)
_, detached_out = detached_model(batch)
_, out = sharded_model(batch)
self.assertTrue(torch.equal(detached_out, out))

# Check that pipeline re-attaches the model again without issues
for i in range(5, 12):
batch = data[i]
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

model_in = pipeline.progress(dataloader)
optim_pipelined.zero_grad()
loss_pred, pred_pipelined = sharded_model_pipelined(model_in)
loss_pred.backward()
optim_pipelined.step()

self.assertTrue(torch.equal(pred, pred_pipelined))

for ebc in ebcs:
self.assertIsInstance(ebc.forward, PipelinedForward)
self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1)

# Check pipeline exhausted
preproc_input = pipeline.progress(dataloader)
self.assertIsNone(preproc_input)
2 changes: 1 addition & 1 deletion torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _init_pipelined_modules(
self.start_sparse_data_dist(batch, context)
return

self._pipelined_modules, self._model = _rewrite_model(
self._pipelined_modules, self._model, _ = _rewrite_model(
model=self._model,
context=context,
dist_stream=self._data_dist_stream,
Expand Down
85 changes: 66 additions & 19 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.fx.node import Node
from torch.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAll
from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable
from torchrec.distributed.embedding_sharding import (
FusedKJTListSplitsAwaitable,
KJTListSplitsAwaitable,
Expand Down Expand Up @@ -671,14 +671,15 @@ def _jit_modules(module: torch.nn.Module, path: str, optional: bool = True) -> b
return len(sharded_children) > 0


# pyre-ignore[3]
def _rewrite_model( # noqa C901
model: torch.nn.Module,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
batch: Optional[In] = None,
apply_jit: bool = False,
pipelined_forward: Type[BaseForward] = PipelinedForward,
) -> Tuple[List[ShardedModule], torch.nn.Module]:
) -> Tuple[List[ShardedModule], torch.nn.Module, List[Callable[..., Any]]]:
input_model = model
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
Expand Down Expand Up @@ -714,6 +715,7 @@ def _rewrite_model( # noqa C901
# Select sharded modules, which are top-level in the forward call graph,
# i.e. don't have input transformations, i.e. rely only on 'builtins.getattr'.
pipelined_forwards = []
original_forwards = []
for node in graph.nodes:
if node.op == "call_module" and node.target in sharded_modules:
total_num_args = len(node.args) + len(node.kwargs)
Expand All @@ -724,6 +726,7 @@ def _rewrite_model( # noqa C901
if num_found == total_num_args:
logger.info(f"Module '{node.target}'' will be pipelined")
child = sharded_modules[node.target]
original_forwards.append(child.forward)
child.forward = pipelined_forward(
node.target,
arg_info_list,
Expand All @@ -744,14 +747,17 @@ def _rewrite_model( # noqa C901
if isinstance(input_model, DistributedModelParallel):
input_model.module = graph_model

return pipelined_forwards, input_model
return pipelined_forwards, input_model, original_forwards


def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> None:
def _override_input_dist_forwards(
pipelined_modules: List[ShardedModule],
) -> List[Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]]:
"""
Overrides each input dist forward to support fusing the splits collective.
NOTE: this can only be called after the input dists are initialized.
"""
original_kjt_dist_forwards = []
for module in pipelined_modules:
for child_fqn, child_module in module.named_modules():
if hasattr(child_module, "_has_uninitialized_input_dist"):
Expand All @@ -765,11 +771,13 @@ def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> Non
for input_dist in child_module._input_dists:
if hasattr(input_dist, "_dist"):
assert isinstance(input_dist._dist, KJTAllToAll)
original_kjt_dist_forwards.append(input_dist._dist.forward)
input_dist._dist.forward = KJTAllToAllForward(
pg=input_dist._dist._pg,
splits=input_dist._dist._splits,
stagger=input_dist._dist._stagger,
)
return original_kjt_dist_forwards


def get_h2d_func(batch: In, device: torch.device) -> Pipelineable:
Expand Down Expand Up @@ -870,31 +878,70 @@ def __init__(
self.context = TrainPipelineContext(version=0)
self.initialized = False
self._pipelined_modules: List[ShardedModule] = []
# pyre-ignore
self.fwd_hook = None

# pyre-ignore
self.original_forward = self.model.forward
self._original_forwards: List[Callable[..., Any]] = []
self._original_kjt_dist_forwards: List[
Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]
] = []

def forward_hook(
module: torch.nn.Module,
input: Union[torch.Tensor, Tuple[torch.Tensor]],
output: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> None:
self.wait_sparse_data_dist()
def detach(self) -> torch.nn.Module:
"""
Removes sparse data dist (SDD) pipelining from the model forward.
Returns the original model.
self.model.register_forward_hook(forward_hook)
To pipeline SDD again, call attach_model(model)
"""
if self.initialized:
assert self.fwd_hook is not None
self.fwd_hook.remove()

kjt_dist_fwds = iter(self._original_kjt_dist_forwards)
for mod, original_fwd in zip(
self._pipelined_modules, self._original_forwards
):
# pyre-ignore
mod.forward = original_fwd
for _, child_module in mod.named_modules():
if not hasattr(child_module, "_input_dists"):
continue
for input_dist in child_module._input_dists:
if hasattr(input_dist, "_dist"):
input_dist._dist.forward = next(kjt_dist_fwds)

self.initialized = False
return self.model

def start_sparse_data_dist(self, batch: In) -> In:
if not self.initialized:
self._pipelined_modules, self.model = _rewrite_model(
model=self.model,
context=self.context,
dist_stream=self.stream,
batch=batch,
apply_jit=self.apply_jit,
# Step 1: Pipeline input dist in trec sharded modules
self._pipelined_modules, self.model, self._original_forwards = (
_rewrite_model(
model=self.model,
context=self.context,
dist_stream=self.stream,
batch=batch,
apply_jit=self.apply_jit,
)
)
# initializes input dist, so we can override input dist forwards
_start_data_dist(self._pipelined_modules, batch, self.context)
_override_input_dist_forwards(self._pipelined_modules)
self._original_kjt_dist_forwards = _override_input_dist_forwards(
self._pipelined_modules
)

# Step 2: Register post-forward hook to wait SDD
def forward_hook(
module: torch.nn.Module,
input: Union[torch.Tensor, Tuple[torch.Tensor]],
output: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> None:
self.wait_sparse_data_dist()

self.fwd_hook = self.model.register_forward_hook(forward_hook)

self.initialized = True

_start_data_dist(self._pipelined_modules, batch, self.context)
Expand Down

0 comments on commit 7b64dea

Please sign in to comment.