Skip to content

Commit

Permalink
add feature importance
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 15, 2024
1 parent 414aa07 commit fffb661
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 3 deletions.
30 changes: 28 additions & 2 deletions rektgbm/gbm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from functools import cached_property

import numpy as np

from rektgbm.base import BaseGBM, MethodName, ParamsLike
from rektgbm.base import BaseGBM, MethodName, ParamsLike, StateException
from rektgbm.dataset import RektDataset
from rektgbm.engine import RektEngine
from rektgbm.metric import RektMetric
Expand Down Expand Up @@ -28,6 +30,7 @@ def fit(
dataset: RektDataset,
valid_set: RektDataset | None = None,
):
self._colnames = dataset.colnames
self._task_type = check_task_type(
target=dataset.label,
group=dataset.group,
Expand Down Expand Up @@ -61,14 +64,37 @@ def fit(
task_type=self._task_type,
)
self.engine.fit(dataset=dataset, valid_set=valid_set)
self._is_fitted = True

def predict(self, dataset: RektDataset):
preds = self.engine.predict(dataset=dataset)
if self._task_type in {TaskType.binary, TaskType.regression, TaskType.rank}:
return preds

if self.method == MethodName.lightgbm:
if self.__is_lgb:
preds = np.argmax(preds, axis=1).astype(int)
else:
preds = np.around(preds).astype(int)
return self.label_encoder.inverse_transform(series=preds)

@cached_property
def feature_importance(self) -> np.ndarray:
self.__check_fitted()
importances = {str(k): 0 for k in self._colnames}
if self.__is_lgb:
_importance = self.engine.model.feature_importance(
importance_type="gain"
).tolist()
importances.update({str(k): v for k, v in zip(self._colnames, _importance)})
return importances
else:
importances.update(self.engine.model.get_score(importance_type="gain"))
return importances

@property
def __is_lgb(self) -> bool:
return self.method == MethodName.lightgbm

def __check_fitted(self):
if not getattr(self, "_is_fitted", False):
raise StateException("fit is not completed.")
58 changes: 57 additions & 1 deletion tests/test_gbm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import pytest

from rektgbm.base import MethodName
from rektgbm.base import MethodName, StateException
from rektgbm.dataset import RektDataset
from rektgbm.encoder import RektLabelEncoder
from rektgbm.engine import RektEngine
Expand Down Expand Up @@ -40,6 +41,27 @@ def mock_engine():
return engine


@pytest.fixture
def dummy_dataset():
x_train = np.random.rand(100, 5)
y_train = np.random.randint(0, 2, size=(100,))
colnames = [f"feature_{i}" for i in range(x_train.shape[1])]
x_train = pd.DataFrame(x_train, columns=colnames)
dataset = RektDataset(x_train, y_train)
return dataset


@pytest.fixture
def dummy_gbm_model():
params = {
"learning_rate": 0.1,
"num_leaves": 31,
"n_estimators": 10,
}
model = RektGBM(method="lightgbm", params=params, task_type="binary")
return model


@patch("rektgbm.gbm.RektEngine", autospec=True)
def test_rektgbm_fit(mock_engine_class, mock_dataset, mock_valid_set, mock_engine):
mock_engine_class.return_value = mock_engine
Expand Down Expand Up @@ -135,3 +157,37 @@ def test_rektgbm_fit_rank_raises_value_error(mock_dataset):
)
with pytest.raises(ValueError):
gbm.fit(dataset=mock_dataset)


def test_feature_importance_before_fit_raises(dummy_gbm_model):
with pytest.raises(StateException, match="fit is not completed"):
_ = dummy_gbm_model.feature_importance


def test_feature_importance_after_fit(dummy_gbm_model, dummy_dataset):
dummy_gbm_model.fit(dataset=dummy_dataset)
feature_importances = dummy_gbm_model.feature_importance

assert isinstance(
feature_importances, dict
), "Feature importances should be a dictionary"
assert len(feature_importances) == len(
dummy_dataset.colnames
), "Feature importance length mismatch"
for feature in dummy_dataset.colnames:
assert (
feature in feature_importances
), f"Feature {feature} not found in importance"


def test_feature_importance_nonzero(dummy_gbm_model, dummy_dataset):
"""Test that at least some feature importances are non-zero after training"""
dummy_gbm_model.fit(dataset=dummy_dataset)
feature_importances = dummy_gbm_model.feature_importance

non_zero_importances = sum(
importance > 0 for importance in feature_importances.values()
)
assert (
non_zero_importances > 0
), "At least one feature should have non-zero importance"

0 comments on commit fffb661

Please sign in to comment.