Skip to content

Commit

Permalink
Add set_to_none param to TrainingStateAverager.step()
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Mar 28, 2023
1 parent ad774b6 commit 26eaeac
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
12 changes: 4 additions & 8 deletions hivemind/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Callable, Optional, Sequence, Union

import torch
from packaging.version import Version

from hivemind.averaging.control import AveragingStage, StepControl
from hivemind.compression import CompressionBase, NoCompression
Expand All @@ -16,6 +15,7 @@
from hivemind.optim.grad_scaler import GradScaler
from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
from hivemind.optim.state_averager import (
ZERO_GRAD_SET_TO_NONE_DEFAULT,
LRSchedulerBase,
OptimizerFactory,
Parameters,
Expand Down Expand Up @@ -638,9 +638,7 @@ def _load_local_gradients_into_optimizer(self):
# - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
self._load_averaged_gradients_into_optimizer_()

_SET_TO_NONE_DEFAULT = Version(torch.__version__).major >= 2

def zero_grad(self, set_to_none: bool = _SET_TO_NONE_DEFAULT):
def zero_grad(self, set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT):
"""Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
raise ValueError(
Expand All @@ -649,11 +647,9 @@ def zero_grad(self, set_to_none: bool = _SET_TO_NONE_DEFAULT):
)
for param_group in self.param_groups:
for param in param_group["params"]:
if param.grad is None:
pass
elif set_to_none:
if set_to_none:
param.grad = None
else:
elif param.grad is not None:
param.grad.zero_()

def _should_load_state_from_peers(self) -> bool:
Expand Down
16 changes: 13 additions & 3 deletions hivemind/optim/state_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union

import torch
from packaging.version import Version

import hivemind
from hivemind.averaging import DecentralizedAverager
Expand All @@ -22,9 +23,11 @@
Parameters = Iterable[torch.Tensor]
ParamGroups = Iterable[Dict[str, Any]]
TorchOptimizer = torch.optim.Optimizer
try:
if Version(torch.__version__).major >= 2:
ZERO_GRAD_SET_TO_NONE_DEFAULT = True
LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
except AttributeError: # torch < 2.0.0
else:
ZERO_GRAD_SET_TO_NONE_DEFAULT = False
LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
Expand Down Expand Up @@ -335,6 +338,7 @@ def step(
averaging_control: Optional[StepControl] = None,
wait_for_trigger: Optional[Callable[[], Any]] = None,
grad_scaler: Optional[GradScaler] = None,
set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT,
averaging_opts: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -356,6 +360,8 @@ def step(
:param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
:note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
:param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
:param set_to_none: if True, zero_grad sets local gradients to None instead of zero tensors
(default in PyTorch 2.0+)
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
"""
if delay_averaging is None:
Expand Down Expand Up @@ -433,6 +439,7 @@ def step(
averaging_round,
averaging_control,
grad_scaler,
set_to_none,
**averaging_opts or {},
)
self.pending_updates.add(pending_update)
Expand Down Expand Up @@ -475,6 +482,7 @@ def _do(
averaging_round: bool,
averaging_control: Optional[StepControl],
grad_scaler: Optional[GradScaler],
set_to_none: bool,
timeout: Optional[float] = None,
**kwargs,
):
Expand Down Expand Up @@ -518,7 +526,9 @@ def _do(
self.optimizer.zero_grad()
if self.offload_optimizer:
for parameter in self.main_parameters:
if parameter.grad is not None:
if set_to_none:
parameter.grad = None
elif parameter.grad is not None:
parameter.grad.zero_()

self._update_scheduler()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hivemind.optim.optimizer import Optimizer
from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
from hivemind.optim.progress_tracker import ProgressTracker
from hivemind.optim.state_averager import TrainingStateAverager
from hivemind.optim.state_averager import ZERO_GRAD_SET_TO_NONE_DEFAULT, TrainingStateAverager
from hivemind.utils.crypto import RSAPrivateKey


Expand Down Expand Up @@ -79,7 +79,7 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
assert torch.allclose(model2.w.grad, ref_average)

# after no longer use_averaged_gradients
if hivemind.Optimizer._SET_TO_NONE_DEFAULT:
if ZERO_GRAD_SET_TO_NONE_DEFAULT:
assert model1.w.grad is None and model2.w.grad is None
else:
assert not torch.allclose(model1.w.grad, ref_average) and not torch.allclose(model2.w.grad, ref_average)
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
F.mse_loss(model2(x), -torch.ones(3)).backward()
avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)

if hivemind.Optimizer._SET_TO_NONE_DEFAULT:
if ZERO_GRAD_SET_TO_NONE_DEFAULT:
assert model1.weight.grad is None and model2.weight.grad is None, ".zero_grad() wasn't called"
else:
assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), ".zero_grad() wasn't called"
Expand Down

0 comments on commit 26eaeac

Please sign in to comment.