diff --git a/torchrec/distributed/train_pipeline/__init__.py b/torchrec/distributed/train_pipeline/__init__.py index d7b38d2b0..15cbbe831 100644 --- a/torchrec/distributed/train_pipeline/__init__.py +++ b/torchrec/distributed/train_pipeline/__init__.py @@ -26,6 +26,7 @@ _to_device, # noqa _wait_for_batch, # noqa ArgInfo, # noqa + ArgInfoStep, # noqa DataLoadingThread, # noqa In, # noqa Out, # noqa diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index dc23593a0..844d3b500 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -60,6 +60,7 @@ TrainPipelineSparseDistCompAutograd, ) from torchrec.distributed.train_pipeline.utils import ( + ArgInfoStep, DataLoadingThread, get_h2d_func, PipelinedForward, @@ -1024,22 +1025,24 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: # Check pipelined args for ebc in [pipelined_ebc, pipelined_weighted_ebc]: self.assertEqual(len(ebc.forward._args), 1) - self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) - self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) - self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) - self.assertIsInstance( - ebc.forward._args[0].postproc_modules[0], PipelinedPostproc - ) - self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) + self.assertEqual(len(ebc.forward._args[0].steps), 2) + [step1, step2] = ebc.forward._args[0].steps + + self.assertEqual(step1.input_attr, "") + self.assertEqual(step1.is_getitem, False) + self.assertEqual(step2.input_attr, 0) + self.assertEqual(step2.is_getitem, True) + self.assertIsNotNone(step1.postproc_module) + self.assertIsNone(step2.postproc_module) self.assertEqual( - pipelined_ebc.forward._args[0].postproc_modules[0], + pipelined_ebc.forward._args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_nonweighted`. pipelined_model.module.postproc_nonweighted, ) self.assertEqual( - pipelined_weighted_ebc.forward._args[0].postproc_modules[0], + pipelined_weighted_ebc.forward._args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_weighted`. pipelined_model.module.postproc_weighted, @@ -1051,15 +1054,18 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: for i in range(len(pipeline._pipelined_postprocs)): postproc_mod = pipeline._pipelined_postprocs[i] self.assertEqual(len(postproc_mod._args), 1) + self.assertEqual(len(postproc_mod._args[0].steps), 2) + [step1, step2] = postproc_mod._args[0].steps + + self.assertTrue(step2.input_attr in input_attr_names) - input_attr_name = postproc_mod._args[0].input_attrs[1] - self.assertTrue(input_attr_name in input_attr_names) - self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name]) - input_attr_names.remove(input_attr_name) + input_attr_names.remove(step2.input_attr) - self.assertEqual(postproc_mod._args[0].is_getitems, [False, False]) + self.assertFalse(step1.is_getitem) + self.assertFalse(step2.is_getitem) # no parent postproc module in FX graph - self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None]) + self.assertIsNone(step1.postproc_module) + self.assertIsNone(step2.postproc_module) # pyre-ignore @unittest.skipIf( @@ -1107,22 +1113,24 @@ def test_pipeline_postproc_recursive(self) -> None: # Check pipelined args for ebc in [pipelined_ebc, pipelined_weighted_ebc]: self.assertEqual(len(ebc.forward._args), 1) - self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) - self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) - self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) - self.assertIsInstance( - ebc.forward._args[0].postproc_modules[0], PipelinedPostproc - ) - self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) + self.assertEqual(len(ebc.forward._args[0].steps), 2) + [step1, step2] = ebc.forward._args[0].steps + + self.assertEqual(step1.input_attr, "") + self.assertEqual(step1.is_getitem, False) + self.assertEqual(step2.input_attr, 0) + self.assertEqual(step2.is_getitem, True) + self.assertIsNotNone(step1.postproc_module) + self.assertIsNone(step2.postproc_module) self.assertEqual( - pipelined_ebc.forward._args[0].postproc_modules[0], + pipelined_ebc.forward._args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_nonweighted`. pipelined_model.module.postproc_nonweighted, ) self.assertEqual( - pipelined_weighted_ebc.forward._args[0].postproc_modules[0], + pipelined_weighted_ebc.forward._args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_weighted`. pipelined_model.module.postproc_weighted, @@ -1141,33 +1149,36 @@ def test_pipeline_postproc_recursive(self) -> None: if postproc_mod == pipelined_model.module.postproc_nonweighted: self.assertEqual(len(postproc_mod._args), 1) args = postproc_mod._args[0] - self.assertEqual(args.input_attrs, ["", "idlist_features"]) - self.assertEqual(args.is_getitems, [False, False]) - self.assertEqual(len(args.postproc_modules), 2) + self.assertEqual(len(args.steps), 2) self.assertEqual( - args.postproc_modules[0], - parent_postproc_mod, + [step.input_attr for step in args.steps], ["", "idlist_features"] ) - self.assertEqual(args.postproc_modules[1], None) + self.assertEqual( + [step.is_getitem for step in args.steps], [False, False] + ) + self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod) + self.assertIsNone(args.steps[1].postproc_module) # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_weighted`. elif postproc_mod == pipelined_model.module.postproc_weighted: self.assertEqual(len(postproc_mod._args), 1) args = postproc_mod._args[0] - self.assertEqual(args.input_attrs, ["", "idscore_features"]) - self.assertEqual(args.is_getitems, [False, False]) - self.assertEqual(len(args.postproc_modules), 2) + self.assertEqual(len(args.steps), 2) + self.assertEqual( + [step.input_attr for step in args.steps], ["", "idscore_features"] + ) self.assertEqual( - args.postproc_modules[0], - parent_postproc_mod, + [step.is_getitem for step in args.steps], [False, False] ) - self.assertEqual(args.postproc_modules[1], None) + self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod) + self.assertIsNone(args.steps[1].postproc_module) elif postproc_mod == parent_postproc_mod: self.assertEqual(len(postproc_mod._args), 1) args = postproc_mod._args[0] - self.assertEqual(args.input_attrs, [""]) - self.assertEqual(args.is_getitems, [False]) - self.assertEqual(args.postproc_modules, [None]) + self.assertEqual(len(args.steps), 1) + self.assertEqual(args.steps[0].input_attr, "") + self.assertFalse(args.steps[0].is_getitem) + self.assertIsNone(args.steps[0].postproc_module) # pyre-ignore @unittest.skipIf( diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 78cd516b7..42629640d 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -26,6 +26,7 @@ _build_args_kwargs, _rewrite_model, ArgInfo, + ArgInfoStep, NodeArgsHelper, PipelinedForward, PipelinedPostproc, @@ -110,7 +111,7 @@ def test_rewrite_model(self) -> None: self.assertEqual( # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `sparse`. - sharded_model.module.sparse.ebc.forward._args[0].postproc_modules[0], + sharded_model.module.sparse.ebc.forward._args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_module`. sharded_model.module.postproc_module, @@ -118,9 +119,9 @@ def test_rewrite_model(self) -> None: self.assertEqual( # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `sparse`. - sharded_model.module.sparse.weighted_ebc.forward._args[0].postproc_modules[ - 0 - ], + sharded_model.module.sparse.weighted_ebc.forward._args[0] + .steps[0] + .postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_module`. sharded_model.module.postproc_module, @@ -263,19 +264,18 @@ def test_restore_from_snapshot(self) -> None: [ # Empty attrs to ignore any attr based logic. ArgInfo( - input_attrs=[ - "", + steps=[ + ArgInfoStep( + input_attr="", + is_getitem=False, + postproc_module=None, + constant=None, + ) ], - is_getitems=[False], - postproc_modules=[None], - constants=[None], name="id_list_features", ), ArgInfo( - input_attrs=[], - is_getitems=[], - postproc_modules=[], - constants=[], + steps=[], name="id_score_list_features", ), ], @@ -286,19 +286,18 @@ def test_restore_from_snapshot(self) -> None: [ # Empty attrs to ignore any attr based logic. ArgInfo( - input_attrs=[ - "", + steps=[ + ArgInfoStep( + input_attr="", + is_getitem=False, + postproc_module=None, + constant=None, + ) ], - is_getitems=[False], - postproc_modules=[None], - constants=[None], name=None, ), ArgInfo( - input_attrs=[], - is_getitems=[], - postproc_modules=[], - constants=[], + steps=[], name=None, ), ], @@ -309,19 +308,18 @@ def test_restore_from_snapshot(self) -> None: [ # Empty attrs to ignore any attr based logic. ArgInfo( - input_attrs=[ - "", + steps=[ + ArgInfoStep( + input_attr="", + is_getitem=False, + postproc_module=None, + constant=None, + ) ], - is_getitems=[False], - postproc_modules=[None], - constants=[None], name=None, ), ArgInfo( - input_attrs=[], - is_getitems=[], - postproc_modules=[], - constants=[], + steps=[], name="id_score_list_features", ), ], diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 410a083ce..9c95de52b 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -168,68 +168,78 @@ class PipelineStage: fill_callback: Optional[Callable[[], None]] = None +@dataclass +class ArgInfoStep: + """ + Representation of args from a node. + + Attributes: + input_attr (str): accessing an attribute (or index) from an input, + is_getitem (bool): is this a getattr (output = foo.input_attr) or getitem (output = foo[input_attr]) call + postproc_module (Optional[PipelinedPostproc]): torch.nn.Modules that transform the input + output = postproc_module(input) + constant: no transformation, just constant argument + output = constant + """ + + input_attr: str + is_getitem: bool + # recursive dataclass as postproc_modules.args -> arginfo.postproc_modules -> so on + postproc_module: Optional["PipelinedPostproc"] + constant: Optional[object] + + @dataclass class ArgInfo: """ Representation of args from a node. Attributes: - input_attrs (List[str]): attributes of input batch, - e.g. `batch.attr1.attr2` will produce ["attr1", "attr2"]. - is_getitems (List[bool]): `batch[attr1].attr2` will produce [True, False]. - postproc_modules (List[Optional[PipelinedPostproc]]): list of torch.nn.Modules that - transform the input batch. - constants: constant arguments that are passed to postproc modules. name (Optional[str]): name for kwarg of pipelined forward() call or None for a positional arg. + steps (List[ArgInfoStep]): sequence of transformations from input batch. + Steps can be thought of consequtive transformations on the input, with + output of previous step used as an input for the next. I.e. for 3 steps + it is similar to step3(step2(step1(input))) + See `ArgInfoStep` docstring for details on which transformations can be applied """ - input_attrs: List[str] - is_getitems: List[bool] - # recursive dataclass as postproc_modules.args -> arginfo.postproc_modules -> so on - postproc_modules: List[Optional["PipelinedPostproc"]] - constants: List[Optional[object]] name: Optional[str] + steps: List[ArgInfoStep] def add_noop(self) -> "ArgInfo": - return self._insert_arg("") + return self._insert_step("") def add_input_attr(self, name: str, is_getitem: bool) -> "ArgInfo": - return self._insert_arg(name, is_getitem) + return self._insert_step(name, is_getitem) def append_input_attr(self, name: str, is_getitem: bool) -> "ArgInfo": - return self._append_arg(name, is_getitem) + return self._append_step(name, is_getitem) def add_postproc(self, pipelined_postproc_module: "PipelinedPostproc") -> "ArgInfo": - return self._insert_arg(postproc=pipelined_postproc_module) + return self._insert_step(postproc=pipelined_postproc_module) def add_constant(self, value: object) -> "ArgInfo": - return self._insert_arg(constant=value) + return self._insert_step(constant=value) - def _insert_arg( + def _insert_step( self, name: str = "", is_getitem: bool = False, postproc: Optional["PipelinedPostproc"] = None, constant: Optional[object] = None, ) -> "ArgInfo": - self.input_attrs.insert(0, name) - self.is_getitems.insert(0, is_getitem) - self.postproc_modules.insert(0, postproc) - self.constants.insert(0, constant) + self.steps.insert(0, ArgInfoStep(name, is_getitem, postproc, constant)) return self - def _append_arg( + def _append_step( self, name: str = "", is_getitem: bool = False, postproc: Optional["PipelinedPostproc"] = None, constant: Optional[object] = None, ) -> "ArgInfo": - self.input_attrs.append(name) - self.is_getitems.append(is_getitem) - self.postproc_modules.append(postproc) - self.constants.append(constant) + self.steps.append(ArgInfoStep(name, is_getitem, postproc, constant)) return self @@ -242,46 +252,41 @@ def _build_args_kwargs( args = [] kwargs = {} for arg_info in fwd_args: - if arg_info.input_attrs: + if arg_info.steps: arg = initial_input - for attr, is_getitem, postproc_mod, obj in zip( - arg_info.input_attrs, - arg_info.is_getitems, - arg_info.postproc_modules, - arg_info.constants, - ): - if obj is not None: - if isinstance(obj, list): + for step in arg_info.steps: + if step.constant is not None: + if isinstance(step.constant, list): arg = [ ( v if not isinstance(v, ArgInfo) else _build_args_kwargs(initial_input, [v])[0][0] ) - for v in obj + for v in step.constant ] - elif isinstance(obj, dict): + elif isinstance(step.constant, dict): arg = { k: ( v if not isinstance(v, ArgInfo) else _build_args_kwargs(initial_input, [v])[0][0] ) - for k, v in obj.items() + for k, v in step.constant.items() } else: - arg = obj + arg = step.constant break - elif postproc_mod is not None: + elif step.postproc_module is not None: # postproc will internally run the same logic recursively # if its args are derived from other postproc modules # we can get all inputs to postproc mod based on its recorded args_info + arg passed to it - arg = postproc_mod(arg) + arg = step.postproc_module(arg) else: - if is_getitem: - arg = arg[attr] - elif attr != "": - arg = getattr(arg, attr) + if step.is_getitem: + arg = arg[step.input_attr] + elif step.input_attr != "": + arg = getattr(arg, step.input_attr) else: # neither is_getitem nor valid attr, no-op arg = arg @@ -1131,7 +1136,7 @@ def _get_node_args_helper_inner( arg, for_postproc_module: bool = False, ) -> Optional[ArgInfo]: - arg_info = ArgInfo([], [], [], [], None) + arg_info = ArgInfo(None, []) while True: if not isinstance(arg, torch.fx.Node): return self._handle_constant(arg, arg_info, for_postproc_module) @@ -1159,9 +1164,9 @@ def _get_node_args_helper_inner( This is for the PT2 export path where we unflatten the input to reconstruct the structure with the recorded tree spec. """ - assert arg_info.is_getitems[0] + assert arg_info.steps[0].is_getitem # pyre-fixme[16] - arg = child_node.args[0][arg_info.input_attrs[0]] + arg = child_node.args[0][arg_info.step[0].input_attr] elif ( fn_module == "torchrec.sparse.jagged_tensor" and fn_name == "KeyedJaggedTensor"