From 31687b18ffbf6840a90c6fa418db051d331e0dc4 Mon Sep 17 00:00:00 2001 From: Evgenii Kolpakov Date: Wed, 19 Feb 2025 01:15:55 -0800 Subject: [PATCH] Refactor _build_args_kwards into an instance method on CallArgs + ArgInfo (#2743) Summary: Torchrec rewriting logic got a bit hairy over the years, this sequence of changes aims to refactor the rewrite logic to be less convoluted and more maintainable in the future. This change: * almost all code in `_build_args_kwargs` deals with the fields of ArgInfoStep, and remaining part handles looping over `ArgInfo.steps` - so this change just colocates "behavior" (`_build_args_kwargs` logic) with data it belongs to. * introduces helper functions/factory methods for various types of ArgInfoStep * encapsulates the logic of handling a `List[ArgInfo]` into a `CallArgs` class (+changes a bit - explicitly separating args nad kwargs, vs. having them differ by empty/present `ArgInfo.name` field) Internal Diff stack navigation: 1. D69292525 and below - before refactoring 2. D69438143 - Refactor get_node_args and friends into a class 3. D69461227 - refactor "joint lists" in ArgInfo into a list of ArgInfoStep 4. D69461226 - refactor `_build_args_kwargs` into instance methods on ArgInfo and ArgInfoStep (**you are here**) 5. D69461228 - split monolithic `ArgInfoStep` into a class hierarchy Differential Revision: D69461226 --- .../distributed/train_pipeline/__init__.py | 3 +- .../tests/test_train_pipelines.py | 44 +-- .../tests/test_train_pipelines_utils.py | 97 +++---- torchrec/distributed/train_pipeline/utils.py | 257 ++++++++++-------- 4 files changed, 201 insertions(+), 200 deletions(-) diff --git a/torchrec/distributed/train_pipeline/__init__.py b/torchrec/distributed/train_pipeline/__init__.py index 15cbbe831..a3e32ed7a 100644 --- a/torchrec/distributed/train_pipeline/__init__.py +++ b/torchrec/distributed/train_pipeline/__init__.py @@ -26,7 +26,8 @@ _to_device, # noqa _wait_for_batch, # noqa ArgInfo, # noqa - ArgInfoStep, # noqa + ArgInfoStepFactory, # noqa + CallArgs, # noqa DataLoadingThread, # noqa In, # noqa Out, # noqa diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 844d3b500..8b491666b 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -1024,9 +1024,10 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: # Check pipelined args for ebc in [pipelined_ebc, pipelined_weighted_ebc]: - self.assertEqual(len(ebc.forward._args), 1) - self.assertEqual(len(ebc.forward._args[0].steps), 2) - [step1, step2] = ebc.forward._args[0].steps + self.assertEqual(len(ebc.forward._args.args), 1) + self.assertEqual(len(ebc.forward._args.kwargs), 0) + self.assertEqual(len(ebc.forward._args.args[0].steps), 2) + [step1, step2] = ebc.forward._args.args[0].steps self.assertEqual(step1.input_attr, "") self.assertEqual(step1.is_getitem, False) @@ -1036,13 +1037,13 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: self.assertIsNone(step2.postproc_module) self.assertEqual( - pipelined_ebc.forward._args[0].steps[0].postproc_module, + pipelined_ebc.forward._args.args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_nonweighted`. pipelined_model.module.postproc_nonweighted, ) self.assertEqual( - pipelined_weighted_ebc.forward._args[0].steps[0].postproc_module, + pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_weighted`. pipelined_model.module.postproc_weighted, @@ -1053,9 +1054,10 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: input_attr_names = {"idlist_features", "idscore_features"} for i in range(len(pipeline._pipelined_postprocs)): postproc_mod = pipeline._pipelined_postprocs[i] - self.assertEqual(len(postproc_mod._args), 1) - self.assertEqual(len(postproc_mod._args[0].steps), 2) - [step1, step2] = postproc_mod._args[0].steps + self.assertEqual(len(postproc_mod._args.args), 1) + self.assertEqual(len(postproc_mod._args.kwargs), 0) + self.assertEqual(len(postproc_mod._args.args[0].steps), 2) + [step1, step2] = postproc_mod._args.args[0].steps self.assertTrue(step2.input_attr in input_attr_names) @@ -1112,9 +1114,10 @@ def test_pipeline_postproc_recursive(self) -> None: # Check pipelined args for ebc in [pipelined_ebc, pipelined_weighted_ebc]: - self.assertEqual(len(ebc.forward._args), 1) - self.assertEqual(len(ebc.forward._args[0].steps), 2) - [step1, step2] = ebc.forward._args[0].steps + self.assertEqual(len(ebc.forward._args.args), 1) + self.assertEqual(len(ebc.forward._args.kwargs), 0) + self.assertEqual(len(ebc.forward._args.args[0].steps), 2) + [step1, step2] = ebc.forward._args.args[0].steps self.assertEqual(step1.input_attr, "") self.assertEqual(step1.is_getitem, False) @@ -1124,13 +1127,13 @@ def test_pipeline_postproc_recursive(self) -> None: self.assertIsNone(step2.postproc_module) self.assertEqual( - pipelined_ebc.forward._args[0].steps[0].postproc_module, + pipelined_ebc.forward._args.args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_nonweighted`. pipelined_model.module.postproc_nonweighted, ) self.assertEqual( - pipelined_weighted_ebc.forward._args[0].steps[0].postproc_module, + pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_weighted`. pipelined_model.module.postproc_weighted, @@ -1147,8 +1150,9 @@ def test_pipeline_postproc_recursive(self) -> None: # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_nonweighted`. if postproc_mod == pipelined_model.module.postproc_nonweighted: - self.assertEqual(len(postproc_mod._args), 1) - args = postproc_mod._args[0] + self.assertEqual(len(postproc_mod._args.args), 1) + self.assertEqual(len(postproc_mod._args.kwargs), 0) + args = postproc_mod._args.args[0] self.assertEqual(len(args.steps), 2) self.assertEqual( [step.input_attr for step in args.steps], ["", "idlist_features"] @@ -1161,8 +1165,9 @@ def test_pipeline_postproc_recursive(self) -> None: # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_weighted`. elif postproc_mod == pipelined_model.module.postproc_weighted: - self.assertEqual(len(postproc_mod._args), 1) - args = postproc_mod._args[0] + self.assertEqual(len(postproc_mod._args.args), 1) + self.assertEqual(len(postproc_mod._args.kwargs), 0) + args = postproc_mod._args.args[0] self.assertEqual(len(args.steps), 2) self.assertEqual( [step.input_attr for step in args.steps], ["", "idscore_features"] @@ -1173,8 +1178,9 @@ def test_pipeline_postproc_recursive(self) -> None: self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod) self.assertIsNone(args.steps[1].postproc_module) elif postproc_mod == parent_postproc_mod: - self.assertEqual(len(postproc_mod._args), 1) - args = postproc_mod._args[0] + self.assertEqual(len(postproc_mod._args.args), 1) + self.assertEqual(len(postproc_mod._args.kwargs), 0) + args = postproc_mod._args.args[0] self.assertEqual(len(args.steps), 1) self.assertEqual(args.steps[0].input_attr, "") self.assertFalse(args.steps[0].is_getitem) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 42629640d..358da4d33 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -23,10 +23,10 @@ TrainPipelineSparseDistTestBase, ) from torchrec.distributed.train_pipeline.utils import ( - _build_args_kwargs, _rewrite_model, ArgInfo, - ArgInfoStep, + ArgInfoStepFactory, + CallArgs, NodeArgsHelper, PipelinedForward, PipelinedPostproc, @@ -111,7 +111,9 @@ def test_rewrite_model(self) -> None: self.assertEqual( # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `sparse`. - sharded_model.module.sparse.ebc.forward._args[0].steps[0].postproc_module, + sharded_model.module.sparse.ebc.forward._args.args[0] + .steps[0] + .postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `postproc_module`. sharded_model.module.postproc_module, @@ -119,7 +121,7 @@ def test_rewrite_model(self) -> None: self.assertEqual( # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `sparse`. - sharded_model.module.sparse.weighted_ebc.forward._args[0] + sharded_model.module.sparse.weighted_ebc.forward._args.args[0] .steps[0] .postproc_module, # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute @@ -155,7 +157,7 @@ def forward(self, x): rewritten_model.test_module = PipelinedPostproc( postproc_module=rewritten_model.test_module, fqn="test_module", - args=[], + args=CallArgs(args=[], kwargs={}), context=TrainPipelineContext(), default_stream=MagicMock(), dist_stream=MagicMock(), @@ -261,68 +263,41 @@ def test_restore_from_snapshot(self) -> None: @parameterized.expand( [ ( - [ - # Empty attrs to ignore any attr based logic. - ArgInfo( - steps=[ - ArgInfoStep( - input_attr="", - is_getitem=False, - postproc_module=None, - constant=None, - ) - ], - name="id_list_features", - ), - ArgInfo( - steps=[], - name="id_score_list_features", - ), - ], + CallArgs( + args=[], + kwargs={ + "id_list_features": ArgInfo(steps=[ArgInfoStepFactory.noop()]), + # Empty attrs to ignore any attr based logic. + "id_score_list_features": ArgInfo( + steps=[ArgInfoStepFactory.noop()] + ), + }, + ), 0, ["id_list_features", "id_score_list_features"], ), ( - [ - # Empty attrs to ignore any attr based logic. - ArgInfo( - steps=[ - ArgInfoStep( - input_attr="", - is_getitem=False, - postproc_module=None, - constant=None, - ) - ], - name=None, - ), - ArgInfo( - steps=[], - name=None, - ), - ], + CallArgs( + args=[ + # Empty attrs to ignore any attr based logic. + ArgInfo(steps=[ArgInfoStepFactory.noop()]), + ArgInfo(steps=[]), + ], + kwargs={}, + ), 2, [], ), ( - [ - # Empty attrs to ignore any attr based logic. - ArgInfo( - steps=[ - ArgInfoStep( - input_attr="", - is_getitem=False, - postproc_module=None, - constant=None, - ) - ], - name=None, - ), - ArgInfo( - steps=[], - name="id_score_list_features", - ), - ], + CallArgs( + args=[ + # Empty attrs to ignore any attr based logic. + ArgInfo( + steps=[ArgInfoStepFactory.noop()], + ) + ], + kwargs={"id_score_list_features": ArgInfo(steps=[])}, + ), 1, ["id_score_list_features"], ), @@ -330,11 +305,11 @@ def test_restore_from_snapshot(self) -> None: ) def test_build_args_kwargs( self, - fwd_args: List[ArgInfo], + fwd_args: CallArgs, args_len: int, kwarges_keys: List[str], ) -> None: - args, kwargs = _build_args_kwargs("initial_input", fwd_args) + args, kwargs = fwd_args.build_args_kwargs("initial_input") self.assertEqual(len(args), args_len) self.assertEqual(list(kwargs.keys()), kwarges_keys) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 104a6992a..ea496ac52 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -6,7 +6,6 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - import copy import itertools import logging @@ -188,6 +187,75 @@ class ArgInfoStep: postproc_module: Optional["PipelinedPostproc"] constant: Optional[object] + # pyre-ignore[3] + def process( + self, + arg, # pyre-ignore[2] + ) -> Any: + if self.constant is not None: + if isinstance(self.constant, list): + arg = [ + (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) + for v in self.constant + ] + elif isinstance(self.constant, dict): + arg = { + k: (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) + for k, v in self.constant.items() + } + else: + arg = self.constant + elif self.postproc_module is not None: + # postproc will internally run the same logic recursively + # if its args are derived from other postproc modules + # we can get all inputs to postproc mod based on its recorded args_info + arg passed to it + arg = self.postproc_module(arg) + else: + if self.is_getitem: + arg = arg[self.input_attr] + elif self.input_attr != "": + arg = getattr(arg, self.input_attr) + else: + # neither is_getitem nor valid attr, no-op + arg = arg + + return arg + + +class ArgInfoStepFactory: + """ + Convenience class to reduce the amount of imports the external uses will have. + Should closely follow the constructor interfaces for the corresponding classes. + """ + + @classmethod + def noop(cls) -> ArgInfoStep: + return ArgInfoStep("", False, None, None) + + @classmethod + def get_attr(cls, name: str) -> ArgInfoStep: + return ArgInfoStep(name, False, None, None) + + @classmethod + def get_item(cls, index: str) -> ArgInfoStep: + return ArgInfoStep(index, True, None, None) + + @classmethod + def postproc(cls, pipelined_postproc_module: "PipelinedPostproc") -> ArgInfoStep: + return ArgInfoStep("", False, pipelined_postproc_module, None) + + @classmethod + def from_scalar(cls, value: object) -> ArgInfoStep: + return ArgInfoStep("", False, None, value) + + @classmethod + def from_list(cls, value: List[object]) -> ArgInfoStep: + return ArgInfoStep("", False, None, value) + + @classmethod + def from_dict(cls, value: Dict[str, object]) -> ArgInfoStep: + return ArgInfoStep("", False, None, value) + @dataclass class ArgInfo: @@ -195,8 +263,6 @@ class ArgInfo: Representation of args from a node. Attributes: - name (Optional[str]): name for kwarg of pipelined forward() call or None for a - positional arg. steps (List[ArgInfoStep]): sequence of transformations from input batch. Steps can be thought of consequtive transformations on the input, with output of previous step used as an input for the next. I.e. for 3 steps @@ -204,102 +270,48 @@ class ArgInfo: See `ArgInfoStep` docstring for details on which transformations can be applied """ - name: Optional[str] steps: List[ArgInfoStep] - def add_noop(self) -> "ArgInfo": - return self._insert_step("") + def add_step(self, step: ArgInfoStep) -> "ArgInfo": + self.steps.insert(0, step) + return self - def add_input_attr(self, name: str, is_getitem: bool) -> "ArgInfo": - return self._insert_step(name, is_getitem) + def append_step(self, step: ArgInfoStep) -> "ArgInfo": + self.steps.append(step) + return self - def append_input_attr(self, name: str, is_getitem: bool) -> "ArgInfo": - return self._append_step(name, is_getitem) + # pyre-ignore[3] + def process_steps( + self, + arg: Any, # pyre-ignore[2] + ) -> Any: + if not self.steps: + return None + for step in self.steps: + arg = step.process(arg) - def add_postproc(self, pipelined_postproc_module: "PipelinedPostproc") -> "ArgInfo": - return self._insert_step(postproc=pipelined_postproc_module) + return arg - def add_constant(self, value: object) -> "ArgInfo": - return self._insert_step(constant=value) - def _insert_step( - self, - name: str = "", - is_getitem: bool = False, - postproc: Optional["PipelinedPostproc"] = None, - constant: Optional[object] = None, - ) -> "ArgInfo": - self.steps.insert(0, ArgInfoStep(name, is_getitem, postproc, constant)) - return self +@dataclass +class CallArgs: + args: List[ArgInfo] + kwargs: Dict[str, ArgInfo] - def _append_step( - self, - name: str = "", - is_getitem: bool = False, - postproc: Optional["PipelinedPostproc"] = None, - constant: Optional[object] = None, - ) -> "ArgInfo": - self.steps.append(ArgInfoStep(name, is_getitem, postproc, constant)) - return self + # pyre-ignore[3] + def build_args_kwargs( + self, initial_input: Any # pyre-ignore[2] + ) -> Tuple[List[Any], Dict[str, Any]]: + """ + Creates args and kwargs for given CallArgs and input + """ + args = [arg.process_steps(initial_input) for arg in self.args] + kwargs = { + key: arg.process_steps(initial_input) for key, arg in self.kwargs.items() + } -# pyre-ignore -def _build_args_kwargs( - # pyre-ignore - initial_input: Any, - fwd_args: List[ArgInfo], -) -> Tuple[List[Any], Dict[str, Any]]: - args = [] - kwargs = {} - for arg_info in fwd_args: - if arg_info.steps: - arg = initial_input - for step in arg_info.steps: - if step.constant is not None: - if isinstance(step.constant, list): - arg = [ - ( - v - if not isinstance(v, ArgInfo) - else _build_args_kwargs(initial_input, [v])[0][0] - ) - for v in step.constant - ] - elif isinstance(step.constant, dict): - arg = { - k: ( - v - if not isinstance(v, ArgInfo) - else _build_args_kwargs(initial_input, [v])[0][0] - ) - for k, v in step.constant.items() - } - else: - arg = step.constant - break - elif step.postproc_module is not None: - # postproc will internally run the same logic recursively - # if its args are derived from other postproc modules - # we can get all inputs to postproc mod based on its recorded args_info + arg passed to it - arg = step.postproc_module(arg) - else: - if step.is_getitem: - arg = arg[step.input_attr] - elif step.input_attr != "": - arg = getattr(arg, step.input_attr) - else: - # neither is_getitem nor valid attr, no-op - arg = arg - if arg_info.name: - kwargs[arg_info.name] = arg - else: - args.append(arg) - else: - if arg_info.name: - kwargs[arg_info.name] = None - else: - args.append(None) - return args, kwargs + return args, kwargs def recursive_record_stream( @@ -343,7 +355,7 @@ class PipelinedPostproc(torch.nn.Module): Args: postproc_module (torch.nn.Module): postproc module to run fqn (str): fqn of the postproc module in the model being pipelined - args (List[ArgInfo]): list of ArgInfo for the postproc module + args (CallArgs): CallArgs for the postproc module context (TrainPipelineContext): Training context for the next iteration / batch Returns: @@ -361,7 +373,7 @@ def __init__( self, postproc_module: torch.nn.Module, fqn: str, - args: List[ArgInfo], + args: CallArgs, context: TrainPipelineContext, # TODO: make streams non-optional - skipping now to avoid ripple effect default_stream: Optional[torch.Stream], @@ -434,7 +446,7 @@ def forward(self, *input, **kwargs) -> Any: # of another postproc module call, as long as module is pipelineable # Use input[0] as _start_data_dist only passes 1 arg - args, kwargs = _build_args_kwargs(input[0], self._args) + args, kwargs = self._args.build_args_kwargs(input[0]) with record_function(f"## sdd_input_postproc {self._context.index} ##"): # should be no-op as we call this in dist stream @@ -467,7 +479,7 @@ def forward(self, *input, **kwargs) -> Any: return res @property - def args(self) -> List[ArgInfo]: + def args(self) -> CallArgs: return self._args def set_context(self, context: TrainPipelineContext) -> None: @@ -544,7 +556,7 @@ class BaseForward(Generic[TForwardContext]): def __init__( self, name: str, - args: List[ArgInfo], + args: CallArgs, module: ShardedModule, context: TForwardContext, stream: Optional[torch.Stream] = None, @@ -561,7 +573,7 @@ def name(self) -> str: return self._name @property - def args(self) -> List[ArgInfo]: + def args(self) -> CallArgs: return self._args def set_context(self, context: TForwardContext) -> None: @@ -661,7 +673,7 @@ class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]): def __init__( self, name: str, - args: List[ArgInfo], + args: CallArgs, module: ShardedModule, context: PrefetchTrainPipelineContext, prefetch_stream: Optional[torch.Stream] = None, @@ -838,7 +850,7 @@ def _start_data_dist( # False means this argument is getting while getattr # and this info was done in the _rewrite_model by tracing the # entire model to get the arg_info_list - args, kwargs = _build_args_kwargs(batch, forward.args) + args, kwargs = forward.args.build_args_kwargs(batch) # Start input distribution. module_ctx = module.create_context() @@ -995,8 +1007,7 @@ def _swap_postproc_module_recursive( def _handle_constant( self, - # pyre-ignore - arg, + arg: Any, # pyre-ignore arg_info: ArgInfo, for_postproc_module: bool = False, ) -> Optional[ArgInfo]: @@ -1004,18 +1015,19 @@ def _handle_constant( return None if isinstance(arg, fx_immutable_dict): - arg_info.add_constant( + step = ArgInfoStepFactory.from_dict( { k: self._handle_collection_element(v, for_postproc_module) for k, v in arg.items() } ) elif isinstance(arg, fx_immutable_list): - arg_info.add_constant( + step = ArgInfoStepFactory.from_list( [self._handle_collection_element(v, for_postproc_module) for v in arg] ) else: - arg_info.add_constant(arg) + step = ArgInfoStepFactory.from_scalar(arg) + arg_info.add_step(step) return arg_info # pyre-ignore[3] @@ -1046,12 +1058,12 @@ def _handle_placeholder( ph_keys = ph_key.split(".") for key in ph_keys: if "]" in key: - arg_info.append_input_attr(key[:-1], True) + arg_info.append_step(ArgInfoStepFactory.get_item(key[:-1])) else: - arg_info.append_input_attr(key, False) + arg_info.append_step(ArgInfoStepFactory.get_attr(key)) else: # no-op - arg_info.add_noop() + arg_info.add_step(ArgInfoStepFactory.noop()) return arg_info def _handle_module( @@ -1075,7 +1087,7 @@ def _handle_module( if isinstance(postproc_module, PipelinedPostproc): # Already did module swap and registered args, early exit self._pipelined_postprocs.add(postproc_module) - arg_info.add_postproc(postproc_module) + arg_info.add_step(ArgInfoStepFactory.postproc(postproc_module)) return arg_info if not isinstance(postproc_module, torch.nn.Module): @@ -1125,7 +1137,7 @@ def _handle_module( ) self._pipelined_postprocs.add(pipelined_postproc_module) - arg_info.add_postproc(pipelined_postproc_module) + arg_info.add_step(ArgInfoStepFactory.postproc(pipelined_postproc_module)) return arg_info return None @@ -1136,7 +1148,7 @@ def _get_node_args_helper_inner( arg, for_postproc_module: bool = False, ) -> Optional[ArgInfo]: - arg_info = ArgInfo(None, []) + arg_info = ArgInfo([]) while True: if not isinstance(arg, torch.fx.Node): return self._handle_constant(arg, arg_info, for_postproc_module) @@ -1152,12 +1164,16 @@ def _get_node_args_helper_inner( # pyre-fixme[16] fn_name = child_node.target.__name__ if fn_module == "builtins" and fn_name == "getattr": - # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown - arg_info.add_input_attr(child_node.args[1], False) + arg_info.add_step( + # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown + ArgInfoStepFactory.get_attr(child_node.args[1]) + ) arg = child_node.args[0] elif fn_module == "_operator" and fn_name == "getitem": - # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown - arg_info.add_input_attr(child_node.args[1], True) + arg_info.add_step( + # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown + ArgInfoStepFactory.get_item(child_node.args[1]) + ) arg = child_node.args[0] elif fn_module == "torch.utils._pytree" and fn_name == "tree_unflatten": """ @@ -1186,6 +1202,9 @@ def _get_node_args_helper_inner( if call_module_found: break + if call_module_found: + break + if "values" in child_node.kwargs: arg = child_node.kwargs["values"] else: @@ -1193,7 +1212,7 @@ def _get_node_args_helper_inner( elif child_node.op == "call_method" and child_node.target == "get": # pyre-ignore[6] - arg_info.add_input_attr(child_node.args[1], True) + arg_info.add_step(ArgInfoStepFactory.get_item(child_node.args[1])) arg = child_node.args[0] else: break @@ -1232,7 +1251,7 @@ def get_node_args( self, node: Node, for_postproc_module: bool = False, - ) -> Tuple[List[ArgInfo], int]: + ) -> Tuple[CallArgs, int]: pos_arg_info_list, args_found = self._get_node_args_helper( node.args, for_postproc_module, @@ -1243,12 +1262,12 @@ def get_node_args( ) # Replace with proper names for kwargs - for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list): - arg_info_list.name = name - - arg_info_list = pos_arg_info_list + kwargs_arg_info_list + kwargs_info_list = { + name: arg_info_list + for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list) + } - return (arg_info_list, args_found + kwargs_found) + return CallArgs(pos_arg_info_list, kwargs_info_list), args_found + kwargs_found def _get_leaf_module_names_helper(