Skip to content

Commit

Permalink
unexclude SFA
Browse files Browse the repository at this point in the history
  • Loading branch information
TonyBagnall committed Feb 3, 2025
1 parent 7a27ede commit f64bb8f
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 57 deletions.
56 changes: 0 additions & 56 deletions aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import joblib
import numpy as np
import pandas as pd
from sklearn.exceptions import NotFittedError

from aeon.anomaly_detection.base import BaseAnomalyDetector
Expand Down Expand Up @@ -605,61 +604,6 @@ def check_raises_not_fitted_error(estimator, datatype):
_run_estimator_method(estimator, method, datatype, "test")


def _equal_outputs(output1, output2):
"""Test whether two outputs from an estimator are logically identical.
Valid data structures are:
1. float: returns a single value (e.g. forecasting)
2. numpy array:
scalars: stores an equal length collection or series (default)
objects: an array of arrays stored as objects (e.g. SimilaritySearch)
3. dict: a histogram of counts, usually of discretised sub-series (e.g. SFA)
4. pd.DataFrame: series stored in dataframe (e.g. Dobin)
5. list: stores possibly unequal length series in a format 2-4
6. tuple: stores two or more series/collections in a format 2-4 (e.g. imbalance
transformers)
"""
if type(output1) is not type(output2):
return False
if np.issubdtype(type(output1), np.floating):
return np.isclose(output1, output2)
if np.issubdtype(type(output1), np.bool_):
return output1 == output2
if isinstance(output1, np.ndarray): # 1. X an equal length collection or series
if np.isscalar(output1):
return np.allclose(output1, output2, equal_nan=True)
for i in range(len(output1)):
if not _equal_outputs(output1[i], output2[i]):
return False
return True
if isinstance(output1, dict): # 2. X a dictionary, dense collection or series
if output1.keys() != output2.keys():
return False
for k in output1.keys():
if not _equal_outputs(output1[k], output2[k]):
return False
return True
if isinstance(output1, pd.DataFrame) or isinstance(output1, pd.Series):
# 3. X a dataframe
return np.allclose(output1.values, output2.values, equal_nan=True)
if isinstance(output1, list): # X a possibly unequal length collection
if len(output1) != len(output2):
return False
for i in range(len(output1)):
if not _equal_outputs(output1[i], output2[i]):
return False
return True
if isinstance(output1, tuple): # returns (X,y)
if len(output1) != len(output2):
return False
for i in range(len(output1)):
if not _equal_outputs(output1[i], output2[i]):
return False
return True
return False


def check_persistence_via_pickle(estimator, datatype):
"""Check that we can pickle all estimators."""
estimator = _clone_estimator(estimator, random_state=0)
Expand Down
2 changes: 1 addition & 1 deletion aeon/testing/testing_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"RSASTClassifier": ["check_fit_deterministic"],
"SAST": ["check_fit_deterministic"],
"RSAST": ["check_fit_deterministic"],
"SFA": ["check_persistence_via_pickle", "check_fit_deterministic"],
# "SFA": ["check_persistence_via_pickle", "check_fit_deterministic"],
# missed in legacy testing, changes state in predict/transform
"FLUSSSegmenter": ["check_non_state_changing_method"],
"InformationGainSegmenter": ["check_non_state_changing_method"],
Expand Down

0 comments on commit f64bb8f

Please sign in to comment.