Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compatibility for Python 3.12 #239

Merged
merged 23 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cronjob_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- build: windows
os: windows-latest
SKIP_LIGHTGBM: False
python-version: [3.8, 3.9, "3.10", "3.11"]
python-version: [3.8, 3.9, "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@master

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- build: windows
os: windows-latest
SKIP_LIGHTGBM: False
python-version: [3.8, 3.9, "3.10", "3.11"]
python-version: [3.8, 3.9, "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@master

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ repos:
hooks:
- id: ruff-check
name: 'Ruff: Check for errors, styling issues and complexity, and fixes issues if possible (including import order)'
entry: ruff
entry: ruff check
language: system
args: [ --fix, --no-cache ]
- id: ruff-format
Expand Down
1 change: 0 additions & 1 deletion docs/discussion/contributing.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/discussion/vision.md

This file was deleted.

14 changes: 9 additions & 5 deletions probatus/utils/shap_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
import pandas as pd
from shap import Explainer
from shap.explainers._tree import Tree
from shap.explainers import TreeExplainer
from shap.utils import sample
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -59,10 +59,10 @@ def shap_calc(
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).

approximate (boolean):
approximate (boolean):
if True uses shap approximations - less accurate, but very fast. It applies to tree-based explainers only.

check_additivity (boolean):
check_additivity (boolean):
if False SHAP will disable the additivity check for tree-based models.

**shap_kwargs: kwargs of the shap.Explainer
Expand Down Expand Up @@ -104,9 +104,13 @@ def shap_calc(
explainer = Explainer(model, masker=mask, **shap_kwargs)

# For tree-explainers allow for using check_additivity and approximate arguments
if isinstance(explainer, Tree):
# Calculate Shap values
if isinstance(explainer, TreeExplainer):
shap_values = explainer.shap_values(X, check_additivity=check_additivity, approximate=approximate)

# From SHAP version 0.43+ https://github.com/shap/shap/pull/3121 required to
# get the second dimension of calculated Shap values.
if not isinstance(shap_values, list) and len(shap_values.shape) == 3:
shap_values = shap_values[:, :, 1]
else:
# Calculate Shap values
shap_values = explainer.shap_values(X)
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "probatus"
version = "3.0.0"
version = "3.0.1"
requires-python= ">=3.8"
description = "Validation of binary classifiers and data used to develop them"
readme = { file = "README.md", content-type = "text/markdown" }
Expand All @@ -20,6 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
Expand All @@ -32,7 +33,8 @@ dependencies = [
"scipy>=1.4.0",
"joblib>=0.13.2",
"tqdm>=4.41.0",
"shap>=0.41.0,<0.43.0",
"shap==0.43.0 ; python_version == '3.8'",
"shap>=0.43.0 ; python_version != '3.8'",
"numpy>=1.23.2",
"numba>=0.57.0",
]
Expand Down
13 changes: 4 additions & 9 deletions tests/feature_elimination/test_feature_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import get_scorer
from sklearn.model_selection import RandomizedSearchCV, StratifiedGroupKFold, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
Expand Down Expand Up @@ -314,7 +313,7 @@ def test_get_feature_shap_values_per_fold(X, y):
Test with ShapRFECV with features per fold.
"""
clf = DecisionTreeClassifier(max_depth=1)
shap_elimination = ShapRFECV(clf)
shap_elimination = ShapRFECV(clf, scoring="roc_auc")
(
shap_values,
train_score,
Expand All @@ -325,7 +324,6 @@ def test_get_feature_shap_values_per_fold(X, y):
clf,
train_index=[2, 3, 4, 5, 6, 7],
val_index=[0, 1],
scorer=get_scorer("roc_auc"),
)
assert test_score == 1
assert train_score > 0.9
Expand Down Expand Up @@ -545,7 +543,7 @@ def test_get_feature_shap_values_per_fold_early_stopping_lightGBM(complex_data):
X, y = complex_data
y = preprocess_labels(y, y_name="y", index=X.index)

shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5)
shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5, scoring="roc_auc")
(
shap_values,
train_score,
Expand All @@ -556,7 +554,6 @@ def test_get_feature_shap_values_per_fold_early_stopping_lightGBM(complex_data):
clf,
train_index=list(range(5, 50)),
val_index=[0, 1, 2, 3, 4],
scorer=get_scorer("roc_auc"),
)
assert test_score > 0.6
assert train_score > 0.6
Expand All @@ -573,7 +570,7 @@ def test_get_feature_shap_values_per_fold_early_stopping_CatBoost(complex_data,
X["f1_categorical"] = X["f1_categorical"].astype(str).astype("category")
y = preprocess_labels(y, y_name="y", index=X.index)

shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5)
shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5, scoring="roc_auc")
(
shap_values,
train_score,
Expand All @@ -584,7 +581,6 @@ def test_get_feature_shap_values_per_fold_early_stopping_CatBoost(complex_data,
clf,
train_index=list(range(5, 50)),
val_index=[0, 1, 2, 3, 4],
scorer=get_scorer("roc_auc"),
)
assert test_score > 0
assert train_score > 0.6
Expand All @@ -603,7 +599,7 @@ def test_get_feature_shap_values_per_fold_early_stopping_XGBoost(complex_data):
X["f1_categorical"] = X["f1_categorical"].astype(float)
y = preprocess_labels(y, y_name="y", index=X.index)

shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5)
shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5, scoring="roc_auc")
(
shap_values,
train_score,
Expand All @@ -614,7 +610,6 @@ def test_get_feature_shap_values_per_fold_early_stopping_XGBoost(complex_data):
clf,
train_index=list(range(5, 50)),
val_index=[0, 1, 2, 3, 4],
scorer=get_scorer("roc_auc"),
)
assert test_score > 0
assert train_score > 0.6
Expand Down