Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Hashmaps, vector preallocation and avoid cloning [Rust] #250

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "rshgf"
crate-type = ["cdylib"]
crate-type = ["cdylib", "rlib"]
path = "src/lib.rs" # The source file of the target.

[dependencies]
Expand Down
22 changes: 22 additions & 0 deletions examples/exponential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use rshgf::model::Network;

fn main() {

// initialize network
let mut network = Network::new();

// create a network with two exponential family state nodes
network.add_nodes(
"exponential-state",
None,
None,
None,
None
);

// belief propagation
let input_data = vec![1.0, 1.3, 1.5, 1.7];
network.set_update_sequence();
network.input_data(input_data);

}
78 changes: 26 additions & 52 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use crate::{updates::observations::observation_update, utils::function_pointer::FnType};
use crate::utils::function_pointer::FnType;
use crate::utils::set_sequence::set_update_sequence;
use crate::utils::beliefs_propagation::belief_propagation;
use crate::utils::function_pointer::get_func_map;
use pyo3::types::PyTuple;
use pyo3::{prelude::*, types::{PyList, PyDict}};
Expand Down Expand Up @@ -129,95 +130,68 @@ impl Network {
self.update_sequence = set_update_sequence(self);
}

/// Single time slice belief propagation.
///
/// # Arguments
/// * `observations` - A vector of values, each value is one new observation associated
/// with one node.
pub fn belief_propagation(&mut self, observations_set: Vec<f64>) {

let predictions = self.update_sequence.predictions.clone();
let updates = self.update_sequence.updates.clone();

// 1. prediction steps
for (idx, step) in predictions.iter() {
step(self, *idx);
}

// 2. observation steps
for (i, observations) in observations_set.iter().enumerate() {
let idx = self.inputs[i];
observation_update(self, idx, *observations);
}

// 3. update steps
for (idx, step) in updates.iter() {
step(self, *idx);
}
}

/// Add a sequence of observations.
///
/// # Arguments
/// * `input_data` - A vector of vectors. Each vector is a time series of observations
/// associated with one node.
pub fn input_data(&mut self, input_data: Vec<f64>) {

let n_time = input_data.len();
let predictions = self.update_sequence.predictions.clone();
let updates = self.update_sequence.updates.clone();

// initialize the belief trajectories result struture
let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()};

// add empty vectors in the floats hashmap
// preallocate empty vectors in the floats hashmap
for (node_idx, node) in &self.attributes.floats {
let new_map: HashMap<String, Vec<f64>> = HashMap::new();
node_trajectories.floats.insert(*node_idx, new_map);
if let Some(attr) = node_trajectories.floats.get_mut(node_idx) {
for key in node.keys() {
attr.insert(key.clone(), Vec::new());
}
let attr = node_trajectories.floats.get_mut(node_idx).expect("New map not found.");
for key in node.keys() {
attr.insert(key.clone(), Vec::with_capacity(n_time));
}
}
// add empty vectors in the vectors hashmap
}

// preallocate empty vectors in the vectors hashmap
for (node_idx, node) in &self.attributes.vectors {
let new_map: HashMap<String, Vec<Vec<f64>>> = HashMap::new();
node_trajectories.vectors.insert(*node_idx, new_map);
if let Some(attr) = node_trajectories.vectors.get_mut(node_idx) {
for key in node.keys() {
attr.insert(key.clone(), Vec::new());
}
let attr = node_trajectories.vectors.get_mut(node_idx).expect("New vector map not found.");
for key in node.keys() {
attr.insert(key.clone(), Vec::with_capacity(n_time));
}
}
}


// iterate over the observations
for observation in input_data {

// 1. belief propagation for one time slice
self.belief_propagation(vec![observation]);
belief_propagation(self, vec![observation], &predictions, &updates);

// 2. append the new beliefs in the trajectories structure
// iterate over the float hashmap
for (new_node_idx, new_node) in &self.attributes.floats {
for (new_key, new_value) in new_node {
// If the key exists in map1, append the vector from map2
if let Some(old_node) = node_trajectories.floats.get_mut(&new_node_idx) {
if let Some(old_value) = old_node.get_mut(new_key) {
old_value.push(*new_value);
}
let old_node = node_trajectories.floats.get_mut(&new_node_idx).expect("Old node not found.");
let old_value = old_node.get_mut(new_key).expect("Old value not found");
old_value.push(*new_value);
}
}
}

// iterate over the vector hashmap
for (new_node_idx, new_node) in &self.attributes.vectors {
for (new_key, new_value) in new_node {
// If the key exists in map1, append the vector from map2
if let Some(old_node) = node_trajectories.vectors.get_mut(&new_node_idx) {
if let Some(old_value) = old_node.get_mut(new_key) {
old_value.push(new_value.clone());
}
let old_node = node_trajectories.vectors.get_mut(&new_node_idx).expect("Old vector node not found.");
let old_value = old_node.get_mut(new_key).expect("Old vector value not found.");
old_value.push(new_value.clone());
}
}
}
}

self.node_trajectories = node_trajectories;
}

Expand Down
Empty file removed src/tests/exponential_family.rs
Empty file.
33 changes: 11 additions & 22 deletions src/updates/prediction_error/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,14 @@ use crate::math::sufficient_statistics;
/// * `network` - The network after message passing.
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) {

if let Some(floats_attributes) = network.attributes.floats.get_mut(&node_idx) {
if let Some(vectors_attributes) = network.attributes.vectors.get_mut(&node_idx) {
let mean = floats_attributes.get("mean");
let nus = floats_attributes.get("nus");
let xis = vectors_attributes.get("xis");
let new_xis = match (mean, nus, xis) {
(Some(mean), Some(nus), Some(xis)) => {
let suf_stats = sufficient_statistics(mean);
let mut new_xis = xis.clone();
for i in 0..suf_stats.len() {
new_xis[i] = new_xis[i] + (1.0 / (1.0 + nus)) * (suf_stats[i] - xis[i]);
}
new_xis
}
_ => Vec::new(),
};
if let Some(xis) = vectors_attributes.get_mut("xis") {
*xis = new_xis; // Modify the value directly
}
}
}
}
let floats_attributes = network.attributes.floats.get_mut(&node_idx).expect("No floats attributes");
let vectors_attributes = network.attributes.vectors.get_mut(&node_idx).expect("No vector attributes");
let mean = floats_attributes.get("mean").expect("Mean not found");
let nus = floats_attributes.get("nus").expect("Nus not found");
let xis = vectors_attributes.get_mut("xis").expect("Xis not found");

let suf_stats = sufficient_statistics(mean);
for i in 0..suf_stats.len() {
xis[i] = xis[i] + (1.0 / (1.0 + nus)) * (suf_stats[i] - xis[i]);
}
}
25 changes: 25 additions & 0 deletions src/utils/beliefs_propagation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::{utils::function_pointer::FnType, model::Network, updates::observations::observation_update};

/// Single time slice belief propagation.
///
/// # Arguments
/// * `observations` - A vector of values, each value is one new observation associated
/// with one node.
pub fn belief_propagation(network: &mut Network, observations_set: Vec<f64>, predictions: & Vec<(usize, FnType)>, updates: & Vec<(usize, FnType)>) {

// 1. prediction steps
for (idx, step) in predictions.iter() {
step(network, *idx);
}

// 2. observation steps
for (i, observations) in observations_set.iter().enumerate() {
let idx = network.inputs[i];
observation_update(network, idx, *observations);
}

// 3. update steps
for (idx, step) in updates.iter() {
step(network, *idx);
}
}
3 changes: 2 additions & 1 deletion src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod set_sequence;
pub mod function_pointer;
pub mod function_pointer;
pub mod beliefs_propagation;
Loading