Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 14, 2024
1 parent 8dcaf96 commit 467deb8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 15 deletions.
20 changes: 10 additions & 10 deletions src/hgf/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/hgf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ crate-type = ["cdylib"]
path = "src/lib.rs" # The source file of the target.

[dependencies]
pyo3 = "0.22.2"
pyo3 = { version = "0.22.3", features = ["extension-module"] }
20 changes: 17 additions & 3 deletions src/hgf/src/network.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::collections::HashMap;
use crate::updates::posterior;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

#[derive(Debug)]
pub struct AdjacencyLists{
Expand All @@ -25,14 +27,18 @@ pub enum Node {
}

#[derive(Debug)]
#[pyclass]
pub struct Network{
pub nodes: HashMap<usize, Node>,
pub edges: Vec<AdjacencyLists>,
pub inputs: Vec<usize>,
}

#[pymethods]
impl Network {

// Create a new graph
#[new] // Define the constructor accessible from Python
pub fn new() -> Self {
Network {
nodes: HashMap::new(),
Expand All @@ -42,6 +48,7 @@ impl Network {
}

// Add a node to the graph
#[pyo3(signature = (kind, value_parents=None, value_childrens=None))]
pub fn add_node(&mut self, kind: String, value_parents: Option<usize>, value_childrens: Option<usize>) {

// the node ID is equal to the number of nodes already in the network
Expand Down Expand Up @@ -99,9 +106,9 @@ impl Network {
}
}

pub fn posterior_update(&mut self, node_idx: &usize, observation: f64) {
pub fn posterior_update(&mut self, node_idx: usize, observation: f64) {

match self.nodes.get_mut(node_idx) {
match self.nodes.get_mut(&node_idx) {
Some(Node::Generic(ref mut node)) => {
node.observation = observation
}
Expand All @@ -120,7 +127,7 @@ impl Network {

let input_node_idx = self.inputs[i];
// 2. inject the observations into the input nodes
self.posterior_update(&input_node_idx, observations[i]);
self.posterior_update(input_node_idx, observations[i]);
// 3. posterior update - prediction errors propagation
self.prediction_error(input_node_idx);
}
Expand All @@ -135,6 +142,13 @@ impl Network {
}


// Create a module to expose the class to Python
#[pymodule]
fn my_rust_library(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Network>()?; // Add the class to the Python module
Ok(())
}

// Tests module for unit tests
#[cfg(test)] // Only compile and include this module when running tests
mod tests {
Expand Down
1 change: 0 additions & 1 deletion src/hgf/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::network::Network;
use std::collections::HashSet;


pub fn get_update_order(network: Network) -> Vec<usize> {
Expand Down

0 comments on commit 467deb8

Please sign in to comment.