Skip to content

Commit

Permalink
In gradient clipping, if DTensors are used, need to first convert the…
Browse files Browse the repository at this point in the history
…m to local tensors (#2091)

Summary:
Pull Request resolved: #2091

X-link: pytorch/pytorch#128317

For gradient clipping, handle a situation where DTensors are used to implement module parameters (which is the case if FSDP-2 used).

The issue is that if we pass DTensors directly to the gradient clipping code, it need to go thru dtensor dispatch code, which is quite slow currently. For instance, below is the trace of gradient clipping to dtensors without this diff:
https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/aps_traces/tree/traces/dynocli/aps-ckluk-d141c794fb/0/rank-0.Jun_05_14_22_04.3562.pt.trace.json.gz The clipping takes 93 ms per training step.

With this diff, we will first convert parameters and their gradients that are dtensors to local tensors, and then pass them to the gradient clipping. For instance, below is the trace of gradient clipping with this diff: https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/aps_traces/tree/traces/dynocli/aps-ckluk-d247c392e3/0/rank-0.Jun_09_17_05_48.3748.pt.trace.json.gz Now, the clipping takes only 4 ms per training step.

Differential Revision: D58325350
  • Loading branch information
ckluk2 authored and facebook-github-bot committed Jun 10, 2024
1 parent 733e42d commit afdab4a
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from typing import Any, List

import torch

from torch.distributed._tensor.api import DTensor

from torchrec.optim.keyed import KeyedOptimizer, OptimizerWrapper


Expand All @@ -36,6 +39,7 @@ def __init__(
optimizer: KeyedOptimizer,
clipping: GradientClipping = GradientClipping.NONE,
max_gradient: float = 0.1,
use_dtensors: bool = False,
) -> None:
super().__init__(optimizer)
self._clipping = clipping
Expand All @@ -46,6 +50,34 @@ def __init__(
for param_group in self.param_groups:
self._params += list(param_group["params"])

if use_dtensors:
# If DTensors are used, we need to convert them to local tensors
# before doing gradient clipping for performance reasons;
# otherwise, it needs to go thru dtensor dispatch, which is
# quite slow currently.
with torch.autograd.profiler.record_function(
"Dtensors => Tensors in GradientClippingOptimizer::__init__()"
):
new_params: List[torch.Tensor] = []
for p in self._params:
if p is not None and isinstance(p, DTensor):
local_p = p._local_tensor
if local_p.numel() == 0:
# skip empty tensors
continue
if p.grad is not None and isinstance(p.grad, DTensor):
local_p.grad = p.grad._local_tensor
# p and p.grad should have the same number of elements.
assert local_p.grad.numel() > 0
else:
# simply use p.grad as local_p.grad
local_p.grad = p.grad
new_params.append(local_p)
else:
# simply use p as it is.
new_params.append(p)
self._params = new_params

# pyre-ignore [2]
def step(self, closure: Any = None) -> None:
if self._check_meta:
Expand Down

0 comments on commit afdab4a

Please sign in to comment.