Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chore] simplify annotation #25

Merged
merged 10 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions rektgbm/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Callable, Optional, Union
from typing import Callable

import lightgbm as lgb
import numpy as np
Expand Down Expand Up @@ -38,11 +38,12 @@ class MethodName(BaseEnum):
xgboost: str = "xgboost"


XdataLike = Union[pd.DataFrame, pd.Series, np.ndarray]
YdataLike = Union[pd.Series, np.ndarray]
ModelLike = Union[lgb.basic.Booster, xgb.Booster]
DataLike = Union[lgb.basic.Dataset, xgb.DMatrix]
DataFuncLike = Callable[[XdataLike, Optional[YdataLike]], Union[DataLike, XdataLike]]
XdataLike = pd.DataFrame | pd.Series | np.ndarray
YdataLike = pd.Series | np.ndarray
ModelLike = lgb.basic.Booster | xgb.Booster
DataLike = lgb.basic.Dataset | xgb.DMatrix
DataFuncLike = Callable[[XdataLike, YdataLike | None], DataLike | XdataLike]
ParamsLike = dict[str, float | int | str | bool]


class StateException(Exception):
Expand Down
18 changes: 9 additions & 9 deletions rektgbm/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Optional

import lightgbm as lgb
import numpy as np
Expand All @@ -24,7 +24,7 @@ class _TypeName(BaseEnum):
predict_dtype: int = 2


_METHOD_FUNC_TYPE_MAPPER: Dict[MethodName, Dict[_TypeName, DataFuncLike]] = {
_METHOD_FUNC_TYPE_MAPPER: dict[MethodName, dict[_TypeName, DataFuncLike]] = {
MethodName.lightgbm: {
_TypeName.train_dtype: lgb.Dataset,
_TypeName.predict_dtype: lambda data: data,
Expand All @@ -43,7 +43,7 @@ def _get_dtype(method: MethodName, dtype: _TypeName):

def _train_valid_split(
data: XdataLike, label: YdataLike, task_type: TaskType
) -> Tuple[XdataLike, XdataLike, YdataLike, YdataLike]:
) -> tuple[XdataLike, XdataLike, YdataLike, YdataLike]:
if task_type == TaskType.regression:
for _bin in range(5, 0, -1):
try:
Expand All @@ -60,8 +60,8 @@ def _train_valid_split(
@dataclass
class RektDataset:
data: XdataLike
label: Optional[YdataLike] = None
group: Optional[YdataLike] = None
label: YdataLike | None = None
group: YdataLike | None = None
reference: Optional["RektDataset"] = None
skip_post_init: bool = False

Expand All @@ -74,7 +74,7 @@ def __post_init__(self) -> None:
self.data = pd.DataFrame(self.data)

if self.reference is None:
self.encoders: Dict[str, RektLabelEncoder] = {}
self.encoders: dict[str, RektLabelEncoder] = {}
for col in self.data.columns:
if self.data[col].dtype == "object":
_encoder = RektLabelEncoder()
Expand Down Expand Up @@ -103,14 +103,14 @@ def dtrain(self, method: MethodName) -> DataLike:
)
return train_dtype(data=self.data, label=self.label, group=self.group)

def dpredict(self, method: MethodName) -> Union[DataLike, XdataLike]:
def dpredict(self, method: MethodName) -> DataLike | XdataLike:
predict_dtype = _get_dtype(
method=method,
dtype=_TypeName.predict_dtype,
)
return predict_dtype(data=self.data)

def split(self, task_type: TaskType) -> Tuple["RektDataset", "RektDataset"]:
def split(self, task_type: TaskType) -> tuple["RektDataset", "RektDataset"]:
self.__check_label_available()
train_data, valid_data, train_label, valid_label = _train_valid_split(
data=self.data,
Expand All @@ -123,7 +123,7 @@ def split(self, task_type: TaskType) -> Tuple["RektDataset", "RektDataset"]:

def dsplit(
self, method: MethodName, task_type: TaskType
) -> Tuple[DataLike, DataLike]:
) -> tuple[DataLike, DataLike]:
self.__check_label_available()
train_data, valid_data, train_label, valid_label = _train_valid_split(
data=self.data,
Expand Down
8 changes: 3 additions & 5 deletions rektgbm/engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Any, Dict, Optional

import lightgbm as lgb
import numpy as np
import xgboost as xgb

from rektgbm.base import BaseGBM, MethodName, StateException
from rektgbm.base import BaseGBM, MethodName, ParamsLike, StateException
from rektgbm.dataset import RektDataset
from rektgbm.metric import METRIC_DICT_KEY_MAPPER, LgbMetricName
from rektgbm.task import TaskType
Expand All @@ -16,7 +14,7 @@ class RektEngine(BaseGBM):
def __init__(
self,
method: MethodName,
params: Dict[str, Any],
params: ParamsLike,
task_type: TaskType,
) -> None:
self.method = method
Expand All @@ -26,7 +24,7 @@ def __init__(
def fit(
self,
dataset: RektDataset,
valid_set: Optional[RektDataset],
valid_set: RektDataset | None,
) -> None:
if valid_set is None:
dtrain, dvalid = dataset.dsplit(
Expand Down
14 changes: 6 additions & 8 deletions rektgbm/gbm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Any, Dict, Optional

import numpy as np

from rektgbm.base import BaseGBM, MethodName
from rektgbm.base import BaseGBM, MethodName, ParamsLike
from rektgbm.dataset import RektDataset
from rektgbm.engine import RektEngine
from rektgbm.metric import RektMetric
Expand All @@ -14,10 +12,10 @@ class RektGBM(BaseGBM):
def __init__(
self,
method: str,
params: Dict[str, Any],
task_type: Optional[str] = None,
objective: Optional[str] = None,
metric: Optional[str] = None,
params: ParamsLike,
task_type: str | None = None,
objective: str | None = None,
metric: str | None = None,
):
self.method = MethodName.get(method)
self.params = params
Expand All @@ -28,7 +26,7 @@ def __init__(
def fit(
self,
dataset: RektDataset,
valid_set: Optional[RektDataset] = None,
valid_set: RektDataset | None = None,
):
self._task_type = check_task_type(
target=dataset.label,
Expand Down
13 changes: 6 additions & 7 deletions rektgbm/metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Dict, List, Optional

from rektgbm.base import BaseEnum, MethodName
from rektgbm.objective import ObjectiveName
Expand Down Expand Up @@ -73,7 +72,7 @@ class LgbMetricName(BaseEnum):
kullback_leibler: str = "kullback_leibler"


TASK_METRIC_MAPPER: Dict[TaskType, List[MetricName]] = {
TASK_METRIC_MAPPER: dict[TaskType, list[MetricName]] = {
TaskType.regression: [
MetricName.rmse,
MetricName.mae,
Expand All @@ -97,7 +96,7 @@ class LgbMetricName(BaseEnum):
}


OBJECTIVE_METRIC_MAPPER: Dict[ObjectiveName, MetricName] = {
OBJECTIVE_METRIC_MAPPER: dict[ObjectiveName, MetricName] = {
ObjectiveName.rmse: MetricName.rmse,
ObjectiveName.mae: MetricName.mae,
ObjectiveName.huber: MetricName.huber,
Expand All @@ -110,12 +109,12 @@ class LgbMetricName(BaseEnum):
}


METRIC_DICT_KEY_MAPPER: Dict[MethodName, str] = {
METRIC_DICT_KEY_MAPPER: dict[MethodName, str] = {
MethodName.lightgbm: "metric",
MethodName.xgboost: "eval_metric",
}

METRIC_ENGINE_MAPPER: Dict[MetricName, Dict[MethodName, str]] = {
METRIC_ENGINE_MAPPER: dict[MetricName, dict[MethodName, str]] = {
MetricName.rmse: {
MethodName.lightgbm: LgbMetricName.rmse.value,
MethodName.xgboost: XgbMetricName.rmse.value,
Expand Down Expand Up @@ -171,7 +170,7 @@ class LgbMetricName(BaseEnum):
class RektMetric:
task_type: TaskType
objective: ObjectiveName
metric: Optional[str]
metric: str | None

def __post_init__(self) -> None:
if self.metric:
Expand All @@ -185,7 +184,7 @@ def __post_init__(self) -> None:
def get_metric_str(self, method: MethodName) -> str:
return self._metric_engine_mapper.get(method)

def get_metric_dict(self, method: MethodName) -> Dict[str, str]:
def get_metric_dict(self, method: MethodName) -> dict[str, str]:
return {METRIC_DICT_KEY_MAPPER.get(method): self.get_metric_str(method=method)}

def __validate_metric(self) -> None:
Expand Down
9 changes: 4 additions & 5 deletions rektgbm/objective.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Dict, List, Optional

from rektgbm.base import BaseEnum, MethodName
from rektgbm.task import TaskType
Expand Down Expand Up @@ -59,7 +58,7 @@ class LgbObjectiveName(BaseEnum):
rank_xendcg: str = "rank_xendcg"


TASK_OBJECTIVE_MAPPER: Dict[TaskType, List[ObjectiveName]] = {
TASK_OBJECTIVE_MAPPER: dict[TaskType, list[ObjectiveName]] = {
TaskType.regression: [
ObjectiveName.rmse,
ObjectiveName.mae,
Expand All @@ -80,7 +79,7 @@ class LgbObjectiveName(BaseEnum):
}


OBJECTIVE_ENGINE_MAPPER: Dict[ObjectiveName, Dict[MethodName, str]] = {
OBJECTIVE_ENGINE_MAPPER: dict[ObjectiveName, dict[MethodName, str]] = {
ObjectiveName.rmse: {
MethodName.lightgbm: LgbObjectiveName.rmse.value,
MethodName.xgboost: XgbObjectiveName.squarederror.value,
Expand Down Expand Up @@ -123,7 +122,7 @@ class LgbObjectiveName(BaseEnum):
@dataclass
class RektObjective:
task_type: TaskType
objective: Optional[str]
objective: str | None

def __post_init__(self) -> None:
if self.objective:
Expand All @@ -138,7 +137,7 @@ def __post_init__(self) -> None:
def get_objective_str(self, method: MethodName) -> str:
return self._objective_engine_mapper.get(method)

def get_objective_dict(self, method: MethodName) -> Dict[str, str]:
def get_objective_dict(self, method: MethodName) -> dict[str, str]:
return {OBJECTIVE_DICT_KEY: self.get_objective_str(method=method)}

def __validate_objective(self) -> None:
Expand Down
22 changes: 11 additions & 11 deletions rektgbm/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Callable

import optuna

from rektgbm.base import BaseEnum, MethodName, StateException
from rektgbm.base import BaseEnum, MethodName, ParamsLike, StateException
from rektgbm.dataset import RektDataset
from rektgbm.engine import RektEngine
from rektgbm.metric import RektMetric
Expand All @@ -21,11 +21,11 @@ class RektOptimizer:
def __init__(
self,
method: str = "both",
task_type: Optional[str] = None,
objective: Optional[str] = None,
metric: Optional[str] = None,
params: Optional[Union[List[Callable], Callable]] = None,
additional_params: Dict[str, Any] = {},
task_type: str | None = None,
objective: str | None = None,
metric: str | None = None,
params: list[Callable] | Callable | None = None,
additional_params: ParamsLike = {},
) -> None:
if _RektMethods.both == _RektMethods.get(method):
self.method = [MethodName.lightgbm, MethodName.xgboost]
Expand All @@ -50,8 +50,8 @@ def optimize_params(
self,
dataset: RektDataset,
n_trials: int,
valid_set: Optional[RektDataset] = None,
) -> Dict[str, Any]:
valid_set: RektDataset | None = None,
) -> None:
self._task_type: TaskType = check_task_type(
target=dataset.label,
group=dataset.group,
Expand Down Expand Up @@ -85,7 +85,7 @@ def optimize_params(
if self.__is_label_encoder_used:
valid_set.transform_label(label_encoder=_label_encoder)

self.studies: Dict[MethodName, optuna.Study] = {}
self.studies: dict[MethodName, optuna.Study] = {}
for method, param in zip(self.method, self.params):
_addtional_params = set_additional_params(
objective=self.rekt_objective.objective,
Expand Down Expand Up @@ -119,7 +119,7 @@ def _study_func(trial: optuna.Trial) -> float:
self._is_optimized = True

@property
def best_params(self) -> Dict[str, Any]:
def best_params(self) -> dict[str, str | int | float | ParamsLike | None]:
self.__check_optimized()
best_method = min(self.studies, key=lambda k: self.studies[k].best_value)
best_study = self.studies.get(best_method)
Expand Down
14 changes: 6 additions & 8 deletions rektgbm/param.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Any, Dict, Optional, Union

from optuna import Trial

from rektgbm.base import MethodName
from rektgbm.base import MethodName, ParamsLike
from rektgbm.metric import MetricName
from rektgbm.objective import ObjectiveName


def get_lgb_params(trial: Trial) -> Dict[str, Union[float, int]]:
def get_lgb_params(trial: Trial) -> ParamsLike:
# https://lightgbm.readthedocs.io/en/latest/Parameters.html#learning-control-parameters
return {
"verbosity": -1,
Expand All @@ -22,7 +20,7 @@ def get_lgb_params(trial: Trial) -> Dict[str, Union[float, int]]:
}


def get_xgb_params(trial: Trial) -> Dict[str, Union[float, int]]:
def get_xgb_params(trial: Trial) -> ParamsLike:
# https://xgboost.readthedocs.io/en/stable/parameter.html#parameters-for-tree-booster
return {
"verbosity": 0,
Expand All @@ -43,12 +41,12 @@ def get_xgb_params(trial: Trial) -> Dict[str, Union[float, int]]:


def set_additional_params(
params: Dict[str, Any],
params: ParamsLike,
objective: ObjectiveName,
metric: str,
method: MethodName,
num_class: Optional[int],
) -> Dict[str, Any]:
num_class: int | None,
) -> ParamsLike:
_params = params.copy()
if objective == ObjectiveName.quantile:
if method == MethodName.lightgbm and "quantile_alpha" in _params.keys():
Expand Down
8 changes: 3 additions & 5 deletions rektgbm/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict, List, Optional

from sklearn.utils.multiclass import type_of_target

from rektgbm.base import BaseEnum, YdataLike
Expand All @@ -18,7 +16,7 @@ class SklearnTaskType(BaseEnum):
multiclass: int = 3


SKLEARN_TASK_TYPE_MAPPER: Dict[SklearnTaskType, List[TaskType]] = {
SKLEARN_TASK_TYPE_MAPPER: dict[SklearnTaskType, list[TaskType]] = {
SklearnTaskType.continuous: [TaskType.regression],
SklearnTaskType.binary: [TaskType.binary],
SklearnTaskType.multiclass: [TaskType.multiclass, TaskType.rank],
Expand All @@ -27,8 +25,8 @@ class SklearnTaskType(BaseEnum):

def check_task_type(
target: YdataLike,
group: Optional[YdataLike],
task_type: Optional[str],
group: YdataLike | None,
task_type: str | None,
) -> TaskType:
if group is not None:
return TaskType.rank
Expand Down
Loading