diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 53fae9001..78cd516b7 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -24,9 +24,9 @@ ) from torchrec.distributed.train_pipeline.utils import ( _build_args_kwargs, - _get_node_args, _rewrite_model, ArgInfo, + NodeArgsHelper, PipelinedForward, PipelinedPostproc, TrainPipelineContext, @@ -367,10 +367,9 @@ def test_get_node_args_helper_call_module_kjt(self) -> None: {}, ) - num_found = 0 - _, num_found = _get_node_args( - MagicMock(), kjt_node, set(), TrainPipelineContext(), False - ) + node_args_helper = NodeArgsHelper(MagicMock(), TrainPipelineContext(), False) + + _, num_found = node_args_helper.get_node_args(kjt_node) # Weights is call_module node, so we should only find 2 args unmodified self.assertEqual(num_found, len(kjt_args) - 1) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 8f0c9d569..410a083ce 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -191,6 +191,47 @@ class ArgInfo: constants: List[Optional[object]] name: Optional[str] + def add_noop(self) -> "ArgInfo": + return self._insert_arg("") + + def add_input_attr(self, name: str, is_getitem: bool) -> "ArgInfo": + return self._insert_arg(name, is_getitem) + + def append_input_attr(self, name: str, is_getitem: bool) -> "ArgInfo": + return self._append_arg(name, is_getitem) + + def add_postproc(self, pipelined_postproc_module: "PipelinedPostproc") -> "ArgInfo": + return self._insert_arg(postproc=pipelined_postproc_module) + + def add_constant(self, value: object) -> "ArgInfo": + return self._insert_arg(constant=value) + + def _insert_arg( + self, + name: str = "", + is_getitem: bool = False, + postproc: Optional["PipelinedPostproc"] = None, + constant: Optional[object] = None, + ) -> "ArgInfo": + self.input_attrs.insert(0, name) + self.is_getitems.insert(0, is_getitem) + self.postproc_modules.insert(0, postproc) + self.constants.insert(0, constant) + return self + + def _append_arg( + self, + name: str = "", + is_getitem: bool = False, + postproc: Optional["PipelinedPostproc"] = None, + constant: Optional[object] = None, + ) -> "ArgInfo": + self.input_attrs.append(name) + self.is_getitems.append(is_getitem) + self.postproc_modules.append(postproc) + self.constants.append(constant) + return self + # pyre-ignore def _build_args_kwargs( @@ -900,377 +941,306 @@ def _find_postproc_module_recursive( return None -def _swap_postproc_module_recursive( - module: torch.nn.Module, - to_swap_module: torch.nn.Module, - postproc_module_fqn: str, - path: str = "", -) -> torch.nn.Module: - """ - Swaps the postproc module in the model. - """ - if isinstance(module, PipelinedPostproc): - return module +class NodeArgsHelper: + def __init__( + self, + model: torch.nn.Module, + context: TrainPipelineContext, + pipeline_postproc: bool, + default_stream: Optional[torch.Stream] = None, + dist_stream: Optional[torch.Stream] = None, + ) -> None: + self._model = model + self._context = context + self._pipeline_postproc = pipeline_postproc + self._default_stream = default_stream + self._dist_stream = dist_stream + self._pipelined_postprocs: Set[PipelinedPostproc] = set() - if path == postproc_module_fqn: - return to_swap_module + @property + def pipelined_postprocs(self) -> Set[PipelinedPostproc]: + return self._pipelined_postprocs - for name, child in module.named_children(): - child = _swap_postproc_module_recursive( - child, - to_swap_module, - postproc_module_fqn, - path + "." + name if path else name, - ) - setattr(module, name, child) + def _swap_postproc_module_recursive( + self, + module: torch.nn.Module, + to_swap_module: torch.nn.Module, + postproc_module_fqn: str, + path: str = "", + ) -> torch.nn.Module: + """ + Swaps the postproc module in the model. + """ + if isinstance(module, PipelinedPostproc): + return module - return module + if path == postproc_module_fqn: + return to_swap_module + for name, child in module.named_children(): + child = self._swap_postproc_module_recursive( + child, + to_swap_module, + postproc_module_fqn, + path + "." + name if path else name, + ) + setattr(module, name, child) -def _get_node_args_helper_inner( - model: torch.nn.Module, - # pyre-ignore - arg, - arg_info: ArgInfo, - num_found: int, - pipelined_postprocs: Set[PipelinedPostproc], - context: TrainPipelineContext, - pipeline_postproc: bool, - for_postproc_module: bool = False, - default_stream: Optional[torch.Stream] = None, - dist_stream: Optional[torch.Stream] = None, -) -> int: - num_found = 0 - while True: - if not isinstance(arg, torch.fx.Node): - if pipeline_postproc: - arg_info.input_attrs.insert(0, "") - arg_info.is_getitems.insert(0, False) - arg_info.postproc_modules.insert(0, None) - - if isinstance(arg, fx_immutable_dict): - fx_nested_dict = {} - - for k, v in arg.items(): - if isinstance(v, torch.fx.Node): - arg_info_nested = ArgInfo([], [], [], [], None) - _get_node_args_helper_inner( - model, - v, - arg_info_nested, - num_found, - pipelined_postprocs, - context, - pipeline_postproc, - for_postproc_module, - default_stream=default_stream, - dist_stream=dist_stream, - ) - fx_nested_dict[k] = arg_info_nested - else: - fx_nested_dict[k] = v - - arg_info.constants.insert(0, fx_nested_dict) - elif isinstance(arg, fx_immutable_list): - fx_nested_list = [] - for v in arg: - if isinstance(v, torch.fx.Node): - arg_info_nested = ArgInfo([], [], [], [], None) - _get_node_args_helper_inner( - model, - v, - arg_info_nested, - num_found, - pipelined_postprocs, - context, - pipeline_postproc, - for_postproc_module, - default_stream=default_stream, - dist_stream=dist_stream, - ) - fx_nested_list.append(arg_info_nested) - else: - fx_nested_list.append(v) + return module - arg_info.constants.insert(0, fx_nested_list) - else: - arg_info.constants.insert(0, arg) - num_found += 1 - break - child_node = arg - - if child_node.op == "placeholder": - if hasattr(child_node, "ph_key"): - # pyre-ignore[16] - ph_key: str = child_node.ph_key - # example: ph_key = 'event_id_list_features_seqs[marketplace]' - ph_key = ph_key.replace("[", ".") - ph_keys = ph_key.split(".") - for key in ph_keys: - if "]" in key: - arg_info.input_attrs.append(key[:-1]) - arg_info.is_getitems.append(True) - else: - arg_info.input_attrs.append(key) - arg_info.is_getitems.append(False) - arg_info.postproc_modules.append(None) - arg_info.constants.append(None) - else: - # no-op - arg_info.input_attrs.insert(0, "") - arg_info.is_getitems.insert(0, False) - arg_info.postproc_modules.insert(0, None) - arg_info.constants.insert(0, None) - - num_found += 1 - break - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "builtins" - # pyre-ignore[16] - and child_node.target.__name__ == "getattr" - ): - # pyre-fixme[6]: For 2nd argument expected `str` but got - # `Union[None, Dict[str, typing.Any], List[typing.Any], Node, bool, - # complex, float, int, range, slice, str, device, dtype, layout, - # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. - arg_info.input_attrs.insert(0, child_node.args[1]) - arg_info.is_getitems.insert(0, False) - arg_info.postproc_modules.insert(0, None) - arg_info.constants.insert(0, None) - arg = child_node.args[0] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "_operator" - # pyre-ignore[16] - and child_node.target.__name__ == "getitem" - ): - # pyre-fixme[6]: For 2nd argument expected `str` but got - # `Union[None, Dict[str, typing.Any], List[typing.Any], Node, bool, - # complex, float, int, range, slice, str, device, dtype, layout, - # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. - arg_info.input_attrs.insert(0, child_node.args[1]) - arg_info.is_getitems.insert(0, True) - arg_info.postproc_modules.insert(0, None) - arg_info.constants.insert(0, None) - arg = child_node.args[0] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "torch.utils._pytree" - # pyre-ignore[16] - and child_node.target.__name__ == "tree_unflatten" - ): - """ - This is for the PT2 export path where we unflatten the input to reconstruct - the structure with the recorded tree spec. - """ - assert arg_info.is_getitems[0] - # pyre-fixme[16] - arg = child_node.args[0][arg_info.input_attrs[0]] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "torchrec.sparse.jagged_tensor" - # pyre-fixme[16] - and child_node.target.__name__ == "KeyedJaggedTensor" - ): - call_module_found = False + def _handle_constant( + self, + # pyre-ignore + arg, + arg_info: ArgInfo, + for_postproc_module: bool = False, + ) -> Optional[ArgInfo]: + if not self._pipeline_postproc: + return None - for arg_node in chain(child_node.args, child_node.kwargs.values()): - if isinstance(arg_node, torch.fx.Node) and _check_args_for_call_module( - arg_node - ): - call_module_found = True - break + if isinstance(arg, fx_immutable_dict): + arg_info.add_constant( + { + k: self._handle_collection_element(v, for_postproc_module) + for k, v in arg.items() + } + ) + elif isinstance(arg, fx_immutable_list): + arg_info.add_constant( + [self._handle_collection_element(v, for_postproc_module) for v in arg] + ) + else: + arg_info.add_constant(arg) + return arg_info - if call_module_found: - break + # pyre-ignore[3] + def _handle_collection_element( + self, + # pyre-ignore[2] + arg: Any, + for_postproc_module: bool = False, + ) -> Any: + if not isinstance(arg, torch.fx.Node): + return arg - if "values" in child_node.kwargs: - arg = child_node.kwargs["values"] - else: - arg = child_node.args[1] - elif child_node.op == "call_method" and child_node.target == "get": - # pyre-ignore[6] - arg_info.input_attrs.insert(0, child_node.args[1]) - arg_info.is_getitems.insert(0, True) - arg_info.postproc_modules.insert(0, None) - arg_info.constants.insert(0, None) - arg = child_node.args[0] - elif child_node.op == "call_module": - postproc_module_fqn = str(child_node.target) - postproc_module = _find_postproc_module_recursive( - model, postproc_module_fqn + arg_info_nested = self._get_node_args_helper_inner( + arg, + for_postproc_module, + ) + return arg_info_nested + + def _handle_placeholder( + self, child_node: torch.fx.Node, arg_info: ArgInfo + ) -> ArgInfo: + # note: mutates arg_info + if hasattr(child_node, "ph_key"): + # pyre-fixme[16] + ph_key: str = child_node.ph_key + # example: ph_key = 'event_id_list_features_seqs[marketplace]' + ph_key = ph_key.replace("[", ".") + ph_keys = ph_key.split(".") + for key in ph_keys: + if "]" in key: + arg_info.append_input_attr(key[:-1], True) + else: + arg_info.append_input_attr(key, False) + else: + # no-op + arg_info.add_noop() + return arg_info + + def _handle_module( + self, child_node: torch.fx.Node, arg_info: ArgInfo + ) -> Optional[ArgInfo]: + postproc_module_fqn = str(child_node.target) + postproc_module = _find_postproc_module_recursive( + self._model, postproc_module_fqn + ) + + if not self._pipeline_postproc: + logger.warning( + f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" ) + return None - if not pipeline_postproc: - logger.warning( - f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" - ) - break + if not postproc_module: + # Could not find such module, should not happen + return None - if not postproc_module: - # Could not find such module, should not happen - break + if isinstance(postproc_module, PipelinedPostproc): + # Already did module swap and registered args, early exit + self._pipelined_postprocs.add(postproc_module) + arg_info.add_postproc(postproc_module) + return arg_info - if isinstance(postproc_module, PipelinedPostproc): - # Already did module swap and registered args, early exit - arg_info.input_attrs.insert(0, "") # dummy value - arg_info.is_getitems.insert(0, False) - pipelined_postprocs.add(postproc_module) - arg_info.postproc_modules.insert(0, postproc_module) - arg_info.constants.insert(0, None) - num_found += 1 - break + if not isinstance(postproc_module, torch.nn.Module): + logger.warning( + f"Expected postproc_module to be nn.Module but was {type(postproc_module)}" + ) + return None - if not isinstance(postproc_module, torch.nn.Module): - logger.warning( - f"Expected postproc_module to be nn.Module but was {type(postproc_module)}" - ) - break + # check if module is safe to pipeline i.e.no trainable param + if not _check_postproc_pipelineable(postproc_module): + return None - # check if module is safe to pipeline i.e.no trainable param - if not _check_postproc_pipelineable(postproc_module): - break + # For module calls, `self` isn't counted + total_num_args = len(child_node.args) + len(child_node.kwargs) + if total_num_args == 0: + # module call without any args, assume KJT modified + return None - # For module calls, `self` isn't counted - total_num_args = len(child_node.args) + len(child_node.kwargs) - if total_num_args == 0: - # module call without any args, assume KJT modified - break + # recursive call to check that all inputs to this postproc module + # is either made of postproc module or non-modifying train batch input + # transformations + postproc_args, num_found_safe_postproc_args = self.get_node_args( + child_node, + for_postproc_module=True, + ) + if num_found_safe_postproc_args == total_num_args: + logger.info( + f"""Module {postproc_module} is a valid postproc module (no + trainable params and inputs can be derived from train batch input + via a series of either valid postproc modules or non-modifying + transformations) and will be applied during sparse data dist + stage""" + ) - # recursive call to check that all inputs to this postproc module - # is either made of postproc module or non-modifying train batch input - # transformations - postproc_args, num_found_safe_postproc_args = _get_node_args( - model, - child_node, - pipelined_postprocs, - context, - pipeline_postproc, - True, - default_stream=default_stream, - dist_stream=dist_stream, + pipelined_postproc_module = PipelinedPostproc( + postproc_module, + postproc_module_fqn, + postproc_args, + self._context, + default_stream=self._default_stream, + dist_stream=self._dist_stream, ) - if num_found_safe_postproc_args == total_num_args: - logger.info( - f"""Module {postproc_module} is a valid postproc module (no - trainable params and inputs can be derived from train batch input - via a series of either valid postproc modules or non-modifying - transformations) and will be applied during sparse data dist - stage""" - ) - pipelined_postproc_module = PipelinedPostproc( - postproc_module, - postproc_module_fqn, - postproc_args, - context, - default_stream=default_stream, - dist_stream=dist_stream, - ) + # module swap + self._model = self._swap_postproc_module_recursive( + self._model, pipelined_postproc_module, postproc_module_fqn + ) - # module swap - _swap_postproc_module_recursive( - model, pipelined_postproc_module, postproc_module_fqn - ) + self._pipelined_postprocs.add(pipelined_postproc_module) + arg_info.add_postproc(pipelined_postproc_module) + return arg_info - arg_info.input_attrs.insert(0, "") # dummy value - arg_info.is_getitems.insert(0, False) - pipelined_postprocs.add(pipelined_postproc_module) - arg_info.postproc_modules.insert(0, pipelined_postproc_module) - arg_info.constants.insert(0, None) + return None - num_found += 1 + def _get_node_args_helper_inner( + self, + # pyre-ignore + arg, + for_postproc_module: bool = False, + ) -> Optional[ArgInfo]: + arg_info = ArgInfo([], [], [], [], None) + while True: + if not isinstance(arg, torch.fx.Node): + return self._handle_constant(arg, arg_info, for_postproc_module) + + child_node = arg + + if child_node.op == "placeholder": + return self._handle_placeholder(arg, arg_info) + elif child_node.op == "call_module": + return self._handle_module(arg, arg_info) + elif child_node.op == "call_function": + fn_module = child_node.target.__module__ + # pyre-fixme[16] + fn_name = child_node.target.__name__ + if fn_module == "builtins" and fn_name == "getattr": + # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown + arg_info.add_input_attr(child_node.args[1], False) + arg = child_node.args[0] + elif fn_module == "_operator" and fn_name == "getitem": + # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown + arg_info.add_input_attr(child_node.args[1], True) + arg = child_node.args[0] + elif fn_module == "torch.utils._pytree" and fn_name == "tree_unflatten": + """ + This is for the PT2 export path where we unflatten the input to reconstruct + the structure with the recorded tree spec. + """ + assert arg_info.is_getitems[0] + # pyre-fixme[16] + arg = child_node.args[0][arg_info.input_attrs[0]] + elif ( + fn_module == "torchrec.sparse.jagged_tensor" + and fn_name == "KeyedJaggedTensor" + ): + call_module_found = False - # we cannot set any other `arg` value here - # break to avoid infinite loop - break - else: - break + for arg_node in chain(child_node.args, child_node.kwargs.values()): + if isinstance( + arg_node, torch.fx.Node + ) and _check_args_for_call_module(arg_node): + call_module_found = True + break - return num_found + if call_module_found: + break + if "values" in child_node.kwargs: + arg = child_node.kwargs["values"] + else: + arg = child_node.args[1] -def _get_node_args_helper( - model: torch.nn.Module, - # pyre-ignore - arguments, - num_found: int, - pipelined_postprocs: Set[PipelinedPostproc], - context: TrainPipelineContext, - pipeline_postproc: bool, - # Add `None` constants to arg info only for postproc modules - # Defaults to False for backward compatibility - for_postproc_module: bool = False, - default_stream: Optional[torch.Stream] = None, - dist_stream: Optional[torch.Stream] = None, -) -> Tuple[List[ArgInfo], int]: - """ - Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. - It also counts the number of (args + kwargs) found. - """ - arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))] - for arg, arg_info in zip(arguments, arg_info_list): - if not for_postproc_module and arg is None: - num_found += 1 - continue - num_found += _get_node_args_helper_inner( - model, - arg, - arg_info, - num_found, - pipelined_postprocs, - context, - pipeline_postproc, - for_postproc_module, - default_stream=default_stream, - dist_stream=dist_stream, - ) - return arg_info_list, num_found + elif child_node.op == "call_method" and child_node.target == "get": + # pyre-ignore[6] + arg_info.add_input_attr(child_node.args[1], True) + arg = child_node.args[0] + else: + break + # if we couldn't hit one of the "decisive" outcomes (constant, placeholder or module), return "not found" + return None -def _get_node_args( - model: torch.nn.Module, - node: Node, - pipelined_postprocs: Set[PipelinedPostproc], - context: TrainPipelineContext, - pipeline_postproc: bool, - for_postproc_module: bool = False, - default_stream: Optional[torch.Stream] = None, - dist_stream: Optional[torch.Stream] = None, -) -> Tuple[List[ArgInfo], int]: - num_found = 0 + def _get_node_args_helper( + self, + # pyre-ignore + arguments, + # Add `None` constants to arg info only for postproc modules + # Defaults to False for backward compatibility + for_postproc_module: bool = False, + ) -> Tuple[List[ArgInfo], int]: + """ + Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. + It also counts the number of (args + kwargs) found. + """ + num_found = 0 + arg_info_list = [] + for arg in arguments: + if not for_postproc_module and arg is None: + num_found += 1 + continue + arg_info = self._get_node_args_helper_inner( + arg, + for_postproc_module, + ) + if arg_info is not None: + num_found += 1 + arg_info_list.append(arg_info) + return arg_info_list, num_found - pos_arg_info_list, num_found = _get_node_args_helper( - model, - node.args, - num_found, - pipelined_postprocs, - context, - pipeline_postproc, - for_postproc_module, - default_stream=default_stream, - dist_stream=dist_stream, - ) - kwargs_arg_info_list, num_found = _get_node_args_helper( - model, - node.kwargs.values(), - num_found, - pipelined_postprocs, - context, - pipeline_postproc, - for_postproc_module, - default_stream=default_stream, - dist_stream=dist_stream, - ) + def get_node_args( + self, + node: Node, + for_postproc_module: bool = False, + ) -> Tuple[List[ArgInfo], int]: + pos_arg_info_list, args_found = self._get_node_args_helper( + node.args, + for_postproc_module, + ) + kwargs_arg_info_list, kwargs_found = self._get_node_args_helper( + node.kwargs.values(), + for_postproc_module, + ) - # Replace with proper names for kwargs - for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list): - arg_info_list.name = name + # Replace with proper names for kwargs + for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list): + arg_info_list.name = name - arg_info_list = pos_arg_info_list + kwargs_arg_info_list + arg_info_list = pos_arg_info_list + kwargs_arg_info_list - return (arg_info_list, num_found) + return (arg_info_list, args_found + kwargs_found) def _get_leaf_module_names_helper( @@ -1431,23 +1401,18 @@ def _rewrite_model( # noqa C901 pipelined_forwards = [] original_forwards = [] - pipelined_postprocs: Set[PipelinedPostproc] = set() non_pipelined_sharded_modules = [] + args_helper = NodeArgsHelper( + model, context, pipeline_postproc, default_stream, dist_stream + ) + for node in graph.nodes: if node.op == "call_module" and node.target in sharded_modules: total_num_args = len(node.args) + len(node.kwargs) if total_num_args == 0: continue - arg_info_list, num_found = _get_node_args( - model, - node, - pipelined_postprocs, - context, - pipeline_postproc, - default_stream=default_stream, - dist_stream=dist_stream, - ) + arg_info_list, num_found = args_helper.get_node_args(node) if num_found == total_num_args: logger.info(f"Module '{node.target}' will be pipelined") @@ -1485,7 +1450,7 @@ def _rewrite_model( # noqa C901 pipelined_forwards, input_model, original_forwards, - list(pipelined_postprocs), + list(args_helper.pipelined_postprocs), non_pipelined_sharded_modules, )