Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
In gradient clipping, if DTensors are used, need to first convert the…
…m to local tensors (#2091) Summary: Pull Request resolved: #2091 X-link: pytorch/pytorch#128317 For gradient clipping, handle a situation where DTensors are used to implement module parameters (which is the case if FSDP-2 used). The issue is that if we pass DTensors directly to the gradient clipping code, it need to go thru dtensor dispatch code, which is quite slow currently. For instance, below is the trace of gradient clipping to dtensors without this diff: https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/aps_traces/tree/traces/dynocli/aps-ckluk-d141c794fb/0/rank-0.Jun_05_14_22_04.3562.pt.trace.json.gz The clipping takes 93 ms per training step. With this diff, we will first convert parameters and their gradients that are dtensors to local tensors, and then pass them to the gradient clipping. For instance, below is the trace of gradient clipping with this diff: https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/aps_traces/tree/traces/dynocli/aps-ckluk-d247c392e3/0/rank-0.Jun_09_17_05_48.3748.pt.trace.json.gz Now, the clipping takes only 4 ms per training step. Differential Revision: D58325350
- Loading branch information