Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor callbacks #776

Merged
merged 13 commits into from
Feb 16, 2020
107 changes: 55 additions & 52 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,60 +12,62 @@

import numpy as np

from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel


class Callback(object):
r"""Abstract base class used to build new callbacks.
"""
"""Abstract base class used to build new callbacks."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as abstract, inherit from abstract ABC... class Callback(ABC): ?
but if it will be ABC, then you have to implement all methods all the time... :/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with all this ABC stuff, so I don't really know :)
But anyway I think it would be better to do it in another PR (if it really should be done)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def __init__(self):
self.validation_data = None
self.model = None

def set_params(self, params):
self.params = params
self._trainer = None

def set_model(self, model):
if isinstance(model, LightningDistributedDataParallel):
model = model.module
self.model = model
def set_trainer(self, trainer):
kuynzereb marked this conversation as resolved.
Show resolved Hide resolved
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer

def on_epoch_begin(self, epoch, logs=None):
"""
called when the epoch begins
def on_epoch_begin(self):
"""Called when the epoch begins."""
pass

Args:
epoch (int): current epoch
logs (dict): key-value pairs of quantities to monitor
def on_epoch_end(self):
"""Called when the epoch ends."""
pass

Example:
def on_batch_begin(self):
"""Called when the training batch begins."""
pass

on_epoch_begin(epoch=2, logs={'val_loss': 0.2})
"""
def on_batch_end(self):
"""Called when the training batch ends."""
pass

def on_epoch_end(self, epoch, logs=None):
def on_train_begin(self):
"""Called when the train begins."""
pass

def on_batch_begin(self, batch, logs=None):
"""
called when the batch starts.
def on_train_end(self):
"""Called when the train ends."""
pass

Args:
batch (Tensor): current batch tensor
logs (dict): key-value pairs of quantities to monitor
"""
def on_validation_begin(self):
"""Called when the validation loop begins."""
pass

def on_batch_end(self, batch, logs=None):
def on_validation_end(self):
"""Called when the validation loop ends."""
pass

def on_train_begin(self, logs=None):
def on_test_begin(self):
"""Called when the test begins."""
pass

def on_train_end(self, logs=None):
def on_test_end(self):
"""Called when the test ends."""
pass


_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"


class EarlyStopping(Callback):
r"""
Stop training when a monitored quantity has stopped improving.
Expand Down Expand Up @@ -148,13 +150,16 @@ def check_metrics(self, logs):

return True

def on_train_begin(self, logs=None):
def on_train_begin(self):
kuynzereb marked this conversation as resolved.
Show resolved Hide resolved
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self, epoch, logs=None):
def on_epoch_end(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG

logs = self._trainer.callback_metrics
stop_training = False
if not self.check_metrics(logs):
return stop_training
Expand All @@ -166,13 +171,13 @@ def on_epoch_end(self, epoch, logs=None):
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.stopped_epoch = self._trainer.current_epoch
stop_training = True
self.on_train_end()

return stop_training

def on_train_end(self, logs=None):
def on_train_end(self):
if self.stopped_epoch > 0 and self.verbose > 0:
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
Expand Down Expand Up @@ -306,8 +311,11 @@ def check_monitor_top_k(self, current):
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
def on_validation_end(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG

logs = self._trainer.callback_metrics
epoch = self._trainer.current_epoch
self.epochs_since_last_check += 1

if self.save_top_k == 0:
Expand Down Expand Up @@ -389,6 +397,8 @@ class GradientAccumulationScheduler(Callback):
"""

def __init__(self, scheduling: dict):
super().__init__()

if scheduling == {}: # empty dict error
raise TypeError("Empty dict cannot be interpreted correct")

Expand All @@ -408,21 +418,14 @@ def __init__(self, scheduling: dict):
self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())

def on_epoch_begin(self, epoch, trainer):
def on_epoch_begin(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG

trainer = self._trainer
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, `epoch += 1` should be removed.
epoch += 1
# In v0.8.0, ` + 1` should be removed.
epoch = trainer.current_epoch + 1
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
break


# if __name__ == '__main__':
# c = EarlyStopping(min_delta=0.9, patience=2, verbose=True)
# losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5]
# for i, loss in enumerate(losses):
# should_stop = c.on_epoch_end(i, logs={'val_loss': loss})
# log.info(loss)
# if should_stop:
# break
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def configure_checkpoint_callback(self):
# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.filepath

# link to the trainer
self.checkpoint_callback.set_trainer(self)

# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_save_path
Expand Down Expand Up @@ -77,3 +80,6 @@ def configure_early_stopping(self, early_stop_callback):
else:
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True

if self.early_stop_callback is not None:
self.early_stop_callback.set_trainer(self)
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,7 @@ def run_evaluation(self, test=False):

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
logs=self.callback_metrics)
self.checkpoint_callback.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
# make dataloader_idx arg in validation_step optional
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def train(self):
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch, self)
self.accumulation_scheduler.on_epoch_begin()

# -----------------
# RUN TNG EPOCH
Expand All @@ -352,8 +352,7 @@ def train(self):
met_min_epochs = epoch >= self.min_epochs - 1
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
(met_min_epochs or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
logs=self.callback_metrics)
should_stop = self.early_stop_callback.on_epoch_end()
# stop training
stop = should_stop and met_min_epochs
if stop:
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ def configure_accumulated_gradients(self, accumulate_grad_batches):
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")

self.accumulation_scheduler.set_trainer(self)
72 changes: 54 additions & 18 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,16 @@ def mock_save_function(filepath):

# -----------------
# CASE K=-1 (all)
w = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
w.save_function = mock_save_function
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
checkpoint_callback.set_trainer(trainer)

# emulate callback's calls during the training
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})
checkpoint_callback._trainer.current_epoch = i
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
checkpoint_callback.on_validation_end()

file_lists = set(os.listdir(save_dir))

Expand All @@ -247,10 +253,16 @@ def mock_save_function(filepath):

# -----------------
# CASE K=0 (none)
w = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
w.save_function = mock_save_function
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
checkpoint_callback.set_trainer(trainer)

# emulate callback's calls during the training
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})
checkpoint_callback._trainer.current_epoch = i
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
checkpoint_callback.on_validation_end()

file_lists = os.listdir(save_dir)

Expand All @@ -261,10 +273,16 @@ def mock_save_function(filepath):

# -----------------
# CASE K=1 (2.5, epoch 4)
w = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix')
w.save_function = mock_save_function
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix')
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
checkpoint_callback.set_trainer(trainer)

# emulate callback's calls during the training
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})
checkpoint_callback._trainer.current_epoch = i
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
checkpoint_callback.on_validation_end()

file_lists = set(os.listdir(save_dir))

Expand All @@ -278,11 +296,17 @@ def mock_save_function(filepath):
# CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
# make sure other files don't get deleted

w = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
open(f'{save_dir}/other_file.ckpt', 'a').close()
w.save_function = mock_save_function
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
checkpoint_callback.set_trainer(trainer)

# emulate callback's calls during the training
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})
checkpoint_callback._trainer.current_epoch = i
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
checkpoint_callback.on_validation_end()

file_lists = set(os.listdir(save_dir))

Expand All @@ -298,10 +322,16 @@ def mock_save_function(filepath):
# CASE K=4 (save all 4 models)
# multiple checkpoints within same epoch

w = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
w.save_function = mock_save_function
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
checkpoint_callback.set_trainer(trainer)

# emulate callback's calls during the training
for loss in losses:
w.on_epoch_end(0, logs={'val_loss': loss})
checkpoint_callback._trainer.current_epoch = 0
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
checkpoint_callback.on_validation_end()

file_lists = set(os.listdir(save_dir))

Expand All @@ -314,10 +344,16 @@ def mock_save_function(filepath):
# CASE K=3 (save the 2nd, 3rd, 4th model)
# multiple checkpoints within same epoch

w = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
w.save_function = mock_save_function
checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
checkpoint_callback.set_trainer(trainer)

# emulate callback's calls during the training
for loss in losses:
w.on_epoch_end(0, logs={'val_loss': loss})
checkpoint_callback._trainer.current_epoch = 0
checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
checkpoint_callback.on_validation_end()

file_lists = set(os.listdir(save_dir))

Expand Down