-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNet-1349][Fit API]Add validation support and unit tests for fit() API #14442
Changes from all commits
9f334da
c81132a
69e118b
7750027
1eafd3a
d9b7480
5d7b58e
7d9137a
353e3d3
b843f56
305d1bf
d07052a
abf6a68
f88515f
282957e
5f77df9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,13 +19,14 @@ | |
# pylint: disable=wildcard-import | ||
"""Gluon Estimator""" | ||
|
||
import copy | ||
import warnings | ||
|
||
from .event_handler import LoggingHandler | ||
from ... import gluon, autograd | ||
from ...context import Context, cpu, gpu, num_gpus | ||
from ...io import DataIter | ||
from ...metric import EvalMetric, Loss | ||
from ...metric import EvalMetric, Loss, Accuracy | ||
|
||
__all__ = ['Estimator'] | ||
|
||
|
@@ -62,44 +63,57 @@ def __init__(self, net, | |
|
||
if isinstance(loss, gluon.loss.Loss): | ||
self.loss = [loss] | ||
elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]): | ||
self.loss = loss | ||
else: | ||
self.loss = loss or [] | ||
for l in self.loss: | ||
if not isinstance(loss, gluon.loss.Loss): | ||
raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss") | ||
raise ValueError("loss must be a Loss or a list of Loss, " | ||
"refer to gluon.loss.Loss:{}".format(loss)) | ||
|
||
if isinstance(metrics, EvalMetric): | ||
self.metrics = [metrics] | ||
self.train_metrics = [metrics] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you infer from the loss function? use 'Accuracy' as default when not passed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'Accuracy' will only work for classification cases, for other cases it will give inaccurate resutls or even fail. Also I'm not sure how we can infer metrics from loss function as there isn't a direct correlation between them, do you have any suggestions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should still infer metrics from known Loss functions (at least from the examples you know) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added Accuracy metric as default for SoftmaxCrossEntropy loss for now. Will add more in a followup PR. |
||
else: | ||
self.metrics = metrics or [] | ||
for metric in self.metrics: | ||
if not isinstance(metric, EvalMetric): | ||
raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric") | ||
self.train_metrics = metrics or [] | ||
if not all([isinstance(metric, EvalMetric) for metric in self.train_metrics]): | ||
raise ValueError("metrics must be a Metric or a list of Metric, " | ||
"refer to mxnet.metric.EvalMetric:{}".format(metrics)) | ||
|
||
# Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss() | ||
if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets get this from a map of Loss->[default metrics] in the next version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, tracking it using JIRA issue: https://issues.apache.org/jira/browse/MXNET-1364 |
||
self.train_metrics = [Accuracy()] | ||
|
||
# Use same metrics for validation | ||
self.val_metrics = copy.deepcopy(self.train_metrics) | ||
|
||
self.initializer = initializer | ||
# store training statistics | ||
self.train_stats = {} | ||
self.train_stats['epochs'] = [] | ||
self.train_stats['learning_rate'] = [] | ||
# current step of the epoch | ||
self.train_stats['step'] = '' | ||
for metric in self.metrics: | ||
for metric in self.train_metrics: | ||
# record a history of metrics over each epoch | ||
self.train_stats['train_' + metric.name] = [] | ||
# only record the latest metric numbers after each batch | ||
self.train_stats['batch_' + metric.name] = 0. | ||
self.loss_metrics = [] | ||
for metric in self.val_metrics: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have one for loop for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to keep the train and val metrics separate here. Currently we are using the same metrics for val and train but future updates may involve separate user specified val metrics in which case combining this update loop won't work. |
||
self.train_stats['val_' + metric.name] = [] | ||
self.train_loss_metrics = [] | ||
self.val_loss_metrics = [] | ||
# using the metric wrapper for loss to record loss value | ||
for l in self.loss: | ||
self.loss_metrics.append(Loss(l.name)) | ||
self.train_loss_metrics.append(Loss(l.name)) | ||
self.val_loss_metrics.append(Loss(l.name)) | ||
self.train_stats['train_' + l.name] = [] | ||
self.train_stats['val_' + l.name] = [] | ||
# only record the latest loss numbers after each batch | ||
self.train_stats['batch_' + l.name] = 0. | ||
|
||
# handle context | ||
if isinstance(context, Context): | ||
self.context = [context] | ||
if not context: | ||
elif isinstance(context, list) and all([isinstance(c, Context) for c in context]): | ||
self.context = context | ||
elif not context: | ||
if num_gpus() > 0: | ||
# only use 1 GPU by default | ||
if num_gpus() > 1: | ||
|
@@ -109,8 +123,13 @@ def __init__(self, net, | |
self.context = [gpu(0)] | ||
else: | ||
self.context = [cpu()] | ||
else: | ||
raise ValueError("context must be a Context or a list of Context, " | ||
"refer to mxnet.Context:{}".format(context)) | ||
|
||
|
||
# initialize the network | ||
self.initializer = initializer | ||
if self.initializer: | ||
if self._is_initialized(): | ||
# if already initialized, re-init with user specified initializer | ||
|
@@ -128,13 +147,13 @@ def __init__(self, net, | |
# handle trainers | ||
if isinstance(trainers, gluon.Trainer): | ||
self.trainers = [trainers] | ||
else: | ||
self.trainers = trainers or [] | ||
if not self.trainers: | ||
elif not trainers: | ||
warnings.warn("No trainer specified, default SGD optimizer " | ||
"with learning rate 0.001 is used.") | ||
self.trainers = [gluon.Trainer(self.net.collect_params(), | ||
'sgd', {'learning_rate': 0.001})] | ||
else: | ||
raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer") | ||
|
||
def _is_initialized(self): | ||
param_dict = self.net.collect_params() | ||
|
@@ -156,7 +175,48 @@ def _batch_fn(self, batch, ctx, is_iterator=False): | |
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
def evaluate(self, | ||
val_data, | ||
batch_fn=None): | ||
"""Evaluate model on validation data | ||
|
||
Parameters | ||
---------- | ||
val_data : DataLoader or DataIter | ||
validation data with data and labels | ||
batch_fn : function | ||
custom batch function to extract data and label | ||
from a data batch and load into contexts(devices) | ||
""" | ||
|
||
for metric in self.val_metrics + self.val_loss_metrics: | ||
metric.reset() | ||
|
||
for _, batch in enumerate(val_data): | ||
if not batch_fn: | ||
if isinstance(val_data, gluon.data.DataLoader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this if/else into into self._batch_fn There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this check needs to be before calling |
||
data, label = self._batch_fn(batch, self.context) | ||
elif isinstance(val_data, DataIter): | ||
data, label = self._batch_fn(batch, self.context, is_iterator=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same above |
||
else: | ||
raise ValueError("You are using a custom iteration, please also provide " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be helpful to the end user if could provide more detailed exception. you can append the below statement at the end, something like this: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing it out, updated! |
||
"batch_fn to extract data and label. Alternatively, you " | ||
"can provide the data as gluon.data.DataLoader or " | ||
"mx.io.DataIter") | ||
else: | ||
data, label = batch_fn(batch, self.context) | ||
pred = [self.net(x) for x in data] | ||
losses = [] | ||
for loss in self.loss: | ||
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) | ||
# update metrics | ||
for metric in self.val_metrics: | ||
metric.update(label, pred) | ||
for loss, loss_metric, in zip(losses, self.val_loss_metrics): | ||
loss_metric.update(0, [l for l in loss]) | ||
|
||
def fit(self, train_data, | ||
val_data=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this shouldn't be optional There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users might want to train without a validation set. Although this is rare, still keeping it optional provides a bit of flexibility There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a reason why users would not a validation dataset, its required to know that the model is not overfitting/ |
||
epochs=1, | ||
batch_size=None, | ||
event_handlers=None, | ||
|
@@ -204,7 +264,7 @@ def fit(self, train_data, | |
for handler in event_handlers: | ||
handler.epoch_begin() | ||
|
||
for metric in self.metrics + self.loss_metrics: | ||
for metric in self.train_metrics + self.train_loss_metrics: | ||
metric.reset() | ||
|
||
for i, batch in enumerate(train_data): | ||
|
@@ -215,7 +275,9 @@ def fit(self, train_data, | |
data, label = self._batch_fn(batch, self.context, is_iterator=True) | ||
else: | ||
raise ValueError("You are using a custom iteration, please also provide " | ||
"batch_fn to extract data and label") | ||
"batch_fn to extract data and label. Alternatively, you " | ||
"can provide the data as gluon.data.DataLoader or " | ||
"mx.io.DataIter") | ||
else: | ||
data, label = batch_fn(batch, self.context) | ||
|
||
|
@@ -233,11 +295,11 @@ def fit(self, train_data, | |
for l in loss: | ||
l.backward() | ||
|
||
# update metrics | ||
for metric in self.metrics: | ||
# update train metrics | ||
for metric in self.train_metrics: | ||
metric.update(label, pred) | ||
self.train_stats['batch_' + metric.name] = metric.get()[1] | ||
for loss, loss_metric, in zip(losses, self.loss_metrics): | ||
for loss, loss_metric, in zip(losses, self.train_loss_metrics): | ||
loss_metric.update(0, [l for l in loss]) | ||
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1] | ||
|
||
|
@@ -253,8 +315,14 @@ def fit(self, train_data, | |
for handler in event_handlers: | ||
handler.batch_end() | ||
|
||
for metric in self.metrics + self.loss_metrics: | ||
if val_data: | ||
self.evaluate(val_data, batch_fn) | ||
|
||
for metric in self.train_metrics + self.train_loss_metrics: | ||
self.train_stats['train_' + metric.name].append(metric.get()[1]) | ||
for metric in self.val_metrics + self.val_loss_metrics: | ||
self.train_stats['val_' + metric.name].append(metric.get()[1]) | ||
|
||
# epoch end | ||
for handler in event_handlers: | ||
handler.epoch_end() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logic similar to above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!