Skip to content

Commit

Permalink
fix flaky test due to input_jkt.weight dtype (pytorch#2763)
Browse files Browse the repository at this point in the history
Summary:

# context
* The [test_model_parallel_nccl](https://fb.workplace.com/groups/970281557043698/posts/1863456557726189/?comment_id=1867254224013089) has been reported to be flaky: [paste](https://www.internalfb.com/intern/everpaste/?color=0&handle=GJBrgxaEWkfR-ycEAP_fNV5sl_l1br0LAAAz)

* after an in-depth investigation, the root cause is that when the dtype of the generated input KJT._weights is always `torch.float` (i.e., `torch.float32`), but the tested embedding table's dtype could be `torch.FP16`.

# changes
* added `dtype` argument to the `torch.rand` function in input generation so that the generated input kjt can have the correct dtype for the kjt._weights

Differential Revision: D70126859
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Feb 24, 2025
1 parent 7500a0f commit 3477ac1
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,11 @@ def forward(
res = embedding_bag(
input=f.values(),
offsets=f.offsets(),
per_sample_weights=f.weights() if self._is_weighted else None,
per_sample_weights=(
f.weights().to(embedding_bag.weight.dtype)
if self._is_weighted
else None
),
).float()
pooled_embeddings.append(res)
return KeyedTensor(
Expand Down

0 comments on commit 3477ac1

Please sign in to comment.