-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
Copy pathcallbacks.py
153 lines (128 loc) · 4.85 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Demo for using and defining callback functions
==============================================
.. versionadded:: 1.3.0
"""
import argparse
import os
import tempfile
from typing import Dict
import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import xgboost as xgb
class Plotting(xgb.callback.TrainingCallback):
"""Plot evaluation result during training. Only for demonstration purpose as it's
quite slow to draw using matplotlib.
"""
def __init__(self, rounds: int) -> None:
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
self.rounds = rounds
self.lines: Dict[str, plt.Line2D] = {}
self.fig.show()
self.x = np.linspace(0, self.rounds, self.rounds)
plt.ion()
def _get_key(self, data: str, metric: str) -> str:
return f"{data}-{metric}"
def after_iteration(
self, model: xgb.Booster, epoch: int, evals_log: Dict[str, dict]
) -> bool:
"""Update the plot."""
if not self.lines:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
(self.lines[key],) = self.ax.plot(self.x, expanded, label=key)
self.ax.legend()
else:
# https://pythonspot.com/matplotlib-update-plot/
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
self.lines[key].set_ydata(expanded)
self.fig.canvas.draw()
# False to indicate training should not stop.
return False
def custom_callback() -> None:
"""Demo for defining a custom callback function that plots evaluation result during
training."""
X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)
D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)
num_boost_round = 100
plotting = Plotting(num_boost_round)
# Pass it to the `callbacks` parameter as a list.
xgb.train(
{
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
"device": "cuda",
},
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=num_boost_round,
callbacks=[plotting],
)
def check_point_callback() -> None:
"""Demo for using the checkpoint callback. Custom logic for handling output is
usually required and users are encouraged to define their own callback for
checkpointing operations. The builtin one can be used as a starting point.
"""
# Only for demo, set a larger value (like 100) in practice as checkpointing is quite
# slow.
rounds = 2
def check(as_pickle: bool) -> None:
for i in range(0, 10, rounds):
if i == 0:
continue
if as_pickle:
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
else:
path = os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
assert os.path.exists(path)
X, y = load_breast_cancer(return_X_y=True)
m = xgb.DMatrix(X, y)
# Check point to a temporary directory for demo
with tempfile.TemporaryDirectory() as tmpdir:
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, interval=rounds, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(False)
# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, interval=rounds, as_pickle=True, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--plot", default=1, type=int)
args = parser.parse_args()
check_point_callback()
if args.plot:
custom_callback()