From 7b64dea94fe217ba9acab2bff85abce963f1c81a Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 29 May 2024 12:42:27 -0700 Subject: [PATCH] Add methods to detach model from sparse data dist staged pipeline (#2049) Summary: Sparse data dist pipelining causes sharded trec module forward to be replaced with `PipelinedForward` variants that uses context to fetch data for current rank. However, there are use cases where we want to perform a simple forward on the trec sharded modules without using a pipeline (e.g. for simple local debug evals during training). In such cases, it is useful to have a way to detach and re-attach the model from SDD pipelining. Reviewed By: zzzwen Differential Revision: D57688338 --- .../tests/test_train_pipelines.py | 124 ++++++++++++++++++ .../train_pipeline/train_pipelines.py | 2 +- torchrec/distributed/train_pipeline/utils.py | 85 +++++++++--- 3 files changed, 191 insertions(+), 20 deletions(-) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index fe9c71931..08892a6ab 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -983,3 +983,127 @@ def on_flush_end() -> None: # Flush end not called this time self.assertEqual(flush_end_called, 1) + + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_model_detach(self) -> None: + model = self._setup_model() + + sharding_type = ShardingType.TABLE_WISE.value + fused_params = {} + kernel_type = EmbeddingComputeKernel.FUSED.value + + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type + ) + + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + sdd = SparseDataDistUtil[ModelInput]( + model=sharded_model_pipelined, + stream=torch.cuda.Stream(), + apply_jit=False, + ) + + pipeline_stages = [ + PipelineStage( + name="data_copy", + runnable=partial(get_h2d_func, device=self.device), + stream=torch.cuda.Stream(), + ), + PipelineStage( + name="start_sparse_data_dist", + runnable=sdd.start_sparse_data_dist, + stream=sdd.stream, + fill_callback=sdd.wait_sparse_data_dist, + ), + ] + + pipeline = StagedTrainPipeline( + pipeline_stages=pipeline_stages, + compute_stream=torch.cuda.current_stream(), + ) + + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + for i in range(5): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + model_in = pipeline.progress(dataloader) + optim_pipelined.zero_grad() + loss_pred, pred_pipelined = sharded_model_pipelined(model_in) + loss_pred.backward() + optim_pipelined.step() + + self.assertTrue(torch.equal(pred, pred_pipelined)) + + # Check internal states + ebcs = [ + sharded_model_pipelined.module.sparse.ebc, + sharded_model_pipelined.module.sparse.weighted_ebc, + ] + for ebc in ebcs: + self.assertIsInstance(ebc.forward, PipelinedForward) + self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1) + + detached_model = sdd.detach() + + # Check internal states + for ebc in ebcs: + self.assertNotIsInstance(ebc.forward, PipelinedForward) + self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 0) + + # Check fwd of detached model is same as non-pipelined model + with torch.no_grad(): + batch = data[5].to(self.device) + _, detached_out = detached_model(batch) + _, out = sharded_model(batch) + self.assertTrue(torch.equal(detached_out, out)) + + # Check that pipeline re-attaches the model again without issues + for i in range(5, 12): + batch = data[i] + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + model_in = pipeline.progress(dataloader) + optim_pipelined.zero_grad() + loss_pred, pred_pipelined = sharded_model_pipelined(model_in) + loss_pred.backward() + optim_pipelined.step() + + self.assertTrue(torch.equal(pred, pred_pipelined)) + + for ebc in ebcs: + self.assertIsInstance(ebc.forward, PipelinedForward) + self.assertEqual(len(sharded_model_pipelined._forward_hooks.items()), 1) + + # Check pipeline exhausted + preproc_input = pipeline.progress(dataloader) + self.assertIsNone(preproc_input) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index be3581df5..1dca3c956 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -309,7 +309,7 @@ def _init_pipelined_modules( self.start_sparse_data_dist(batch, context) return - self._pipelined_modules, self._model = _rewrite_model( + self._pipelined_modules, self._model, _ = _rewrite_model( model=self._model, context=context, dist_stream=self._data_dist_stream, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index b91a8f2c5..34784272e 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -33,7 +33,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.fx.node import Node from torch.profiler import record_function -from torchrec.distributed.dist_data import KJTAllToAll +from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable from torchrec.distributed.embedding_sharding import ( FusedKJTListSplitsAwaitable, KJTListSplitsAwaitable, @@ -671,6 +671,7 @@ def _jit_modules(module: torch.nn.Module, path: str, optional: bool = True) -> b return len(sharded_children) > 0 +# pyre-ignore[3] def _rewrite_model( # noqa C901 model: torch.nn.Module, context: TrainPipelineContext, @@ -678,7 +679,7 @@ def _rewrite_model( # noqa C901 batch: Optional[In] = None, apply_jit: bool = False, pipelined_forward: Type[BaseForward] = PipelinedForward, -) -> Tuple[List[ShardedModule], torch.nn.Module]: +) -> Tuple[List[ShardedModule], torch.nn.Module, List[Callable[..., Any]]]: input_model = model # Get underlying nn.Module if isinstance(model, DistributedModelParallel): @@ -714,6 +715,7 @@ def _rewrite_model( # noqa C901 # Select sharded modules, which are top-level in the forward call graph, # i.e. don't have input transformations, i.e. rely only on 'builtins.getattr'. pipelined_forwards = [] + original_forwards = [] for node in graph.nodes: if node.op == "call_module" and node.target in sharded_modules: total_num_args = len(node.args) + len(node.kwargs) @@ -724,6 +726,7 @@ def _rewrite_model( # noqa C901 if num_found == total_num_args: logger.info(f"Module '{node.target}'' will be pipelined") child = sharded_modules[node.target] + original_forwards.append(child.forward) child.forward = pipelined_forward( node.target, arg_info_list, @@ -744,14 +747,17 @@ def _rewrite_model( # noqa C901 if isinstance(input_model, DistributedModelParallel): input_model.module = graph_model - return pipelined_forwards, input_model + return pipelined_forwards, input_model, original_forwards -def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> None: +def _override_input_dist_forwards( + pipelined_modules: List[ShardedModule], +) -> List[Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]]: """ Overrides each input dist forward to support fusing the splits collective. NOTE: this can only be called after the input dists are initialized. """ + original_kjt_dist_forwards = [] for module in pipelined_modules: for child_fqn, child_module in module.named_modules(): if hasattr(child_module, "_has_uninitialized_input_dist"): @@ -765,11 +771,13 @@ def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> Non for input_dist in child_module._input_dists: if hasattr(input_dist, "_dist"): assert isinstance(input_dist._dist, KJTAllToAll) + original_kjt_dist_forwards.append(input_dist._dist.forward) input_dist._dist.forward = KJTAllToAllForward( pg=input_dist._dist._pg, splits=input_dist._dist._splits, stagger=input_dist._dist._stagger, ) + return original_kjt_dist_forwards def get_h2d_func(batch: In, device: torch.device) -> Pipelineable: @@ -870,31 +878,70 @@ def __init__( self.context = TrainPipelineContext(version=0) self.initialized = False self._pipelined_modules: List[ShardedModule] = [] + # pyre-ignore + self.fwd_hook = None # pyre-ignore - self.original_forward = self.model.forward + self._original_forwards: List[Callable[..., Any]] = [] + self._original_kjt_dist_forwards: List[ + Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]] + ] = [] - def forward_hook( - module: torch.nn.Module, - input: Union[torch.Tensor, Tuple[torch.Tensor]], - output: Union[torch.Tensor, Tuple[torch.Tensor]], - ) -> None: - self.wait_sparse_data_dist() + def detach(self) -> torch.nn.Module: + """ + Removes sparse data dist (SDD) pipelining from the model forward. + Returns the original model. - self.model.register_forward_hook(forward_hook) + To pipeline SDD again, call attach_model(model) + """ + if self.initialized: + assert self.fwd_hook is not None + self.fwd_hook.remove() + + kjt_dist_fwds = iter(self._original_kjt_dist_forwards) + for mod, original_fwd in zip( + self._pipelined_modules, self._original_forwards + ): + # pyre-ignore + mod.forward = original_fwd + for _, child_module in mod.named_modules(): + if not hasattr(child_module, "_input_dists"): + continue + for input_dist in child_module._input_dists: + if hasattr(input_dist, "_dist"): + input_dist._dist.forward = next(kjt_dist_fwds) + + self.initialized = False + return self.model def start_sparse_data_dist(self, batch: In) -> In: if not self.initialized: - self._pipelined_modules, self.model = _rewrite_model( - model=self.model, - context=self.context, - dist_stream=self.stream, - batch=batch, - apply_jit=self.apply_jit, + # Step 1: Pipeline input dist in trec sharded modules + self._pipelined_modules, self.model, self._original_forwards = ( + _rewrite_model( + model=self.model, + context=self.context, + dist_stream=self.stream, + batch=batch, + apply_jit=self.apply_jit, + ) ) # initializes input dist, so we can override input dist forwards _start_data_dist(self._pipelined_modules, batch, self.context) - _override_input_dist_forwards(self._pipelined_modules) + self._original_kjt_dist_forwards = _override_input_dist_forwards( + self._pipelined_modules + ) + + # Step 2: Register post-forward hook to wait SDD + def forward_hook( + module: torch.nn.Module, + input: Union[torch.Tensor, Tuple[torch.Tensor]], + output: Union[torch.Tensor, Tuple[torch.Tensor]], + ) -> None: + self.wait_sparse_data_dist() + + self.fwd_hook = self.model.register_forward_hook(forward_hook) + self.initialized = True _start_data_dist(self._pipelined_modules, batch, self.context)