Skip to content

Commit

Permalink
Refactor get_node_args and friends into a class (pytorch#2741)
Browse files Browse the repository at this point in the history
Summary:

Torchrec rewriting logic got a bit hairy over the years, this sequence of changes aims to refactor the rewrite logic to be less convoluted and more maintainable in the future.

This change: _get_node_args and related functions pass around lot of "context" (train_pipeline_context, streams, etc.) that rarely or never changes + some "state" (model, pipelined_preprocs) that is accumulated during the run. Refactoring `_get_node_args` (and friends) into a class allows initializing/passing those into class constructor, and simplifies the call signatures a lot

Internal

Diff stack navigation:
1. D69292525 and below - before refactoring
2. D69438143 - Refactor get_node_args and friends into a class (**you are here**)
3. D69461227 - refactor "joint lists" in ArgInfo into a list of ArgInfoStep
4. D69461226 - refactor `_build_args_kwargs` into instance methods on ArgInfo and ArgInfoStep
5. D69461228 - split monolithic `ArgInfoStep` into a class hierarchy

Differential Revision: D69438143
  • Loading branch information
che-sh authored and facebook-github-bot committed Feb 20, 2025
1 parent 244d91f commit d73b968
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 355 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
)
from torchrec.distributed.train_pipeline.utils import (
_build_args_kwargs,
_get_node_args,
_rewrite_model,
ArgInfo,
NodeArgsHelper,
PipelinedForward,
PipelinedPostproc,
TrainPipelineContext,
Expand Down Expand Up @@ -367,10 +367,9 @@ def test_get_node_args_helper_call_module_kjt(self) -> None:
{},
)

num_found = 0
_, num_found = _get_node_args(
MagicMock(), kjt_node, set(), TrainPipelineContext(), False
)
node_args_helper = NodeArgsHelper(MagicMock(), TrainPipelineContext(), False)

_, num_found = node_args_helper.get_node_args(kjt_node)

# Weights is call_module node, so we should only find 2 args unmodified
self.assertEqual(num_found, len(kjt_args) - 1)
Loading

0 comments on commit d73b968

Please sign in to comment.