Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up optimizer code #3587

Merged
merged 8 commits into from
Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))

### Deprecated

Expand Down
30 changes: 25 additions & 5 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,46 @@ Every optimizer you use can be paired with any `LearningRateScheduler <https://p
# Adam + LR scheduler
def configure_optimizers(self):
optimizer = Adam(...)
scheduler = ReduceLROnPlateau(optimizer, ...)
scheduler = LambdaLR(optimizer, ...)
return [optimizer], [scheduler]

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
return {
'optimizer': Adam(...),
'scheduler': ReduceLROnPlateau(optimizer, ...),
'monitor': 'metric_to_track'
}

# Two optimizers each with a scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler1 = LambdaLR(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return [optimizer1, optimizer2], [scheduler1, scheduler2]

# Alternatively
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'},
{'optimizer': optimizer2, 'lr_scheduler': scheduler2},
)

# Same as above with additional params passed to the first scheduler
def configure_optimizers(self):
optimizers = [Adam(...), SGD(...)]
schedulers = [
{
'scheduler': ReduceLROnPlateau(optimizers[0], ...),
'monitor': 'val_recall', # Default: val_loss
'monitor': 'metric_to_track',
'interval': 'epoch',
'frequency': 1
'frequency': 1,
'strict': True,
},
LambdaLR(optimizers[1], ...)
]
Expand All @@ -144,7 +164,7 @@ To use multiple optimizers return > 1 optimizers from :meth:`pytorch_lightning.c

# Two optimizers, one scheduler for adam only
def configure_optimizers(self):
return [Adam(...), SGD(...)], [ReduceLROnPlateau()]
return [Adam(...), SGD(...)], {'scheduler': ReduceLROnPlateau(), 'monitor': 'metric_to_track'}

Lightning will call each optimizer sequentially:

Expand Down
54 changes: 16 additions & 38 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class OptimizerConnector:

def __init__(self, trainer):
self.trainer = trainer

Expand All @@ -41,21 +40,15 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0:
# If instance of ReduceLROnPlateau, we need to pass validation loss
# If instance of ReduceLROnPlateau, we need a monitor
monitor_key, monitor_val = None, None
if lr_scheduler['reduce_on_plateau']:
try:
monitor_key = lr_scheduler['monitor']
except KeyError as e:
m = "ReduceLROnPlateau requires returning a dict from configure_optimizers with the keyword " \
"monitor=. For example:" \
"return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'your_loss'}"
raise MisconfigurationException(m)

if monitor_metrics is not None:
monitor_val = monitor_metrics.get(monitor_key)
else:
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)

monitor_key = lr_scheduler['monitor']
monitor_val = (
monitor_metrics.get(monitor_key)
if monitor_metrics is not None
else self.trainer.logger_connector.callback_metrics.get(monitor_key)
)
if monitor_val is None:
if lr_scheduler.get('strict', True):
avail_metrics = self.trainer.logger_connector.callback_metrics.keys()
Expand All @@ -71,30 +64,15 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
RuntimeWarning,
)
continue
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
if lr_scheduler['reduce_on_plateau']:
lr_scheduler['scheduler'].step(monitor_val)
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx,
interval,
scheduler_idx,
old_lr,
new_lr,
monitor_key,
)
else:
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
lr_scheduler['scheduler'].step()
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx,
interval,
scheduler_idx,
old_lr, new_lr
)
if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
)
133 changes: 63 additions & 70 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,111 +21,107 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class TrainerOptimizersMixin(ABC):

def init_optimizers(
self,
model: LightningModule
) -> Tuple[List, List, List]:
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
optim_conf = model.configure_optimizers()

if optim_conf is None:
rank_zero_warn('`LightningModule.configure_optimizers` returned `None`, '
'this fit will run with no optimizer', UserWarning)
rank_zero_warn(
'`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer',
UserWarning,
)
optim_conf = _MockOptimizer()

optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None

# single output, single optimizer
if isinstance(optim_conf, Optimizer):
return [optim_conf], [], []

optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
and isinstance(optim_conf[0], list):
optimizers, lr_schedulers = optim_conf
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers, []

elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
# single dictionary
elif isinstance(optim_conf, dict):
optimizer = optim_conf["optimizer"]
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get('monitor', None)
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler], monitor)
else:
lr_schedulers = []
return [optimizer], lr_schedulers, []

lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
# take only lr wif exists and ot they are defined - not None
lr_schedulers = [
opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")
]
# take only freq wif exists and ot they are defined - not None
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") is not None
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]

# clean scheduler list
if lr_schedulers:
lr_schedulers = self.configure_schedulers(lr_schedulers)
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
return optimizers, lr_schedulers, optimizer_frequencies

# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)):
return list(optim_conf), [], []

optimizers = list(optim_conf)
# unknown configuration
else:
raise ValueError(
raise MisconfigurationException(
'Unknown configuration for model optimizers.'
' Output from `model.configure_optimizers()` should either be:'
' * single output, single `torch.optim.Optimizer`'
' * single output, list of `torch.optim.Optimizer`'
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
' a list of `torch.optim.lr_scheduler`'
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
' Output from `model.configure_optimizers()` should either be:\n'
' * `torch.optim.Optimizer`\n'
' * [`torch.optim.Optimizer`]\n'
' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n'
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)

return optimizers, lr_schedulers, optimizer_frequencies

def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information
lr_schedulers = []
default_config = {
'interval': 'epoch', # default every epoch
'frequency': 1, # default every epoch/batch
'reduce_on_plateau': False
} # most often not ReduceLROnPlateau scheduler

if monitor is not None:
default_config['monitor'] = monitor

'scheduler': None,
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': monitor, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}
for scheduler in schedulers:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(f'Found unsupported keys in the lr scheduler dict: {extra_keys}', RuntimeWarning)
if 'scheduler' not in scheduler:
raise ValueError('Lr scheduler should have key `scheduler`',
' with item being a lr scheduler')
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)

scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler['reduce_on_plateau'] and scheduler.get('monitor', None) is None:
raise MisconfigurationException(
'The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
lr_schedulers.append({**default_config, **scheduler})

elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
lr_schedulers.append({**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True})

if monitor is None:
raise MisconfigurationException(
'`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:
raise ValueError(f'Input {scheduler} to lr schedulers '
'is a invalid input.')
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
Expand All @@ -138,10 +134,7 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if (
mro == optim.lr_scheduler._LRScheduler
or mro == optim.lr_scheduler.ReduceLROnPlateau
):
if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau):
idx = i
state = scheduler.state_dict()
else:
Expand Down
5 changes: 0 additions & 5 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ def configure_optimizers__mixed_scheduling(self):
return [optimizer1, optimizer2], \
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]

def configure_optimizers__reduce_lr_on_plateau(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [lr_scheduler]

def configure_optimizers__param_groups(self):
param_groups = [
{'params': list(self.parameters())[:2], 'lr': self.learning_rate * 0.1},
Expand Down
Loading