diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 3fd3c0fd0..c4ffa6b12 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -31,6 +31,7 @@ from torchrec.modules.activation import SwishLayerNorm from torchrec.modules.embedding_configs import ( BaseEmbeddingConfig, + data_type_to_dtype, EmbeddingBagConfig, EmbeddingConfig, ) @@ -223,8 +224,8 @@ def _validate_pooling_factor( else: raise ValueError(f"For IdList features, unknown input type {input_type}") - for idx in range(len(idscore_ind_ranges)): - ind_range = idscore_ind_ranges[idx] + for idx, ind_range in enumerate(idscore_ind_ranges): + weighted_table = weighted_tables[idx] lengths_ = torch.abs( torch.randn(batch_size * world_size, device=device) + ( @@ -259,7 +260,11 @@ def _validate_pooling_factor( dtype=torch.long if long_indices else torch.int32, device=device, ) - weights = torch.rand((num_indices,), device=device) + weights = torch.rand( + (num_indices,), + device=device, + dtype=data_type_to_dtype(weighted_table.data_type), + ) global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights)