Skip to content

Commit

Permalink
fix flaky test due to input_jkt.weight dtype
Browse files Browse the repository at this point in the history
Summary:
# context
* The [test_model_parallel_nccl](https://fb.workplace.com/groups/970281557043698/posts/1863456557726189/?comment_id=1867254224013089) has been reported to be flaky: [paste](https://www.internalfb.com/intern/everpaste/?color=0&handle=GJBrgxaEWkfR-ycEAP_fNV5sl_l1br0LAAAz)

* after an in-depth investigation, the root cause is that when the dtype of the generated input KJT._weights is always `torch.float` (i.e., `torch.float32`), but the tested embedding table's dtype could be `torch.FP16`.

# changes
* added `dtype` argument to the `torch.rand` function in input generation so that the generated input kjt can have the correct dtype for the kjt._weights

Differential Revision: D70126859
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Feb 24, 2025
1 parent 7500a0f commit aa37644
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchrec.modules.activation import SwishLayerNorm
from torchrec.modules.embedding_configs import (
BaseEmbeddingConfig,
data_type_to_dtype,
EmbeddingBagConfig,
EmbeddingConfig,
)
Expand Down Expand Up @@ -223,8 +224,8 @@ def _validate_pooling_factor(
else:
raise ValueError(f"For IdList features, unknown input type {input_type}")

for idx in range(len(idscore_ind_ranges)):
ind_range = idscore_ind_ranges[idx]
for idx, ind_range in enumerate(idscore_ind_ranges):
weighted_table = weighted_tables[idx]
lengths_ = torch.abs(
torch.randn(batch_size * world_size, device=device)
+ (
Expand Down Expand Up @@ -259,7 +260,11 @@ def _validate_pooling_factor(
dtype=torch.long if long_indices else torch.int32,
device=device,
)
weights = torch.rand((num_indices,), device=device)
weights = torch.rand(
(num_indices,),
device=device,
dtype=data_type_to_dtype(weighted_table.data_type),
)
global_idscore_lengths.append(lengths)
global_idscore_indices.append(indices)
global_idscore_weights.append(weights)
Expand Down

0 comments on commit aa37644

Please sign in to comment.