From a9e56dae757bdf9c1453f11d6a1a3dd08cc3ae53 Mon Sep 17 00:00:00 2001 From: Luis Catala Date: Wed, 17 Jan 2024 14:51:08 +0100 Subject: [PATCH 1/5] Select between ._fit() of .fit() depending on the sklearn version used --- src/m5py/main.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/m5py/main.py b/src/m5py/main.py index 4716809..958629a 100644 --- a/src/m5py/main.py +++ b/src/m5py/main.py @@ -17,6 +17,7 @@ from sklearn.tree._tree import DOUBLE from sklearn.utils import check_array from sklearn.utils.validation import check_is_fitted +from sklearn import __version__ as sklearn_version from m5py.linreg_utils import linreg_model_to_text, DeNormalizableMixIn, DeNormalizableLinearRegression @@ -210,8 +211,14 @@ def fit(self, X, y: np.ndarray, sample_weight=None, check_input=True, X_idx_sort if self.use_smoothing not in [False, np.bool_(False), "installed", "on_prediction"]: raise ValueError("use_smoothing: Unexpected value: %s, please report it as issue." % self.use_smoothing) + # Get the correct fit method name based on the sklearn version used + sklearn_version_tuple = tuple(map(int, sklearn_version.split('.'))) + fit_method_name = "fit" if sklearn_version_tuple <= (1, 3) else "_fit" + # (1) Build the initial tree as usual - super(M5Base, self).fit(X, y, sample_weight=sample_weight, check_input=check_input) + fit_method = getattr(super(M5Base, self), fit_method_name) + fit_method(X, y, sample_weight=sample_weight, check_input=check_input) + if self.debug_prints: logger.debug("(debug_prints) Initial tree:") From 7680f1e13f450346543fd15d13235c8a354c4dec Mon Sep 17 00:00:00 2001 From: Luis Catala Date: Mon, 22 Jan 2024 09:50:12 +0100 Subject: [PATCH 2/5] Check sklearn version with packaging.version.Version --- src/m5py/main.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/m5py/main.py b/src/m5py/main.py index 958629a..5b24166 100644 --- a/src/m5py/main.py +++ b/src/m5py/main.py @@ -17,10 +17,15 @@ from sklearn.tree._tree import DOUBLE from sklearn.utils import check_array from sklearn.utils.validation import check_is_fitted -from sklearn import __version__ as sklearn_version +import sklearn.__version__ from m5py.linreg_utils import linreg_model_to_text, DeNormalizableMixIn, DeNormalizableLinearRegression +from packaging.version import Version + +SKLEARN_VERSION = Version(sklearn.__version__) +SKLEARN13_OR_GREATER = SKLEARN_VERSION >= Version("1.3.0") + __all__ = ["M5Base", "M5Prime"] @@ -211,11 +216,12 @@ def fit(self, X, y: np.ndarray, sample_weight=None, check_input=True, X_idx_sort if self.use_smoothing not in [False, np.bool_(False), "installed", "on_prediction"]: raise ValueError("use_smoothing: Unexpected value: %s, please report it as issue." % self.use_smoothing) - # Get the correct fit method name based on the sklearn version used - sklearn_version_tuple = tuple(map(int, sklearn_version.split('.'))) - fit_method_name = "fit" if sklearn_version_tuple <= (1, 3) else "_fit" # (1) Build the initial tree as usual + + # Get the correct fit method name based on the sklearn version used + fit_method_name = "_fit" if SKLEARN13_OR_GREATER else "fit" + fit_method = getattr(super(M5Base, self), fit_method_name) fit_method(X, y, sample_weight=sample_weight, check_input=check_input) From 6e4d77f3be4208da9e119236bcd1e2687528dd75 Mon Sep 17 00:00:00 2001 From: Luis Catala Date: Mon, 22 Jan 2024 09:50:26 +0100 Subject: [PATCH 3/5] Update changelog --- docs/changelog.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 8a2964b..a65dea7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,10 @@ # Changelog +### 0.3.2 - Fixed compliance with sklearn 1.3.0 + + * Fixed `AttributeError: 'super' object has no attribute 'fit' `. + PR [#16](https://github.com/smarie/python-m5p/pull/16) by [lccatala](https://github.com/lccatala) + ### 0.3.1 - Fixed compliance with sklearn 1.1.0 * Fixed `TypeError: fit() got an unexpected keyword argument 'X_idx_sorted'`. From 6aa54c2bdf7c62feccf79ddf58e29b9cff1bbbb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sylvain=20Mari=C3=A9?= Date: Mon, 22 Jan 2024 12:18:24 +0100 Subject: [PATCH 4/5] Update src/m5py/main.py --- src/m5py/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/m5py/main.py b/src/m5py/main.py index 5b24166..6b92d24 100644 --- a/src/m5py/main.py +++ b/src/m5py/main.py @@ -17,7 +17,7 @@ from sklearn.tree._tree import DOUBLE from sklearn.utils import check_array from sklearn.utils.validation import check_is_fitted -import sklearn.__version__ +from sklearn import __version__ as sklearn_version from m5py.linreg_utils import linreg_model_to_text, DeNormalizableMixIn, DeNormalizableLinearRegression From 10e132a232284b433ef2cc6ad1b521b93c6edfa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sylvain=20Mari=C3=A9?= Date: Mon, 22 Jan 2024 12:18:30 +0100 Subject: [PATCH 5/5] Update src/m5py/main.py --- src/m5py/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/m5py/main.py b/src/m5py/main.py index 6b92d24..2eda296 100644 --- a/src/m5py/main.py +++ b/src/m5py/main.py @@ -23,7 +23,7 @@ from packaging.version import Version -SKLEARN_VERSION = Version(sklearn.__version__) +SKLEARN_VERSION = Version(sklearn_version) SKLEARN13_OR_GREATER = SKLEARN_VERSION >= Version("1.3.0")