Skip to content

Commit

Permalink
Add returns_A and returns_B parameters (#50)
Browse files Browse the repository at this point in the history
* Update get_covmat.py

* Update get_covmat.py

fix static method

* Update get_covmat.py

fix instantiation order

* Update README.md

* Update get_covmat.py

invert returns_A/B and seed

* Update test_get_covmat.py

update tests

* Update README.md

correct example 2

* 🎨 Format Python code with psf/black (#51)

Co-authored-by: gcattan <[email protected]>

* Update test_get_covmat.py

fix test_seed

* Update test_get_covmat.py

fix test returns A/B

* Update get_covmat.py

* Update get_covmat.py

fix all events returned even when no specified

* 🎨 Format Python code with psf/black (#52)

Co-authored-by: gcattan <[email protected]>

* Update get_covmat.py

parenthesis missing

* Update get_covmat.py

debug

* Update get_covmat.py

* Update get_covmat.py

fix wrong id used

* 🎨 Format Python code with psf/black (#53)

Co-authored-by: gcattan <[email protected]>

* Update get_covmat.py

* Update get_covmat.py

* Update test_get_covmat.py

fix instance not cleaned

* 🎨 Format Python code with psf/black (#54)

Co-authored-by: gcattan <[email protected]>

* Update test_get_covmat.py

* 🎨 Format Python code with psf/black (#55)

Co-authored-by: gcattan <[email protected]>

* Update test_get_covmat.py

missing import

* Update get_covmat.py

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: gcattan <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2024
1 parent dc38bd0 commit ee15bf8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ python setup.py develop

## Usage

### Example 1
```
from covmatest import get_covmat
n_matrices = 3
Expand All @@ -28,6 +29,14 @@ covmat = get_covmat(n_matrices, n_channels)
print(covmat)
```

### Example 2
```
from covmatest import get_covmat
n_matrices, n_channels = 3, 2
classA = get_covmat(n_matrices, n_channels, returns_A=True, returns_B=False)
classB = get_covmat(n_matrices, n_channels, returns_A=False, returns_B=True)
```

## Environment

- Ubuntu, Windows, MacOs
Expand Down
30 changes: 26 additions & 4 deletions covmatest/get_covmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_instance = None


def get_covmat(n_trials, n_channels, seed=None):
def get_covmat(n_trials, n_channels, returns_A=True, returns_B=True, seed=None):
"""Get a set of covariance matrices.
Parameters
Expand All @@ -34,6 +34,10 @@ def get_covmat(n_trials, n_channels, seed=None):
n_channels: int
The number of channels (>= 1 and <= 16)
in a matrix.
returns_A: boolean (default: True)
Return the "closed" epochs from the Alphawaves dataset.
returns_B: boolean (default: True)
Return the "open" epochs from the Alphawaves dataset.
seed: int|None (default: None)
The seed for the random number generator.
Expand All @@ -44,7 +48,7 @@ def get_covmat(n_trials, n_channels, seed=None):
"""
global _instance
if _instance is None:
_instance = CovmatGen(seed)
_instance = CovmatGen(returns_A, returns_B, seed)
elif seed is not None:
random.seed(seed)
return _instance.get_covmat(n_trials, n_channels)
Expand All @@ -56,6 +60,10 @@ class CovmatGen:
Parameters
----------
returns_A: boolean (default: True)
Return the "closed" epochs from the Alphawaves dataset.
returns_B: boolean (default: True)
Return the "open" epochs from the Alphawaves dataset.
seed: int|None (default: None)
The seed for the random number generator.
Expand All @@ -71,9 +79,11 @@ class CovmatGen:
"""

def __init__(self, seed=None):
def __init__(self, returns_A=True, returns_B=True, seed=None):
if seed is not None:
random.seed(seed)
self._returns_A = returns_A
self._returns_B = returns_B
self._seed = seed
self._dataset = AlphaWaves()
subject = self._get_random_subject()
Expand All @@ -92,7 +102,19 @@ def _get_random_subject(self):

def _get_trials(self):
events = mne.find_events(raw=self._raw, shortest_event=1, verbose=False)
event_id = {"closed": 1, "open": 2}

events = [
e
for e in events
if (e[2] == 1 and self._returns_A) or (e[2] == 2 and self._returns_B)
]

event_id = {}
if self._returns_A:
event_id["closed"] = 1
if self._returns_B:
event_id["open"] = 2

epochs = mne.Epochs(
self._raw,
events,
Expand Down
17 changes: 14 additions & 3 deletions tests/test_get_covmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,19 @@ def test_is_spd(is_spd):


def test_seed():
covmat1 = get_covmat(1, 1, 42)
covmat1bis = get_covmat(1, 1, 42)
covmat2 = get_covmat(1, 1, 43)
covmat1 = get_covmat(1, 1, seed=42)
covmat1bis = get_covmat(1, 1, seed=42)
covmat2 = get_covmat(1, 1, seed=43)
assert not covmat1[0][0] == covmat2[0][0]
assert covmat1[0][0] == covmat1bis[0][0]


def test_returns_A_B():
n_matrices, n_channels = 1, 1
classA = CovmatGen(returns_A=True, returns_B=False).get_covmat(
n_matrices, n_channels
)
classB = CovmatGen(returns_A=False, returns_B=True).get_covmat(
n_matrices, n_channels
)
assert not classA[0][0] == classB[0][0]

0 comments on commit ee15bf8

Please sign in to comment.