Skip to content

Commit

Permalink
Adress comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Feb 17, 2020
1 parent bec9064 commit 4dd0f91
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 55 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- pytest-cov
- pytest-flake8
- flake8
- autopep8
- check-manifest
- twine==1.13.0
- pillow<7.0.0
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .callback import Callback
from .base import Callback
from .early_stopping import EarlyStopping
from .model_checkpoint import ModelCheckpoint
from .gradient_accumulation_scheduler import GradientAccumulationScheduler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
Callbacks supported by Lightning
"""

import abc


_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"


class Callback(object):
class Callback(abc.ABC):
"""Abstract base class used to build new callbacks."""

def __init__(self):
Expand Down
29 changes: 10 additions & 19 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from .callback import Callback
from .base import Callback


class EarlyStopping(Callback):
Expand Down Expand Up @@ -38,9 +38,8 @@ class EarlyStopping(Callback):
Trainer(early_stop_callback=early_stopping)
"""

def __init__(self, monitor='val_loss',
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
super(EarlyStopping, self).__init__()
def __init__(self, monitor='val_loss', min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
super().__init__()

self.monitor = monitor
self.patience = patience
Expand All @@ -55,20 +54,13 @@ def __init__(self, monitor='val_loss',
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
mode = 'auto'

if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less

if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
mode_dict = {
'min': np.less,
'max': np.greater,
'auto': np.greater if 'acc' in self.monitor else np.less
}
self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == np.greater else -1

self.on_train_begin()

Expand All @@ -95,7 +87,6 @@ def on_train_begin(self):
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self):

logs = self.trainer.callback_metrics
stop_training = False
if not self.check_metrics(logs):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings

from .callback import Callback
from .base import CallbackBase


class GradientAccumulationScheduler(Callback):
Expand All @@ -25,10 +25,10 @@ class GradientAccumulationScheduler(Callback):
def __init__(self, scheduling: dict):
super().__init__()

if scheduling == {}: # empty dict error
if not scheduling: # empty dict error
raise TypeError("Empty dict cannot be interpreted correct")

for key in scheduling.keys():
for key in scheduling:
if not isinstance(key, int) or not isinstance(scheduling[key], int):
raise TypeError("All epoches and accumulation factor must be integers")

Expand All @@ -45,7 +45,6 @@ def __init__(self, scheduling: dict):
self.epochs = sorted(scheduling.keys())

def on_epoch_begin(self):

trainer = self.trainer
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, ` + 1` should be removed.
Expand Down
46 changes: 16 additions & 30 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

import numpy as np

from .callback import Callback
from .base import Callback


class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Args:
Expand All @@ -27,14 +26,14 @@ class ModelCheckpoint(Callback):
save_top_k (int): if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if `save_top_k == 0`, no models are saved.
if `save_top_k == -1`, all models are saved.
if ``save_top_k == 0``, no models are saved.
if ``save_top_k == -1``, all models are saved.
Please note that the monitors are checked every `period` epochs.
if `save_top_k >= 2` and the callback is called multiple
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
If `save_top_k != 0`, the decision
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
Expand All @@ -60,11 +59,11 @@ class ModelCheckpoint(Callback):
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_top_k=1, save_weights_only=False,
mode='auto', period=1, prefix=''):
super(ModelCheckpoint, self).__init__()
super().__init__()
if (
save_top_k and
os.path.isdir(filepath) and
len(os.listdir(filepath)) > 0
save_top_k
and os.path.isdir(filepath)
and len(os.listdir(filepath)) > 0
):
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
Expand Down Expand Up @@ -111,34 +110,26 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,
self.mode = 'min'

def _del_model(self, filepath):
dirpath = os.path.dirname(filepath)

# make paths
os.makedirs(dirpath, exist_ok=True)

try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)

def _save_model(self, filepath):
dirpath = os.path.dirname(filepath)

# make paths
os.makedirs(dirpath, exist_ok=True)
os.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
assert self.save_function is not None, ".save_function() not set"
self.save_function(filepath)

def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_validation_end(self):

logs = self.trainer.callback_metrics
epoch = self.trainer.current_epoch
self.epochs_since_last_check += 1
Expand Down Expand Up @@ -174,18 +165,13 @@ def on_validation_end(self):
self.best_k_models[filepath] = current
if len(self.best_k_models.keys()) == self.save_top_k:
# monitor dict has reached k elements
if self.mode == 'min':
self.kth_best_model = max(
self.best_k_models, key=self.best_k_models.get)
else:
self.kth_best_model = min(
self.best_k_models, key=self.best_k_models.get)
_op = min if self.mode == 'min' else max
self.kth_best_model = _op(self.best_k_models, key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]

if self.mode == 'min':
self.best = min(self.best_k_models.values())
else:
self.best = max(self.best_k_models.values())
_op = min if self.mode == 'min' else max
self.best = _op(self.best_k_models.values())

if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
Expand Down

0 comments on commit 4dd0f91

Please sign in to comment.