Skip to content

Commit

Permalink
Clean up optimizer code (#3587)
Browse files Browse the repository at this point in the history
* Update optimizer code

* Update CHANGELOG

* Fix tuple of one list case

* Update docs

* Fix pep issue

* Minor typo [skip-ci]

* Use minimal match

Co-authored-by: Adrian Wälchli <[email protected]>

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <[email protected]>

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2020
1 parent 0ec4107 commit 2549ca4
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 172 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))

### Deprecated

Expand Down
30 changes: 25 additions & 5 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,46 @@ Every optimizer you use can be paired with any `LearningRateScheduler <https://p
# Adam + LR scheduler
def configure_optimizers(self):
optimizer = Adam(...)
scheduler = ReduceLROnPlateau(optimizer, ...)
scheduler = LambdaLR(optimizer, ...)
return [optimizer], [scheduler]
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
return {
'optimizer': Adam(...),
'scheduler': ReduceLROnPlateau(optimizer, ...),
'monitor': 'metric_to_track'
}
# Two optimizers each with a scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler1 = LambdaLR(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return [optimizer1, optimizer2], [scheduler1, scheduler2]
# Alternatively
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'},
{'optimizer': optimizer2, 'lr_scheduler': scheduler2},
)
# Same as above with additional params passed to the first scheduler
def configure_optimizers(self):
optimizers = [Adam(...), SGD(...)]
schedulers = [
{
'scheduler': ReduceLROnPlateau(optimizers[0], ...),
'monitor': 'val_recall', # Default: val_loss
'monitor': 'metric_to_track',
'interval': 'epoch',
'frequency': 1
'frequency': 1,
'strict': True,
},
LambdaLR(optimizers[1], ...)
]
Expand All @@ -144,7 +164,7 @@ To use multiple optimizers return > 1 optimizers from :meth:`pytorch_lightning.c
# Two optimizers, one scheduler for adam only
def configure_optimizers(self):
return [Adam(...), SGD(...)], [ReduceLROnPlateau()]
return [Adam(...), SGD(...)], {'scheduler': ReduceLROnPlateau(), 'monitor': 'metric_to_track'}
Lightning will call each optimizer sequentially:
Expand Down
54 changes: 16 additions & 38 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class OptimizerConnector:

def __init__(self, trainer):
self.trainer = trainer

Expand All @@ -41,21 +40,15 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0:
# If instance of ReduceLROnPlateau, we need to pass validation loss
# If instance of ReduceLROnPlateau, we need a monitor
monitor_key, monitor_val = None, None
if lr_scheduler['reduce_on_plateau']:
try:
monitor_key = lr_scheduler['monitor']
except KeyError as e:
m = "ReduceLROnPlateau requires returning a dict from configure_optimizers with the keyword " \
"monitor=. For example:" \
"return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'your_loss'}"
raise MisconfigurationException(m)

if monitor_metrics is not None:
monitor_val = monitor_metrics.get(monitor_key)
else:
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)

monitor_key = lr_scheduler['monitor']
monitor_val = (
monitor_metrics.get(monitor_key)
if monitor_metrics is not None
else self.trainer.logger_connector.callback_metrics.get(monitor_key)
)
if monitor_val is None:
if lr_scheduler.get('strict', True):
avail_metrics = self.trainer.logger_connector.callback_metrics.keys()
Expand All @@ -71,30 +64,15 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
RuntimeWarning,
)
continue
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
if lr_scheduler['reduce_on_plateau']:
lr_scheduler['scheduler'].step(monitor_val)
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx,
interval,
scheduler_idx,
old_lr,
new_lr,
monitor_key,
)
else:
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
lr_scheduler['scheduler'].step()
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx,
interval,
scheduler_idx,
old_lr, new_lr
)
if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
)
133 changes: 63 additions & 70 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,111 +21,107 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class TrainerOptimizersMixin(ABC):

def init_optimizers(
self,
model: LightningModule
) -> Tuple[List, List, List]:
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
optim_conf = model.configure_optimizers()

if optim_conf is None:
rank_zero_warn('`LightningModule.configure_optimizers` returned `None`, '
'this fit will run with no optimizer', UserWarning)
rank_zero_warn(
'`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer',
UserWarning,
)
optim_conf = _MockOptimizer()

optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None

# single output, single optimizer
if isinstance(optim_conf, Optimizer):
return [optim_conf], [], []

optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
and isinstance(optim_conf[0], list):
optimizers, lr_schedulers = optim_conf
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers, []

elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
# single dictionary
elif isinstance(optim_conf, dict):
optimizer = optim_conf["optimizer"]
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get('monitor', None)
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler], monitor)
else:
lr_schedulers = []
return [optimizer], lr_schedulers, []

lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
# take only lr wif exists and ot they are defined - not None
lr_schedulers = [
opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")
]
# take only freq wif exists and ot they are defined - not None
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") is not None
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]

# clean scheduler list
if lr_schedulers:
lr_schedulers = self.configure_schedulers(lr_schedulers)
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
return optimizers, lr_schedulers, optimizer_frequencies

# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)):
return list(optim_conf), [], []

optimizers = list(optim_conf)
# unknown configuration
else:
raise ValueError(
raise MisconfigurationException(
'Unknown configuration for model optimizers.'
' Output from `model.configure_optimizers()` should either be:'
' * single output, single `torch.optim.Optimizer`'
' * single output, list of `torch.optim.Optimizer`'
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
' a list of `torch.optim.lr_scheduler`'
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
' Output from `model.configure_optimizers()` should either be:\n'
' * `torch.optim.Optimizer`\n'
' * [`torch.optim.Optimizer`]\n'
' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n'
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)

return optimizers, lr_schedulers, optimizer_frequencies

def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information
lr_schedulers = []
default_config = {
'interval': 'epoch', # default every epoch
'frequency': 1, # default every epoch/batch
'reduce_on_plateau': False
} # most often not ReduceLROnPlateau scheduler

if monitor is not None:
default_config['monitor'] = monitor

'scheduler': None,
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': monitor, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}
for scheduler in schedulers:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(f'Found unsupported keys in the lr scheduler dict: {extra_keys}', RuntimeWarning)
if 'scheduler' not in scheduler:
raise ValueError('Lr scheduler should have key `scheduler`',
' with item being a lr scheduler')
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)

scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler['reduce_on_plateau'] and scheduler.get('monitor', None) is None:
raise MisconfigurationException(
'The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
lr_schedulers.append({**default_config, **scheduler})

elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
lr_schedulers.append({**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True})

if monitor is None:
raise MisconfigurationException(
'`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:
raise ValueError(f'Input {scheduler} to lr schedulers '
'is a invalid input.')
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
Expand All @@ -138,10 +134,7 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if (
mro == optim.lr_scheduler._LRScheduler
or mro == optim.lr_scheduler.ReduceLROnPlateau
):
if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau):
idx = i
state = scheduler.state_dict()
else:
Expand Down
5 changes: 0 additions & 5 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ def configure_optimizers__mixed_scheduling(self):
return [optimizer1, optimizer2], \
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]

def configure_optimizers__reduce_lr_on_plateau(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [lr_scheduler]

def configure_optimizers__param_groups(self):
param_groups = [
{'params': list(self.parameters())[:2], 'lr': self.learning_rate * 0.1},
Expand Down
Loading

0 comments on commit 2549ca4

Please sign in to comment.