From f4a49ca3f8a4056f064f45492cd80c731fd744e7 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Mon, 25 Nov 2019 09:14:22 +0000 Subject: [PATCH 01/10] introduce gradient update handler to the base estimator --- .../gluon/contrib/estimator/estimator.py | 9 +++-- .../gluon/contrib/estimator/event_handler.py | 40 +++++++++++++++---- tests/python/unittest/test_gluon_estimator.py | 18 +++++---- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 54a0b165016e..40fb69b14973 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -24,7 +24,7 @@ import sys import warnings -from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler +from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler, GradientUpdateHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd from .event_handler import _check_event_handlers from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref @@ -302,13 +302,11 @@ def fit_batch(self, train_batch, batch_axis=0): with autograd.record(): pred = [self.net(x) for x in data] - loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)] + loss = [self.loss(y_hat, y) / batch_size for y_hat, y in zip(pred, label)] for l in loss: l.backward() - self.trainer.step(batch_size) - return data, label, pred, loss def fit(self, train_data, @@ -414,6 +412,9 @@ def _prepare_default_handlers(self, val_data, event_handlers): # no need to add to default handler check as StoppingHandler does not use metrics added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) + if not any(isinstance(handler, GradientUpdateHandler) for handler in event_handlers): + added_default_handlers.append(GradientUpdateHandler()) + if not any(isinstance(handler, MetricHandler) for handler in event_handlers): added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics)) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 53ba07dc836a..10fc6227e019 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -31,7 +31,7 @@ __all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd', 'StoppingHandler', 'MetricHandler', 'ValidationHandler', - 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler'] + 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler', 'GradientUpdateHandler'] class EventHandler(object): @@ -130,13 +130,15 @@ class MetricHandler(EpochBegin, BatchEnd): ---------- train_metrics : List of EvalMetrics Training metrics to be updated at batch end. + priority : scalar + Priority level of the MetricHandler """ - def __init__(self, train_metrics): + def __init__(self, train_metrics, priority=-1000): self.train_metrics = _check_metrics(train_metrics) # order to be called among all callbacks # metrics need to be calculated before other callbacks can access them - self.priority = -np.Inf + self.priority = priority def epoch_begin(self, estimator, *args, **kwargs): for metric in self.train_metrics: @@ -176,6 +178,8 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): batch_period : int, default None How often to run validation at batch end, by default :py:class:`ValidationHandler` does not validate at batch end. + priority: scalar, default -1000 + Priority level of the ValidataionHandler """ def __init__(self, @@ -183,7 +187,8 @@ def __init__(self, eval_fn, val_metrics=None, epoch_period=1, - batch_period=None): + batch_period=None, + priority=-1000): self.val_data = val_data self.eval_fn = eval_fn self.epoch_period = epoch_period @@ -193,7 +198,7 @@ def __init__(self, self.current_epoch = 0 # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them - self.priority = -np.Inf + self.priority = priority def train_begin(self, estimator, *args, **kwargs): # reset epoch and batch counter @@ -235,11 +240,14 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Training metrics to be logged, logged at batch end, epoch end, train end. val_metrics : list of EvalMetrics Validation metrics to be logged, logged at epoch end, train end. + priority : scalar, default np.Inf + Priority level of the LoggingHandler """ def __init__(self, log_interval='epoch', train_metrics=None, - val_metrics=None): + val_metrics=None, + priority=np.Inf): super(LoggingHandler, self).__init__() if not isinstance(log_interval, int) and log_interval != 'epoch': raise ValueError("log_interval must be either an integer or string 'epoch'") @@ -250,7 +258,7 @@ def __init__(self, log_interval='epoch', self.processed_samples = 0 # logging handler need to be called at last to make sure all states are updated # it will also shut down logging at train end - self.priority = np.Inf + self.priority = priority self.log_interval = log_interval def train_begin(self, estimator, *args, **kwargs): @@ -704,3 +712,21 @@ def train_end(self, estimator, *args, **kwargs): estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: ' 'early stopping due to %s not improving', self.stopped_epoch, self.monitor.get()[0]) + +class GradientUpdateHandler(BatchEnd): + """Gradient Update Handler that apply gradients on network weights + + :py:class:`GradientUpdateHandler` takes the priority level. It updates weight parameters + at the end of each batch + + Parameters + ---------- + priority : scalar, default -np.Inf + priority level of the gradient update handler. It should be executed before all other handlers. + ---------- + """ + def __init__(self, priority=-np.Inf): + self.priority = priority + + def batch_end(self, estimator, *args, **kwargs): + estimator.trainer.step(1) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index cf913a6161c0..a16f0e647053 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -334,39 +334,41 @@ def test_default_handlers(): train_acc = mx.metric.RMSE() loss = gluon.loss.L2Loss() + gradient_update = GradientUpdateHandler() est = Estimator(net=net, loss=loss, metrics=train_acc, trainer=trainer, context=ctx) - # no handler(all default handlers), no warning + # no handler except gradient update handler (all default handlers), no warning with warnings.catch_warnings(record=True) as w: - est.fit(train_data=train_data, epochs=num_epochs) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[gradient_update]) # handler with prepared loss and metrics # use mix of default and user defined handlers train_metrics = est.train_metrics val_metrics = est.val_metrics logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging, gradient_update]) # handler with all user defined metrics # use mix of default and user defined handlers metric = MetricHandler(train_metrics=[train_acc]) logging = LoggingHandler(train_metrics=[train_acc]) - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging]) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging, gradient_update]) # handler with mixed metrics, some handler use metrics prepared by estimator # some handler use metrics user prepared logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")]) with assert_raises(ValueError): - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging, gradient_update]) # test handler order train_metrics = est.train_metrics val_metrics = est.val_metrics early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping]) - assert len(handlers) == 4 - assert isinstance(handlers[0], MetricHandler) - assert isinstance(handlers[3], LoggingHandler) + assert len(handlers) == 5 + assert isinstance(handlers[0], GradientUpdateHandler) + assert isinstance(handlers[1], MetricHandler) + assert isinstance(handlers[4], LoggingHandler) From 68def7be812dde6902e76ccadef838aa31cc893c Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 26 Nov 2019 07:59:42 +0000 Subject: [PATCH 02/10] Modify the gradient update handler to include the batch size --- python/mxnet/gluon/contrib/estimator/estimator.py | 3 ++- .../mxnet/gluon/contrib/estimator/event_handler.py | 9 ++++++++- tests/python/unittest/test_gluon_estimator.py | 12 ++++++------ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 40fb69b14973..ab7018f58e1f 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -302,7 +302,7 @@ def fit_batch(self, train_batch, batch_axis=0): with autograd.record(): pred = [self.net(x) for x in data] - loss = [self.loss(y_hat, y) / batch_size for y_hat, y in zip(pred, label)] + loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)] for l in loss: l.backward() @@ -358,6 +358,7 @@ def fit(self, train_data, self.max_epoch = epochs self.max_batch = batches + self.batch_axis = batch_axis # provide default handlers event_handlers = self._prepare_default_handlers(val_data, event_handlers) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 10fc6227e019..2a91171b4e12 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -28,6 +28,7 @@ from ....metric import CompositeEvalMetric, EvalMetric from ....metric import Loss as metric_loss from .utils import _check_metrics +from .... import ndarray __all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd', 'StoppingHandler', 'MetricHandler', 'ValidationHandler', @@ -729,4 +730,10 @@ def __init__(self, priority=-np.Inf): self.priority = priority def batch_end(self, estimator, *args, **kwargs): - estimator.trainer.step(1) + loss = kwargs['loss'] + batch_size = 1 + if isinstance(loss, list) and len(loss) > 0: + loss = loss[0] + if isinstance(loss, ndarray.ndarray.NDArray): + batch_size = loss.shape[estimator.batch_axis] + estimator.trainer.step(batch_size) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index a16f0e647053..b87b895d2cb7 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -334,36 +334,36 @@ def test_default_handlers(): train_acc = mx.metric.RMSE() loss = gluon.loss.L2Loss() - gradient_update = GradientUpdateHandler() est = Estimator(net=net, loss=loss, metrics=train_acc, trainer=trainer, context=ctx) - # no handler except gradient update handler (all default handlers), no warning + # no handler(all default handlers), no warning with warnings.catch_warnings(record=True) as w: - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[gradient_update]) + est.fit(train_data=train_data, epochs=num_epochs) # handler with prepared loss and metrics # use mix of default and user defined handlers train_metrics = est.train_metrics val_metrics = est.val_metrics logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging, gradient_update]) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # handler with all user defined metrics # use mix of default and user defined handlers metric = MetricHandler(train_metrics=[train_acc]) logging = LoggingHandler(train_metrics=[train_acc]) - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging, gradient_update]) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging]) # handler with mixed metrics, some handler use metrics prepared by estimator # some handler use metrics user prepared logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")]) with assert_raises(ValueError): - est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging, gradient_update]) + est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # test handler order + gradient_update = GradientUpdateHandler() train_metrics = est.train_metrics val_metrics = est.val_metrics early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) From d121f16a7168a20ea1e620c59bad8b54e25b4ad6 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 26 Nov 2019 08:11:55 +0000 Subject: [PATCH 03/10] Remove unrelated gradient update handler. --- tests/python/unittest/test_gluon_estimator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index b87b895d2cb7..21f949a0bba6 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -363,7 +363,6 @@ def test_default_handlers(): est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) # test handler order - gradient_update = GradientUpdateHandler() train_metrics = est.train_metrics val_metrics = est.val_metrics early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) From d2a6f3a082b2dd95d1aa350b4541ec2d3c6b5eb9 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 26 Nov 2019 09:00:21 +0000 Subject: [PATCH 04/10] Modify gradient update handler to take the current batch size. --- python/mxnet/gluon/contrib/estimator/event_handler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 2a91171b4e12..d9a559b87277 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -731,9 +731,11 @@ def __init__(self, priority=-np.Inf): def batch_end(self, estimator, *args, **kwargs): loss = kwargs['loss'] - batch_size = 1 - if isinstance(loss, list) and len(loss) > 0: - loss = loss[0] + batch_size = 0 if isinstance(loss, ndarray.ndarray.NDArray): - batch_size = loss.shape[estimator.batch_axis] + loss = [loss] + if isinstance(loss, list) and len(loss) > 0: + for l in loss: + batch_size += l.shape[estimator.batch_axis] + estimator.trainer.step(batch_size) From 5c55bb4d2cfa0f35da521a496692f6e11c5a7455 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 26 Nov 2019 09:29:17 +0000 Subject: [PATCH 05/10] Remove white space to avoid the sanity check failure --- python/mxnet/gluon/contrib/estimator/event_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index d9a559b87277..9ddfe4e765f6 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -737,5 +737,5 @@ def batch_end(self, estimator, *args, **kwargs): if isinstance(loss, list) and len(loss) > 0: for l in loss: batch_size += l.shape[estimator.batch_axis] - + estimator.trainer.step(batch_size) From 19142a0538e4e032ec41b33b5073ab387ae0a903 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 26 Nov 2019 09:44:42 +0000 Subject: [PATCH 06/10] add small tweak to the handler code --- python/mxnet/gluon/contrib/estimator/event_handler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 9ddfe4e765f6..f71360762b17 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -28,7 +28,6 @@ from ....metric import CompositeEvalMetric, EvalMetric from ....metric import Loss as metric_loss from .utils import _check_metrics -from .... import ndarray __all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd', 'StoppingHandler', 'MetricHandler', 'ValidationHandler', @@ -732,7 +731,7 @@ def __init__(self, priority=-np.Inf): def batch_end(self, estimator, *args, **kwargs): loss = kwargs['loss'] batch_size = 0 - if isinstance(loss, ndarray.ndarray.NDArray): + if not isinstance(loss, list): loss = [loss] if isinstance(loss, list) and len(loss) > 0: for l in loss: From a1825e89b271bcb9bb3ebdc4cad7e4bd8348f1b3 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Wed, 27 Nov 2019 03:10:49 +0000 Subject: [PATCH 07/10] Modify the documentation of priority parameter of relevant handlers. --- .../mxnet/gluon/contrib/estimator/event_handler.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index f71360762b17..65fcd28b44de 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -131,7 +131,8 @@ class MetricHandler(EpochBegin, BatchEnd): train_metrics : List of EvalMetrics Training metrics to be updated at batch end. priority : scalar - Priority level of the MetricHandler + Priority level of the MetricHandler. Priority level is sorted in ascending + order. The lower the number is, the higher priority level it is. """ def __init__(self, train_metrics, priority=-1000): @@ -179,7 +180,8 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): How often to run validation at batch end, by default :py:class:`ValidationHandler` does not validate at batch end. priority: scalar, default -1000 - Priority level of the ValidataionHandler + Priority level of the ValidationHandler. Priority level is sorted in + ascending order. The lower the number is, the higher priority level it is. """ def __init__(self, @@ -241,7 +243,8 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat val_metrics : list of EvalMetrics Validation metrics to be logged, logged at epoch end, train end. priority : scalar, default np.Inf - Priority level of the LoggingHandler + Priority level of the LoggingHandler. Priority level is sorted in + ascending order. The lower the number is, the higher priority level it is. """ def __init__(self, log_interval='epoch', @@ -722,10 +725,11 @@ class GradientUpdateHandler(BatchEnd): Parameters ---------- priority : scalar, default -np.Inf - priority level of the gradient update handler. It should be executed before all other handlers. + priority level of the gradient update handler. Priority level is sorted in ascending + order. The lower the number is, the higher priority level it is. ---------- """ - def __init__(self, priority=-np.Inf): + def __init__(self, priority=-2000): self.priority = priority def batch_end(self, estimator, *args, **kwargs): From 6c5bf21f658f5daf93ad5c6791520ec5fdfb658f Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Wed, 27 Nov 2019 03:13:55 +0000 Subject: [PATCH 08/10] small modification on the documentation. --- python/mxnet/gluon/contrib/estimator/event_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 65fcd28b44de..d82fd3fd5e57 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -724,7 +724,7 @@ class GradientUpdateHandler(BatchEnd): Parameters ---------- - priority : scalar, default -np.Inf + priority : scalar, default -2000 priority level of the gradient update handler. Priority level is sorted in ascending order. The lower the number is, the higher priority level it is. ---------- From 058c3a46c81f16af2864ae8d07a7fc0153bbadb6 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Wed, 27 Nov 2019 03:18:35 +0000 Subject: [PATCH 09/10] Add small modification on the documentation. --- python/mxnet/gluon/contrib/estimator/event_handler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index d82fd3fd5e57..90c477623f16 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -132,7 +132,7 @@ class MetricHandler(EpochBegin, BatchEnd): Training metrics to be updated at batch end. priority : scalar Priority level of the MetricHandler. Priority level is sorted in ascending - order. The lower the number is, the higher priority level it is. + order. The lower the number is, the higher priority level the handler is. """ def __init__(self, train_metrics, priority=-1000): @@ -181,7 +181,8 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): :py:class:`ValidationHandler` does not validate at batch end. priority: scalar, default -1000 Priority level of the ValidationHandler. Priority level is sorted in - ascending order. The lower the number is, the higher priority level it is. + ascending order. The lower the number is, the higher priority level the + handler is. """ def __init__(self, @@ -244,7 +245,8 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Validation metrics to be logged, logged at epoch end, train end. priority : scalar, default np.Inf Priority level of the LoggingHandler. Priority level is sorted in - ascending order. The lower the number is, the higher priority level it is. + ascending order. The lower the number is, the higher priority level the + handler is. """ def __init__(self, log_interval='epoch', @@ -726,7 +728,7 @@ class GradientUpdateHandler(BatchEnd): ---------- priority : scalar, default -2000 priority level of the gradient update handler. Priority level is sorted in ascending - order. The lower the number is, the higher priority level it is. + order. The lower the number is, the higher priority level the handler is. ---------- """ def __init__(self, priority=-2000): From db20abb53d4bc69b0b6ec211bed3d172ef8551ff Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Mon, 9 Dec 2019 03:02:50 +0000 Subject: [PATCH 10/10] Remove unnecessary list check --- python/mxnet/gluon/contrib/estimator/event_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 90c477623f16..64777608bef0 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -739,7 +739,7 @@ def batch_end(self, estimator, *args, **kwargs): batch_size = 0 if not isinstance(loss, list): loss = [loss] - if isinstance(loss, list) and len(loss) > 0: + if isinstance(loss, list): for l in loss: batch_size += l.shape[estimator.batch_axis]