Skip to content

Commit

Permalink
[ENH] ADD SAST transformer and SASTClassifier (#958)
Browse files Browse the repository at this point in the history
* ADD SAST transformer and SASTClassifier

- SAST: a subsequence based transformation for time series
- SASTClassifier: as helper classifier to build a time series classifier based on the SAST transformation

* add myself in all-contributors

* update doc with SAST

* matplotlib as soft dependency

* fix SAST argument in example

* fix fit return in SAST

* fix fit return in SAST

* `random_state` renamed to `seed`

* fix attribute error in predict_proba

* rename private variables

* check if classifier has predict_proba is SASTClassifier

* fix arg error in demo notebook

* fix arg error in demo notebook

* fix attribute error in SASTClassifier

* passed unit tests

* move SAST into the shapelet based category
  • Loading branch information
frankl1 authored Dec 19, 2023
1 parent 8d3daa7 commit 8e88d41
Show file tree
Hide file tree
Showing 9 changed files with 973 additions and 12 deletions.
11 changes: 11 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -2237,6 +2237,17 @@
"contributions": [
"doc"
]
},
{
"login": "frankl1",
"name": "Michael F. Mbouopda",
"avatar_url": "https://avatars.githubusercontent.com/u/23366578?s=96&v=4",
"profile": "https://github.com/frankl1",
"contributions": [
"code",
"bug",
"doc"
]
}
],
"commitType": "docs"
Expand Down
8 changes: 7 additions & 1 deletion aeon/classification/shapelet_based/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
"""Shapelet based time series classifiers."""

__all__ = ["MrSQMClassifier", "ShapeletTransformClassifier", "RDSTClassifier"]
__all__ = [
"MrSQMClassifier",
"ShapeletTransformClassifier",
"RDSTClassifier",
"SASTClassifier",
]

from aeon.classification.shapelet_based._mrsqm import MrSQMClassifier
from aeon.classification.shapelet_based._rdst import RDSTClassifier
from aeon.classification.shapelet_based._sast_classifier import SASTClassifier
from aeon.classification.shapelet_based._stc import ShapeletTransformClassifier
203 changes: 203 additions & 0 deletions aeon/classification/shapelet_based/_sast_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Scalable and Accurate Subsequence Transform (SAST).
Pipeline classifier using the SAST transformer and an sklearn classifier.
"""

__author__ = ["MichaelMbouopda"]
__all__ = ["SASTClassifier"]

from operator import itemgetter

import numpy as np
from sklearn.linear_model import RidgeClassifierCV
from sklearn.pipeline import make_pipeline

from aeon.base._base import _clone_estimator
from aeon.classification import BaseClassifier
from aeon.transformations.collection.shapelet_based import SAST
from aeon.utils.numba.general import z_normalise_series


class SASTClassifier(BaseClassifier):
"""Classification pipeline using SAST [1]_ transformer and an sklean classifier.
Parameters
----------
length_list : int[], default = None
an array containing the lengths of the subsequences to be generated.
If None, will be infered during fit as np.arange(3, X.shape[1])
stride : int, default = 1
the stride used when generating subsquences
nb_inst_per_class : int default = 1
the number of reference time series to select per class
seed : int, default = None
the seed of the random generator
classifier : sklearn compatible classifier, default = None
if None, a RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)) is used.
n_jobs : int, default -1
Number of threads to use for the transform.
Reference
---------
.. [1] Mbouopda, Michael Franklin, and Engelbert Mephu Nguifo.
"Scalable and accurate subsequence transform for time series classification."
Pattern Recognition 147 (2023): 110121.
https://www.sciencedirect.com/science/article/abs/pii/S003132032300818X,
https://uca.hal.science/hal-03087686/document
Examples
--------
>>> from aeon.classification.shapelet_based import SASTClassifier
>>> from aeon.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train")
>>> X_test, y_test = load_unit_test(split="test")
>>> clf = SASTClassifier()
>>> clf.fit(X_train, y_train)
SASTClassifier(...)
>>> y_pred = clf.predict(X_test)
"""

_tags = {
"capability:multithreading": True,
"capability:multivariate": False,
"algorithm_type": "subsequence",
}

def __init__(
self,
length_list=None,
stride=1,
nb_inst_per_class=1,
seed=None,
classifier=None,
n_jobs=-1,
):
super(SASTClassifier, self).__init__()
self.length_list = length_list
self.stride = stride
self.nb_inst_per_class = nb_inst_per_class
self.n_jobs = n_jobs
self.seed = seed

self.classifier = classifier

def _fit(self, X, y):
"""Fit SASTClassifier to the training data.
Parameters
----------
X: np.ndarray shape (n_time_series, n_channels, n_timepoints)
The training input samples.
y: array-like or list
The class values for X.
Return
------
self : SASTClassifier
This pipeline classifier
"""
self._transformer = SAST(
self.length_list,
self.stride,
self.nb_inst_per_class,
self.seed,
self.n_jobs,
)

self._classifier = _clone_estimator(
RidgeClassifierCV(alphas=np.logspace(-3, 3, 10))
if self.classifier is None
else self.classifier,
self.seed,
)

self._pipeline = make_pipeline(self._transformer, self._classifier)

self._pipeline.fit(X, y)

return self

def _predict(self, X):
"""Predict labels for the input.
Parameters
----------
X: np.ndarray shape (n_time_series, n_channels, n_timepoints)
The training input samples.
Return
------
array-like or list
Predicted class labels.
"""
return self._pipeline.predict(X)

def _predict_proba(self, X):
"""Predict labels probabilities for the input.
Parameters
----------
X: np.ndarray shape (n_time_series, n_channels, n_timepoints)
The training input samples.
Return
------
dists : np.ndarray shape (n_time_series, n_timepoints)
Predicted class probabilities.
"""
m = getattr(self._classifier, "predict_proba", None)
if callable(m):
dists = self._pipeline.predict_proba(X)
else:
dists = np.zeros((X.shape[0], self.n_classes_))
preds = self._pipeline.predict(X)
for i in range(0, X.shape[0]):
dists[i, np.where(self.classes_ == preds[i])] = 1
return dists

def plot_most_important_feature_on_ts(self, ts, feature_importance, limit=5):
"""Plot the most important features on ts.
Parameters
----------
ts : float[:]
The time series
feature_importance : float[:]
The importance of each feature in the transformed data
limit : int, default = 5
The maximum number of features to plot
Returns
-------
fig : plt.figure
The figure
"""
import matplotlib.pyplot as plt

features = zip(self._transformer._kernel_orig, feature_importance)
sorted_features = sorted(features, key=itemgetter(1), reverse=True)

max_ = min(limit, len(sorted_features))

fig, axes = plt.subplots(
1, max_, sharey=True, figsize=(3 * max_, 3), tight_layout=True
)

for f in range(max_):
kernel, _ = sorted_features[f]
znorm_kernel = z_normalise_series(kernel)
d_best = np.inf
for i in range(ts.size - kernel.size):
s = ts[i : i + kernel.size]
s = z_normalise_series(s)
d = np.sum((s - znorm_kernel) ** 2)
if d < d_best:
d_best = d
start_pos = i
axes[f].plot(range(start_pos, start_pos + kernel.size), kernel, linewidth=5)
axes[f].plot(range(ts.size), ts, linewidth=2)
axes[f].set_title(f"feature: {f+1}")

return fig
10 changes: 4 additions & 6 deletions aeon/transformations/collection/convolution_based/_rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
__author__ = ["angus924"]
__all__ = ["Rocket"]

import multiprocessing

import numpy as np
from numba import get_num_threads, njit, prange, set_num_threads

from aeon.transformations.collection import BaseCollectionTransformer
from aeon.utils.validation import check_n_jobs


class Rocket(BaseCollectionTransformer):
Expand Down Expand Up @@ -119,10 +118,9 @@ def _transform(self, X, y=None):
X.std(axis=-1, keepdims=True) + 1e-8
)
prev_threads = get_num_threads()
if self.n_jobs < 1 or self.n_jobs > multiprocessing.cpu_count():
n_jobs = multiprocessing.cpu_count()
else:
n_jobs = self.n_jobs

n_jobs = check_n_jobs(self.n_jobs)

set_num_threads(n_jobs)
X_ = _apply_kernels(X.astype(np.float32), self.kernels)
set_num_threads(prev_threads)
Expand Down
6 changes: 2 additions & 4 deletions aeon/transformations/collection/shapelet_based/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Shapelet based transformers."""

__all__ = [
"RandomShapeletTransform",
"RandomDilatedShapeletTransform",
]
__all__ = ["RandomShapeletTransform", "RandomDilatedShapeletTransform", "SAST"]

from aeon.transformations.collection.shapelet_based._dilated_shapelet_transform import (
RandomDilatedShapeletTransform,
)
from aeon.transformations.collection.shapelet_based._sast import SAST
from aeon.transformations.collection.shapelet_based._shapelet_transform import (
RandomShapeletTransform,
)
Loading

0 comments on commit 8e88d41

Please sign in to comment.