Skip to content

Commit

Permalink
refactoring utils
Browse files Browse the repository at this point in the history
  • Loading branch information
SylvainEstebe authored and LegrandNico committed Nov 29, 2024
1 parent e119974 commit 4f5e35e
Show file tree
Hide file tree
Showing 9 changed files with 984 additions and 745 deletions.
745 changes: 0 additions & 745 deletions pyhgf/utils.py

This file was deleted.

23 changes: 23 additions & 0 deletions pyhgf/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .add_edges import add_edges

from .beliefs_propagation import beliefs_propagation

from .fill_categorical_state_node import fill_categorical_state_node

from .get_input_idxs import get_input_idxs

from .get_update_sequence import get_update_sequence

from .list_branches import list_branches

from .to_pandas import to_pandas

__all__ = [
"add_edges",
"beliefs_propagation",
"fill_categorical_state_node",
"get_input_idxs",
"get_update_sequence",
"list_branches",
"to_pandas",
]
191 changes: 191 additions & 0 deletions pyhgf/utils/add_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error
from pyhgf.updates.prediction_error.categorical import (
categorical_state_prediction_error,
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)

if TYPE_CHECKING:
from pyhgf.model import Network


def add_edges(
attributes: Dict,
edges: Edges,
kind="value",
parent_idxs=Union[int, List[int]],
children_idxs=Union[int, List[int]],
coupling_strengths: Union[float, List[float], Tuple[float]] = 1.0,
coupling_fn: Tuple[Optional[Callable], ...] = (None,),
) -> Tuple:
"""Add a value or volatility coupling link between a set of nodes.
Parameters
----------
attributes :
Attributes of the neural network.
edges :
Edges of the neural network.
kind :
The kind of coupling can be `"value"` or `"volatility"`.
parent_idxs :
The index(es) of the parent node(s).
children_idxs :
The index(es) of the children node(s).
coupling_strengths :
The coupling strength between the parents and children.
coupling_fn :
Coupling function(s) between the current node and its value children.
It has to be provided as a tuple. If multiple value children are specified,
the coupling functions must be stated in the same order of the children.
Note: if a node has multiple parents nodes with different coupling
functions, a coupling function should be indicated for all the parent nodes.
If no coupling function is stated, the relationship between nodes is assumed
linear.
"""
if kind not in ["value", "volatility"]:
raise ValueError(
f"The kind of coupling should be value or volatility, got {kind}"
)
if isinstance(children_idxs, int):
children_idxs = [children_idxs]
assert isinstance(children_idxs, (list, tuple))

if isinstance(parent_idxs, int):
parent_idxs = [parent_idxs]
assert isinstance(parent_idxs, (list, tuple))

if isinstance(coupling_strengths, int):
coupling_strengths = [float(coupling_strengths)]
if isinstance(coupling_strengths, float):
coupling_strengths = [coupling_strengths]

assert isinstance(coupling_strengths, (list, tuple))

edges_as_list = list(edges)
# update the parent nodes
# -----------------------
for parent_idx in parent_idxs:
# unpack the parent's edges
(
node_type,
value_parents,
volatility_parents,
value_children,
volatility_children,
this_coupling_fn,
) = edges_as_list[parent_idx]

if kind == "value":
if value_children is None:
value_children = tuple(children_idxs)
attributes[parent_idx]["value_coupling_children"] = tuple(
coupling_strengths
)
else:
value_children = value_children + tuple(children_idxs)
attributes[parent_idx]["value_coupling_children"] += tuple(
coupling_strengths
)
this_coupling_fn = this_coupling_fn + coupling_fn
elif kind == "volatility":
if volatility_children is None:
volatility_children = tuple(children_idxs)
attributes[parent_idx]["volatility_coupling_children"] = tuple(
coupling_strengths
)
else:
volatility_children = volatility_children + tuple(children_idxs)
attributes[parent_idx]["volatility_coupling_children"] += tuple(
coupling_strengths
)

# save the updated edges back
edges_as_list[parent_idx] = AdjacencyLists(
node_type,
value_parents,
volatility_parents,
value_children,
volatility_children,
this_coupling_fn,
)

# update the children nodes
# -------------------------
for children_idx in children_idxs:
# unpack this node's edges
(
node_type,
value_parents,
volatility_parents,
value_children,
volatility_children,
coupling_fn,
) = edges_as_list[children_idx]

if kind == "value":
if value_parents is None:
value_parents = tuple(parent_idxs)
attributes[children_idx]["value_coupling_parents"] = tuple(
coupling_strengths
)
else:
value_parents = value_parents + tuple(parent_idxs)
attributes[children_idx]["value_coupling_parents"] += tuple(
coupling_strengths
)
elif kind == "volatility":
if volatility_parents is None:
volatility_parents = tuple(parent_idxs)
attributes[children_idx]["volatility_coupling_parents"] = tuple(
coupling_strengths
)
else:
volatility_parents = volatility_parents + tuple(parent_idxs)
attributes[children_idx]["volatility_coupling_parents"] += tuple(
coupling_strengths
)

# save the updated edges back
edges_as_list[children_idx] = AdjacencyLists(
node_type,
value_parents,
volatility_parents,
value_children,
volatility_children,
coupling_fn,
)

# convert the list back to a tuple
edges = tuple(edges_as_list)

return attributes, edges


129 changes: 129 additions & 0 deletions pyhgf/utils/beliefs_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error
from pyhgf.updates.prediction_error.categorical import (
categorical_state_prediction_error,
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)

if TYPE_CHECKING:
from pyhgf.model import Network


@partial(jit, static_argnames=("update_sequence", "edges", "input_idxs"))
def beliefs_propagation(
attributes: Attributes,
inputs: Tuple[ArrayLike, ...],
update_sequence: UpdateSequence,
edges: Edges,
input_idxs: Tuple[int],
) -> Tuple[Dict, Dict]:
"""Update the network's parameters after observing new data point(s).
This function performs the beliefs propagation step. Belief propagation consists in:
1. A prediction sequence, from the leaves of the graph to the roots.
2. The assignation of new observations to target nodes (usually the roots of the
network)
3. An inference step alternating between prediction errors and posterior updates,
starting from the roots of the network to the leaves.
This function returns a tuple of two new `parameter_structure` (i.e. the carryover
and the accumulated in the context of :py:func:`jax.lax.scan`).
Parameters
----------
attributes :
The dictionaries of nodes' parameters. This variable is updated and returned
after the beliefs propagation step.
inputs :
A tuple of n by time steps arrays containing the new observation(s), the time
steps as well as a boolean mask for observed values. The new observations are a
tuple of array, with length equal to the number of input nodes. Each input node
can receive observations The time steps are the last
column of the array, the default is unit incrementation.
update_sequence :
The sequence of updates that will be applied to the node structure.
edges :
Information on the network's edges.
input_idxs :
List input indexes.
Returns
-------
attributes, attributes :
A tuple of parameters structure (carryover and accumulated).
"""
prediction_steps, update_steps = update_sequence

# unpack input data - input_values is a tuple of n x time steps arrays
(*input_data, time_step) = inputs

attributes[-1]["time_step"] = time_step

# Prediction sequence
# -------------------
for step in prediction_steps:

node_idx, update_fn = step

attributes = update_fn(
attributes=attributes,
node_idx=node_idx,
edges=edges,
)

# Observations
# ------------
for values, observed, node_idx in zip(
input_data[::2], input_data[1::2], input_idxs
):

attributes = set_observation(
attributes=attributes,
node_idx=node_idx,
values=values,
observed=observed,
)

# Update sequence
# ---------------
for step in update_steps:

node_idx, update_fn = step

attributes = update_fn(
attributes=attributes,
node_idx=node_idx,
edges=edges,
)

return (
attributes,
attributes,
) # ("carryover", "accumulated")

Loading

0 comments on commit 4f5e35e

Please sign in to comment.