diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 846cb343b..a7379c3cb 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -71,6 +71,7 @@ def generate( weighted_tables_pooling: Optional[List[int]] = None, randomize_indices: bool = True, device: Optional[torch.device] = None, + max_feature_lengths: Optional[List[int]] = None, ) -> Tuple["ModelInput", List["ModelInput"]]: """ Returns a global (single-rank training) batch @@ -102,11 +103,17 @@ def _validate_pooling_factor( idlist_features_to_num_embeddings = {} idlist_features_to_pooling_factor = {} + idlist_features_to_max_length = {} + feature_idx = 0 for idx in range(len(tables)): for feature in tables[idx].feature_names: idlist_features_to_num_embeddings[feature] = tables[idx].num_embeddings + idlist_features_to_max_length[feature] = ( + max_feature_lengths[feature_idx] if max_feature_lengths else None + ) if tables_pooling is not None: idlist_features_to_pooling_factor[feature] = tables_pooling[idx] + feature_idx += 1 idlist_features = list(idlist_features_to_num_embeddings.keys()) idscore_features = [ @@ -119,6 +126,8 @@ def _validate_pooling_factor( idlist_pooling_factor = list(idlist_features_to_pooling_factor.values()) idscore_pooling_factor = weighted_tables_pooling + idlist_max_lengths = list(idlist_features_to_max_length.values()) + # Generate global batch. global_idlist_lengths = [] global_idlist_indices = [] @@ -142,6 +151,10 @@ def _validate_pooling_factor( lengths_ = torch.abs( torch.randn(batch_size * world_size, device=device) + pooling_avg, ).int() + + if idlist_max_lengths[idx]: + lengths_ = torch.clamp(lengths_, max=idlist_max_lengths[idx]) + if variable_batch_size: lengths = torch.zeros(batch_size * world_size, device=device).int() for r in range(world_size): @@ -152,6 +165,7 @@ def _validate_pooling_factor( ) else: lengths = lengths_ + num_indices = cast(int, torch.sum(lengths).item()) if randomize_indices: indices = torch.randint( diff --git a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py index deee22f5e..3e00c7f82 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py @@ -28,6 +28,8 @@ from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +DEFAULT_MAX_FEATURE_LENGTH = 12 + class SparseArch(nn.Module): def __init__( @@ -35,9 +37,21 @@ def __init__( tables: List[EmbeddingBagConfig], use_fp_collection: bool, device: torch.device, + max_feature_lengths: Optional[List[int]] = None, ) -> None: super().__init__() + feature_names = [ + feature_name for table in tables for feature_name in table.feature_names + ] + + if max_feature_lengths is None: + max_feature_lengths = [DEFAULT_MAX_FEATURE_LENGTH] * len(feature_names) + + assert len(max_feature_lengths) == len( + feature_names + ), "Expect max_feature_lengths to have the same number of items as feature_names" + self._fp_ebc: FeatureProcessedEmbeddingBagCollection = ( FeatureProcessedEmbeddingBagCollection( EmbeddingBagCollection( @@ -49,20 +63,19 @@ def __init__( cast( Dict[str, FeatureProcessor], { - "feature_0": PositionWeightedModule(max_feature_length=10), - "feature_1": PositionWeightedModule(max_feature_length=10), - "feature_2": PositionWeightedModule(max_feature_length=12), - "feature_3": PositionWeightedModule(max_feature_length=12), + feature_name: PositionWeightedModule( + max_feature_length=max_feature_length + ) + for feature_name, max_feature_length in zip( + feature_names, max_feature_lengths + ) }, ) if not use_fp_collection else PositionWeightedModuleCollection( - max_feature_lengths={ - "feature_0": 10, - "feature_1": 10, - "feature_2": 12, - "feature_3": 12, - } + max_feature_lengths=dict( + zip(feature_names, max_feature_lengths) + ), ) ), ).to(device) @@ -85,9 +98,10 @@ def create_module_and_freeze( tables: List[EmbeddingBagConfig], use_fp_collection: bool, device: torch.device, + max_feature_lengths: Optional[List[int]] = None, ) -> SparseArch: - sparse_arch = SparseArch(tables, use_fp_collection, device) + sparse_arch = SparseArch(tables, use_fp_collection, device, max_feature_lengths) torch.manual_seed(0) for param in sparse_arch.parameters(): diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 916c3975b..b41de4346 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -178,20 +178,23 @@ def __init__(self, sparse_arch): def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: return self.m(model_input.idlist_features) + max_feature_lengths = [10, 10, 12, 12] sparse_arch = DummyWrapper( create_module_and_freeze( tables=self.tables, device=self.device, use_fp_collection=False, + max_feature_lengths=max_feature_lengths, ) ) + compute_kernel = EmbeddingComputeKernel.FUSED.value module_sharding_plan = construct_module_sharding_plan( sparse_arch.m._fp_ebc, per_param_sharding={ - "table_0": table_wise(rank=0), - "table_1": table_wise(rank=0), - "table_2": table_wise(rank=0), - "table_3": table_wise(rank=0), + "table_0": table_wise(rank=0, compute_kernel=compute_kernel), + "table_1": table_wise(rank=0, compute_kernel=compute_kernel), + "table_2": table_wise(rank=0, compute_kernel=compute_kernel), + "table_3": table_wise(rank=0, compute_kernel=compute_kernel), }, local_size=1, world_size=1, @@ -219,7 +222,9 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: sharded_sparse_arch_pipeline.state_dict(), ) - data = self._generate_data(num_batches=5, batch_size=1) + data = self._generate_data( + num_batches=5, batch_size=1, max_feature_lengths=max_feature_lengths + ) dataloader = iter(data) optimizer_no_pipeline = optim.SGD( @@ -236,6 +241,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: ) for batch in data[:-2]: + batch = batch.to(self.device) optimizer_no_pipeline.zero_grad() loss, pred = sharded_sparse_arch_no_pipeline(batch) loss.backward() diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index 76105d7d3..8317f2354 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -68,6 +68,7 @@ def _generate_data( self, num_batches: int = 5, batch_size: int = 1, + max_feature_lengths: Optional[List[int]] = None, ) -> List[ModelInput]: return [ ModelInput.generate( @@ -76,6 +77,7 @@ def _generate_data( batch_size=batch_size, world_size=1, num_float_features=10, + max_feature_lengths=max_feature_lengths, )[0] for i in range(num_batches) ]