Skip to content

Commit

Permalink
Add APIs to detach/attach model to sdd pipeline (#2076)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2076

## Summary
Adds 2 new user-facing APIs for `TrainPipelineSparseDist`:
- `detach()` -> torch.nn.Module. Detaches original model so it can be used outside of current train pipeline.
- `attach(model: Optional[torch.nn.Module] = none)` -> None. Attaches model to pipeline (i.e. override trec module forward and input dist fwds). If no model specified, uses original model.

## Sample use cases:
- Bulk eval on trec sharded modules (e.g. ShardedEBC) after/during pipelined training. Currently this causes issues because the model forward is swapped with pipelined forward call.
- Train on one pipeline (e.g. full-sync SDD), then swap to another pipeline (e.g. semi-sync)
- Swap out model during training by calling `attach()` on another model (no current use case but is supported)

Reviewed By: joshuadeng

Differential Revision: D57882281

fbshipit-source-id: e4687bd59bfd419922cd22f646644f647d614d32
  • Loading branch information
sarckk authored and facebook-github-bot committed Jun 6, 2024
1 parent f0156b5 commit 27a1667
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 11 deletions.
197 changes: 197 additions & 0 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,203 @@ def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None:
self.assertEqual(pred_gpu.device, self.device)
self.assertEqual(pred_gpu.cpu().size(), pred.size())

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_model_detach_during_train(self) -> None:
"""
Test the scenario in which:
1) Model training with pipeline.progress()
2) Mid-training, model is detached
3) Check that fwd of detached model is same as non-pipelined model
4) Pipeline progress() re-attaches the model and we can continue progressing
"""
data = self._generate_data(
num_batches=7,
batch_size=32,
)
dataloader = iter(data)

sharding_type = ShardingType.TABLE_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value
fused_params = {}

model = self._setup_model()
sharded_model, optim = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)

(
sharded_model_pipelined,
optim_pipelined,
) = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)
copy_state_dict(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
execute_all_batches=True,
)

for i in range(3):
batch = data[i]
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

pred_pipelined = pipeline.progress(dataloader)
self.assertTrue(torch.equal(pred, pred_pipelined))

# Check internal states
ebcs = [
sharded_model_pipelined.module.sparse.ebc,
sharded_model_pipelined.module.sparse.weighted_ebc,
]
for ebc in ebcs:
self.assertIsInstance(ebc.forward, PipelinedForward)

detached_model = pipeline.detach()

# Check internal states
for ebc in ebcs:
self.assertNotIsInstance(ebc.forward, PipelinedForward)

# Check fwd of detached model is same as non-pipelined model
with torch.no_grad():
batch = data[3].to(self.device)
_, detached_out = detached_model(batch)
_, out = sharded_model(batch)
self.assertTrue(torch.equal(detached_out, out))

# Check that pipeline re-attaches the model again without issues
for i in range(3, 7):
batch = data[i]
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

pred_pipelined = pipeline.progress(dataloader)
self.assertTrue(torch.equal(pred, pred_pipelined))

for ebc in ebcs:
self.assertIsInstance(ebc.forward, PipelinedForward)

# Check pipeline exhausted
self.assertRaises(StopIteration, pipeline.progress, dataloader)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_model_detach_after_train(self) -> None:
"""
Test the scenario in which:
1) Model training with pipeline.progress()
2) Pipeline exhausts dataloader and raises StopIteration
4) Model is detached
5) Check that fwd of detached model is same as non-pipelined model
6) Pipeline progress() with new dataloader re-attaches model
"""
data = self._generate_data(
num_batches=7,
batch_size=32,
)
dataloader = iter(data)

sharding_type = ShardingType.TABLE_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value
fused_params = {}

model = self._setup_model()
sharded_model, optim = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)

(
sharded_model_pipelined,
optim_pipelined,
) = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)
copy_state_dict(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
execute_all_batches=True,
)

for i in range(7):
batch = data[i]
# Forward + backward w/o pipelining
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

pred_pipelined = pipeline.progress(dataloader)
self.assertTrue(torch.equal(pred, pred_pipelined))

# Check pipeline exhausted
self.assertRaises(StopIteration, pipeline.progress, dataloader)

detached_model = pipeline.detach()

# Check internal states
ebcs = [
sharded_model_pipelined.module.sparse.ebc,
sharded_model_pipelined.module.sparse.weighted_ebc,
]
for ebc in ebcs:
self.assertNotIsInstance(ebc.forward, PipelinedForward)

# Check fwd of detached model is same as non-pipelined model
with torch.no_grad():
for i in range(2):
batch = data[i].to(self.device)
_, detached_out = detached_model(batch)
_, out = sharded_model(batch)
self.assertTrue(torch.equal(detached_out, out))

# Provide new loaded dataloader and check model is re-attached
data = self._generate_data(
num_batches=4,
batch_size=32,
)
dataloader = iter(data)

for i in range(4):
batch = data[i]
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()

pred_pipelined = pipeline.progress(dataloader)
self.assertTrue(torch.equal(pred, pred_pipelined))

# Check pipeline exhausted
self.assertRaises(StopIteration, pipeline.progress, dataloader)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
not torch.cuda.is_available(),
Expand Down
88 changes: 77 additions & 11 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
from collections import deque
from typing import (
Any,
Callable,
cast,
Deque,
Expand All @@ -25,9 +26,11 @@

import torch
from torch.autograd.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
from torchrec.distributed.model_parallel import ShardedModule
from torchrec.distributed.train_pipeline.utils import (
_override_input_dist_forwards,
_pipeline_detach_model,
_rewrite_model,
_start_data_dist,
_start_embedding_lookup,
Expand All @@ -50,12 +53,17 @@
)
from torchrec.distributed.types import Awaitable
from torchrec.pt2.checks import is_torchdynamo_compiling
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable


logger: logging.Logger = logging.getLogger(__name__)


class ModelDetachedException(Exception):
pass


class TrainPipeline(abc.ABC, Generic[In, Out]):
@abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out:
Expand Down Expand Up @@ -187,6 +195,15 @@ def __init__(
(torch.cuda.Stream(priority=-1)) if device.type == "cuda" else None
)

# pyre-ignore
self._original_forwards: List[Callable[..., Any]] = []

self._original_kjt_dist_forwards: List[
Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]
] = []

self._model_attached = True

self._next_index: int = 0
self.contexts: Deque[TrainPipelineContext] = deque()
self._pipelined_modules: List[ShardedModule] = []
Expand All @@ -201,6 +218,42 @@ def __init__(
self._batch_ip2: Optional[In] = None
self._context: TrainPipelineContext = context_type(version=0)

def detach(self) -> torch.nn.Module:
"""
Detaches the model from sparse data dist (SDD) pipeline.
To use the pipeline after detaching the model, pipeline.attach(model)
needs to be called.
Inflight batches are kept so pipeline.progress(data_iter) can be resumed normally.
Returns the original model.
"""
if self._pipelined_modules:
_pipeline_detach_model(
pipelined_modules=self._pipelined_modules,
original_forwards=self._original_forwards,
original_kjt_dist_forwards=self._original_kjt_dist_forwards,
)

self._model_attached = False
return self._model

def attach(self, model: Optional[torch.nn.Module] = None) -> None:
if model:
self._model = model

self._model_attached = True
if self.contexts:
self._pipeline_model(
batch=self.batches[0],
context=self.contexts[0],
pipelined_forward=PipelinedForward,
)
else:
# attaching the model after end of train pipeline
# model rewrite for SDD needs context but self.contexts is empty
# reset _pipelined_modules so _fill_pipeline will rewrite model on progress()
self._pipelined_modules = []

def _set_module_context(self, context: TrainPipelineContext) -> None:
for module in self._pipelined_modules:
module.forward.set_context(context)
Expand Down Expand Up @@ -247,6 +300,9 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
return

def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self._model_attached:
self.attach(self._model)

self.fill_pipeline(dataloader_iter)
if not self.batches:
raise StopIteration
Expand Down Expand Up @@ -293,6 +349,26 @@ def _create_context(self) -> TrainPipelineContext:
self._next_index += 1
return context

def _pipeline_model(
self,
batch: Optional[In],
context: TrainPipelineContext,
pipelined_forward: Type[PipelinedForward] = PipelinedForward,
) -> None:
self._pipelined_modules, self._model, self._original_forwards = _rewrite_model(
model=self._model,
context=context,
dist_stream=self._data_dist_stream,
batch=batch,
apply_jit=self._apply_jit,
pipelined_forward=pipelined_forward,
)
# initializes input dist, so we can override input dist forwards
self.start_sparse_data_dist(batch, context)
self._original_kjt_dist_forwards = _override_input_dist_forwards(
self._pipelined_modules
)

def _init_pipelined_modules(
self,
batch: In,
Expand All @@ -309,17 +385,7 @@ def _init_pipelined_modules(
self.start_sparse_data_dist(batch, context)
return

self._pipelined_modules, self._model, _ = _rewrite_model(
model=self._model,
context=context,
dist_stream=self._data_dist_stream,
batch=batch,
apply_jit=self._apply_jit,
pipelined_forward=pipelined_forward,
)
# initializes input dist, so we can override input dist forwards
self.start_sparse_data_dist(batch, context)
_override_input_dist_forwards(self._pipelined_modules)
self._pipeline_model(batch, context, pipelined_forward)

def copy_batch_to_gpu(
self,
Expand Down

0 comments on commit 27a1667

Please sign in to comment.