Skip to content

Commit

Permalink
[XGBoost,MetaSchedule] Support xgb set tree method (apache#15133)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentccc authored Jul 28, 2023
1 parent 64ac43a commit d1f7ef4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
9 changes: 6 additions & 3 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ def create(
if kind == "xgb":
return XGBModel(*args, **kwargs) # type: ignore

if "num_tuning_cores" in kwargs:
# num_tuning_cores is only relevant for XGBModel.
kwargs.pop("num_tuning_cores")
# params only relevant to XGBModel
_xgb_params = ["num_tuning_cores", "tree_method"]

for param in _xgb_params:
if param in kwargs:
kwargs.pop(param)

if kind == "random":
return RandomModel(*args, **kwargs) # type: ignore
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from itertools import chain as itertools_chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple

from typing_extensions import Literal

import numpy as np # type: ignore

from ...contrib.tar import tar, untar
Expand Down Expand Up @@ -202,6 +204,8 @@ def average_peak_score(
class XGBConfig(NamedTuple):
"""XGBoost model configuration
Reference: https://xgboost.readthedocs.io/en/stable/parameter.html
Parameters
----------
max_depth : int
Expand All @@ -217,6 +221,8 @@ class XGBConfig(NamedTuple):
nthread : Optional[int],
The number of threads to use.
Default is None, which means to use physical number of cores.
tree_method : Literal["auto", "exact", "approx", "hist", "gpu_hist"]
The tree construction algorithm used in XGBoost.
"""

max_depth: int = 10
Expand All @@ -225,15 +231,19 @@ class XGBConfig(NamedTuple):
eta: float = 0.2
seed: int = 43
nthread: Optional[int] = None
tree_method: Literal["auto", "exact", "approx", "hist", "gpu_hist"] = "auto"

def to_dict(self):
"""Convert to dict"""

return {
"max_depth": self.max_depth,
"gamma": self.gamma,
"min_child_weight": self.min_child_weight,
"eta": self.eta,
"seed": self.seed,
"nthread": self.nthread,
"tree_method": self.tree_method,
}


Expand Down Expand Up @@ -334,6 +344,7 @@ def __init__(
average_peak_n: int = 32,
adaptive_training: bool = True,
num_tuning_cores: Optional[int] = None,
tree_method: Optional[Literal["auto", "exact", "approx", "hist", "gpu_hist"]] = None,
):
super().__init__()
if not isinstance(extractor, FeatureExtractor):
Expand All @@ -348,6 +359,9 @@ def __init__(
else:
config = config._replace(nthread=num_tuning_cores)

if tree_method is not None:
config._replace(tree_method=tree_method)

self.config = config
# behavior of randomness
self.num_warmup_samples = num_warmup_samples
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def tune_tasks(
elif not isinstance(database, Database):
database = Database.create(database, module_equality=module_equality)
if not isinstance(cost_model, CostModel):
cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores)
cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores, tree_method="auto")
if isinstance(measure_callbacks, MeasureCallback):
measure_callbacks = [measure_callbacks]
elif measure_callbacks == "default":
Expand Down

0 comments on commit d1f7ef4

Please sign in to comment.