-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1211 from tqchen/master
[PYTHON] Refactor trainnig API to use callback
- Loading branch information
Showing
18 changed files
with
491 additions
and
277 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
# coding: utf-8 | ||
# pylint: disable= invalid-name | ||
"""Training Library containing training routines.""" | ||
from __future__ import absolute_import | ||
|
||
from . import rabit | ||
from .core import EarlyStopException | ||
|
||
|
||
def _fmt_metric(value, show_stdv=True): | ||
"""format metric string""" | ||
if len(value) == 2: | ||
return '%s:%g' % (value[0], value[1]) | ||
elif len(value) == 3: | ||
if show_stdv: | ||
return '%s:%g+%g' % (value[0], value[1], value[2]) | ||
else: | ||
return '%s:%g' % (value[0], value[1]) | ||
else: | ||
raise ValueError("wrong metric value") | ||
|
||
|
||
def print_evaluation(period=1, show_stdv=True): | ||
"""Create a callback that print evaluation result. | ||
Parameters | ||
---------- | ||
period : int | ||
The period to log the evaluation results | ||
show_stdv : bool, optional | ||
Whether show stdv if provided | ||
Returns | ||
------- | ||
callback : function | ||
A callback that print evaluation every period iterations. | ||
""" | ||
def callback(env): | ||
"""internal function""" | ||
if env.rank != 0 or len(env.evaluation_result_list) == 0: | ||
return | ||
i = env.iteration | ||
if (i % period == 0 or i + 1 == env.begin_iteration): | ||
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list]) | ||
rabit.tracker_print('[%d]\t%s\n' % (i, msg)) | ||
return callback | ||
|
||
|
||
def record_evaluation(eval_result): | ||
"""Create a call back that records the evaluation history into eval_result. | ||
Parameters | ||
---------- | ||
eval_result : dict | ||
A dictionary to store the evaluation results. | ||
Returns | ||
------- | ||
callback : function | ||
The requested callback function. | ||
""" | ||
if not isinstance(eval_result, dict): | ||
raise TypeError('eval_result has to be a dictionary') | ||
eval_result.clear() | ||
|
||
def init(env): | ||
"""internal function""" | ||
for k, _ in env.evaluation_result_list: | ||
key, metric = k.split('-') | ||
if key not in eval_result: | ||
eval_result[key] = {} | ||
if metric not in eval_result[key]: | ||
eval_result[key][metric] = [] | ||
|
||
def callback(env): | ||
"""internal function""" | ||
if len(eval_result) == 0: | ||
init(env) | ||
for k, v in env.evaluation_result_list: | ||
key, metric = k.split('-') | ||
eval_result[key][metric].append(v) | ||
return callback | ||
|
||
|
||
def reset_learning_rate(learning_rates): | ||
"""Reset learning rate after iteration 1 | ||
NOTE: the initial learning rate will still take in-effect on first iteration. | ||
Parameters | ||
---------- | ||
learning_rates: list or function | ||
List of learning rate for each boosting round | ||
or a customized function that calculates eta in terms of | ||
current number of round and the total number of boosting round (e.g. yields | ||
learning rate decay) | ||
- list l: eta = l[boosting round] | ||
- function f: eta = f(boosting round, num_boost_round) | ||
Returns | ||
------- | ||
callback : function | ||
The requested callback function. | ||
""" | ||
def callback(env): | ||
"""internal function""" | ||
bst = env.model | ||
i = env.iteration | ||
if isinstance(learning_rates, list): | ||
if len(learning_rates) != env.end_iteration: | ||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") | ||
bst.set_param('learning_rate', learning_rates[i]) | ||
else: | ||
bst.set_param('learning_rate', learning_rates(i, env.end_iteration)) | ||
callback.before_iteration = True | ||
return callback | ||
|
||
|
||
def early_stop(stopping_rounds, maximize=False, verbose=True): | ||
"""Create a callback that activates early stoppping. | ||
Validation error needs to decrease at least | ||
every <stopping_rounds> round(s) to continue training. | ||
Requires at least one item in evals. | ||
If there's more than one, will use the last. | ||
Returns the model from the last iteration (not the best one). | ||
If early stopping occurs, the model will have three additional fields: | ||
bst.best_score, bst.best_iteration and bst.best_ntree_limit. | ||
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree | ||
and/or num_class appears in the parameters) | ||
Parameters | ||
---------- | ||
stopp_rounds : int | ||
The stopping rounds before the trend occur. | ||
maximize : bool | ||
Whether to maximize evaluation metric. | ||
verbose : optional, bool | ||
Whether to print message about early stopping information. | ||
Returns | ||
------- | ||
callback : function | ||
The requested callback function. | ||
""" | ||
state = {} | ||
|
||
def init(env): | ||
"""internal function""" | ||
bst = env.model | ||
|
||
if len(env.evaluation_result_list) == 0: | ||
raise ValueError('For early stopping you need at least one set in evals.') | ||
if len(env.evaluation_result_list) > 1 and verbose: | ||
msg = ("Multiple eval metrics have been passed: " | ||
"'{0}' will be used for early stopping.\n\n") | ||
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0])) | ||
maximize_metrics = ('auc', 'map', 'ndcg') | ||
maximize_score = maximize | ||
metric = env.evaluation_result_list[-1][0] | ||
if any(env.evaluation_result_list[-1][0].split('-')[1].startswith(x) | ||
for x in maximize_metrics): | ||
maximize_score = True | ||
|
||
if verbose and env.rank == 0: | ||
msg = "Will train until {} hasn't improved in {} rounds.\n" | ||
rabit.tracker_print(msg.format(metric, stopping_rounds)) | ||
|
||
state['maximize_score'] = maximize_score | ||
state['best_iteration'] = 0 | ||
if maximize_score: | ||
state['best_score'] = float('-inf') | ||
else: | ||
state['best_score'] = float('inf') | ||
|
||
if bst is not None: | ||
if bst.attr('best_score') is not None: | ||
state['best_score'] = float(bst.attr('best_score')) | ||
state['best_iteration'] = int(bst.attr('best_iteration')) | ||
state['best_msg'] = bst.attr('best_msg') | ||
else: | ||
bst.set_attr(best_iteration=str(state['best_iteration'])) | ||
bst.set_attr(best_score=str(state['best_score'])) | ||
else: | ||
assert env.cvfolds is not None | ||
|
||
def callback(env): | ||
"""internal function""" | ||
score = env.evaluation_result_list[-1][1] | ||
if len(state) == 0: | ||
init(env) | ||
best_score = state['best_score'] | ||
best_iteration = state['best_iteration'] | ||
maximize_score = state['maximize_score'] | ||
if (maximize_score and score > best_score) or \ | ||
(not maximize_score and score < best_score): | ||
msg = '[%d]\t%s' % ( | ||
env.iteration, | ||
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])) | ||
state['best_msg'] = msg | ||
state['best_score'] = score | ||
state['best_iteration'] = env.iteration | ||
# save the property to attributes, so they will occur in checkpoint. | ||
if env.model is not None: | ||
env.model.set_attr(best_score=str(state['best_score']), | ||
best_iteration=str(state['best_iteration']), | ||
best_msg=state['best_msg']) | ||
elif env.iteration - best_iteration >= stopping_rounds: | ||
best_msg = state['best_msg'] | ||
if verbose and env.rank == 0: | ||
msg = "Stopping. Best iteration:\n{}\n\n" | ||
rabit.tracker_print(msg.format(best_msg)) | ||
raise EarlyStopException(best_iteration) | ||
return callback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.