diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 010abb459..5a75be4a1 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn from tensordict import TensorDict +from torch import ones_like from torchrec.distributed.embedding_tower_sharding import ( EmbeddingTowerCollectionSharder, EmbeddingTowerSharder, @@ -43,6 +44,8 @@ from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor from torchrec.streamable import Pipelineable +torch.fx.wrap("ones_like") + @dataclass class ModelInput(Pipelineable): @@ -1832,6 +1835,110 @@ def forward( return pred.sum(), pred +class TestModelWithPreprocCollectionArgs(nn.Module): + """ + Basic module with up to 3 postproc modules: + - postproc on idlist_features for non-weighted EBC + - postproc on idscore_features for weighted EBC + - postproc_inner on model input shared by both EBCs + - postproc_outer providing input to postproc_b (aka nested postproc) + + Args: + tables, + weighted_tables, + device, + postproc_module_a, + postproc_module_b, + num_float_features, + + Example: + >>> TestModelWithPreprocWithListArg(tables, weighted_tables, device) + + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + + CONST_LIST_IDX = 0 + CONST_DICT_KEY = "const" + STATIC_TENSOR_LIST_IDX = 1 + STATIC_TENSOR_DICT_KEY = "static_tensor" + INPUT_TENSOR_LIST_IDX = 2 + INPUT_TENSOR_DICT_KEY = "tensor_from_input" + POSTPROC_TENSOR_LIST_IDX = 3 + POSTPTOC_TENSOR_DICT_KEY = "tensor_from_postproc" + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: torch.device, + postproc_module_outer: nn.Module, + postproc_module_nested: nn.Module, + num_float_features: int = 10, + ) -> None: + super().__init__() + self.dense = TestDenseArch(num_float_features, device) + + self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( + tables=tables, + device=device, + ) + self.weighted_ebc = EmbeddingBagCollection( + tables=weighted_tables, + is_weighted=True, + device=device, + ) + self.postproc_nonweighted = TestPreprocNonWeighted() + self.postproc_weighted = TestPreprocWeighted() + self._postproc_module_outer = postproc_module_outer + self._postproc_module_nested = postproc_module_nested + + def forward( + self, + input: ModelInput, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Runs preproc for EBC and weighted EBC, optionally runs postproc for input + + Args: + input + Returns: + Tuple[torch.Tensor, torch.Tensor] + """ + modified_input = input + + outer_postproc_input = self._postproc_module_outer(modified_input) + + preproc_input_list = [ + 1, + ones_like(outer_postproc_input, dtype=torch.float32), + modified_input.float_features, + outer_postproc_input, + ] + preproc_input_dict = { + self.CONST_DICT_KEY: 1, + self.STATIC_TENSOR_DICT_KEY: ones_like(outer_postproc_input), + self.INPUT_TENSOR_DICT_KEY: modified_input.float_features, + self.POSTPTOC_TENSOR_DICT_KEY: outer_postproc_input, + } + + modified_input = self._postproc_module_nested( + modified_input, preproc_input_list, preproc_input_dict + ) + + modified_idlist_features = self.postproc_nonweighted( + modified_input.idlist_features + ) + modified_idscore_features = self.postproc_weighted( + modified_input.idscore_features + ) + ebc_out = self.ebc(modified_idlist_features[0]) + weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0]) + + pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1) + return pred.sum(), pred + + class TestNegSamplingModule(torch.nn.Module): """ Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index bf708b1f5..e3bf54dfd 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -13,7 +13,7 @@ from contextlib import ExitStack from dataclasses import dataclass from functools import partial -from typing import cast, List, Optional, Tuple, Type, Union +from typing import cast, Dict, List, Optional, Tuple, Type, Union from unittest.mock import MagicMock import torch @@ -21,6 +21,7 @@ from torch import nn, optim from torch._dynamo.testing import reduce_to_scalar_loss from torch._dynamo.utils import counters +from torch.fx._symbolic_trace import is_fx_tracing from torchrec.distributed import DistributedModelParallel from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder @@ -36,6 +37,7 @@ ModelInput, TestEBCSharder, TestModelWithPreproc, + TestModelWithPreprocCollectionArgs, TestNegSamplingModule, TestPositionWeightedPreprocModule, TestSparseNN, @@ -1448,6 +1450,81 @@ def forward( self.assertEqual(len(pipeline._pipelined_modules), 2) self.assertEqual(len(pipeline._pipelined_postprocs), 1) + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pipeline_postproc_with_collection_args(self) -> None: + """ + Exercises scenario when postproc module has an argument that is a list or dict + with some elements being: + * static scalars + * static tensors (e.g. torch.ones()) + * tensors derived from input batch (e.g. input.idlist_features["feature_0"]) + * tensors derived from input batch and other postproc module (e.g. other_postproc(input.idlist_features["feature_0"])) + """ + test_runner = self + + class PostprocOuter(nn.Module): + def __init__( + self, + ) -> None: + super().__init__() + + def forward( + self, + model_input: ModelInput, + ) -> torch.Tensor: + return model_input.float_features * 0.1 + + class PostprocInner(nn.Module): + def __init__( + self, + ) -> None: + super().__init__() + + def forward( + self, + model_input: ModelInput, + input_list: List[Union[torch.Tensor, int]], + input_dict: Dict[str, Union[torch.Tensor, int]], + ) -> ModelInput: + if not is_fx_tracing(): + for idx, value in enumerate(input_list): + if isinstance(value, torch.fx.Node): + test_runner.fail( + f"input_list[{idx}] was a fx.Node: {value}" + ) + model_input.float_features += value + + for key, value in input_dict.items(): + if isinstance(value, torch.fx.Node): + test_runner.fail( + f"input_dict[{key}] was a fx.Node: {value}" + ) + model_input.float_features += value + + return model_input + + model = TestModelWithPreprocCollectionArgs( + tables=self.tables[:-1], # ignore last table as postproc will remove + weighted_tables=self.weighted_tables[:-1], # ignore last table + device=self.device, + postproc_module_outer=PostprocOuter(), + postproc_module_nested=PostprocInner(), + ) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # both EC end EBC are pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + # both outer and nested postproces are pipelined + self.assertEqual(len(pipeline._pipelined_postprocs), 2) + class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase): @unittest.skipIf(