Skip to content

Commit

Permalink
Implementation of NCH (Nearest Convex Hull) classifier (#253)
Browse files Browse the repository at this point in the history
* Initial version of NearestConvexHull.

* Added script for testing.

* First version that runs.

* Improved code.

* Added support for parallel processing.
It gives an error: AttributeError: Pipeline has none of the following attributes: decision_function.

* renamed

* New version that uses a new class that implements a NCH classifier.

* small update

* Updated to newest code - the new version of the distance function.
Added an example that runs on a small number of test samples, so that we can get results quicker.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* reinforce constraint on weights

* - remove constraints on weights
- limite size of training set
- change to slsqp optimizer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Added n_max_hull parameter. MOABB support tested.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* added multiple hulls.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Code cleanups.
Added second parameter that specifies the number of hulls.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Improved code.
Added support for transform().
Added a new pipeline [NCH+LDA]

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* updated default parameters

* General improvements.
Improvements requested by GC.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* removed commented code

* Small adjustments.

* Better class separation.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Added support for n_samples_per_hull = -1 which takes all the samples for a class.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update pyriemann_qiskit/classification.py

Set of SPD matrices.

Co-authored-by: Quentin Barthélemy <[email protected]>

* Update pyriemann_qiskit/classification.py

Added new lines to before Parameters

Co-authored-by: Quentin Barthélemy <[email protected]>

* Update pyriemann_qiskit/classification.py

[y == c, :, :] => [y == c]

Co-authored-by: Quentin Barthélemy <[email protected]>

* Update pyriemann_qiskit/classification.py

NearestConvexHull text change

Co-authored-by: Quentin Barthélemy <[email protected]>

* Improvements proposed by Quentin.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Added comment for the optimizer.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Added some comments in classification.
Changes about the global optimizer so, that it is more evident that a global one is used.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Implemented min hull.
Added support for both "min-hull" and "random-hull" using the constructor parameter "hull-type".

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Reverted to previous version as requested by Gregoire.

* fix lint issues

* [pre-commit.ci] auto fixes from pre-commit.com hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: gcattan <[email protected]>
Co-authored-by: Gregoire Cattan <[email protected]>
Co-authored-by: Quentin Barthélemy <[email protected]>
  • Loading branch information
5 people authored Mar 18, 2024
1 parent a89478b commit a0241ce
Show file tree
Hide file tree
Showing 3 changed files with 484 additions and 2 deletions.
150 changes: 150 additions & 0 deletions examples/ERP/classify_P300_nch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
====================================================================
Classification of P300 datasets from MOABB using NCH
====================================================================
Demonstrates classification with QunatumNCH.
Evaluation is done using MOABB.
If parameter "shots" is None then a classical SVM is used similar to the one
in scikit learn.
If "shots" is not None and IBM Qunatum token is provided with "q_account_token"
then a real Quantum computer will be used.
You also need to adjust the "n_components" in the PCA procedure to the number
of qubits supported by the real quantum computer you are going to use.
A list of real quantum computers is available in your IBM quantum account.
"""
# Author: Anton Andreev
# Modified from plot_classify_EEG_tangentspace.py of pyRiemann
# License: BSD (3-clause)

from pyriemann.estimation import XdawnCovariances
from sklearn.pipeline import make_pipeline
from matplotlib import pyplot as plt
import warnings
import seaborn as sns
from moabb import set_log_level
from moabb.datasets import bi2013a
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import P300
from pyriemann_qiskit.classification import QuanticNCH
from pyriemann.classification import MDM

print(__doc__)

##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

warnings.filterwarnings("ignore")

set_log_level("info")

##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.

##############################################################################
# We have to do this because the classes are called 'Target' and 'NonTarget'
# but the evaluation function uses a LabelEncoder, transforming them
# to 0 and 1
labels_dict = {"Target": 1, "NonTarget": 0}

paradigm = P300(resample=128)

datasets = [bi2013a()] # MOABB provides several other P300 datasets

# reduce the number of subjects, the Quantum pipeline takes a lot of time
# if executed on the entire dataset
n_subjects = 1
for dataset in datasets:
dataset.subject_list = dataset.subject_list[0:n_subjects]

overwrite = True # set to True if we want to overwrite cached results

pipelines = {}

pipelines["NCH+RANDOM_HULL"] = make_pipeline(
# applies XDawn and calculates the covariance matrix, output it matrices
XdawnCovariances(
nfilter=3,
classes=[labels_dict["Target"]],
estimator="lwf",
xdawn_estimator="scm",
),
QuanticNCH(
n_hulls_per_class=1,
n_samples_per_hull=3,
n_jobs=12,
hull_type="random-hull",
quantum=False,
),
)

pipelines["NCH+MIN_HULL"] = make_pipeline(
# applies XDawn and calculates the covariance matrix, output it matrices
XdawnCovariances(
nfilter=3,
classes=[labels_dict["Target"]],
estimator="lwf",
xdawn_estimator="scm",
),
QuanticNCH(
n_hulls_per_class=1,
n_samples_per_hull=3,
n_jobs=12,
hull_type="min-hull",
quantum=False,
),
)

# this is a non quantum pipeline
pipelines["XD+MDM"] = make_pipeline(
XdawnCovariances(
nfilter=3,
classes=[labels_dict["Target"]],
estimator="lwf",
xdawn_estimator="scm",
),
MDM(),
)

print("Total pipelines to evaluate: ", len(pipelines))

evaluation = WithinSessionEvaluation(
paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite
)

results = evaluation.process(pipelines)

print("Averaging the session performance:")
print(results.groupby("pipeline").mean("score")[["score", "time"]])

##############################################################################
# Plot Results
# ----------------
#
# Here we plot the results to compare the two pipelines

fig, ax = plt.subplots(facecolor="white", figsize=[8, 4])

sns.stripplot(
data=results,
y="score",
x="pipeline",
ax=ax,
jitter=True,
alpha=0.5,
zorder=1,
palette="Set1",
)
sns.pointplot(data=results, y="score", x="pipeline", ax=ax, palette="Set1")

ax.set_ylabel("ROC AUC")
ax.set_ylim(0.3, 1)

plt.show()
Loading

0 comments on commit a0241ce

Please sign in to comment.