Skip to content

Commit

Permalink
[Chore] refactor gbm (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 13, 2024
1 parent e8db7bc commit f54f29b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
24 changes: 7 additions & 17 deletions rektgbm/gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ def fit(
)
if self._task_type in {TaskType.binary, TaskType.multiclass, TaskType.rank}:
self.label_encoder = dataset.fit_transform_label()
self._is_label_encoder_used = True

if valid_set is not None and self.__is_label_encoder_used:
valid_set.transform_label(label_encoder=self.label_encoder)
if valid_set:
valid_set.transform_label(label_encoder=self.label_encoder)

_objective = self.rekt_objective.get_objective_dict(method=self.method)
_metric = self.rekt_metric.get_metric_dict(method=self.method)
Expand All @@ -69,16 +67,8 @@ def predict(self, dataset: RektDataset):
if self._task_type in {TaskType.binary, TaskType.regression, TaskType.rank}:
return preds

if self._task_type == TaskType.multiclass:
if self.method == MethodName.lightgbm:
preds = np.argmax(preds, axis=1).astype(int)
else:
preds = np.around(preds).astype(int)

if self.__is_label_encoder_used:
preds = self.label_encoder.inverse_transform(series=preds)
return preds

@property
def __is_label_encoder_used(self) -> bool:
return getattr(self, "_is_label_encoder_used", False)
if self.method == MethodName.lightgbm:
preds = np.argmax(preds, axis=1).astype(int)
else:
preds = np.around(preds).astype(int)
return self.label_encoder.inverse_transform(series=preds)
3 changes: 3 additions & 0 deletions tests/test_gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from rektgbm.base import MethodName
from rektgbm.dataset import RektDataset
from rektgbm.encoder import RektLabelEncoder
from rektgbm.engine import RektEngine
from rektgbm.gbm import RektGBM
from rektgbm.task import TaskType
Expand Down Expand Up @@ -90,6 +91,8 @@ def test_rektgbm_predict_multiclass(mock_dataset, mock_engine):
gbm.engine = mock_engine
gbm._task_type = TaskType.multiclass
gbm._is_fitted = True
gbm.label_encoder = RektLabelEncoder()
gbm.label_encoder.fit_label([0, 1, 2])

mock_engine.predict.return_value = np.array(
[[0.1, 0.7, 0.2], [0.3, 0.4, 0.3], [0.2, 0.2, 0.6]]
Expand Down

1 comment on commit f54f29b

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
91 0 💤 0 ❌ 0 🔥 5.813s ⏱️

Please sign in to comment.