From 76b1fea42b5bdb5a4a964bb209f5901f3b0a660d Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Wed, 30 Oct 2024 10:24:16 +0100 Subject: [PATCH 1/3] get hashmap without match statements --- Cargo.lock | 83 +++++++++++++++++++++ Cargo.toml | 3 +- src/lib.rs | 3 +- src/tests/exponential_family.rs | 0 src/updates/prediction_error/exponential.rs | 33 +++----- 5 files changed, 98 insertions(+), 24 deletions(-) delete mode 100644 src/tests/exponential_family.rs diff --git a/Cargo.lock b/Cargo.lock index db1887429..d10eeaf7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,42 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "autocfg" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets", +] + [[package]] name = "bitflags" version = "2.6.0" @@ -20,6 +50,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + [[package]] name = "heck" version = "0.4.1" @@ -58,6 +94,12 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + [[package]] name = "memoffset" version = "0.9.1" @@ -67,6 +109,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "ndarray" version = "0.15.6" @@ -122,6 +173,15 @@ dependencies = [ "rustc-hash", ] +[[package]] +name = "object" +version = "0.36.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -151,6 +211,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "pin-project-lite" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" + [[package]] name = "portable-atomic" version = "1.9.0" @@ -259,8 +325,15 @@ version = "0.1.0" dependencies = [ "numpy", "pyo3", + "tokio", ] +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -296,6 +369,16 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tokio" +version = "1.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +dependencies = [ + "backtrace", + "pin-project-lite", +] + [[package]] name = "unicode-ident" version = "1.0.13" diff --git a/Cargo.toml b/Cargo.toml index d4e4c9e87..7d1f7d6b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,5 @@ path = "src/lib.rs" # The source file of the target. [dependencies] pyo3 = { version = "0.21.2", features = ["extension-module"] } -numpy = "0.21" \ No newline at end of file +numpy = "0.21" +tokio = "1.41.0" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 55ad98cd0..4f5612ef0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod model; pub mod utils; pub mod math; -pub mod updates; \ No newline at end of file +pub mod updates; +pub mod reactive; \ No newline at end of file diff --git a/src/tests/exponential_family.rs b/src/tests/exponential_family.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/updates/prediction_error/exponential.rs b/src/updates/prediction_error/exponential.rs index e2cebca05..82239ed2a 100644 --- a/src/updates/prediction_error/exponential.rs +++ b/src/updates/prediction_error/exponential.rs @@ -11,25 +11,14 @@ use crate::math::sufficient_statistics; /// * `network` - The network after message passing. pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) { - if let Some(floats_attributes) = network.attributes.floats.get_mut(&node_idx) { - if let Some(vectors_attributes) = network.attributes.vectors.get_mut(&node_idx) { - let mean = floats_attributes.get("mean"); - let nus = floats_attributes.get("nus"); - let xis = vectors_attributes.get("xis"); - let new_xis = match (mean, nus, xis) { - (Some(mean), Some(nus), Some(xis)) => { - let suf_stats = sufficient_statistics(mean); - let mut new_xis = xis.clone(); - for i in 0..suf_stats.len() { - new_xis[i] = new_xis[i] + (1.0 / (1.0 + nus)) * (suf_stats[i] - xis[i]); - } - new_xis - } - _ => Vec::new(), - }; - if let Some(xis) = vectors_attributes.get_mut("xis") { - *xis = new_xis; // Modify the value directly - } - } - } - } \ No newline at end of file + let floats_attributes = network.attributes.floats.get_mut(&node_idx).expect("No floats attributes"); + let vectors_attributes = network.attributes.vectors.get_mut(&node_idx).expect("No vector attributes"); + let mean = floats_attributes.get("mean").expect("Mean not found"); + let nus = floats_attributes.get("nus").expect("Nus not found"); + let xis = vectors_attributes.get_mut("xis").expect("Xis not found"); + + let suf_stats = sufficient_statistics(mean); + for i in 0..suf_stats.len() { + xis[i] = xis[i] + (1.0 / (1.0 + nus)) * (suf_stats[i] - xis[i]); + } +} \ No newline at end of file From 78c5f7fc0abacb345171cbc562da13a4bc624a59 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Thu, 31 Oct 2024 09:04:10 +0100 Subject: [PATCH 2/3] preallocate vectors --- Cargo.lock | 83 ----------------------------------------- Cargo.toml | 5 +-- examples/exponential.rs | 22 +++++++++++ src/lib.rs | 3 +- src/model.rs | 44 +++++++++++----------- 5 files changed, 46 insertions(+), 111 deletions(-) create mode 100644 examples/exponential.rs diff --git a/Cargo.lock b/Cargo.lock index d10eeaf7d..db1887429 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,42 +2,12 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler2" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" - [[package]] name = "autocfg" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" -[[package]] -name = "backtrace" -version = "0.3.74" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" -dependencies = [ - "addr2line", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", - "windows-targets", -] - [[package]] name = "bitflags" version = "2.6.0" @@ -50,12 +20,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "gimli" -version = "0.31.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" - [[package]] name = "heck" version = "0.4.1" @@ -94,12 +58,6 @@ dependencies = [ "rawpointer", ] -[[package]] -name = "memchr" -version = "2.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" - [[package]] name = "memoffset" version = "0.9.1" @@ -109,15 +67,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "miniz_oxide" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" -dependencies = [ - "adler2", -] - [[package]] name = "ndarray" version = "0.15.6" @@ -173,15 +122,6 @@ dependencies = [ "rustc-hash", ] -[[package]] -name = "object" -version = "0.36.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.20.2" @@ -211,12 +151,6 @@ dependencies = [ "windows-targets", ] -[[package]] -name = "pin-project-lite" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" - [[package]] name = "portable-atomic" version = "1.9.0" @@ -325,15 +259,8 @@ version = "0.1.0" dependencies = [ "numpy", "pyo3", - "tokio", ] -[[package]] -name = "rustc-demangle" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" - [[package]] name = "rustc-hash" version = "1.1.0" @@ -369,16 +296,6 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" -[[package]] -name = "tokio" -version = "1.41.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" -dependencies = [ - "backtrace", - "pin-project-lite", -] - [[package]] name = "unicode-ident" version = "1.0.13" diff --git a/Cargo.toml b/Cargo.toml index 7d1f7d6b6..2f5937bac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,10 +6,9 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] name = "rshgf" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] path = "src/lib.rs" # The source file of the target. [dependencies] pyo3 = { version = "0.21.2", features = ["extension-module"] } -numpy = "0.21" -tokio = "1.41.0" \ No newline at end of file +numpy = "0.21" \ No newline at end of file diff --git a/examples/exponential.rs b/examples/exponential.rs new file mode 100644 index 000000000..dd58f6368 --- /dev/null +++ b/examples/exponential.rs @@ -0,0 +1,22 @@ +use rshgf::model::Network; + +fn main() { + + // initialize network + let mut network = Network::new(); + + // create a network with two exponential family state nodes + network.add_nodes( + "exponential-state", + None, + None, + None, + None + ); + + // belief propagation + let input_data = vec![1.0, 1.3, 1.5, 1.7]; + network.set_update_sequence(); + network.input_data(input_data); + +} diff --git a/src/lib.rs b/src/lib.rs index 4f5612ef0..55ad98cd0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ pub mod model; pub mod utils; pub mod math; -pub mod updates; -pub mod reactive; \ No newline at end of file +pub mod updates; \ No newline at end of file diff --git a/src/model.rs b/src/model.rs index 223458da3..a2ad8a3cb 100644 --- a/src/model.rs +++ b/src/model.rs @@ -163,29 +163,31 @@ impl Network { /// associated with one node. pub fn input_data(&mut self, input_data: Vec) { + let n_time = input_data.len(); + // initialize the belief trajectories result struture let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()}; - // add empty vectors in the floats hashmap + // preallocate empty vectors in the floats hashmap for (node_idx, node) in &self.attributes.floats { let new_map: HashMap> = HashMap::new(); node_trajectories.floats.insert(*node_idx, new_map); - if let Some(attr) = node_trajectories.floats.get_mut(node_idx) { - for key in node.keys() { - attr.insert(key.clone(), Vec::new()); - } + let attr = node_trajectories.floats.get_mut(node_idx).expect("New map not found."); + for key in node.keys() { + attr.insert(key.clone(), Vec::with_capacity(n_time)); } - } - // add empty vectors in the vectors hashmap + } + + // preallocate empty vectors in the vectors hashmap for (node_idx, node) in &self.attributes.vectors { let new_map: HashMap>> = HashMap::new(); node_trajectories.vectors.insert(*node_idx, new_map); - if let Some(attr) = node_trajectories.vectors.get_mut(node_idx) { - for key in node.keys() { - attr.insert(key.clone(), Vec::new()); - } + let attr = node_trajectories.vectors.get_mut(node_idx).expect("New vector map not found."); + for key in node.keys() { + attr.insert(key.clone(), Vec::with_capacity(n_time)); } - } + } + // iterate over the observations for observation in input_data { @@ -198,26 +200,22 @@ impl Network { for (new_node_idx, new_node) in &self.attributes.floats { for (new_key, new_value) in new_node { // If the key exists in map1, append the vector from map2 - if let Some(old_node) = node_trajectories.floats.get_mut(&new_node_idx) { - if let Some(old_value) = old_node.get_mut(new_key) { - old_value.push(*new_value); - } + let old_node = node_trajectories.floats.get_mut(&new_node_idx).expect("Old node not found."); + let old_value = old_node.get_mut(new_key).expect("Old value not found"); + old_value.push(*new_value); } } - } + // iterate over the vector hashmap for (new_node_idx, new_node) in &self.attributes.vectors { for (new_key, new_value) in new_node { // If the key exists in map1, append the vector from map2 - if let Some(old_node) = node_trajectories.vectors.get_mut(&new_node_idx) { - if let Some(old_value) = old_node.get_mut(new_key) { - old_value.push(new_value.clone()); - } + let old_node = node_trajectories.vectors.get_mut(&new_node_idx).expect("Old vector node not found."); + let old_value = old_node.get_mut(new_key).expect("Old vector value not found."); + old_value.push(new_value.clone()); } } } - } - self.node_trajectories = node_trajectories; } From d98ffe80b6208d64d8e993fcc846e9001e09f43b Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 11 Nov 2024 12:24:41 +0100 Subject: [PATCH 3/3] perf --- src/model.rs | 34 +++++--------------------------- src/utils/beliefs_propagation.rs | 25 +++++++++++++++++++++++ src/utils/mod.rs | 3 ++- 3 files changed, 32 insertions(+), 30 deletions(-) create mode 100644 src/utils/beliefs_propagation.rs diff --git a/src/model.rs b/src/model.rs index a2ad8a3cb..1f1d4f7f8 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; -use crate::{updates::observations::observation_update, utils::function_pointer::FnType}; +use crate::utils::function_pointer::FnType; use crate::utils::set_sequence::set_update_sequence; +use crate::utils::beliefs_propagation::belief_propagation; use crate::utils::function_pointer::get_func_map; use pyo3::types::PyTuple; use pyo3::{prelude::*, types::{PyList, PyDict}}; @@ -129,33 +130,6 @@ impl Network { self.update_sequence = set_update_sequence(self); } - /// Single time slice belief propagation. - /// - /// # Arguments - /// * `observations` - A vector of values, each value is one new observation associated - /// with one node. - pub fn belief_propagation(&mut self, observations_set: Vec) { - - let predictions = self.update_sequence.predictions.clone(); - let updates = self.update_sequence.updates.clone(); - - // 1. prediction steps - for (idx, step) in predictions.iter() { - step(self, *idx); - } - - // 2. observation steps - for (i, observations) in observations_set.iter().enumerate() { - let idx = self.inputs[i]; - observation_update(self, idx, *observations); - } - - // 3. update steps - for (idx, step) in updates.iter() { - step(self, *idx); - } - } - /// Add a sequence of observations. /// /// # Arguments @@ -164,6 +138,8 @@ impl Network { pub fn input_data(&mut self, input_data: Vec) { let n_time = input_data.len(); + let predictions = self.update_sequence.predictions.clone(); + let updates = self.update_sequence.updates.clone(); // initialize the belief trajectories result struture let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()}; @@ -193,7 +169,7 @@ impl Network { for observation in input_data { // 1. belief propagation for one time slice - self.belief_propagation(vec![observation]); + belief_propagation(self, vec![observation], &predictions, &updates); // 2. append the new beliefs in the trajectories structure // iterate over the float hashmap diff --git a/src/utils/beliefs_propagation.rs b/src/utils/beliefs_propagation.rs new file mode 100644 index 000000000..789887643 --- /dev/null +++ b/src/utils/beliefs_propagation.rs @@ -0,0 +1,25 @@ +use crate::{utils::function_pointer::FnType, model::Network, updates::observations::observation_update}; + +/// Single time slice belief propagation. +/// +/// # Arguments +/// * `observations` - A vector of values, each value is one new observation associated +/// with one node. +pub fn belief_propagation(network: &mut Network, observations_set: Vec, predictions: & Vec<(usize, FnType)>, updates: & Vec<(usize, FnType)>) { + + // 1. prediction steps + for (idx, step) in predictions.iter() { + step(network, *idx); + } + + // 2. observation steps + for (i, observations) in observations_set.iter().enumerate() { + let idx = network.inputs[i]; + observation_update(network, idx, *observations); + } + + // 3. update steps + for (idx, step) in updates.iter() { + step(network, *idx); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1918ac547..6c65a8867 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1,3 @@ pub mod set_sequence; -pub mod function_pointer; \ No newline at end of file +pub mod function_pointer; +pub mod beliefs_propagation; \ No newline at end of file