From b750a40624b108c9322a95670136d239bd08637b Mon Sep 17 00:00:00 2001 From: Evgenii Kolpakov Date: Thu, 13 Feb 2025 23:02:10 -0800 Subject: [PATCH] Refactor monolithic ArgInfoStep into separate classes encoding different operations (#2744) Summary: Torchrec rewriting logic got a bit hairy over the years, this sequence of changes aims to refactor the rewrite logic to be less convoluted and more maintainable in the future. This change: Splits monolithic ArgInfoStep into multiple classes, each handling single potential operation (+minimum data necessary to perform it). Internal Diff stack navigation: 1. D69292525 and below - before refactoring 2. D69438143 - Refactor get_node_args and friends into a class 3. D69461227 - refactor "joint lists" in ArgInfo into a list of ArgInfoStep 4. D69461226 - refactor `_build_args_kwargs` into instance methods on ArgInfo and ArgInfoStep 5. D69461228 - split monolithic `ArgInfoStep` into a class hierarchy (**you are here**) Differential Revision: D69461228 --- .../tests/test_train_pipelines.py | 156 ++++++++--------- torchrec/distributed/train_pipeline/utils.py | 164 ++++++++++-------- 2 files changed, 165 insertions(+), 155 deletions(-) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 8b491666b..617fa1d7f 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -60,12 +60,14 @@ TrainPipelineSparseDistCompAutograd, ) from torchrec.distributed.train_pipeline.utils import ( - ArgInfoStep, DataLoadingThread, get_h2d_func, + GetAttrArgInfoStep, + GetItemArgInfoStep, + NoopArgInfoStep, PipelinedForward, - PipelinedPostproc, PipelineStage, + PostprocArgInfoStep, SparseDataDistUtil, StageOut, TrainPipelineContext, @@ -1023,51 +1025,56 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: pipelined_weighted_ebc = pipeline._pipelined_modules[1] # Check pipelined args - for ebc in [pipelined_ebc, pipelined_weighted_ebc]: - self.assertEqual(len(ebc.forward._args.args), 1) - self.assertEqual(len(ebc.forward._args.kwargs), 0) - self.assertEqual(len(ebc.forward._args.args[0].steps), 2) - [step1, step2] = ebc.forward._args.args[0].steps - - self.assertEqual(step1.input_attr, "") - self.assertEqual(step1.is_getitem, False) - self.assertEqual(step2.input_attr, 0) - self.assertEqual(step2.is_getitem, True) - self.assertIsNotNone(step1.postproc_module) - self.assertIsNone(step2.postproc_module) - + self.assertEqual(len(pipelined_ebc.forward._args.args), 1) + self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0) self.assertEqual( - pipelined_ebc.forward._args.args[0].steps[0].postproc_module, - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `postproc_nonweighted`. - pipelined_model.module.postproc_nonweighted, + pipelined_ebc.forward._args.args[0].steps, + [ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. + PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted), + GetItemArgInfoStep(0), + ], ) + self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1) + self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0) self.assertEqual( - pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module, - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `postproc_weighted`. - pipelined_model.module.postproc_weighted, + pipelined_weighted_ebc.forward._args.args[0].steps, + [ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. + PostprocArgInfoStep(pipelined_model.module.postproc_weighted), + GetItemArgInfoStep(0), + ], ) # postproc args self.assertEqual(len(pipeline._pipelined_postprocs), 2) - input_attr_names = {"idlist_features", "idscore_features"} - for i in range(len(pipeline._pipelined_postprocs)): - postproc_mod = pipeline._pipelined_postprocs[i] - self.assertEqual(len(postproc_mod._args.args), 1) - self.assertEqual(len(postproc_mod._args.kwargs), 0) - self.assertEqual(len(postproc_mod._args.args[0].steps), 2) - [step1, step2] = postproc_mod._args.args[0].steps - - self.assertTrue(step2.input_attr in input_attr_names) + # postprocs can be added in any order, so we can't assert on exact steps structures + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1) + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0) + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2) + self.assertEqual( + pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep() + ) + self.assertIsInstance( + pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep + ) - input_attr_names.remove(step2.input_attr) + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1) + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0) + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2) + self.assertEqual( + pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep() + ) + self.assertIsInstance( + pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep + ) - self.assertFalse(step1.is_getitem) - self.assertFalse(step2.is_getitem) - # no parent postproc module in FX graph - self.assertIsNone(step1.postproc_module) - self.assertIsNone(step2.postproc_module) + get_arg_infos = { + # pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep + postproc._args.args[0].steps[1].attr_name + for postproc in pipeline._pipelined_postprocs + } + self.assertEqual(get_arg_infos, {"idlist_features", "idscore_features"}) # pyre-ignore @unittest.skipIf( @@ -1113,37 +1120,31 @@ def test_pipeline_postproc_recursive(self) -> None: pipelined_weighted_ebc = pipeline._pipelined_modules[1] # Check pipelined args - for ebc in [pipelined_ebc, pipelined_weighted_ebc]: - self.assertEqual(len(ebc.forward._args.args), 1) - self.assertEqual(len(ebc.forward._args.kwargs), 0) - self.assertEqual(len(ebc.forward._args.args[0].steps), 2) - [step1, step2] = ebc.forward._args.args[0].steps - - self.assertEqual(step1.input_attr, "") - self.assertEqual(step1.is_getitem, False) - self.assertEqual(step2.input_attr, 0) - self.assertEqual(step2.is_getitem, True) - self.assertIsNotNone(step1.postproc_module) - self.assertIsNone(step2.postproc_module) - + self.assertEqual(len(pipelined_ebc.forward._args.args), 1) + self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0) self.assertEqual( - pipelined_ebc.forward._args.args[0].steps[0].postproc_module, - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `postproc_nonweighted`. - pipelined_model.module.postproc_nonweighted, + pipelined_ebc.forward._args.args[0].steps, + [ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. + PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted), + GetItemArgInfoStep(0), + ], ) + self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1) + self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0) self.assertEqual( - pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module, - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `postproc_weighted`. - pipelined_model.module.postproc_weighted, + pipelined_weighted_ebc.forward._args.args[0].steps, + [ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. + PostprocArgInfoStep(pipelined_model.module.postproc_weighted), + GetItemArgInfoStep(0), + ], ) # postproc args self.assertEqual(len(pipeline._pipelined_postprocs), 3) - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute - # `_postproc_module`. + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. parent_postproc_mod = pipelined_model.module._postproc_module for postproc_mod in pipeline._pipelined_postprocs: @@ -1152,39 +1153,30 @@ def test_pipeline_postproc_recursive(self) -> None: if postproc_mod == pipelined_model.module.postproc_nonweighted: self.assertEqual(len(postproc_mod._args.args), 1) self.assertEqual(len(postproc_mod._args.kwargs), 0) - args = postproc_mod._args.args[0] - self.assertEqual(len(args.steps), 2) - self.assertEqual( - [step.input_attr for step in args.steps], ["", "idlist_features"] - ) self.assertEqual( - [step.is_getitem for step in args.steps], [False, False] + postproc_mod._args.args[0].steps, + [ + PostprocArgInfoStep(parent_postproc_mod), + GetAttrArgInfoStep("idlist_features"), + ], ) - self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod) - self.assertIsNone(args.steps[1].postproc_module) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_weighted`. elif postproc_mod == pipelined_model.module.postproc_weighted: self.assertEqual(len(postproc_mod._args.args), 1) self.assertEqual(len(postproc_mod._args.kwargs), 0) - args = postproc_mod._args.args[0] - self.assertEqual(len(args.steps), 2) - self.assertEqual( - [step.input_attr for step in args.steps], ["", "idscore_features"] - ) self.assertEqual( - [step.is_getitem for step in args.steps], [False, False] + postproc_mod._args.args[0].steps, + [ + PostprocArgInfoStep(parent_postproc_mod), + GetAttrArgInfoStep("idscore_features"), + ], ) - self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod) - self.assertIsNone(args.steps[1].postproc_module) elif postproc_mod == parent_postproc_mod: self.assertEqual(len(postproc_mod._args.args), 1) self.assertEqual(len(postproc_mod._args.kwargs), 0) - args = postproc_mod._args.args[0] - self.assertEqual(len(args.steps), 1) - self.assertEqual(args.steps[0].input_attr, "") - self.assertFalse(args.steps[0].is_getitem) - self.assertIsNone(args.steps[0].postproc_module) + self.assertEqual(postproc_mod._args.args[0].steps, [NoopArgInfoStep()]) # pyre-ignore @unittest.skipIf( diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 7d29bb3d7..1e8803455 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +import abc import copy import itertools import logging @@ -167,59 +168,78 @@ class PipelineStage: fill_callback: Optional[Callable[[], None]] = None +class BaseArgInfoStep(abc.ABC): + @abc.abstractmethod + # pyre-ignore + def process(self, arg) -> Any: + raise Exception("Not implemented in the BaseArgInfoStep") + + @dataclass -class ArgInfoStep: - """ - Representation of args from a node. +class NoopArgInfoStep(BaseArgInfoStep): + # pyre-ignore + def process(self, arg) -> Any: + return arg - Attributes: - input_attr (str): accessing an attribute (or index) from an input, - is_getitem (bool): is this a getattr (output = foo.input_attr) or getitem (output = foo[input_attr]) call - postproc_module (Optional[PipelinedPostproc]): torch.nn.Modules that transform the input - output = postproc_module(input) - constant: no transformation, just constant argument - output = constant - """ - input_attr: str - is_getitem: bool - # recursive dataclass as postproc_modules.args -> arginfo.postproc_modules -> so on - postproc_module: Optional["PipelinedPostproc"] - constant: Optional[object] +@dataclass +class GetAttrArgInfoStep(BaseArgInfoStep): + attr_name: str - # pyre-ignore[3] - def process( - self, - arg, # pyre-ignore[2] - ) -> Any: - if self.constant is not None: - if isinstance(self.constant, list): - arg = [ - (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) - for v in self.constant - ] - elif isinstance(self.constant, dict): - arg = { - k: (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) - for k, v in self.constant.items() - } - else: - arg = self.constant - elif self.postproc_module is not None: - # postproc will internally run the same logic recursively - # if its args are derived from other postproc modules - # we can get all inputs to postproc mod based on its recorded args_info + arg passed to it - arg = self.postproc_module(arg) - else: - if self.is_getitem: - arg = arg[self.input_attr] - elif self.input_attr != "": - arg = getattr(arg, self.input_attr) - else: - # neither is_getitem nor valid attr, no-op - arg = arg + # pyre-ignore + def process(self, arg) -> Any: + return getattr(arg, self.attr_name) - return arg + +@dataclass +class GetItemArgInfoStep(BaseArgInfoStep): + item_index: Union[str, int] + + # pyre-ignore + def process(self, arg) -> Any: + return arg[self.item_index] + + +@dataclass +class PostprocArgInfoStep(BaseArgInfoStep): + postproc_module: "PipelinedPostproc" + + # pyre-ignore + def process(self, arg) -> Any: + return self.postproc_module(arg) + + +@dataclass +class ScalarArgInfoStep(BaseArgInfoStep): + value: object + + # pyre-ignore + def process(self, _arg) -> Any: + return self.value + + +@dataclass +class ListArgInfoStep(BaseArgInfoStep): + value: List[object] + + # pyre-ignore + def process(self, arg) -> Any: + return [ + (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) + for v in self.value + ] + + +@dataclass +class DictArgInfoStep(BaseArgInfoStep): + value: Dict[str, object] + + # pyre-ignore + def process(self, arg) -> Any: + return { + k: (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) + for k, v in self.value.items() + } class ArgInfoStepFactory: @@ -229,32 +249,34 @@ class ArgInfoStepFactory: """ @classmethod - def noop(cls) -> ArgInfoStep: - return ArgInfoStep("", False, None, None) + def noop(cls) -> NoopArgInfoStep: + return NoopArgInfoStep() @classmethod - def get_attr(cls, name: str) -> ArgInfoStep: - return ArgInfoStep(name, False, None, None) + def get_attr(cls, name: str) -> GetAttrArgInfoStep: + return GetAttrArgInfoStep(name) @classmethod - def get_item(cls, index: str) -> ArgInfoStep: - return ArgInfoStep(index, True, None, None) + def get_item(cls, index: Union[str, int]) -> GetItemArgInfoStep: + return GetItemArgInfoStep(index) @classmethod - def postproc(cls, pipelined_postproc_module: "PipelinedPostproc") -> ArgInfoStep: - return ArgInfoStep("", False, pipelined_postproc_module, None) + def postproc( + cls, pipelined_postproc_module: "PipelinedPostproc" + ) -> PostprocArgInfoStep: + return PostprocArgInfoStep(pipelined_postproc_module) @classmethod - def from_scalar(cls, value: object) -> ArgInfoStep: - return ArgInfoStep("", False, None, value) + def from_scalar(cls, value: object) -> ScalarArgInfoStep: + return ScalarArgInfoStep(value) @classmethod - def from_list(cls, value: List[object]) -> ArgInfoStep: - return ArgInfoStep("", False, None, value) + def from_list(cls, value: List[object]) -> ListArgInfoStep: + return ListArgInfoStep(value) @classmethod - def from_dict(cls, value: Dict[str, object]) -> ArgInfoStep: - return ArgInfoStep("", False, None, value) + def from_dict(cls, value: Dict[str, object]) -> DictArgInfoStep: + return DictArgInfoStep(value) @dataclass @@ -267,16 +289,16 @@ class ArgInfo: Steps can be thought of consequtive transformations on the input, with output of previous step used as an input for the next. I.e. for 3 steps it is similar to step3(step2(step1(input))) - See `ArgInfoStep` docstring for details on which transformations can be applied + See `BaseArgInfoStep` class hierearchy for supported transformations """ - steps: List[ArgInfoStep] + steps: List[BaseArgInfoStep] - def add_step(self, step: ArgInfoStep) -> "ArgInfo": + def add_step(self, step: BaseArgInfoStep) -> "ArgInfo": self.steps.insert(0, step) return self - def append_step(self, step: ArgInfoStep) -> "ArgInfo": + def append_step(self, step: BaseArgInfoStep) -> "ArgInfo": self.steps.append(step) return self @@ -302,15 +324,10 @@ class CallArgs: def build_args_kwargs( self, initial_input: Any # pyre-ignore[2] ) -> Tuple[List[Any], Dict[str, Any]]: - """ - Creates args and kwargs for given CallArgs and input - """ args = [arg.process_steps(initial_input) for arg in self.args] - kwargs = { key: arg.process_steps(initial_input) for key, arg in self.kwargs.items() } - return args, kwargs @@ -1181,9 +1198,10 @@ def _get_node_args_helper_inner( This is for the PT2 export path where we unflatten the input to reconstruct the structure with the recorded tree spec. """ - assert arg_info.steps[0].is_getitem + step = arg_info.steps[0] + assert isinstance(step, GetItemArgInfoStep) # pyre-fixme[16] - arg = child_node.args[0][arg_info.step[0].input_attr] + arg = child_node.args[0][step.item_index] case ("torchrec.sparse.jagged_tensor", "KeyedJaggedTensor"): call_module_found = False