-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for generalised Bayesian filtering with dynamic learning …
…rate in JAX (#266) * clarify docs for math distribution module + fix error for 1d Gaussian * ef-state node supporting hgf learning * api docs * notebook * univariate gaussian working * add math modules * support for ef distributions * fix error with dirichlet nodes
- Loading branch information
1 parent
52cfc0e
commit aa56974
Showing
13 changed files
with
1,018 additions
and
434 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,029 changes: 637 additions & 392 deletions
1,029
docs/source/notebooks/0.3-Generalised_filtering.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Author: Nicolas Legrand <[email protected]> | ||
|
||
from functools import partial | ||
from typing import Dict | ||
|
||
import jax.numpy as jnp | ||
from jax import jit | ||
|
||
from pyhgf.typing import Attributes, Edges | ||
|
||
|
||
@partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn")) | ||
def posterior_update_exponential_family_dynamic( | ||
attributes: Dict, edges: Edges, node_idx: int, **args | ||
) -> Attributes: | ||
r"""Update the hyperparameters of an ef state node using HGF-implied learning rates. | ||
This posterior update step is usually moved at the end of the update sequence as we | ||
have to wait that all parent nodes tracking the expected sufficient statistics have | ||
been updated, and therefore being able to infer the implied learning rate to update | ||
the :math:`nu` vector. The new impled :math:`nu` is given by a ratio: | ||
.. math:: | ||
\nu \leftarrow \frac{\delta}{\Delta} | ||
Where :math:`delta` is the prediction error (the new sufficient statistics compared | ||
to the expected sufficient statistic), and :math:`Delta` is the differential of | ||
expectation (what was expected before compared to what is expected after). This | ||
ratio quantifies how much the model is learning from new observations. | ||
Parameters | ||
---------- | ||
attributes : | ||
The attributes of the probabilistic nodes. | ||
edges : | ||
The edges of the probabilistic nodes as a tuple of | ||
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node | ||
number. For each node, the index lists the value and volatility parents and | ||
children. | ||
node_idx : | ||
Pointer to the value parent node that will be updated. | ||
Returns | ||
------- | ||
attributes : | ||
The updated attributes of the probabilistic nodes. | ||
References | ||
---------- | ||
.. [1] Mathys, C., & Weber, L. (2020). Hierarchical Gaussian Filtering of Sufficient | ||
Statistic Time Series for Active Inference. In Active Inference (pp. 52–58). | ||
Springer International Publishing. https://doi.org/10.1007/978-3-030-64919-7_7 | ||
""" | ||
# prediction error - expectation differential | ||
pe, ed = [], [] | ||
for parent_idx in edges[node_idx].value_parents or []: | ||
|
||
pe.append( | ||
attributes[parent_idx]["mean"] - attributes[parent_idx]["expected_mean"] | ||
) | ||
|
||
parent_parent_idx = edges[parent_idx].value_parents[0] | ||
ed.append( | ||
attributes[parent_parent_idx]["mean"] | ||
- attributes[parent_parent_idx]["expected_mean"] | ||
) | ||
|
||
# implied learning rate | ||
attributes[node_idx]["nus"] = (jnp.array(pe) / jnp.array(ed)).mean() | ||
|
||
# apply the Bayesian update using fixed learning rates nus | ||
xis = attributes[node_idx]["xis"] + (1 / (1 + attributes[node_idx]["nus"])) * ( | ||
attributes[node_idx]["observation_ss"] - attributes[node_idx]["xis"] | ||
) | ||
|
||
# blank update in the case of unobserved value | ||
attributes[node_idx]["xis"] = jnp.where( | ||
attributes[node_idx]["observed"], xis, attributes[node_idx]["xis"] | ||
) | ||
|
||
return attributes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.