Skip to content

Commit

Permalink
attempt to fix regularization test
Browse files Browse the repository at this point in the history
  • Loading branch information
dengemann committed Aug 7, 2024
1 parent 06a23b1 commit 0b6a125
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 26 deletions.
8 changes: 8 additions & 0 deletions meeglet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def _compute_spectral_features(data, wavelets, features, out, info,
out.cov[:, :, i_foi] = np.real(out.csd[:, :, i_foi])

if 'cov_oas' in features:
np.set_printoptions(precision=10, suppress=False, linewidth=100)

out.cov_oas[:, :, i_foi] = out.cov[:, :, i_foi]
# The following code is adapted from scikit-learn implementation of
# Oracle Approximating Shrinkage (OAS) for covariance regularization.
Expand All @@ -259,6 +261,12 @@ def _compute_spectral_features(data, wavelets, features, out, info,
shrunk_cov = (1.0 - shrinkage) * emp_cov
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu
out.cov_oas[:, :, i_foi] = shrunk_cov
print(f"emp_cov: {emp_cov}")
print(f"mu: {mu}")
print(f"alpha: {alpha}")
print(f"num: {num}")
print(f"den: {den}")
print(f"shrinkage: {shrinkage}")

# coherence measures
if 'coh' in features or 'icoh' in features:
Expand Down
83 changes: 57 additions & 26 deletions nbs/api/wavelets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -24,7 +24,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -66,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -175,7 +175,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -322,6 +322,8 @@
" out.cov[:, :, i_foi] = np.real(out.csd[:, :, i_foi])\n",
"\n",
" if 'cov_oas' in features:\n",
" np.set_printoptions(precision=10, suppress=False, linewidth=100)\n",
"\n",
" out.cov_oas[:, :, i_foi] = out.cov[:, :, i_foi]\n",
" # The following code is adapted from scikit-learn implementation of\n",
" # Oracle Approximating Shrinkage (OAS) for covariance regularization.\n",
Expand All @@ -340,6 +342,12 @@
" shrunk_cov = (1.0 - shrinkage) * emp_cov\n",
" shrunk_cov.flat[:: n_features + 1] += shrinkage * mu\n",
" out.cov_oas[:, :, i_foi] = shrunk_cov\n",
" print(f\"emp_cov: {emp_cov}\")\n",
" print(f\"mu: {mu}\")\n",
" print(f\"alpha: {alpha}\")\n",
" print(f\"num: {num}\")\n",
" print(f\"den: {den}\")\n",
" print(f\"shrinkage: {shrinkage}\")\n",
"\n",
" # coherence measures\n",
" if 'coh' in features or 'icoh' in features:\n",
Expand Down Expand Up @@ -440,7 +448,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -594,7 +602,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -632,7 +640,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -666,7 +674,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -675,7 +683,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -686,7 +694,7 @@
" [ 1.03626232, -1.1151485 , -0.54523585]])"
]
},
"execution_count": null,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -721,7 +729,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -742,7 +750,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -771,7 +779,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -829,7 +837,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -868,7 +876,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -900,7 +908,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -926,7 +934,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -942,7 +950,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1069,7 +1077,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1195,7 +1203,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1265,7 +1273,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"metadata": {},
"outputs": [
{
Expand All @@ -1280,7 +1288,19 @@
" Average EEG reference (1 x 60) idle\n",
" Range : 12900 ... 18906 = 42.956 ... 62.955 secs\n",
"Ready.\n",
"Reading 0 ... 6006 = 0.000 ... 20.000 secs...\n"
"Reading 0 ... 6006 = 0.000 ... 20.000 secs...\n",
"emp_cov: [[169.1829331299 165.5540268772 170.7685469001 ... 41.7958519615 44.5094084744 40.9753816606]\n",
" [165.5540268772 171.3411701362 172.4168265474 ... 42.5657271299 45.1879297352 41.5968435755]\n",
" [170.7685469001 172.4168265474 183.4567440404 ... 43.6711276469 46.2338549713 42.751764722 ]\n",
" ...\n",
" [ 41.7958519615 42.5657271299 43.6711276469 ... 22.4542800943 19.0473951014 16.4332660708]\n",
" [ 44.5094084744 45.1879297352 46.2338549713 ... 19.0473951014 26.4031441727 17.6140693698]\n",
" [ 40.9753816606 41.5968435755 42.751764722 ... 16.4332660708 17.6140693698 19.9843687476]]\n",
"mu: 46.791718525647354\n",
"alpha: 1871.2825531727729\n",
"num: 4060.747475756183\n",
"den: 42185.978498915836\n",
"shrinkage: 0.09625822655412729\n"
]
}
],
Expand All @@ -1297,7 +1317,6 @@
"\n",
"def test_regularized_covariance():\n",
" \"Test spectral features array-interface against Matlab implementation.\"\n",
" matlab_results = get_matlab_results()\n",
"\n",
" raw = read_testing_data()\n",
" dat = raw.get_data() * 1e6\n",
Expand All @@ -1308,7 +1327,7 @@
" sfreq=raw.info['sfreq'],\n",
" bw_oct=0.5,\n",
" foi_start=2,\n",
" foi_end=32,\n",
" foi_end=2,\n",
" window_shift=0.25,\n",
" kernel_width=5,\n",
" allow_fraction_nan=0,\n",
Expand All @@ -1332,11 +1351,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
" #| hide\n",
"#| hide\n",
"from nbdev.doclinks import nbdev_export\n",
"nbdev_export()"
]
Expand All @@ -1347,6 +1366,18 @@
"display_name": "meeglet310",
"language": "python",
"name": "meeglet310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 0b6a125

Please sign in to comment.