Skip to content

Commit

Permalink
[BUG] bugfix on #387 - changed paramset 3 to use `ConditionUncensored…
Browse files Browse the repository at this point in the history
…` instead of `CoxPH` (#388)

#### Reference Issues/PRs

Fixes #387 . Changed paramset3 to use `ConditionUncensored` instead of
`CoxPH` since it doesn't seem stable on smaller datasets.

Discussion thread on #291
  • Loading branch information
julian-fong authored Jun 14, 2024
1 parent e723655 commit 70252b5
Showing 1 changed file with 18 additions and 28 deletions.
46 changes: 18 additions & 28 deletions skpro/model_selection/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,7 @@ def get_test_params(cls, parameter_set="default"):

from skpro.metrics import CRPS, PinballLoss
from skpro.regression.residual import ResidualDouble
from skpro.survival.coxph import CoxPH
from skpro.utils.validation._dependencies import _check_estimator_deps
from skpro.survival.compose._reduce_cond_unc import ConditionUncensored

linreg1 = LinearRegression()
linreg2 = LinearRegression(fit_intercept=False)
Expand All @@ -510,18 +509,14 @@ def get_test_params(cls, parameter_set="default"):
"error_score": "raise",
}

params = [param1, param2]

# testing with survival predictor
if _check_estimator_deps(CoxPH, severity="none"):
param3 = {
"estimator": CoxPH(alpha=0.05),
"cv": KFold(n_splits=4),
"param_grid": {"method": ["lpl", "elastic_net"]},
"scoring": PinballLoss(),
"error_score": "raise",
}
params.append(param3)
params3 = {
"estimator": ConditionUncensored(ResidualDouble(LinearRegression())),
"cv": KFold(n_splits=4),
"param_grid": {"estimator__fit_intercept": [True, False]},
"scoring": PinballLoss(),
"error_score": "raise",
}
params = [param1, param2, params3]

return params

Expand Down Expand Up @@ -747,8 +742,7 @@ def get_test_params(cls, parameter_set="default"):

from skpro.metrics import CRPS, PinballLoss
from skpro.regression.residual import ResidualDouble
from skpro.survival.coxph import CoxPH
from skpro.utils.validation._dependencies import _check_estimator_deps
from skpro.survival.compose._reduce_cond_unc import ConditionUncensored

linreg1 = LinearRegression()
linreg2 = LinearRegression(fit_intercept=False)
Expand All @@ -769,17 +763,13 @@ def get_test_params(cls, parameter_set="default"):
"error_score": "raise",
}

params = [param1, param2]

# testing with survival predictor
if _check_estimator_deps(CoxPH, severity="none"):
param3 = {
"estimator": CoxPH(alpha=0.05),
"cv": KFold(n_splits=4),
"param_distributions": {"method": ["lpl", "elastic_net"]},
"scoring": PinballLoss(),
"error_score": "raise",
}
params += [param3]
params3 = {
"estimator": ConditionUncensored(ResidualDouble(LinearRegression())),
"cv": KFold(n_splits=4),
"param_distributions": {"estimator__fit_intercept": [True, False]},
"scoring": PinballLoss(),
"error_score": "raise",
}
params = [param1, param2, params3]

return params

0 comments on commit 70252b5

Please sign in to comment.