Skip to content

Commit

Permalink
[Feature] change metric and set default metric (#5)
Browse files Browse the repository at this point in the history
* deduplicated key

* change default metric

* change example name

* change optimizer and rektgbm

* change metric add params
  • Loading branch information
RektPunk authored Aug 7, 2024
1 parent 760892f commit 1d03e14
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 24 deletions.
File renamed without changes.
22 changes: 22 additions & 0 deletions rektgbm/addparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Dict

from rektgbm.base import MethodName
from rektgbm.objective import ObjectiveName


def set_additional_params(
objective: ObjectiveName,
method: MethodName,
params: Dict[str, Any],
) -> Dict[str, Any]:
if objective == ObjectiveName.quantile:
if method == MethodName.lightgbm and "quantile_alpha" in params.keys():
params["alpha"] = params.pop("quantile_alpha")
elif method == MethodName.xgboost and "alpha" in params.keys():
params["quantile_alpha"] = params.pop("alpha")
elif objective == ObjectiveName.huber:
if method == MethodName.lightgbm and "huber_slope" in params.keys():
params["alpha"] = params.pop("quantile_alpha")
elif method == MethodName.xgboost and "alpha" in params.keys():
params["huber_slope"] = params.pop("alpha")
return params
2 changes: 0 additions & 2 deletions rektgbm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def predict(self):
class MethodName(BaseEnum):
lightgbm: str = "lightgbm"
xgboost: str = "xgboost"
lgb: str = "lightgbm"
xgb: str = "xgboost"


XdataLike = Union[pd.DataFrame, pd.Series, np.ndarray]
Expand Down
57 changes: 41 additions & 16 deletions rektgbm/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
from typing import Dict, List, Optional

from rektgbm.base import BaseEnum, MethodName
from rektgbm.objective import ObjectiveName
from rektgbm.task import TaskType


class MetricName(BaseEnum):
rmse: str = "rmse"
mae: str = "mae"
mse: str = "mse"
mape: str = "mape"
huber: str = "huber"
gamma: str = "gamma"
gamma_deviance: str = "gamma_deviance"
poisson: str = "poisson"
tweedie: str = "tweedie"
quantile: str = "quantile"
logloss: str = "logloss"
auc: str = "auc"
mlogloss: str = "mlogloss"
Expand All @@ -28,6 +30,7 @@ class XgbMetricName(BaseEnum):
mae: str = "mae"
mape: str = "mape"
mphe: str = "mphe"
quantile: str = "quantile"
logloss: str = "logloss"
error: str = "error"
merror: str = "merror"
Expand All @@ -48,42 +51,41 @@ class XgbMetricName(BaseEnum):

class LgbMetricName(BaseEnum):
# https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric
mae: str = "mae"
mse: str = "mse"
l1: str = "l1"
l2: str = "l2"
rmse: str = "rmse"
quantile: str = "quantile"
mape: str = "mape"
binary_logloss: str = "binary_logloss"
binary_error: str = "binary_error"
multi_logloss: str = "multi_logloss"
multi_error: str = "multi_error"
huber: str = "huber"
fair: str = "fair"
poisson: str = "poisson"
gamma: str = "gamma"
gamma_deviance: str = "gamma_deviance"
tweedie: str = "tweedie"
ndcg: str = "ndcg"
lambdarank: str = "ndcg"
cross_entropy: str = "cross_entropy"
cross_entropy_lambda: str = "cross_entropy_lambda"
kullback_leibler: str = "kullback_leibler"
map: str = "map"
mean_average_precision: str = "mean_average_precision"
auc: str = "auc"
average_precision: str = "average_precision"
binary_logloss: str = "binary_logloss"
binary_error: str = "binary_error"
auc_mu: str = "auc_mu"
multi_logloss: str = "multi_logloss"
multi_error: str = "multi_error"
cross_entropy: str = "cross_entropy"
cross_entropy_lambda: str = "cross_entropy_lambda"
kullback_leibler: str = "kullback_leibler"


TASK_METRIC_MAPPER: Dict[TaskType, List[MetricName]] = {
TaskType.regression: [
MetricName.rmse,
MetricName.mae,
MetricName.mse,
MetricName.huber,
MetricName.mape,
MetricName.gamma,
MetricName.gamma_deviance,
MetricName.poisson,
MetricName.quantile,
MetricName.tweedie,
],
TaskType.binary: [
Expand All @@ -100,6 +102,21 @@ class LgbMetricName(BaseEnum):
}


OBJECTIVE_METRIC_MAPPER: Dict[ObjectiveName, MetricName] = {
ObjectiveName.rmse: MetricName.rmse,
ObjectiveName.mae: MetricName.mae,
ObjectiveName.huber: MetricName.huber,
ObjectiveName.poisson: MetricName.poisson,
ObjectiveName.quantile: MetricName.quantile,
ObjectiveName.gamma: MetricName.gamma,
ObjectiveName.tweedie: MetricName.tweedie,
ObjectiveName.binary: MetricName.logloss,
ObjectiveName.multiclass: MetricName.mlogloss,
ObjectiveName.lambdarank: MetricName.ndcg,
ObjectiveName.ndcg: MetricName.map,
}


METRIC_DICT_KEY_MAPPER: Dict[MethodName, str] = {
MethodName.lightgbm: "metric",
MethodName.xgboost: "eval_metric",
Expand All @@ -111,9 +128,13 @@ class LgbMetricName(BaseEnum):
MethodName.xgboost: XgbMetricName.rmse.value,
},
MetricName.mae: {
MethodName.lightgbm: LgbMetricName.mae.value,
MethodName.lightgbm: LgbMetricName.l1.value,
MethodName.xgboost: XgbMetricName.mae.value,
},
MetricName.huber: {
MethodName.lightgbm: LgbMetricName.huber.value,
MethodName.xgboost: XgbMetricName.mphe.value,
},
MetricName.logloss: {
MethodName.lightgbm: LgbMetricName.binary_logloss.value,
MethodName.xgboost: XgbMetricName.logloss.value,
Expand Down Expand Up @@ -146,6 +167,10 @@ class LgbMetricName(BaseEnum):
MethodName.lightgbm: LgbMetricName.tweedie.value,
MethodName.xgboost: XgbMetricName.tweedie_nloglik.value,
},
MetricName.quantile: {
MethodName.lightgbm: LgbMetricName.quantile.value,
MethodName.xgboost: XgbMetricName.quantile.value,
},
MetricName.ndcg: {
MethodName.lightgbm: LgbMetricName.ndcg.value,
MethodName.xgboost: XgbMetricName.ndcg.value,
Expand All @@ -160,15 +185,15 @@ class LgbMetricName(BaseEnum):
@dataclass
class RektMetric:
task_type: TaskType
objective: ObjectiveName
metric: Optional[str]

def __post_init__(self) -> None:
if self.metric:
self.metric = MetricName.get(self.metric)
self.__validate_metric()
else:
_metrics = TASK_METRIC_MAPPER.get(self.task_type)
self.metric = _metrics[0]
self.metric = OBJECTIVE_METRIC_MAPPER.get(self.objective)

self._metric_engine_mapper = METRIC_ENGINE_MAPPER.get(self.metric)

Expand Down
3 changes: 1 addition & 2 deletions rektgbm/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class XgbObjectiveName(BaseEnum):
squarederror: str = "reg:squarederror"
squaredlogerror: str = "reg:squaredlogerror"
pseudohubererror: str = "reg:pseudohubererror"
absoluteerror: str = "reg:reg:absoluteerror"
absoluteerror: str = "reg:absoluteerror"
quantileerror: str = "reg:quantileerror"
logistic: str = "binary:logistic"
logitraw: str = "binary:logitraw"
Expand All @@ -38,7 +38,6 @@ class XgbObjectiveName(BaseEnum):
pairwise: str = "rank:pairwise"
ndcg: str = "rank:ndcg"
map: str = "rank:map"
pairwise: str = "rank:pairwise"
gamma: str = "reg:gamma"
tweedie: str = "reg:tweedie"

Expand Down
22 changes: 18 additions & 4 deletions rektgbm/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import optuna

from rektgbm.addparams import set_additional_params
from rektgbm.base import BaseEnum, MethodName, StateException
from rektgbm.dataset import RektDataset
from rektgbm.engine import RektEngine
Expand All @@ -14,9 +15,7 @@
class _RektMethods(BaseEnum):
both: int = 1
lightgbm: int = 2
lgb: int = 2
xgboost: int = 3
xgb: int = 3


class RektOptimizer:
Expand All @@ -27,6 +26,7 @@ def __init__(
objective: Optional[str] = None,
metric: Optional[str] = None,
params: Optional[Union[List[Callable], Callable]] = None,
additional_params: Dict[str, Any] = {},
) -> None:
if _RektMethods.both == _RektMethods.get(method):
self.method = [MethodName.lightgbm, MethodName.xgboost]
Expand All @@ -45,6 +45,7 @@ def __init__(
self.objective = objective
self._task_type = task_type
self.metric = metric
self.additional_params = additional_params

def optimize_params(
self,
Expand All @@ -61,6 +62,7 @@ def optimize_params(
)
self.rekt_metric = RektMetric(
task_type=self.task_type,
objective=self.rekt_objective.objective,
metric=self.metric,
)
self.studies: Dict[MethodName, optuna.Study] = {}
Expand All @@ -72,7 +74,12 @@ def _study_func(trial: optuna.Trial) -> float:
_param = param(trial=trial)
_objective = self.rekt_objective.get_objective(method=method)
_metric = self.rekt_metric.get_metric(method=method)
_param.update({**_objective, **_metric})
_addtional_params = set_additional_params(
objective=self.rekt_objective.objective,
method=method,
params=self.additional_params,
)
_param.update({**_objective, **_metric, **_addtional_params})

_engine = RektEngine(
params=_param,
Expand All @@ -95,9 +102,16 @@ def best_params(self) -> Dict[str, Any]:
self.__check_optimized()
best_method = min(self.studies, key=lambda k: self.studies[k].best_value)
best_study = self.studies.get(best_method)
_best_params = best_study.best_params
_addtional_params = set_additional_params(
objective=self.rekt_objective.objective,
method=best_method,
params=self.additional_params,
)
_best_params.update({**_addtional_params})
return {
"method": best_method.value,
"params": best_study.best_params,
"params": _best_params,
"task_type": self.task_type.value,
"objective": self.rekt_objective.objective.value,
"metric": self.rekt_metric.metric.value,
Expand Down
1 change: 1 addition & 0 deletions rektgbm/rektgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def fit(
)
self.rekt_metric = RektMetric(
task_type=self._task_type,
objective=self.rekt_objective.objective,
metric=self.metric,
)

Expand Down

0 comments on commit 1d03e14

Please sign in to comment.