Skip to content

Commit

Permalink
[MNT] Add/rework transformation tests and remove from exclude list (#…
Browse files Browse the repository at this point in the history
…2360)

* excluded tests

* trying to fix things

* tidy up transform testing

* fixes

* fixes

* fixes

* still trying to make this work

* ignore index for pandas

* allclose

* Empty commit for CI

* correct

* rist

* rist

* fix

---------

Co-authored-by: MatthewMiddlehurst <[email protected]>
  • Loading branch information
MatthewMiddlehurst and MatthewMiddlehurst authored Nov 24, 2024
1 parent 224cdc1 commit 05097f5
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 239 deletions.
9 changes: 8 additions & 1 deletion aeon/testing/estimator_checking/_estimator_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
from aeon.testing.estimator_checking._yield_estimator_checks import (
_yield_all_aeon_checks,
)
from aeon.testing.testing_config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
from aeon.testing.testing_config import (
EXCLUDE_ESTIMATORS,
EXCLUDED_TESTS,
EXCLUDED_TESTS_NO_NUMBA,
NUMBA_DISABLED,
)
from aeon.utils.validation._dependencies import (
_check_estimator_deps,
_check_soft_dependencies,
Expand Down Expand Up @@ -313,6 +318,8 @@ def _should_be_skipped(estimator, check, has_dependencies):
return True, "In aeon estimator exclude list", check_name
elif check_name in EXCLUDED_TESTS.get(est_name, []):
return True, "In aeon test exclude list for estimator", check_name
elif NUMBA_DISABLED and check_name in EXCLUDED_TESTS_NO_NUMBA.get(est_name, []):
return True, "In aeon no numba test exclude list for estimator", check_name

return False, "", check_name

Expand Down
21 changes: 8 additions & 13 deletions aeon/testing/estimator_checking/_yield_classification_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,14 @@ def _yield_classification_checks(estimator_class, estimator_instances, datatypes
results_dict=unit_test_proba,
resample_seed=0,
)
# the test currently fails when numba is disabled. See issue #622
if (
estimator_class.__name__ != "HIVECOTEV2"
or os.environ.get("NUMBA_DISABLE_JIT") != "1"
):
yield partial(
check_classifier_against_expected_results,
estimator_class=estimator_class,
data_name="BasicMotions",
data_loader=load_basic_motions,
results_dict=basic_motions_proba,
resample_seed=4,
)
yield partial(
check_classifier_against_expected_results,
estimator_class=estimator_class,
data_name="BasicMotions",
data_loader=load_basic_motions,
results_dict=basic_motions_proba,
resample_seed=4,
)
yield partial(check_classifier_overrides_and_tags, estimator_class=estimator_class)

# data type irrelevant
Expand Down

This file was deleted.

32 changes: 5 additions & 27 deletions aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from aeon.regression import BaseRegressor
from aeon.regression.deep_learning.base import BaseDeepRegressor
from aeon.segmentation import BaseSegmenter
from aeon.similarity_search import BaseSimilaritySearch
from aeon.testing.estimator_checking._yield_anomaly_detection_checks import (
_yield_anomaly_detection_checks,
)
Expand All @@ -34,9 +33,6 @@
from aeon.testing.estimator_checking._yield_clustering_checks import (
_yield_clustering_checks,
)
from aeon.testing.estimator_checking._yield_collection_transformation_checks import (
_yield_collection_transformation_checks,
)
from aeon.testing.estimator_checking._yield_early_classification_checks import (
_yield_early_classification_checks,
)
Expand All @@ -49,12 +45,6 @@
from aeon.testing.estimator_checking._yield_segmentation_checks import (
_yield_segmentation_checks,
)
from aeon.testing.estimator_checking._yield_series_transformation_checks import (
_yield_series_transformation_checks,
)
from aeon.testing.estimator_checking._yield_similarity_search_checks import (
_yield_similarity_search_checks,
)
from aeon.testing.estimator_checking._yield_soft_dependency_checks import (
_yield_soft_dependency_checks,
)
Expand All @@ -69,8 +59,6 @@
from aeon.testing.utils.deep_equals import deep_equals
from aeon.testing.utils.estimator_checks import _get_tag, _run_estimator_method
from aeon.transformations.base import BaseTransformer
from aeon.transformations.collection import BaseCollectionTransformer
from aeon.transformations.series import BaseSeriesTransformer
from aeon.utils.base import VALID_ESTIMATOR_BASES
from aeon.utils.tags import check_valid_tags
from aeon.utils.validation._dependencies import _check_estimator_deps
Expand Down Expand Up @@ -153,26 +141,11 @@ def _yield_all_aeon_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseSimilaritySearch):
yield from _yield_similarity_search_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseTransformer):
yield from _yield_transformation_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseCollectionTransformer):
yield from _yield_collection_transformation_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseSeriesTransformer):
yield from _yield_series_transformation_checks(
estimator_class, estimator_instances, datatypes
)


def _yield_estimator_checks(estimator_class, estimator_instances, datatypes):
"""Yield all general checks for an aeon estimator."""
Expand Down Expand Up @@ -289,6 +262,11 @@ def check_has_common_interface(estimator_class):
"axis" not in estimator_class.__dict__
), "axis should not be a class parameter"

# Must have at least one set to True
multi = estimator_class.get_class_tag(tag_name="capability:multivariate")
uni = estimator_class.get_class_tag(tag_name="capability:univariate")
assert multi or uni


def check_set_params(estimator_class):
"""Check that set_params works correctly."""
Expand Down

This file was deleted.

This file was deleted.

Loading

0 comments on commit 05097f5

Please sign in to comment.