Skip to content

Commit

Permalink
Add support for generalised Bayesian filtering with dynamic learning …
Browse files Browse the repository at this point in the history
…rate in JAX (ComputationalPsychiatry#266)

* clarify docs for math distribution module + fix error for 1d Gaussian

* ef-state node supporting hgf learning

* api docs

* notebook

* univariate gaussian working

* add math modules

* support for ef distributions

* fix error with dirichlet nodes
  • Loading branch information
LegrandNico authored and SylvainEstebe committed Jan 17, 2025
1 parent 9edfb4d commit 6f55a36
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 122 deletions.
107 changes: 16 additions & 91 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@
"id": "31b80846",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:53:59.914472Z",
"iopub.status.busy": "2025-01-10T13:53:59.913668Z",
"iopub.status.idle": "2025-01-10T13:53:59.921214Z",
"shell.execute_reply": "2025-01-10T13:53:59.919815Z"
},
"slideshow": {
"slide_type": ""
},
Expand All @@ -60,12 +54,6 @@
"id": "6e337fd3-5a3e-4e0f-ab4f-e055cebfb7ff",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:53:59.930368Z",
"iopub.status.busy": "2025-01-10T13:53:59.924990Z",
"iopub.status.idle": "2025-01-10T13:54:02.194691Z",
"shell.execute_reply": "2025-01-10T13:54:02.193859Z"
},
"slideshow": {
"slide_type": ""
},
Expand Down Expand Up @@ -171,12 +159,6 @@
"id": "f0be17ad-5611-4c89-80a2-9e45b1ddffc4",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:54:02.197152Z",
"iopub.status.busy": "2025-01-10T13:54:02.196883Z",
"iopub.status.idle": "2025-01-10T13:54:02.200982Z",
"shell.execute_reply": "2025-01-10T13:54:02.200133Z"
},
"slideshow": {
"slide_type": ""
},
Expand All @@ -195,12 +177,6 @@
"id": "ba318975-ce19-4b47-9934-bd02c0130e10",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:54:02.203128Z",
"iopub.status.busy": "2025-01-10T13:54:02.202914Z",
"iopub.status.idle": "2025-01-10T13:54:02.823804Z",
"shell.execute_reply": "2025-01-10T13:54:02.822473Z"
},
"slideshow": {
"slide_type": ""
},
Expand Down Expand Up @@ -277,12 +253,6 @@
"id": "add927a3-4233-484d-ad01-74ad03b87b5b",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:54:02.826901Z",
"iopub.status.busy": "2025-01-10T13:54:02.826428Z",
"iopub.status.idle": "2025-01-10T13:54:03.406849Z",
"shell.execute_reply": "2025-01-10T13:54:03.406071Z"
},
"scrolled": true,
"slideshow": {
"slide_type": ""
Expand Down Expand Up @@ -377,32 +347,30 @@
"tags": []
},
"source": [
":::{note} From sufficient statistics to distribution parameters and backwards\n",
"````{note} From sufficient statistics to distribution parameters and backwards\n",
":class: dropdown\n",
"\n",
"When using a 1-dimensional Gaussian distribution, Setting $\\xi = [0, \\frac{1}{8}]$ is equivalent to a mean $\\mu = 0.0$ and a variance $\\sigma^2 = \\frac{1}{8}$. You can convert between distribution parameters and expected sufficient statistics using the distribution classes from PyHGF (when implemented):\n",
"\n",
"```python\n",
"```{code-cell} python\n",
"from pyhgf.math import Normal\n",
"\n",
"# from an observation to sufficient statistics\n",
"Normal.sufficient_statistics_from_observations(x=1.5)\n",
"```\n",
"> Array([1.5 , 2.25], dtype=float32)\n",
"\n",
"```python\n",
"```{code-cell} python\n",
"# from distribution parameters to sufficient statistics\n",
"Normal.sufficient_statistics_from_parameters(mean=0.0, variance=4.0)\n",
"```\n",
"> Array([0., 4.], dtype=float32)\n",
"\n",
"```python\n",
"```{code-cell} python\n",
"# from sufficient statistics to distribution parameters\n",
"Normal.parameters_from_sufficient_statistics(xis=[0.0, 4.0])\n",
"```\n",
"> (0.0, 4.0)\n",
"\n",
":::"
"````"
]
},
{
Expand All @@ -423,14 +391,7 @@
"cell_type": "code",
"execution_count": 7,
"id": "1798765e-3d65-4bfd-964b-7f9b6b0902be",
"metadata": {
"execution": {
"iopub.execute_input": "2025-01-10T13:54:03.418432Z",
"iopub.status.busy": "2025-01-10T13:54:03.417963Z",
"iopub.status.idle": "2025-01-10T13:54:03.454413Z",
"shell.execute_reply": "2025-01-10T13:54:03.453224Z"
}
},
"metadata": {},
"outputs": [
{
"data": {
Expand All @@ -456,7 +417,7 @@
"</svg>\n"
],
"text/plain": [
"<graphviz.sources.Source at 0x7fec53b15520>"
"<graphviz.sources.Source at 0x7f69845ec530>"
]
},
"execution_count": 7,
Expand All @@ -482,12 +443,6 @@
"id": "2d921e51-a940-42b2-88f2-e25bd7ab5a01",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:54:03.458489Z",
"iopub.status.busy": "2025-01-10T13:54:03.457995Z",
"iopub.status.idle": "2025-01-10T13:54:03.462879Z",
"shell.execute_reply": "2025-01-10T13:54:03.462074Z"
},
"slideshow": {
"slide_type": ""
},
Expand Down Expand Up @@ -543,12 +498,6 @@
"id": "0754380d-87d0-430e-a533-540c5f252091",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:54:03.672185Z",
"iopub.status.busy": "2025-01-10T13:54:03.671465Z",
"iopub.status.idle": "2025-01-10T13:54:03.946934Z",
"shell.execute_reply": "2025-01-10T13:54:03.946361Z"
},
"slideshow": {
"slide_type": ""
},
Expand Down Expand Up @@ -626,12 +575,6 @@
"id": "baf6e2fc-dc8d-46bb-896b-ff5c45a944ee",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:54:03.949485Z",
"iopub.status.busy": "2025-01-10T13:54:03.948882Z",
"iopub.status.idle": "2025-01-10T13:54:04.636554Z",
"shell.execute_reply": "2025-01-10T13:54:04.635183Z"
},
"slideshow": {
"slide_type": ""
},
Expand Down Expand Up @@ -722,7 +665,7 @@
"</svg>\n"
],
"text/plain": [
"<graphviz.sources.Source at 0x7fec4dea0950>"
"<graphviz.sources.Source at 0x7f6984591040>"
]
},
"execution_count": 13,
Expand All @@ -739,12 +682,6 @@
"execution_count": 14,
"id": "2cc62622-0709-41f3-8ccd-d7024099de7a",
"metadata": {
"execution": {
"iopub.execute_input": "2025-01-10T13:54:04.639360Z",
"iopub.status.busy": "2025-01-10T13:54:04.639144Z",
"iopub.status.idle": "2025-01-10T13:54:04.674265Z",
"shell.execute_reply": "2025-01-10T13:54:04.673450Z"
},
"scrolled": true
},
"outputs": [],
Expand All @@ -771,12 +708,6 @@
"id": "45cdf250-5a7c-4dc3-b5e7-46055cb35adf",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:54:04.676756Z",
"iopub.status.busy": "2025-01-10T13:54:04.676453Z",
"iopub.status.idle": "2025-01-10T13:54:04.897442Z",
"shell.execute_reply": "2025-01-10T13:54:04.896519Z"
},
"slideshow": {
"slide_type": ""
},
Expand Down Expand Up @@ -848,12 +779,6 @@
"id": "781b08fc-a2c7-4856-a103-beffd7787325",
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-01-10T13:54:04.900348Z",
"iopub.status.busy": "2025-01-10T13:54:04.899801Z",
"iopub.status.idle": "2025-01-10T13:54:04.911618Z",
"shell.execute_reply": "2025-01-10T13:54:04.911004Z"
},
"slideshow": {
"slide_type": ""
},
Expand Down Expand Up @@ -1314,7 +1239,7 @@
"</svg>\n"
],
"text/plain": [
"<graphviz.sources.Source at 0x7fec3c292780>"
"<graphviz.sources.Source at 0x7f69784b16a0>"
]
},
"execution_count": 21,
Expand Down Expand Up @@ -1367,7 +1292,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Last updated: Wed Jan 15 2025\n",
"Last updated: Tue Jan 14 2025\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.12.3\n",
Expand All @@ -1377,13 +1302,13 @@
"jax : 0.4.31\n",
"jaxlib: 0.4.31\n",
"\n",
"matplotlib: 3.10.0\n",
"jax : 0.4.31\n",
"seaborn : 0.13.2\n",
"IPython : 8.31.0\n",
"numpy : 1.26.0\n",
"sys : 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0]\n",
"pyhgf : 0.2.1.post4.dev0+d49aafe9\n",
"numpy : 1.26.0\n",
"jax : 0.4.31\n",
"IPython : 8.31.0\n",
"seaborn : 0.13.2\n",
"matplotlib: 3.10.0\n",
"\n",
"Watermark: 2.5.0\n",
"\n"
Expand All @@ -1398,7 +1323,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "pyhgf_dev",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down
32 changes: 1 addition & 31 deletions tests/test_nodes/test_exponential_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ def test_multivariate_gaussian():
+ np.random.randn(N, 2) * 2
)

# Python
# ----------------------------------------------------------------------------------

# generalised filtering
# Python ---------------------------------------------------------------------------
bivariate_normal = (
PyNetwork()
.add_nodes(
Expand All @@ -67,30 +64,3 @@ def test_multivariate_gaussian():
dtype="float32",
),
).all()

# hgf updates
bivariate_hgf = PyNetwork().add_nodes(
kind="ef-state",
learning="hgf-2",
distribution="multivariate-normal",
dimension=2,
)

# adapting prior parameter values to the sufficient statistics
# covariances statistics will have greater variability and amplitudes
for node_idx in [2, 5, 8, 11, 14]:
bivariate_hgf.attributes[node_idx]["tonic_volatility"] = -2.0
for node_idx in [1, 4, 7, 10, 13]:
bivariate_hgf.attributes[node_idx]["precision"] = 0.01
for node_idx in [9, 12, 15]:
bivariate_hgf.attributes[node_idx]["mean"] = 10.0

bivariate_hgf.input_data(input_data=spiral_data)

assert jnp.isclose(
bivariate_normal.node_trajectories[0]["xis"][-1],
jnp.array(
[3.4652710e01, -1.0609777e00, 1.2103647e03, -3.6398651e01, 3.3951855e00],
dtype="float32",
),
).all()

0 comments on commit 6f55a36

Please sign in to comment.