-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] ADD SAST transformer and SASTClassifier (#958)
* ADD SAST transformer and SASTClassifier - SAST: a subsequence based transformation for time series - SASTClassifier: as helper classifier to build a time series classifier based on the SAST transformation * add myself in all-contributors * update doc with SAST * matplotlib as soft dependency * fix SAST argument in example * fix fit return in SAST * fix fit return in SAST * `random_state` renamed to `seed` * fix attribute error in predict_proba * rename private variables * check if classifier has predict_proba is SASTClassifier * fix arg error in demo notebook * fix arg error in demo notebook * fix attribute error in SASTClassifier * passed unit tests * move SAST into the shapelet based category
- Loading branch information
Showing
9 changed files
with
973 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
"""Shapelet based time series classifiers.""" | ||
|
||
__all__ = ["MrSQMClassifier", "ShapeletTransformClassifier", "RDSTClassifier"] | ||
__all__ = [ | ||
"MrSQMClassifier", | ||
"ShapeletTransformClassifier", | ||
"RDSTClassifier", | ||
"SASTClassifier", | ||
] | ||
|
||
from aeon.classification.shapelet_based._mrsqm import MrSQMClassifier | ||
from aeon.classification.shapelet_based._rdst import RDSTClassifier | ||
from aeon.classification.shapelet_based._sast_classifier import SASTClassifier | ||
from aeon.classification.shapelet_based._stc import ShapeletTransformClassifier |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
"""Scalable and Accurate Subsequence Transform (SAST). | ||
Pipeline classifier using the SAST transformer and an sklearn classifier. | ||
""" | ||
|
||
__author__ = ["MichaelMbouopda"] | ||
__all__ = ["SASTClassifier"] | ||
|
||
from operator import itemgetter | ||
|
||
import numpy as np | ||
from sklearn.linear_model import RidgeClassifierCV | ||
from sklearn.pipeline import make_pipeline | ||
|
||
from aeon.base._base import _clone_estimator | ||
from aeon.classification import BaseClassifier | ||
from aeon.transformations.collection.shapelet_based import SAST | ||
from aeon.utils.numba.general import z_normalise_series | ||
|
||
|
||
class SASTClassifier(BaseClassifier): | ||
"""Classification pipeline using SAST [1]_ transformer and an sklean classifier. | ||
Parameters | ||
---------- | ||
length_list : int[], default = None | ||
an array containing the lengths of the subsequences to be generated. | ||
If None, will be infered during fit as np.arange(3, X.shape[1]) | ||
stride : int, default = 1 | ||
the stride used when generating subsquences | ||
nb_inst_per_class : int default = 1 | ||
the number of reference time series to select per class | ||
seed : int, default = None | ||
the seed of the random generator | ||
classifier : sklearn compatible classifier, default = None | ||
if None, a RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)) is used. | ||
n_jobs : int, default -1 | ||
Number of threads to use for the transform. | ||
Reference | ||
--------- | ||
.. [1] Mbouopda, Michael Franklin, and Engelbert Mephu Nguifo. | ||
"Scalable and accurate subsequence transform for time series classification." | ||
Pattern Recognition 147 (2023): 110121. | ||
https://www.sciencedirect.com/science/article/abs/pii/S003132032300818X, | ||
https://uca.hal.science/hal-03087686/document | ||
Examples | ||
-------- | ||
>>> from aeon.classification.shapelet_based import SASTClassifier | ||
>>> from aeon.datasets import load_unit_test | ||
>>> X_train, y_train = load_unit_test(split="train") | ||
>>> X_test, y_test = load_unit_test(split="test") | ||
>>> clf = SASTClassifier() | ||
>>> clf.fit(X_train, y_train) | ||
SASTClassifier(...) | ||
>>> y_pred = clf.predict(X_test) | ||
""" | ||
|
||
_tags = { | ||
"capability:multithreading": True, | ||
"capability:multivariate": False, | ||
"algorithm_type": "subsequence", | ||
} | ||
|
||
def __init__( | ||
self, | ||
length_list=None, | ||
stride=1, | ||
nb_inst_per_class=1, | ||
seed=None, | ||
classifier=None, | ||
n_jobs=-1, | ||
): | ||
super(SASTClassifier, self).__init__() | ||
self.length_list = length_list | ||
self.stride = stride | ||
self.nb_inst_per_class = nb_inst_per_class | ||
self.n_jobs = n_jobs | ||
self.seed = seed | ||
|
||
self.classifier = classifier | ||
|
||
def _fit(self, X, y): | ||
"""Fit SASTClassifier to the training data. | ||
Parameters | ||
---------- | ||
X: np.ndarray shape (n_time_series, n_channels, n_timepoints) | ||
The training input samples. | ||
y: array-like or list | ||
The class values for X. | ||
Return | ||
------ | ||
self : SASTClassifier | ||
This pipeline classifier | ||
""" | ||
self._transformer = SAST( | ||
self.length_list, | ||
self.stride, | ||
self.nb_inst_per_class, | ||
self.seed, | ||
self.n_jobs, | ||
) | ||
|
||
self._classifier = _clone_estimator( | ||
RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)) | ||
if self.classifier is None | ||
else self.classifier, | ||
self.seed, | ||
) | ||
|
||
self._pipeline = make_pipeline(self._transformer, self._classifier) | ||
|
||
self._pipeline.fit(X, y) | ||
|
||
return self | ||
|
||
def _predict(self, X): | ||
"""Predict labels for the input. | ||
Parameters | ||
---------- | ||
X: np.ndarray shape (n_time_series, n_channels, n_timepoints) | ||
The training input samples. | ||
Return | ||
------ | ||
array-like or list | ||
Predicted class labels. | ||
""" | ||
return self._pipeline.predict(X) | ||
|
||
def _predict_proba(self, X): | ||
"""Predict labels probabilities for the input. | ||
Parameters | ||
---------- | ||
X: np.ndarray shape (n_time_series, n_channels, n_timepoints) | ||
The training input samples. | ||
Return | ||
------ | ||
dists : np.ndarray shape (n_time_series, n_timepoints) | ||
Predicted class probabilities. | ||
""" | ||
m = getattr(self._classifier, "predict_proba", None) | ||
if callable(m): | ||
dists = self._pipeline.predict_proba(X) | ||
else: | ||
dists = np.zeros((X.shape[0], self.n_classes_)) | ||
preds = self._pipeline.predict(X) | ||
for i in range(0, X.shape[0]): | ||
dists[i, np.where(self.classes_ == preds[i])] = 1 | ||
return dists | ||
|
||
def plot_most_important_feature_on_ts(self, ts, feature_importance, limit=5): | ||
"""Plot the most important features on ts. | ||
Parameters | ||
---------- | ||
ts : float[:] | ||
The time series | ||
feature_importance : float[:] | ||
The importance of each feature in the transformed data | ||
limit : int, default = 5 | ||
The maximum number of features to plot | ||
Returns | ||
------- | ||
fig : plt.figure | ||
The figure | ||
""" | ||
import matplotlib.pyplot as plt | ||
|
||
features = zip(self._transformer._kernel_orig, feature_importance) | ||
sorted_features = sorted(features, key=itemgetter(1), reverse=True) | ||
|
||
max_ = min(limit, len(sorted_features)) | ||
|
||
fig, axes = plt.subplots( | ||
1, max_, sharey=True, figsize=(3 * max_, 3), tight_layout=True | ||
) | ||
|
||
for f in range(max_): | ||
kernel, _ = sorted_features[f] | ||
znorm_kernel = z_normalise_series(kernel) | ||
d_best = np.inf | ||
for i in range(ts.size - kernel.size): | ||
s = ts[i : i + kernel.size] | ||
s = z_normalise_series(s) | ||
d = np.sum((s - znorm_kernel) ** 2) | ||
if d < d_best: | ||
d_best = d | ||
start_pos = i | ||
axes[f].plot(range(start_pos, start_pos + kernel.size), kernel, linewidth=5) | ||
axes[f].plot(range(ts.size), ts, linewidth=2) | ||
axes[f].set_title(f"feature: {f+1}") | ||
|
||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,11 @@ | ||
"""Shapelet based transformers.""" | ||
|
||
__all__ = [ | ||
"RandomShapeletTransform", | ||
"RandomDilatedShapeletTransform", | ||
] | ||
__all__ = ["RandomShapeletTransform", "RandomDilatedShapeletTransform", "SAST"] | ||
|
||
from aeon.transformations.collection.shapelet_based._dilated_shapelet_transform import ( | ||
RandomDilatedShapeletTransform, | ||
) | ||
from aeon.transformations.collection.shapelet_based._sast import SAST | ||
from aeon.transformations.collection.shapelet_based._shapelet_transform import ( | ||
RandomShapeletTransform, | ||
) |
Oops, something went wrong.