From aa37644d24bcc331262e00330bd2899c9e397859 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Mon, 24 Feb 2025 14:24:13 -0800 Subject: [PATCH] fix flaky test due to input_jkt.weight dtype Summary: # context * The [test_model_parallel_nccl](https://fb.workplace.com/groups/970281557043698/posts/1863456557726189/?comment_id=1867254224013089) has been reported to be flaky: [paste](https://www.internalfb.com/intern/everpaste/?color=0&handle=GJBrgxaEWkfR-ycEAP_fNV5sl_l1br0LAAAz) * after an in-depth investigation, the root cause is that when the dtype of the generated input KJT._weights is always `torch.float` (i.e., `torch.float32`), but the tested embedding table's dtype could be `torch.FP16`. # changes * added `dtype` argument to the `torch.rand` function in input generation so that the generated input kjt can have the correct dtype for the kjt._weights Differential Revision: D70126859 --- torchrec/distributed/test_utils/test_model.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 3fd3c0fd0..c4ffa6b12 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -31,6 +31,7 @@ from torchrec.modules.activation import SwishLayerNorm from torchrec.modules.embedding_configs import ( BaseEmbeddingConfig, + data_type_to_dtype, EmbeddingBagConfig, EmbeddingConfig, ) @@ -223,8 +224,8 @@ def _validate_pooling_factor( else: raise ValueError(f"For IdList features, unknown input type {input_type}") - for idx in range(len(idscore_ind_ranges)): - ind_range = idscore_ind_ranges[idx] + for idx, ind_range in enumerate(idscore_ind_ranges): + weighted_table = weighted_tables[idx] lengths_ = torch.abs( torch.randn(batch_size * world_size, device=device) + ( @@ -259,7 +260,11 @@ def _validate_pooling_factor( dtype=torch.long if long_indices else torch.int32, device=device, ) - weights = torch.rand((num_indices,), device=device) + weights = torch.rand( + (num_indices,), + device=device, + dtype=data_type_to_dtype(weighted_table.data_type), + ) global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights)