-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create noplot_nch_study.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * remove unused imports * [pre-commit.ci] auto fixes from pre-commit.com hooks * Rename noplot_nch_study.py to noplot_nch_study.py * Create noplot_ablation_study_nch.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update noplot_nch_study.py * Update noplot_ablation_study_nch.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update noplot_ablation_study_nch.py fix inverted condition (min/random) for display * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update noplot_ablation_study_nch.py fix key name for pipelines * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update noplot_nch_study.py remove dead code * Update noplot_ablation_study_nch.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update noplot_nch_study.py Change axis limits * Update noplot_ablation_study_nch.py correct keyname for randomhull * [pre-commit.ci] auto fixes from pre-commit.com hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
65dc2b0
commit d5eb32a
Showing
2 changed files
with
393 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
""" | ||
==================================================================== | ||
Ablation study for the NCH | ||
==================================================================== | ||
This example is an ablation study of the NCH. | ||
Two subsampling strategies (min and random) are benchmarked, | ||
varying the number of hull and samples. | ||
We used the dataset bi2012 for this study. | ||
""" | ||
# Author: Gregoire Cattan, Quentin Barthelemy | ||
# License: BSD (3-clause) | ||
|
||
import random | ||
import warnings | ||
|
||
import numpy as np | ||
import qiskit_algorithms | ||
import seaborn as sns | ||
from matplotlib import cm | ||
from matplotlib import pyplot as plt | ||
from matplotlib.ticker import FuncFormatter | ||
from moabb import set_log_level | ||
from moabb.datasets import bi2012 | ||
from moabb.evaluations import WithinSessionEvaluation | ||
from moabb.paradigms import P300 | ||
from pyriemann.estimation import XdawnCovariances | ||
from sklearn.pipeline import make_pipeline | ||
|
||
from pyriemann_qiskit.classification import QuanticNCH | ||
|
||
print(__doc__) | ||
|
||
############################################################################## | ||
# getting rid of the warnings about the future | ||
warnings.simplefilter(action="ignore", category=FutureWarning) | ||
warnings.simplefilter(action="ignore", category=RuntimeWarning) | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
set_log_level("info") | ||
|
||
############################################################################## | ||
# Create Pipelines | ||
# ---------------- | ||
# | ||
# Pipelines must be a dict of sklearn pipeline transformer. | ||
|
||
############################################################################## | ||
# We have to do this because the classes are called 'Target' and 'NonTarget' | ||
# but the evaluation function uses a LabelEncoder, transforming them | ||
# to 0 and 1 | ||
labels_dict = {"Target": 1, "NonTarget": 0} | ||
|
||
events = ["on", "off"] | ||
paradigm = P300() | ||
|
||
datasets = [bi2012()] | ||
|
||
for dataset in datasets: | ||
dataset.subject_list = dataset.subject_list[0:-1] | ||
|
||
overwrite = True # set to True if we want to overwrite cached results | ||
|
||
pipelines = {} | ||
|
||
############################################################################## | ||
# Set seed | ||
|
||
seed = 475751 | ||
|
||
random.seed(seed) | ||
np.random.seed(seed) | ||
qiskit_algorithms.utils.algorithm_globals.random_seed | ||
|
||
|
||
############################################################################## | ||
# Set NCH strategy | ||
|
||
strategy = "random" # or "random" | ||
pipe_name = strategy.upper() | ||
|
||
max_hull_per_class = 1 if strategy == "min" else 6 | ||
max_samples_per_hull = 15 if strategy == "min" else 25 | ||
samples_step = 1 if strategy == "min" else 5 | ||
|
||
############################################################################## | ||
# Define spatial filtering | ||
|
||
sf = make_pipeline(XdawnCovariances()) | ||
|
||
############################################################################## | ||
# Define pipelines | ||
|
||
for n_hulls_per_class in range(1, max_hull_per_class + 1, 1): | ||
for n_samples_per_hull in range(1, max_samples_per_hull + 1, samples_step): | ||
key = f"NCH+{pipe_name}_HULL_{n_hulls_per_class}h_{n_samples_per_hull}samples" | ||
print(key) | ||
pipelines[key] = make_pipeline( | ||
sf, | ||
QuanticNCH( | ||
seed=seed, | ||
n_hulls_per_class=n_hulls_per_class, | ||
n_samples_per_hull=n_samples_per_hull, | ||
n_jobs=12, | ||
subsampling=strategy, | ||
quantum=False, | ||
), | ||
) | ||
|
||
|
||
print("Total pipelines to evaluate: ", len(pipelines)) | ||
print(np.unique(pipelines.keys())) | ||
|
||
############################################################################## | ||
# Run evaluation | ||
|
||
evaluation = WithinSessionEvaluation( | ||
paradigm=paradigm, | ||
datasets=datasets, | ||
suffix="examples", | ||
overwrite=overwrite, | ||
n_splits=None, | ||
random_state=seed, | ||
) | ||
|
||
|
||
results = evaluation.process(pipelines) | ||
|
||
############################################################################## | ||
# Print results | ||
|
||
|
||
def get_hull(v): | ||
return int(v.split(f"NCH+{pipe_name}_HULL_")[1].split("h_")[0]) | ||
|
||
|
||
def get_samples(v): | ||
return int(v.split("h_")[1].split("samples")[0]) | ||
|
||
|
||
results["n_hull"] = results["pipeline"].apply(get_hull) | ||
results["n_samples"] = results["pipeline"].apply(get_samples) | ||
print(results) | ||
|
||
means = results.groupby("pipeline").mean() | ||
|
||
if strategy == "random": | ||
fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) | ||
ax.plot_trisurf( | ||
means.n_hull, means.n_samples, means.score, cmap=cm.jet, linewidth=0.2 | ||
) | ||
ax.set_xlabel("n_hull") | ||
ax.set_ylabel("n_samples") | ||
ax.set_zlabel("score") | ||
else: | ||
sns.pointplot(means, x="n_samples", y="score") | ||
plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x, _: int(x))) | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
""" | ||
==================================================================== | ||
Classification of P300 datasets from MOABB using NCH | ||
==================================================================== | ||
Comparison of NCH with different optimization methods, | ||
in a "hard" dataset (classical methods don't provide results). | ||
""" | ||
# Author: Gregoire Cattan, Quentin Barthelemy | ||
# Modified from noplot_classify_P300_nch.py | ||
# License: BSD (3-clause) | ||
|
||
import random | ||
import warnings | ||
|
||
import numpy as np | ||
import qiskit_algorithms | ||
import seaborn as sns | ||
from matplotlib import pyplot as plt | ||
from moabb import set_log_level | ||
from moabb.datasets import Cattan2019_PHMD | ||
from moabb.evaluations import CrossSubjectEvaluation | ||
from moabb.paradigms import RestingStateToP300Adapter | ||
from pyriemann.classification import MDM | ||
from pyriemann.estimation import Covariances | ||
from pyriemann.tangentspace import TangentSpace | ||
from qiskit_algorithms.optimizers import SPSA | ||
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA | ||
from sklearn.pipeline import make_pipeline | ||
|
||
from pyriemann_qiskit.classification import QuanticNCH | ||
from pyriemann_qiskit.utils.hyper_params_factory import create_mixer_rotational_X_gates | ||
|
||
print(__doc__) | ||
|
||
############################################################################## | ||
# getting rid of the warnings about the future | ||
warnings.simplefilter(action="ignore", category=FutureWarning) | ||
warnings.simplefilter(action="ignore", category=RuntimeWarning) | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
set_log_level("info") | ||
|
||
############################################################################## | ||
# Set global seed for better reproducibility | ||
seed = 475751 | ||
|
||
random.seed(seed) | ||
np.random.seed(seed) | ||
qiskit_algorithms.utils.algorithm_globals.random_seed | ||
|
||
############################################################################## | ||
# Create Pipelines | ||
# ---------------- | ||
# | ||
# Pipelines must be a dict of sklearn pipeline transformer. | ||
|
||
events = ["on", "off"] | ||
paradigm = RestingStateToP300Adapter(events=events) | ||
|
||
datasets = [Cattan2019_PHMD()] | ||
|
||
overwrite = True # set to True if we want to overwrite cached results | ||
|
||
pipelines = {} | ||
|
||
n_hulls_per_class = 3 | ||
n_samples_per_hull = 6 | ||
|
||
sf = make_pipeline( | ||
Covariances(estimator="lwf"), | ||
) | ||
|
||
############################################################################## | ||
# NCH without quantum optimization | ||
pipelines["NCH+RANDOM_HULL"] = make_pipeline( | ||
sf, | ||
QuanticNCH( | ||
seed=seed, | ||
n_hulls_per_class=n_hulls_per_class, | ||
n_samples_per_hull=n_samples_per_hull, | ||
n_jobs=12, | ||
subsampling="random", | ||
quantum=False, | ||
), | ||
) | ||
|
||
pipelines["NCH+MIN_HULL"] = make_pipeline( | ||
sf, | ||
QuanticNCH( | ||
seed=seed, | ||
n_samples_per_hull=n_samples_per_hull, | ||
n_jobs=12, | ||
subsampling="min", | ||
quantum=False, | ||
), | ||
) | ||
|
||
|
||
############################################################################## | ||
# NCH with quantum optimization | ||
pipelines["NCH+RANDOM_HULL_QAOACV"] = make_pipeline( | ||
sf, | ||
QuanticNCH( | ||
seed=seed, | ||
n_hulls_per_class=n_hulls_per_class, | ||
n_samples_per_hull=n_samples_per_hull, | ||
n_jobs=12, | ||
subsampling="random", | ||
quantum=True, | ||
create_mixer=create_mixer_rotational_X_gates(0), | ||
shots=100, | ||
qaoa_optimizer=SPSA(maxiter=100, blocking=False), | ||
n_reps=2, | ||
), | ||
) | ||
|
||
pipelines["NCH+RANDOM_HULL_NAIVEQAOA"] = make_pipeline( | ||
sf, | ||
QuanticNCH( | ||
seed=seed, | ||
n_hulls_per_class=n_hulls_per_class, | ||
n_samples_per_hull=n_samples_per_hull, | ||
n_jobs=12, | ||
subsampling="random", | ||
quantum=True, | ||
), | ||
) | ||
|
||
pipelines["NCH+MIN_HULL_QAOACV"] = make_pipeline( | ||
sf, | ||
QuanticNCH( | ||
seed=seed, | ||
n_samples_per_hull=n_samples_per_hull, | ||
n_jobs=12, | ||
subsampling="min", | ||
quantum=True, | ||
create_mixer=create_mixer_rotational_X_gates(0), | ||
shots=100, | ||
qaoa_optimizer=SPSA(maxiter=100, blocking=False), | ||
n_reps=2, | ||
), | ||
) | ||
|
||
pipelines["NCH+MIN_HULL_NAIVEQAOA"] = make_pipeline( | ||
sf, | ||
QuanticNCH( | ||
seed=seed, | ||
n_samples_per_hull=n_samples_per_hull, | ||
n_jobs=12, | ||
subsampling="min", | ||
quantum=True, | ||
), | ||
) | ||
|
||
############################################################################## | ||
# SOTA classical methods for comparison | ||
pipelines["MDM"] = make_pipeline( | ||
sf, | ||
MDM(), | ||
) | ||
|
||
pipelines["TS+LDA"] = make_pipeline( | ||
sf, | ||
TangentSpace(metric="riemann"), | ||
LDA(), | ||
) | ||
|
||
print("Total pipelines to evaluate: ", len(pipelines)) | ||
|
||
evaluation = CrossSubjectEvaluation( | ||
paradigm=paradigm, | ||
datasets=datasets, | ||
suffix="examples", | ||
overwrite=overwrite, | ||
n_splits=3, | ||
random_state=seed, | ||
) | ||
|
||
results = evaluation.process(pipelines) | ||
|
||
print("Averaging the session performance:") | ||
print(results.groupby("pipeline").mean("score")[["score", "time"]]) | ||
|
||
############################################################################## | ||
# Plot Results | ||
# ---------------- | ||
# | ||
# Here we plot the results to compare the two pipelines | ||
|
||
fig, ax = plt.subplots(facecolor="white", figsize=[8, 4]) | ||
|
||
order = [ | ||
"NCH+RANDOM_HULL", | ||
"NCH+RANDOM_HULL_NAIVEQAOA", | ||
"NCH+RANDOM_HULL_QAOACV", | ||
"NCH+MIN_HULL", | ||
"NCH+MIN_HULL_NAIVEQAOA", | ||
"NCH+MIN_HULL_QAOACV", | ||
"TS+LDA", | ||
"MDM", | ||
] | ||
|
||
sns.stripplot( | ||
data=results, | ||
y="score", | ||
x="pipeline", | ||
ax=ax, | ||
jitter=True, | ||
alpha=0.5, | ||
zorder=1, | ||
palette="Set1", | ||
order=order, | ||
hue_order=order, | ||
) | ||
sns.pointplot( | ||
data=results, | ||
y="score", | ||
x="pipeline", | ||
ax=ax, | ||
palette="Set1", | ||
order=order, | ||
hue_order=order, | ||
) | ||
|
||
ax.set_ylabel("ROC AUC") | ||
ax.set_ylim(0.35, 0.7) | ||
plt.xticks(rotation=45) | ||
plt.show() |