Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] ADD SAST transformer and SASTClassifier #958

Merged
merged 23 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
acd6c01
ADD SAST transformer and SASTClassifier
frankl1 Dec 3, 2023
419d70b
add myself in all-contributors
frankl1 Dec 3, 2023
5ebd04d
update doc with SAST
frankl1 Dec 3, 2023
1d26841
matplotlib as soft dependency
frankl1 Dec 3, 2023
2ef58bc
fix SAST argument in example
frankl1 Dec 3, 2023
099fc16
fix fit return in SAST
frankl1 Dec 3, 2023
6ff65dc
fix fit return in SAST
frankl1 Dec 3, 2023
014f1c2
Merge branch 'sast' of github.com:frankl1/aeon into sast
frankl1 Dec 3, 2023
99cf8cc
`random_state` renamed to `seed`
frankl1 Dec 3, 2023
095a6b9
fix attribute error in predict_proba
frankl1 Dec 3, 2023
317a2b0
rename private variables
frankl1 Dec 3, 2023
10d2268
check if classifier has predict_proba is SASTClassifier
frankl1 Dec 3, 2023
8c76f44
fix arg error in demo notebook
frankl1 Dec 3, 2023
04d4e21
fix arg error in demo notebook
frankl1 Dec 3, 2023
a292a67
Merge branch 'sast' of github.com:frankl1/aeon into sast
frankl1 Dec 3, 2023
5769d3d
fix attribute error in SASTClassifier
frankl1 Dec 4, 2023
c7c5f2f
passed unit tests
frankl1 Dec 4, 2023
734a8a9
move SAST into the shapelet based category
frankl1 Dec 4, 2023
47ba290
Merge branch 'sast' of github.com:frankl1/aeon into sast
frankl1 Dec 8, 2023
1856c1e
taking @baraline comments into account
frankl1 Dec 8, 2023
51120de
Merge branch 'sast' of github.com:frankl1/aeon into sast
frankl1 Dec 8, 2023
c3d1902
update docstring
frankl1 Dec 8, 2023
17c7dd7
Merge branch 'sast' of github.com:frankl1/aeon into sast
frankl1 Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -2227,6 +2227,17 @@
"contributions": [
"doc"
]
},
{
"login": "frankl1",
"name": "Michael F. Mbouopda",
"avatar_url": "https://avatars.githubusercontent.com/u/23366578?s=96&v=4",
"profile": "https://github.com/frankl1",
"contributions": [
"code",
"bug",
"doc"
]
}
],
"commitType": "docs"
Expand Down
5 changes: 5 additions & 0 deletions aeon/classification/subsequence_based/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Subsequence based time series classifiers."""

__all__ = ["SASTClassifier"]

from aeon.classification.subsequence_based._sast_classifier import SASTClassifier
204 changes: 204 additions & 0 deletions aeon/classification/subsequence_based/_sast_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""Scalable and Accurate Subsequence Transform (SAST).

Pipeline classifier using the SAST transformer and an sklearn classifier.
"""

__author__ = ["MichaelMbouopda"]
__all__ = ["SASTClassifier"]

from operator import itemgetter

import numpy as np
from sklearn.linear_model import RidgeClassifierCV
from sklearn.pipeline import make_pipeline

from aeon.base._base import _clone_estimator
from aeon.classification import BaseClassifier
from aeon.transformations.collection.subsequence_based import SAST


class SASTClassifier(BaseClassifier):
"""Classification pipeline using SAST [1]_ transformer and an sklean classifier.

Parameters
----------
length_list : int[], default = None
an array containing the lengths of the subsequences to be generated.
If None, will be infered during fit as np.arange(3, X.shape[1])
stride : int, default = 1
the stride used when generating subsquences
nb_inst_per_class : int default = 1
the number of reference time series to select per class
seed : int, default = None
the seed of the random generator
classifier : sklearn compatible classifier, default = None
if None, a RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)) is used.
n_jobs : int, default -1
frankl1 marked this conversation as resolved.
Show resolved Hide resolved
Number of threads to use for the transform.


Reference
---------
.. [1] Mbouopda, Michael Franklin, and Engelbert Mephu Nguifo.
"Scalable and accurate subsequence transform for time series classification."
Pattern Recognition 147 (2023): 110121.
https://www.sciencedirect.com/science/article/abs/pii/S003132032300818X,
https://uca.hal.science/hal-03087686/document

Examples
--------
>>> from aeon.classification.subsequence_based import SASTClassifier
>>> from aeon.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train")
>>> X_test, y_test = load_unit_test(split="test")
>>> clf = SASTClassifier()
>>> clf.fit(X_train, y_train)
SASTClassifier(...)
>>> y_pred = clf.predict(X_test)
"""

_tags = {
"capability:multithreading": True,
"capability:multivariate": False,
"algorithm_type": "subsequence",
}

def __init__(
self,
length_list=None,
stride=1,
nb_inst_per_class=1,
seed=None,
classifier=None,
n_jobs=-1,
):
super(SASTClassifier, self).__init__()
self.length_list = length_list
self.stride = stride
self.nb_inst_per_class = nb_inst_per_class
self.n_jobs = n_jobs
self.seed = seed

self.classifier = classifier

def _fit(self, X, y):
"""Fit SASTClassifier to the training data.

Parameters
----------
X : float[:,:,:]
frankl1 marked this conversation as resolved.
Show resolved Hide resolved
an array of shape (n_time_series, n_channels,
time_series_length) containing the time series
y : Any[:]
an array of shape (n_time_series,), containing
the class label of each time series in X

Return
------
self

"""
self._transformer = SAST(
self.length_list,
self.stride,
self.nb_inst_per_class,
self.seed,
self.n_jobs,
)

self._classifier = _clone_estimator(
RidgeClassifierCV(alphas=np.logspace(-3, 3, 10))
if self.classifier is None
else self.classifier,
self.seed,
)

self._pipeline = make_pipeline(self._transformer, self._classifier)

self._pipeline.fit(X, y)

return self

def _predict(self, X):
"""Predict labels for the input.

Parameters
----------
X : float[:,:,:]
an array of shape (n_time_series, n_channels,
time_series_length) containing the time series

Return
------
y : array-like, shape = [n_instances]
Predicted class labels.
"""
return self._pipeline.predict(X)

def _predict_proba(self, X):
"""Predict labels probabilities for the input.

Parameters
----------
X : float[:,:,:]
an array of shape (n_time_series, n_channels,
time_series_length) containing the time series

Return
------
y : array-like, shape = [n_instances, n_classes]
Predicted class probabilities.
"""
m = getattr(self._classifier, "predict_proba", None)
if callable(m):
return self._pipeline.predict_proba(X)
else:
dists = np.zeros((X.shape[0], self.n_classes_))
preds = self._pipeline.predict(X)
for i in range(0, X.shape[0]):
dists[i, np.where(self.classes_ == preds[i])] = 1
return dists

def plot_most_important_feature_on_ts(self, ts, feature_importance, limit=5):
"""Plot the most important features on ts.

Parameters
----------
ts : float[:]
The time series
feature_importance : float[:]
The importance of each feature in the transformed data
limit : int, default = 5
The maximum number of features to plot

Returns
-------
plt figure
"""
import matplotlib.pyplot as plt

features = zip(self._transformer._kernel_orig, feature_importance)
sorted_features = sorted(features, key=itemgetter(1), reverse=True)

max_ = min(limit, len(sorted_features))

fig, axes = plt.subplots(
1, max_, sharey=True, figsize=(3 * max_, 3), tight_layout=True
)

for f in range(max_):
kernel, _ = sorted_features[f]
znorm_kernel = (kernel - kernel.mean()) / (kernel.std() + 1e-8)
d_best = np.inf
for i in range(ts.size - kernel.size):
s = ts[i : i + kernel.size]
s = (s - s.mean()) / (s.std() + 1e-8)
d = np.sum((s - znorm_kernel) ** 2)
if d < d_best:
d_best = d
start_pos = i
axes[f].plot(range(start_pos, start_pos + kernel.size), kernel, linewidth=5)
baraline marked this conversation as resolved.
Show resolved Hide resolved
axes[f].plot(range(ts.size), ts, linewidth=2)
axes[f].set_title(f"feature: {f+1}")

return fig
6 changes: 6 additions & 0 deletions aeon/transformations/collection/subsequence_based/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""SAST transformers."""
__all__ = [
"SAST",
]

from ._sast import SAST
Loading