Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
YJ Shi committed Jul 27, 2022
1 parent bbb7f87 commit 2f42667
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 115 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/cost_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
"""
from .cost_model import CostModel, PyCostModel
from .random_model import RandomModel
from .xgb_model import XGBModel
from .xgb_model import XGBModel, XGBoostCustomCallback, PackSum
124 changes: 11 additions & 113 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@
from ..utils import cpu_count, derived_object, shash2hex
from .metric import max_curve

try:
from xgboost.callback import TrainingCallback # type: ignore
except ImportError:

class TrainingCallback: # type: ignore
pass
if TYPE_CHECKING:
try:
from xgboost.callback import TrainingCallback # type: ignore
except ImportError:

class TrainingCallback: # type: ignore
pass

if TYPE_CHECKING:
import xgboost as xgb # type: ignore

from ..tune_context import TuneContext
Expand Down Expand Up @@ -674,114 +674,8 @@ def init(env: "xgb.core.CallbackEnv"):
booster.set_attr(best_iteration=str(state["best_iteration"]))
booster.set_attr(best_score=str(state["best_score"]))

def callback(env: "xgb.core.CallbackEnv"):
# pylint:disable = import-outside-toplevel
import xgboost as xgb
from xgboost.callback import _fmt_metric # type: ignore
from xgboost.core import EarlyStopException # type: ignore

try:
from xgboost.training import aggcv # type: ignore
except ImportError:
from xgboost.callback import _aggcv as aggcv # type: ignore
# pylint:enable = import-outside-toplevel

if not state:
init(env)
booster: xgb.Booster = env.model
iteration: int = env.iteration
cvfolds: List[xgb.training.CVPack] = env.cvfolds
##### Evaluation #####
# `eval_result` is a list of (key, score)
eval_result: List[Tuple[str, float]] = []
if cvfolds is None:
eval_result = list(
itertools_chain.from_iterable(
[
(key, float(value))
for key, value in map(
lambda x: x.split(":"),
booster.eval_set(
evals=evals,
iteration=iteration,
feval=feval,
).split()[1:],
)
]
for feval in fevals
)
)
else:
eval_result = list(
itertools_chain.from_iterable(
[
(key, score)
for key, score, _std in aggcv(
fold.eval(
iteration=iteration,
feval=feval,
)
for fold in cvfolds
)
]
for feval in fevals
)
)
eval_result = list(eval_result)
eval_result.sort(key=sort_key)

##### Print eval result #####
if verbose_eval and iteration % verbose_eval == 0:
info = []
for key, score in eval_result:
if "null" not in key:
info.append(f"{key}: {score:.6f}")
logger.debug("XGB iter %3d: %s", iteration, "\t".join(info))

##### Choose score and do early stopping #####
score = None
for key, _score in eval_result:
if key == focused_metric:
score = _score
break
assert score is not None

best_score = state["best_score"]
best_iteration = state["best_iteration"]
if score < best_score:
tab = "\t" # to work with f-string
msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}"
state["best_msg"] = msg
state["best_score"] = score
state["best_iteration"] = env.iteration
# save the property to attributes, so they will occur in checkpoint.
if env.model is not None:
env.model.set_attr(
best_score=str(state["best_score"]),
best_iteration=str(state["best_iteration"]),
best_msg=state["best_msg"],
)
elif env.iteration - best_iteration >= early_stopping_rounds:
best_msg = state["best_msg"]
if verbose_eval and env.rank == 0:
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
raise EarlyStopException(best_iteration)

return callback


class XGBoostCallback(TrainingCallback):
"""Base class for XGBoost callbacks."""

def __call__(self, env: "xgb.core.CallbackEnv"):
# Compatibility with xgboost < 1.3
return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)

def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
raise NotImplementedError


class XGBoostCustomCallback(XGBoostCallback):
class XGBoostCustomCallback(TrainingCallback):
"""Custom callback class for xgboost to support multiple custom evaluation functions"""

def __init__(
Expand All @@ -804,6 +698,10 @@ def __init__(
if cvfolds is not None:
self.aggregated_cv = None

def __call__(self, env: "xgb.core.CallbackEnv"):
# Compatibility with xgboost < 1.3
return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)

def init(self, model: "xgb.Booster"):
"""Internal function for intialization"""
booster: "xgb.Booster" = model
Expand Down
92 changes: 91 additions & 1 deletion tests/python/unittest/test_meta_schedule_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
import pytest
import tvm
import tvm.testing
from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel
from tvm.meta_schedule.cost_model import (
PyCostModel,
RandomModel,
XGBModel,
XGBoostCustomCallback,
PackSum,
)
from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.search_strategy import MeasureCandidate
Expand Down Expand Up @@ -228,5 +234,89 @@ def test_meta_schedule_xgb_model_reupdate():
model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])


def test_meta_schedule_xgb_model_callback():
import xgboost as xgb
from itertools import chain as itertools_chain
from functools import partial

extractor = RandomFeatureExtractor()
model = XGBModel(extractor=extractor, num_warmup_samples=10)
update_sample_count = 20
predict_sample_count = 30

model.update(
TuneContext(),
[_dummy_candidate() for i in range(update_sample_count)],
[_dummy_result() for i in range(update_sample_count)],
)
model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])
with tempfile.NamedTemporaryFile() as path:
# Backup and train on new TrainingCallBack api
random_state = model.extractor.random_state # save feature extractor's random state

model.save(path.name)

old_booster = model.booster
xs = [
x.numpy().astype("float32")
for x in extractor.extract_from(
TuneContext(),
[_dummy_candidate() for i in range(predict_sample_count)],
)
]
d_test = PackSum(xs=xs, ys=None)
pred1 = old_booster.predict(d_test.dmatrix)

# Load and train on deprecated TrainingCallBack api
model.extractor.random_state = random_state # load feature extractor's random state
model.load(path.name)
d_train = PackSum(
xs=list(itertools_chain.from_iterable([g.features for g in model.data.values()])),
ys=np.concatenate(
[g.min_cost / g.costs for g in model.data.values()],
axis=0,
),
)

def obj(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return d_train.obj_square_error(ys_pred)

def rmse(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return d_train.rmse(ys_pred)

def avg_peak_score(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return d_train.average_peak_score(ys_pred, model.average_peak_n)

new_booster = xgb.train(
model.config.to_dict(),
d_train.dmatrix,
num_boost_round=10000,
obj=obj,
callbacks=[
partial(
XGBoostCustomCallback(
early_stopping_rounds=model.early_stopping_rounds,
verbose_eval=model.verbose_eval,
fevals=[rmse, avg_peak_score],
evals=[(d_train.dmatrix, "tr")],
cvfolds=None,
)
)
],
)

xs = [
x.numpy().astype("float32")
for x in extractor.extract_from(
TuneContext(),
[_dummy_candidate() for i in range(predict_sample_count)],
)
]
d_test = PackSum(xs=xs, ys=None)
pred2 = new_booster.predict(d_test.dmatrix)

assert np.allclose(pred1, pred2, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 2f42667

Please sign in to comment.