Skip to content

Commit

Permalink
Formatting changes for gradient scaling (pytorch#33832)
Browse files Browse the repository at this point in the history
Summary:
hard to get right locally...I can build the docs but never quite match what it looks like live.  the bullet point indentation was just an oversight.

Removing `Returns:` formatting tabs because they take up a lot of space when rendered and add no clarity.  Some functions in Pytorch [do use them](https://pytorch.org/docs/master/torch.html#torch.eye), but [many don't bother](https://pytorch.org/docs/master/torch.html#torch.is_tensor), so apparently some people shared my feelings (Not using them is in line with existing practice).
Pull Request resolved: pytorch#33832

Differential Revision: D20135581

Pulled By: ngimel

fbshipit-source-id: bc788a7e57b142f95c4fa5baf3fe01f94c45abd8
  • Loading branch information
definitelynotmcarilli authored and facebook-github-bot committed Feb 28, 2020
1 parent 5dde8cd commit a726827
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions torch/cuda/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ class GradScaler(object):
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
Expand Down Expand Up @@ -125,11 +125,11 @@ def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Arguments:
outputs (Tensor or iterable of Tensors): Outputs to scale.
Returns:
Scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified.
"""
if not self._enabled:
return outputs
Expand Down Expand Up @@ -234,14 +234,13 @@ def step(self, optimizer, *args, **kwargs):
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Arguments:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
Returns:
The return value of ``optimizer.step(*args, **kwargs)``.
.. warning::
Closure use is not currently supported.
"""
Expand Down Expand Up @@ -342,8 +341,7 @@ def _get_scale_async(self):

def get_scale(self):
"""
Returns:
A Python float containing the current scale, or 1.0 if scaling is disabled.
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
Expand All @@ -355,8 +353,7 @@ def get_scale(self):

def get_growth_factor(self):
r"""
Returns:
A Python float containing the scale growth factor.
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor

Expand All @@ -369,8 +366,7 @@ def set_growth_factor(self, new_factor):

def get_backoff_factor(self):
r"""
Returns:
A Python float containing the scale backoff factor.
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor

Expand All @@ -383,8 +379,7 @@ def set_backoff_factor(self, new_factor):

def get_growth_interval(self):
r"""
Returns:
A Python int containing the growth interval.
Returns a Python int containing the growth interval.
"""
return self._growth_interval

Expand All @@ -403,8 +398,7 @@ def _get_growth_tracker(self):

def is_enabled(self):
r"""
Returns:
A bool indicating whether this instance is enabled.
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled

Expand Down

0 comments on commit a726827

Please sign in to comment.