Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Resolve manual optimization #5852

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import (
Expand Down Expand Up @@ -228,8 +229,8 @@ def predict(self, args):
return self.training_type_plugin.predict(*args)

def process_dataloader(
self, dataloader: Union[Iterable, torch.utils.data.DataLoader]
) -> Union[Iterable, torch.utils.data.DataLoader]:
self, dataloader: Union[Iterable, DataLoader]
) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary

Args:
Expand All @@ -240,7 +241,7 @@ def process_dataloader(
def backward(
self,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
Expand All @@ -254,17 +255,17 @@ def backward(
opt_idx: the index of the optimizer
should_accumulate: whether to accumulate gradients
"""
self.training_type_plugin.pre_backward(closure_loss, optimizer, opt_idx)
self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, opt_idx)

output = self.precision_plugin.backward(
self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
)

self.training_type_plugin.post_backward(closure_loss, optimizer, opt_idx)
self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, opt_idx)

return output

def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
"""performs the actual optimizer step.

Args:
Expand All @@ -273,33 +274,23 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_
lambda_closure: closure calculating the loss value

"""

self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)

if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
# apex does not support passing a closure to the optimizer, call it by itself
lambda_closure()
lambda_closure = None

optimizer.step(closure=lambda_closure, **kwargs)

make_optimizer_step = self.precision_plugin.pre_optimizer_step(
self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs)
if make_optimizer_step:
self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)

if self.rpc_enabled and self.training_type_plugin.is_main_rpc_process:

# Initialize optimizer step on main process
self.training_type_plugin.worker_optimizer_step(model=self.lightning_module, opt_idx=opt_idx, **kwargs)
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)

def optimizer_zero_grad(
self, current_epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int
) -> None:
"""Zeros all model parameter's gradients"""
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer: torch.optim.Optimizer, clip_val: Union[int, float]) -> None:
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""clips all the optimizer parameters to the given value"""

self.precision_plugin.clip_gradients(optimizer, clip_val)
Expand Down
19 changes: 2 additions & 17 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable

import torch
from torch.optim import Optimizer

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
Expand All @@ -26,20 +26,5 @@ def setup(self, trainer, model):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def optimizer_step(self, optimizer: torch.optim.Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
"""performs the actual optimizer step.

Args:
optimizer: the optimizer performing the step
opt_idx: index of the current optimizer
lambda_closure: closure calculating the loss value

"""

self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})

self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx)
13 changes: 4 additions & 9 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from torch.nn import Module
from abc import ABC, abstractmethod
from typing import Any, Generator, Optional, overload, Sequence, Tuple
from typing import Any, Callable, Generator, Optional, overload, Sequence, Tuple

import torch

Expand All @@ -22,18 +23,12 @@ class Plugin(ABC):
"""Basic Plugin class to derive precision and training type plugins from."""

@abstractmethod
def connect(self, model: torch.nn.Module, *args: Sequence,
**kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]:
def connect(self, model: Module, *args: Sequence,
**kwargs: Sequence) -> Optional[Tuple[Module, Sequence, Sequence]]:
"""Connects the plugin with the accelerator (and thereby with trainer and model).
Will be called by the accelerator.
"""

def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Hook to do something before each optimizer step."""

def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""

def pre_training(self) -> None:
"""Hook to do something before the training starts."""

Expand Down
48 changes: 29 additions & 19 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Generator
from typing import Callable, Generator

import torch
from torch.optim import LBFGS, Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
Expand All @@ -33,25 +34,11 @@ def __init__(self):
self.backend = AMPType.NATIVE
self.scaler = torch.cuda.amp.GradScaler()

def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""always called before the optimizer step.
Checks that the optimizer is not LBFGS, as this one is not supported by native amp
"""
if isinstance(optimizer, torch.optim.LBFGS):
raise MisconfigurationException(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)

def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Updates the GradScaler"""
self.scaler.update()

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
Expand All @@ -69,16 +56,39 @@ def backward(
"""
closure_loss = self.scaler.scale(closure_loss)

automatic_optimization = model.automatic_optimization

closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs)

# unscale gradient to allow analyze within `on_after_backward`
if not should_accumulate and automatic_optimization:
if not should_accumulate and model.automatic_optimization:
self.scaler.unscale_(optimizer)

return closure_loss

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
) -> bool:
"""always called before the optimizer step.
Checks that the optimizer is not LBFGS, as this one is not supported by native amp
"""
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
lambda_closure()

if not pl_module.automatic_optimization:
self.scaler.unscale_(optimizer)

pl_module.trainer.call_hook("on_after_backward")

return False

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Updates the GradScaler"""
self.scaler.step(optimizer)
self.scaler.update()

@contextmanager
def train_step_context(self) -> Generator[autocast, None, None]:
"""Enable autocast context"""
Expand Down
20 changes: 15 additions & 5 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Generator, Sequence, Tuple, Union
from typing import Any, Callable, Generator, Sequence, Tuple, Union

import torch
from torch.nn import Module
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
Expand All @@ -28,7 +29,7 @@ class PrecisionPlugin(Plugin):
EPSILON = 1e-6
precision = 32

def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Tensor, None, None]:
def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]:
"""The master params of the model. Returns the plain model params here.
Maybe different in other precision plugins.

Expand All @@ -37,16 +38,16 @@ def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Ten
for p in group["params"]:
yield p

def connect(self, model: torch.nn.Module, optimizers: Sequence,
lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]:
def connect(self, model: Module, optimizers: Sequence,
lr_schedulers: Sequence) -> Tuple[Module, Sequence, Sequence]:
"""Connects this plugin to the accelerator and the training process"""
return model, optimizers, lr_schedulers

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
Expand Down Expand Up @@ -75,6 +76,15 @@ def backward(

return closure_loss

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, closure: Callable, **kwargs
) -> bool:
"""Hook to do something before each optimizer step."""
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of the bool return value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used to bypass run_optimizer_step as AMP is going to run self.scaler.step(optimizer).

It was the only way I found to make accelerator.optimizer_step plugin agnostic.


def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None:
"""Clips the gradients to a specific value"""
# TODO: separate TPU case from here
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def barrier(self, *args, **kwargs):
def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
prepare_for_backward(self.model, closure_loss)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def model_to_device(self):
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
prepare_for_backward(self.model, closure_loss)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = hvd.broadcast_object(obj, src)
return obj

def post_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
optimizer.synchronize()

def model_to_device(self):
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
import os
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import torch
from torch.nn import Module
from torch.optim import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.base_plugin import Plugin
Expand Down Expand Up @@ -69,19 +69,19 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
"""Reduce the early stopping decision across all possibly spawned processes"""
return should_stop

def pre_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""

def post_backward(self, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int):
def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run after precision plugin executes backward"""

@property
def model(self) -> torch.nn.Module:
def model(self) -> Module:
"""Returns the potentially wrapped LightningModule"""
return self._model

@model.setter
def model(self, new_model: torch.nn.Module) -> None:
def model(self, new_model: Module) -> None:
self._model = new_model

@property
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def training_step(self, batch, batch_idx):
if self.should_update:

self.manual_backward(loss, opt)
opt.step()
opt.step(make_optimizer_step=self.should_have_updated)

return loss.detach() if self.detach else loss

Expand All @@ -557,7 +557,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
assert torch.sum(self.layer.weight.grad) != 0
self.count += 1

def on_train_end(self):
def on_train_epoch_end(self, *_, **__):
assert self.called["training_step"] == 20
assert self.called["on_train_batch_start"] == 20
assert self.called["on_train_batch_end"] == 20
Expand Down Expand Up @@ -828,7 +828,7 @@ def optimizer_closure():
retain_graph = num_backward != backward_idx # noqa E225
self.manual_backward(loss_1, opt, retain_graph=retain_graph)

opt.step(closure=optimizer_closure)
opt.step(closure=optimizer_closure, make_optimizer_step=True)

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
Expand Down