diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 010abb459..722325ebf 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -1832,6 +1832,103 @@ def forward( return pred.sum(), pred +class TestModelWithPreprocCollectionArgs(nn.Module): + """ + Basic module with up to 3 postproc modules: + - postproc on idlist_features for non-weighted EBC + - postproc on idscore_features for weighted EBC + - postproc_inner on model input shared by both EBCs + - postproc_outer providing input to postproc_b (aka nested postproc) + + Args: + tables, + weighted_tables, + device, + postproc_module_a, + postproc_module_b, + num_float_features, + + Example: + >>> TestModelWithPreprocWithListArg(tables, weighted_tables, device) + + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + + CONST_DICT_KEY = "const" + INPUT_TENSOR_DICT_KEY = "tensor_from_input" + POSTPTOC_TENSOR_DICT_KEY = "tensor_from_postproc" + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: torch.device, + postproc_module_outer: nn.Module, + postproc_module_nested: nn.Module, + num_float_features: int = 10, + ) -> None: + super().__init__() + self.dense = TestDenseArch(num_float_features, device) + + self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( + tables=tables, + device=device, + ) + self.weighted_ebc = EmbeddingBagCollection( + tables=weighted_tables, + is_weighted=True, + device=device, + ) + self.postproc_nonweighted = TestPreprocNonWeighted() + self.postproc_weighted = TestPreprocWeighted() + self._postproc_module_outer = postproc_module_outer + self._postproc_module_nested = postproc_module_nested + + def forward( + self, + input: ModelInput, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Runs preproc for EBC and weighted EBC, optionally runs postproc for input + + Args: + input + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + modified_input = input + + outer_postproc_input = self._postproc_module_outer(modified_input) + + preproc_input_list = [ + 1, + modified_input.float_features, + outer_postproc_input, + ] + preproc_input_dict = { + self.CONST_DICT_KEY: 1, + self.INPUT_TENSOR_DICT_KEY: modified_input.float_features, + self.POSTPTOC_TENSOR_DICT_KEY: outer_postproc_input, + } + + modified_input = self._postproc_module_nested( + modified_input, preproc_input_list, preproc_input_dict + ) + + modified_idlist_features = self.postproc_nonweighted( + modified_input.idlist_features + ) + modified_idscore_features = self.postproc_weighted( + modified_input.idscore_features + ) + ebc_out = self.ebc(modified_idlist_features[0]) + weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0]) + + pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1) + return pred.sum(), pred + + class TestNegSamplingModule(torch.nn.Module): """ Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index bf708b1f5..dc23593a0 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -13,7 +13,7 @@ from contextlib import ExitStack from dataclasses import dataclass from functools import partial -from typing import cast, List, Optional, Tuple, Type, Union +from typing import cast, Dict, List, Optional, Tuple, Type, Union from unittest.mock import MagicMock import torch @@ -21,6 +21,7 @@ from torch import nn, optim from torch._dynamo.testing import reduce_to_scalar_loss from torch._dynamo.utils import counters +from torch.fx._symbolic_trace import is_fx_tracing from torchrec.distributed import DistributedModelParallel from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder @@ -36,6 +37,7 @@ ModelInput, TestEBCSharder, TestModelWithPreproc, + TestModelWithPreprocCollectionArgs, TestNegSamplingModule, TestPositionWeightedPreprocModule, TestSparseNN, @@ -1448,6 +1450,81 @@ def forward( self.assertEqual(len(pipeline._pipelined_modules), 2) self.assertEqual(len(pipeline._pipelined_postprocs), 1) + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_postproc_with_collection_args(self) -> None: + """ + Exercises scenario when postproc module has an argument that is a list or dict + with some elements being: + * static scalars + * static tensors (e.g. torch.ones()) + * tensors derived from input batch (e.g. input.idlist_features["feature_0"]) + * tensors derived from input batch and other postproc module (e.g. other_postproc(input.idlist_features["feature_0"])) + """ + test_runner = self + + class PostprocOuter(nn.Module): + def __init__( + self, + ) -> None: + super().__init__() + + def forward( + self, + model_input: ModelInput, + ) -> torch.Tensor: + return model_input.float_features * 0.1 + + class PostprocInner(nn.Module): + def __init__( + self, + ) -> None: + super().__init__() + + def forward( + self, + model_input: ModelInput, + input_list: List[Union[torch.Tensor, int]], + input_dict: Dict[str, Union[torch.Tensor, int]], + ) -> ModelInput: + if not is_fx_tracing(): + for idx, value in enumerate(input_list): + if isinstance(value, torch.fx.Node): + test_runner.fail( + f"input_list[{idx}] was a fx.Node: {value}" + ) + model_input.float_features += value + + for key, value in input_dict.items(): + if isinstance(value, torch.fx.Node): + test_runner.fail( + f"input_dict[{key}] was a fx.Node: {value}" + ) + model_input.float_features += value + + return model_input + + model = TestModelWithPreprocCollectionArgs( + tables=self.tables[:-1], # ignore last table as postproc will remove + weighted_tables=self.weighted_tables[:-1], # ignore last table + device=self.device, + postproc_module_outer=PostprocOuter(), + postproc_module_nested=PostprocInner(), + ) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # both EC end EBC are pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + # both outer and nested postproces are pipelined + self.assertEqual(len(pipeline._pipelined_postprocs), 4) + class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase): @unittest.skipIf( diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 5b76c1e2d..054218d60 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -210,7 +210,26 @@ def _build_args_kwargs( arg_info.constants, ): if obj is not None: - arg = obj + if isinstance(obj, list): + arg = [ + ( + v + if not isinstance(v, ArgInfo) + else _build_args_kwargs(initial_input, [v])[0][0] + ) + for v in obj + ] + elif isinstance(obj, dict): + arg = { + k: ( + v + if not isinstance(v, ArgInfo) + else _build_args_kwargs(initial_input, [v])[0][0] + ) + for k, v in obj.items() + } + else: + arg = obj break elif postproc_mod is not None: # postproc will internally run the same logic recursively @@ -908,232 +927,305 @@ def _swap_postproc_module_recursive( return module -def _get_node_args_helper( +def _get_node_args_helper_inner( model: torch.nn.Module, # pyre-ignore - arguments, + arg, + arg_info: ArgInfo, num_found: int, pipelined_postprocs: Set[PipelinedPostproc], context: TrainPipelineContext, pipeline_postproc: bool, - # Add `None` constants to arg info only for postproc modules - # Defaults to False for backward compatibility for_postproc_module: bool = False, default_stream: Optional[torch.Stream] = None, dist_stream: Optional[torch.Stream] = None, -) -> Tuple[List[ArgInfo], int]: - """ - Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. - It also counts the number of (args + kwargs) found. - """ - arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))] - for arg, arg_info in zip(arguments, arg_info_list): - if not for_postproc_module and arg is None: - num_found += 1 - continue - while True: - if not isinstance(arg, torch.fx.Node): - if pipeline_postproc: - arg_info.input_attrs.insert(0, "") - arg_info.is_getitems.insert(0, False) - arg_info.postproc_modules.insert(0, None) - if isinstance(arg, (fx_immutable_dict, fx_immutable_list)): - # Make them mutable again, in case in-place updates are made - arg_info.constants.insert(0, arg.copy()) - else: - arg_info.constants.insert(0, arg) - num_found += 1 - break - child_node = arg - - if child_node.op == "placeholder": - if hasattr(child_node, "ph_key"): - # pyre-ignore[16] - ph_key: str = child_node.ph_key - # example: ph_key = 'event_id_list_features_seqs[marketplace]' - ph_key = ph_key.replace("[", ".") - ph_keys = ph_key.split(".") - for key in ph_keys: - if "]" in key: - arg_info.input_attrs.append(key[:-1]) - arg_info.is_getitems.append(True) +) -> int: + num_found = 0 + while True: + if not isinstance(arg, torch.fx.Node): + if pipeline_postproc: + arg_info.input_attrs.insert(0, "") + arg_info.is_getitems.insert(0, False) + arg_info.postproc_modules.insert(0, None) + + if isinstance(arg, fx_immutable_dict): + fx_nested_dict = {} + + for k, v in arg.items(): + if isinstance(v, torch.fx.Node): + arg_info_nested = ArgInfo([], [], [], [], None) + _get_node_args_helper_inner( + model, + v, + arg_info_nested, + num_found, + pipelined_postprocs, + context, + pipeline_postproc, + for_postproc_module, + default_stream=default_stream, + dist_stream=dist_stream, + ) + fx_nested_dict[k] = arg_info_nested else: - arg_info.input_attrs.append(key) - arg_info.is_getitems.append(False) - arg_info.postproc_modules.append(None) - arg_info.constants.append(None) - else: - # no-op - arg_info.input_attrs.insert(0, "") - arg_info.is_getitems.insert(0, False) - arg_info.postproc_modules.insert(0, None) - arg_info.constants.insert(0, None) + fx_nested_dict[k] = v + + arg_info.constants.insert(0, fx_nested_dict) + elif isinstance(arg, fx_immutable_list): + fx_nested_list = [] + for v in arg: + if isinstance(v, torch.fx.Node): + arg_info_nested = ArgInfo([], [], [], [], None) + _get_node_args_helper_inner( + model, + v, + arg_info_nested, + num_found, + pipelined_postprocs, + context, + pipeline_postproc, + for_postproc_module, + default_stream=default_stream, + dist_stream=dist_stream, + ) + fx_nested_list.append(arg_info_nested) + else: + fx_nested_list.append(v) + arg_info.constants.insert(0, fx_nested_list) + else: + arg_info.constants.insert(0, arg) num_found += 1 - break - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "builtins" + break + child_node = arg + + if child_node.op == "placeholder": + if hasattr(child_node, "ph_key"): # pyre-ignore[16] - and child_node.target.__name__ == "getattr" - ): - # pyre-fixme[6]: For 2nd argument expected `str` but got - # `Union[None, Dict[str, typing.Any], List[typing.Any], Node, bool, - # complex, float, int, range, slice, str, device, dtype, layout, - # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. - arg_info.input_attrs.insert(0, child_node.args[1]) + ph_key: str = child_node.ph_key + # example: ph_key = 'event_id_list_features_seqs[marketplace]' + ph_key = ph_key.replace("[", ".") + ph_keys = ph_key.split(".") + for key in ph_keys: + if "]" in key: + arg_info.input_attrs.append(key[:-1]) + arg_info.is_getitems.append(True) + else: + arg_info.input_attrs.append(key) + arg_info.is_getitems.append(False) + arg_info.postproc_modules.append(None) + arg_info.constants.append(None) + else: + # no-op + arg_info.input_attrs.insert(0, "") arg_info.is_getitems.insert(0, False) arg_info.postproc_modules.insert(0, None) arg_info.constants.insert(0, None) - arg = child_node.args[0] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "_operator" - # pyre-ignore[16] - and child_node.target.__name__ == "getitem" - ): - # pyre-fixme[6]: For 2nd argument expected `str` but got - # `Union[None, Dict[str, typing.Any], List[typing.Any], Node, bool, - # complex, float, int, range, slice, str, device, dtype, layout, - # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. - arg_info.input_attrs.insert(0, child_node.args[1]) - arg_info.is_getitems.insert(0, True) - arg_info.postproc_modules.insert(0, None) - arg_info.constants.insert(0, None) - arg = child_node.args[0] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "torch.utils._pytree" - # pyre-ignore[16] - and child_node.target.__name__ == "tree_unflatten" - ): - """ - This is for the PT2 export path where we unflatten the input to reconstruct - the structure with the recorded tree spec. - """ - assert arg_info.is_getitems[0] - # pyre-fixme[16] - arg = child_node.args[0][arg_info.input_attrs[0]] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "torchrec.sparse.jagged_tensor" - # pyre-fixme[16] - and child_node.target.__name__ == "KeyedJaggedTensor" - ): - call_module_found = False - for arg_node in chain(child_node.args, child_node.kwargs.values()): - if isinstance( - arg_node, torch.fx.Node - ) and _check_args_for_call_module(arg_node): - call_module_found = True - break + num_found += 1 + break + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "builtins" + # pyre-ignore[16] + and child_node.target.__name__ == "getattr" + ): + # pyre-fixme[6]: For 2nd argument expected `str` but got + # `Union[None, Dict[str, typing.Any], List[typing.Any], Node, bool, + # complex, float, int, range, slice, str, device, dtype, layout, + # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. + arg_info.input_attrs.insert(0, child_node.args[1]) + arg_info.is_getitems.insert(0, False) + arg_info.postproc_modules.insert(0, None) + arg_info.constants.insert(0, None) + arg = child_node.args[0] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "_operator" + # pyre-ignore[16] + and child_node.target.__name__ == "getitem" + ): + # pyre-fixme[6]: For 2nd argument expected `str` but got + # `Union[None, Dict[str, typing.Any], List[typing.Any], Node, bool, + # complex, float, int, range, slice, str, device, dtype, layout, + # memory_format, Tensor, typing.Tuple[typing.Any, ...]]`. + arg_info.input_attrs.insert(0, child_node.args[1]) + arg_info.is_getitems.insert(0, True) + arg_info.postproc_modules.insert(0, None) + arg_info.constants.insert(0, None) + arg = child_node.args[0] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "torch.utils._pytree" + # pyre-ignore[16] + and child_node.target.__name__ == "tree_unflatten" + ): + """ + This is for the PT2 export path where we unflatten the input to reconstruct + the structure with the recorded tree spec. + """ + assert arg_info.is_getitems[0] + # pyre-fixme[16] + arg = child_node.args[0][arg_info.input_attrs[0]] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "torchrec.sparse.jagged_tensor" + # pyre-fixme[16] + and child_node.target.__name__ == "KeyedJaggedTensor" + ): + call_module_found = False - if call_module_found: + for arg_node in chain(child_node.args, child_node.kwargs.values()): + if isinstance(arg_node, torch.fx.Node) and _check_args_for_call_module( + arg_node + ): + call_module_found = True break - if "values" in child_node.kwargs: - arg = child_node.kwargs["values"] - else: - arg = child_node.args[1] - elif child_node.op == "call_method" and child_node.target == "get": - # pyre-ignore[6] - arg_info.input_attrs.insert(0, child_node.args[1]) - arg_info.is_getitems.insert(0, True) - arg_info.postproc_modules.insert(0, None) - arg_info.constants.insert(0, None) - arg = child_node.args[0] - elif child_node.op == "call_module": - postproc_module_fqn = str(child_node.target) - postproc_module = _find_postproc_module_recursive( - model, postproc_module_fqn + if call_module_found: + break + + if "values" in child_node.kwargs: + arg = child_node.kwargs["values"] + else: + arg = child_node.args[1] + elif child_node.op == "call_method" and child_node.target == "get": + # pyre-ignore[6] + arg_info.input_attrs.insert(0, child_node.args[1]) + arg_info.is_getitems.insert(0, True) + arg_info.postproc_modules.insert(0, None) + arg_info.constants.insert(0, None) + arg = child_node.args[0] + elif child_node.op == "call_module": + postproc_module_fqn = str(child_node.target) + postproc_module = _find_postproc_module_recursive( + model, postproc_module_fqn + ) + + if not pipeline_postproc: + logger.warning( + f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" ) + break - if not pipeline_postproc: - logger.warning( - f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" - ) - break + if not postproc_module: + # Could not find such module, should not happen + break - if not postproc_module: - # Could not find such module, should not happen - break + if isinstance(postproc_module, PipelinedPostproc): + # Already did module swap and registered args, early exit + arg_info.input_attrs.insert(0, "") # dummy value + arg_info.is_getitems.insert(0, False) + pipelined_postprocs.add(postproc_module) + arg_info.postproc_modules.insert(0, postproc_module) + arg_info.constants.insert(0, None) + num_found += 1 + break - if isinstance(postproc_module, PipelinedPostproc): - # Already did module swap and registered args, early exit - arg_info.input_attrs.insert(0, "") # dummy value - arg_info.is_getitems.insert(0, False) - pipelined_postprocs.add(postproc_module) - arg_info.postproc_modules.insert(0, postproc_module) - arg_info.constants.insert(0, None) - num_found += 1 - break + if not isinstance(postproc_module, torch.nn.Module): + logger.warning( + f"Expected postproc_module to be nn.Module but was {type(postproc_module)}" + ) + break - if not isinstance(postproc_module, torch.nn.Module): - logger.warning( - f"Expected postproc_module to be nn.Module but was {type(postproc_module)}" - ) - break + # check if module is safe to pipeline i.e.no trainable param + if not _check_postproc_pipelineable(postproc_module): + break - # check if module is safe to pipeline i.e.no trainable param - if not _check_postproc_pipelineable(postproc_module): - break + # For module calls, `self` isn't counted + total_num_args = len(child_node.args) + len(child_node.kwargs) + if total_num_args == 0: + # module call without any args, assume KJT modified + break - # For module calls, `self` isn't counted - total_num_args = len(child_node.args) + len(child_node.kwargs) - if total_num_args == 0: - # module call without any args, assume KJT modified - break + # recursive call to check that all inputs to this postproc module + # is either made of postproc module or non-modifying train batch input + # transformations + postproc_args, num_found_safe_postproc_args = _get_node_args( + model, + child_node, + pipelined_postprocs, + context, + pipeline_postproc, + True, + default_stream=default_stream, + dist_stream=dist_stream, + ) + if num_found_safe_postproc_args == total_num_args: + logger.info( + f"""Module {postproc_module} is a valid postproc module (no + trainable params and inputs can be derived from train batch input + via a series of either valid postproc modules or non-modifying + transformations) and will be applied during sparse data dist + stage""" + ) - # recursive call to check that all inputs to this postproc module - # is either made of postproc module or non-modifying train batch input - # transformations - postproc_args, num_found_safe_postproc_args = _get_node_args( - model, - child_node, - pipelined_postprocs, + pipelined_postproc_module = PipelinedPostproc( + postproc_module, + postproc_module_fqn, + postproc_args, context, - pipeline_postproc, - True, default_stream=default_stream, dist_stream=dist_stream, ) - if num_found_safe_postproc_args == total_num_args: - logger.info( - f"""Module {postproc_module} is a valid postproc module (no - trainable params and inputs can be derived from train batch input - via a series of either valid postproc modules or non-modifying - transformations) and will be applied during sparse data dist - stage""" - ) - pipelined_postproc_module = PipelinedPostproc( - postproc_module, - postproc_module_fqn, - postproc_args, - context, - default_stream=default_stream, - dist_stream=dist_stream, - ) + # module swap + _swap_postproc_module_recursive( + model, pipelined_postproc_module, postproc_module_fqn + ) - # module swap - _swap_postproc_module_recursive( - model, pipelined_postproc_module, postproc_module_fqn - ) + arg_info.input_attrs.insert(0, "") # dummy value + arg_info.is_getitems.insert(0, False) + pipelined_postprocs.add(pipelined_postproc_module) + arg_info.postproc_modules.insert(0, pipelined_postproc_module) + arg_info.constants.insert(0, None) - arg_info.input_attrs.insert(0, "") # dummy value - arg_info.is_getitems.insert(0, False) - pipelined_postprocs.add(pipelined_postproc_module) - arg_info.postproc_modules.insert(0, pipelined_postproc_module) - arg_info.constants.insert(0, None) + num_found += 1 - num_found += 1 + # we cannot set any other `arg` value here + # break to avoid infinite loop + break + else: + break - # we cannot set any other `arg` value here - # break to avoid infinite loop - break - else: - break + return num_found + + +def _get_node_args_helper( + model: torch.nn.Module, + # pyre-ignore + arguments, + num_found: int, + pipelined_postprocs: Set[PipelinedPostproc], + context: TrainPipelineContext, + pipeline_postproc: bool, + # Add `None` constants to arg info only for postproc modules + # Defaults to False for backward compatibility + for_postproc_module: bool = False, + default_stream: Optional[torch.Stream] = None, + dist_stream: Optional[torch.Stream] = None, +) -> Tuple[List[ArgInfo], int]: + """ + Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. + It also counts the number of (args + kwargs) found. + """ + arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))] + for arg, arg_info in zip(arguments, arg_info_list): + if not for_postproc_module and arg is None: + num_found += 1 + continue + num_found += _get_node_args_helper_inner( + model, + arg, + arg_info, + num_found, + pipelined_postprocs, + context, + pipeline_postproc, + for_postproc_module, + default_stream=default_stream, + dist_stream=dist_stream, + ) return arg_info_list, num_found