From ee15bf881bb24a6516a292d69e0828c6aa16fa4c Mon Sep 17 00:00:00 2001 From: gcattan Date: Mon, 22 Jan 2024 23:06:44 +0100 Subject: [PATCH] Add returns_A and returns_B parameters (#50) * Update get_covmat.py * Update get_covmat.py fix static method * Update get_covmat.py fix instantiation order * Update README.md * Update get_covmat.py invert returns_A/B and seed * Update test_get_covmat.py update tests * Update README.md correct example 2 * :art: Format Python code with psf/black (#51) Co-authored-by: gcattan * Update test_get_covmat.py fix test_seed * Update test_get_covmat.py fix test returns A/B * Update get_covmat.py * Update get_covmat.py fix all events returned even when no specified * :art: Format Python code with psf/black (#52) Co-authored-by: gcattan * Update get_covmat.py parenthesis missing * Update get_covmat.py debug * Update get_covmat.py * Update get_covmat.py fix wrong id used * :art: Format Python code with psf/black (#53) Co-authored-by: gcattan * Update get_covmat.py * Update get_covmat.py * Update test_get_covmat.py fix instance not cleaned * :art: Format Python code with psf/black (#54) Co-authored-by: gcattan * Update test_get_covmat.py * :art: Format Python code with psf/black (#55) Co-authored-by: gcattan * Update test_get_covmat.py missing import * Update get_covmat.py --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gcattan --- README.md | 9 +++++++++ covmatest/get_covmat.py | 30 ++++++++++++++++++++++++++---- tests/test_get_covmat.py | 17 ++++++++++++++--- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index affa4b7..da7b293 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ python setup.py develop ## Usage +### Example 1 ``` from covmatest import get_covmat n_matrices = 3 @@ -28,6 +29,14 @@ covmat = get_covmat(n_matrices, n_channels) print(covmat) ``` +### Example 2 +``` +from covmatest import get_covmat +n_matrices, n_channels = 3, 2 +classA = get_covmat(n_matrices, n_channels, returns_A=True, returns_B=False) +classB = get_covmat(n_matrices, n_channels, returns_A=False, returns_B=True) +``` + ## Environment - Ubuntu, Windows, MacOs diff --git a/covmatest/get_covmat.py b/covmatest/get_covmat.py index 4263906..4b1b230 100644 --- a/covmatest/get_covmat.py +++ b/covmatest/get_covmat.py @@ -24,7 +24,7 @@ _instance = None -def get_covmat(n_trials, n_channels, seed=None): +def get_covmat(n_trials, n_channels, returns_A=True, returns_B=True, seed=None): """Get a set of covariance matrices. Parameters @@ -34,6 +34,10 @@ def get_covmat(n_trials, n_channels, seed=None): n_channels: int The number of channels (>= 1 and <= 16) in a matrix. + returns_A: boolean (default: True) + Return the "closed" epochs from the Alphawaves dataset. + returns_B: boolean (default: True) + Return the "open" epochs from the Alphawaves dataset. seed: int|None (default: None) The seed for the random number generator. @@ -44,7 +48,7 @@ def get_covmat(n_trials, n_channels, seed=None): """ global _instance if _instance is None: - _instance = CovmatGen(seed) + _instance = CovmatGen(returns_A, returns_B, seed) elif seed is not None: random.seed(seed) return _instance.get_covmat(n_trials, n_channels) @@ -56,6 +60,10 @@ class CovmatGen: Parameters ---------- + returns_A: boolean (default: True) + Return the "closed" epochs from the Alphawaves dataset. + returns_B: boolean (default: True) + Return the "open" epochs from the Alphawaves dataset. seed: int|None (default: None) The seed for the random number generator. @@ -71,9 +79,11 @@ class CovmatGen: """ - def __init__(self, seed=None): + def __init__(self, returns_A=True, returns_B=True, seed=None): if seed is not None: random.seed(seed) + self._returns_A = returns_A + self._returns_B = returns_B self._seed = seed self._dataset = AlphaWaves() subject = self._get_random_subject() @@ -92,7 +102,19 @@ def _get_random_subject(self): def _get_trials(self): events = mne.find_events(raw=self._raw, shortest_event=1, verbose=False) - event_id = {"closed": 1, "open": 2} + + events = [ + e + for e in events + if (e[2] == 1 and self._returns_A) or (e[2] == 2 and self._returns_B) + ] + + event_id = {} + if self._returns_A: + event_id["closed"] = 1 + if self._returns_B: + event_id["open"] = 2 + epochs = mne.Epochs( self._raw, events, diff --git a/tests/test_get_covmat.py b/tests/test_get_covmat.py index 3824155..bbdd356 100644 --- a/tests/test_get_covmat.py +++ b/tests/test_get_covmat.py @@ -27,8 +27,19 @@ def test_is_spd(is_spd): def test_seed(): - covmat1 = get_covmat(1, 1, 42) - covmat1bis = get_covmat(1, 1, 42) - covmat2 = get_covmat(1, 1, 43) + covmat1 = get_covmat(1, 1, seed=42) + covmat1bis = get_covmat(1, 1, seed=42) + covmat2 = get_covmat(1, 1, seed=43) assert not covmat1[0][0] == covmat2[0][0] assert covmat1[0][0] == covmat1bis[0][0] + + +def test_returns_A_B(): + n_matrices, n_channels = 1, 1 + classA = CovmatGen(returns_A=True, returns_B=False).get_covmat( + n_matrices, n_channels + ) + classB = CovmatGen(returns_A=False, returns_B=True).get_covmat( + n_matrices, n_channels + ) + assert not classA[0][0] == classB[0][0]