Skip to content

Commit

Permalink
Merge pull request #12 from PriorLabs/add_ignore_pretraining_limits_phe
Browse files Browse the repository at this point in the history
add: support for ignore_pretraining_limits to PHE
  • Loading branch information
LeoGrin authored Jan 15, 2025
2 parents 6102c53 + 5b22d88 commit c0125a0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/tabpfn_extensions/post_hoc_ensembles/pfn_phe.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
n_folds: int | None = None,
holdout_fraction: float = 0.33,
ges_n_iterations: int = 25,
ignore_pretraining_limits: bool = False,
) -> None:
"""Builds a PostHocEnsembleConfig with default values for the given parameters.
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
n_folds: The number of folds to use for cross-validation. Pas None in the holdout setting.
holdout_fraction: The fraction of the data to use for holdout validation.
ges_n_iterations: The number of iterations to use for the greedy ensemble search.
ignore_pretraining_limits: Whether to ignore the pretraining limits of the TabPFN models.
"""
# Task Type and User Input
self.preset = preset
Expand All @@ -148,6 +150,7 @@ def __init__(
self.device = device
self.bm_random_state = bm_random_state
self.ges_random_state = ges_random_state
self.ignore_pretraining_limits = ignore_pretraining_limits

# Model Source
self.tabpfn_base_model_source = tabpfn_base_model_source
Expand Down Expand Up @@ -353,6 +356,7 @@ def _collect_base_models(
device=self.device,
random_state=self.bm_random_state,
categorical_indices=categorical_feature_indices,
ignore_pretraining_limits=self.ignore_pretraining_limits,
)
else:
raise ValueError(
Expand Down Expand Up @@ -419,6 +423,7 @@ def _get_base_models_from_random_search(
categorical_indices: list[int],
random_portfolio_size: int = 100,
start_with_default_pfn: bool = True,
ignore_pretraining_limits: bool = False,
) -> list[tuple[str, object]]:
# TODO: switch to config space to not depend on hyperopt
from hyperopt.pyll import stochastic
Expand Down Expand Up @@ -460,6 +465,7 @@ def _get_base_models_from_random_search(
param["device"] = device
param["random_state"] = model_seed
param["categorical_features_indices"] = categorical_indices
param["ignore_pretraining_limits"] = ignore_pretraining_limits
n_ensemble_repeats = param.pop("n_ensemble_repeats", None)
model_is_rf_pfn = param.pop("model_type", "no") == "dt_pfn"
if model_is_rf_pfn:
Expand Down
10 changes: 10 additions & 0 deletions src/tabpfn_extensions/post_hoc_ensembles/sklearn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class AutoTabPFNClassifier(ClassifierMixin, BaseEstimator):
Controls both the randomness base models and the post hoc ensembling method.
categorical_feature_indices: list[int] or None, default=None
The indices of the categorical features in the input data. Can also be passed to `fit()`.
ignore_pretraining_limits: bool, default=False
Whether to ignore the pretraining limits of the TabPFN base models.
phe_init_args : dict | None, default=None
The initialization arguments for the post hoc ensemble predictor.
See post_hoc_ensembles.pfn_phe.AutoPostHocEnsemblePredictor for more options and all details.
Expand All @@ -62,6 +64,7 @@ def __init__(
device: Literal["cpu", "cuda"] = "cpu",
random_state: int | None | np.random.RandomState = None,
categorical_feature_indices: list[int] | None = None,
ignore_pretraining_limits: bool = False,
phe_init_args: dict | None = None,
):
self.max_time = max_time
Expand All @@ -71,6 +74,7 @@ def __init__(
self.random_state = random_state
self.phe_init_args = phe_init_args
self.categorical_feature_indices = categorical_feature_indices
self.ignore_pretraining_limits = ignore_pretraining_limits

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
Expand Down Expand Up @@ -101,6 +105,7 @@ def fit(self, X, y, categorical_feature_indices: list[int] | None = None):
device=self.device,
bm_random_state=rnd.randint(0, MAX_INT),
ges_random_state=rnd.randint(0, MAX_INT),
ignore_pretraining_limits=self.ignore_pretraining_limits,
**self.phe_init_args_,
)

Expand Down Expand Up @@ -146,6 +151,8 @@ class AutoTabPFNRegressor(RegressorMixin, BaseEstimator):
Controls both the randomness base models and the post hoc ensembling method.
categorical_feature_indices: list[int] or None, default=None
The indices of the categorical features in the input data. Can also be passed to `fit()`.
ignore_pretraining_limits: bool, default=False
Whether to ignore the pretraining limits of the TabPFN base models.
phe_init_args : dict | None, default=None
The initialization arguments for the post hoc ensemble predictor.
See post_hoc_ensembles.pfn_phe.AutoPostHocEnsemblePredictor for more options and all details.
Expand All @@ -170,6 +177,7 @@ def __init__(
device: Literal["cpu", "cuda"] = "cpu",
random_state: int | None | np.random.RandomState = None,
categorical_feature_indices: list[int] | None = None,
ignore_pretraining_limits: bool = False,
phe_init_args: dict | None = None,
):
self.max_time = max_time
Expand All @@ -179,6 +187,7 @@ def __init__(
self.random_state = random_state
self.phe_init_args = phe_init_args
self.categorical_feature_indices = categorical_feature_indices
self.ignore_pretraining_limits = ignore_pretraining_limits

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
Expand Down Expand Up @@ -206,6 +215,7 @@ def fit(self, X, y, categorical_feature_indices: list[int] | None = None):
device=self.device,
bm_random_state=rnd.randint(0, MAX_INT),
ges_random_state=rnd.randint(0, MAX_INT),
ignore_pretraining_limits=self.ignore_pretraining_limits,
**self.phe_init_args_,
)

Expand Down

0 comments on commit c0125a0

Please sign in to comment.