Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

introduce gradient update handler to the base estimator #16900

Merged
merged 10 commits into from
Dec 9, 2019
8 changes: 5 additions & 3 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys
import warnings

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler, GradientUpdateHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .event_handler import _check_event_handlers
from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
Expand Down Expand Up @@ -307,8 +307,6 @@ def fit_batch(self, train_batch, batch_axis=0):
for l in loss:
l.backward()

self.trainer.step(batch_size)

return data, label, pred, loss

def fit(self, train_data,
Expand Down Expand Up @@ -360,6 +358,7 @@ def fit(self, train_data,

self.max_epoch = epochs
self.max_batch = batches
self.batch_axis = batch_axis

# provide default handlers
event_handlers = self._prepare_default_handlers(val_data, event_handlers)
Expand Down Expand Up @@ -414,6 +413,9 @@ def _prepare_default_handlers(self, val_data, event_handlers):
# no need to add to default handler check as StoppingHandler does not use metrics
added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))

if not any(isinstance(handler, GradientUpdateHandler) for handler in event_handlers):
added_default_handlers.append(GradientUpdateHandler())

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))

Expand Down
54 changes: 47 additions & 7 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler', 'GradientUpdateHandler']


class EventHandler(object):
Expand Down Expand Up @@ -130,13 +130,16 @@ class MetricHandler(EpochBegin, BatchEnd):
----------
train_metrics : List of EvalMetrics
Training metrics to be updated at batch end.
priority : scalar
Priority level of the MetricHandler. Priority level is sorted in ascending
order. The lower the number is, the higher priority level the handler is.
"""

def __init__(self, train_metrics):
def __init__(self, train_metrics, priority=-1000):
self.train_metrics = _check_metrics(train_metrics)
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.priority = -np.Inf
self.priority = priority

def epoch_begin(self, estimator, *args, **kwargs):
for metric in self.train_metrics:
Expand Down Expand Up @@ -176,14 +179,19 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
batch_period : int, default None
How often to run validation at batch end, by default
:py:class:`ValidationHandler` does not validate at batch end.
priority: scalar, default -1000
Priority level of the ValidationHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
handler is.
"""

def __init__(self,
val_data,
eval_fn,
val_metrics=None,
epoch_period=1,
batch_period=None):
batch_period=None,
priority=-1000):
self.val_data = val_data
self.eval_fn = eval_fn
self.epoch_period = epoch_period
Expand All @@ -193,7 +201,7 @@ def __init__(self,
self.current_epoch = 0
# order to be called among all callbacks
# validation metrics need to be calculated before other callbacks can access them
self.priority = -np.Inf
self.priority = priority

def train_begin(self, estimator, *args, **kwargs):
# reset epoch and batch counter
Expand Down Expand Up @@ -235,11 +243,16 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
Training metrics to be logged, logged at batch end, epoch end, train end.
val_metrics : list of EvalMetrics
Validation metrics to be logged, logged at epoch end, train end.
priority : scalar, default np.Inf
Priority level of the LoggingHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
handler is.
"""

def __init__(self, log_interval='epoch',
train_metrics=None,
val_metrics=None):
val_metrics=None,
priority=np.Inf):
super(LoggingHandler, self).__init__()
if not isinstance(log_interval, int) and log_interval != 'epoch':
raise ValueError("log_interval must be either an integer or string 'epoch'")
Expand All @@ -250,7 +263,7 @@ def __init__(self, log_interval='epoch',
self.processed_samples = 0
# logging handler need to be called at last to make sure all states are updated
# it will also shut down logging at train end
self.priority = np.Inf
self.priority = priority
self.log_interval = log_interval

def train_begin(self, estimator, *args, **kwargs):
Expand Down Expand Up @@ -704,3 +717,30 @@ def train_end(self, estimator, *args, **kwargs):
estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: '
'early stopping due to %s not improving',
self.stopped_epoch, self.monitor.get()[0])

class GradientUpdateHandler(BatchEnd):
"""Gradient Update Handler that apply gradients on network weights

:py:class:`GradientUpdateHandler` takes the priority level. It updates weight parameters
at the end of each batch

Parameters
----------
priority : scalar, default -2000
priority level of the gradient update handler. Priority level is sorted in ascending
order. The lower the number is, the higher priority level the handler is.
----------
"""
def __init__(self, priority=-2000):
self.priority = priority

def batch_end(self, estimator, *args, **kwargs):
loss = kwargs['loss']
batch_size = 0
if not isinstance(loss, list):
loss = [loss]
if isinstance(loss, list):
for l in loss:
batch_size += l.shape[estimator.batch_axis]

estimator.trainer.step(batch_size)
7 changes: 4 additions & 3 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def test_default_handlers():
val_metrics = est.val_metrics
early_stopping = EarlyStoppingHandler(monitor=val_metrics[0])
handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping])
assert len(handlers) == 4
assert isinstance(handlers[0], MetricHandler)
assert isinstance(handlers[3], LoggingHandler)
assert len(handlers) == 5
assert isinstance(handlers[0], GradientUpdateHandler)
assert isinstance(handlers[1], MetricHandler)
assert isinstance(handlers[4], LoggingHandler)