Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JGET HGF (volatility coupling for continuous input nodes) #125

Merged
merged 9 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/ambv/black
rev: 23.9.1
rev: 23.10.1
hooks:
- id: black
language_version: python3
Expand All @@ -22,10 +22,10 @@ repos:
hooks:
- id: pydocstyle
args: ['--ignore', 'D213,D100,D203,D104']
files: ^pyhgf/
files: ^src/
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.6.0'
rev: 'v1.6.1'
hooks:
- id: mypy
files: ^pyhgf/
files: ^src/
args: [--ignore-missing-imports]
36 changes: 28 additions & 8 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,51 @@ Propagate prediction errors to the value and volatility parents of a given node.
Binary nodes
~~~~~~~~~~~~

.. currentmodule:: pyhgf.updates.prediction_error.binary
.. currentmodule:: pyhgf.updates.prediction_error.inputs.binary

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.binary
:toctree: generated/pyhgf.updates.prediction_error.inputs.binary

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_value_parent
prediction_error_input_value_parent
input_surprise_inf
input_surprise_reg

.. currentmodule:: pyhgf.updates.prediction_error.nodes.binary

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.nodes.binary

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_value_parent

Continuous nodes
~~~~~~~~~~~~~~~~

.. currentmodule:: pyhgf.updates.prediction_error.continuous
Updating continuous input nodes.

.. currentmodule:: pyhgf.updates.prediction_error.inputs.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.continuous
:toctree: generated/pyhgf.updates.prediction_error.inputs.continuous

prediction_error_input_precision_value_parent
prediction_error_input_precision_volatility_parent
prediction_error_input_mean_volatility_parent
prediction_error_input_mean_value_parent


Updating continuous state nodes.

.. currentmodule:: pyhgf.updates.prediction_error.nodes.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.nodes.continuous

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_precision_volatility_parent
prediction_error_mean_volatility_parent
prediction_error_input_mean_value_parent

Prediction steps
================
Expand Down
179 changes: 111 additions & 68 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb

Large diffs are not rendered by default.

536 changes: 536 additions & 0 deletions docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb

Large diffs are not rendered by default.

202 changes: 202 additions & 0 deletions docs/source/notebooks/Example_2_Input_node_volatility_coupling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

(example_1)=
# Example 2: Estimating the mean and precision of an input node

```{code-cell} ipython3
%%capture
import sys
if 'google.colab' in sys.modules:
! pip install pyhgf
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
from pyhgf.distribution import HGFDistribution
from pyhgf.model import HGF
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import norm
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

Where the standard continuous HGF assumes a known precision in the input node (usually set to something high), this assumption can be relaxed and the filter can also try to estimate this quantity from the data. In this notebook, we demonstrate how we can infer the value of the mean, of the precision, or both value at the same time, using the appropriate value and volatility coupling parents.

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## Unkown mean, known precision

+++ {"editable": true, "slideshow": {"slide_type": ""}}

```{hint}
The {ref}`continuous_hgf` is an example of a model assuming a continuous input with known precision and unknown mean. It is further assumed that the mean is changing overtime, and we want the model to track this rate of change by adding a volatility node on the top of the value parent (two-level continuous HGF), and event track the rate of change of this rate of change by adding another volatility parent (three-level continuous HGF).
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
dist_mean, dist_std = 5, 1
input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000)
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
mean_hgf = (
HGF(model_type=None)
.add_input_node(kind="continuous", continuous_parameters={'continuous_precision': 1})
.add_value_parent(children_idxs=[0], tonic_volatility=-8.0)
.init()
).input_data(input_data)
mean_hgf.plot_network()
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

```{note}
We are setting the tonic volatility to something low for visualization purposes, but changing this value can make the model learn in fewer iterations.
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
tags: [hide-input]
---
# get the nodes trajectories
df = mean_hgf.to_pandas()

fig, ax = plt.subplots(figsize=(12, 5))

x = np.linspace(-10, 10, 1000)
for i, color in zip([0, 2, 5, 10, 50, 500], plt.cm.Greys(np.linspace(.2, 1, 6))):

# extract the sufficient statistics from the input node (and parents)
mean = df.x_1_expected_mean.iloc[i]
std = np.sqrt(
1/(mean_hgf.attributes[0]["expected_precision"])
)

# the model expectations
ax.plot(x, norm(mean, std).pdf(x), color=color, label=i)


# the sampling distribution
ax.fill_between(x, norm(dist_mean, dist_std).pdf(x), color="#582766", alpha=.2)

ax.legend(title="Iterations")
ax.set_xlabel("Input (u)")
ax.set_ylabel("Density")
plt.grid(linestyle=":")
sns.despine()
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## Kown mean, unknown precision

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## Unkown mean, unknown precision

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
dist_mean, dist_std = 5, 1
input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000)
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
mean_precision_hgf = (
HGF(model_type=None)
.add_input_node(kind="continuous", continuous_parameters={'continuous_precision': 0.01})
.add_value_parent(children_idxs=[0], tonic_volatility=-6.0)
.add_volatility_parent(children_idxs=[0], tonic_volatility=-6.0)
.init()
).input_data(input_data)
mean_precision_hgf.plot_network()
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
tags: [hide-input]
---
# get the nodes trajectories
df = mean_precision_hgf.to_pandas()

fig, ax = plt.subplots(figsize=(12, 5))

x = np.linspace(-10, 10, 1000)
for i, color in zip(range(0, 150, 15), plt.cm.Greys(np.linspace(.2, 1, 10))):

# extract the sufficient statistics from the input node (and parents)
mean = df.x_1_expected_mean.iloc[i]
std = np.sqrt(
1/(mean_precision_hgf.attributes[0]["expected_precision"] * (1/np.exp(df.x_2_expected_mean.iloc[i])))
)

# the model expectations
ax.plot(x, norm(mean, std).pdf(x), color=color, label=i)


# the sampling distribution
ax.fill_between(x, norm(dist_mean, dist_std).pdf(x), color="#582766", alpha=.2)

ax.legend(title="Iterations")
ax.set_xlabel("Input (u)")
ax.set_ylabel("Density")
plt.grid(linestyle=":")
sns.despine()
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## System configuration

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
```
1 change: 1 addition & 0 deletions docs/source/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ glob:
| Notebook | Colab |
| --- | ---|
| {ref}`example_1` | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilabcode/pyhgf/blob/master/docs/source/notebooks/Example_1_Heart_rate_variability.ipynb)
| {ref}`example_2` | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilabcode/pyhgf/blob/master/docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb)

## Exercises

Expand Down
2 changes: 1 addition & 1 deletion src/pyhgf/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def hgf_logp(
volatility_coupling_1: Union[np.ndarray, ArrayLike, float] = 1.0,
volatility_coupling_2: Union[np.ndarray, ArrayLike, float] = 1.0,
input_data: List[np.ndarray] = [np.nan],
response_function: Callable = None,
response_function: Optional[Callable] = None,
model_type: str = "continuous",
n_levels: int = 2,
response_function_parameters: List[Tuple] = [()],
Expand Down
Loading