Skip to content

Commit

Permalink
NCH studies (#338)
Browse files Browse the repository at this point in the history
* Create noplot_nch_study.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* remove unused imports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Rename noplot_nch_study.py to noplot_nch_study.py

* Create noplot_ablation_study_nch.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update noplot_nch_study.py

* Update noplot_ablation_study_nch.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update noplot_ablation_study_nch.py

fix inverted condition (min/random) for display

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update noplot_ablation_study_nch.py

fix key name for pipelines

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update noplot_nch_study.py

remove dead code

* Update noplot_ablation_study_nch.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update noplot_nch_study.py

Change axis limits

* Update noplot_ablation_study_nch.py

correct keyname for randomhull

* [pre-commit.ci] auto fixes from pre-commit.com hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gcattan and pre-commit-ci[bot] authored Jan 20, 2025
1 parent 65dc2b0 commit d5eb32a
Show file tree
Hide file tree
Showing 2 changed files with 393 additions and 0 deletions.
162 changes: 162 additions & 0 deletions examples/ERP/noplot_ablation_study_nch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
====================================================================
Ablation study for the NCH
====================================================================
This example is an ablation study of the NCH.
Two subsampling strategies (min and random) are benchmarked,
varying the number of hull and samples.
We used the dataset bi2012 for this study.
"""
# Author: Gregoire Cattan, Quentin Barthelemy
# License: BSD (3-clause)

import random
import warnings

import numpy as np
import qiskit_algorithms
import seaborn as sns
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter
from moabb import set_log_level
from moabb.datasets import bi2012
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import P300
from pyriemann.estimation import XdawnCovariances
from sklearn.pipeline import make_pipeline

from pyriemann_qiskit.classification import QuanticNCH

print(__doc__)

##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

warnings.filterwarnings("ignore")

set_log_level("info")

##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.

##############################################################################
# We have to do this because the classes are called 'Target' and 'NonTarget'
# but the evaluation function uses a LabelEncoder, transforming them
# to 0 and 1
labels_dict = {"Target": 1, "NonTarget": 0}

events = ["on", "off"]
paradigm = P300()

datasets = [bi2012()]

for dataset in datasets:
dataset.subject_list = dataset.subject_list[0:-1]

overwrite = True # set to True if we want to overwrite cached results

pipelines = {}

##############################################################################
# Set seed

seed = 475751

random.seed(seed)
np.random.seed(seed)
qiskit_algorithms.utils.algorithm_globals.random_seed


##############################################################################
# Set NCH strategy

strategy = "random" # or "random"
pipe_name = strategy.upper()

max_hull_per_class = 1 if strategy == "min" else 6
max_samples_per_hull = 15 if strategy == "min" else 25
samples_step = 1 if strategy == "min" else 5

##############################################################################
# Define spatial filtering

sf = make_pipeline(XdawnCovariances())

##############################################################################
# Define pipelines

for n_hulls_per_class in range(1, max_hull_per_class + 1, 1):
for n_samples_per_hull in range(1, max_samples_per_hull + 1, samples_step):
key = f"NCH+{pipe_name}_HULL_{n_hulls_per_class}h_{n_samples_per_hull}samples"
print(key)
pipelines[key] = make_pipeline(
sf,
QuanticNCH(
seed=seed,
n_hulls_per_class=n_hulls_per_class,
n_samples_per_hull=n_samples_per_hull,
n_jobs=12,
subsampling=strategy,
quantum=False,
),
)


print("Total pipelines to evaluate: ", len(pipelines))
print(np.unique(pipelines.keys()))

##############################################################################
# Run evaluation

evaluation = WithinSessionEvaluation(
paradigm=paradigm,
datasets=datasets,
suffix="examples",
overwrite=overwrite,
n_splits=None,
random_state=seed,
)


results = evaluation.process(pipelines)

##############################################################################
# Print results


def get_hull(v):
return int(v.split(f"NCH+{pipe_name}_HULL_")[1].split("h_")[0])


def get_samples(v):
return int(v.split("h_")[1].split("samples")[0])


results["n_hull"] = results["pipeline"].apply(get_hull)
results["n_samples"] = results["pipeline"].apply(get_samples)
print(results)

means = results.groupby("pipeline").mean()

if strategy == "random":
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.plot_trisurf(
means.n_hull, means.n_samples, means.score, cmap=cm.jet, linewidth=0.2
)
ax.set_xlabel("n_hull")
ax.set_ylabel("n_samples")
ax.set_zlabel("score")
else:
sns.pointplot(means, x="n_samples", y="score")
plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x, _: int(x)))

plt.show()
231 changes: 231 additions & 0 deletions examples/resting_states/noplot_nch_study.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""
====================================================================
Classification of P300 datasets from MOABB using NCH
====================================================================
Comparison of NCH with different optimization methods,
in a "hard" dataset (classical methods don't provide results).
"""
# Author: Gregoire Cattan, Quentin Barthelemy
# Modified from noplot_classify_P300_nch.py
# License: BSD (3-clause)

import random
import warnings

import numpy as np
import qiskit_algorithms
import seaborn as sns
from matplotlib import pyplot as plt
from moabb import set_log_level
from moabb.datasets import Cattan2019_PHMD
from moabb.evaluations import CrossSubjectEvaluation
from moabb.paradigms import RestingStateToP300Adapter
from pyriemann.classification import MDM
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from qiskit_algorithms.optimizers import SPSA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline

from pyriemann_qiskit.classification import QuanticNCH
from pyriemann_qiskit.utils.hyper_params_factory import create_mixer_rotational_X_gates

print(__doc__)

##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

warnings.filterwarnings("ignore")

set_log_level("info")

##############################################################################
# Set global seed for better reproducibility
seed = 475751

random.seed(seed)
np.random.seed(seed)
qiskit_algorithms.utils.algorithm_globals.random_seed

##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.

events = ["on", "off"]
paradigm = RestingStateToP300Adapter(events=events)

datasets = [Cattan2019_PHMD()]

overwrite = True # set to True if we want to overwrite cached results

pipelines = {}

n_hulls_per_class = 3
n_samples_per_hull = 6

sf = make_pipeline(
Covariances(estimator="lwf"),
)

##############################################################################
# NCH without quantum optimization
pipelines["NCH+RANDOM_HULL"] = make_pipeline(
sf,
QuanticNCH(
seed=seed,
n_hulls_per_class=n_hulls_per_class,
n_samples_per_hull=n_samples_per_hull,
n_jobs=12,
subsampling="random",
quantum=False,
),
)

pipelines["NCH+MIN_HULL"] = make_pipeline(
sf,
QuanticNCH(
seed=seed,
n_samples_per_hull=n_samples_per_hull,
n_jobs=12,
subsampling="min",
quantum=False,
),
)


##############################################################################
# NCH with quantum optimization
pipelines["NCH+RANDOM_HULL_QAOACV"] = make_pipeline(
sf,
QuanticNCH(
seed=seed,
n_hulls_per_class=n_hulls_per_class,
n_samples_per_hull=n_samples_per_hull,
n_jobs=12,
subsampling="random",
quantum=True,
create_mixer=create_mixer_rotational_X_gates(0),
shots=100,
qaoa_optimizer=SPSA(maxiter=100, blocking=False),
n_reps=2,
),
)

pipelines["NCH+RANDOM_HULL_NAIVEQAOA"] = make_pipeline(
sf,
QuanticNCH(
seed=seed,
n_hulls_per_class=n_hulls_per_class,
n_samples_per_hull=n_samples_per_hull,
n_jobs=12,
subsampling="random",
quantum=True,
),
)

pipelines["NCH+MIN_HULL_QAOACV"] = make_pipeline(
sf,
QuanticNCH(
seed=seed,
n_samples_per_hull=n_samples_per_hull,
n_jobs=12,
subsampling="min",
quantum=True,
create_mixer=create_mixer_rotational_X_gates(0),
shots=100,
qaoa_optimizer=SPSA(maxiter=100, blocking=False),
n_reps=2,
),
)

pipelines["NCH+MIN_HULL_NAIVEQAOA"] = make_pipeline(
sf,
QuanticNCH(
seed=seed,
n_samples_per_hull=n_samples_per_hull,
n_jobs=12,
subsampling="min",
quantum=True,
),
)

##############################################################################
# SOTA classical methods for comparison
pipelines["MDM"] = make_pipeline(
sf,
MDM(),
)

pipelines["TS+LDA"] = make_pipeline(
sf,
TangentSpace(metric="riemann"),
LDA(),
)

print("Total pipelines to evaluate: ", len(pipelines))

evaluation = CrossSubjectEvaluation(
paradigm=paradigm,
datasets=datasets,
suffix="examples",
overwrite=overwrite,
n_splits=3,
random_state=seed,
)

results = evaluation.process(pipelines)

print("Averaging the session performance:")
print(results.groupby("pipeline").mean("score")[["score", "time"]])

##############################################################################
# Plot Results
# ----------------
#
# Here we plot the results to compare the two pipelines

fig, ax = plt.subplots(facecolor="white", figsize=[8, 4])

order = [
"NCH+RANDOM_HULL",
"NCH+RANDOM_HULL_NAIVEQAOA",
"NCH+RANDOM_HULL_QAOACV",
"NCH+MIN_HULL",
"NCH+MIN_HULL_NAIVEQAOA",
"NCH+MIN_HULL_QAOACV",
"TS+LDA",
"MDM",
]

sns.stripplot(
data=results,
y="score",
x="pipeline",
ax=ax,
jitter=True,
alpha=0.5,
zorder=1,
palette="Set1",
order=order,
hue_order=order,
)
sns.pointplot(
data=results,
y="score",
x="pipeline",
ax=ax,
palette="Set1",
order=order,
hue_order=order,
)

ax.set_ylabel("ROC AUC")
ax.set_ylim(0.35, 0.7)
plt.xticks(rotation=45)
plt.show()

0 comments on commit d5eb32a

Please sign in to comment.