Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNet-1349][Fit API]Add validation support and unit tests for fit() API #14442

Merged
merged 16 commits into from
Mar 25, 2019
116 changes: 92 additions & 24 deletions python/mxnet/gluon/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
# pylint: disable=wildcard-import
"""Gluon Estimator"""

import copy
import warnings

from .event_handler import LoggingHandler
from ... import gluon, autograd
from ...context import Context, cpu, gpu, num_gpus
from ...io import DataIter
from ...metric import EvalMetric, Loss
from ...metric import EvalMetric, Loss, Accuracy

__all__ = ['Estimator']

Expand Down Expand Up @@ -62,44 +63,57 @@ def __init__(self, net,

if isinstance(loss, gluon.loss.Loss):
self.loss = [loss]
elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
self.loss = loss
else:
self.loss = loss or []
for l in self.loss:
if not isinstance(loss, gluon.loss.Loss):
raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss")
raise ValueError("loss must be a Loss or a list of Loss, "
"refer to gluon.loss.Loss:{}".format(loss))

if isinstance(metrics, EvalMetric):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logic similar to above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

self.metrics = [metrics]
self.train_metrics = [metrics]
Copy link
Member

@nswamy nswamy Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you infer from the loss function? use 'Accuracy' as default when not passed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'Accuracy' will only work for classification cases, for other cases it will give inaccurate resutls or even fail. Also I'm not sure how we can infer metrics from loss function as there isn't a direct correlation between them, do you have any suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should still infer metrics from known Loss functions (at least from the examples you know)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added Accuracy metric as default for SoftmaxCrossEntropy loss for now. Will add more in a followup PR.

else:
self.metrics = metrics or []
for metric in self.metrics:
if not isinstance(metric, EvalMetric):
raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric")
self.train_metrics = metrics or []
if not all([isinstance(metric, EvalMetric) for metric in self.train_metrics]):
raise ValueError("metrics must be a Metric or a list of Metric, "
"refer to mxnet.metric.EvalMetric:{}".format(metrics))

# Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss()
if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets get this from a map of Loss->[default metrics] in the next version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, tracking it using JIRA issue: https://issues.apache.org/jira/browse/MXNET-1364

self.train_metrics = [Accuracy()]

# Use same metrics for validation
self.val_metrics = copy.deepcopy(self.train_metrics)

self.initializer = initializer
# store training statistics
self.train_stats = {}
self.train_stats['epochs'] = []
self.train_stats['learning_rate'] = []
# current step of the epoch
self.train_stats['step'] = ''
for metric in self.metrics:
for metric in self.train_metrics:
# record a history of metrics over each epoch
self.train_stats['train_' + metric.name] = []
# only record the latest metric numbers after each batch
self.train_stats['batch_' + metric.name] = 0.
self.loss_metrics = []
for metric in self.val_metrics:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have one for loop for self.train_metrics and self.val_metrics since value for both the parameter is same according to line 82. Something like this for train_m, val_m in zip(self.train_metrics, self.val_metrics). Though zip() operator stop after exhausting shorter array but since both the array are of same length, we can use zip() operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to keep the train and val metrics separate here. Currently we are using the same metrics for val and train but future updates may involve separate user specified val metrics in which case combining this update loop won't work.

self.train_stats['val_' + metric.name] = []
self.train_loss_metrics = []
self.val_loss_metrics = []
# using the metric wrapper for loss to record loss value
for l in self.loss:
self.loss_metrics.append(Loss(l.name))
self.train_loss_metrics.append(Loss(l.name))
self.val_loss_metrics.append(Loss(l.name))
self.train_stats['train_' + l.name] = []
self.train_stats['val_' + l.name] = []
# only record the latest loss numbers after each batch
self.train_stats['batch_' + l.name] = 0.

# handle context
if isinstance(context, Context):
self.context = [context]
if not context:
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
self.context = context
elif not context:
if num_gpus() > 0:
# only use 1 GPU by default
if num_gpus() > 1:
Expand All @@ -109,8 +123,13 @@ def __init__(self, net,
self.context = [gpu(0)]
else:
self.context = [cpu()]
else:
raise ValueError("context must be a Context or a list of Context, "
"refer to mxnet.Context:{}".format(context))


# initialize the network
self.initializer = initializer
if self.initializer:
if self._is_initialized():
# if already initialized, re-init with user specified initializer
Expand All @@ -128,13 +147,13 @@ def __init__(self, net,
# handle trainers
if isinstance(trainers, gluon.Trainer):
self.trainers = [trainers]
else:
self.trainers = trainers or []
if not self.trainers:
elif not trainers:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
self.trainers = [gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})]
else:
raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer")

def _is_initialized(self):
param_dict = self.net.collect_params()
Expand All @@ -156,7 +175,48 @@ def _batch_fn(self, batch, ctx, is_iterator=False):
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
return data, label

def evaluate(self,
val_data,
batch_fn=None):
"""Evaluate model on validation data

Parameters
----------
val_data : DataLoader or DataIter
validation data with data and labels
batch_fn : function
custom batch function to extract data and label
from a data batch and load into contexts(devices)
"""

for metric in self.val_metrics + self.val_loss_metrics:
metric.reset()

for _, batch in enumerate(val_data):
if not batch_fn:
if isinstance(val_data, gluon.data.DataLoader):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this if/else into into self._batch_fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check needs to be before calling batch_fn as val_data is not available to it

data, label = self._batch_fn(batch, self.context)
elif isinstance(val_data, DataIter):
data, label = self._batch_fn(batch, self.context, is_iterator=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

else:
raise ValueError("You are using a custom iteration, please also provide "
Copy link
Contributor

@karan6181 karan6181 Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to the end user if could provide more detailed exception. you can append the below statement at the end, something like this: or you can provide the data in terms of gluon.data.DataLoader or mx.io.DataIter. Please also change this statement in fit() method if you are changing it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out, updated!

"batch_fn to extract data and label. Alternatively, you "
"can provide the data as gluon.data.DataLoader or "
"mx.io.DataIter")
else:
data, label = batch_fn(batch, self.context)
pred = [self.net(x) for x in data]
losses = []
for loss in self.loss:
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
# update metrics
for metric in self.val_metrics:
metric.update(label, pred)
for loss, loss_metric, in zip(losses, self.val_loss_metrics):
loss_metric.update(0, [l for l in loss])

def fit(self, train_data,
val_data=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be optional

Copy link
Contributor Author

@abhinavs95 abhinavs95 Mar 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users might want to train without a validation set. Although this is rare, still keeping it optional provides a bit of flexibility

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a reason why users would not a validation dataset, its required to know that the model is not overfitting/

epochs=1,
batch_size=None,
event_handlers=None,
Expand Down Expand Up @@ -204,7 +264,7 @@ def fit(self, train_data,
for handler in event_handlers:
handler.epoch_begin()

for metric in self.metrics + self.loss_metrics:
for metric in self.train_metrics + self.train_loss_metrics:
metric.reset()

for i, batch in enumerate(train_data):
Expand All @@ -215,7 +275,9 @@ def fit(self, train_data,
data, label = self._batch_fn(batch, self.context, is_iterator=True)
else:
raise ValueError("You are using a custom iteration, please also provide "
"batch_fn to extract data and label")
"batch_fn to extract data and label. Alternatively, you "
"can provide the data as gluon.data.DataLoader or "
"mx.io.DataIter")
else:
data, label = batch_fn(batch, self.context)

Expand All @@ -233,11 +295,11 @@ def fit(self, train_data,
for l in loss:
l.backward()

# update metrics
for metric in self.metrics:
# update train metrics
for metric in self.train_metrics:
metric.update(label, pred)
self.train_stats['batch_' + metric.name] = metric.get()[1]
for loss, loss_metric, in zip(losses, self.loss_metrics):
for loss, loss_metric, in zip(losses, self.train_loss_metrics):
loss_metric.update(0, [l for l in loss])
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]

Expand All @@ -253,8 +315,14 @@ def fit(self, train_data,
for handler in event_handlers:
handler.batch_end()

for metric in self.metrics + self.loss_metrics:
if val_data:
self.evaluate(val_data, batch_fn)

for metric in self.train_metrics + self.train_loss_metrics:
self.train_stats['train_' + metric.name].append(metric.get()[1])
for metric in self.val_metrics + self.val_loss_metrics:
self.train_stats['val_' + metric.name].append(metric.get()[1])

# epoch end
for handler in event_handlers:
handler.epoch_end()
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def epoch_end(self):
epoch = self._estimator.train_stats['epochs'][-1]
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
for key in self._estimator.train_stats.keys():
if key.startswith('train_') or key.startswith('test_'):
if key.startswith('train_') or key.startswith('val_'):
msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch]
self.logger.info(msg)

Expand Down
Loading