Skip to content

Commit

Permalink
Add support for generalised Bayesian filtering with dynamic learning …
Browse files Browse the repository at this point in the history
…rate in JAX (#266)

* clarify docs for math distribution module + fix error for 1d Gaussian

* ef-state node supporting hgf learning

* api docs

* notebook

* univariate gaussian working

* add math modules

* support for ef distributions

* fix error with dirichlet nodes
  • Loading branch information
LegrandNico authored Jan 14, 2025
1 parent 52cfc0e commit aa56974
Show file tree
Hide file tree
Showing 13 changed files with 1,018 additions and 434 deletions.
13 changes: 12 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ Continuous nodes
continuous_node_posterior_update
continuous_node_posterior_update_ehgf

Exponential family
------------------

.. currentmodule:: pyhgf.updates.posterior.exponential

.. autosummary::
:toctree: generated/pyhgf.updates.posterior.exponential

posterior_update_exponential_family_dynamic

Prediction steps
================

Expand Down Expand Up @@ -144,7 +154,8 @@ Exponential family
.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.exponential

prediction_error_update_exponential_family
prediction_error_update_exponential_family_fixed
prediction_error_update_exponential_family_dynamic

Distribution
************
Expand Down
Binary file removed docs/source/images/multivariate_hgf.gif
Binary file not shown.
Binary file added docs/source/images/multivariate_normal.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,029 changes: 637 additions & 392 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb

Large diffs are not rendered by default.

87 changes: 77 additions & 10 deletions pyhgf/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,67 @@ class MultivariateNormal:
"""

@staticmethod
def sufficient_statistics(x: ArrayLike) -> Array:
"""Compute the sufficient statistics for the multivariate normal."""
def sufficient_statistics_from_observations(x: ArrayLike) -> Array:
"""Compute the expected sufficient statistics from a single observation."""
return jnp.hstack([x, jnp.outer(x, x)[jnp.tril_indices(x.shape[0])]])

@staticmethod
def sufficient_statistics_from_parameters(
mean: ArrayLike, covariance: ArrayLike
) -> Array:
"""Compute the expected sufficient statistics from distribution parameter.
Parameters
----------
mean :
Mean of the Gaussian distribution.
covariance :
Variance of the Gaussian distribution.
Returns
-------
xis :
The sufficient statistics.
"""
return jnp.append(
mean,
(covariance + jnp.outer(mean, mean))[jnp.tril_indices(covariance.shape[0])],
)

@staticmethod
def base_measure(k: int) -> float:
"""Compute the base measures for the multivariate normal."""
return (2 * jnp.pi) ** (-k / 2)

@staticmethod
def parameters_from_sufficient_statistics(
xis: ArrayLike, dimension: int
) -> Tuple[Array, Array]:
"""Compute the distribution parameters from the sufficient statistics.
Parameters
----------
xis :
The sufficient statistics.
dimension :
The dimension of the multivariate normal distribution.
Returns
-------
means, covariance :
The parameters of the distribution (mean and covariance).
"""
mean = xis[:dimension]
covariance = jnp.zeros((dimension, dimension))
covariance = covariance.at[jnp.tril_indices(dimension)].set(
xis[dimension:] - jnp.outer(mean, mean)[jnp.tril_indices(dimension)]
)
covariance += covariance.T - jnp.diag(covariance.diagonal())

return mean, covariance


class Normal:
"""The univariate normal as an exponential family distribution.
Expand All @@ -38,28 +90,42 @@ class Normal:
"""

@staticmethod
def sufficient_statistics(x: float) -> Array:
"""Sufficient statistics for the univariate normal."""
def sufficient_statistics_from_observations(x: float) -> Array:
"""Compute the expected sufficient statistics from a single observation."""
return jnp.array([x, x**2])

@staticmethod
def expected_sufficient_statistics(mu: float, sigma) -> Array:
"""Compute expected sufficient statistics from the mean and std."""
return jnp.array([mu, mu**2 + sigma**2])
def sufficient_statistics_from_parameters(mean: float, variance: float) -> Array:
"""Compute the expected sufficient statistics from distribution parameter.
Parameters
----------
mean :
Mean of the Gaussian distribution.
variance :
Variance of the Gaussian distribution.
Returns
-------
xis :
The sufficient statistics.
"""
return jnp.array([mean, mean**2 + variance])

@staticmethod
def base_measure() -> float:
"""Compute the base measure of the univariate normal."""
return 1 / (jnp.sqrt(2 * jnp.pi))

@staticmethod
def parameters(xis: ArrayLike) -> Tuple[float, float]:
"""Get parameters from the expected sufficient statistics.
def parameters_from_sufficient_statistics(xis: ArrayLike) -> Tuple[float, float]:
"""Compute the distribution parameters from the sufficient statistics.
Parameters
----------
xis :
The expected sufficient statistics.
The sufficient statistics.
Returns
-------
Expand All @@ -69,6 +135,7 @@ def parameters(xis: ArrayLike) -> Tuple[float, float]:
"""
mean = xis[0]
variance = xis[1] - (mean**2)

return mean, variance


Expand Down
62 changes: 47 additions & 15 deletions pyhgf/model/add_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,27 @@ def add_ef_state(
"learning": "generalised-filtering",
"nus": 3.0,
"xis": jnp.array([0.0, 1.0]),
"mean": 0.0,
"observed": 1,
}

node_parameters = update_parameters(
node_parameters, default_parameters, additional_parameters
)

# the size of the sufficient statistics vector of a multivariate normal
# distribution is given by d + d(d+1) / 2, where d is the dimension
d = node_parameters["dimension"]
n_suff_stats = d + d * (d + 1) // 2
node_parameters["mean"] = jnp.zeros(d) if d > 1 else 0.0
node_parameters["observation_ss"] = jnp.zeros(n_suff_stats)
if node_parameters["distribution"] == "normal":
node_parameters["xis"] = jnp.array([0.0, 1.0])
elif node_parameters["distribution"] == "multivariate-normal":
node_parameters["xis"] = (
MultivariateNormal.sufficient_statistics_from_parameters(
mean=jnp.zeros(d), covariance=jnp.identity(d)
)
)
network = insert_nodes(
network=network,
n_nodes=n_nodes,
Expand All @@ -147,21 +160,40 @@ def add_ef_state(
# loop over the indexes of nodes created in the previous step
for node_idx in range(network.n_nodes - 1, network.n_nodes - n_nodes - 1, -1):

if network.attributes[node_idx]["learning"] == "generalised-filtering":

# create the sufficient statistic function and store in the side parameters
if network.attributes[node_idx]["distribution"] == "normal":
sufficient_stats_fn = Normal().sufficient_statistics
elif network.attributes[node_idx]["distribution"] == "multivariate-normal":
sufficient_stats_fn = MultivariateNormal().sufficient_statistics

network.attributes[node_idx].pop("distribution")
network.attributes[node_idx].pop("learning")
# create the sufficient statistic function and store in the side parameters
if network.attributes[node_idx]["distribution"] == "normal":
sufficient_stats_fn = Normal().sufficient_statistics_from_observations
elif network.attributes[node_idx]["distribution"] == "multivariate-normal":
sufficient_stats_fn = (
MultivariateNormal().sufficient_statistics_from_observations
)
else:
raise ValueError(
"The distribution should be either 'normal' or 'multivariate-normal'."
)

# add the sufficient statistics function in the side parameters
network.additional_parameters.setdefault(node_idx, {})[
"sufficient_stats_fn"
] = sufficient_stats_fn
# add the sufficient statistics function in the side parameters
network.additional_parameters.setdefault(node_idx, {})[
"sufficient_stats_fn"
] = sufficient_stats_fn

if "hgf" in network.attributes[node_idx]["learning"]:

# create a collection of continuous state nodes
# to track the sufficient statistics of the implied distribution
for i in range(n_suff_stats):
network.add_nodes(value_children=node_idx)
network.add_nodes(value_children=network.n_nodes - 1)
if (
"-2" in network.attributes[node_idx]["learning"]
or "-3" in network.attributes[node_idx]["learning"]
):
network.add_nodes(volatility_children=network.n_nodes - 1)
if "-3" in network.attributes[node_idx]["learning"]:
network.add_nodes(volatility_children=network.n_nodes - 1)

network.attributes[node_idx].pop("distribution")
network.attributes[node_idx].pop("learning")

return network

Expand Down
82 changes: 82 additions & 0 deletions pyhgf/updates/posterior/exponential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import Dict

import jax.numpy as jnp
from jax import jit

from pyhgf.typing import Attributes, Edges


@partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn"))
def posterior_update_exponential_family_dynamic(
attributes: Dict, edges: Edges, node_idx: int, **args
) -> Attributes:
r"""Update the hyperparameters of an ef state node using HGF-implied learning rates.
This posterior update step is usually moved at the end of the update sequence as we
have to wait that all parent nodes tracking the expected sufficient statistics have
been updated, and therefore being able to infer the implied learning rate to update
the :math:`nu` vector. The new impled :math:`nu` is given by a ratio:
.. math::
\nu \leftarrow \frac{\delta}{\Delta}
Where :math:`delta` is the prediction error (the new sufficient statistics compared
to the expected sufficient statistic), and :math:`Delta` is the differential of
expectation (what was expected before compared to what is expected after). This
ratio quantifies how much the model is learning from new observations.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node
number. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
Returns
-------
attributes :
The updated attributes of the probabilistic nodes.
References
----------
.. [1] Mathys, C., & Weber, L. (2020). Hierarchical Gaussian Filtering of Sufficient
Statistic Time Series for Active Inference. In Active Inference (pp. 52–58).
Springer International Publishing. https://doi.org/10.1007/978-3-030-64919-7_7
"""
# prediction error - expectation differential
pe, ed = [], []
for parent_idx in edges[node_idx].value_parents or []:

pe.append(
attributes[parent_idx]["mean"] - attributes[parent_idx]["expected_mean"]
)

parent_parent_idx = edges[parent_idx].value_parents[0]
ed.append(
attributes[parent_parent_idx]["mean"]
- attributes[parent_parent_idx]["expected_mean"]
)

# implied learning rate
attributes[node_idx]["nus"] = (jnp.array(pe) / jnp.array(ed)).mean()

# apply the Bayesian update using fixed learning rates nus
xis = attributes[node_idx]["xis"] + (1 / (1 + attributes[node_idx]["nus"])) * (
attributes[node_idx]["observation_ss"] - attributes[node_idx]["xis"]
)

# blank update in the case of unobserved value
attributes[node_idx]["xis"] = jnp.where(
attributes[node_idx]["observed"], xis, attributes[node_idx]["xis"]
)

return attributes
4 changes: 3 additions & 1 deletion pyhgf/updates/prediction/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def dirichlet_node_prediction(
if value_parent_idxs is not None:
parameters = jnp.array(
[
Normal().parameters(xis=attributes[parent_idx]["xis"])
Normal().parameters_from_sufficient_statistics(
xis=attributes[parent_idx]["xis"]
)
for parent_idx in value_parent_idxs
]
)
Expand Down
4 changes: 2 additions & 2 deletions pyhgf/updates/prediction_error/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def create_cluster(operands: Tuple, edges: Edges, node_idx: int) -> Attributes:
# initialize the new cluster using candidate values
attributes[value_parent_idx]["xis"] = jnp.where(
cluster_idx == i,
Normal().expected_sufficient_statistics(
mu=candidate_mean, sigma=candidate_sigma
Normal().sufficient_statistics_from_parameters(
mean=candidate_mean, variance=candidate_sigma**2
),
attributes[value_parent_idx]["xis"],
)
Expand Down
Loading

0 comments on commit aa56974

Please sign in to comment.