Skip to content

Commit

Permalink
Make feature validation immutable. (#9388)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jul 15, 2023
1 parent 0a07900 commit b342ef9
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,7 @@ def __init__(
)
for d in cache:
# Validate feature only after the feature names are saved into booster.
self._validate_dmatrix_features(d)
self._assign_dmatrix_features(d)

if isinstance(model_file, Booster):
assert self.handle is not None
Expand Down Expand Up @@ -1746,6 +1746,11 @@ def __setstate__(self, state: Dict) -> None:
self.__dict__.update(state)

def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
"""Get a slice of the tree-based model.
.. versionadded:: 1.3.0
"""
if isinstance(val, int):
val = slice(val, val + 1)
if isinstance(val, tuple):
Expand Down Expand Up @@ -1784,6 +1789,11 @@ def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
return sliced

def __iter__(self) -> Generator["Booster", None, None]:
"""Iterator method for getting individual trees.
.. versionadded:: 2.0.0
"""
for i in range(0, self.num_boosted_rounds()):
yield self[i]

Expand Down Expand Up @@ -1994,7 +2004,7 @@ def update(
"""
if not isinstance(dtrain, DMatrix):
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_dmatrix_features(dtrain)
self._assign_dmatrix_features(dtrain)

if fobj is None:
_check_call(
Expand Down Expand Up @@ -2026,7 +2036,7 @@ def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None:
raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}")
if not isinstance(dtrain, DMatrix):
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_dmatrix_features(dtrain)
self._assign_dmatrix_features(dtrain)

_check_call(
_LIB.XGBoosterBoostOneIter(
Expand Down Expand Up @@ -2067,7 +2077,7 @@ def eval_set(
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
if not isinstance(d[1], str):
raise TypeError(f"expected string, got {type(d[1]).__name__}")
self._validate_dmatrix_features(d[0])
self._assign_dmatrix_features(d[0])

dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
Expand Down Expand Up @@ -2119,7 +2129,7 @@ def eval(self, data: DMatrix, name: str = "eval", iteration: int = 0) -> str:
result: str
Evaluation result string.
"""
self._validate_dmatrix_features(data)
self._assign_dmatrix_features(data)
return self.eval_set([(data, name)], iteration)

# pylint: disable=too-many-function-args
Expand Down Expand Up @@ -2218,7 +2228,8 @@ def predict(
if not isinstance(data, DMatrix):
raise TypeError("Expecting data to be a DMatrix object, got: ", type(data))
if validate_features:
self._validate_dmatrix_features(data)
fn = data.feature_names
self._validate_features(fn)
args = {
"type": 0,
"training": training,
Expand Down Expand Up @@ -2843,14 +2854,13 @@ def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = "") -> DataFrame:
# pylint: disable=no-member
return df.sort(["Tree", "Node"]).reset_index(drop=True)

def _validate_dmatrix_features(self, data: DMatrix) -> None:
def _assign_dmatrix_features(self, data: DMatrix) -> None:
if data.num_row() == 0:
return

fn = data.feature_names
ft = data.feature_types
# Be consistent with versions before 1.7, "validate" actually modifies the
# booster.

if self.feature_names is None:
self.feature_names = fn
if self.feature_types is None:
Expand Down

0 comments on commit b342ef9

Please sign in to comment.