From 8c5b0caabc9791ed7c2a467eecece5df6554f208 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 29 Oct 2024 12:14:41 +0100 Subject: [PATCH] exponential 1d node working --- Cargo.lock | 38 ++++++++++------- Cargo.toml | 2 +- docs/source/api.rst | 20 ++++----- .../notebooks/0.3-Generalised_filtering.ipynb | 4 +- pyhgf/model/network.py | 8 ++-- pyhgf/updates/prediction/dirichlet.py | 2 +- .../exponential.py | 2 +- pyhgf/utils.py | 31 +++++++------- src/model.rs | 42 +++++++++++++++++-- tests/test_exponential_family.py | 17 ++++++-- tests/test_utils.py | 2 +- 11 files changed, 112 insertions(+), 56 deletions(-) rename pyhgf/updates/{posterior => prediction_error}/exponential.py (97%) diff --git a/Cargo.lock b/Cargo.lock index a247d39c5..db1887429 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -69,16 +69,14 @@ dependencies = [ [[package]] name = "ndarray" -version = "0.16.1" +version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" dependencies = [ "matrixmultiply", "num-complex", "num-integer", "num-traits", - "portable-atomic", - "portable-atomic-util", "rawpointer", ] @@ -109,6 +107,21 @@ dependencies = [ "autocfg", ] +[[package]] +name = "numpy" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec170733ca37175f5d75a5bea5911d6ff45d2cd52849ce98b685394e4f2f37f4" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "rustc-hash", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -144,15 +157,6 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" -[[package]] -name = "portable-atomic-util" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90a7d5beecc52a491b54d6dd05c7a45ba1801666a5baad9fdbfc6fef8d2d206c" -dependencies = [ - "portable-atomic", -] - [[package]] name = "proc-macro2" version = "1.0.87" @@ -253,10 +257,16 @@ dependencies = [ name = "rshgf" version = "0.1.0" dependencies = [ - "ndarray", + "numpy", "pyo3", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "scopeguard" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index eb8e726cb..d4e4c9e87 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,4 @@ path = "src/lib.rs" # The source file of the target. [dependencies] pyo3 = { version = "0.21.2", features = ["extension-module"] } -ndarray = "0.16.1" \ No newline at end of file +numpy = "0.21" \ No newline at end of file diff --git a/docs/source/api.rst b/docs/source/api.rst index 8f0e2584e..7e3d3cdf3 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -46,16 +46,6 @@ Continuous nodes continuous_node_posterior_update continuous_node_posterior_update_ehgf -Exponential family ------------------- - -.. currentmodule:: pyhgf.updates.posterior.exponential - -.. autosummary:: - :toctree: generated/pyhgf.updates.posterior.exponential - - posterior_update_exponential_family - Prediction steps ================ @@ -146,6 +136,16 @@ Dirichlet state nodes likely_cluster_proposal clusters_likelihood +Exponential family +^^^^^^^^^^^^^^^^^^ + +.. currentmodule:: pyhgf.updates.prediction_error.exponential + +.. autosummary:: + :toctree: generated/pyhgf.updates.prediction_error.exponential + + prediction_error_update_exponential_family + Distribution ************ diff --git a/docs/source/notebooks/0.3-Generalised_filtering.ipynb b/docs/source/notebooks/0.3-Generalised_filtering.ipynb index fd3a19c5f..257839b20 100644 --- a/docs/source/notebooks/0.3-Generalised_filtering.ipynb +++ b/docs/source/notebooks/0.3-Generalised_filtering.ipynb @@ -320,7 +320,7 @@ "\n", "### Using a fixed $\\nu$\n", "\n", - "This operation can be achieved using a continuous state node that implements the exponential family updates on the values that are passed by the value child nodes. Such nodes are referred to as `ef-` nodes, with the type of distribution (here a simple one-dimensional Gaussian distribution, therefore the kind is set to `\"ef-normal\"`). The input node is set to generic, which means that this input simply passes the observed value to the value parents without any additional computation. We can define such a model as follows:" + "This operation can be achieved using a continuous state node that implements the exponential family updates on the values that are passed by the value child nodes. Such nodes are referred to as `exponential-state` nodes, with the type of distribution (here a simple one-dimensional Gaussian distribution). The input node is set to generic, which means that this input simply passes the observed value to the value parents without any additional computation. We can define such a model as follows:" ] }, { @@ -340,7 +340,7 @@ "generalised_filter = (\n", " Network()\n", " .add_nodes(kind=\"generic-state\")\n", - " .add_nodes(kind=\"ef-normal\", value_children=0, xis=np.array([0, 1 / 8]))\n", + " .add_nodes(kind=\"exponential-state\", value_children=0, xis=np.array([0, 1 / 8]))\n", ")" ] }, diff --git a/pyhgf/model/network.py b/pyhgf/model/network.py index b5096df9a..a6041341f 100644 --- a/pyhgf/model/network.py +++ b/pyhgf/model/network.py @@ -390,8 +390,8 @@ def add_nodes( raise ValueError( ( "Invalid node type. Should be one of the following: " - "'DP-state', 'continuous-state', 'binary-state', 'ef-normal'." - "'generic-state' or 'categorical-state'" + "'DP-state', 'continuous-state', 'binary-state', " + "'exponential-state', 'generic-state' or 'categorical-state'" ) ) @@ -473,7 +473,7 @@ def add_nodes( "nus": 3.0, "xis": jnp.array([0.0, 1.0]), "mean": 0.0, - "observed": 1.0, + "observed": 1, } elif kind == "categorical-state": if "n_categories" in node_parameters: @@ -562,7 +562,7 @@ def add_nodes( node_type = 1 elif kind == "continuous-state": node_type = 2 - elif kind == "ef-normal": + elif kind == "exponential-state": node_type = 3 elif kind == "DP-state": node_type = 4 diff --git a/pyhgf/updates/prediction/dirichlet.py b/pyhgf/updates/prediction/dirichlet.py index 03fb86e25..298fbd2b9 100644 --- a/pyhgf/updates/prediction/dirichlet.py +++ b/pyhgf/updates/prediction/dirichlet.py @@ -39,7 +39,7 @@ def dirichlet_node_prediction( Static parameters of the Dirichlet process node. """ - # get the parameter (mean and variance) from the EF-normal parent nodes + # get the parameter (mean and variance) from the exponential state parent nodes value_parent_idxs = edges[node_idx].value_parents if value_parent_idxs is not None: parameters = jnp.array( diff --git a/pyhgf/updates/posterior/exponential.py b/pyhgf/updates/prediction_error/exponential.py similarity index 97% rename from pyhgf/updates/posterior/exponential.py rename to pyhgf/updates/prediction_error/exponential.py index f4b9ed519..a84b31b53 100644 --- a/pyhgf/updates/posterior/exponential.py +++ b/pyhgf/updates/prediction_error/exponential.py @@ -10,7 +10,7 @@ @partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn")) -def posterior_update_exponential_family( +def prediction_error_update_exponential_family( attributes: Dict, edges: Edges, node_idx: int, sufficient_stats_fn: Callable, **args ) -> Attributes: r"""Update the parameters of an exponential family distribution. diff --git a/pyhgf/utils.py b/pyhgf/utils.py index 227baf7cd..7a713cffa 100644 --- a/pyhgf/utils.py +++ b/pyhgf/utils.py @@ -18,7 +18,6 @@ continuous_node_posterior_update, continuous_node_posterior_update_ehgf, ) -from pyhgf.updates.posterior.exponential import posterior_update_exponential_family from pyhgf.updates.prediction.binary import binary_state_node_prediction from pyhgf.updates.prediction.continuous import continuous_node_prediction from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction @@ -28,6 +27,9 @@ ) from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error +from pyhgf.updates.prediction_error.exponential import ( + prediction_error_update_exponential_family, +) from pyhgf.updates.prediction_error.generic import generic_state_prediction_error if TYPE_CHECKING: @@ -374,16 +376,6 @@ def get_update_sequence( elif update_type == "standard": update_fn = continuous_node_posterior_update - elif network.edges[idx].node_type == 3: - - # create the sufficient statistic function - # for the exponential family node - ef_update = Partial( - posterior_update_exponential_family, - sufficient_stats_fn=Normal().sufficient_statistics, - ) - update_fn = ef_update - elif network.edges[idx].node_type == 4: update_fn = None @@ -407,8 +399,21 @@ def get_update_sequence( ] # if this node has no parent, no need to compute prediction errors + # unless this is an exponential family state node if len(all_parents) == 0: - nodes_without_prediction_error.remove(idx) + if network.edges[idx].node_type == 3: + # create the sufficient statistic function + # for the exponential family node + ef_update = Partial( + prediction_error_update_exponential_family, + sufficient_stats_fn=Normal().sufficient_statistics, + ) + update_fn = ef_update + no_update = False + update_sequence.append((idx, update_fn)) + nodes_without_prediction_error.remove(idx) + else: + nodes_without_prediction_error.remove(idx) else: # if this node has been updated if idx not in nodes_without_posterior_update: @@ -419,8 +424,6 @@ def get_update_sequence( update_fn = binary_state_node_prediction_error elif network.edges[idx].node_type == 2: update_fn = continuous_node_prediction_error - elif network.edges[idx].node_type == 3: - update_fn = None elif network.edges[idx].node_type == 4: update_fn = dirichlet_node_prediction_error elif network.edges[idx].node_type == 5: diff --git a/src/model.rs b/src/model.rs index 931064346..223458da3 100644 --- a/src/model.rs +++ b/src/model.rs @@ -4,7 +4,7 @@ use crate::utils::set_sequence::set_update_sequence; use crate::utils::function_pointer::get_func_map; use pyo3::types::PyTuple; use pyo3::{prelude::*, types::{PyList, PyDict}}; -use ndarray::{Array2, Axis, stack}; +use numpy::{PyArray1, PyArray}; #[derive(Debug)] #[pyclass] @@ -165,6 +165,8 @@ impl Network { // initialize the belief trajectories result struture let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()}; + + // add empty vectors in the floats hashmap for (node_idx, node) in &self.attributes.floats { let new_map: HashMap> = HashMap::new(); node_trajectories.floats.insert(*node_idx, new_map); @@ -174,13 +176,25 @@ impl Network { } } } + // add empty vectors in the vectors hashmap + for (node_idx, node) in &self.attributes.vectors { + let new_map: HashMap>> = HashMap::new(); + node_trajectories.vectors.insert(*node_idx, new_map); + if let Some(attr) = node_trajectories.vectors.get_mut(node_idx) { + for key in node.keys() { + attr.insert(key.clone(), Vec::new()); + } + } + } + // iterate over the observations for observation in input_data { // 1. belief propagation for one time slice self.belief_propagation(vec![observation]); - // 2. append the new states in the result vector + // 2. append the new beliefs in the trajectories structure + // iterate over the float hashmap for (new_node_idx, new_node) in &self.attributes.floats { for (new_key, new_value) in new_node { // If the key exists in map1, append the vector from map2 @@ -191,6 +205,17 @@ impl Network { } } } + // iterate over the vector hashmap + for (new_node_idx, new_node) in &self.attributes.vectors { + for (new_key, new_value) in new_node { + // If the key exists in map1, append the vector from map2 + if let Some(old_node) = node_trajectories.vectors.get_mut(&new_node_idx) { + if let Some(old_value) = old_node.get_mut(new_key) { + old_value.push(new_value.clone()); + } + } + } + } } self.node_trajectories = node_trajectories; @@ -201,15 +226,24 @@ impl Network { let py_list = PyList::empty(py); - // Iterate over the Rust HashMap and insert key-value pairs into the PyDict + // Iterate over the float hashmap and insert key-value pairs into the list as PyDict for (node_idx, node) in &self.node_trajectories.floats { let py_dict = PyDict::new(py); for (key, value) in node { // Create a new Python dictionary - py_dict.set_item(key, value).expect("Failed to set item in PyDict"); + py_dict.set_item(key, PyArray1::from_vec(py, value.clone()).to_owned()).expect("Failed to set item in PyDict"); + } + + // Iterate over the vector hashmap if any and insert key-value pairs into the list as PyDict + if let Some(vector_node) = self.node_trajectories.vectors.get(node_idx) { + for (vector_key, vector_value) in vector_node { + // Create a new Python dictionary + py_dict.set_item(vector_key, PyArray::from_vec2_bound(py, &vector_value).unwrap()).expect("Failed to set item in PyDict"); + } } py_list.append(py_dict)?; } + // Create a PyList from Vec Ok(py_list) } diff --git a/tests/test_exponential_family.py b/tests/test_exponential_family.py index 847200e27..1e5f834aa 100644 --- a/tests/test_exponential_family.py +++ b/tests/test_exponential_family.py @@ -1,5 +1,6 @@ # Author: Nicolas Legrand +import numpy as np from rshgf import Network as RsNetwork from pyhgf import load_data @@ -13,12 +14,20 @@ def test_1d_gaussain(): # Rust ----------------------------------------------------------------------------- rs_network = RsNetwork() rs_network.add_nodes(kind="exponential-state") - rs_network.inputs - rs_network.edges rs_network.set_update_sequence() - rs_network.input_data(timeseries) # Python --------------------------------------------------------------------------- py_network = PyNetwork().add_nodes(kind="exponential-state") - py_network.attributes + py_network.input_data(timeseries) + + # Ensure identical results + assert np.isclose( + py_network.node_trajectories[0]["xis"], rs_network.node_trajectories[0]["xis"] + ).all() + assert np.isclose( + py_network.node_trajectories[0]["mean"], rs_network.node_trajectories[0]["mean"] + ).all() + assert np.isclose( + py_network.node_trajectories[0]["nus"], rs_network.node_trajectories[0]["nus"] + ).all() diff --git a/tests/test_utils.py b/tests/test_utils.py index 93a37dad4..1dde30aec 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -74,7 +74,7 @@ def test_set_update_sequence(): network3 = ( Network() .add_nodes(kind="generic-state") - .add_nodes(kind="ef-normal", value_children=0) + .add_nodes(kind="exponential-state", value_children=0) .create_belief_propagation_fn() ) predictions, updates = network3.update_sequence