diff --git a/docs/source/notebooks/0.3-Generalised_filtering.ipynb b/docs/source/notebooks/0.3-Generalised_filtering.ipynb index 5b63e5839..5d19a1bf0 100644 --- a/docs/source/notebooks/0.3-Generalised_filtering.ipynb +++ b/docs/source/notebooks/0.3-Generalised_filtering.ipynb @@ -29,12 +29,6 @@ "id": "31b80846", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:53:59.914472Z", - "iopub.status.busy": "2025-01-10T13:53:59.913668Z", - "iopub.status.idle": "2025-01-10T13:53:59.921214Z", - "shell.execute_reply": "2025-01-10T13:53:59.919815Z" - }, "slideshow": { "slide_type": "" }, @@ -60,12 +54,6 @@ "id": "6e337fd3-5a3e-4e0f-ab4f-e055cebfb7ff", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:53:59.930368Z", - "iopub.status.busy": "2025-01-10T13:53:59.924990Z", - "iopub.status.idle": "2025-01-10T13:54:02.194691Z", - "shell.execute_reply": "2025-01-10T13:54:02.193859Z" - }, "slideshow": { "slide_type": "" }, @@ -171,12 +159,6 @@ "id": "f0be17ad-5611-4c89-80a2-9e45b1ddffc4", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:54:02.197152Z", - "iopub.status.busy": "2025-01-10T13:54:02.196883Z", - "iopub.status.idle": "2025-01-10T13:54:02.200982Z", - "shell.execute_reply": "2025-01-10T13:54:02.200133Z" - }, "slideshow": { "slide_type": "" }, @@ -195,12 +177,6 @@ "id": "ba318975-ce19-4b47-9934-bd02c0130e10", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:54:02.203128Z", - "iopub.status.busy": "2025-01-10T13:54:02.202914Z", - "iopub.status.idle": "2025-01-10T13:54:02.823804Z", - "shell.execute_reply": "2025-01-10T13:54:02.822473Z" - }, "slideshow": { "slide_type": "" }, @@ -277,12 +253,6 @@ "id": "add927a3-4233-484d-ad01-74ad03b87b5b", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:54:02.826901Z", - "iopub.status.busy": "2025-01-10T13:54:02.826428Z", - "iopub.status.idle": "2025-01-10T13:54:03.406849Z", - "shell.execute_reply": "2025-01-10T13:54:03.406071Z" - }, "scrolled": true, "slideshow": { "slide_type": "" @@ -377,32 +347,30 @@ "tags": [] }, "source": [ - ":::{note} From sufficient statistics to distribution parameters and backwards\n", + "````{note} From sufficient statistics to distribution parameters and backwards\n", ":class: dropdown\n", "\n", "When using a 1-dimensional Gaussian distribution, Setting $\\xi = [0, \\frac{1}{8}]$ is equivalent to a mean $\\mu = 0.0$ and a variance $\\sigma^2 = \\frac{1}{8}$. You can convert between distribution parameters and expected sufficient statistics using the distribution classes from PyHGF (when implemented):\n", "\n", - "```python\n", + "```{code-cell} python\n", "from pyhgf.math import Normal\n", "\n", "# from an observation to sufficient statistics\n", "Normal.sufficient_statistics_from_observations(x=1.5)\n", "```\n", "> Array([1.5 , 2.25], dtype=float32)\n", - "\n", - "```python\n", + "```{code-cell} python\n", "# from distribution parameters to sufficient statistics\n", "Normal.sufficient_statistics_from_parameters(mean=0.0, variance=4.0)\n", "```\n", "> Array([0., 4.], dtype=float32)\n", - "\n", - "```python\n", + "```{code-cell} python\n", "# from sufficient statistics to distribution parameters\n", "Normal.parameters_from_sufficient_statistics(xis=[0.0, 4.0])\n", "```\n", "> (0.0, 4.0)\n", "\n", - ":::" + "````" ] }, { @@ -423,14 +391,7 @@ "cell_type": "code", "execution_count": 7, "id": "1798765e-3d65-4bfd-964b-7f9b6b0902be", - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-10T13:54:03.418432Z", - "iopub.status.busy": "2025-01-10T13:54:03.417963Z", - "iopub.status.idle": "2025-01-10T13:54:03.454413Z", - "shell.execute_reply": "2025-01-10T13:54:03.453224Z" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -456,7 +417,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -482,12 +443,6 @@ "id": "2d921e51-a940-42b2-88f2-e25bd7ab5a01", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:54:03.458489Z", - "iopub.status.busy": "2025-01-10T13:54:03.457995Z", - "iopub.status.idle": "2025-01-10T13:54:03.462879Z", - "shell.execute_reply": "2025-01-10T13:54:03.462074Z" - }, "slideshow": { "slide_type": "" }, @@ -543,12 +498,6 @@ "id": "0754380d-87d0-430e-a533-540c5f252091", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:54:03.672185Z", - "iopub.status.busy": "2025-01-10T13:54:03.671465Z", - "iopub.status.idle": "2025-01-10T13:54:03.946934Z", - "shell.execute_reply": "2025-01-10T13:54:03.946361Z" - }, "slideshow": { "slide_type": "" }, @@ -626,12 +575,6 @@ "id": "baf6e2fc-dc8d-46bb-896b-ff5c45a944ee", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:54:03.949485Z", - "iopub.status.busy": "2025-01-10T13:54:03.948882Z", - "iopub.status.idle": "2025-01-10T13:54:04.636554Z", - "shell.execute_reply": "2025-01-10T13:54:04.635183Z" - }, "slideshow": { "slide_type": "" }, @@ -722,7 +665,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 13, @@ -739,12 +682,6 @@ "execution_count": 14, "id": "2cc62622-0709-41f3-8ccd-d7024099de7a", "metadata": { - "execution": { - "iopub.execute_input": "2025-01-10T13:54:04.639360Z", - "iopub.status.busy": "2025-01-10T13:54:04.639144Z", - "iopub.status.idle": "2025-01-10T13:54:04.674265Z", - "shell.execute_reply": "2025-01-10T13:54:04.673450Z" - }, "scrolled": true }, "outputs": [], @@ -771,12 +708,6 @@ "id": "45cdf250-5a7c-4dc3-b5e7-46055cb35adf", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:54:04.676756Z", - "iopub.status.busy": "2025-01-10T13:54:04.676453Z", - "iopub.status.idle": "2025-01-10T13:54:04.897442Z", - "shell.execute_reply": "2025-01-10T13:54:04.896519Z" - }, "slideshow": { "slide_type": "" }, @@ -848,12 +779,6 @@ "id": "781b08fc-a2c7-4856-a103-beffd7787325", "metadata": { "editable": true, - "execution": { - "iopub.execute_input": "2025-01-10T13:54:04.900348Z", - "iopub.status.busy": "2025-01-10T13:54:04.899801Z", - "iopub.status.idle": "2025-01-10T13:54:04.911618Z", - "shell.execute_reply": "2025-01-10T13:54:04.911004Z" - }, "slideshow": { "slide_type": "" }, @@ -1314,7 +1239,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 21, @@ -1367,7 +1292,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Last updated: Wed Jan 15 2025\n", + "Last updated: Tue Jan 14 2025\n", "\n", "Python implementation: CPython\n", "Python version : 3.12.3\n", @@ -1377,13 +1302,13 @@ "jax : 0.4.31\n", "jaxlib: 0.4.31\n", "\n", - "matplotlib: 3.10.0\n", - "jax : 0.4.31\n", - "seaborn : 0.13.2\n", - "IPython : 8.31.0\n", + "numpy : 1.26.0\n", "sys : 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0]\n", "pyhgf : 0.2.1.post4.dev0+d49aafe9\n", - "numpy : 1.26.0\n", + "jax : 0.4.31\n", + "IPython : 8.31.0\n", + "seaborn : 0.13.2\n", + "matplotlib: 3.10.0\n", "\n", "Watermark: 2.5.0\n", "\n" @@ -1398,7 +1323,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pyhgf_dev", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/tests/test_nodes/test_exponential_family.py b/tests/test_nodes/test_exponential_family.py index a1b0dfc97..488070bb7 100644 --- a/tests/test_nodes/test_exponential_family.py +++ b/tests/test_nodes/test_exponential_family.py @@ -46,10 +46,7 @@ def test_multivariate_gaussian(): + np.random.randn(N, 2) * 2 ) - # Python - # ---------------------------------------------------------------------------------- - - # generalised filtering + # Python --------------------------------------------------------------------------- bivariate_normal = ( PyNetwork() .add_nodes( @@ -67,30 +64,3 @@ def test_multivariate_gaussian(): dtype="float32", ), ).all() - - # hgf updates - bivariate_hgf = PyNetwork().add_nodes( - kind="ef-state", - learning="hgf-2", - distribution="multivariate-normal", - dimension=2, - ) - - # adapting prior parameter values to the sufficient statistics - # covariances statistics will have greater variability and amplitudes - for node_idx in [2, 5, 8, 11, 14]: - bivariate_hgf.attributes[node_idx]["tonic_volatility"] = -2.0 - for node_idx in [1, 4, 7, 10, 13]: - bivariate_hgf.attributes[node_idx]["precision"] = 0.01 - for node_idx in [9, 12, 15]: - bivariate_hgf.attributes[node_idx]["mean"] = 10.0 - - bivariate_hgf.input_data(input_data=spiral_data) - - assert jnp.isclose( - bivariate_normal.node_trajectories[0]["xis"][-1], - jnp.array( - [3.4652710e01, -1.0609777e00, 1.2103647e03, -3.6398651e01, 3.3951855e00], - dtype="float32", - ), - ).all()