Skip to content

Commit

Permalink
dissociate time steps and input data when passed to beliefs_propagati…
Browse files Browse the repository at this point in the history
…on (#146)

* docstrings

* dissociate time steps and input data when passed to the beliefs propagation

* notebooks
  • Loading branch information
LegrandNico authored Nov 29, 2023
1 parent 37781ef commit 19ed44a
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 121 deletions.
370 changes: 310 additions & 60 deletions docs/source/notebooks/Example_3_Multi_armed_bandit.ipynb

Large diffs are not rendered by default.

23 changes: 5 additions & 18 deletions docs/source/notebooks/Example_3_Multi_armed_bandit.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,6 @@ two_armed_bandit_missing_inputs_hgf = (
two_armed_bandit_hgf.plot_network()
```

```{code-cell} ipython3
two_armed_bandit_missing_inputs_hgf.update_sequence
```

```{code-cell} ipython3
two_armed_bandit_missing_inputs_hgf.input_data(input_data=missing_inputs_u.T);
```
Expand Down Expand Up @@ -287,9 +283,6 @@ two_armed_bandit_missing_inputs_hgf = (
```

```{code-cell} ipython3
# add a time step vector
input_data = np.c_[u.T, np.ones(u.shape[1])]
# get the network variables from the HGF class
attributes = two_armed_bandit_missing_inputs_hgf.attributes
update_sequence = two_armed_bandit_missing_inputs_hgf.update_sequence
Expand All @@ -301,6 +294,7 @@ beta = 1.0
```

```{code-cell} ipython3
input_data = u.astype(float).T
responses = [] # 1: arm A - 0: arm B
for i in range(input_data.shape[0]):
Expand All @@ -320,11 +314,12 @@ for i in range(input_data.shape[0]):
else:
input_data[i, 2:4] = np.nan
responses.append(0)
time_steps = np.ones(1)
# update the probabilistic network
attributes, _ = beliefs_propagation(
attributes=attributes,
data=input_data[i],
input_data=(input_data[i], time_steps),
update_sequence=update_sequence,
edges=edges,
input_nodes_idx=input_nodes_idx
Expand All @@ -336,10 +331,6 @@ responses = jnp.asarray(responses)

First, we start by creating the response function we want to optimize (see also {ref}`custom_response_functions` on how to create such functions).

```{code-cell} ipython3
jnp.isnan(input_data[:, 0])
```

```{code-cell} ipython3
from pyhgf.math import binary_surprise
from jax.tree_util import Partial
Expand Down Expand Up @@ -390,7 +381,7 @@ def two_bandits_logp(tonic_volatility, hgf, input_data, responses):
logp_fn = Partial(
two_bandits_logp,
hgf=two_armed_bandit_missing_inputs_hgf,
input_data=input_data[:, :-1],
input_data=input_data,
responses=responses
)
```
Expand All @@ -411,10 +402,6 @@ def vjp_custom_op_jax(x, gz):
jitted_vjp_custom_op_jax = jit(vjp_custom_op_jax)
```

```{code-cell} ipython3
```

```{code-cell} ipython3
---
editable: true
Expand Down
4 changes: 2 additions & 2 deletions src/pyhgf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def load_data(dataset: str) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
data : np.ndarray
The data (a 1d timeseries).
Notes
-----
Note
----
The continuous time series is the standard USD-CHF conversion rates over time used
in the Matlab examples.
Expand Down
28 changes: 16 additions & 12 deletions src/pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,10 @@ def init(self) -> "HGF":
_ = scan(
self.scan_fn,
self.attributes,
jnp.array([jnp.ones(len(self.input_nodes_idx.idx) + 1)]),
(
jnp.ones((1, len(self.input_nodes_idx.idx))),
jnp.ones((1, 1)),
),
)
if self.verbose:
print("... Cache the belief propagation function.")
Expand All @@ -319,19 +322,18 @@ def input_data(
if self.verbose:
print((f"Adding {len(input_data)} new observations."))
if time_steps is None:
time_steps = np.ones(len(input_data)) # time steps vector

# concatenate data and time
time_steps = time_steps[..., jnp.newaxis]
time_steps = np.ones((len(input_data), 1)) # time steps vector
else:
time_steps = time_steps[..., jnp.newaxis]
if input_data.ndim == 1:
input_data = input_data[..., jnp.newaxis]

data = jnp.concatenate((input_data, time_steps), dtype=float, axis=1)

# this is where the model loop over the whole input time series
# at each time point, the node structure is traversed and beliefs are updated
# using precision-weighted prediction errors
_, node_trajectories = scan(self.scan_fn, self.attributes, data)
_, node_trajectories = scan(
self.scan_fn, self.attributes, (input_data, time_steps)
)

# trajectories of the network attributes a each time point
self.node_trajectories = node_trajectories
Expand Down Expand Up @@ -377,10 +379,12 @@ def input_custom_sequence(
time_steps = np.ones(len(input_data)) # time steps vector

# concatenate data and time
time_steps = time_steps[..., jnp.newaxis]
if time_steps is None:
time_steps = np.ones((len(input_data), 1)) # time steps vector
else:
time_steps = time_steps[..., jnp.newaxis]
if input_data.ndim == 1:
input_data = input_data[..., jnp.newaxis]
data = jnp.concatenate((input_data, time_steps), dtype=float, axis=1)

# create the update functions that will be scanned
branches_fn = [
Expand All @@ -399,7 +403,7 @@ def switching_propagation(attributes, scan_input):
return switch(idx, branches_fn, attributes, data)

# wrap the inputs
scan_input = data, branches_idx
scan_input = (input_data, time_steps), branches_idx

# scan over the input data and apply the switching belief propagation functions
_, node_trajectories = scan(switching_propagation, self.attributes, scan_input)
Expand All @@ -410,7 +414,7 @@ def switching_propagation(attributes, scan_input):
# because some of the input node might not have been updated, here we manually
# insert the input data to the input node (without triggering updates)
for idx, inp in zip(self.input_nodes_idx.idx, range(input_data.shape[1])):
self.node_trajectories[idx]["value"] = input_data[:, inp]
self.node_trajectories[idx]["value"] = input_data[inp]

return self

Expand Down
11 changes: 5 additions & 6 deletions src/pyhgf/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
@partial(jit, static_argnames=("update_sequence", "edges", "input_nodes_idx"))
def beliefs_propagation(
attributes: Dict,
data: ArrayLike,
input_data: ArrayLike,
update_sequence: UpdateSequence,
edges: Tuple,
input_nodes_idx: Tuple = (0,),
Expand All @@ -58,7 +58,7 @@ def beliefs_propagation(
attributes :
The dictionaries of nodes' parameters. This variable is updated and returned
after the beliefs propagation step.
data :
input_data :
An array containing the new observation(s) as well as the time steps. The new
observations can be a single value or a vector of observation with a length
matching the length `input_nodes_idx`. `input_nodes_idx` is used to index the
Expand All @@ -80,9 +80,8 @@ def beliefs_propagation(
"""
# extract value(s) and time steps from the data
values = data[:-1]
time_step = data[-1]
# extract value(s) and time steps
values, time_step = input_data

input_nodes_idx = jnp.asarray(input_nodes_idx)
# Fit the model with the current time and value variables, given the model structure
Expand All @@ -95,7 +94,7 @@ def beliefs_propagation(

attributes = update_fn(
attributes=attributes,
time_step=time_step,
time_step=time_step[0],
node_idx=node_idx,
edges=edges,
value=value,
Expand Down
4 changes: 2 additions & 2 deletions src/pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def plot_network(hgf: "HGF") -> "Source":
hgf :
An instance of the HGF model containing a node structure.
Notes
-----
Note
----
This function requires [Graphviz](https://github.com/xflr6/graphviz) to be
installed to work correctly.
Expand Down
10 changes: 5 additions & 5 deletions src/pyhgf/updates/posterior/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def posterior_update_mean_continuous_node(
posterior_mean :
The new posterior mean.
Notes
-----
Note
----
This update step is similar to the one used for the state node, except that it uses
the observed value instead of the mean of the child node, and the expected mean of
the parent node instead of the expected mean of the child node.
Expand Down Expand Up @@ -131,7 +131,7 @@ def posterior_update_mean_continuous_node(
# sum the precision weigthed prediction errors over all children
precision_weigthed_prediction_error += (
(value_coupling * attributes[value_child_idx]["expected_precision"])
/ attributes[node_idx]["precision"]
/ node_precision
) * value_prediction_error

# Volatility coupling updates - update the mean of a volatility parent
Expand Down Expand Up @@ -247,8 +247,8 @@ def posterior_update_precision_continuous_node(
posterior_precision :
The new posterior precision.
Notes
-----
Note
----
This update step is similar to the one used for the state node, except that it uses
the observed value instead of the mean of the child node, and the expected mean of
the parent node instead of the expected mean of the child node.
Expand Down
4 changes: 2 additions & 2 deletions src/pyhgf/updates/prediction_error/inputs/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def continuous_input_value_prediction_error(
attributes :
The attributes of the probabilistic nodes.
Notes
-----
Note
----
This update step is similar to the one used for the state node, except that it uses
the observed value instead of the mean of the child node, and the expected mean of
the parent node instead of the expected mean of the child node.
Expand Down
10 changes: 6 additions & 4 deletions tests/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,15 @@ def test_update_binary_input_parents(self):
sequence8,
sequence9,
)
data = jnp.array([1.0, 1.0])
data = jnp.ones(1)
time_steps = jnp.ones(1)

# apply sequence
new_attributes, _ = beliefs_propagation(
edges=edges,
attributes=attributes,
update_sequence=update_sequence,
data=data,
input_data=(data, time_steps),
)
for idx, val in zip(
["mean", "expected_mean", "binary_expected_precision"],
Expand All @@ -180,7 +181,8 @@ def test_update_binary_input_parents(self):

# Create the data (value and time steps vectors) - only use the 5 first trials
# as the priors are ill defined here
data = jnp.array([u, jnp.ones(len(u), dtype=int)]).T[:5]
data = jnp.array([u[:5]]).T
time_steps = jnp.ones((len(u[:5]), 1))

# create the function that will be scaned
scan_fn = Partial(
Expand All @@ -190,7 +192,7 @@ def test_update_binary_input_parents(self):
)

# Run the entire for loop
last, _ = scan(scan_fn, attributes, data)
last, _ = scan(scan_fn, attributes, (data, time_steps))
for idx, val in zip(
["mean", "expected_mean", "binary_expected_precision"],
[0.0, 0.95616907, 23.860779],
Expand Down
18 changes: 10 additions & 8 deletions tests/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def test_continuous_node_update(self):
Indexes(None, None, None, None),
Indexes(None, None, None, None),
)
data = jnp.array([0.2, 1.0])
data = jnp.array([0.2])
time_steps = jnp.ones(1)

###########################################
# No value parent - no volatility parents #
Expand All @@ -100,7 +101,7 @@ def test_continuous_node_update(self):
attributes=attributes,
edges=edges,
update_sequence=update_sequence,
data=data,
input_data=(data, time_steps),
)

assert attributes[1] == new_attributes[1]
Expand Down Expand Up @@ -197,14 +198,15 @@ def test_continuous_input_update(self):
sequence5,
sequence6,
)
data = jnp.array([0.2, 1.0])
data = jnp.array([0.2])
time_steps = jnp.ones(1)

# apply beliefs propagation updates
new_attributes, _ = beliefs_propagation(
edges=edges,
attributes=attributes,
update_sequence=update_sequence,
data=data,
input_data=(data, time_steps),
)

for idx, val in zip(["time_step", "value"], [1.0, 0.2]):
Expand All @@ -223,9 +225,6 @@ def test_continuous_input_update(self):
def test_scan_loop(self):
timeserie = load_data("continuous")

# Create the data (value and time steps vectors)
data = jnp.array([timeserie, jnp.ones(len(timeserie), dtype=int)]).T

###############################################
# one value parent with one volatility parent #
###############################################
Expand Down Expand Up @@ -316,8 +315,11 @@ def test_scan_loop(self):
edges=edges,
)

# Create the data (value and time steps vectors)
time_steps = jnp.ones((len(timeserie), 1))

# Run the entire for loop
last, _ = scan(scan_fn, attributes, data)
last, _ = scan(scan_fn, attributes, (timeserie, time_steps))
for idx, val in zip(["time_step", "value"], [1.0, 0.8241]):
assert jnp.isclose(last[0][idx], val)
for idx, val in zip(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,13 @@ def test_beliefs_propagation(self):
update_sequence = (sequence1, sequence2, sequence3)

# one batch of new observations with time step
data = jnp.array([0.2, 1.0])
data = jnp.array([0.2])
time_steps = jnp.ones(1)

# apply sequence
new_attributes, _ = beliefs_propagation(
attributes=attributes,
data=data,
input_data=(data, time_steps),
update_sequence=update_sequence,
edges=edges,
)
Expand Down

0 comments on commit 19ed44a

Please sign in to comment.