Skip to content

Commit

Permalink
Merge pull request #113 from ing-bank/cronjob_test_dependencies
Browse files Browse the repository at this point in the history
Add github action that will run unit tests everyday, closes #42
  • Loading branch information
Mateusz Garbacz authored Mar 26, 2021
2 parents 2e98264 + e526bcc commit d483011
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 39 deletions.
44 changes: 44 additions & 0 deletions .github/workflows/cronjob_unit_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Cron Test Dependencies

# Controls when the action will run.
# Everyday at 4:05
# See https://crontab.guru/#5_4_*_*_*
on:
schedule:
- cron: "5 4 * * *"

jobs:
run:
name: Run unit tests
runs-on: ${{ matrix.os }}
strategy:
matrix:
build: [macos, ubuntu, windows]
include:
- build: macos
os: macos-latest
SKIP_LIGHTGBM: True
- build: ubuntu
os: ubuntu-latest
SKIP_LIGHTGBM: False
- build: windows
os: windows-latest
SKIP_LIGHTGBM: False
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@master
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
env:
SKIP_LIGHTGBM: ${{ matrix.SKIP_LIGHTGBM }}
run: |
pip3 install --upgrade setuptools pip
pip3 install ".[all]"
- name: Run unit tests
env:
SKIP_LIGHTGBM: ${{ matrix.SKIP_LIGHTGBM }}
run: |
pytest
55 changes: 24 additions & 31 deletions tests/interpret/test_model_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import pandas as pd
from probatus.interpret import ShapModelInterpreter
from unittest.mock import patch
import os


Expand Down Expand Up @@ -116,16 +115,14 @@ def test_shap_interpret(fitted_tree, X_train, y_train, X_test, y_test, expected_
assert test_auc == pytest.approx(0.833, 0.01)

# Check if plots work for such dataset
with patch("matplotlib.pyplot.figure") as _:
with patch("shap.plots._waterfall.waterfall_legacy"):
ax1 = shap_interpret.plot("importance", target_set="test", show=False)
ax2 = shap_interpret.plot("summary", target_set="test", show=False)
ax3 = shap_interpret.plot("dependence", target_columns="col_3", target_set="test", show=False)
ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False)
ax5 = shap_interpret.plot("importance", target_set="train")
ax6 = shap_interpret.plot("summary", target_set="train")
ax7 = shap_interpret.plot("dependence", target_columns="col_3", target_set="train")
ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train")
ax1 = shap_interpret.plot("importance", target_set="test", show=False)
ax2 = shap_interpret.plot("summary", target_set="test", show=False)
ax3 = shap_interpret.plot("dependence", target_columns="col_3", target_set="test", show=False)
ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False)
ax5 = shap_interpret.plot("importance", target_set="train", show=False)
ax6 = shap_interpret.plot("summary", target_set="train", show=False)
ax7 = shap_interpret.plot("dependence", target_columns="col_3", target_set="train", show=False)
ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False)
assert not (isinstance(ax1, list))
assert not (isinstance(ax2, list))
assert isinstance(ax3, list) and len(ax4) == 2
Expand Down Expand Up @@ -167,16 +164,14 @@ def test_shap_interpret_lin_models(
assert test_auc == pytest.approx(0.833, 0.01)

# Check if plots work for such dataset
with patch("matplotlib.pyplot.figure") as _:
with patch("shap.plots._waterfall.waterfall_legacy"):
ax1 = shap_interpret.plot("importance", target_set="test", show=False)
ax2 = shap_interpret.plot("summary", target_set="test", show=False)
ax3 = shap_interpret.plot("dependence", target_columns="col_3", target_set="test", show=False)
ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False)
ax5 = shap_interpret.plot("importance", target_set="train")
ax6 = shap_interpret.plot("summary", target_set="train")
ax7 = shap_interpret.plot("dependence", target_columns="col_3", target_set="train")
ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train")
ax1 = shap_interpret.plot("importance", target_set="test", show=False)
ax2 = shap_interpret.plot("summary", target_set="test", show=False)
ax3 = shap_interpret.plot("dependence", target_columns="col_3", target_set="test", show=False)
ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False)
ax5 = shap_interpret.plot("importance", target_set="train", show=False)
ax6 = shap_interpret.plot("summary", target_set="train", show=False)
ax7 = shap_interpret.plot("dependence", target_columns="col_3", target_set="train", show=False)
ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False)
assert not (isinstance(ax1, list))
assert not (isinstance(ax2, list))
assert isinstance(ax3, list) and len(ax4) == 2
Expand Down Expand Up @@ -263,16 +258,14 @@ def test_shap_interpret_complex_data(complex_data_split, complex_fitted_lightgbm
assert importance_df.shape[0] == X_train.shape[1]

# Check if plots work for such dataset
with patch("matplotlib.pyplot.figure") as _:
with patch("shap.plots._waterfall.waterfall_legacy"):
ax1 = shap_interpret.plot("importance", target_set="test")
ax2 = shap_interpret.plot("summary", target_set="test")
ax3 = shap_interpret.plot("dependence", target_columns="f2_missing", target_set="test")
ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test")
ax5 = shap_interpret.plot("importance", target_set="train")
ax6 = shap_interpret.plot("summary", target_set="train")
ax7 = shap_interpret.plot("dependence", target_columns="f2_missing", target_set="train")
ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train")
ax1 = shap_interpret.plot("importance", target_set="test", show=False)
ax2 = shap_interpret.plot("summary", target_set="test", show=False)
ax3 = shap_interpret.plot("dependence", target_columns="f2_missing", target_set="test", show=False)
ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False)
ax5 = shap_interpret.plot("importance", target_set="train", show=False)
ax6 = shap_interpret.plot("summary", target_set="train", show=False)
ax7 = shap_interpret.plot("dependence", target_columns="f2_missing", target_set="train", show=False)
ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False)
assert not (isinstance(ax1, list))
assert not (isinstance(ax2, list))
assert isinstance(ax3, list) and len(ax4) == 2
Expand Down
6 changes: 2 additions & 4 deletions tests/interpret/test_shap_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from unittest.mock import patch
from probatus.interpret.shap_dependence import DependencePlotter
from probatus.utils.exceptions import NotFittedError
import os
Expand Down Expand Up @@ -121,9 +120,8 @@ def test_fit_complex(complex_data_split, complex_fitted_lightgbm):
assert plotter.fitted is True

# Check if plotting doesnt cause errors
with patch("matplotlib.pyplot.figure") as _:
for binning in ["simple", "agglomerative", "quantile"]:
_ = plotter.plot(feature="f2_missing", type_binning=binning)
for binning in ["simple", "agglomerative", "quantile"]:
_ = plotter.plot(feature="f2_missing", type_binning=binning, show=False)


def test_get_X_y_shap_with_q_cut_normal(X_y, clf):
Expand Down
6 changes: 2 additions & 4 deletions tests/metric_volatility/test_metric_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,7 @@ def test_fit_compute_full_process(X_df, y_series):
assert report.shape == (2, 6)

# Check if plot runs
with patch("matplotlib.pyplot.figure") as _:
vol.plot()
vol.plot(show=False)


@pytest.mark.skipif(os.environ.get("SKIP_LIGHTGBM") == "true", reason="LightGBM tests disabled")
Expand All @@ -420,5 +419,4 @@ def test_fit_compute_complex(complex_data, complex_lightgbm):
assert report.shape == (1, 6)

# Check if plot runs
with patch("matplotlib.pyplot.figure") as _:
vol.plot(show=False)
vol.plot(show=False)

0 comments on commit d483011

Please sign in to comment.