Skip to content

Commit

Permalink
Refactor monolithic ArgInfoStep into separate classes encoding differ…
Browse files Browse the repository at this point in the history
…ent operations (#2744)

Summary:

Torchrec rewriting logic got a bit hairy over the years, this sequence of changes aims to refactor the rewrite logic to be less convoluted and more maintainable in the future.

This change: Splits monolithic ArgInfoStep into multiple classes, each handling single potential operation (+minimum data necessary to perform it).

Internal

Diff stack navigation:
1. D69292525 and below - before refactoring
2. D69438143 - Refactor get_node_args and friends into a class 
3. D69461227 - refactor "joint lists" in ArgInfo into a list of ArgInfoStep
4. D69461226 - refactor `_build_args_kwargs` into instance methods on ArgInfo and ArgInfoStep 
5. D69461228 - split monolithic `ArgInfoStep` into a class hierarchy (**you are here**)

Differential Revision: D69461228
  • Loading branch information
che-sh authored and facebook-github-bot committed Feb 14, 2025
1 parent 3ee073d commit b750a40
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 155 deletions.
156 changes: 74 additions & 82 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@
TrainPipelineSparseDistCompAutograd,
)
from torchrec.distributed.train_pipeline.utils import (
ArgInfoStep,
DataLoadingThread,
get_h2d_func,
GetAttrArgInfoStep,
GetItemArgInfoStep,
NoopArgInfoStep,
PipelinedForward,
PipelinedPostproc,
PipelineStage,
PostprocArgInfoStep,
SparseDataDistUtil,
StageOut,
TrainPipelineContext,
Expand Down Expand Up @@ -1023,51 +1025,56 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
pipelined_weighted_ebc = pipeline._pipelined_modules[1]

# Check pipelined args
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args.args), 1)
self.assertEqual(len(ebc.forward._args.kwargs), 0)
self.assertEqual(len(ebc.forward._args.args[0].steps), 2)
[step1, step2] = ebc.forward._args.args[0].steps

self.assertEqual(step1.input_attr, "")
self.assertEqual(step1.is_getitem, False)
self.assertEqual(step2.input_attr, 0)
self.assertEqual(step2.is_getitem, True)
self.assertIsNotNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)

self.assertEqual(len(pipelined_ebc.forward._args.args), 1)
self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0)
self.assertEqual(
pipelined_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
pipelined_ebc.forward._args.args[0].steps,
[
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted),
GetItemArgInfoStep(0),
],
)
self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1)
self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0)
self.assertEqual(
pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
pipelined_weighted_ebc.forward._args.args[0].steps,
[
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
PostprocArgInfoStep(pipelined_model.module.postproc_weighted),
GetItemArgInfoStep(0),
],
)

# postproc args
self.assertEqual(len(pipeline._pipelined_postprocs), 2)
input_attr_names = {"idlist_features", "idscore_features"}
for i in range(len(pipeline._pipelined_postprocs)):
postproc_mod = pipeline._pipelined_postprocs[i]
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
self.assertEqual(len(postproc_mod._args.args[0].steps), 2)
[step1, step2] = postproc_mod._args.args[0].steps

self.assertTrue(step2.input_attr in input_attr_names)
# postprocs can be added in any order, so we can't assert on exact steps structures
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1)
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0)
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2)
self.assertEqual(
pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep()
)
self.assertIsInstance(
pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep
)

input_attr_names.remove(step2.input_attr)
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1)
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0)
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2)
self.assertEqual(
pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep()
)
self.assertIsInstance(
pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep
)

self.assertFalse(step1.is_getitem)
self.assertFalse(step2.is_getitem)
# no parent postproc module in FX graph
self.assertIsNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)
get_arg_infos = {
# pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep
postproc._args.args[0].steps[1].attr_name
for postproc in pipeline._pipelined_postprocs
}
self.assertEqual(get_arg_infos, {"idlist_features", "idscore_features"})

# pyre-ignore
@unittest.skipIf(
Expand Down Expand Up @@ -1113,37 +1120,31 @@ def test_pipeline_postproc_recursive(self) -> None:
pipelined_weighted_ebc = pipeline._pipelined_modules[1]

# Check pipelined args
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args.args), 1)
self.assertEqual(len(ebc.forward._args.kwargs), 0)
self.assertEqual(len(ebc.forward._args.args[0].steps), 2)
[step1, step2] = ebc.forward._args.args[0].steps

self.assertEqual(step1.input_attr, "")
self.assertEqual(step1.is_getitem, False)
self.assertEqual(step2.input_attr, 0)
self.assertEqual(step2.is_getitem, True)
self.assertIsNotNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)

self.assertEqual(len(pipelined_ebc.forward._args.args), 1)
self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0)
self.assertEqual(
pipelined_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
pipelined_ebc.forward._args.args[0].steps,
[
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted),
GetItemArgInfoStep(0),
],
)
self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1)
self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0)
self.assertEqual(
pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
pipelined_weighted_ebc.forward._args.args[0].steps,
[
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
PostprocArgInfoStep(pipelined_model.module.postproc_weighted),
GetItemArgInfoStep(0),
],
)

# postproc args
self.assertEqual(len(pipeline._pipelined_postprocs), 3)

# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `_postproc_module`.
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
parent_postproc_mod = pipelined_model.module._postproc_module

for postproc_mod in pipeline._pipelined_postprocs:
Expand All @@ -1152,39 +1153,30 @@ def test_pipeline_postproc_recursive(self) -> None:
if postproc_mod == pipelined_model.module.postproc_nonweighted:
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 2)
self.assertEqual(
[step.input_attr for step in args.steps], ["", "idlist_features"]
)
self.assertEqual(
[step.is_getitem for step in args.steps], [False, False]
postproc_mod._args.args[0].steps,
[
PostprocArgInfoStep(parent_postproc_mod),
GetAttrArgInfoStep("idlist_features"),
],
)
self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod)
self.assertIsNone(args.steps[1].postproc_module)

# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
elif postproc_mod == pipelined_model.module.postproc_weighted:
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 2)
self.assertEqual(
[step.input_attr for step in args.steps], ["", "idscore_features"]
)
self.assertEqual(
[step.is_getitem for step in args.steps], [False, False]
postproc_mod._args.args[0].steps,
[
PostprocArgInfoStep(parent_postproc_mod),
GetAttrArgInfoStep("idscore_features"),
],
)
self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod)
self.assertIsNone(args.steps[1].postproc_module)
elif postproc_mod == parent_postproc_mod:
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 1)
self.assertEqual(args.steps[0].input_attr, "")
self.assertFalse(args.steps[0].is_getitem)
self.assertIsNone(args.steps[0].postproc_module)
self.assertEqual(postproc_mod._args.args[0].steps, [NoopArgInfoStep()])

# pyre-ignore
@unittest.skipIf(
Expand Down
Loading

0 comments on commit b750a40

Please sign in to comment.