From 149589c58358088ad3232bd64774d7f1548fa7c0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 19 May 2016 17:47:11 -0700 Subject: [PATCH] [PYTHON] Refactor trainnig API to use callback --- Makefile | 7 +- demo/guide-python/cross_validation.py | 12 +- python-package/.pylintrc | 4 +- python-package/xgboost/callback.py | 217 +++++++++ python-package/xgboost/compat.py | 4 +- python-package/xgboost/core.py | 50 +- python-package/xgboost/rabit.py | 5 +- python-package/xgboost/sklearn.py | 7 +- python-package/xgboost/training.py | 426 ++++++++---------- rabit | 2 +- tests/python/test_basic.py | 20 +- tests/python/test_early_stopping.py | 2 +- tests/python/test_eval_metrics.py | 2 +- tests/python/test_plotting.py | 2 +- tests/python/test_training_continuation.py | 2 +- tests/python/test_with_pandas.py | 2 +- tests/python/test_with_sklearn.py | 2 +- .../xgboost => tests/python}/testing.py | 2 +- 18 files changed, 491 insertions(+), 277 deletions(-) create mode 100644 python-package/xgboost/callback.py rename {python-package/xgboost => tests/python}/testing.py (87%) diff --git a/Makefile b/Makefile index e3f3134e4a9e..abe8ccfaab6c 100644 --- a/Makefile +++ b/Makefile @@ -73,7 +73,7 @@ endif # specify tensor path -.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java +.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost @@ -131,8 +131,11 @@ rcpplint: python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src lint: rcpplint - python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin + python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin python-package +pylint: + flake8 --ignore E501 python-package + flake8 --ignore E501 tests/python clean: $(RM) -rf build build_plugin lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost diff --git a/demo/guide-python/cross_validation.py b/demo/guide-python/cross_validation.py index 6ca13d46035d..5c8ee0b1b54f 100755 --- a/demo/guide-python/cross_validation.py +++ b/demo/guide-python/cross_validation.py @@ -12,15 +12,18 @@ # [iteration] metric_name:mean_value+std_value # std_value is standard deviation of the metric xgb.cv(param, dtrain, num_round, nfold=5, - metrics={'error'}, seed = 0) + metrics={'error'}, seed = 0, + callbacks=[xgb.callback.print_evaluation(show_stdv=True)]) print ('running cross validation, disable standard deviation display') # do cross validation, this will print result out as # [iteration] metric_name:mean_value+std_value # std_value is standard deviation of the metric -xgb.cv(param, dtrain, num_round, nfold=5, - metrics={'error'}, seed = 0, show_stdv = False) - +res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5, + metrics={'error'}, seed = 0, + callbacks=[xgb.callback.print_evaluation(show_stdv=False), + xgb.callback.early_stop(3)]) +print (res) print ('running cross validation, with preprocessing function') # define the preprocessing function # used to return the preprocessed training, test data, and parameter @@ -58,4 +61,3 @@ def evalerror(preds, dtrain): # train with customized objective xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0, obj = logregobj, feval=evalerror) - diff --git a/python-package/.pylintrc b/python-package/.pylintrc index 1e63cdabe703..e8e957d2b1eb 100644 --- a/python-package/.pylintrc +++ b/python-package/.pylintrc @@ -2,8 +2,8 @@ ignore=tests -unexpected-special-method-signature,too-many-nested-blocks +disiable=unexpected-special-method-signature,too-many-nested-blocks dummy-variables-rgx=(unused|)_.* -reports=no \ No newline at end of file +reports=no diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py new file mode 100644 index 000000000000..3683ea2dd61f --- /dev/null +++ b/python-package/xgboost/callback.py @@ -0,0 +1,217 @@ +# coding: utf-8 +# pylint: disable= invalid-name +"""Training Library containing training routines.""" +from __future__ import absolute_import + +from . import rabit +from .core import EarlyStopException + + +def _fmt_metric(value, show_stdv=True): + """format metric string""" + if len(value) == 2: + return '%s:%g' % (value[0], value[1]) + elif len(value) == 3: + if show_stdv: + return '%s:%g+%g' % (value[0], value[1], value[2]) + else: + return '%s:%g' % (value[0], value[1]) + else: + raise ValueError("wrong metric value") + + +def print_evaluation(period=1, show_stdv=True): + """Create a callback that print evaluation result. + + Parameters + ---------- + period : int + The period to log the evaluation results + + show_stdv : bool, optional + Whether show stdv if provided + + Returns + ------- + callback : function + A callback that print evaluation every period iterations. + """ + def callback(env): + """internal function""" + if env.rank != 0 or len(env.evaluation_result_list) == 0: + return + i = env.iteration + if (i % period == 0 or i + 1 == env.begin_iteration): + msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list]) + rabit.tracker_print('[%d]\t%s\n' % (i, msg)) + return callback + + +def record_evaluation(eval_result): + """Create a call back that records the evaluation history into eval_result. + + Parameters + ---------- + eval_result : dict + A dictionary to store the evaluation results. + + Returns + ------- + callback : function + The requested callback function. + """ + if not isinstance(eval_result, dict): + raise TypeError('eval_result has to be a dictionary') + eval_result.clear() + + def init(env): + """internal function""" + for k, _ in env.evaluation_result_list: + key, metric = k.split('-') + if key not in eval_result: + eval_result[key] = {} + if metric not in eval_result[key]: + eval_result[key][metric] = [] + + def callback(env): + """internal function""" + if len(eval_result) == 0: + init(env) + for k, v in env.evaluation_result_list: + key, metric = k.split('-') + eval_result[key][metric].append(v) + return callback + + +def reset_learning_rate(learning_rates): + """Reset learning rate after iteration 1 + + NOTE: the initial learning rate will still take in-effect on first iteration. + + Parameters + ---------- + learning_rates: list or function + List of learning rate for each boosting round + or a customized function that calculates eta in terms of + current number of round and the total number of boosting round (e.g. yields + learning rate decay) + - list l: eta = l[boosting round] + - function f: eta = f(boosting round, num_boost_round) + + Returns + ------- + callback : function + The requested callback function. + """ + def callback(env): + """internal function""" + bst = env.model + i = env.iteration + if isinstance(learning_rates, list): + if len(learning_rates) != env.end_iteration: + raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") + bst.set_param('learning_rate', learning_rates[i]) + else: + bst.set_param('learning_rate', learning_rates(i, env.end_iteration)) + callback.before_iteration = True + return callback + + +def early_stop(stopping_rounds, maximize=False, verbose=True): + """Create a callback that activates early stoppping. + + Validation error needs to decrease at least + every round(s) to continue training. + Requires at least one item in evals. + If there's more than one, will use the last. + Returns the model from the last iteration (not the best one). + If early stopping occurs, the model will have three additional fields: + bst.best_score, bst.best_iteration and bst.best_ntree_limit. + (Use bst.best_ntree_limit to get the correct value if num_parallel_tree + and/or num_class appears in the parameters) + + Parameters + ---------- + stopp_rounds : int + The stopping rounds before the trend occur. + + maximize : bool + Whether to maximize evaluation metric. + + verbose : optional, bool + Whether to print message about early stopping information. + + Returns + ------- + callback : function + The requested callback function. + """ + state = {} + + def init(env): + """internal function""" + bst = env.model + + if len(env.evaluation_result_list) == 0: + raise ValueError('For early stopping you need at least one set in evals.') + if len(env.evaluation_result_list) > 1 and verbose: + msg = ("Multiple eval metrics have been passed: " + "'{0}' will be used for early stopping.\n\n") + rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0])) + maximize_metrics = ('auc', 'map', 'ndcg') + maximize_score = maximize + metric = env.evaluation_result_list[-1][0] + if any(env.evaluation_result_list[-1][0].split('-')[1].startswith(x) + for x in maximize_metrics): + maximize_score = True + + if verbose and env.rank == 0: + msg = "Will train until {} hasn't improved in {} rounds.\n" + rabit.tracker_print(msg.format(metric, stopping_rounds)) + + state['maximize_score'] = maximize_score + state['best_iteration'] = 0 + if maximize_score: + state['best_score'] = float('-inf') + else: + state['best_score'] = float('inf') + + if bst is not None: + if bst.attr('best_score') is not None: + state['best_score'] = float(bst.attr('best_score')) + state['best_iteration'] = int(bst.attr('best_iteration')) + state['best_msg'] = bst.attr('best_msg') + else: + bst.set_attr(best_iteration=str(state['best_iteration'])) + bst.set_attr(best_score=str(state['best_score'])) + else: + assert env.cvfolds is not None + + def callback(env): + """internal function""" + score = env.evaluation_result_list[-1][1] + if len(state) == 0: + init(env) + best_score = state['best_score'] + best_iteration = state['best_iteration'] + maximize_score = state['maximize_score'] + if (maximize_score and score > best_score) or \ + (not maximize_score and score < best_score): + msg = '[%d]\t%s' % ( + env.iteration, + '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])) + state['best_msg'] = msg + state['best_score'] = score + state['best_iteration'] = env.iteration + # save the property to attributes, so they will occur in checkpoint. + if env.model is not None: + env.model.set_attr(best_score=str(state['best_score']), + best_iteration=str(state['best_iteration']), + best_msg=state['best_msg']) + elif env.iteration - best_iteration >= stopping_rounds: + best_msg = state['best_msg'] + if verbose and env.rank == 0: + msg = "Stopping. Best iteration:\n{}\n\n" + rabit.tracker_print(msg.format(best_msg)) + raise EarlyStopException(best_iteration) + return callback diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 44707c539cea..8237b1249fef 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -1,5 +1,5 @@ # coding: utf-8 -# pylint: disable=unused-import, invalid-name, wrong-import-position +# pylint: disable= invalid-name, unused-import """For compatibility""" from __future__ import absolute_import @@ -14,12 +14,14 @@ STRING_TYPES = str, def py_str(x): + """convert c string back to python string""" return x.decode('utf-8') else: # pylint: disable=invalid-name STRING_TYPES = basestring, def py_str(x): + """convert c string back to python string""" return x try: diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f22ca7ef1dd5..e31f622cf0e2 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1,5 +1,6 @@ # coding: utf-8 -# pylint: disable=too-many-arguments, too-many-branches +# pylint: disable=too-many-arguments, too-many-branches, invalid-name +# pylint: disable=too-many-branches, too-many-lines, W0141 """Core XGBoost Library.""" from __future__ import absolute_import @@ -22,6 +23,31 @@ class XGBoostError(Exception): pass +class EarlyStopException(Exception): + """Exception to signal early stopping. + + Parameters + ---------- + best_iteration : int + The best iteration stopped. + """ + def __init__(self, best_iteration): + super(EarlyStopException, self).__init__() + self.best_iteration = best_iteration + + +# Callback environment used by callbacks +CallbackEnv = collections.namedtuple( + "XGBoostCallbackEnv", + ["model", + "cvfolds", + "iteration", + "begin_iteration", + "end_iteration", + "rank", + "evaluation_result_list"]) + + def from_pystr_to_cstr(data): """Convert a list of Python str to C pointer @@ -657,7 +683,7 @@ def __setstate__(self, state): def __copy__(self): return self.__deepcopy__(None) - def __deepcopy__(self, memo): + def __deepcopy__(self, _): return Booster(model_file=self.save_raw()) def copy(self): @@ -975,7 +1001,6 @@ def load_model(self, fname): _check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)) def dump_model(self, fout, fmap='', with_stats=False): - # pylint: disable=consider-using-enumerate """ Dump model into a text file. @@ -1143,10 +1168,12 @@ def _validate_features(self, data): msg = 'feature_names mismatch: {0} {1}' if dat_missing: - msg += '\nexpected ' + ', '.join(str(s) for s in dat_missing) + ' in input data' + msg += ('\nexpected ' + ', '.join(str(s) for s in dat_missing) + + ' in input data') if my_missing: - msg += '\ntraining data did not have the following fields: ' + ', '.join(str(s) for s in my_missing) + msg += ('\ntraining data did not have the following fields: ' + + ', '.join(str(s) for s in my_missing)) raise ValueError(msg.format(self.feature_names, data.feature_names)) @@ -1161,23 +1188,25 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True) The name of feature map file. bin: int, default None The maximum number of bins. - Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique. + Number of bins equals number of unique split values n_unique, + if bins == None or bins > n_unique. as_pandas : bool, default True Return pd.DataFrame when pandas is installed. If False or pandas is not installed, return numpy ndarray. Returns ------- - a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame. + a histogram of used splitting values for the specified feature + either as numpy array or pandas DataFrame. """ xgdump = self.get_dump(fmap=fmap) values = [] - regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature)) + regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature)) for i in range(len(xgdump)): m = re.findall(regexp, xgdump[i]) values.extend(map(float, m)) - n_unique = np.unique(values).shape[0] + n_unique = len(np.unique(values)) bins = max(min(n_unique, bins) if bins is not None else n_unique, 1) nph = np.histogram(values, bins=bins) @@ -1187,7 +1216,8 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True) if as_pandas and PANDAS_INSTALLED: return DataFrame(nph, columns=['SplitValue', 'Count']) elif as_pandas and not PANDAS_INSTALLED: - sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).") + sys.stderr.write( + "Returning histogram as ndarray (as_pandas == True, but pandas is not installed).") return nph else: return nph diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index af85b2dd043a..89b2a4ec651d 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -1,3 +1,6 @@ +# coding: utf-8 +# pylint: disable= invalid-name + """Distributed XGBoost Rabit related API.""" from __future__ import absolute_import import sys @@ -179,7 +182,7 @@ def allreduce(data, op, prepare_fun=None): else: func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) - def pfunc(args): + def pfunc(_): """prepare function.""" prepare_fun(data) _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index cafbe073fa88..4e3251f818b9 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1,5 +1,5 @@ # coding: utf-8 -# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme +# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912 """Scikit-Learn Wrapper interface for XGBoost.""" from __future__ import absolute_import @@ -42,6 +42,7 @@ def _objective_decorator(func): ``dmatrix.get_label()`` """ def inner(preds, dmatrix): + """internal function""" labels = dmatrix.get_label() return func(labels, preds) return inner @@ -183,7 +184,7 @@ def get_xgb_params(self): def fit(self, X, y, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True): - # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init, redefined-variable-type + # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init """ Fit the gradient boosting model @@ -351,7 +352,7 @@ def __init__(self, max_depth=3, learning_rate=0.1, def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True): - # pylint: disable = attribute-defined-outside-init,arguments-differ, redefined-variable-type + # pylint: disable = attribute-defined-outside-init,arguments-differ """ Fit gradient boosting classifier diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index d21edd30d201..3da92ff51345 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -1,20 +1,122 @@ # coding: utf-8 # pylint: disable=too-many-locals, too-many-arguments, invalid-name -# pylint: disable=too-many-branches +# pylint: disable=too-many-branches, too-many-statements """Training Library containing training routines.""" from __future__ import absolute_import -import sys -import re + import numpy as np -from .core import Booster, STRING_TYPES, XGBoostError +from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from . import rabit +from . import callback + + +def _train_internal(params, dtrain, + num_boost_round=10, evals=(), + obj=None, feval=None, + xgb_model=None, callbacks=None): + """internal training function""" + callbacks = [] if callbacks is None else callbacks + evals = list(evals) + if isinstance(params, dict) \ + and 'eval_metric' in params \ + and isinstance(params['eval_metric'], list): + params = dict((k, v) for k, v in params.items()) + eval_metrics = params['eval_metric'] + params.pop("eval_metric", None) + params = list(params.items()) + for eval_metric in eval_metrics: + params += [('eval_metric', eval_metric)] + + bst = Booster(params, [dtrain] + [d[0] for d in evals]) + nboost = 0 + num_parallel_tree = 1 + + if xgb_model is not None: + if not isinstance(xgb_model, STRING_TYPES): + xgb_model = xgb_model.save_raw() + bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) + nboost = len(bst.get_dump()) + else: + bst = Booster(params, [dtrain] + [d[0] for d in evals]) + + _params = dict(params) if isinstance(params, list) else params + + if 'num_parallel_tree' in _params: + num_parallel_tree = _params['num_parallel_tree'] + nboost //= num_parallel_tree + if 'num_class' in _params: + nboost //= _params['num_class'] + + # Distributed code: Load the checkpoint from rabit. + version = bst.load_rabit_checkpoint() + assert(rabit.get_world_size() != 1 or version == 0) + rank = rabit.get_rank() + start_iteration = int(version / 2) + nboost += start_iteration + + callbacks_before_iter = [ + cb for cb in callbacks if cb.__dict__.get('before_iteration', False)] + callbacks_after_iter = [ + cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)] + + for i in range(start_iteration, num_boost_round): + for cb in callbacks_before_iter: + cb(CallbackEnv(model=bst, + cvfolds=None, + iteration=i, + begin_iteration=start_iteration, + end_iteration=num_boost_round, + rank=rank, + evaluation_result_list=None)) + # Distributed code: need to resume to this point. + # Skip the first update if it is a recovery step. + if version % 2 == 0: + bst.update(dtrain, i, obj) + bst.save_rabit_checkpoint() + version += 1 + + assert(rabit.get_world_size() == 1 or version == rabit.version_number()) + + nboost += 1 + evaluation_result_list = [] + # check evaluation result. + if len(evals) != 0: + bst_eval_set = bst.eval_set(evals, i, feval) + if isinstance(bst_eval_set, STRING_TYPES): + msg = bst_eval_set + else: + msg = bst_eval_set.decode() + res = [x.split(':') for x in msg.split()] + evaluation_result_list = [(k, float(v)) for k, v in res[1:]] + try: + for cb in callbacks_after_iter: + cb(CallbackEnv(model=bst, + cvfolds=None, + iteration=i, + begin_iteration=start_iteration, + end_iteration=num_boost_round, + rank=rank, + evaluation_result_list=evaluation_result_list)) + except EarlyStopException: + break + # do checkpoint after evaluation, in case evaluation also updates booster. + bst.save_rabit_checkpoint() + version += 1 + + if bst.attr('best_score') is not None: + bst.best_score = float(bst.attr('best_score')) + bst.best_iteration = int(bst.attr('best_iteration')) + else: + bst.best_iteration = nboost - 1 + bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree + return bst def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None, evals_result=None, - verbose_eval=True, learning_rates=None, xgb_model=None): + verbose_eval=True, learning_rates=None, xgb_model=None, callbacks=None): # pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init """Train a booster with given parameters. @@ -70,176 +172,37 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, xgb_model : file name of stored xgb model or 'Booster' instance Xgb model to be loaded before training (allows training continuation). + callbacks : list of callback functions + List of callback functions that are applied at end of each iteration. + Returns ------- booster : a trained booster model """ - evals = list(evals) - if isinstance(params, dict) \ - and 'eval_metric' in params \ - and isinstance(params['eval_metric'], list): - params = dict((k, v) for k, v in params.items()) - eval_metrics = params['eval_metric'] - params.pop("eval_metric", None) - params = list(params.items()) - for eval_metric in eval_metrics: - params += [('eval_metric', eval_metric)] - - bst = Booster(params, [dtrain] + [d[0] for d in evals]) - nboost = 0 - num_parallel_tree = 1 + callbacks = [] if callbacks is None else callbacks - if isinstance(verbose_eval, bool): - verbose_eval_every_line = False + # Most of legacy advanced options becomes callbacks + if isinstance(verbose_eval, bool) and verbose_eval: + callbacks.append(callback.print_evaluation()) else: if isinstance(verbose_eval, int): - verbose_eval_every_line = verbose_eval - verbose_eval = True if verbose_eval_every_line > 0 else False + callbacks.append(callback.print_evaluation(verbose_eval)) - if rabit.get_rank() != 0: - verbose_eval = False - - if xgb_model is not None: - if not isinstance(xgb_model, STRING_TYPES): - xgb_model = xgb_model.save_raw() - bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) - nboost = len(bst.get_dump()) - else: - bst = Booster(params, [dtrain] + [d[0] for d in evals]) - - _params = dict(params) if isinstance(params, list) else params - _eta_param_name = 'eta' if 'eta' in _params else 'learning_rate' - if 'num_parallel_tree' in _params: - num_parallel_tree = _params['num_parallel_tree'] - nboost //= num_parallel_tree - if 'num_class' in _params: - nboost //= _params['num_class'] - - if evals_result is not None: - if not isinstance(evals_result, dict): - raise TypeError('evals_result has to be a dictionary') - else: - evals_name = [d[1] for d in evals] - evals_result.clear() - evals_result.update(dict([(key, {}) for key in evals_name])) - - # early stopping if early_stopping_rounds is not None: - if len(evals) < 1: - raise ValueError('For early stopping you need at least one set in evals.') - - if verbose_eval: - rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format( - evals[-1][1], early_stopping_rounds)) - - # is params a list of tuples? are we using multiple eval metrics? - if isinstance(params, list): - if len(params) != len(dict(params).items()): - params = dict(params) - msg = ("Multiple eval metrics have been passed: " - "'{0}' will be used for early stopping.\n\n") - rabit.tracker_print(msg.format(params['eval_metric'])) - else: - params = dict(params) + callbacks.append(callback.early_stop(early_stopping_rounds, + maximize=maximize, + verbose=bool(verbose_eval))) + if learning_rates is not None: + callbacks.append(callback.reset_learning_rate(learning_rates)) - # either minimize loss or maximize AUC/MAP/NDCG - maximize_score = False - if 'eval_metric' in params: - maximize_metrics = ('auc', 'map', 'ndcg') - if any(params['eval_metric'].startswith(x) for x in maximize_metrics): - maximize_score = True - if feval is not None: - maximize_score = maximize - - if maximize_score: - bst.set_attr(best_score='0.0') - else: - bst.set_attr(best_score='inf') - bst.set_attr(best_iteration='0') - - if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round: - raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") - - # Distributed code: Load the checkpoint from rabit. - version = bst.load_rabit_checkpoint() - assert(rabit.get_world_size() != 1 or version == 0) - start_iteration = int(version / 2) - nboost += start_iteration - - for i in range(start_iteration, num_boost_round): - if learning_rates is not None: - if isinstance(learning_rates, list): - bst.set_param(_eta_param_name, learning_rates[i]) - else: - bst.set_param(_eta_param_name, learning_rates(i, num_boost_round)) - - # Distributed code: need to resume to this point. - # Skip the first update if it is a recovery step. - if version % 2 == 0: - bst.update(dtrain, i, obj) - bst.save_rabit_checkpoint() - version += 1 - - assert(rabit.get_world_size() == 1 or version == rabit.version_number()) - - nboost += 1 - # check evaluation result. - if len(evals) != 0: - bst_eval_set = bst.eval_set(evals, i, feval) - - if isinstance(bst_eval_set, STRING_TYPES): - msg = bst_eval_set - else: - msg = bst_eval_set.decode() - - if verbose_eval: - if verbose_eval_every_line: - if i % verbose_eval_every_line == 0 or i == num_boost_round - 1: - rabit.tracker_print(msg + '\n') - else: - rabit.tracker_print(msg + '\n') - - if evals_result is not None: - res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg) - for key in evals_name: - evals_idx = evals_name.index(key) - res_per_eval = len(res) // len(evals_name) - for r in range(res_per_eval): - res_item = res[(evals_idx * res_per_eval) + r] - res_key = res_item[0] - res_val = res_item[1] - if res_key in evals_result[key]: - evals_result[key][res_key].append(res_val) - else: - evals_result[key][res_key] = [res_val] - - if early_stopping_rounds: - score = float(msg.rsplit(':', 1)[1]) - best_score = float(bst.attr('best_score')) - best_iteration = int(bst.attr('best_iteration')) - if (maximize_score and score > best_score) or \ - (not maximize_score and score < best_score): - # save the property to attributes, so they will occur in checkpoint. - bst.set_attr(best_score=str(score), - best_iteration=str(nboost - 1), - best_msg=msg) - elif i - best_iteration >= early_stopping_rounds: - best_msg = bst.attr('best_msg') - if verbose_eval: - msg = "Stopping. Best iteration:\n{}\n\n" - rabit.tracker_print(msg.format(best_msg)) - break - # do checkpoint after evaluation, in case evaluation also updates booster. - bst.save_rabit_checkpoint() - version += 1 + if evals_result is not None: + callbacks.append(callback.record_evaluation(evals_result)) - if early_stopping_rounds: - bst.best_score = float(bst.attr('best_score')) - bst.best_iteration = int(bst.attr('best_iteration')) - else: - bst.best_iteration = nboost - 1 - bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree - return bst + return _train_internal(params, dtrain, + num_boost_round=num_boost_round, + evals=evals, + obj=obj, feval=feval, + xgb_model=xgb_model, callbacks=callbacks) class CVPack(object): @@ -294,7 +257,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, return ret -def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0): +def aggcv(rlist): # pylint: disable=invalid-name """ Aggregate cross-validation results. @@ -315,50 +278,21 @@ def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0): if k not in cvmap: cvmap[k] = [] cvmap[k].append(float(v)) - msg = idx - - if show_stdv: - fmt = '\tcv-{0}:{1}+{2}' - else: - fmt = '\tcv-{0}:{1}' - - index = [] results = [] - for k, v in sorted(cvmap.items(), key=lambda x: x[0]): + for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])): v = np.array(v) if not isinstance(msg, STRING_TYPES): msg = msg.decode() mean, std = np.mean(v), np.std(v) - msg += fmt.format(k, mean, std) - - index.extend([k + '-mean', k + '-std']) - results.extend([mean, std]) - - if as_pandas: - try: - import pandas as pd - results = pd.Series(results, index=index) - except ImportError: - if verbose_eval is None: - verbose_eval = True - else: - # if verbose_eval is default (None), - # result will be np.ndarray as it can't hold column name - if verbose_eval is None: - verbose_eval = True - - if (isinstance(verbose_eval, int) and verbose_eval > 0 and trial % verbose_eval == 0) or \ - (isinstance(verbose_eval, bool) and verbose_eval): - sys.stderr.write(msg + '\n') - sys.stderr.flush() - + results.extend([(k, mean, std)]) return results def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None, metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None, - fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0): + fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0, + callbacks=None): # pylint: disable = invalid-name """Cross-validation with given paramaters. @@ -404,6 +338,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None Results are not affected, and always contains std. seed : int Seed used to generate the folds (passed to numpy.random.seed). + callbacks : list of callback functions + List of callback functions that are applied at end of each iteration. Returns ------- @@ -431,59 +367,63 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None params.pop("eval_metric", None) + results = {} + cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds) + + # setup callbacks + callbacks = [] if callbacks is None else callbacks if early_stopping_rounds is not None: + callbacks.append(callback.early_stop(early_stopping_rounds, + maximize=maximize, + verbose=False)) + if isinstance(verbose_eval, bool) and verbose_eval: + callbacks.append(callback.print_evaluation(show_stdv=show_stdv)) + else: + if isinstance(verbose_eval, int): + callbacks.append(callback.print_evaluation(verbose_eval, show_stdv=show_stdv)) - if len(metrics) > 1: - msg = ('Check your params. ' - 'Early stopping works with single eval metric only.') - raise ValueError(msg) - if verbose_eval: - msg = "Will train until cv error hasn't decreased in {} rounds.\n" - sys.stderr.write(msg.format(early_stopping_rounds)) - - maximize_score = False - if len(metrics) == 1: - maximize_metrics = ('auc', 'map', 'ndcg') - if any(metrics[0].startswith(x) for x in maximize_metrics): - maximize_score = True - if feval is not None: - maximize_score = maximize - - if maximize_score: - best_score = 0.0 - else: - best_score = float('inf') + callbacks_before_iter = [ + cb for cb in callbacks if cb.__dict__.get('before_iteration', False)] + callbacks_after_iter = [ + cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)] - best_score_i = 0 - results = [] - cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds) for i in range(num_boost_round): + for cb in callbacks_before_iter: + cb(CallbackEnv(model=None, + cvfolds=cvfolds, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + rank=0, + evaluation_result_list=None)) for fold in cvfolds: fold.update(i, obj) - res = aggcv([f.eval(i, feval) for f in cvfolds], - show_stdv=show_stdv, verbose_eval=verbose_eval, - as_pandas=as_pandas, trial=i) - results.append(res) - - if early_stopping_rounds is not None: - score = res[0] - if (maximize_score and score > best_score) or \ - (not maximize_score and score < best_score): - best_score = score - best_score_i = i - elif i - best_score_i >= early_stopping_rounds: - results = results[:best_score_i + 1] - if verbose_eval: - msg = "Stopping. Best iteration:\n[{}] cv-mean:{}\tcv-std:{}\n" - sys.stderr.write(msg.format(best_score_i, results[-1][0], results[-1][1])) - break + res = aggcv([f.eval(i, feval) for f in cvfolds]) + + for key, mean, std in res: + if key + '-mean' not in results: + results[key + '-mean'] = [] + if key + '-std' not in results: + results[key + '-std'] = [] + results[key + '-mean'].append(mean) + results[key + '-std'].append(std) + try: + for cb in callbacks_after_iter: + cb(CallbackEnv(model=None, + cvfolds=cvfolds, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + rank=0, + evaluation_result_list=res)) + except EarlyStopException as e: + for k in results.keys(): + results[k] = results[k][:(e.best_iteration + 1)] + break if as_pandas: try: import pandas as pd - results = pd.DataFrame(results) + results = pd.DataFrame.from_dict(results) except ImportError: - results = np.array(results) - else: - results = np.array(results) - + pass return results diff --git a/rabit b/rabit index e19fced5cbd4..8f61535b83e6 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit e19fced5cbd4e41b10099facae7caa5cd3e6ada3 +Subproject commit 8f61535b83e650331459d7f33a1615fa7d27b7bd diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 38646209161a..710de987d35d 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -35,6 +35,22 @@ def test_basic(self): # assert they are the same assert np.sum(np.abs(preds2 - preds)) == 0 + def test_record_results(self): + dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') + dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') + param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'} + # specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 2 + result = {} + res2 = {} + xgb.train(param, dtrain, num_round, watchlist, + callbacks=[xgb.callback.record_evaluation(result)]) + xgb.train(param, dtrain, num_round, watchlist, + evals_result=res2) + assert result['train']['error'][0] < 0.1 + assert res2 == result + def test_multiclass(self): dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') @@ -189,5 +205,5 @@ def test_cv(self): # return np.ndarray cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False) - assert isinstance(cv, np.ndarray) - assert cv.shape == (10, 4) + assert isinstance(cv, dict) + assert len(cv) == (4) diff --git a/tests/python/test_early_stopping.py b/tests/python/test_early_stopping.py index b015547a1ae6..67e725b74c1d 100644 --- a/tests/python/test_early_stopping.py +++ b/tests/python/test_early_stopping.py @@ -1,5 +1,5 @@ import xgboost as xgb -import xgboost.testing as tm +import testing as tm import numpy as np import unittest diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 2391bfe28fd5..529ef698c4a8 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -1,5 +1,5 @@ import xgboost as xgb -import xgboost.testing as tm +import testing as tm import numpy as np import unittest diff --git a/tests/python/test_plotting.py b/tests/python/test_plotting.py index 7a70bd95e8be..fde98dcca5ca 100644 --- a/tests/python/test_plotting.py +++ b/tests/python/test_plotting.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np import xgboost as xgb -import xgboost.testing as tm +import testing as tm import unittest try: diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index 2cb93f9ac3f8..f7511f68524d 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -1,5 +1,5 @@ import xgboost as xgb -import xgboost.testing as tm +import testing as tm import numpy as np import unittest diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 9536c1e82693..0bef20ec2f3e 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np import xgboost as xgb -import xgboost.testing as tm +import testing as tm import unittest try: diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 72ae27948d1a..d079d99fe2df 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1,7 +1,7 @@ import numpy as np import random import xgboost as xgb -import xgboost.testing as tm +import testing as tm rng = np.random.RandomState(1994) diff --git a/python-package/xgboost/testing.py b/tests/python/testing.py similarity index 87% rename from python-package/xgboost/testing.py rename to tests/python/testing.py index 647a89fef0ef..fb368dedd62c 100644 --- a/python-package/xgboost/testing.py +++ b/tests/python/testing.py @@ -17,6 +17,6 @@ def _skip_if_no_pandas(): def _skip_if_no_matplotlib(): try: - import matplotlib.pyplot as plt # noqa + import matplotlib.pyplot as _ # noqa except ImportError: raise nose.SkipTest()