From e522dc7550edb38648aca7e5fd66a636ba2473bf Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 15 Aug 2019 12:11:18 -0700 Subject: [PATCH 1/7] added check for empty params file and unknown param (not arg/aux) --- python/mxnet/model.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index aee4a8ce2b45..f26bbb9b322d 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -38,7 +38,7 @@ from .optimizer import get_updater from .executor_manager import DataParallelExecutorManager, _check_arguments, _load_data from .io import DataDesc -from .base import mx_real_t +from .base import mx_real_t, MXNetError BASE_ESTIMATOR = object @@ -451,12 +451,18 @@ def load_checkpoint(prefix, epoch): save_dict = nd.load('%s-%04d.params' % (prefix, epoch)) arg_params = {} aux_params = {} - for k, v in save_dict.items(): - tp, name = k.split(':', 1) - if tp == 'arg': - arg_params[name] = v - if tp == 'aux': - aux_params[name] = v + #load any params in the dict, skip if params are empty + if(save_dict): + for k, v in save_dict.items(): + tp, name = k.split(':', 1) + if tp == 'arg': + arg_params[name] = v + elif tp == 'aux': + aux_params[name] = v + else: + raise MXNetError("Params file '%s' contains unknown param '%s'" % + ('%s-%04d.params' % (prefix, epoch), + k)) return (symbol, arg_params, aux_params) from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position From 31fa25614d81c954ddd30b9312f533e84430cf44 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 15 Aug 2019 12:26:51 -0700 Subject: [PATCH 2/7] changed exception to warning for unknown params --- python/mxnet/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index f26bbb9b322d..aab5986545ac 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -460,7 +460,7 @@ def load_checkpoint(prefix, epoch): elif tp == 'aux': aux_params[name] = v else: - raise MXNetError("Params file '%s' contains unknown param '%s'" % + logging.warning("Params file '%s' contains unknown param '%s'" % ('%s-%04d.params' % (prefix, epoch), k)) return (symbol, arg_params, aux_params) From e2c6b056840493fa12a22a82c3d9285972de567e Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 15 Aug 2019 12:28:30 -0700 Subject: [PATCH 3/7] removed unnecessary MXNetError import --- python/mxnet/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index aab5986545ac..965543180f9d 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -38,7 +38,7 @@ from .optimizer import get_updater from .executor_manager import DataParallelExecutorManager, _check_arguments, _load_data from .io import DataDesc -from .base import mx_real_t, MXNetError +from .base import mx_real_t BASE_ESTIMATOR = object From b4d9f7c5956823d38b0169c170beb8a98be1af0a Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 15 Aug 2019 12:51:41 -0700 Subject: [PATCH 4/7] added warning message is params is empty --- python/mxnet/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 965543180f9d..ed5aa3dd04a0 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -452,7 +452,9 @@ def load_checkpoint(prefix, epoch): arg_params = {} aux_params = {} #load any params in the dict, skip if params are empty - if(save_dict): + if not save_dict: + logging.warning("Params file '%s' is empty") + else: for k, v in save_dict.items(): tp, name = k.split(':', 1) if tp == 'arg': From e11e5ceb8678105bdfe5a1080b79f7c5e4a2e180 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 15 Aug 2019 12:53:54 -0700 Subject: [PATCH 5/7] fixed print --- python/mxnet/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index ed5aa3dd04a0..429e8d4c91c1 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -453,7 +453,7 @@ def load_checkpoint(prefix, epoch): aux_params = {} #load any params in the dict, skip if params are empty if not save_dict: - logging.warning("Params file '%s' is empty") + logging.warning("Params file '%s' is empty" % ('%s-%04d.params' % (prefix, epoch))) else: for k, v in save_dict.items(): tp, name = k.split(':', 1) From 62ca94e9f5856e07bafd45f2da17356cd1453c33 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 15 Aug 2019 13:03:04 -0700 Subject: [PATCH 6/7] fixed formatting --- python/mxnet/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 429e8d4c91c1..32e43dd9cc53 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -453,7 +453,7 @@ def load_checkpoint(prefix, epoch): aux_params = {} #load any params in the dict, skip if params are empty if not save_dict: - logging.warning("Params file '%s' is empty" % ('%s-%04d.params' % (prefix, epoch))) + logging.warning("Params file '%s' is empty", '%s-%04d.params' % (prefix, epoch)) else: for k, v in save_dict.items(): tp, name = k.split(':', 1) @@ -462,9 +462,8 @@ def load_checkpoint(prefix, epoch): elif tp == 'aux': aux_params[name] = v else: - logging.warning("Params file '%s' contains unknown param '%s'" % - ('%s-%04d.params' % (prefix, epoch), - k)) + logging.warning("Params file '%s' contains unknown param '%s'", + '%s-%04d.params' % (prefix, epoch), k)) return (symbol, arg_params, aux_params) from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position From 8df4c3d8466e3094bde0f7ff94593c476a24c7cf Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Thu, 15 Aug 2019 13:06:54 -0700 Subject: [PATCH 7/7] missing paren --- python/mxnet/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 32e43dd9cc53..1ff1ee04643f 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -463,7 +463,7 @@ def load_checkpoint(prefix, epoch): aux_params[name] = v else: logging.warning("Params file '%s' contains unknown param '%s'", - '%s-%04d.params' % (prefix, epoch), k)) + '%s-%04d.params' % (prefix, epoch), k) return (symbol, arg_params, aux_params) from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position