Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix flaky test due to input_jkt.weight dtype (pytorch#2763)
Summary: # context * The [test_model_parallel_nccl](https://fb.workplace.com/groups/970281557043698/posts/1863456557726189/?comment_id=1867254224013089) has been reported to be flaky: [paste](https://www.internalfb.com/intern/everpaste/?color=0&handle=GJBrgxaEWkfR-ycEAP_fNV5sl_l1br0LAAAz) * after an in-depth investigation, the root cause is that when the dtype of the generated input KJT._weights is always `torch.float` (i.e., `torch.float32`), but the tested embedding table's dtype could be `torch.FP16`. # changes * added `dtype` argument to the `torch.rand` function in input generation so that the generated input kjt can have the correct dtype for the kjt._weights Differential Revision: D70126859
- Loading branch information