Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PYTHON] Refactor trainnig API to use callback #1211

Merged
merged 1 commit into from
May 20, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ endif


# specify tensor path
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint


all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
Expand Down Expand Up @@ -131,8 +131,11 @@ rcpplint:
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src

lint: rcpplint
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin python-package

pylint:
flake8 --ignore E501 python-package
flake8 --ignore E501 tests/python
clean:
$(RM) -rf build build_plugin lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost

Expand Down
12 changes: 7 additions & 5 deletions demo/guide-python/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed = 0)
metrics={'error'}, seed = 0,
callbacks=[xgb.callback.print_evaluation(show_stdv=True)])

print ('running cross validation, disable standard deviation display')
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed = 0, show_stdv = False)

res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
metrics={'error'}, seed = 0,
callbacks=[xgb.callback.print_evaluation(show_stdv=False),
xgb.callback.early_stop(3)])
print (res)
print ('running cross validation, with preprocessing function')
# define the preprocessing function
# used to return the preprocessed training, test data, and parameter
Expand Down Expand Up @@ -58,4 +61,3 @@ def evalerror(preds, dtrain):
# train with customized objective
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
obj = logregobj, feval=evalerror)

4 changes: 2 additions & 2 deletions python-package/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

ignore=tests

unexpected-special-method-signature,too-many-nested-blocks
disiable=unexpected-special-method-signature,too-many-nested-blocks

dummy-variables-rgx=(unused|)_.*

reports=no
reports=no
217 changes: 217 additions & 0 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# coding: utf-8
# pylint: disable= invalid-name
"""Training Library containing training routines."""
from __future__ import absolute_import

from . import rabit
from .core import EarlyStopException


def _fmt_metric(value, show_stdv=True):
"""format metric string"""
if len(value) == 2:
return '%s:%g' % (value[0], value[1])
elif len(value) == 3:
if show_stdv:
return '%s:%g+%g' % (value[0], value[1], value[2])
else:
return '%s:%g' % (value[0], value[1])
else:
raise ValueError("wrong metric value")


def print_evaluation(period=1, show_stdv=True):
"""Create a callback that print evaluation result.

Parameters
----------
period : int
The period to log the evaluation results

show_stdv : bool, optional
Whether show stdv if provided

Returns
-------
callback : function
A callback that print evaluation every period iterations.
"""
def callback(env):
"""internal function"""
if env.rank != 0 or len(env.evaluation_result_list) == 0:
return
i = env.iteration
if (i % period == 0 or i + 1 == env.begin_iteration):
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
rabit.tracker_print('[%d]\t%s\n' % (i, msg))
return callback


def record_evaluation(eval_result):
"""Create a call back that records the evaluation history into eval_result.

Parameters
----------
eval_result : dict
A dictionary to store the evaluation results.

Returns
-------
callback : function
The requested callback function.
"""
if not isinstance(eval_result, dict):
raise TypeError('eval_result has to be a dictionary')
eval_result.clear()

def init(env):
"""internal function"""
for k, _ in env.evaluation_result_list:
key, metric = k.split('-')
if key not in eval_result:
eval_result[key] = {}
if metric not in eval_result[key]:
eval_result[key][metric] = []

def callback(env):
"""internal function"""
if len(eval_result) == 0:
init(env)
for k, v in env.evaluation_result_list:
key, metric = k.split('-')
eval_result[key][metric].append(v)
return callback


def reset_learning_rate(learning_rates):
"""Reset learning rate after iteration 1

NOTE: the initial learning rate will still take in-effect on first iteration.

Parameters
----------
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates eta in terms of
current number of round and the total number of boosting round (e.g. yields
learning rate decay)
- list l: eta = l[boosting round]
- function f: eta = f(boosting round, num_boost_round)

Returns
-------
callback : function
The requested callback function.
"""
def callback(env):
"""internal function"""
bst = env.model
i = env.iteration
if isinstance(learning_rates, list):
if len(learning_rates) != env.end_iteration:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
bst.set_param('learning_rate', learning_rates[i])
else:
bst.set_param('learning_rate', learning_rates(i, env.end_iteration))
callback.before_iteration = True
return callback


def early_stop(stopping_rounds, maximize=False, verbose=True):
"""Create a callback that activates early stoppping.

Validation error needs to decrease at least
every <stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have three additional fields:
bst.best_score, bst.best_iteration and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)

Parameters
----------
stopp_rounds : int
The stopping rounds before the trend occur.

maximize : bool
Whether to maximize evaluation metric.

verbose : optional, bool
Whether to print message about early stopping information.

Returns
-------
callback : function
The requested callback function.
"""
state = {}

def init(env):
"""internal function"""
bst = env.model

if len(env.evaluation_result_list) == 0:
raise ValueError('For early stopping you need at least one set in evals.')
if len(env.evaluation_result_list) > 1 and verbose:
msg = ("Multiple eval metrics have been passed: "
"'{0}' will be used for early stopping.\n\n")
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
maximize_metrics = ('auc', 'map', 'ndcg')
maximize_score = maximize
metric = env.evaluation_result_list[-1][0]
if any(env.evaluation_result_list[-1][0].split('-')[1].startswith(x)
for x in maximize_metrics):
maximize_score = True

if verbose and env.rank == 0:
msg = "Will train until {} hasn't improved in {} rounds.\n"
rabit.tracker_print(msg.format(metric, stopping_rounds))

state['maximize_score'] = maximize_score
state['best_iteration'] = 0
if maximize_score:
state['best_score'] = float('-inf')
else:
state['best_score'] = float('inf')

if bst is not None:
if bst.attr('best_score') is not None:
state['best_score'] = float(bst.attr('best_score'))
state['best_iteration'] = int(bst.attr('best_iteration'))
state['best_msg'] = bst.attr('best_msg')
else:
bst.set_attr(best_iteration=str(state['best_iteration']))
bst.set_attr(best_score=str(state['best_score']))
else:
assert env.cvfolds is not None

def callback(env):
"""internal function"""
score = env.evaluation_result_list[-1][1]
if len(state) == 0:
init(env)
best_score = state['best_score']
best_iteration = state['best_iteration']
maximize_score = state['maximize_score']
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
msg = '[%d]\t%s' % (
env.iteration,
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
state['best_msg'] = msg
state['best_score'] = score
state['best_iteration'] = env.iteration
# save the property to attributes, so they will occur in checkpoint.
if env.model is not None:
env.model.set_attr(best_score=str(state['best_score']),
best_iteration=str(state['best_iteration']),
best_msg=state['best_msg'])
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg']
if verbose and env.rank == 0:
msg = "Stopping. Best iteration:\n{}\n\n"
rabit.tracker_print(msg.format(best_msg))
raise EarlyStopException(best_iteration)
return callback
4 changes: 3 additions & 1 deletion python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=unused-import, invalid-name, wrong-import-position
# pylint: disable= invalid-name, unused-import
"""For compatibility"""

from __future__ import absolute_import
Expand All @@ -14,12 +14,14 @@
STRING_TYPES = str,

def py_str(x):
"""convert c string back to python string"""
return x.decode('utf-8')
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,

def py_str(x):
"""convert c string back to python string"""
return x

try:
Expand Down
Loading