Skip to content

Commit

Permalink
Add methods to detach model from sparse data dist staged pipeline
Browse files Browse the repository at this point in the history
Summary: Sparse data dist pipelining causes sharded trec module forward to be replaced with `PipelinedForward` variants that uses context to fetch data for current rank. However, there are use cases where we want to perform a simple forward on the trec sharded modules without using a pipeline (e.g. for simple local debug evals during training). In such cases, it is useful to have a way to detach and re-attach the model from SDD pipelining.

Reviewed By: zzzwen

Differential Revision: D57688338
  • Loading branch information
sarckk authored and facebook-github-bot committed May 28, 2024
1 parent 3ca8e8b commit 23dafde
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 13 deletions.
76 changes: 76 additions & 0 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,79 @@ def on_flush_end() -> None:

# Flush end not called this time
self.assertEqual(flush_end_called, 1)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_model_detach(self) -> None:
model = self._setup_model()

sharding_type = ShardingType.TABLE_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value

(
sharded_model_pipelined,
optim_pipelined,
) = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type
)

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
apply_jit=False,
)

pipeline_stages = [
PipelineStage(
name="data_copy",
runnable=partial(get_h2d_func, device=self.device),
stream=torch.cuda.Stream(),
),
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]

pipeline = StagedTrainPipeline(
pipeline_stages=pipeline_stages,
compute_stream=torch.cuda.current_stream(),
)

data = self._generate_data(
num_batches=10,
batch_size=32,
)
dataloader = iter(data)

ebcs = [
sharded_model_pipelined.module.sparse.ebc,
sharded_model_pipelined.module.sparse.weighted_ebc,
]

pipeline.progress(dataloader)
for ebc in ebcs:
self.assertIsInstance(ebc.forward, PipelinedForward)

detached_model = sdd.detach_model()
for ebc in ebcs:
self.assertNotIsInstance(ebc.forward, PipelinedForward)

self.assertEqual(len(detached_model._forward_hooks.items()), 0)

# Check that we cannot progress pipeline as model is not attached to the pipeline for SDD
error_msg = "No model attached! Call attach_model(model) with the model first."
with self.assertRaises(AssertionError, msg=error_msg):
pipeline.progress(dataloader)

# Check we can re-attach model and progress
sdd.attach_model(detached_model)
pipeline.progress(dataloader)
self.assertEqual(len(detached_model._forward_hooks.items()), 1)
for ebc in ebcs:
self.assertIsInstance(ebc.forward, PipelinedForward)
2 changes: 1 addition & 1 deletion torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _init_pipelined_modules(
self.start_sparse_data_dist(batch, context)
return

self._pipelined_modules, self._model = _rewrite_model(
self._pipelined_modules, self._model, _ = _rewrite_model(
model=self._model,
context=context,
dist_stream=self._data_dist_stream,
Expand Down
58 changes: 46 additions & 12 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,14 +671,15 @@ def _jit_modules(module: torch.nn.Module, path: str, optional: bool = True) -> b
return len(sharded_children) > 0


# pyre-ignore[3]
def _rewrite_model( # noqa C901
model: torch.nn.Module,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
batch: Optional[In] = None,
apply_jit: bool = False,
pipelined_forward: Type[BaseForward] = PipelinedForward,
) -> Tuple[List[ShardedModule], torch.nn.Module]:
) -> Tuple[List[ShardedModule], torch.nn.Module, List[Callable[..., Any]]]:
input_model = model
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
Expand Down Expand Up @@ -714,6 +715,7 @@ def _rewrite_model( # noqa C901
# Select sharded modules, which are top-level in the forward call graph,
# i.e. don't have input transformations, i.e. rely only on 'builtins.getattr'.
pipelined_forwards = []
original_forwards = []
for node in graph.nodes:
if node.op == "call_module" and node.target in sharded_modules:
total_num_args = len(node.args) + len(node.kwargs)
Expand All @@ -724,6 +726,7 @@ def _rewrite_model( # noqa C901
if num_found == total_num_args:
logger.info(f"Module '{node.target}'' will be pipelined")
child = sharded_modules[node.target]
original_forwards.append(child.forward)
child.forward = pipelined_forward(
node.target,
arg_info_list,
Expand All @@ -744,7 +747,7 @@ def _rewrite_model( # noqa C901
if isinstance(input_model, DistributedModelParallel):
input_model.module = graph_model

return pipelined_forwards, input_model
return pipelined_forwards, input_model, original_forwards


def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> None:
Expand Down Expand Up @@ -864,15 +867,37 @@ def __init__(
apply_jit: bool = False,
) -> None:
super().__init__()
self.model = model
self.stream = stream
self.apply_jit = apply_jit
self.attach_model(model)

def detach_model(self) -> torch.nn.Module:
"""
Removes sparse data dist (SDD) pipelining from the model forward.
Returns the original model.
To pipeline SDD again, call attach_model(model)
"""
# Remove model-level fwd hook
self.fwd_hook.remove()

# Restore original fwd for trec sharded modules
for mod, original_fwd in zip(self._pipelined_modules, self._original_forwards):
# pyre-ignore
mod.forward = original_fwd

model = self.model
self.model = None
return model

def attach_model(self, model: torch.nn.Module) -> None:
self.model = model

self.context = TrainPipelineContext(version=0)
self.initialized = False
self._pipelined_modules: List[ShardedModule] = []

# pyre-ignore
self.original_forward = self.model.forward
self._original_forwards: List[Callable[..., Any]] = []

def forward_hook(
module: torch.nn.Module,
Expand All @@ -881,16 +906,24 @@ def forward_hook(
) -> None:
self.wait_sparse_data_dist()

self.model.register_forward_hook(forward_hook)
self.fwd_hook = self.model.register_forward_hook(forward_hook)

def assert_model_attached(self) -> None:
assert (
self.model is not None
), "No model attached! Call attach_model(model) with the model first."

def start_sparse_data_dist(self, batch: In) -> In:
self.assert_model_attached()
if not self.initialized:
self._pipelined_modules, self.model = _rewrite_model(
model=self.model,
context=self.context,
dist_stream=self.stream,
batch=batch,
apply_jit=self.apply_jit,
self._pipelined_modules, self.model, self._original_forwards = (
_rewrite_model(
model=self.model,
context=self.context,
dist_stream=self.stream,
batch=batch,
apply_jit=self.apply_jit,
)
)
# initializes input dist, so we can override input dist forwards
_start_data_dist(self._pipelined_modules, batch, self.context)
Expand All @@ -902,6 +935,7 @@ def start_sparse_data_dist(self, batch: In) -> In:
return batch

def wait_sparse_data_dist(self) -> None:
self.assert_model_attached()
with record_function("## wait_sparse_data_dist ##"):
with torch.cuda.stream(self.stream):
self.context.module_contexts = (
Expand Down

0 comments on commit 23dafde

Please sign in to comment.