From f64bb8fdca42f639e56fe76f8bef26087e2ad953 Mon Sep 17 00:00:00 2001 From: Tony Bagnall Date: Mon, 3 Feb 2025 16:25:38 +0000 Subject: [PATCH] unexclude SFA --- .../_yield_estimator_checks.py | 56 ------------------- aeon/testing/testing_config.py | 2 +- 2 files changed, 1 insertion(+), 57 deletions(-) diff --git a/aeon/testing/estimator_checking/_yield_estimator_checks.py b/aeon/testing/estimator_checking/_yield_estimator_checks.py index 4bd2346c33..a35fb89667 100644 --- a/aeon/testing/estimator_checking/_yield_estimator_checks.py +++ b/aeon/testing/estimator_checking/_yield_estimator_checks.py @@ -9,7 +9,6 @@ import joblib import numpy as np -import pandas as pd from sklearn.exceptions import NotFittedError from aeon.anomaly_detection.base import BaseAnomalyDetector @@ -605,61 +604,6 @@ def check_raises_not_fitted_error(estimator, datatype): _run_estimator_method(estimator, method, datatype, "test") -def _equal_outputs(output1, output2): - """Test whether two outputs from an estimator are logically identical. - - Valid data structures are: - 1. float: returns a single value (e.g. forecasting) - 2. numpy array: - scalars: stores an equal length collection or series (default) - objects: an array of arrays stored as objects (e.g. SimilaritySearch) - 3. dict: a histogram of counts, usually of discretised sub-series (e.g. SFA) - 4. pd.DataFrame: series stored in dataframe (e.g. Dobin) - 5. list: stores possibly unequal length series in a format 2-4 - 6. tuple: stores two or more series/collections in a format 2-4 (e.g. imbalance - transformers) - - """ - if type(output1) is not type(output2): - return False - if np.issubdtype(type(output1), np.floating): - return np.isclose(output1, output2) - if np.issubdtype(type(output1), np.bool_): - return output1 == output2 - if isinstance(output1, np.ndarray): # 1. X an equal length collection or series - if np.isscalar(output1): - return np.allclose(output1, output2, equal_nan=True) - for i in range(len(output1)): - if not _equal_outputs(output1[i], output2[i]): - return False - return True - if isinstance(output1, dict): # 2. X a dictionary, dense collection or series - if output1.keys() != output2.keys(): - return False - for k in output1.keys(): - if not _equal_outputs(output1[k], output2[k]): - return False - return True - if isinstance(output1, pd.DataFrame) or isinstance(output1, pd.Series): - # 3. X a dataframe - return np.allclose(output1.values, output2.values, equal_nan=True) - if isinstance(output1, list): # X a possibly unequal length collection - if len(output1) != len(output2): - return False - for i in range(len(output1)): - if not _equal_outputs(output1[i], output2[i]): - return False - return True - if isinstance(output1, tuple): # returns (X,y) - if len(output1) != len(output2): - return False - for i in range(len(output1)): - if not _equal_outputs(output1[i], output2[i]): - return False - return True - return False - - def check_persistence_via_pickle(estimator, datatype): """Check that we can pickle all estimators.""" estimator = _clone_estimator(estimator, random_state=0) diff --git a/aeon/testing/testing_config.py b/aeon/testing/testing_config.py index bb099220d5..e352a232e7 100644 --- a/aeon/testing/testing_config.py +++ b/aeon/testing/testing_config.py @@ -50,7 +50,7 @@ "RSASTClassifier": ["check_fit_deterministic"], "SAST": ["check_fit_deterministic"], "RSAST": ["check_fit_deterministic"], - "SFA": ["check_persistence_via_pickle", "check_fit_deterministic"], + # "SFA": ["check_persistence_via_pickle", "check_fit_deterministic"], # missed in legacy testing, changes state in predict/transform "FLUSSSegmenter": ["check_non_state_changing_method"], "InformationGainSegmenter": ["check_non_state_changing_method"],