From 9a6da7f845bd7576e1d8ec895ecfdac90e2de52b Mon Sep 17 00:00:00 2001 From: Nicolas Legrand Date: Wed, 15 Jan 2025 12:17:38 +0100 Subject: [PATCH] add more tests and improve code coverage (#273) * improve tests and code coverage * notebook * more tests --- .../notebooks/0.3-Generalised_filtering.ipynb | 30 +++++++++-------- tests/test_nodes/test_exponential_family.py | 32 ++++++++++++++++++- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/docs/source/notebooks/0.3-Generalised_filtering.ipynb b/docs/source/notebooks/0.3-Generalised_filtering.ipynb index 5d19a1bf0..959326490 100644 --- a/docs/source/notebooks/0.3-Generalised_filtering.ipynb +++ b/docs/source/notebooks/0.3-Generalised_filtering.ipynb @@ -347,30 +347,32 @@ "tags": [] }, "source": [ - "````{note} From sufficient statistics to distribution parameters and backwards\n", + ":::{note} From sufficient statistics to distribution parameters and backwards\n", ":class: dropdown\n", "\n", "When using a 1-dimensional Gaussian distribution, Setting $\\xi = [0, \\frac{1}{8}]$ is equivalent to a mean $\\mu = 0.0$ and a variance $\\sigma^2 = \\frac{1}{8}$. You can convert between distribution parameters and expected sufficient statistics using the distribution classes from PyHGF (when implemented):\n", "\n", - "```{code-cell} python\n", + "```python\n", "from pyhgf.math import Normal\n", "\n", "# from an observation to sufficient statistics\n", "Normal.sufficient_statistics_from_observations(x=1.5)\n", "```\n", "> Array([1.5 , 2.25], dtype=float32)\n", - "```{code-cell} python\n", + "\n", + "```python\n", "# from distribution parameters to sufficient statistics\n", "Normal.sufficient_statistics_from_parameters(mean=0.0, variance=4.0)\n", "```\n", "> Array([0., 4.], dtype=float32)\n", - "```{code-cell} python\n", + "\n", + "```python\n", "# from sufficient statistics to distribution parameters\n", "Normal.parameters_from_sufficient_statistics(xis=[0.0, 4.0])\n", "```\n", "> (0.0, 4.0)\n", "\n", - "````" + ":::" ] }, { @@ -417,7 +419,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -665,7 +667,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 13, @@ -1239,7 +1241,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 21, @@ -1292,7 +1294,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Last updated: Tue Jan 14 2025\n", + "Last updated: Wed Jan 15 2025\n", "\n", "Python implementation: CPython\n", "Python version : 3.12.3\n", @@ -1302,13 +1304,13 @@ "jax : 0.4.31\n", "jaxlib: 0.4.31\n", "\n", - "numpy : 1.26.0\n", - "sys : 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0]\n", - "pyhgf : 0.2.1.post4.dev0+d49aafe9\n", + "matplotlib: 3.10.0\n", "jax : 0.4.31\n", - "IPython : 8.31.0\n", "seaborn : 0.13.2\n", - "matplotlib: 3.10.0\n", + "IPython : 8.31.0\n", + "sys : 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0]\n", + "pyhgf : 0.2.1.post4.dev0+d49aafe9\n", + "numpy : 1.26.0\n", "\n", "Watermark: 2.5.0\n", "\n" diff --git a/tests/test_nodes/test_exponential_family.py b/tests/test_nodes/test_exponential_family.py index 488070bb7..a1b0dfc97 100644 --- a/tests/test_nodes/test_exponential_family.py +++ b/tests/test_nodes/test_exponential_family.py @@ -46,7 +46,10 @@ def test_multivariate_gaussian(): + np.random.randn(N, 2) * 2 ) - # Python --------------------------------------------------------------------------- + # Python + # ---------------------------------------------------------------------------------- + + # generalised filtering bivariate_normal = ( PyNetwork() .add_nodes( @@ -64,3 +67,30 @@ def test_multivariate_gaussian(): dtype="float32", ), ).all() + + # hgf updates + bivariate_hgf = PyNetwork().add_nodes( + kind="ef-state", + learning="hgf-2", + distribution="multivariate-normal", + dimension=2, + ) + + # adapting prior parameter values to the sufficient statistics + # covariances statistics will have greater variability and amplitudes + for node_idx in [2, 5, 8, 11, 14]: + bivariate_hgf.attributes[node_idx]["tonic_volatility"] = -2.0 + for node_idx in [1, 4, 7, 10, 13]: + bivariate_hgf.attributes[node_idx]["precision"] = 0.01 + for node_idx in [9, 12, 15]: + bivariate_hgf.attributes[node_idx]["mean"] = 10.0 + + bivariate_hgf.input_data(input_data=spiral_data) + + assert jnp.isclose( + bivariate_normal.node_trajectories[0]["xis"][-1], + jnp.array( + [3.4652710e01, -1.0609777e00, 1.2103647e03, -3.6398651e01, 3.3951855e00], + dtype="float32", + ), + ).all()