Skip to content

Commit

Permalink
[BugFix] Resolve manual optimization (#5852)
Browse files Browse the repository at this point in the history
* resolve manual_optimization

* update

* update

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
tchaton and Ubuntu authored Feb 6, 2021
1 parent 7bb9d9f commit 13ae1ff
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 86 deletions.
39 changes: 15 additions & 24 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import (
Expand Down Expand Up @@ -228,8 +229,8 @@ def predict(self, args):
return self.training_type_plugin.predict(*args)

def process_dataloader(
self, dataloader: Union[Iterable, torch.utils.data.DataLoader]
) -> Union[Iterable, torch.utils.data.DataLoader]:
self, dataloader: Union[Iterable, DataLoader]
) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary
Args:
Expand All @@ -240,7 +241,7 @@ def process_dataloader(
def backward(
self,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
Expand All @@ -254,17 +255,17 @@ def backward(
opt_idx: the index of the optimizer
should_accumulate: whether to accumulate gradients
"""
self.training_type_plugin.pre_backward(closure_loss, optimizer, opt_idx)
self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, opt_idx)

output = self.precision_plugin.backward(
self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
)

self.training_type_plugin.post_backward(closure_loss, optimizer, opt_idx)
self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, opt_idx)

return output

def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
"""performs the actual optimizer step.
Args:
Expand All @@ -273,33 +274,23 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_
lambda_closure: closure calculating the loss value
"""

self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)

if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
# apex does not support passing a closure to the optimizer, call it by itself
lambda_closure()
lambda_closure = None

optimizer.step(closure=lambda_closure, **kwargs)

make_optimizer_step = self.precision_plugin.pre_optimizer_step(
self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs)
if make_optimizer_step:
self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)

if self.rpc_enabled and self.training_type_plugin.is_main_rpc_process:

# Initialize optimizer step on main process
self.training_type_plugin.worker_optimizer_step(model=self.lightning_module, opt_idx=opt_idx, **kwargs)
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)

def optimizer_zero_grad(
self, current_epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int
) -> None:
"""Zeros all model parameter's gradients"""
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer: torch.optim.Optimizer, clip_val: Union[int, float]) -> None:
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""clips all the optimizer parameters to the given value"""

self.precision_plugin.clip_gradients(optimizer, clip_val)
Expand Down
19 changes: 2 additions & 17 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable

import torch
from torch.optim import Optimizer

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
Expand All @@ -26,20 +26,5 @@ def setup(self, trainer, model):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
"""performs the actual optimizer step.
Args:
optimizer: the optimizer performing the step
opt_idx: index of the current optimizer
lambda_closure: closure calculating the loss value
"""

self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})

self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)
13 changes: 4 additions & 9 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from torch.nn import Module
from abc import ABC, abstractmethod
from typing import Any, Generator, Optional, overload, Sequence, Tuple
from typing import Any, Callable, Generator, Optional, overload, Sequence, Tuple

import torch

Expand All @@ -22,18 +23,12 @@ class Plugin(ABC):
"""Basic Plugin class to derive precision and training type plugins from."""

@abstractmethod
def connect(self, model: torch.nn.Module, *args: Sequence,
**kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]:
def connect(self, model: Module, *args: Sequence,
**kwargs: Sequence) -> Optional[Tuple[Module, Sequence, Sequence]]:
"""Connects the plugin with the accelerator (and thereby with trainer and model).
Will be called by the accelerator.
"""

def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Hook to do something before each optimizer step."""

def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""

def pre_training(self) -> None:
"""Hook to do something before the training starts."""

Expand Down
48 changes: 29 additions & 19 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Generator
from typing import Callable, Generator

import torch
from torch.optim import LBFGS, Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
Expand All @@ -33,25 +34,11 @@ def __init__(self):
self.backend = AMPType.NATIVE
self.scaler = torch.cuda.amp.GradScaler()

def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""always called before the optimizer step.
Checks that the optimizer is not LBFGS, as this one is not supported by native amp
"""
if isinstance(optimizer, torch.optim.LBFGS):
raise MisconfigurationException(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)

def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Updates the GradScaler"""
self.scaler.update()

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
Expand All @@ -69,16 +56,39 @@ def backward(
"""
closure_loss = self.scaler.scale(closure_loss)

automatic_optimization = model.automatic_optimization

closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs)

# unscale gradient to allow analyze within `on_after_backward`
if not should_accumulate and automatic_optimization:
if not should_accumulate and model.automatic_optimization:
self.scaler.unscale_(optimizer)

return closure_loss

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
) -> bool:
"""always called before the optimizer step.
Checks that the optimizer is not LBFGS, as this one is not supported by native amp
"""
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
lambda_closure()

if not pl_module.automatic_optimization:
self.scaler.unscale_(optimizer)

pl_module.trainer.call_hook("on_after_backward")

return False

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Updates the GradScaler"""
self.scaler.step(optimizer)
self.scaler.update()

@contextmanager
def train_step_context(self) -> Generator[autocast, None, None]:
"""Enable autocast context"""
Expand Down
20 changes: 15 additions & 5 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Generator, Sequence, Tuple, Union
from typing import Any, Callable, Generator, Sequence, Tuple, Union

import torch
from torch.nn import Module
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
Expand All @@ -28,7 +29,7 @@ class PrecisionPlugin(Plugin):
EPSILON = 1e-6
precision = 32

def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Tensor, None, None]:
def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]:
"""The master params of the model. Returns the plain model params here.
Maybe different in other precision plugins.
Expand All @@ -37,16 +38,16 @@ def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Ten
for p in group["params"]:
yield p

def connect(self, model: torch.nn.Module, optimizers: Sequence,
lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]:
def connect(self, model: Module, optimizers: Sequence,
lr_schedulers: Sequence) -> Tuple[Module, Sequence, Sequence]:
"""Connects this plugin to the accelerator and the training process"""
return model, optimizers, lr_schedulers

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
Expand Down Expand Up @@ -75,6 +76,15 @@ def backward(

return closure_loss

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable, **kwargs
) -> bool:
"""Hook to do something before each optimizer step."""
return True

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None:
"""Clips the gradients to a specific value"""
# TODO: separate TPU case from here
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def barrier(self, *args, **kwargs):
def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
prepare_for_backward(self.model, closure_loss)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def model_to_device(self):
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
prepare_for_backward(self.model, closure_loss)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = hvd.broadcast_object(obj, src)
return obj

def post_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
optimizer.synchronize()

def model_to_device(self):
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
import os
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import torch
from torch.nn import Module
from torch.optim import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.base_plugin import Plugin
Expand Down Expand Up @@ -69,19 +69,19 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
"""Reduce the early stopping decision across all possibly spawned processes"""
return should_stop

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""

def post_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run after precision plugin executes backward"""

@property
def model(self) -> torch.nn.Module:
def model(self) -> Module:
"""Returns the potentially wrapped LightningModule"""
return self._model

@model.setter
def model(self, new_model: torch.nn.Module) -> None:
def model(self, new_model: Module) -> None:
self._model = new_model

@property
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def training_step(self, batch, batch_idx):
if self.should_update:

self.manual_backward(loss, opt)
opt.step()
opt.step(make_optimizer_step=self.should_have_updated)

return loss.detach() if self.detach else loss

Expand All @@ -557,7 +557,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
assert torch.sum(self.layer.weight.grad) != 0
self.count += 1

def on_train_end(self):
def on_train_epoch_end(self, *_, **__):
assert self.called["training_step"] == 20
assert self.called["on_train_batch_start"] == 20
assert self.called["on_train_batch_end"] == 20
Expand Down Expand Up @@ -828,7 +828,7 @@ def optimizer_closure():
retain_graph = num_backward != backward_idx # noqa E225
self.manual_backward(loss_1, opt, retain_graph=retain_graph)

opt.step(closure=optimizer_closure)
opt.step(closure=optimizer_closure, make_optimizer_step=True)

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
Expand Down

0 comments on commit 13ae1ff

Please sign in to comment.