Skip to content

Commit

Permalink
Accelerator Refactor: Precision Plugins (#5718)
Browse files Browse the repository at this point in the history
* add basic accelerator class.
Co-Authored with @awaelchi

* add basic trainign type plugin.
Co-Authored with @awaelchi

* pep8

Co-authored-by: @awaelchi

* update copyright

Co-authored-by: Adrian Wälchli <[email protected]>

* add apex_amp

Co-authored-by: Adrian Wälchli <[email protected]>

* add mixed base class

Co-authored-by: Adrian Wälchli <[email protected]>

* add native amp

Co-authored-by: Adrian Wälchli <[email protected]>

* add native amp sharded

Co-authored-by: Adrian Wälchli <[email protected]>

* add tpu bfloat

Co-authored-by: Adrian Wälchli <[email protected]>

* add inits

Co-authored-by: Adrian Wälchli <[email protected]>

* Update precision_plugin.py

Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
justusschock and awaelchli authored Jan 31, 2021
1 parent 3bacac7 commit 069ae27
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 54 deletions.
45 changes: 24 additions & 21 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins import TrainingTypePlugin
from pytorch_lightning.plugins.precision import (
ApexMixedPrecisionPlugin,
MixedPrecisionPlugin,
NativeMixedPrecisionPlugin,
PrecisionPlugin,
)
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.enums import AMPType, LightningEnum


class Accelerator(object):
Expand All @@ -39,7 +46,7 @@ class Accelerator(object):

def __init__(
self,
precision_plugin, #: PrecisionPlugin # fixme
precision_plugin: PrecisionPlugin,
training_type_plugin: TrainingTypePlugin,
) -> None:
"""
Expand Down Expand Up @@ -230,9 +237,8 @@ def backward(
)

# TODO: this is a hack, find a better solution for this (hook?)
# fixme: uncomment when this class is added
# if isinstance(self.training_type_plugin, HorovodPlugin):
# optimizer.synchronize()
if isinstance(self.training_type_plugin, HorovodPlugin):
optimizer.synchronize()

return output

Expand All @@ -256,11 +262,9 @@ def optimizer_step(
"""
model_ref = self.lightning_module
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
# fixme: uncomment when this class is added
# is_native_amp = (
# isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE
# )
is_native_amp = False
native_amp = (
isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE
)

self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)
Expand All @@ -273,7 +277,7 @@ def optimizer_step(
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=is_native_amp,
using_native_amp=native_amp,
using_lbfgs=is_lbfgs,
)

Expand Down Expand Up @@ -326,7 +330,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn
"""
plugin.connect(model)

def connect_precision_plugin(self, plugin): #: PrecisionPlugin # fixme
def connect_precision_plugin(self, plugin: PrecisionPlugin):
"""Attaches the precision plugin to the accelerator"""
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
self.model = model
Expand All @@ -339,13 +343,12 @@ def to_device(self, batch: Any) -> Any:

@property
def amp_backend(self) -> Optional[LightningEnum]:
# fixme: uncomment when this class is added
# if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
# return AMPType.APEX
# elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
# return AMPType.NATIVE
# return None
pass
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
return AMPType.APEX
elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
return AMPType.NATIVE
else:
return None

@property
def precision(self) -> int:
Expand All @@ -372,4 +375,4 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:
return optimizer.state_dict()

def on_save(self, checkpoint):
return checkpoint
return checkpoint
25 changes: 13 additions & 12 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Generator, Optional, overload, Sequence, Tuple

import torch


class Plugin(object):
class Plugin(ABC):
"""Basic Plugin class to derive precision and training type plugins from."""

def connect(self, model: torch.nn.Module, *args, **kwargs):
@abstractmethod
def connect(self, model: torch.nn.Module, *args: Sequence, **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]:
"""Connects the plugin with the accelerator (and thereby with trainer and model).
Will be called by the accelerator.
"""
pass

def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int):
def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Hook to do something before each optimizer step."""
pass

def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int):
def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""
pass

def pre_training(self):
def pre_training(self) -> None:
"""Hook to do something before the training starts."""
pass

def post_training(self):
def post_training(self) -> None:
"""Hook to do something after the training finishes."""
pass

@contextlib.contextmanager
def train_step_context(self):
def train_step_context(self) -> Generator:
"""A contextmanager for the trainstep"""
yield

@contextlib.contextmanager
def val_step_context(self):
def val_step_context(self) -> Generator:
"""A contextmanager for the validation step"""
yield

@contextlib.contextmanager
def test_step_context(self):
def test_step_context(self) -> Generator:
"""A contextmanager for the teststep"""
yield
yield
7 changes: 6 additions & 1 deletion pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@

from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin
146 changes: 146 additions & 0 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple

import torch
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn

if _APEX_AVAILABLE:
from apex import amp


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

def __init__(self, amp_level: str):
self.backend = AMPType.APEX
self.amp_level = amp_level

def master_params(self, optimizer: torch.optim.Optimizer):
return amp.master_params(optimizer)

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
self.reinit_scheduler_properties(optimizers, lr_schedulers)
return model, optimizers, lr_schedulers

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
"""performs the actual backpropagation
Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: the optimizer to perform the step lateron
opt_idx: the optimizer's index
should_accumulate: whether to accumulate gradients or not
"""
closure_loss = amp.scale_loss(closure_loss, optimizer)

# enter apex context
context = closure_loss
closure_loss = closure_loss.__enter__()

# do backward pass
# TODO: not entirely sure, why we need this
if model is not None and isinstance(model, LightningModule):
model.backward(closure_loss, optimizer, opt_idx)
else:
closure_loss.backward(*args, **kwargs)

# exit amp context
a, b, c = None, None, None
error = context.__exit__(a, b, c)
if error:
rank_zero_warn(a, b, c)
raise Exception("apex unscale error")

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
return closure_loss

def configure_apex(
self,
amp: object,
model: LightningModule,
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple[LightningModule, List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Args:
amp: pointer to amp library object.
model: pointer to current :class:`LightningModule`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)
Return:
Apex wrapped model and optimizers
Examples:
.. code-block:: python
# Default implementation used by Trainer.
def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

@staticmethod
def reinit_scheduler_properties(optimizers: list, schedulers: list):
"""Reinitializes schedulers with correct properties"""
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
scheduler = scheduler["scheduler"]

for optimizer in optimizers:
state = None
idx = 0

# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if mro in (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
idx = i
state = scheduler.state_dict()
else:
state = None

scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)
23 changes: 23 additions & 0 deletions pytorch_lightning/plugins/precision/mixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import AMPType


class MixedPrecisionPlugin(PrecisionPlugin):
"""Base Class for mixed precision"""

EPSILON = 1e-5
backend: AMPType
precision = "mixed"
Loading

0 comments on commit 069ae27

Please sign in to comment.