From c7232807b98ab8b38315ae8b58a8a61d9e0f5b5c Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 5 Feb 2024 10:34:49 +0100 Subject: [PATCH 01/20] rust package --- src/rs-hgf/.github/workflows/CI.yml | 120 ++++++++++++ src/rs-hgf/.gitignore | 72 ++++++++ src/rs-hgf/Cargo.lock | 273 ++++++++++++++++++++++++++++ src/rs-hgf/Cargo.toml | 12 ++ src/rs-hgf/pyproject.toml | 16 ++ src/rs-hgf/src/lib.rs | 14 ++ 6 files changed, 507 insertions(+) create mode 100644 src/rs-hgf/.github/workflows/CI.yml create mode 100644 src/rs-hgf/.gitignore create mode 100644 src/rs-hgf/Cargo.lock create mode 100644 src/rs-hgf/Cargo.toml create mode 100644 src/rs-hgf/pyproject.toml create mode 100644 src/rs-hgf/src/lib.rs diff --git a/src/rs-hgf/.github/workflows/CI.yml b/src/rs-hgf/.github/workflows/CI.yml new file mode 100644 index 000000000..1bae4be43 --- /dev/null +++ b/src/rs-hgf/.github/workflows/CI.yml @@ -0,0 +1,120 @@ +# This file is autogenerated by maturin v1.4.0 +# To update, run +# +# maturin generate-ci github +# +name: CI + +on: + push: + branches: + - main + - master + tags: + - '*' + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + linux: + runs-on: ubuntu-latest + strategy: + matrix: + target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + windows: + runs-on: windows-latest + strategy: + matrix: + target: [x64, x86] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + architecture: ${{ matrix.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + macos: + runs-on: macos-latest + strategy: + matrix: + target: [x86_64, aarch64] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + sdist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + release: + name: Release + runs-on: ubuntu-latest + if: "startsWith(github.ref, 'refs/tags/')" + needs: [linux, windows, macos, sdist] + steps: + - uses: actions/download-artifact@v3 + with: + name: wheels + - name: Publish to PyPI + uses: PyO3/maturin-action@v1 + env: + MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + with: + command: upload + args: --non-interactive --skip-existing * diff --git a/src/rs-hgf/.gitignore b/src/rs-hgf/.gitignore new file mode 100644 index 000000000..c8f044299 --- /dev/null +++ b/src/rs-hgf/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version diff --git a/src/rs-hgf/Cargo.lock b/src/rs-hgf/Cargo.lock new file mode 100644 index 000000000..06994a31a --- /dev/null +++ b/src/rs-hgf/Cargo.lock @@ -0,0 +1,273 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "proc-macro2" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rs-hgf" +version = "0.1.0" +dependencies = [ + "pyo3", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/src/rs-hgf/Cargo.toml b/src/rs-hgf/Cargo.toml new file mode 100644 index 000000000..19da1aa16 --- /dev/null +++ b/src/rs-hgf/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "rs-hgf" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "rs_hgf" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = "0.19.0" diff --git a/src/rs-hgf/pyproject.toml b/src/rs-hgf/pyproject.toml new file mode 100644 index 000000000..c370be96a --- /dev/null +++ b/src/rs-hgf/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["maturin>=1.4,<2.0"] +build-backend = "maturin" + +[project] +name = "rs-hgf" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] + +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/src/rs-hgf/src/lib.rs b/src/rs-hgf/src/lib.rs new file mode 100644 index 000000000..2b7e57db8 --- /dev/null +++ b/src/rs-hgf/src/lib.rs @@ -0,0 +1,14 @@ +use pyo3::prelude::*; + +/// Formats the sum of two numbers as string. +#[pyfunction] +fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) +} + +/// A Python module implemented in Rust. +#[pymodule] +fn rs_hgf(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + Ok(()) +} From 384168fde68d426f8fcd38be14d95fbaf47757e2 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 3 Jun 2024 10:05:02 +0200 Subject: [PATCH 02/20] rust --- src/rs-hgf/src/lib.rs | 24 ++-- src/rs-hgf/src/main.rs | 242 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+), 12 deletions(-) create mode 100644 src/rs-hgf/src/main.rs diff --git a/src/rs-hgf/src/lib.rs b/src/rs-hgf/src/lib.rs index 2b7e57db8..53f27b16d 100644 --- a/src/rs-hgf/src/lib.rs +++ b/src/rs-hgf/src/lib.rs @@ -1,14 +1,14 @@ -use pyo3::prelude::*; +// use pyo3::prelude::*; -/// Formats the sum of two numbers as string. -#[pyfunction] -fn sum_as_string(a: usize, b: usize) -> PyResult { - Ok((a + b).to_string()) -} +// /// Formats the sum of two numbers as string. +// #[pyfunction] +// fn sum_as_string(a: usize, b: usize) -> PyResult { +// Ok((a + b).to_string()) +// } -/// A Python module implemented in Rust. -#[pymodule] -fn rs_hgf(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; - Ok(()) -} +// /// A Python module implemented in Rust. +// #[pymodule] +// fn rs_hgf(_py: Python, m: &PyModule) -> PyResult<()> { +// m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; +// Ok(()) +// } diff --git a/src/rs-hgf/src/main.rs b/src/rs-hgf/src/main.rs new file mode 100644 index 000000000..344d014aa --- /dev/null +++ b/src/rs-hgf/src/main.rs @@ -0,0 +1,242 @@ + +struct AdjacencyLists{ + value_parents: Option, + value_children: Option, +} + +struct GenericInputNode{ + observation: f64, + time_step: f64, + edges: AdjacencyLists, +} + +struct ExponentialNode { + observation: f64, + nus: f64, + xis: f64, + edges: AdjacencyLists, +} +enum Nodes { + Generic(GenericInputNode), + Exponential(ExponentialNode), +} + +struct Network{ + nodes: Vec, + update_sequence: Vec Network>, +} + +pub trait NodeControls { + fn prediction_error(&self, network: Network) -> Network; + fn get_observation(&mut self, observation: f64); +} + + +impl GenericInputNode { + + fn get_observation(&mut self, observation: f64) { + self.observation = observation + } + + fn prediction_error(&mut self, observation: f64) { + self.observation = observation + } + +} + +impl ExponentialNode { + + fn get_observation(&mut self, observation: f64) { + self.observation = observation + } + + fn prediction_error(&mut self, network: Network) -> Network { + self.observation = observation + } + +} + + +impl NodeControls for Nodes { + + fn get_observation(&mut self, observation: f64) { + + match self { + Nodes::Generic(node) => node.get_observation(observation), + Nodes::Exponential(node) => node.get_observation(observation), + } + } + + fn prediction_error(&self, network: Network) -> Network { + + match self { + Nodes::Generic(node) => node.prediction_error(network), + Nodes::Exponential(node) => node.prediction_error(network), + } + } +} + + + + + + + + + + + + + + + + + + + +impl NodeControls for Nodes { + + fn prediction_error(&self, network: Network) -> Network { + + // get value parents + let value_parent_idx = self.edges.value_parents; + + if let Some(idx) = value_parent_idx { + network.nodes[idx].get_observation(self.observation); + } + else { + println!("No index provided"); + } + network + } + + fn get_observation(&mut self, observation: f64) { + self.observation = observation + } +} + + +impl NodeControls for ExponentialNode { + + fn get_observation(&mut self, observation: f64) { + self.observation = observation + } + + fn prediction_error(&self, network: Network) -> Network { + + // get value parents + let value_parent_idx = self.edges.value_parents; + + if let Some(idx) = value_parent_idx { + network.nodes[idx].get_observation(self.observation); + } + else { + println!("No index provided"); + } + network + } + +} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +pub trait PredictionError { + fn prediction_error(&self, network: Network) -> Network { + network + } +} + +impl Nodes { + fn prediction_error(&self, network: Network) -> Network { + + match network.nodes[idx] { + Nodes::GenericInputNode | Nodes::ExponentialNode(ref mut observation) => { + // get value parent + let value_parent_idx = self.edges.value_parents; + // store the input value in the value parent, if any + if let Some(idx) = value_parent_idx { + *observation = self.observation; + }, + _ => { + println!("This node cannot receive observations.") + } else { + println!("No index provided"); + } + } + } + + } + network +} + +fn main() { + + // create input data + let values = [1.1, 1.2, 1.0, 4.0, 1.1]; + let time_steps = [1.0, 1.0, 1.0, 1.0, 1.0]; + + // define the update sequence + let update_sequence: [fn(); 2] = [generic_input_update, exponential_node_update]; + + // define the attributes + let mut generic_input: GenericInput = GenericInput{ + observation: 0.0, + time_step: 0.0, + edges: AdjacencyLists{ + value_children: None, + value_parents: Some(1), + } + }; + + let mut exponential_node: ExponentialNode = ExponentialNode{ + observation: 0.0, + nus: 0.0, + xis: 0.0, + edges: AdjacencyLists{ + value_children: Some(0), + value_parents: None, + } + }; + + // define edges + let edges: (AdjacencyLists, AdjacencyLists) = (edge_1, edge_2); + + // Iterate over the update sequence and call each function with node pointer + for func in update_sequence.iter() { + func(); + } + +} \ No newline at end of file From 5e4956202b96d90e39a6b7ffd7c26f658eca15b7 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Fri, 7 Jun 2024 07:59:55 +0200 Subject: [PATCH 03/20] simple working version --- src/rs-hgf/src/main.rs | 348 +++++++++++++++++++---------------------- 1 file changed, 163 insertions(+), 185 deletions(-) diff --git a/src/rs-hgf/src/main.rs b/src/rs-hgf/src/main.rs index 344d014aa..7809c5dc4 100644 --- a/src/rs-hgf/src/main.rs +++ b/src/rs-hgf/src/main.rs @@ -1,242 +1,220 @@ +use std::collections::HashMap; +#[derive(Debug)] struct AdjacencyLists{ value_parents: Option, value_children: Option, } - +#[derive(Debug, Clone)] struct GenericInputNode{ observation: f64, time_step: f64, - edges: AdjacencyLists, } - +#[derive(Debug, Clone)] struct ExponentialNode { observation: f64, nus: f64, - xis: f64, - edges: AdjacencyLists, + xis: [f64; 2], } -enum Nodes { + +#[derive(Debug, Clone)] +enum Node { Generic(GenericInputNode), Exponential(ExponentialNode), } +#[derive(Debug)] struct Network{ - nodes: Vec, - update_sequence: Vec Network>, + nodes: HashMap, + edges: Vec, + inputs: Vec, } -pub trait NodeControls { - fn prediction_error(&self, network: Network) -> Network; - fn get_observation(&mut self, observation: f64); +fn sufficient_statistics(x: &f64) -> [f64; 2] { + [*x, x.powf(2.0)] } - -impl GenericInputNode { - - fn get_observation(&mut self, observation: f64) { - self.observation = observation - } - - fn prediction_error(&mut self, observation: f64) { - self.observation = observation - } - -} - -impl ExponentialNode { - - fn get_observation(&mut self, observation: f64) { - self.observation = observation +impl Network { + // Create a new graph + fn new() -> Self { + Network { + nodes: HashMap::new(), + edges: Vec::new(), + inputs: Vec::new(), + } } - fn prediction_error(&mut self, network: Network) -> Network { - self.observation = observation + // Add a node to the graph + fn add_node(&mut self, kind: String, value_parents: Option, value_childrens: Option) { + + // the node ID is equal to the number of nodes already in the network + let node_id: usize = self.nodes.len(); + + let edges = AdjacencyLists{ + value_children: value_parents, + value_parents: value_childrens, + }; + + // add edges and attributes + if kind == "generic-input" { + let generic_input = GenericInputNode{observation: 0.0, time_step: 0.0}; + let node = Node::Generic(generic_input); + self.nodes.insert(node_id, node); + self.edges.push(edges); + self.inputs.push(node_id); + } else if kind == "exponential-node" { + let exponential_node: ExponentialNode = ExponentialNode{observation: 0.0, nus: 0.0, xis: [0.0, 0.0]}; + let node = Node::Exponential(exponential_node); + self.nodes.insert(node_id, node); + self.edges.push(edges); + } else { + println!("Invalid type of node provided ({}).", kind); + } } -} - - -impl NodeControls for Nodes { + fn posterior_update(&mut self, node_idx: &usize, observation: f64) { - fn get_observation(&mut self, observation: f64) { - - match self { - Nodes::Generic(node) => node.get_observation(observation), - Nodes::Exponential(node) => node.get_observation(observation), + match self.nodes.get_mut(node_idx) { + Some(Node::Generic(ref mut node)) => { + node.observation = observation + } + Some(Node::Exponential(ref mut node)) => { + let suf_stats = sufficient_statistics(&node.observation); + for i in 0..suf_stats.len() { + node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); + } + } + None => println!("The value is None") } } - fn prediction_error(&self, network: Network) -> Network { + fn prediction_error(&mut self, node_idx: usize) { - match self { - Nodes::Generic(node) => node.prediction_error(network), - Nodes::Exponential(node) => node.prediction_error(network), + // get the observation value + let observation; + match self.nodes[&node_idx] { + Node::Generic(ref node) => { + observation = node.observation; + } + Node::Exponential(ref node) => { + observation = node.observation; + } } - } -} - - - - - - - - - - - - - - - - - - -impl NodeControls for Nodes { - - fn prediction_error(&self, network: Network) -> Network { - - // get value parents - let value_parent_idx = self.edges.value_parents; - - if let Some(idx) = value_parent_idx { - network.nodes[idx].get_observation(self.observation); + let value_parent_idx = &self.edges[node_idx].value_parents; + match value_parent_idx { + Some(idx) => { + match self.nodes.get_mut(idx) { + Some(Node::Generic(ref mut parent)) => { + parent.observation = observation + } + Some(Node::Exponential(ref mut parent)) => { + parent.observation = observation + } + None => println!("The value is None"), + } + } + None => println!("The value is None"), } - else { - println!("No index provided"); } - network - } - - fn get_observation(&mut self, observation: f64) { - self.observation = observation - } -} - + + fn belief_propagation(&mut self, observations: Vec) { -impl NodeControls for ExponentialNode { + // 1. prediction propagation - fn get_observation(&mut self, observation: f64) { - self.observation = observation - } + // 2. inject the observations into the input nodes + for i in 0..observations.len() { - fn prediction_error(&self, network: Network) -> Network { + let input_node_idx = self.inputs[i]; + self.posterior_update(&input_node_idx, observations[i]); + self.prediction_error(input_node_idx); + } - // get value parents - let value_parent_idx = self.edges.value_parents; + // 3. posterior update - prediction errors propagation + } - if let Some(idx) = value_parent_idx { - network.nodes[idx].get_observation(self.observation); + fn input_data(&mut self, input_data: Vec>) { + for observation in input_data { + self.belief_propagation(observation); } - else { - println!("No index provided"); - } - network } -} - - - - - - - - - - - - - - - - - - - - - - + fn get_update_order(self) -> Vec { + + let mut update_list = Vec::new(); + // list all nodes availables in the network + let mut nodes_idxs: Vec = self.nodes.keys().cloned().collect(); + // remove the input nodes + nodes_idxs.retain(|x| !self.inputs.contains(x)); + // start with the value parents of input nodes + for input_idx in self.inputs { + let value_parent_idxs = self.edges[input_idx].value_parents; + match value_parent_idxs { + Some(idx) => { + // if this parent is still in the list, update it now + if nodes_idxs.contains(&idx) { + // add the node in the update list + update_list.push(idx); + // remove the parent from the availables nodes list + nodes_idxs.retain(|&x| x != idx); - - - - - - - - -pub trait PredictionError { - fn prediction_error(&self, network: Network) -> Network { - network - } -} - -impl Nodes { - fn prediction_error(&self, network: Network) -> Network { - - match network.nodes[idx] { - Nodes::GenericInputNode | Nodes::ExponentialNode(ref mut observation) => { - // get value parent - let value_parent_idx = self.edges.value_parents; - // store the input value in the value parent, if any - if let Some(idx) = value_parent_idx { - *observation = self.observation; - }, - _ => { - println!("This node cannot receive observations.") - } else { - println!("No index provided"); - } + } } + None => println!("The value is None") } + } + nodes_idxs + } - } - network -} - -fn main() { - - // create input data - let values = [1.1, 1.2, 1.0, 4.0, 1.1]; - let time_steps = [1.0, 1.0, 1.0, 1.0, 1.0]; + } - // define the update sequence - let update_sequence: [fn(); 2] = [generic_input_update, exponential_node_update]; - // define the attributes - let mut generic_input: GenericInput = GenericInput{ - observation: 0.0, - time_step: 0.0, - edges: AdjacencyLists{ - value_children: None, - value_parents: Some(1), - } - }; - - let mut exponential_node: ExponentialNode = ExponentialNode{ - observation: 0.0, - nus: 0.0, - xis: 0.0, - edges: AdjacencyLists{ - value_children: Some(0), - value_parents: None, - } - }; - - // define edges - let edges: (AdjacencyLists, AdjacencyLists) = (edge_1, edge_2); +fn main() { - // Iterate over the update sequence and call each function with node pointer - for func in update_sequence.iter() { - func(); - } + // initialize network + let mut network = Network::new(); + + // create a network + network.add_node( + String::from("generic-input"), + None, + None, + ); + network.add_node( + String::from("generic-input"), + None, + None, + ); + network.add_node( + String::from("exponential-node"), + None, + Some(0), + ); + network.add_node( + String::from("exponential-node"), + None, + Some(1), + ); + + println!("Graph before belief propagation: {:?}", network); + + // belief propagation + let input_data = vec![ + vec![1.1, 2.2], + vec![1.2, 2.1], + vec![1.0, 2.0], + vec![1.3, 2.2], + vec![1.1, 2.5], + vec![1.0, 2.6], + ]; + + network.input_data(input_data); + println!("Graph after belief propagation: {:?}", network); + } \ No newline at end of file From 1a318bd0a83d5633de0e9b43738577de6879fe74 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 2 Sep 2024 11:27:01 +0200 Subject: [PATCH 04/20] package working with tests --- src/{rs-hgf => hgf}/.gitignore | 0 src/hgf/Cargo.lock | 171 +++++++++++ src/{rs-hgf => hgf}/Cargo.toml | 7 +- src/{rs-hgf => hgf}/pyproject.toml | 2 +- src/hgf/src/lib.rs | 4 + src/hgf/src/math.rs | 3 + src/hgf/src/network.rs | 189 ++++++++++++ src/hgf/src/updates/mod.rs | 3 + src/hgf/src/updates/posterior/continuous.rs | 9 + src/hgf/src/updates/posterior/mod.rs | 1 + src/hgf/src/updates/prediction/mod.rs | 0 .../updates/prediction_error/inputs/mod.rs | 0 src/hgf/src/updates/prediction_error/mod.rs | 2 + .../prediction_error/nodes/continuous.rs | 0 .../src/updates/prediction_error/nodes/mod.rs | 1 + src/hgf/src/utils.rs | 34 +++ src/hgf/tests/exponential_family.rs | 0 src/rs-hgf/.github/workflows/CI.yml | 120 -------- src/rs-hgf/Cargo.lock | 273 ------------------ src/rs-hgf/src/lib.rs | 14 - src/rs-hgf/src/main.rs | 220 -------------- 21 files changed, 422 insertions(+), 631 deletions(-) rename src/{rs-hgf => hgf}/.gitignore (100%) create mode 100644 src/hgf/Cargo.lock rename src/{rs-hgf => hgf}/Cargo.toml (65%) rename src/{rs-hgf => hgf}/pyproject.toml (95%) create mode 100644 src/hgf/src/lib.rs create mode 100644 src/hgf/src/math.rs create mode 100644 src/hgf/src/network.rs create mode 100644 src/hgf/src/updates/mod.rs create mode 100644 src/hgf/src/updates/posterior/continuous.rs create mode 100644 src/hgf/src/updates/posterior/mod.rs create mode 100644 src/hgf/src/updates/prediction/mod.rs create mode 100644 src/hgf/src/updates/prediction_error/inputs/mod.rs create mode 100644 src/hgf/src/updates/prediction_error/mod.rs create mode 100644 src/hgf/src/updates/prediction_error/nodes/continuous.rs create mode 100644 src/hgf/src/updates/prediction_error/nodes/mod.rs create mode 100644 src/hgf/src/utils.rs create mode 100644 src/hgf/tests/exponential_family.rs delete mode 100644 src/rs-hgf/.github/workflows/CI.yml delete mode 100644 src/rs-hgf/Cargo.lock delete mode 100644 src/rs-hgf/src/lib.rs delete mode 100644 src/rs-hgf/src/main.rs diff --git a/src/rs-hgf/.gitignore b/src/hgf/.gitignore similarity index 100% rename from src/rs-hgf/.gitignore rename to src/hgf/.gitignore diff --git a/src/hgf/Cargo.lock b/src/hgf/Cargo.lock new file mode 100644 index 000000000..a63f242ae --- /dev/null +++ b/src/hgf/Cargo.lock @@ -0,0 +1,171 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hgf" +version = "0.1.0" +dependencies = [ + "pyo3", +] + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" diff --git a/src/rs-hgf/Cargo.toml b/src/hgf/Cargo.toml similarity index 65% rename from src/rs-hgf/Cargo.toml rename to src/hgf/Cargo.toml index 19da1aa16..9e430d5e9 100644 --- a/src/rs-hgf/Cargo.toml +++ b/src/hgf/Cargo.toml @@ -1,12 +1,13 @@ [package] -name = "rs-hgf" +name = "hgf" version = "0.1.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] -name = "rs_hgf" +name = "hgf" crate-type = ["cdylib"] +path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = "0.19.0" +pyo3 = "0.22.2" diff --git a/src/rs-hgf/pyproject.toml b/src/hgf/pyproject.toml similarity index 95% rename from src/rs-hgf/pyproject.toml rename to src/hgf/pyproject.toml index c370be96a..399c9df9e 100644 --- a/src/rs-hgf/pyproject.toml +++ b/src/hgf/pyproject.toml @@ -3,7 +3,7 @@ requires = ["maturin>=1.4,<2.0"] build-backend = "maturin" [project] -name = "rs-hgf" +name = "hgf" requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", diff --git a/src/hgf/src/lib.rs b/src/hgf/src/lib.rs new file mode 100644 index 000000000..504294410 --- /dev/null +++ b/src/hgf/src/lib.rs @@ -0,0 +1,4 @@ +pub mod network; +pub mod utils; +pub mod math; +pub mod updates; \ No newline at end of file diff --git a/src/hgf/src/math.rs b/src/hgf/src/math.rs new file mode 100644 index 000000000..20a0e568b --- /dev/null +++ b/src/hgf/src/math.rs @@ -0,0 +1,3 @@ +pub fn sufficient_statistics(x: &f64) -> [f64; 2] { + [*x, x.powf(2.0)] +} \ No newline at end of file diff --git a/src/hgf/src/network.rs b/src/hgf/src/network.rs new file mode 100644 index 000000000..1d79a952e --- /dev/null +++ b/src/hgf/src/network.rs @@ -0,0 +1,189 @@ +use std::collections::HashMap; +use crate::updates::posterior; + +#[derive(Debug)] +pub struct AdjacencyLists{ + pub value_parents: Option, + pub value_children: Option, +} +#[derive(Debug, Clone)] +pub struct GenericInputNode{ + pub observation: f64, + pub time_step: f64, +} +#[derive(Debug, Clone)] +pub struct ExponentialNode { + pub observation: f64, + pub nus: f64, + pub xis: [f64; 2], +} + +#[derive(Debug, Clone)] +pub enum Node { + Generic(GenericInputNode), + Exponential(ExponentialNode), +} + +#[derive(Debug)] +pub struct Network{ + pub nodes: HashMap, + pub edges: Vec, + pub inputs: Vec, +} + +impl Network { + // Create a new graph + pub fn new() -> Self { + Network { + nodes: HashMap::new(), + edges: Vec::new(), + inputs: Vec::new(), + } + } + + // Add a node to the graph + pub fn add_node(&mut self, kind: String, value_parents: Option, value_childrens: Option) { + + // the node ID is equal to the number of nodes already in the network + let node_id: usize = self.nodes.len(); + + let edges = AdjacencyLists{ + value_children: value_parents, + value_parents: value_childrens, + }; + + // add edges and attributes + if kind == "generic-input" { + let generic_input = GenericInputNode{observation: 0.0, time_step: 0.0}; + let node = Node::Generic(generic_input); + self.nodes.insert(node_id, node); + self.edges.push(edges); + self.inputs.push(node_id); + } else if kind == "exponential-node" { + let exponential_node: ExponentialNode = ExponentialNode{observation: 0.0, nus: 0.0, xis: [0.0, 0.0]}; + let node = Node::Exponential(exponential_node); + self.nodes.insert(node_id, node); + self.edges.push(edges); + } else { + println!("Invalid type of node provided ({}).", kind); + } + } + + pub fn prediction_error(&mut self, node_idx: usize) { + + // get the observation value + let observation; + match self.nodes[&node_idx] { + Node::Generic(ref node) => { + observation = node.observation; + } + Node::Exponential(ref node) => { + observation = node.observation; + } + } + + let value_parent_idx = &self.edges[node_idx].value_parents; + match value_parent_idx { + Some(idx) => { + match self.nodes.get_mut(idx) { + Some(Node::Generic(ref mut parent)) => { + parent.observation = observation + } + Some(Node::Exponential(ref mut parent)) => { + parent.observation = observation + } + None => println!("No prediction error for this type of node."), + } + } + None => println!("No value parent"), + } + } + + pub fn posterior_update(&mut self, node_idx: &usize, observation: f64) { + + match self.nodes.get_mut(node_idx) { + Some(Node::Generic(ref mut node)) => { + node.observation = observation + } + Some(Node::Exponential(ref mut node)) => { + posterior::continuous::posterior_update_exponential(node) + } + None => println!("No posterior update for this type of node.") + } + } + + pub fn belief_propagation(&mut self, observations: Vec) { + + // 1. prediction propagation + + // 2. inject the observations into the input nodes + for i in 0..observations.len() { + + let input_node_idx = self.inputs[i]; + self.posterior_update(&input_node_idx, observations[i]); + self.prediction_error(input_node_idx); + } + + // 3. posterior update - prediction errors propagation + } + + pub fn input_data(&mut self, input_data: Vec>) { + for observation in input_data { + self.belief_propagation(observation); + } + } + } + + +// Tests module for unit tests +#[cfg(test)] // Only compile and include this module when running tests +mod tests { + use super::*; // Import the parent module's items to test them + + + #[test] + fn test_exponential_family_gaussian() { + + // initialize network + let mut network = Network::new(); + + // create a network + network.add_node( + String::from("generic-input"), + None, + None, + ); + network.add_node( + String::from("generic-input"), + None, + None, + ); + network.add_node( + String::from("exponential-node"), + None, + Some(0), + ); + network.add_node( + String::from("exponential-node"), + None, + Some(1), + ); + + println!("Graph before belief propagation: {:?}", network); + + // belief propagation + let input_data = vec![ + vec![1.1, 2.2], + vec![1.2, 2.1], + vec![1.0, 2.0], + vec![1.3, 2.2], + vec![1.1, 2.5], + vec![1.0, 2.6], + ]; + + network.input_data(input_data); + + println!("Graph after belief propagation: {:?}", network); + + } +} diff --git a/src/hgf/src/updates/mod.rs b/src/hgf/src/updates/mod.rs new file mode 100644 index 000000000..dd8404066 --- /dev/null +++ b/src/hgf/src/updates/mod.rs @@ -0,0 +1,3 @@ +pub mod posterior; +pub mod prediction; +pub mod prediction_error; \ No newline at end of file diff --git a/src/hgf/src/updates/posterior/continuous.rs b/src/hgf/src/updates/posterior/continuous.rs new file mode 100644 index 000000000..7fb5b78e6 --- /dev/null +++ b/src/hgf/src/updates/posterior/continuous.rs @@ -0,0 +1,9 @@ +use crate::network::ExponentialNode; +use crate::math::sufficient_statistics; + +pub fn posterior_update_exponential(node: &mut ExponentialNode) { + let suf_stats = sufficient_statistics(&node.observation); + for i in 0..suf_stats.len() { + node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); + } +} \ No newline at end of file diff --git a/src/hgf/src/updates/posterior/mod.rs b/src/hgf/src/updates/posterior/mod.rs new file mode 100644 index 000000000..6817d49e7 --- /dev/null +++ b/src/hgf/src/updates/posterior/mod.rs @@ -0,0 +1 @@ +pub mod continuous; \ No newline at end of file diff --git a/src/hgf/src/updates/prediction/mod.rs b/src/hgf/src/updates/prediction/mod.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/hgf/src/updates/prediction_error/inputs/mod.rs b/src/hgf/src/updates/prediction_error/inputs/mod.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/hgf/src/updates/prediction_error/mod.rs b/src/hgf/src/updates/prediction_error/mod.rs new file mode 100644 index 000000000..264f047ac --- /dev/null +++ b/src/hgf/src/updates/prediction_error/mod.rs @@ -0,0 +1,2 @@ +pub mod inputs; +pub mod nodes; diff --git a/src/hgf/src/updates/prediction_error/nodes/continuous.rs b/src/hgf/src/updates/prediction_error/nodes/continuous.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/hgf/src/updates/prediction_error/nodes/mod.rs b/src/hgf/src/updates/prediction_error/nodes/mod.rs new file mode 100644 index 000000000..6817d49e7 --- /dev/null +++ b/src/hgf/src/updates/prediction_error/nodes/mod.rs @@ -0,0 +1 @@ +pub mod continuous; \ No newline at end of file diff --git a/src/hgf/src/utils.rs b/src/hgf/src/utils.rs new file mode 100644 index 000000000..354166a33 --- /dev/null +++ b/src/hgf/src/utils.rs @@ -0,0 +1,34 @@ +use crate::network::Network; + + +pub fn get_update_order(network: Network) -> Vec { + + let mut update_list = Vec::new(); + + // list all nodes availables in the network + let mut nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + + // remove the input nodes + nodes_idxs.retain(|x| !network.inputs.contains(x)); + + // start with the value parents of input nodes + for input_idx in network.inputs { + let value_parent_idxs = network.edges[input_idx].value_parents; + match value_parent_idxs { + Some(idx) => { + // if this parent is still in the list, update it now + if nodes_idxs.contains(&idx) { + + // add the node in the update list + update_list.push(idx); + + // remove the parent from the availables nodes list + nodes_idxs.retain(|&x| x != idx); + + } + } + None => println!("The value is None") + } + } + nodes_idxs +} diff --git a/src/hgf/tests/exponential_family.rs b/src/hgf/tests/exponential_family.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/rs-hgf/.github/workflows/CI.yml b/src/rs-hgf/.github/workflows/CI.yml deleted file mode 100644 index 1bae4be43..000000000 --- a/src/rs-hgf/.github/workflows/CI.yml +++ /dev/null @@ -1,120 +0,0 @@ -# This file is autogenerated by maturin v1.4.0 -# To update, run -# -# maturin generate-ci github -# -name: CI - -on: - push: - branches: - - main - - master - tags: - - '*' - pull_request: - workflow_dispatch: - -permissions: - contents: read - -jobs: - linux: - runs-on: ubuntu-latest - strategy: - matrix: - target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - manylinux: auto - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - windows: - runs-on: windows-latest - strategy: - matrix: - target: [x64, x86] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: ${{ matrix.target }} - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - macos: - runs-on: macos-latest - strategy: - matrix: - target: [x86_64, aarch64] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Build sdist - uses: PyO3/maturin-action@v1 - with: - command: sdist - args: --out dist - - name: Upload sdist - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - release: - name: Release - runs-on: ubuntu-latest - if: "startsWith(github.ref, 'refs/tags/')" - needs: [linux, windows, macos, sdist] - steps: - - uses: actions/download-artifact@v3 - with: - name: wheels - - name: Publish to PyPI - uses: PyO3/maturin-action@v1 - env: - MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} - with: - command: upload - args: --non-interactive --skip-existing * diff --git a/src/rs-hgf/Cargo.lock b/src/rs-hgf/Cargo.lock deleted file mode 100644 index 06994a31a..000000000 --- a/src/rs-hgf/Cargo.lock +++ /dev/null @@ -1,273 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "indoc" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" - -[[package]] -name = "libc" -version = "0.2.153" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" - -[[package]] -name = "lock_api" -version = "0.4.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "memoffset" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - -[[package]] -name = "proc-macro2" -version = "1.0.78" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "pyo3" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "parking_lot", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "redox_syscall" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" -dependencies = [ - "bitflags", -] - -[[package]] -name = "rs-hgf" -version = "0.1.0" -dependencies = [ - "pyo3", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "smallvec" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "target-lexicon" -version = "0.12.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "unindent" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/src/rs-hgf/src/lib.rs b/src/rs-hgf/src/lib.rs deleted file mode 100644 index 53f27b16d..000000000 --- a/src/rs-hgf/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -// use pyo3::prelude::*; - -// /// Formats the sum of two numbers as string. -// #[pyfunction] -// fn sum_as_string(a: usize, b: usize) -> PyResult { -// Ok((a + b).to_string()) -// } - -// /// A Python module implemented in Rust. -// #[pymodule] -// fn rs_hgf(_py: Python, m: &PyModule) -> PyResult<()> { -// m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; -// Ok(()) -// } diff --git a/src/rs-hgf/src/main.rs b/src/rs-hgf/src/main.rs deleted file mode 100644 index 7809c5dc4..000000000 --- a/src/rs-hgf/src/main.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::collections::HashMap; - -#[derive(Debug)] -struct AdjacencyLists{ - value_parents: Option, - value_children: Option, -} -#[derive(Debug, Clone)] -struct GenericInputNode{ - observation: f64, - time_step: f64, -} -#[derive(Debug, Clone)] -struct ExponentialNode { - observation: f64, - nus: f64, - xis: [f64; 2], -} - -#[derive(Debug, Clone)] -enum Node { - Generic(GenericInputNode), - Exponential(ExponentialNode), -} - -#[derive(Debug)] -struct Network{ - nodes: HashMap, - edges: Vec, - inputs: Vec, -} - -fn sufficient_statistics(x: &f64) -> [f64; 2] { - [*x, x.powf(2.0)] -} - -impl Network { - // Create a new graph - fn new() -> Self { - Network { - nodes: HashMap::new(), - edges: Vec::new(), - inputs: Vec::new(), - } - } - - // Add a node to the graph - fn add_node(&mut self, kind: String, value_parents: Option, value_childrens: Option) { - - // the node ID is equal to the number of nodes already in the network - let node_id: usize = self.nodes.len(); - - let edges = AdjacencyLists{ - value_children: value_parents, - value_parents: value_childrens, - }; - - // add edges and attributes - if kind == "generic-input" { - let generic_input = GenericInputNode{observation: 0.0, time_step: 0.0}; - let node = Node::Generic(generic_input); - self.nodes.insert(node_id, node); - self.edges.push(edges); - self.inputs.push(node_id); - } else if kind == "exponential-node" { - let exponential_node: ExponentialNode = ExponentialNode{observation: 0.0, nus: 0.0, xis: [0.0, 0.0]}; - let node = Node::Exponential(exponential_node); - self.nodes.insert(node_id, node); - self.edges.push(edges); - } else { - println!("Invalid type of node provided ({}).", kind); - } - } - - fn posterior_update(&mut self, node_idx: &usize, observation: f64) { - - match self.nodes.get_mut(node_idx) { - Some(Node::Generic(ref mut node)) => { - node.observation = observation - } - Some(Node::Exponential(ref mut node)) => { - let suf_stats = sufficient_statistics(&node.observation); - for i in 0..suf_stats.len() { - node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); - } - } - None => println!("The value is None") - } - } - - fn prediction_error(&mut self, node_idx: usize) { - - // get the observation value - let observation; - match self.nodes[&node_idx] { - Node::Generic(ref node) => { - observation = node.observation; - } - Node::Exponential(ref node) => { - observation = node.observation; - } - } - - let value_parent_idx = &self.edges[node_idx].value_parents; - match value_parent_idx { - Some(idx) => { - match self.nodes.get_mut(idx) { - Some(Node::Generic(ref mut parent)) => { - parent.observation = observation - } - Some(Node::Exponential(ref mut parent)) => { - parent.observation = observation - } - None => println!("The value is None"), - } - } - None => println!("The value is None"), - } - } - - fn belief_propagation(&mut self, observations: Vec) { - - // 1. prediction propagation - - // 2. inject the observations into the input nodes - for i in 0..observations.len() { - - let input_node_idx = self.inputs[i]; - self.posterior_update(&input_node_idx, observations[i]); - self.prediction_error(input_node_idx); - } - - // 3. posterior update - prediction errors propagation - } - - fn input_data(&mut self, input_data: Vec>) { - for observation in input_data { - self.belief_propagation(observation); - } - } - - fn get_update_order(self) -> Vec { - - let mut update_list = Vec::new(); - - // list all nodes availables in the network - let mut nodes_idxs: Vec = self.nodes.keys().cloned().collect(); - - // remove the input nodes - nodes_idxs.retain(|x| !self.inputs.contains(x)); - - // start with the value parents of input nodes - for input_idx in self.inputs { - let value_parent_idxs = self.edges[input_idx].value_parents; - match value_parent_idxs { - Some(idx) => { - // if this parent is still in the list, update it now - if nodes_idxs.contains(&idx) { - - // add the node in the update list - update_list.push(idx); - - // remove the parent from the availables nodes list - nodes_idxs.retain(|&x| x != idx); - - } - } - None => println!("The value is None") - } - } - nodes_idxs - } - - } - - -fn main() { - - // initialize network - let mut network = Network::new(); - - // create a network - network.add_node( - String::from("generic-input"), - None, - None, - ); - network.add_node( - String::from("generic-input"), - None, - None, - ); - network.add_node( - String::from("exponential-node"), - None, - Some(0), - ); - network.add_node( - String::from("exponential-node"), - None, - Some(1), - ); - - println!("Graph before belief propagation: {:?}", network); - - // belief propagation - let input_data = vec![ - vec![1.1, 2.2], - vec![1.2, 2.1], - vec![1.0, 2.0], - vec![1.3, 2.2], - vec![1.1, 2.5], - vec![1.0, 2.6], - ]; - - network.input_data(input_data); - - println!("Graph after belief propagation: {:?}", network); - -} \ No newline at end of file From 8dcaf96cfc7f5176157d2212bba35167ad510ab2 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 9 Sep 2024 14:09:24 +0200 Subject: [PATCH 05/20] rust library --- src/hgf/.github/workflows/CI.yml | 120 +++++++++++++++++++++++++++++++ src/hgf/src/network.rs | 10 +-- src/hgf/src/utils.rs | 82 +++++++++++++++++---- 3 files changed, 193 insertions(+), 19 deletions(-) create mode 100644 src/hgf/.github/workflows/CI.yml diff --git a/src/hgf/.github/workflows/CI.yml b/src/hgf/.github/workflows/CI.yml new file mode 100644 index 000000000..1bae4be43 --- /dev/null +++ b/src/hgf/.github/workflows/CI.yml @@ -0,0 +1,120 @@ +# This file is autogenerated by maturin v1.4.0 +# To update, run +# +# maturin generate-ci github +# +name: CI + +on: + push: + branches: + - main + - master + tags: + - '*' + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + linux: + runs-on: ubuntu-latest + strategy: + matrix: + target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + windows: + runs-on: windows-latest + strategy: + matrix: + target: [x64, x86] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + architecture: ${{ matrix.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + macos: + runs-on: macos-latest + strategy: + matrix: + target: [x86_64, aarch64] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + sdist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + release: + name: Release + runs-on: ubuntu-latest + if: "startsWith(github.ref, 'refs/tags/')" + needs: [linux, windows, macos, sdist] + steps: + - uses: actions/download-artifact@v3 + with: + name: wheels + - name: Publish to PyPI + uses: PyO3/maturin-action@v1 + env: + MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + with: + command: upload + args: --non-interactive --skip-existing * diff --git a/src/hgf/src/network.rs b/src/hgf/src/network.rs index 1d79a952e..07100a3df 100644 --- a/src/hgf/src/network.rs +++ b/src/hgf/src/network.rs @@ -116,15 +116,15 @@ impl Network { // 1. prediction propagation - // 2. inject the observations into the input nodes for i in 0..observations.len() { - + let input_node_idx = self.inputs[i]; + // 2. inject the observations into the input nodes self.posterior_update(&input_node_idx, observations[i]); + // 3. posterior update - prediction errors propagation self.prediction_error(input_node_idx); } - // 3. posterior update - prediction errors propagation } pub fn input_data(&mut self, input_data: Vec>) { @@ -169,7 +169,7 @@ mod tests { Some(1), ); - println!("Graph before belief propagation: {:?}", network); + // println!("Graph before belief propagation: {:?}", network); // belief propagation let input_data = vec![ @@ -183,7 +183,7 @@ mod tests { network.input_data(input_data); - println!("Graph after belief propagation: {:?}", network); + // println!("Graph after belief propagation: {:?}", network); } } diff --git a/src/hgf/src/utils.rs b/src/hgf/src/utils.rs index 354166a33..d0dbecb0c 100644 --- a/src/hgf/src/utils.rs +++ b/src/hgf/src/utils.rs @@ -1,4 +1,5 @@ use crate::network::Network; +use std::collections::HashSet; pub fn get_update_order(network: Network) -> Vec { @@ -11,24 +12,77 @@ pub fn get_update_order(network: Network) -> Vec { // remove the input nodes nodes_idxs.retain(|x| !network.inputs.contains(x)); - // start with the value parents of input nodes - for input_idx in network.inputs { - let value_parent_idxs = network.edges[input_idx].value_parents; - match value_parent_idxs { - Some(idx) => { - // if this parent is still in the list, update it now - if nodes_idxs.contains(&idx) { + let mut remaining = nodes_idxs.len(); - // add the node in the update list - update_list.push(idx); + while remaining > 0 { - // remove the parent from the availables nodes list - nodes_idxs.retain(|&x| x != idx); + let mut has_update = false; + // loop over all available + for i in 0..nodes_idxs.len() { - } + let idx = nodes_idxs[i]; + let value_children_idxs = network.edges[idx].value_children; + + // check if there is any element in value children + // that is found in the to-be-updated list of nodes + let contains_common = value_children_idxs.iter().any(|&item| nodes_idxs.contains(&item)); + + if !(contains_common) { + + // add the node in the update list + update_list.push(idx); + + // remove the parent from the availables nodes list + nodes_idxs.retain(|&x| x != idx); + + remaining -= 1; + has_update = true; + break; + } } - None => println!("The value is None") + if !(has_update) { + break; } } - nodes_idxs + update_list +} + + +// Tests module for unit tests +#[cfg(test)] // Only compile and include this module when running tests +mod tests { + use super::*; // Import the parent module's items to test them + + #[test] + fn test_get_update_order() { + + // initialize network + let mut network = Network::new(); + + // create a network + network.add_node( + String::from("generic-input"), + Some(1), + None, + ); + network.add_node( + String::from("exponential-node"), + Some(2), + Some(0), + ); + network.add_node( + String::from("exponential-node"), + Some(3), + Some(1), + ); + network.add_node( + String::from("exponential-node"), + None, + Some(2), + ); + + println!("Network: {:?}", network); + println!("Update order: {:?}", get_update_order(network)); + + } } From 467deb87016250dd77e8bceba0235756a13c8437 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Thu, 3 Oct 2024 13:45:23 +0200 Subject: [PATCH 06/20] misc --- src/hgf/Cargo.lock | 20 ++++++++++---------- src/hgf/Cargo.toml | 2 +- src/hgf/src/network.rs | 20 +++++++++++++++++--- src/hgf/src/utils.rs | 1 - 4 files changed, 28 insertions(+), 15 deletions(-) diff --git a/src/hgf/Cargo.lock b/src/hgf/Cargo.lock index a63f242ae..fab37caf4 100644 --- a/src/hgf/Cargo.lock +++ b/src/hgf/Cargo.lock @@ -71,9 +71,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" dependencies = [ "cfg-if", "indoc", @@ -89,9 +89,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" dependencies = [ "once_cell", "target-lexicon", @@ -99,9 +99,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" dependencies = [ "libc", "pyo3-build-config", @@ -109,9 +109,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -121,9 +121,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" dependencies = [ "heck", "proc-macro2", diff --git a/src/hgf/Cargo.toml b/src/hgf/Cargo.toml index 9e430d5e9..c59d39794 100644 --- a/src/hgf/Cargo.toml +++ b/src/hgf/Cargo.toml @@ -10,4 +10,4 @@ crate-type = ["cdylib"] path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = "0.22.2" +pyo3 = { version = "0.22.3", features = ["extension-module"] } \ No newline at end of file diff --git a/src/hgf/src/network.rs b/src/hgf/src/network.rs index 07100a3df..9dcfd12b5 100644 --- a/src/hgf/src/network.rs +++ b/src/hgf/src/network.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; use crate::updates::posterior; +use pyo3::prelude::*; +use pyo3::wrap_pyfunction; #[derive(Debug)] pub struct AdjacencyLists{ @@ -25,14 +27,18 @@ pub enum Node { } #[derive(Debug)] +#[pyclass] pub struct Network{ pub nodes: HashMap, pub edges: Vec, pub inputs: Vec, } +#[pymethods] impl Network { + // Create a new graph + #[new] // Define the constructor accessible from Python pub fn new() -> Self { Network { nodes: HashMap::new(), @@ -42,6 +48,7 @@ impl Network { } // Add a node to the graph + #[pyo3(signature = (kind, value_parents=None, value_childrens=None))] pub fn add_node(&mut self, kind: String, value_parents: Option, value_childrens: Option) { // the node ID is equal to the number of nodes already in the network @@ -99,9 +106,9 @@ impl Network { } } - pub fn posterior_update(&mut self, node_idx: &usize, observation: f64) { + pub fn posterior_update(&mut self, node_idx: usize, observation: f64) { - match self.nodes.get_mut(node_idx) { + match self.nodes.get_mut(&node_idx) { Some(Node::Generic(ref mut node)) => { node.observation = observation } @@ -120,7 +127,7 @@ impl Network { let input_node_idx = self.inputs[i]; // 2. inject the observations into the input nodes - self.posterior_update(&input_node_idx, observations[i]); + self.posterior_update(input_node_idx, observations[i]); // 3. posterior update - prediction errors propagation self.prediction_error(input_node_idx); } @@ -135,6 +142,13 @@ impl Network { } +// Create a module to expose the class to Python +#[pymodule] +fn my_rust_library(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; // Add the class to the Python module + Ok(()) +} + // Tests module for unit tests #[cfg(test)] // Only compile and include this module when running tests mod tests { diff --git a/src/hgf/src/utils.rs b/src/hgf/src/utils.rs index d0dbecb0c..6d4505930 100644 --- a/src/hgf/src/utils.rs +++ b/src/hgf/src/utils.rs @@ -1,5 +1,4 @@ use crate::network::Network; -use std::collections::HashSet; pub fn get_update_order(network: Network) -> Vec { From c84ed0fafdb954804342266690d594d59b78dc77 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 14 Oct 2024 12:21:23 +0200 Subject: [PATCH 07/20] pyo3 --- src/hgf/Cargo.lock | 56 +++++++++++++++++++++--------------------- src/hgf/src/network.rs | 5 ++-- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/hgf/Cargo.lock b/src/hgf/Cargo.lock index fab37caf4..0c069cc90 100644 --- a/src/hgf/Cargo.lock +++ b/src/hgf/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "autocfg" -version = "1.1.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "cfg-if" @@ -35,45 +35,45 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "libc" -version = "0.2.153" +version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" [[package]] name = "memoffset" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" dependencies = [ "autocfg", ] [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "portable-atomic" -version = "1.7.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" dependencies = [ "unicode-ident", ] [[package]] name = "pyo3" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" +checksum = "00e89ce2565d6044ca31a3eb79a334c3a79a841120a98f64eea9f579564cb691" dependencies = [ "cfg-if", "indoc", @@ -89,9 +89,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" +checksum = "d8afbaf3abd7325e08f35ffb8deb5892046fcb2608b703db6a583a5ba4cea01e" dependencies = [ "once_cell", "target-lexicon", @@ -99,9 +99,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" +checksum = "ec15a5ba277339d04763f4c23d85987a5b08cbb494860be141e6a10a8eb88022" dependencies = [ "libc", "pyo3-build-config", @@ -109,9 +109,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" +checksum = "15e0f01b5364bcfbb686a52fc4181d412b708a68ed20c330db9fc8d2c2bf5a43" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -121,9 +121,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" +checksum = "a09b550200e1e5ed9176976d0060cbc2ea82dc8515da07885e7b8153a85caacb" dependencies = [ "heck", "proc-macro2", @@ -134,18 +134,18 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.35" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] [[package]] name = "syn" -version = "2.0.77" +version = "2.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" dependencies = [ "proc-macro2", "quote", @@ -160,9 +160,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unindent" diff --git a/src/hgf/src/network.rs b/src/hgf/src/network.rs index 9dcfd12b5..3378fe54d 100644 --- a/src/hgf/src/network.rs +++ b/src/hgf/src/network.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use crate::updates::posterior; use pyo3::prelude::*; -use pyo3::wrap_pyfunction; #[derive(Debug)] pub struct AdjacencyLists{ @@ -144,8 +143,8 @@ impl Network { // Create a module to expose the class to Python #[pymodule] -fn my_rust_library(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; // Add the class to the Python module +fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; Ok(()) } From e823bff56e835f1f6bc782bc1cf858f584970291 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 14 Oct 2024 17:00:07 +0200 Subject: [PATCH 08/20] reorganize packages structure --- .github/workflows/ci.yml | 182 ++++++++++++++++++ .gitignore | 3 +- src/hgf/Cargo.lock => Cargo.lock | 0 src/hgf/Cargo.toml => Cargo.toml | 2 +- {src/pyhgf => pyhgf}/__init__.py | 0 {src/pyhgf => pyhgf}/data/binary_input.txt | 0 {src/pyhgf => pyhgf}/data/binary_response.txt | 0 {src/pyhgf => pyhgf}/data/usdchf.txt | 0 {src/pyhgf => pyhgf}/distribution.py | 0 {src/pyhgf => pyhgf}/math.py | 0 {src/pyhgf => pyhgf}/model/__init__.py | 0 {src/pyhgf => pyhgf}/model/hgf.py | 0 {src/pyhgf => pyhgf}/model/network.py | 0 {src/pyhgf => pyhgf}/plots.py | 0 {src/pyhgf => pyhgf}/response.py | 0 {src/pyhgf => pyhgf}/typing.py | 0 {src/pyhgf => pyhgf}/updates/__init__.py | 0 {src/pyhgf => pyhgf}/updates/observation.py | 0 .../updates/posterior/__init__.py | 0 .../updates/posterior/categorical.py | 0 .../updates/posterior/continuous.py | 0 .../updates/posterior/exponential.py | 0 .../updates/prediction/__init__.py | 0 .../updates/prediction/binary.py | 0 .../updates/prediction/continuous.py | 0 .../updates/prediction/dirichlet.py | 0 .../updates/prediction_error/__init__.py | 0 .../updates/prediction_error/binary.py | 0 .../updates/prediction_error/categorical.py | 0 .../updates/prediction_error/continuous.py | 0 .../updates/prediction_error/dirichlet.py | 0 .../updates/prediction_error/generic.py | 0 {src/pyhgf => pyhgf}/utils.py | 0 pyproject.toml | 15 +- src/hgf/.github/workflows/CI.yml | 120 ------------ src/hgf/.gitignore | 72 ------- src/hgf/pyproject.toml | 16 -- .../updates/prediction_error/inputs/mod.rs | 0 src/hgf/src/updates/prediction_error/mod.rs | 2 - src/{hgf/src => }/lib.rs | 0 src/{hgf/src => }/math.rs | 0 src/{hgf/src => }/network.rs | 0 src/{hgf => }/tests/exponential_family.rs | 0 src/{hgf/src => }/updates/mod.rs | 0 .../src => }/updates/posterior/continuous.rs | 0 src/{hgf/src => }/updates/posterior/mod.rs | 0 src/{hgf/src => }/updates/prediction/mod.rs | 0 src/updates/prediction_error/mod.rs | 1 + .../prediction_error/nodes/continuous.rs | 0 .../updates/prediction_error/nodes/mod.rs | 0 src/{hgf/src => }/utils.rs | 0 51 files changed, 198 insertions(+), 215 deletions(-) create mode 100644 .github/workflows/ci.yml rename src/hgf/Cargo.lock => Cargo.lock (100%) rename src/hgf/Cargo.toml => Cargo.toml (81%) rename {src/pyhgf => pyhgf}/__init__.py (100%) rename {src/pyhgf => pyhgf}/data/binary_input.txt (100%) rename {src/pyhgf => pyhgf}/data/binary_response.txt (100%) rename {src/pyhgf => pyhgf}/data/usdchf.txt (100%) rename {src/pyhgf => pyhgf}/distribution.py (100%) rename {src/pyhgf => pyhgf}/math.py (100%) rename {src/pyhgf => pyhgf}/model/__init__.py (100%) rename {src/pyhgf => pyhgf}/model/hgf.py (100%) rename {src/pyhgf => pyhgf}/model/network.py (100%) rename {src/pyhgf => pyhgf}/plots.py (100%) rename {src/pyhgf => pyhgf}/response.py (100%) rename {src/pyhgf => pyhgf}/typing.py (100%) rename {src/pyhgf => pyhgf}/updates/__init__.py (100%) rename {src/pyhgf => pyhgf}/updates/observation.py (100%) rename {src/pyhgf => pyhgf}/updates/posterior/__init__.py (100%) rename {src/pyhgf => pyhgf}/updates/posterior/categorical.py (100%) rename {src/pyhgf => pyhgf}/updates/posterior/continuous.py (100%) rename {src/pyhgf => pyhgf}/updates/posterior/exponential.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction/__init__.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction/binary.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction/continuous.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction/dirichlet.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction_error/__init__.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction_error/binary.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction_error/categorical.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction_error/continuous.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction_error/dirichlet.py (100%) rename {src/pyhgf => pyhgf}/updates/prediction_error/generic.py (100%) rename {src/pyhgf => pyhgf}/utils.py (100%) delete mode 100644 src/hgf/.github/workflows/CI.yml delete mode 100644 src/hgf/.gitignore delete mode 100644 src/hgf/pyproject.toml delete mode 100644 src/hgf/src/updates/prediction_error/inputs/mod.rs delete mode 100644 src/hgf/src/updates/prediction_error/mod.rs rename src/{hgf/src => }/lib.rs (100%) rename src/{hgf/src => }/math.rs (100%) rename src/{hgf/src => }/network.rs (100%) rename src/{hgf => }/tests/exponential_family.rs (100%) rename src/{hgf/src => }/updates/mod.rs (100%) rename src/{hgf/src => }/updates/posterior/continuous.rs (100%) rename src/{hgf/src => }/updates/posterior/mod.rs (100%) rename src/{hgf/src => }/updates/prediction/mod.rs (100%) create mode 100644 src/updates/prediction_error/mod.rs rename src/{hgf/src => }/updates/prediction_error/nodes/continuous.rs (100%) rename src/{hgf/src => }/updates/prediction_error/nodes/mod.rs (100%) rename src/{hgf/src => }/utils.rs (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..8762e87df --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,182 @@ +# This file is autogenerated by maturin v1.7.4 +# To update, run +# +# maturin generate-ci github +# + +name: CI + +on: + push: + branches: + - main + - master + tags: + - '*' + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + linux: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: ubuntu-latest + target: x86_64 + - runner: ubuntu-latest + target: x86 + - runner: ubuntu-latest + target: aarch64 + - runner: ubuntu-latest + target: armv7 + - runner: ubuntu-latest + target: s390x + - runner: ubuntu-latest + target: ppc64le + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-linux-${{ matrix.platform.target }} + path: dist + + musllinux: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: ubuntu-latest + target: x86_64 + - runner: ubuntu-latest + target: x86 + - runner: ubuntu-latest + target: aarch64 + - runner: ubuntu-latest + target: armv7 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + manylinux: musllinux_1_2 + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-musllinux-${{ matrix.platform.target }} + path: dist + + windows: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: windows-latest + target: x64 + - runner: windows-latest + target: x86 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + architecture: ${{ matrix.platform.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-windows-${{ matrix.platform.target }} + path: dist + + macos: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: macos-12 + target: x86_64 + - runner: macos-14 + target: aarch64 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-macos-${{ matrix.platform.target }} + path: dist + + sdist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v4 + with: + name: wheels-sdist + path: dist + + release: + name: Release + runs-on: ubuntu-latest + if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} + needs: [linux, musllinux, windows, macos, sdist] + permissions: + # Use to sign the release artifacts + id-token: write + # Used to upload release artifacts + contents: write + # Used to generate artifact attestation + attestations: write + steps: + - uses: actions/download-artifact@v4 + - name: Generate artifact attestation + uses: actions/attest-build-provenance@v1 + with: + subject-path: 'wheels-*/*' + - name: Publish to PyPI + if: "startsWith(github.ref, 'refs/tags/')" + uses: PyO3/maturin-action@v1 + env: + MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + with: + command: upload + args: --non-interactive --skip-existing wheels-*/* \ No newline at end of file diff --git a/.gitignore b/.gitignore index 221b5eb7a..6cff95f36 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ coverage.xm .mypy_cache .pytest_cache dist/* -src/hgf/target* \ No newline at end of file +src/hgf/target* +target/* \ No newline at end of file diff --git a/src/hgf/Cargo.lock b/Cargo.lock similarity index 100% rename from src/hgf/Cargo.lock rename to Cargo.lock diff --git a/src/hgf/Cargo.toml b/Cargo.toml similarity index 81% rename from src/hgf/Cargo.toml rename to Cargo.toml index c59d39794..fd3efa442 100644 --- a/src/hgf/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,4 @@ crate-type = ["cdylib"] path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = { version = "0.22.3", features = ["extension-module"] } \ No newline at end of file +pyo3 = { version = "0.22.4", features = ["extension-module"] } \ No newline at end of file diff --git a/src/pyhgf/__init__.py b/pyhgf/__init__.py similarity index 100% rename from src/pyhgf/__init__.py rename to pyhgf/__init__.py diff --git a/src/pyhgf/data/binary_input.txt b/pyhgf/data/binary_input.txt similarity index 100% rename from src/pyhgf/data/binary_input.txt rename to pyhgf/data/binary_input.txt diff --git a/src/pyhgf/data/binary_response.txt b/pyhgf/data/binary_response.txt similarity index 100% rename from src/pyhgf/data/binary_response.txt rename to pyhgf/data/binary_response.txt diff --git a/src/pyhgf/data/usdchf.txt b/pyhgf/data/usdchf.txt similarity index 100% rename from src/pyhgf/data/usdchf.txt rename to pyhgf/data/usdchf.txt diff --git a/src/pyhgf/distribution.py b/pyhgf/distribution.py similarity index 100% rename from src/pyhgf/distribution.py rename to pyhgf/distribution.py diff --git a/src/pyhgf/math.py b/pyhgf/math.py similarity index 100% rename from src/pyhgf/math.py rename to pyhgf/math.py diff --git a/src/pyhgf/model/__init__.py b/pyhgf/model/__init__.py similarity index 100% rename from src/pyhgf/model/__init__.py rename to pyhgf/model/__init__.py diff --git a/src/pyhgf/model/hgf.py b/pyhgf/model/hgf.py similarity index 100% rename from src/pyhgf/model/hgf.py rename to pyhgf/model/hgf.py diff --git a/src/pyhgf/model/network.py b/pyhgf/model/network.py similarity index 100% rename from src/pyhgf/model/network.py rename to pyhgf/model/network.py diff --git a/src/pyhgf/plots.py b/pyhgf/plots.py similarity index 100% rename from src/pyhgf/plots.py rename to pyhgf/plots.py diff --git a/src/pyhgf/response.py b/pyhgf/response.py similarity index 100% rename from src/pyhgf/response.py rename to pyhgf/response.py diff --git a/src/pyhgf/typing.py b/pyhgf/typing.py similarity index 100% rename from src/pyhgf/typing.py rename to pyhgf/typing.py diff --git a/src/pyhgf/updates/__init__.py b/pyhgf/updates/__init__.py similarity index 100% rename from src/pyhgf/updates/__init__.py rename to pyhgf/updates/__init__.py diff --git a/src/pyhgf/updates/observation.py b/pyhgf/updates/observation.py similarity index 100% rename from src/pyhgf/updates/observation.py rename to pyhgf/updates/observation.py diff --git a/src/pyhgf/updates/posterior/__init__.py b/pyhgf/updates/posterior/__init__.py similarity index 100% rename from src/pyhgf/updates/posterior/__init__.py rename to pyhgf/updates/posterior/__init__.py diff --git a/src/pyhgf/updates/posterior/categorical.py b/pyhgf/updates/posterior/categorical.py similarity index 100% rename from src/pyhgf/updates/posterior/categorical.py rename to pyhgf/updates/posterior/categorical.py diff --git a/src/pyhgf/updates/posterior/continuous.py b/pyhgf/updates/posterior/continuous.py similarity index 100% rename from src/pyhgf/updates/posterior/continuous.py rename to pyhgf/updates/posterior/continuous.py diff --git a/src/pyhgf/updates/posterior/exponential.py b/pyhgf/updates/posterior/exponential.py similarity index 100% rename from src/pyhgf/updates/posterior/exponential.py rename to pyhgf/updates/posterior/exponential.py diff --git a/src/pyhgf/updates/prediction/__init__.py b/pyhgf/updates/prediction/__init__.py similarity index 100% rename from src/pyhgf/updates/prediction/__init__.py rename to pyhgf/updates/prediction/__init__.py diff --git a/src/pyhgf/updates/prediction/binary.py b/pyhgf/updates/prediction/binary.py similarity index 100% rename from src/pyhgf/updates/prediction/binary.py rename to pyhgf/updates/prediction/binary.py diff --git a/src/pyhgf/updates/prediction/continuous.py b/pyhgf/updates/prediction/continuous.py similarity index 100% rename from src/pyhgf/updates/prediction/continuous.py rename to pyhgf/updates/prediction/continuous.py diff --git a/src/pyhgf/updates/prediction/dirichlet.py b/pyhgf/updates/prediction/dirichlet.py similarity index 100% rename from src/pyhgf/updates/prediction/dirichlet.py rename to pyhgf/updates/prediction/dirichlet.py diff --git a/src/pyhgf/updates/prediction_error/__init__.py b/pyhgf/updates/prediction_error/__init__.py similarity index 100% rename from src/pyhgf/updates/prediction_error/__init__.py rename to pyhgf/updates/prediction_error/__init__.py diff --git a/src/pyhgf/updates/prediction_error/binary.py b/pyhgf/updates/prediction_error/binary.py similarity index 100% rename from src/pyhgf/updates/prediction_error/binary.py rename to pyhgf/updates/prediction_error/binary.py diff --git a/src/pyhgf/updates/prediction_error/categorical.py b/pyhgf/updates/prediction_error/categorical.py similarity index 100% rename from src/pyhgf/updates/prediction_error/categorical.py rename to pyhgf/updates/prediction_error/categorical.py diff --git a/src/pyhgf/updates/prediction_error/continuous.py b/pyhgf/updates/prediction_error/continuous.py similarity index 100% rename from src/pyhgf/updates/prediction_error/continuous.py rename to pyhgf/updates/prediction_error/continuous.py diff --git a/src/pyhgf/updates/prediction_error/dirichlet.py b/pyhgf/updates/prediction_error/dirichlet.py similarity index 100% rename from src/pyhgf/updates/prediction_error/dirichlet.py rename to pyhgf/updates/prediction_error/dirichlet.py diff --git a/src/pyhgf/updates/prediction_error/generic.py b/pyhgf/updates/prediction_error/generic.py similarity index 100% rename from src/pyhgf/updates/prediction_error/generic.py rename to pyhgf/updates/prediction_error/generic.py diff --git a/src/pyhgf/utils.py b/pyhgf/utils.py similarity index 100% rename from src/pyhgf/utils.py rename to pyhgf/utils.py diff --git a/pyproject.toml b/pyproject.toml index 7f622b736..377edd01f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,5 @@ [tool.poetry] name = "pyhgf" -version = "0.0.0" description = "Dynamic neural networks for predictive coding" authors = ["Nicolas Legrand "] license = "GPL-3.0" @@ -13,6 +12,12 @@ include = [ "rc/pyhgf/data/binary_input.txt", "src/pyhgf/data/binary_response.txt", ] +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] [tool.poetry-dynamic-versioning] # Enable dynamic versioning @@ -51,5 +56,9 @@ pytest = "^8.3.3" pytest-cov = "^5.0.0" [build-system] -requires = ["poetry-core", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] -build-backend = "poetry_dynamic_versioning.backend" +requires = ["poetry-core", "poetry-dynamic-versioning>=1.0.0,<2.0.0", "maturin>=1.4,<2.0"] +build-backend = "maturin" + +[tool.maturin] +# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so) +features = ["pyo3/extension-module"] \ No newline at end of file diff --git a/src/hgf/.github/workflows/CI.yml b/src/hgf/.github/workflows/CI.yml deleted file mode 100644 index 1bae4be43..000000000 --- a/src/hgf/.github/workflows/CI.yml +++ /dev/null @@ -1,120 +0,0 @@ -# This file is autogenerated by maturin v1.4.0 -# To update, run -# -# maturin generate-ci github -# -name: CI - -on: - push: - branches: - - main - - master - tags: - - '*' - pull_request: - workflow_dispatch: - -permissions: - contents: read - -jobs: - linux: - runs-on: ubuntu-latest - strategy: - matrix: - target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - manylinux: auto - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - windows: - runs-on: windows-latest - strategy: - matrix: - target: [x64, x86] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: ${{ matrix.target }} - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - macos: - runs-on: macos-latest - strategy: - matrix: - target: [x86_64, aarch64] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Build sdist - uses: PyO3/maturin-action@v1 - with: - command: sdist - args: --out dist - - name: Upload sdist - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - release: - name: Release - runs-on: ubuntu-latest - if: "startsWith(github.ref, 'refs/tags/')" - needs: [linux, windows, macos, sdist] - steps: - - uses: actions/download-artifact@v3 - with: - name: wheels - - name: Publish to PyPI - uses: PyO3/maturin-action@v1 - env: - MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} - with: - command: upload - args: --non-interactive --skip-existing * diff --git a/src/hgf/.gitignore b/src/hgf/.gitignore deleted file mode 100644 index c8f044299..000000000 --- a/src/hgf/.gitignore +++ /dev/null @@ -1,72 +0,0 @@ -/target - -# Byte-compiled / optimized / DLL files -__pycache__/ -.pytest_cache/ -*.py[cod] - -# C extensions -*.so - -# Distribution / packaging -.Python -.venv/ -env/ -bin/ -build/ -develop-eggs/ -dist/ -eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -include/ -man/ -venv/ -*.egg-info/ -.installed.cfg -*.egg - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt -pip-selfcheck.json - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.cache -nosetests.xml -coverage.xml - -# Translations -*.mo - -# Mr Developer -.mr.developer.cfg -.project -.pydevproject - -# Rope -.ropeproject - -# Django stuff: -*.log -*.pot - -.DS_Store - -# Sphinx documentation -docs/_build/ - -# PyCharm -.idea/ - -# VSCode -.vscode/ - -# Pyenv -.python-version diff --git a/src/hgf/pyproject.toml b/src/hgf/pyproject.toml deleted file mode 100644 index 399c9df9e..000000000 --- a/src/hgf/pyproject.toml +++ /dev/null @@ -1,16 +0,0 @@ -[build-system] -requires = ["maturin>=1.4,<2.0"] -build-backend = "maturin" - -[project] -name = "hgf" -requires-python = ">=3.8" -classifiers = [ - "Programming Language :: Rust", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", -] -dynamic = ["version"] - -[tool.maturin] -features = ["pyo3/extension-module"] diff --git a/src/hgf/src/updates/prediction_error/inputs/mod.rs b/src/hgf/src/updates/prediction_error/inputs/mod.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/hgf/src/updates/prediction_error/mod.rs b/src/hgf/src/updates/prediction_error/mod.rs deleted file mode 100644 index 264f047ac..000000000 --- a/src/hgf/src/updates/prediction_error/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod inputs; -pub mod nodes; diff --git a/src/hgf/src/lib.rs b/src/lib.rs similarity index 100% rename from src/hgf/src/lib.rs rename to src/lib.rs diff --git a/src/hgf/src/math.rs b/src/math.rs similarity index 100% rename from src/hgf/src/math.rs rename to src/math.rs diff --git a/src/hgf/src/network.rs b/src/network.rs similarity index 100% rename from src/hgf/src/network.rs rename to src/network.rs diff --git a/src/hgf/tests/exponential_family.rs b/src/tests/exponential_family.rs similarity index 100% rename from src/hgf/tests/exponential_family.rs rename to src/tests/exponential_family.rs diff --git a/src/hgf/src/updates/mod.rs b/src/updates/mod.rs similarity index 100% rename from src/hgf/src/updates/mod.rs rename to src/updates/mod.rs diff --git a/src/hgf/src/updates/posterior/continuous.rs b/src/updates/posterior/continuous.rs similarity index 100% rename from src/hgf/src/updates/posterior/continuous.rs rename to src/updates/posterior/continuous.rs diff --git a/src/hgf/src/updates/posterior/mod.rs b/src/updates/posterior/mod.rs similarity index 100% rename from src/hgf/src/updates/posterior/mod.rs rename to src/updates/posterior/mod.rs diff --git a/src/hgf/src/updates/prediction/mod.rs b/src/updates/prediction/mod.rs similarity index 100% rename from src/hgf/src/updates/prediction/mod.rs rename to src/updates/prediction/mod.rs diff --git a/src/updates/prediction_error/mod.rs b/src/updates/prediction_error/mod.rs new file mode 100644 index 000000000..69e3fea51 --- /dev/null +++ b/src/updates/prediction_error/mod.rs @@ -0,0 +1 @@ +pub mod nodes; diff --git a/src/hgf/src/updates/prediction_error/nodes/continuous.rs b/src/updates/prediction_error/nodes/continuous.rs similarity index 100% rename from src/hgf/src/updates/prediction_error/nodes/continuous.rs rename to src/updates/prediction_error/nodes/continuous.rs diff --git a/src/hgf/src/updates/prediction_error/nodes/mod.rs b/src/updates/prediction_error/nodes/mod.rs similarity index 100% rename from src/hgf/src/updates/prediction_error/nodes/mod.rs rename to src/updates/prediction_error/nodes/mod.rs diff --git a/src/hgf/src/utils.rs b/src/utils.rs similarity index 100% rename from src/hgf/src/utils.rs rename to src/utils.rs From d35ba139c9f09597c1a75946c3117f91490fde4d Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 15 Oct 2024 12:42:48 +0200 Subject: [PATCH 09/20] ci --- .github/workflows/test-rs.yml | 23 +++++++++++++++++++++++ src/network.rs | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/test-rs.yml diff --git a/.github/workflows/test-rs.yml b/.github/workflows/test-rs.yml new file mode 100644 index 000000000..823906d80 --- /dev/null +++ b/.github/workflows/test-rs.yml @@ -0,0 +1,23 @@ +name: Rust Crate CI + +on: + release: + types: [published] + pull_request: + types: [opened, synchronize, reopened] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + + - name: Run tests + run: cargo test \ No newline at end of file diff --git a/src/network.rs b/src/network.rs index 3378fe54d..3ebcd4915 100644 --- a/src/network.rs +++ b/src/network.rs @@ -143,7 +143,7 @@ impl Network { // Create a module to expose the class to Python #[pymodule] -fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> { +fn rshgf(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; Ok(()) } From 069337911c3911066491aa781dfefd19ab02f097 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 15 Oct 2024 12:46:07 +0200 Subject: [PATCH 10/20] fix pyproject bug --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 377edd01f..eae00e95f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dynamic = ["version"] +version = "0.0.0" [tool.poetry-dynamic-versioning] # Enable dynamic versioning From d2e3a55798ccfc428f27959f6905e6d47370cb36 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 21 Oct 2024 21:04:51 +0200 Subject: [PATCH 11/20] remove input nodes and add volatility couplings --- Cargo.lock | 34 +++++----- Cargo.toml | 6 +- src/network.rs | 95 +++++++++++++++++----------- src/updates/posterior/continuous.rs | 9 --- src/updates/posterior/exponential.rs | 9 +++ src/updates/posterior/mod.rs | 3 +- src/utils.rs | 22 +++++-- 7 files changed, 103 insertions(+), 75 deletions(-) create mode 100644 src/updates/posterior/exponential.rs diff --git a/Cargo.lock b/Cargo.lock index 0c069cc90..1d4b58a05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,13 +20,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hgf" -version = "0.1.0" -dependencies = [ - "pyo3", -] - [[package]] name = "indoc" version = "2.0.5" @@ -71,9 +64,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.4" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00e89ce2565d6044ca31a3eb79a334c3a79a841120a98f64eea9f579564cb691" +checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51" dependencies = [ "cfg-if", "indoc", @@ -89,9 +82,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.4" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8afbaf3abd7325e08f35ffb8deb5892046fcb2608b703db6a583a5ba4cea01e" +checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179" dependencies = [ "once_cell", "target-lexicon", @@ -99,9 +92,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.4" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec15a5ba277339d04763f4c23d85987a5b08cbb494860be141e6a10a8eb88022" +checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d" dependencies = [ "libc", "pyo3-build-config", @@ -109,9 +102,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.4" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e0f01b5364bcfbb686a52fc4181d412b708a68ed20c330db9fc8d2c2bf5a43" +checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -121,9 +114,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.4" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a09b550200e1e5ed9176976d0060cbc2ea82dc8515da07885e7b8153a85caacb" +checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce" dependencies = [ "heck", "proc-macro2", @@ -141,6 +134,13 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rshgf" +version = "0.1.0" +dependencies = [ + "pyo3", +] + [[package]] name = "syn" version = "2.0.79" diff --git a/Cargo.toml b/Cargo.toml index fd3efa442..9a30c4362 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [package] -name = "hgf" +name = "rshgf" version = "0.1.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] -name = "hgf" +name = "rshgf" crate-type = ["cdylib"] path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = { version = "0.22.4", features = ["extension-module"] } \ No newline at end of file +pyo3 = { version = "0.22.5", features = ["extension-module"] } \ No newline at end of file diff --git a/src/network.rs b/src/network.rs index 3ebcd4915..2822d5e9d 100644 --- a/src/network.rs +++ b/src/network.rs @@ -6,23 +6,31 @@ use pyo3::prelude::*; pub struct AdjacencyLists{ pub value_parents: Option, pub value_children: Option, + pub volatility_parents: Option, + pub volatility_children: Option, } #[derive(Debug, Clone)] -pub struct GenericInputNode{ - pub observation: f64, - pub time_step: f64, +pub struct ContinuousStateNode{ + pub mean: f64, + pub expected_mean: f64, + pub precision: f64, + pub expected_precision: f64, + pub tonic_volatility: f64, + pub tonic_drift: f64, + pub autoconnection_strength: f64, } #[derive(Debug, Clone)] -pub struct ExponentialNode { - pub observation: f64, +pub struct ExponentialFamiliyStateNode { + pub mean: f64, + pub expected_mean: f64, pub nus: f64, pub xis: [f64; 2], } #[derive(Debug, Clone)] pub enum Node { - Generic(GenericInputNode), - Exponential(ExponentialNode), + Continuous(ContinuousStateNode), + Exponential(ExponentialFamiliyStateNode), } #[derive(Debug)] @@ -47,29 +55,46 @@ impl Network { } // Add a node to the graph - #[pyo3(signature = (kind, value_parents=None, value_childrens=None))] - pub fn add_node(&mut self, kind: String, value_parents: Option, value_childrens: Option) { + #[pyo3(signature = (kind, value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] + pub fn add_node(&mut self, kind: String, value_parents: Option, value_children: Option, volatility_children: Option, volatility_parents: Option) { // the node ID is equal to the number of nodes already in the network let node_id: usize = self.nodes.len(); let edges = AdjacencyLists{ + value_parents: value_children, value_children: value_parents, - value_parents: value_childrens, + volatility_parents: volatility_parents, + volatility_children: volatility_children, }; // add edges and attributes - if kind == "generic-input" { - let generic_input = GenericInputNode{observation: 0.0, time_step: 0.0}; - let node = Node::Generic(generic_input); + if kind == "continuous-state" { + let continuous_state = ContinuousStateNode{ + mean: 0.0, expected_mean: 0.0, precision: 1.0, expected_precision: 1.0, + tonic_drift: 0.0, tonic_volatility: -4.0, autoconnection_strength: 1.0 + }; + let node = Node::Continuous(continuous_state); self.nodes.insert(node_id, node); self.edges.push(edges); - self.inputs.push(node_id); - } else if kind == "exponential-node" { - let exponential_node: ExponentialNode = ExponentialNode{observation: 0.0, nus: 0.0, xis: [0.0, 0.0]}; + + // if this node has no children, this is an input node + if (value_children == None) & (volatility_children == None) { + self.inputs.push(node_id); + } + } else if kind == "exponential-state" { + let exponential_node: ExponentialFamiliyStateNode = ExponentialFamiliyStateNode{ + mean: 0.0, expected_mean: 0.0, nus: 0.0, xis: [0.0, 0.0] + }; let node = Node::Exponential(exponential_node); self.nodes.insert(node_id, node); self.edges.push(edges); + + // if this node has no children, this is an input node + if (value_children == None) & (volatility_children == None) { + self.inputs.push(node_id); + } + } else { println!("Invalid type of node provided ({}).", kind); } @@ -78,13 +103,13 @@ impl Network { pub fn prediction_error(&mut self, node_idx: usize) { // get the observation value - let observation; + let mean; match self.nodes[&node_idx] { - Node::Generic(ref node) => { - observation = node.observation; + Node::Continuous(ref node) => { + mean = node.mean; } Node::Exponential(ref node) => { - observation = node.observation; + mean = node.mean; } } @@ -92,11 +117,11 @@ impl Network { match value_parent_idx { Some(idx) => { match self.nodes.get_mut(idx) { - Some(Node::Generic(ref mut parent)) => { - parent.observation = observation + Some(Node::Continuous(ref mut parent)) => { + parent.mean = mean } Some(Node::Exponential(ref mut parent)) => { - parent.observation = observation + parent.mean = mean } None => println!("No prediction error for this type of node."), } @@ -108,11 +133,11 @@ impl Network { pub fn posterior_update(&mut self, node_idx: usize, observation: f64) { match self.nodes.get_mut(&node_idx) { - Some(Node::Generic(ref mut node)) => { - node.observation = observation + Some(Node::Continuous(ref mut node)) => { + node.mean = observation } Some(Node::Exponential(ref mut node)) => { - posterior::continuous::posterior_update_exponential(node) + posterior::exponential::posterior_update_exponential(node) } None => println!("No posterior update for this type of node.") } @@ -160,26 +185,20 @@ mod tests { // initialize network let mut network = Network::new(); - // create a network + // create a network with two exponential family state nodes network.add_node( - String::from("generic-input"), + String::from("exponential-state"), None, None, - ); - network.add_node( - String::from("generic-input"), - None, None, + None ); network.add_node( - String::from("exponential-node"), + String::from("exponential-state"), + None, None, - Some(0), - ); - network.add_node( - String::from("exponential-node"), None, - Some(1), + None ); // println!("Graph before belief propagation: {:?}", network); diff --git a/src/updates/posterior/continuous.rs b/src/updates/posterior/continuous.rs index 7fb5b78e6..e69de29bb 100644 --- a/src/updates/posterior/continuous.rs +++ b/src/updates/posterior/continuous.rs @@ -1,9 +0,0 @@ -use crate::network::ExponentialNode; -use crate::math::sufficient_statistics; - -pub fn posterior_update_exponential(node: &mut ExponentialNode) { - let suf_stats = sufficient_statistics(&node.observation); - for i in 0..suf_stats.len() { - node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); - } -} \ No newline at end of file diff --git a/src/updates/posterior/exponential.rs b/src/updates/posterior/exponential.rs new file mode 100644 index 000000000..97b07e8cf --- /dev/null +++ b/src/updates/posterior/exponential.rs @@ -0,0 +1,9 @@ +use crate::network::ExponentialFamiliyStateNode; +use crate::math::sufficient_statistics; + +pub fn posterior_update_exponential(node: &mut ExponentialFamiliyStateNode) { + let suf_stats = sufficient_statistics(&node.mean); + for i in 0..suf_stats.len() { + node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); + } +} \ No newline at end of file diff --git a/src/updates/posterior/mod.rs b/src/updates/posterior/mod.rs index 6817d49e7..d92816181 100644 --- a/src/updates/posterior/mod.rs +++ b/src/updates/posterior/mod.rs @@ -1 +1,2 @@ -pub mod continuous; \ No newline at end of file +pub mod continuous; +pub mod exponential; \ No newline at end of file diff --git a/src/utils.rs b/src/utils.rs index 6d4505930..87acda9e6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -60,24 +60,32 @@ mod tests { // create a network network.add_node( - String::from("generic-input"), + String::from("continuous-state"), Some(1), None, + None, + Some(2), ); network.add_node( - String::from("exponential-node"), - Some(2), + String::from("continuous-state"), + None, Some(0), + None, + None, ); network.add_node( - String::from("exponential-node"), - Some(3), - Some(1), + String::from("continuous-state"), + None, + None, + Some(0), + None, ); network.add_node( String::from("exponential-node"), None, - Some(2), + None, + None, + None, ); println!("Network: {:?}", network); From 198245fa6d7f6e2e44671857e07637728b46cd9a Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 22 Oct 2024 20:53:18 +0200 Subject: [PATCH 12/20] returns attributes --- Cargo.lock | 150 ++++++++++++++++++++++++++++--- Cargo.toml | 2 +- src/network.rs | 66 ++++++++++++-- src/utils.rs | 16 ++-- tests/test_exponential_family.py | 25 ++++++ 5 files changed, 229 insertions(+), 30 deletions(-) create mode 100644 tests/test_exponential_family.py diff --git a/Cargo.lock b/Cargo.lock index 1d4b58a05..e4c2a367f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "cfg-if" version = "1.0.0" @@ -16,9 +22,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "heck" -version = "0.5.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "indoc" @@ -32,6 +38,16 @@ version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "memoffset" version = "0.9.1" @@ -47,6 +63,29 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "portable-atomic" version = "1.9.0" @@ -64,15 +103,15 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "cfg-if", "indoc", "libc", "memoffset", - "once_cell", + "parking_lot", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -82,9 +121,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" dependencies = [ "once_cell", "target-lexicon", @@ -92,9 +131,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" dependencies = [ "libc", "pyo3-build-config", @@ -102,9 +141,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -114,9 +153,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ "heck", "proc-macro2", @@ -134,6 +173,15 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags", +] + [[package]] name = "rshgf" version = "0.1.0" @@ -141,6 +189,18 @@ dependencies = [ "pyo3", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + [[package]] name = "syn" version = "2.0.79" @@ -169,3 +229,67 @@ name = "unindent" version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/Cargo.toml b/Cargo.toml index 9a30c4362..0673e99a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,4 @@ crate-type = ["cdylib"] path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = { version = "0.22.5", features = ["extension-module"] } \ No newline at end of file +pyo3 = { version = "0.21.2", features = ["extension-module"] } \ No newline at end of file diff --git a/src/network.rs b/src/network.rs index 2822d5e9d..61408539e 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,12 +1,17 @@ use std::collections::HashMap; use crate::updates::posterior; -use pyo3::prelude::*; +use pyo3::{prelude::*, types::{PyList, PyDict}}; #[derive(Debug)] +#[pyclass] pub struct AdjacencyLists{ + #[pyo3(get, set)] pub value_parents: Option, + #[pyo3(get, set)] pub value_children: Option, + #[pyo3(get, set)] pub volatility_parents: Option, + #[pyo3(get, set)] pub volatility_children: Option, } #[derive(Debug, Clone)] @@ -54,9 +59,18 @@ impl Network { } } - // Add a node to the graph - #[pyo3(signature = (kind, value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] - pub fn add_node(&mut self, kind: String, value_parents: Option, value_children: Option, volatility_children: Option, volatility_parents: Option) { + /// Add nodes to the network. + /// + /// # Arguments + /// * `kind` - The type of node that should be added. + /// * `value_parents` - The indexes of the node's value parents. + /// * `value_children` - The indexes of the node's value children. + /// * `volatility_children` - The indexes of the node's volatility children. + /// * `volatility_parents` - The indexes of the node's volatility parents. + #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] + pub fn add_nodes(&mut self, kind: &str, value_parents: Option, + value_children: Option, volatility_children: Option, + volatility_parents: Option) { // the node ID is equal to the number of nodes already in the network let node_id: usize = self.nodes.len(); @@ -143,6 +157,11 @@ impl Network { } } + /// One time step belief propagation. + /// + /// # Arguments + /// * `observations` - A vector of values, each value is one new observation associated + /// with one node. pub fn belief_propagation(&mut self, observations: Vec) { // 1. prediction propagation @@ -158,13 +177,44 @@ impl Network { } + /// Add a sequence of observations. + /// + /// # Arguments + /// * `input_data` - A vector of vectors. Each vector is a time series of observations + /// associated with one node. pub fn input_data(&mut self, input_data: Vec>) { for observation in input_data { self.belief_propagation(observation); } } + + #[getter] + pub fn get_inputs<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> { + let py_list = PyList::new(py, &self.inputs); // Create a PyList from Vec + Ok(py_list) + } + + #[getter] + pub fn get_edges<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> { + // Create a new Python list + let py_list = PyList::empty(py); + + // Convert each struct in the Vec to a Python object and add to PyList + for s in &self.edges { + // Create a new Python dictionary for each MyStruct + let py_dict = PyDict::new(py); + py_dict.set_item("value_parents", s.value_parents)?; + py_dict.set_item("value_children", s.value_children)?; + py_dict.set_item("volatility_parents", s.volatility_parents)?; + py_dict.set_item("volatility_children", s.volatility_children)?; + + // Add the dictionary to the list + py_list.append(py_dict)?; + } + Ok(py_list) } +} // Create a module to expose the class to Python #[pymodule] @@ -186,15 +236,15 @@ mod tests { let mut network = Network::new(); // create a network with two exponential family state nodes - network.add_node( - String::from("exponential-state"), + network.add_nodes( + "exponential-state", None, None, None, None ); - network.add_node( - String::from("exponential-state"), + network.add_nodes( + "exponential-state", None, None, None, diff --git a/src/utils.rs b/src/utils.rs index 87acda9e6..ab5da35a3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -59,29 +59,29 @@ mod tests { let mut network = Network::new(); // create a network - network.add_node( - String::from("continuous-state"), + network.add_nodes( + "continuous-state", Some(1), None, None, Some(2), ); - network.add_node( - String::from("continuous-state"), + network.add_nodes( + "continuous-state", None, Some(0), None, None, ); - network.add_node( - String::from("continuous-state"), + network.add_nodes( + "continuous-state", None, None, Some(0), None, ); - network.add_node( - String::from("exponential-node"), + network.add_nodes( + "exponential-node", None, None, None, diff --git a/tests/test_exponential_family.py b/tests/test_exponential_family.py new file mode 100644 index 000000000..837b685f8 --- /dev/null +++ b/tests/test_exponential_family.py @@ -0,0 +1,25 @@ +# Author: Nicolas Legrand + +from rshgf import Network as RsNetwork + +from pyhgf import load_data +from pyhgf.model import Network as PyNetwork + + +def test_1d_gaussain(): + + timeseries = load_data("continuous") + + # Rust ----------------------------------------------------------------------------- + rs_network = RsNetwork() + rs_network.add_nodes(kind="exponential-state") + rs_network.add_nodes(kind="exponential-state") + rs_network.inputs + rs_network.edges + + rs_network.input_data([[0, 0], [1, 1]]) + rs_network.input_data([timeseries, timeseries]) + + # Python --------------------------------------------------------------------------- + py_network = PyNetwork().add_nodes(kind="continuous-state") + py_network.attributes From db5910ca62e5032d01baa337605c90e79aff20ce Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Thu, 24 Oct 2024 12:56:14 +0200 Subject: [PATCH 13/20] add a proper get_update_sequence function --- src/lib.rs | 2 +- src/{network.rs => model.rs} | 125 +++++------ src/updates/mod.rs | 3 +- src/updates/observations.rs | 21 ++ src/updates/posterior/continuous.rs | 13 ++ src/updates/posterior/exponential.rs | 24 ++- src/updates/prediction/continuous.rs | 13 ++ src/updates/prediction/mod.rs | 1 + src/updates/prediction_error/continuous.rs | 13 ++ src/updates/prediction_error/mod.rs | 2 +- .../prediction_error/nodes/continuous.rs | 0 src/updates/prediction_error/nodes/mod.rs | 1 - src/utils.rs | 201 +++++++++++++++--- 13 files changed, 308 insertions(+), 111 deletions(-) rename src/{network.rs => model.rs} (66%) create mode 100644 src/updates/observations.rs create mode 100644 src/updates/prediction/continuous.rs create mode 100644 src/updates/prediction_error/continuous.rs delete mode 100644 src/updates/prediction_error/nodes/continuous.rs delete mode 100644 src/updates/prediction_error/nodes/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 504294410..55ad98cd0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -pub mod network; +pub mod model; pub mod utils; pub mod math; pub mod updates; \ No newline at end of file diff --git a/src/network.rs b/src/model.rs similarity index 66% rename from src/network.rs rename to src/model.rs index 61408539e..2b5845f80 100644 --- a/src/network.rs +++ b/src/model.rs @@ -1,18 +1,19 @@ use std::collections::HashMap; -use crate::updates::posterior; +use crate::updates::observations::observation_update; +use crate::utils::get_update_sequence; use pyo3::{prelude::*, types::{PyList, PyDict}}; #[derive(Debug)] #[pyclass] pub struct AdjacencyLists{ #[pyo3(get, set)] - pub value_parents: Option, + pub value_parents: Option>, #[pyo3(get, set)] - pub value_children: Option, + pub value_children: Option>, #[pyo3(get, set)] - pub volatility_parents: Option, + pub volatility_parents: Option>, #[pyo3(get, set)] - pub volatility_children: Option, + pub volatility_children: Option>, } #[derive(Debug, Clone)] pub struct ContinuousStateNode{ @@ -38,12 +39,22 @@ pub enum Node { Exponential(ExponentialFamiliyStateNode), } +// Create a default signature for update functions +pub type FnType = fn(&mut Network, usize); + +#[derive(Debug)] +pub struct UpdateSequence { + pub predictions: Vec<(usize, FnType)>, + pub updates: Vec<(usize, FnType)>, +} + #[derive(Debug)] #[pyclass] pub struct Network{ pub nodes: HashMap, pub edges: Vec, pub inputs: Vec, + pub update_sequence: UpdateSequence, } #[pymethods] @@ -56,6 +67,7 @@ impl Network { nodes: HashMap::new(), edges: Vec::new(), inputs: Vec::new(), + update_sequence: UpdateSequence {predictions: Vec::new(), updates: Vec::new()} } } @@ -68,12 +80,17 @@ impl Network { /// * `volatility_children` - The indexes of the node's volatility children. /// * `volatility_parents` - The indexes of the node's volatility parents. #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] - pub fn add_nodes(&mut self, kind: &str, value_parents: Option, - value_children: Option, volatility_children: Option, - volatility_parents: Option) { - + pub fn add_nodes(&mut self, kind: &str, value_parents: Option>, + value_children: Option>, volatility_children: Option>, + volatility_parents: Option>) { + // the node ID is equal to the number of nodes already in the network let node_id: usize = self.nodes.len(); + + // if this node has no children, this is an input node + if (value_children == None) & (volatility_children == None) { + self.inputs.push(node_id); + } let edges = AdjacencyLists{ value_parents: value_children, @@ -92,10 +109,6 @@ impl Network { self.nodes.insert(node_id, node); self.edges.push(edges); - // if this node has no children, this is an input node - if (value_children == None) & (volatility_children == None) { - self.inputs.push(node_id); - } } else if kind == "exponential-state" { let exponential_node: ExponentialFamiliyStateNode = ExponentialFamiliyStateNode{ mean: 0.0, expected_mean: 0.0, nus: 0.0, xis: [0.0, 0.0] @@ -103,78 +116,41 @@ impl Network { let node = Node::Exponential(exponential_node); self.nodes.insert(node_id, node); self.edges.push(edges); - - // if this node has no children, this is an input node - if (value_children == None) & (volatility_children == None) { - self.inputs.push(node_id); - } } else { println!("Invalid type of node provided ({}).", kind); } } - pub fn prediction_error(&mut self, node_idx: usize) { - - // get the observation value - let mean; - match self.nodes[&node_idx] { - Node::Continuous(ref node) => { - mean = node.mean; - } - Node::Exponential(ref node) => { - mean = node.mean; - } - } - - let value_parent_idx = &self.edges[node_idx].value_parents; - match value_parent_idx { - Some(idx) => { - match self.nodes.get_mut(idx) { - Some(Node::Continuous(ref mut parent)) => { - parent.mean = mean - } - Some(Node::Exponential(ref mut parent)) => { - parent.mean = mean - } - None => println!("No prediction error for this type of node."), - } - } - None => println!("No value parent"), - } - } - - pub fn posterior_update(&mut self, node_idx: usize, observation: f64) { - - match self.nodes.get_mut(&node_idx) { - Some(Node::Continuous(ref mut node)) => { - node.mean = observation - } - Some(Node::Exponential(ref mut node)) => { - posterior::exponential::posterior_update_exponential(node) - } - None => println!("No posterior update for this type of node.") - } + pub fn get_update_sequence(&mut self) { + self.update_sequence = get_update_sequence(self); } - /// One time step belief propagation. + /// Single time slice belief propagation. /// /// # Arguments /// * `observations` - A vector of values, each value is one new observation associated /// with one node. - pub fn belief_propagation(&mut self, observations: Vec) { + pub fn belief_propagation(&mut self, observations_set: Vec) { - // 1. prediction propagation + let predictions = self.update_sequence.predictions.clone(); + let updates = self.update_sequence.updates.clone(); - for i in 0..observations.len() { - - let input_node_idx = self.inputs[i]; - // 2. inject the observations into the input nodes - self.posterior_update(input_node_idx, observations[i]); - // 3. posterior update - prediction errors propagation - self.prediction_error(input_node_idx); + // 1. prediction steps + for (idx, step) in predictions.iter() { + step(self, *idx); + } + + // 2. observation steps + for (i, observations) in observations_set.iter().enumerate() { + let idx = self.inputs[i]; + observation_update(self, idx, *observations); + } + + // 3. update steps + for (idx, step) in updates.iter() { + step(self, *idx); } - } /// Add a sequence of observations. @@ -203,17 +179,16 @@ impl Network { for s in &self.edges { // Create a new Python dictionary for each MyStruct let py_dict = PyDict::new(py); - py_dict.set_item("value_parents", s.value_parents)?; - py_dict.set_item("value_children", s.value_children)?; - py_dict.set_item("volatility_parents", s.volatility_parents)?; - py_dict.set_item("volatility_children", s.volatility_children)?; + py_dict.set_item("value_parents", &s.value_parents)?; + py_dict.set_item("value_children", &s.value_children)?; + py_dict.set_item("volatility_parents", &s.volatility_parents)?; + py_dict.set_item("volatility_children", &s.volatility_children)?; // Add the dictionary to the list py_list.append(py_dict)?; } Ok(py_list) } - } // Create a module to expose the class to Python diff --git a/src/updates/mod.rs b/src/updates/mod.rs index dd8404066..28479cc79 100644 --- a/src/updates/mod.rs +++ b/src/updates/mod.rs @@ -1,3 +1,4 @@ pub mod posterior; pub mod prediction; -pub mod prediction_error; \ No newline at end of file +pub mod prediction_error; +pub mod observations; \ No newline at end of file diff --git a/src/updates/observations.rs b/src/updates/observations.rs new file mode 100644 index 000000000..1d64ec7a9 --- /dev/null +++ b/src/updates/observations.rs @@ -0,0 +1,21 @@ +use crate::model::{Network, Node}; + + +/// Inject new observations into an input node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The input node index. +/// * `observations` - The new observations. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn observation_update(network: &mut Network, node_idx: usize, observations: f64) { + + match network.nodes.get_mut(&node_idx) { + Some(Node::Exponential(ref mut node)) => { + node.mean = observations; + } + _ => (), + } +} \ No newline at end of file diff --git a/src/updates/posterior/continuous.rs b/src/updates/posterior/continuous.rs index e69de29bb..b1941915d 100644 --- a/src/updates/posterior/continuous.rs +++ b/src/updates/posterior/continuous.rs @@ -0,0 +1,13 @@ +use crate::model::Network; + +/// Posterior update from a continuous state node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The node index. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn posterior_update_continuous_state_node(network: &mut Network, node_idx: usize) { + let a = 1; +} \ No newline at end of file diff --git a/src/updates/posterior/exponential.rs b/src/updates/posterior/exponential.rs index 97b07e8cf..20dcf2d15 100644 --- a/src/updates/posterior/exponential.rs +++ b/src/updates/posterior/exponential.rs @@ -1,9 +1,23 @@ -use crate::network::ExponentialFamiliyStateNode; +use crate::model::{Network, Node}; use crate::math::sufficient_statistics; -pub fn posterior_update_exponential(node: &mut ExponentialFamiliyStateNode) { - let suf_stats = sufficient_statistics(&node.mean); - for i in 0..suf_stats.len() { - node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); +/// Updating an exponential family state node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The node index. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn posterior_update_exponential_state_node(network: &mut Network, node_idx: usize) { + + match network.nodes.get_mut(&node_idx) { + Some(Node::Exponential(ref mut node)) => { + let suf_stats = sufficient_statistics(&node.mean); + for i in 0..suf_stats.len() { + node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); + } + } + _ => (), } } \ No newline at end of file diff --git a/src/updates/prediction/continuous.rs b/src/updates/prediction/continuous.rs new file mode 100644 index 000000000..f40d2c35e --- /dev/null +++ b/src/updates/prediction/continuous.rs @@ -0,0 +1,13 @@ +use crate::model::Network; + +/// Prediction from a continuous state node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The node index. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn prediction_continuous_state_node(network: &mut Network, node_idx: usize) { + let a = 1; +} \ No newline at end of file diff --git a/src/updates/prediction/mod.rs b/src/updates/prediction/mod.rs index e69de29bb..6817d49e7 100644 --- a/src/updates/prediction/mod.rs +++ b/src/updates/prediction/mod.rs @@ -0,0 +1 @@ +pub mod continuous; \ No newline at end of file diff --git a/src/updates/prediction_error/continuous.rs b/src/updates/prediction_error/continuous.rs new file mode 100644 index 000000000..96e8990ca --- /dev/null +++ b/src/updates/prediction_error/continuous.rs @@ -0,0 +1,13 @@ +use crate::model::Network; + +/// Prediction error from a continuous state node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The node index. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn prediction_error_continuous_state_node(network: &mut Network, node_idx: usize) { + let a = 1; +} \ No newline at end of file diff --git a/src/updates/prediction_error/mod.rs b/src/updates/prediction_error/mod.rs index 69e3fea51..c53f4e7ae 100644 --- a/src/updates/prediction_error/mod.rs +++ b/src/updates/prediction_error/mod.rs @@ -1 +1 @@ -pub mod nodes; +pub mod continuous; diff --git a/src/updates/prediction_error/nodes/continuous.rs b/src/updates/prediction_error/nodes/continuous.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/updates/prediction_error/nodes/mod.rs b/src/updates/prediction_error/nodes/mod.rs deleted file mode 100644 index 6817d49e7..000000000 --- a/src/updates/prediction_error/nodes/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod continuous; \ No newline at end of file diff --git a/src/utils.rs b/src/utils.rs index ab5da35a3..178d1ba6a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,51 +1,198 @@ -use crate::network::Network; +use crate::{model::{FnType, Network, Node, UpdateSequence}, updates::{posterior::{continuous::posterior_update_continuous_state_node, exponential::posterior_update_exponential_state_node}, prediction::continuous::prediction_continuous_state_node, prediction_error::continuous::prediction_error_continuous_state_node}}; +pub fn get_update_sequence(network: &Network) -> UpdateSequence { + let predictions = get_predictions_sequence(network); + let updates = get_updates_sequence(network); -pub fn get_update_order(network: Network) -> Vec { - - let mut update_list = Vec::new(); + // return the update sequence + let update_sequence = UpdateSequence {predictions: predictions, updates: updates}; + update_sequence +} - // list all nodes availables in the network - let mut nodes_idxs: Vec = network.nodes.keys().cloned().collect(); - // remove the input nodes - nodes_idxs.retain(|x| !network.inputs.contains(x)); +pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { - let mut remaining = nodes_idxs.len(); + let mut predictions : Vec<(usize, FnType)> = Vec::new(); + + // 1. get prediction sequence ------------------------------------------------------ + + // list all nodes availables in the network + let mut nodes_idxs: Vec = network.nodes.keys().cloned().collect(); - while remaining > 0 { + // iterate over all nodes and add the prediction step if all criteria are met + let mut n_remaining = nodes_idxs.len(); + while n_remaining > 0 { + + // were we able to add an update step in the list on that iteration? let mut has_update = false; - // loop over all available + + // loop over all the remaining nodes for i in 0..nodes_idxs.len() { let idx = nodes_idxs[i]; - let value_children_idxs = network.edges[idx].value_children; - - // check if there is any element in value children - // that is found in the to-be-updated list of nodes - let contains_common = value_children_idxs.iter().any(|&item| nodes_idxs.contains(&item)); - + + // list the node's parents + let value_parents_idxs = &network.edges[idx].value_parents; + let volatility_parents_idxs = &network.edges[idx].volatility_parents; + + let parents_idxs = match (value_parents_idxs, volatility_parents_idxs) { + // If both are Some, merge the vectors + (Some(ref vec1), Some(ref vec2)) => { + // Create a new vector by merging the two + let merged_vec: Vec = vec1.iter().chain(vec2.iter()).cloned().collect(); + Some(merged_vec) // Return the merged vector wrapped in Some + } + // If one is Some and the other is None, return the one that's Some + (Some(vec), None) | (None, Some(vec)) => Some(vec.clone()), + // If both are None, return None + (None, None) => None, + }; + + // check if there is any parent node that is still found in the to-be-updated list + let contains_common = match parents_idxs { + Some(vec) => vec.iter().any(|item| nodes_idxs.contains(item)), + None => false + }; + + // if all parents have processed their prediction, this one can be added if !(contains_common) { // add the node in the update list - update_list.push(idx); + match network.nodes.get(&idx) { + Some(Node::Continuous(_)) => { + predictions.push((idx, prediction_continuous_state_node)); + } + Some(Node::Exponential(_)) => (), + None => () + + } - // remove the parent from the availables nodes list + // remove the node from the to-be-updated list nodes_idxs.retain(|&x| x != idx); - - remaining -= 1; + n_remaining -= 1; has_update = true; break; } } + // 2. get update sequence ------------------------------------------------------ + if !(has_update) { break; } } - update_list + predictions + } +pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { + + let mut updates : Vec<(usize, FnType)> = Vec::new(); + + // 1. get update sequence ---------------------------------------------------------- + + // list all nodes availables in the network + let mut pe_nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + let mut po_nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + + // remove the input nodes from the to-be-visited nodes for posterior updates + po_nodes_idxs.retain(|x| !network.inputs.contains(x)); + + // iterate over all nodes and add the prediction step if all criteria are met + let mut n_remaining = 2 * pe_nodes_idxs.len(); // posterior updates + prediction errors + + while n_remaining > 0 { + + // were we able to add an update step in the list on that iteration? + let mut has_update = false; + + // loop over all the remaining nodes for prediction errors --------------------- + for i in 0..pe_nodes_idxs.len() { + + let idx = pe_nodes_idxs[i]; + + // to send a prediction error, this node should have been updated first + if !(po_nodes_idxs.contains(&idx)) { + + // add the node in the update list + match network.nodes.get(&idx) { + Some(Node::Continuous(_)) => { + updates.push((idx, prediction_error_continuous_state_node)); + } + Some(Node::Exponential(_)) => (), + None => () + + } + + // remove the node from the to-be-updated list + pe_nodes_idxs.retain(|&x| x != idx); + n_remaining -= 1; + has_update = true; + break; + } + } + + // loop over all the remaining nodes for posterior updates --------------------- + for i in 0..po_nodes_idxs.len() { + + let idx = po_nodes_idxs[i]; + + // to start a posterior update, all children should have sent prediction errors + // 1. get a list of all children + let value_children_idxs = &network.edges[idx].value_children; + let volatility_children_idxs = &network.edges[idx].volatility_children; + + let children_idxs = match (value_children_idxs, volatility_children_idxs) { + // If both are Some, merge the vectors + (Some(ref vec1), Some(ref vec2)) => { + // Create a new vector by merging the two + let merged_vec: Vec = vec1.iter().chain(vec2.iter()).cloned().collect(); + Some(merged_vec) // Return the merged vector wrapped in Some + } + // If one is Some and the other is None, return the one that's Some + (Some(vec), None) | (None, Some(vec)) => Some(vec.clone()), + // If both are None, return None + (None, None) => None, + }; + + // 2. check if any of the children is still on the to-be-visited list for prediction errors + // check if there is any parent node that is still found in the to-be-updated list + let missing_pe = match children_idxs { + Some(vec) => vec.iter().any(|item| pe_nodes_idxs.contains(item)), + None => false + }; + + // 3. if false, add the posterior update to the list + if !(missing_pe) { + + // add the node in the update list + match network.nodes.get(&idx) { + Some(Node::Continuous(_)) => { + updates.push((idx, posterior_update_continuous_state_node)); + } + Some(Node::Exponential(_)) => { + updates.push((idx, posterior_update_exponential_state_node)); + } + None => () + + } + + // remove the node from the to-be-updated list + po_nodes_idxs.retain(|&x| x != idx); + n_remaining -= 1; + has_update = true; + break; + } + } + // 2. get update sequence ---------------------------------------------------------- + + if !(has_update) { + break; + } + } + updates + +} // Tests module for unit tests #[cfg(test)] // Only compile and include this module when running tests @@ -61,15 +208,15 @@ mod tests { // create a network network.add_nodes( "continuous-state", - Some(1), + Some(vec![1]), None, None, - Some(2), + Some(vec![2]), ); network.add_nodes( "continuous-state", None, - Some(0), + Some(vec![0]), None, None, ); @@ -77,7 +224,7 @@ mod tests { "continuous-state", None, None, - Some(0), + Some(vec![0]), None, ); network.add_nodes( @@ -89,7 +236,7 @@ mod tests { ); println!("Network: {:?}", network); - println!("Update order: {:?}", get_update_order(network)); + println!("Update order: {:?}", get_update_sequence(&network)); } } From c76d10818ce2122dad12b354980d7d315ef40213 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Fri, 25 Oct 2024 12:36:28 +0200 Subject: [PATCH 14/20] set_sequence working properly --- src/model.rs | 52 +++++++--- src/updates/mod.rs | 2 +- src/updates/posterior/mod.rs | 3 +- .../exponential.rs | 2 +- src/updates/prediction_error/mod.rs | 1 + src/utils/function_pointer.rs | 18 ++++ src/utils/mod.rs | 2 + src/{utils.rs => utils/set_sequence.rs} | 98 ++++++++++++------- 8 files changed, 128 insertions(+), 50 deletions(-) rename src/updates/{posterior => prediction_error}/exponential.rs (91%) create mode 100644 src/utils/function_pointer.rs create mode 100644 src/utils/mod.rs rename src/{utils.rs => utils/set_sequence.rs} (65%) diff --git a/src/model.rs b/src/model.rs index 2b5845f80..5f23e09e9 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; -use crate::updates::observations::observation_update; -use crate::utils::get_update_sequence; +use crate::{updates::observations::observation_update, utils::function_pointer::FnType}; +use crate::utils::set_sequence::set_update_sequence; +use crate::utils::function_pointer::get_func_map; +use pyo3::types::PyTuple; use pyo3::{prelude::*, types::{PyList, PyDict}}; #[derive(Debug)] @@ -39,9 +41,6 @@ pub enum Node { Exponential(ExponentialFamiliyStateNode), } -// Create a default signature for update functions -pub type FnType = fn(&mut Network, usize); - #[derive(Debug)] pub struct UpdateSequence { pub predictions: Vec<(usize, FnType)>, @@ -79,10 +78,10 @@ impl Network { /// * `value_children` - The indexes of the node's value children. /// * `volatility_children` - The indexes of the node's volatility children. /// * `volatility_parents` - The indexes of the node's volatility parents. - #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] + #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_parents=None, volatility_children=None,))] pub fn add_nodes(&mut self, kind: &str, value_parents: Option>, - value_children: Option>, volatility_children: Option>, - volatility_parents: Option>) { + value_children: Option>, + volatility_parents: Option>, volatility_children: Option>, ) { // the node ID is equal to the number of nodes already in the network let node_id: usize = self.nodes.len(); @@ -93,8 +92,8 @@ impl Network { } let edges = AdjacencyLists{ - value_parents: value_children, - value_children: value_parents, + value_parents: value_parents, + value_children: value_children, volatility_parents: volatility_parents, volatility_children: volatility_children, }; @@ -122,8 +121,8 @@ impl Network { } } - pub fn get_update_sequence(&mut self) { - self.update_sequence = get_update_sequence(self); + pub fn set_update_sequence(&mut self) { + self.update_sequence = set_update_sequence(self); } /// Single time slice belief propagation. @@ -189,6 +188,35 @@ impl Network { } Ok(py_list) } + + #[getter] + pub fn get_update_sequence<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> { + + let func_map = get_func_map(); + let py_list = PyList::empty(py); + // Iterate over the Rust vector and convert each tuple + for &(num, func) in self.update_sequence.predictions.iter() { + // Retrieve the function name from the map + let func_name = func_map.get(&func).unwrap_or(&"unknown"); + + // Convert the Rust tuple to a Python tuple with the function name as a string + let py_tuple = PyTuple::new(py, &[num.into_py(py), (*func_name).into_py(py)]); + + // Append the Python tuple to the Python list + py_list.append(py_tuple)?; + } + for &(num, func) in self.update_sequence.updates.iter() { + // Retrieve the function name from the map + let func_name = func_map.get(&func).unwrap_or(&"unknown"); + + // Convert the Rust tuple to a Python tuple with the function name as a string + let py_tuple = PyTuple::new(py, &[num.into_py(py), (*func_name).into_py(py)]); + + // Append the Python tuple to the Python list + py_list.append(py_tuple)?; + } + Ok(py_list) + } } // Create a module to expose the class to Python diff --git a/src/updates/mod.rs b/src/updates/mod.rs index 28479cc79..ad43beacb 100644 --- a/src/updates/mod.rs +++ b/src/updates/mod.rs @@ -1,4 +1,4 @@ pub mod posterior; pub mod prediction; pub mod prediction_error; -pub mod observations; \ No newline at end of file +pub mod observations; diff --git a/src/updates/posterior/mod.rs b/src/updates/posterior/mod.rs index d92816181..6817d49e7 100644 --- a/src/updates/posterior/mod.rs +++ b/src/updates/posterior/mod.rs @@ -1,2 +1 @@ -pub mod continuous; -pub mod exponential; \ No newline at end of file +pub mod continuous; \ No newline at end of file diff --git a/src/updates/posterior/exponential.rs b/src/updates/prediction_error/exponential.rs similarity index 91% rename from src/updates/posterior/exponential.rs rename to src/updates/prediction_error/exponential.rs index 20dcf2d15..35695b9c3 100644 --- a/src/updates/posterior/exponential.rs +++ b/src/updates/prediction_error/exponential.rs @@ -9,7 +9,7 @@ use crate::math::sufficient_statistics; /// /// # Returns /// * `network` - The network after message passing. -pub fn posterior_update_exponential_state_node(network: &mut Network, node_idx: usize) { +pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) { match network.nodes.get_mut(&node_idx) { Some(Node::Exponential(ref mut node)) => { diff --git a/src/updates/prediction_error/mod.rs b/src/updates/prediction_error/mod.rs index c53f4e7ae..810699ed6 100644 --- a/src/updates/prediction_error/mod.rs +++ b/src/updates/prediction_error/mod.rs @@ -1 +1,2 @@ pub mod continuous; +pub mod exponential; diff --git a/src/utils/function_pointer.rs b/src/utils/function_pointer.rs new file mode 100644 index 000000000..163c12d11 --- /dev/null +++ b/src/utils/function_pointer.rs @@ -0,0 +1,18 @@ +use std::collections::HashMap; + +use crate::{model::Network, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; + +// Create a default signature for update functions +pub type FnType = for<'a> fn(&'a mut Network, usize); + +pub fn get_func_map() -> HashMap { + let function_map: HashMap = [ + (posterior_update_continuous_state_node as FnType, "posterior_update_continuous_state_node"), + (prediction_continuous_state_node as FnType, "prediction_continuous_state_node"), + (prediction_error_continuous_state_node as FnType, "prediction_error_continuous_state_node"), + (prediction_error_exponential_state_node as FnType, "prediction_error_exponential_state_node"), + ] + .into_iter() + .collect(); + function_map +} \ No newline at end of file diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 000000000..1918ac547 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod set_sequence; +pub mod function_pointer; \ No newline at end of file diff --git a/src/utils.rs b/src/utils/set_sequence.rs similarity index 65% rename from src/utils.rs rename to src/utils/set_sequence.rs index 178d1ba6a..2a130980e 100644 --- a/src/utils.rs +++ b/src/utils/set_sequence.rs @@ -1,6 +1,7 @@ -use crate::{model::{FnType, Network, Node, UpdateSequence}, updates::{posterior::{continuous::posterior_update_continuous_state_node, exponential::posterior_update_exponential_state_node}, prediction::continuous::prediction_continuous_state_node, prediction_error::continuous::prediction_error_continuous_state_node}}; +use crate::{model::{Network, Node, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; +use crate::utils::function_pointer::FnType; -pub fn get_update_sequence(network: &Network) -> UpdateSequence { +pub fn set_update_sequence(network: &Network) -> UpdateSequence { let predictions = get_predictions_sequence(network); let updates = get_updates_sequence(network); @@ -40,8 +41,8 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { // If both are Some, merge the vectors (Some(ref vec1), Some(ref vec2)) => { // Create a new vector by merging the two - let merged_vec: Vec = vec1.iter().chain(vec2.iter()).cloned().collect(); - Some(merged_vec) // Return the merged vector wrapped in Some + let vec: Vec = vec1.iter().chain(vec2.iter()).cloned().collect(); + Some(vec) // Return the merged vector wrapped in Some } // If one is Some and the other is None, return the one that's Some (Some(vec), None) | (None, Some(vec)) => Some(vec.clone()), @@ -99,7 +100,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { po_nodes_idxs.retain(|x| !network.inputs.contains(x)); // iterate over all nodes and add the prediction step if all criteria are met - let mut n_remaining = 2 * pe_nodes_idxs.len(); // posterior updates + prediction errors + let mut n_remaining = po_nodes_idxs.len() + pe_nodes_idxs.len(); // posterior updates + prediction errors while n_remaining > 0 { @@ -107,6 +108,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { let mut has_update = false; // loop over all the remaining nodes for prediction errors --------------------- + // ----------------------------------------------------------------------------- for i in 0..pe_nodes_idxs.len() { let idx = pe_nodes_idxs[i]; @@ -114,25 +116,42 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { // to send a prediction error, this node should have been updated first if !(po_nodes_idxs.contains(&idx)) { + // only send a prediction error if this node has any parent + let value_parents_idxs = &network.edges[idx].value_parents; + let volatility_parents_idxs = &network.edges[idx].volatility_parents; + + let has_parents = match (value_parents_idxs, volatility_parents_idxs) { + // If both are None, return false + (None, None) => false, + _ => true, + }; + // add the node in the update list - match network.nodes.get(&idx) { - Some(Node::Continuous(_)) => { + match (network.nodes.get(&idx), has_parents) { + (Some(Node::Continuous(_)), true) => { updates.push((idx, prediction_error_continuous_state_node)); + // remove the node from the to-be-updated list + pe_nodes_idxs.retain(|&x| x != idx); + n_remaining -= 1; + has_update = true; + break; } - Some(Node::Exponential(_)) => (), - None => () + (Some(Node::Exponential(_)), _) => { + updates.push((idx, prediction_error_exponential_state_node)); + // remove the node from the to-be-updated list + pe_nodes_idxs.retain(|&x| x != idx); + n_remaining -= 1; + has_update = true; + break; + } + _ => () } - - // remove the node from the to-be-updated list - pe_nodes_idxs.retain(|&x| x != idx); - n_remaining -= 1; - has_update = true; - break; } } // loop over all the remaining nodes for posterior updates --------------------- + // ----------------------------------------------------------------------------- for i in 0..po_nodes_idxs.len() { let idx = po_nodes_idxs[i]; @@ -163,17 +182,14 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { }; // 3. if false, add the posterior update to the list - if !(missing_pe) { + if !missing_pe { // add the node in the update list match network.nodes.get(&idx) { Some(Node::Continuous(_)) => { updates.push((idx, posterior_update_continuous_state_node)); } - Some(Node::Exponential(_)) => { - updates.push((idx, posterior_update_exponential_state_node)); - } - None => () + _ => () } @@ -183,9 +199,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { has_update = true; break; } - } - // 2. get update sequence ---------------------------------------------------------- - + } if !(has_update) { break; } @@ -197,46 +211,62 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { // Tests module for unit tests #[cfg(test)] // Only compile and include this module when running tests mod tests { + use crate::utils::function_pointer::get_func_map; + use super::*; // Import the parent module's items to test them #[test] fn test_get_update_order() { + let func_map = get_func_map(); + // initialize network - let mut network = Network::new(); + let mut hgf_network = Network::new(); // create a network - network.add_nodes( + hgf_network.add_nodes( "continuous-state", Some(vec![1]), None, - None, Some(vec![2]), + None, ); - network.add_nodes( + hgf_network.add_nodes( "continuous-state", None, Some(vec![0]), None, None, ); - network.add_nodes( + hgf_network.add_nodes( "continuous-state", None, None, - Some(vec![0]), None, + Some(vec![0]), ); - network.add_nodes( - "exponential-node", + hgf_network.set_update_sequence(); + + println!("Prediction sequence ----------"); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[0].0, func_map.get(&hgf_network.update_sequence.predictions[0].1).unwrap_or(&"unknown")); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[1].0, func_map.get(&hgf_network.update_sequence.predictions[1].1).unwrap_or(&"unknown")); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[2].0, func_map.get(&hgf_network.update_sequence.predictions[2].1).unwrap_or(&"unknown")); + println!("Update sequence ----------"); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[0].0, func_map.get(&hgf_network.update_sequence.updates[0].1).unwrap_or(&"unknown")); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[1].0, func_map.get(&hgf_network.update_sequence.updates[1].1).unwrap_or(&"unknown")); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[2].0, func_map.get(&hgf_network.update_sequence.updates[2].1).unwrap_or(&"unknown")); + + // initialize network + let mut exp_network = Network::new(); + exp_network.add_nodes( + "exponential-state", None, None, None, None, ); + exp_network.set_update_sequence(); + println!("Node: {} - Function name: {}", &exp_network.update_sequence.updates[0].0, func_map.get(&exp_network.update_sequence.updates[0].1).unwrap_or(&"unknown")); - println!("Network: {:?}", network); - println!("Update order: {:?}", get_update_sequence(&network)); - } } From d46529d7b9ef92f895e446de68d914a87f9ef2b4 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 28 Oct 2024 20:55:11 +0100 Subject: [PATCH 15/20] split attributes into floats and vectors --- Cargo.lock | 68 +++++++ Cargo.toml | 3 +- pyhgf/model/network.py | 23 +-- src/math.rs | 4 +- src/model.rs | 174 +++++++++++------- src/updates/observations.rs | 11 +- src/updates/prediction_error/exponential.rs | 32 +++- src/utils/set_sequence.rs | 40 ++-- tests/test_exponential_family.py | 7 +- .../prediction_errors/test_dirichlet.py | 2 +- tests/test_utils.py | 2 +- 11 files changed, 234 insertions(+), 132 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e4c2a367f..a247d39c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,16 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memoffset" version = "0.9.1" @@ -57,6 +67,48 @@ dependencies = [ "autocfg", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -92,6 +144,15 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +[[package]] +name = "portable-atomic-util" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90a7d5beecc52a491b54d6dd05c7a45ba1801666a5baad9fdbfc6fef8d2d206c" +dependencies = [ + "portable-atomic", +] + [[package]] name = "proc-macro2" version = "1.0.87" @@ -173,6 +234,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.5.7" @@ -186,6 +253,7 @@ dependencies = [ name = "rshgf" version = "0.1.0" dependencies = [ + "ndarray", "pyo3", ] diff --git a/Cargo.toml b/Cargo.toml index 0673e99a1..eb8e726cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,5 @@ crate-type = ["cdylib"] path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = { version = "0.21.2", features = ["extension-module"] } \ No newline at end of file +pyo3 = { version = "0.21.2", features = ["extension-module"] } +ndarray = "0.16.1" \ No newline at end of file diff --git a/pyhgf/model/network.py b/pyhgf/model/network.py index f2b9b1e9c..b5096df9a 100644 --- a/pyhgf/model/network.py +++ b/pyhgf/model/network.py @@ -329,29 +329,14 @@ def add_nodes( be a regular state node that can have value and/or volatility parents/children. If `"binary-state"`, the node should be the value parent of a binary input. State nodes filtering distribution from the - exponential family can be created using the `"ef-"` prefix (e.g. - `"ef-normal"` for a univariate normal distribution). Note that only a few - distributions are implemented at the moment. - - In addition to state nodes, four types of input nodes are supported: - - `generic-input`: receive a value or an array and pass it to the parent - nodes. - - `continuous-input`: receive a continuous observation as input. - - `binary-input` receives a single boolean as observation. The parameters - provided to the binary input node contain: 1. `binary_precision`, the binary - input precision, which defaults to `jnp.inf`. 2. `eta0`, the lower bound of - the binary process, which defaults to `0.0`. 3. `eta1`, the higher bound of - the binary process, which defaults to `1.0`. - - `categorical-input` receives a boolean array as observation. The - parameters provided to the categorical input node contain: 1. - `n_categories`, the number of categories implied by the categorical state. + exponential family can be created using `"exponential-state"`. .. note:: When using a categorical state node, the `binary_parameters` can be used to parametrize the implied collection of binary HGFs. .. note: - When using `categorical-input`, the implied `n` binary HGFs are + When using `categorical-state`, the implied `n` binary HGFs are automatically created with a shared volatility parent at the third level, resulting in a network with `3n + 2` nodes in total. @@ -396,7 +381,7 @@ def add_nodes( """ if kind not in [ "DP-state", - "ef-normal", + "exponential-state", "categorical-state", "continuous-state", "binary-state", @@ -483,7 +468,7 @@ def add_nodes( "mean": 0.0, "observed": 1, } - elif "ef-normal" in kind: + elif "exponential-state" in kind: default_parameters = { "nus": 3.0, "xis": jnp.array([0.0, 1.0]), diff --git a/src/math.rs b/src/math.rs index 20a0e568b..80cd8d4dd 100644 --- a/src/math.rs +++ b/src/math.rs @@ -1,3 +1,3 @@ -pub fn sufficient_statistics(x: &f64) -> [f64; 2] { - [*x, x.powf(2.0)] +pub fn sufficient_statistics(x: &f64) -> Vec { + vec![*x, x.powf(2.0)] } \ No newline at end of file diff --git a/src/model.rs b/src/model.rs index 5f23e09e9..931064346 100644 --- a/src/model.rs +++ b/src/model.rs @@ -4,10 +4,13 @@ use crate::utils::set_sequence::set_update_sequence; use crate::utils::function_pointer::get_func_map; use pyo3::types::PyTuple; use pyo3::{prelude::*, types::{PyList, PyDict}}; +use ndarray::{Array2, Axis, stack}; #[derive(Debug)] #[pyclass] pub struct AdjacencyLists{ + #[pyo3(get, set)] + pub node_type: String, #[pyo3(get, set)] pub value_parents: Option>, #[pyo3(get, set)] @@ -17,29 +20,6 @@ pub struct AdjacencyLists{ #[pyo3(get, set)] pub volatility_children: Option>, } -#[derive(Debug, Clone)] -pub struct ContinuousStateNode{ - pub mean: f64, - pub expected_mean: f64, - pub precision: f64, - pub expected_precision: f64, - pub tonic_volatility: f64, - pub tonic_drift: f64, - pub autoconnection_strength: f64, -} -#[derive(Debug, Clone)] -pub struct ExponentialFamiliyStateNode { - pub mean: f64, - pub expected_mean: f64, - pub nus: f64, - pub xis: [f64; 2], -} - -#[derive(Debug, Clone)] -pub enum Node { - Continuous(ContinuousStateNode), - Exponential(ExponentialFamiliyStateNode), -} #[derive(Debug)] pub struct UpdateSequence { @@ -47,13 +27,26 @@ pub struct UpdateSequence { pub updates: Vec<(usize, FnType)>, } +#[derive(Debug)] +pub struct Attributes { + pub floats: HashMap>, + pub vectors: HashMap>>, +} + +#[derive(Debug)] +pub struct NodeTrajectories { + pub floats: HashMap>>, + pub vectors: HashMap>>>, +} + #[derive(Debug)] #[pyclass] pub struct Network{ - pub nodes: HashMap, - pub edges: Vec, + pub attributes: Attributes, + pub edges: HashMap, pub inputs: Vec, pub update_sequence: UpdateSequence, + pub node_trajectories: NodeTrajectories, } #[pymethods] @@ -63,12 +56,13 @@ impl Network { #[new] // Define the constructor accessible from Python pub fn new() -> Self { Network { - nodes: HashMap::new(), - edges: Vec::new(), + attributes: Attributes {floats: HashMap::new(), vectors: HashMap::new()}, + edges: HashMap::new(), inputs: Vec::new(), - update_sequence: UpdateSequence {predictions: Vec::new(), updates: Vec::new()} + update_sequence: UpdateSequence {predictions: Vec::new(), updates: Vec::new()}, + node_trajectories: NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()} + } } - } /// Add nodes to the network. /// @@ -84,7 +78,7 @@ impl Network { volatility_parents: Option>, volatility_children: Option>, ) { // the node ID is equal to the number of nodes already in the network - let node_id: usize = self.nodes.len(); + let node_id: usize = self.edges.len(); // if this node has no children, this is an input node if (value_children == None) & (volatility_children == None) { @@ -92,29 +86,39 @@ impl Network { } let edges = AdjacencyLists{ + node_type: String::from(kind), value_parents: value_parents, value_children: value_children, volatility_parents: volatility_parents, volatility_children: volatility_children, }; - + // add edges and attributes if kind == "continuous-state" { - let continuous_state = ContinuousStateNode{ - mean: 0.0, expected_mean: 0.0, precision: 1.0, expected_precision: 1.0, - tonic_drift: 0.0, tonic_volatility: -4.0, autoconnection_strength: 1.0 - }; - let node = Node::Continuous(continuous_state); - self.nodes.insert(node_id, node); - self.edges.push(edges); + + let attributes = [ + (String::from("mean"), 0.0), + (String::from("expected_mean"), 0.0), + (String::from("precision"), 1.0), + (String::from("expected_precision"), 1.0), + (String::from("tonic_volatility"), -4.0), + (String::from("tonic_drift"), 0.0), + (String::from("autoconnection_strength"), 1.0)].into_iter().collect(); + + self.attributes.floats.insert(node_id, attributes); + self.edges.insert(node_id, edges); } else if kind == "exponential-state" { - let exponential_node: ExponentialFamiliyStateNode = ExponentialFamiliyStateNode{ - mean: 0.0, expected_mean: 0.0, nus: 0.0, xis: [0.0, 0.0] - }; - let node = Node::Exponential(exponential_node); - self.nodes.insert(node_id, node); - self.edges.push(edges); + + let floats_attributes = [ + (String::from("mean"), 0.0), + (String::from("nus"), 3.0)].into_iter().collect(); + let vector_attributes = [ + (String::from("xis"), vec![0.0, 1.0])].into_iter().collect(); + + self.attributes.floats.insert(node_id, floats_attributes); + self.attributes.vectors.insert(node_id, vector_attributes); + self.edges.insert(node_id, edges); } else { println!("Invalid type of node provided ({}).", kind); @@ -157,10 +161,57 @@ impl Network { /// # Arguments /// * `input_data` - A vector of vectors. Each vector is a time series of observations /// associated with one node. - pub fn input_data(&mut self, input_data: Vec>) { + pub fn input_data(&mut self, input_data: Vec) { + + // initialize the belief trajectories result struture + let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()}; + for (node_idx, node) in &self.attributes.floats { + let new_map: HashMap> = HashMap::new(); + node_trajectories.floats.insert(*node_idx, new_map); + if let Some(attr) = node_trajectories.floats.get_mut(node_idx) { + for key in node.keys() { + attr.insert(key.clone(), Vec::new()); + } + } + } + // iterate over the observations for observation in input_data { - self.belief_propagation(observation); + + // 1. belief propagation for one time slice + self.belief_propagation(vec![observation]); + + // 2. append the new states in the result vector + for (new_node_idx, new_node) in &self.attributes.floats { + for (new_key, new_value) in new_node { + // If the key exists in map1, append the vector from map2 + if let Some(old_node) = node_trajectories.floats.get_mut(&new_node_idx) { + if let Some(old_value) = old_node.get_mut(new_key) { + old_value.push(*new_value); + } + } + } + } } + + self.node_trajectories = node_trajectories; + } + + #[getter] + pub fn get_node_trajectories<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> { + let py_list = PyList::empty(py); + + + // Iterate over the Rust HashMap and insert key-value pairs into the PyDict + for (node_idx, node) in &self.node_trajectories.floats { + let py_dict = PyDict::new(py); + for (key, value) in node { + // Create a new Python dictionary + py_dict.set_item(key, value).expect("Failed to set item in PyDict"); + } + py_list.append(py_dict)?; + } + // Create a PyList from Vec + Ok(py_list) } #[getter] @@ -175,13 +226,13 @@ impl Network { let py_list = PyList::empty(py); // Convert each struct in the Vec to a Python object and add to PyList - for s in &self.edges { + for i in 0..self.edges.len() { // Create a new Python dictionary for each MyStruct let py_dict = PyDict::new(py); - py_dict.set_item("value_parents", &s.value_parents)?; - py_dict.set_item("value_children", &s.value_children)?; - py_dict.set_item("volatility_parents", &s.volatility_parents)?; - py_dict.set_item("volatility_children", &s.volatility_children)?; + py_dict.set_item("value_parents", &self.edges[&i].value_parents)?; + py_dict.set_item("value_children", &self.edges[&i].value_children)?; + py_dict.set_item("volatility_parents", &self.edges[&i].volatility_parents)?; + py_dict.set_item("volatility_children", &self.edges[&i].volatility_children)?; // Add the dictionary to the list py_list.append(py_dict)?; @@ -246,29 +297,16 @@ mod tests { None, None ); - network.add_nodes( - "exponential-state", - None, - None, - None, - None - ); // println!("Graph before belief propagation: {:?}", network); // belief propagation - let input_data = vec![ - vec![1.1, 2.2], - vec![1.2, 2.1], - vec![1.0, 2.0], - vec![1.3, 2.2], - vec![1.1, 2.5], - vec![1.0, 2.6], - ]; - + let input_data = vec![1.0, 1.3, 1.5, 1.7]; + network.set_update_sequence(); network.input_data(input_data); - // println!("Graph after belief propagation: {:?}", network); + println!("Update sequence: {:?}", network.update_sequence); + println!("Node trajectories: {:?}", network.node_trajectories); } } diff --git a/src/updates/observations.rs b/src/updates/observations.rs index 1d64ec7a9..1037b83b2 100644 --- a/src/updates/observations.rs +++ b/src/updates/observations.rs @@ -1,4 +1,4 @@ -use crate::model::{Network, Node}; +use crate::model::Network; /// Inject new observations into an input node @@ -12,10 +12,9 @@ use crate::model::{Network, Node}; /// * `network` - The network after message passing. pub fn observation_update(network: &mut Network, node_idx: usize, observations: f64) { - match network.nodes.get_mut(&node_idx) { - Some(Node::Exponential(ref mut node)) => { - node.mean = observations; - } - _ => (), + if let Some(node) = network.attributes.floats.get_mut(&node_idx) { + if let Some(mean) = node.get_mut("mean") { + *mean = observations; } + } } \ No newline at end of file diff --git a/src/updates/prediction_error/exponential.rs b/src/updates/prediction_error/exponential.rs index 35695b9c3..e2cebca05 100644 --- a/src/updates/prediction_error/exponential.rs +++ b/src/updates/prediction_error/exponential.rs @@ -1,4 +1,4 @@ -use crate::model::{Network, Node}; +use crate::model::Network; use crate::math::sufficient_statistics; /// Updating an exponential family state node @@ -11,13 +11,25 @@ use crate::math::sufficient_statistics; /// * `network` - The network after message passing. pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) { - match network.nodes.get_mut(&node_idx) { - Some(Node::Exponential(ref mut node)) => { - let suf_stats = sufficient_statistics(&node.mean); - for i in 0..suf_stats.len() { - node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); - } + if let Some(floats_attributes) = network.attributes.floats.get_mut(&node_idx) { + if let Some(vectors_attributes) = network.attributes.vectors.get_mut(&node_idx) { + let mean = floats_attributes.get("mean"); + let nus = floats_attributes.get("nus"); + let xis = vectors_attributes.get("xis"); + let new_xis = match (mean, nus, xis) { + (Some(mean), Some(nus), Some(xis)) => { + let suf_stats = sufficient_statistics(mean); + let mut new_xis = xis.clone(); + for i in 0..suf_stats.len() { + new_xis[i] = new_xis[i] + (1.0 / (1.0 + nus)) * (suf_stats[i] - xis[i]); + } + new_xis + } + _ => Vec::new(), + }; + if let Some(xis) = vectors_attributes.get_mut("xis") { + *xis = new_xis; // Modify the value directly + } + } } - _ => (), - } -} \ No newline at end of file + } \ No newline at end of file diff --git a/src/utils/set_sequence.rs b/src/utils/set_sequence.rs index 2a130980e..191d734ab 100644 --- a/src/utils/set_sequence.rs +++ b/src/utils/set_sequence.rs @@ -1,4 +1,4 @@ -use crate::{model::{Network, Node, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; +use crate::{model::{AdjacencyLists, Network, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; use crate::utils::function_pointer::FnType; pub fn set_update_sequence(network: &Network) -> UpdateSequence { @@ -18,7 +18,7 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { // 1. get prediction sequence ------------------------------------------------------ // list all nodes availables in the network - let mut nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + let mut nodes_idxs: Vec = network.edges.keys().cloned().collect(); // iterate over all nodes and add the prediction step if all criteria are met let mut n_remaining = nodes_idxs.len(); @@ -34,9 +34,9 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { let idx = nodes_idxs[i]; // list the node's parents - let value_parents_idxs = &network.edges[idx].value_parents; - let volatility_parents_idxs = &network.edges[idx].volatility_parents; - + let value_parents_idxs = &network.edges[&idx].value_parents; + let volatility_parents_idxs = &network.edges[&idx].volatility_parents; + let parents_idxs = match (value_parents_idxs, volatility_parents_idxs) { // If both are Some, merge the vectors (Some(ref vec1), Some(ref vec2)) => { @@ -50,6 +50,7 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { (None, None) => None, }; + // check if there is any parent node that is still found in the to-be-updated list let contains_common = match parents_idxs { Some(vec) => vec.iter().any(|item| nodes_idxs.contains(item)), @@ -60,12 +61,11 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { if !(contains_common) { // add the node in the update list - match network.nodes.get(&idx) { - Some(Node::Continuous(_)) => { + match network.edges.get(&idx) { + Some(AdjacencyLists {node_type, ..}) if node_type == "continuous-state" => { predictions.push((idx, prediction_continuous_state_node)); } - Some(Node::Exponential(_)) => (), - None => () + _ => () } @@ -93,8 +93,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { // 1. get update sequence ---------------------------------------------------------- // list all nodes availables in the network - let mut pe_nodes_idxs: Vec = network.nodes.keys().cloned().collect(); - let mut po_nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + let mut pe_nodes_idxs: Vec = network.edges.keys().cloned().collect(); + let mut po_nodes_idxs: Vec = network.edges.keys().cloned().collect(); // remove the input nodes from the to-be-visited nodes for posterior updates po_nodes_idxs.retain(|x| !network.inputs.contains(x)); @@ -117,8 +117,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { if !(po_nodes_idxs.contains(&idx)) { // only send a prediction error if this node has any parent - let value_parents_idxs = &network.edges[idx].value_parents; - let volatility_parents_idxs = &network.edges[idx].volatility_parents; + let value_parents_idxs = &network.edges[&idx].value_parents; + let volatility_parents_idxs = &network.edges[&idx].volatility_parents; let has_parents = match (value_parents_idxs, volatility_parents_idxs) { // If both are None, return false @@ -127,8 +127,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { }; // add the node in the update list - match (network.nodes.get(&idx), has_parents) { - (Some(Node::Continuous(_)), true) => { + match (network.edges.get(&idx), has_parents) { + (Some(AdjacencyLists {node_type, ..}), true) if node_type == "continuous-state" => { updates.push((idx, prediction_error_continuous_state_node)); // remove the node from the to-be-updated list pe_nodes_idxs.retain(|&x| x != idx); @@ -136,7 +136,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { has_update = true; break; } - (Some(Node::Exponential(_)), _) => { + (Some(AdjacencyLists {node_type, ..}), _) if node_type == "exponential-state" => { updates.push((idx, prediction_error_exponential_state_node)); // remove the node from the to-be-updated list pe_nodes_idxs.retain(|&x| x != idx); @@ -158,8 +158,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { // to start a posterior update, all children should have sent prediction errors // 1. get a list of all children - let value_children_idxs = &network.edges[idx].value_children; - let volatility_children_idxs = &network.edges[idx].volatility_children; + let value_children_idxs = &network.edges[&idx].value_children; + let volatility_children_idxs = &network.edges[&idx].volatility_children; let children_idxs = match (value_children_idxs, volatility_children_idxs) { // If both are Some, merge the vectors @@ -185,8 +185,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { if !missing_pe { // add the node in the update list - match network.nodes.get(&idx) { - Some(Node::Continuous(_)) => { + match network.edges.get(&idx) { + Some(AdjacencyLists {node_type, ..}) if node_type == "continuous-state" => { updates.push((idx, posterior_update_continuous_state_node)); } _ => () diff --git a/tests/test_exponential_family.py b/tests/test_exponential_family.py index 837b685f8..847200e27 100644 --- a/tests/test_exponential_family.py +++ b/tests/test_exponential_family.py @@ -13,13 +13,12 @@ def test_1d_gaussain(): # Rust ----------------------------------------------------------------------------- rs_network = RsNetwork() rs_network.add_nodes(kind="exponential-state") - rs_network.add_nodes(kind="exponential-state") rs_network.inputs rs_network.edges + rs_network.set_update_sequence() - rs_network.input_data([[0, 0], [1, 1]]) - rs_network.input_data([timeseries, timeseries]) + rs_network.input_data(timeseries) # Python --------------------------------------------------------------------------- - py_network = PyNetwork().add_nodes(kind="continuous-state") + py_network = PyNetwork().add_nodes(kind="exponential-state") py_network.attributes diff --git a/tests/test_updates/prediction_errors/test_dirichlet.py b/tests/test_updates/prediction_errors/test_dirichlet.py index e6e12c3bd..d044ca80d 100644 --- a/tests/test_updates/prediction_errors/test_dirichlet.py +++ b/tests/test_updates/prediction_errors/test_dirichlet.py @@ -29,7 +29,7 @@ def test_dirichlet_node_prediction_error(): .add_nodes(kind="generic-state") .add_nodes(kind="DP-state", value_children=0, batch_size=2) .add_nodes( - kind="ef-normal", + kind="exponential-state", n_nodes=2, value_children=1, xis=jnp.array([0.0, 1 / 8]), diff --git a/tests/test_utils.py b/tests/test_utils.py index f3763ef4a..93a37dad4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -87,7 +87,7 @@ def test_set_update_sequence(): .add_nodes(kind="generic-state") .add_nodes(kind="DP-state", value_children=0, alpha=0.1, batch_size=2) .add_nodes( - kind="ef-normal", + kind="exponential-state", n_nodes=2, value_children=1, xis=jnp.array([0.0, 1 / 8]), From 8c5b0caabc9791ed7c2a467eecece5df6554f208 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 29 Oct 2024 12:14:41 +0100 Subject: [PATCH 16/20] exponential 1d node working --- Cargo.lock | 38 ++++++++++------- Cargo.toml | 2 +- docs/source/api.rst | 20 ++++----- .../notebooks/0.3-Generalised_filtering.ipynb | 4 +- pyhgf/model/network.py | 8 ++-- pyhgf/updates/prediction/dirichlet.py | 2 +- .../exponential.py | 2 +- pyhgf/utils.py | 31 +++++++------- src/model.rs | 42 +++++++++++++++++-- tests/test_exponential_family.py | 17 ++++++-- tests/test_utils.py | 2 +- 11 files changed, 112 insertions(+), 56 deletions(-) rename pyhgf/updates/{posterior => prediction_error}/exponential.py (97%) diff --git a/Cargo.lock b/Cargo.lock index a247d39c5..db1887429 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -69,16 +69,14 @@ dependencies = [ [[package]] name = "ndarray" -version = "0.16.1" +version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" dependencies = [ "matrixmultiply", "num-complex", "num-integer", "num-traits", - "portable-atomic", - "portable-atomic-util", "rawpointer", ] @@ -109,6 +107,21 @@ dependencies = [ "autocfg", ] +[[package]] +name = "numpy" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec170733ca37175f5d75a5bea5911d6ff45d2cd52849ce98b685394e4f2f37f4" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "rustc-hash", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -144,15 +157,6 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" -[[package]] -name = "portable-atomic-util" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90a7d5beecc52a491b54d6dd05c7a45ba1801666a5baad9fdbfc6fef8d2d206c" -dependencies = [ - "portable-atomic", -] - [[package]] name = "proc-macro2" version = "1.0.87" @@ -253,10 +257,16 @@ dependencies = [ name = "rshgf" version = "0.1.0" dependencies = [ - "ndarray", + "numpy", "pyo3", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "scopeguard" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index eb8e726cb..d4e4c9e87 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,4 @@ path = "src/lib.rs" # The source file of the target. [dependencies] pyo3 = { version = "0.21.2", features = ["extension-module"] } -ndarray = "0.16.1" \ No newline at end of file +numpy = "0.21" \ No newline at end of file diff --git a/docs/source/api.rst b/docs/source/api.rst index 8f0e2584e..7e3d3cdf3 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -46,16 +46,6 @@ Continuous nodes continuous_node_posterior_update continuous_node_posterior_update_ehgf -Exponential family ------------------- - -.. currentmodule:: pyhgf.updates.posterior.exponential - -.. autosummary:: - :toctree: generated/pyhgf.updates.posterior.exponential - - posterior_update_exponential_family - Prediction steps ================ @@ -146,6 +136,16 @@ Dirichlet state nodes likely_cluster_proposal clusters_likelihood +Exponential family +^^^^^^^^^^^^^^^^^^ + +.. currentmodule:: pyhgf.updates.prediction_error.exponential + +.. autosummary:: + :toctree: generated/pyhgf.updates.prediction_error.exponential + + prediction_error_update_exponential_family + Distribution ************ diff --git a/docs/source/notebooks/0.3-Generalised_filtering.ipynb b/docs/source/notebooks/0.3-Generalised_filtering.ipynb index fd3a19c5f..257839b20 100644 --- a/docs/source/notebooks/0.3-Generalised_filtering.ipynb +++ b/docs/source/notebooks/0.3-Generalised_filtering.ipynb @@ -320,7 +320,7 @@ "\n", "### Using a fixed $\\nu$\n", "\n", - "This operation can be achieved using a continuous state node that implements the exponential family updates on the values that are passed by the value child nodes. Such nodes are referred to as `ef-` nodes, with the type of distribution (here a simple one-dimensional Gaussian distribution, therefore the kind is set to `\"ef-normal\"`). The input node is set to generic, which means that this input simply passes the observed value to the value parents without any additional computation. We can define such a model as follows:" + "This operation can be achieved using a continuous state node that implements the exponential family updates on the values that are passed by the value child nodes. Such nodes are referred to as `exponential-state` nodes, with the type of distribution (here a simple one-dimensional Gaussian distribution). The input node is set to generic, which means that this input simply passes the observed value to the value parents without any additional computation. We can define such a model as follows:" ] }, { @@ -340,7 +340,7 @@ "generalised_filter = (\n", " Network()\n", " .add_nodes(kind=\"generic-state\")\n", - " .add_nodes(kind=\"ef-normal\", value_children=0, xis=np.array([0, 1 / 8]))\n", + " .add_nodes(kind=\"exponential-state\", value_children=0, xis=np.array([0, 1 / 8]))\n", ")" ] }, diff --git a/pyhgf/model/network.py b/pyhgf/model/network.py index b5096df9a..a6041341f 100644 --- a/pyhgf/model/network.py +++ b/pyhgf/model/network.py @@ -390,8 +390,8 @@ def add_nodes( raise ValueError( ( "Invalid node type. Should be one of the following: " - "'DP-state', 'continuous-state', 'binary-state', 'ef-normal'." - "'generic-state' or 'categorical-state'" + "'DP-state', 'continuous-state', 'binary-state', " + "'exponential-state', 'generic-state' or 'categorical-state'" ) ) @@ -473,7 +473,7 @@ def add_nodes( "nus": 3.0, "xis": jnp.array([0.0, 1.0]), "mean": 0.0, - "observed": 1.0, + "observed": 1, } elif kind == "categorical-state": if "n_categories" in node_parameters: @@ -562,7 +562,7 @@ def add_nodes( node_type = 1 elif kind == "continuous-state": node_type = 2 - elif kind == "ef-normal": + elif kind == "exponential-state": node_type = 3 elif kind == "DP-state": node_type = 4 diff --git a/pyhgf/updates/prediction/dirichlet.py b/pyhgf/updates/prediction/dirichlet.py index 03fb86e25..298fbd2b9 100644 --- a/pyhgf/updates/prediction/dirichlet.py +++ b/pyhgf/updates/prediction/dirichlet.py @@ -39,7 +39,7 @@ def dirichlet_node_prediction( Static parameters of the Dirichlet process node. """ - # get the parameter (mean and variance) from the EF-normal parent nodes + # get the parameter (mean and variance) from the exponential state parent nodes value_parent_idxs = edges[node_idx].value_parents if value_parent_idxs is not None: parameters = jnp.array( diff --git a/pyhgf/updates/posterior/exponential.py b/pyhgf/updates/prediction_error/exponential.py similarity index 97% rename from pyhgf/updates/posterior/exponential.py rename to pyhgf/updates/prediction_error/exponential.py index f4b9ed519..a84b31b53 100644 --- a/pyhgf/updates/posterior/exponential.py +++ b/pyhgf/updates/prediction_error/exponential.py @@ -10,7 +10,7 @@ @partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn")) -def posterior_update_exponential_family( +def prediction_error_update_exponential_family( attributes: Dict, edges: Edges, node_idx: int, sufficient_stats_fn: Callable, **args ) -> Attributes: r"""Update the parameters of an exponential family distribution. diff --git a/pyhgf/utils.py b/pyhgf/utils.py index 227baf7cd..7a713cffa 100644 --- a/pyhgf/utils.py +++ b/pyhgf/utils.py @@ -18,7 +18,6 @@ continuous_node_posterior_update, continuous_node_posterior_update_ehgf, ) -from pyhgf.updates.posterior.exponential import posterior_update_exponential_family from pyhgf.updates.prediction.binary import binary_state_node_prediction from pyhgf.updates.prediction.continuous import continuous_node_prediction from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction @@ -28,6 +27,9 @@ ) from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error +from pyhgf.updates.prediction_error.exponential import ( + prediction_error_update_exponential_family, +) from pyhgf.updates.prediction_error.generic import generic_state_prediction_error if TYPE_CHECKING: @@ -374,16 +376,6 @@ def get_update_sequence( elif update_type == "standard": update_fn = continuous_node_posterior_update - elif network.edges[idx].node_type == 3: - - # create the sufficient statistic function - # for the exponential family node - ef_update = Partial( - posterior_update_exponential_family, - sufficient_stats_fn=Normal().sufficient_statistics, - ) - update_fn = ef_update - elif network.edges[idx].node_type == 4: update_fn = None @@ -407,8 +399,21 @@ def get_update_sequence( ] # if this node has no parent, no need to compute prediction errors + # unless this is an exponential family state node if len(all_parents) == 0: - nodes_without_prediction_error.remove(idx) + if network.edges[idx].node_type == 3: + # create the sufficient statistic function + # for the exponential family node + ef_update = Partial( + prediction_error_update_exponential_family, + sufficient_stats_fn=Normal().sufficient_statistics, + ) + update_fn = ef_update + no_update = False + update_sequence.append((idx, update_fn)) + nodes_without_prediction_error.remove(idx) + else: + nodes_without_prediction_error.remove(idx) else: # if this node has been updated if idx not in nodes_without_posterior_update: @@ -419,8 +424,6 @@ def get_update_sequence( update_fn = binary_state_node_prediction_error elif network.edges[idx].node_type == 2: update_fn = continuous_node_prediction_error - elif network.edges[idx].node_type == 3: - update_fn = None elif network.edges[idx].node_type == 4: update_fn = dirichlet_node_prediction_error elif network.edges[idx].node_type == 5: diff --git a/src/model.rs b/src/model.rs index 931064346..223458da3 100644 --- a/src/model.rs +++ b/src/model.rs @@ -4,7 +4,7 @@ use crate::utils::set_sequence::set_update_sequence; use crate::utils::function_pointer::get_func_map; use pyo3::types::PyTuple; use pyo3::{prelude::*, types::{PyList, PyDict}}; -use ndarray::{Array2, Axis, stack}; +use numpy::{PyArray1, PyArray}; #[derive(Debug)] #[pyclass] @@ -165,6 +165,8 @@ impl Network { // initialize the belief trajectories result struture let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()}; + + // add empty vectors in the floats hashmap for (node_idx, node) in &self.attributes.floats { let new_map: HashMap> = HashMap::new(); node_trajectories.floats.insert(*node_idx, new_map); @@ -174,13 +176,25 @@ impl Network { } } } + // add empty vectors in the vectors hashmap + for (node_idx, node) in &self.attributes.vectors { + let new_map: HashMap>> = HashMap::new(); + node_trajectories.vectors.insert(*node_idx, new_map); + if let Some(attr) = node_trajectories.vectors.get_mut(node_idx) { + for key in node.keys() { + attr.insert(key.clone(), Vec::new()); + } + } + } + // iterate over the observations for observation in input_data { // 1. belief propagation for one time slice self.belief_propagation(vec![observation]); - // 2. append the new states in the result vector + // 2. append the new beliefs in the trajectories structure + // iterate over the float hashmap for (new_node_idx, new_node) in &self.attributes.floats { for (new_key, new_value) in new_node { // If the key exists in map1, append the vector from map2 @@ -191,6 +205,17 @@ impl Network { } } } + // iterate over the vector hashmap + for (new_node_idx, new_node) in &self.attributes.vectors { + for (new_key, new_value) in new_node { + // If the key exists in map1, append the vector from map2 + if let Some(old_node) = node_trajectories.vectors.get_mut(&new_node_idx) { + if let Some(old_value) = old_node.get_mut(new_key) { + old_value.push(new_value.clone()); + } + } + } + } } self.node_trajectories = node_trajectories; @@ -201,15 +226,24 @@ impl Network { let py_list = PyList::empty(py); - // Iterate over the Rust HashMap and insert key-value pairs into the PyDict + // Iterate over the float hashmap and insert key-value pairs into the list as PyDict for (node_idx, node) in &self.node_trajectories.floats { let py_dict = PyDict::new(py); for (key, value) in node { // Create a new Python dictionary - py_dict.set_item(key, value).expect("Failed to set item in PyDict"); + py_dict.set_item(key, PyArray1::from_vec(py, value.clone()).to_owned()).expect("Failed to set item in PyDict"); + } + + // Iterate over the vector hashmap if any and insert key-value pairs into the list as PyDict + if let Some(vector_node) = self.node_trajectories.vectors.get(node_idx) { + for (vector_key, vector_value) in vector_node { + // Create a new Python dictionary + py_dict.set_item(vector_key, PyArray::from_vec2_bound(py, &vector_value).unwrap()).expect("Failed to set item in PyDict"); + } } py_list.append(py_dict)?; } + // Create a PyList from Vec Ok(py_list) } diff --git a/tests/test_exponential_family.py b/tests/test_exponential_family.py index 847200e27..1e5f834aa 100644 --- a/tests/test_exponential_family.py +++ b/tests/test_exponential_family.py @@ -1,5 +1,6 @@ # Author: Nicolas Legrand +import numpy as np from rshgf import Network as RsNetwork from pyhgf import load_data @@ -13,12 +14,20 @@ def test_1d_gaussain(): # Rust ----------------------------------------------------------------------------- rs_network = RsNetwork() rs_network.add_nodes(kind="exponential-state") - rs_network.inputs - rs_network.edges rs_network.set_update_sequence() - rs_network.input_data(timeseries) # Python --------------------------------------------------------------------------- py_network = PyNetwork().add_nodes(kind="exponential-state") - py_network.attributes + py_network.input_data(timeseries) + + # Ensure identical results + assert np.isclose( + py_network.node_trajectories[0]["xis"], rs_network.node_trajectories[0]["xis"] + ).all() + assert np.isclose( + py_network.node_trajectories[0]["mean"], rs_network.node_trajectories[0]["mean"] + ).all() + assert np.isclose( + py_network.node_trajectories[0]["nus"], rs_network.node_trajectories[0]["nus"] + ).all() diff --git a/tests/test_utils.py b/tests/test_utils.py index 93a37dad4..1dde30aec 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -74,7 +74,7 @@ def test_set_update_sequence(): network3 = ( Network() .add_nodes(kind="generic-state") - .add_nodes(kind="ef-normal", value_children=0) + .add_nodes(kind="exponential-state", value_children=0) .create_belief_propagation_fn() ) predictions, updates = network3.update_sequence From 19ee4852bbfb46b85973baf455bd226e8966e270 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 29 Oct 2024 12:35:50 +0100 Subject: [PATCH 17/20] github action --- .github/workflows/test.yml | 17 ++++++++++++++--- pyproject.toml | 1 + 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2d7df66cb..8fefe136d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,18 +37,29 @@ jobs: # Step 4: Install System Dependencies - name: Install Graphviz run: sudo apt-get install -y graphviz + + # Step 5: Install Rust + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + profile: minimal - # Step 5: Install Python Dependencies using Poetry + # Step 6: Install Python Dependencies using Poetry - name: Install dependencies run: | poetry install --with dev + + # Step 7: Install the Rust package + - name: Build and Install the Package + run: maturin develop - # Step 6: Run Tests and Generate Coverage Report + # Step 8: Run Tests and Generate Coverage Report - name: Run tests and coverage run: | poetry run pytest ./tests/ --cov=./src/pyhgf/ --cov-report=xml - # Step 7: Upload Coverage Report to Codecov + # Step 9: Upload Coverage Report to Codecov - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: diff --git a/pyproject.toml b/pyproject.toml index eae00e95f..678d407fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ ipykernel = "^6.29.5 " coverage = "^7.6.3" pytest = "^8.3.3" pytest-cov = "^5.0.0" +maturin = "^1.7.4" [build-system] requires = ["poetry-core", "poetry-dynamic-versioning>=1.0.0,<2.0.0", "maturin>=1.4,<2.0"] From 7af1945a79875e512430a7f2ebe74b82eb126031 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 29 Oct 2024 12:38:13 +0100 Subject: [PATCH 18/20] lock file --- poetry.lock | 642 +++++++++++++++++++++++++++------------------------- 1 file changed, 338 insertions(+), 304 deletions(-) diff --git a/poetry.lock b/poetry.lock index a24410a45..953345830 100644 --- a/poetry.lock +++ b/poetry.lock @@ -528,73 +528,73 @@ test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist" [[package]] name = "coverage" -version = "7.6.3" +version = "7.6.4" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.9" files = [ - {file = "coverage-7.6.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6da42bbcec130b188169107ecb6ee7bd7b4c849d24c9370a0c884cf728d8e976"}, - {file = "coverage-7.6.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c222958f59b0ae091f4535851cbb24eb57fc0baea07ba675af718fb5302dddb2"}, - {file = "coverage-7.6.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab84a8b698ad5a6c365b08061920138e7a7dd9a04b6feb09ba1bfae68346ce6d"}, - {file = "coverage-7.6.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70a6756ce66cd6fe8486c775b30889f0dc4cb20c157aa8c35b45fd7868255c5c"}, - {file = "coverage-7.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c2e6fa98032fec8282f6b27e3f3986c6e05702828380618776ad794e938f53a"}, - {file = "coverage-7.6.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:921fbe13492caf6a69528f09d5d7c7d518c8d0e7b9f6701b7719715f29a71e6e"}, - {file = "coverage-7.6.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:6d99198203f0b9cb0b5d1c0393859555bc26b548223a769baf7e321a627ed4fc"}, - {file = "coverage-7.6.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:87cd2e29067ea397a47e352efb13f976eb1b03e18c999270bb50589323294c6e"}, - {file = "coverage-7.6.3-cp310-cp310-win32.whl", hash = "sha256:a3328c3e64ea4ab12b85999eb0779e6139295bbf5485f69d42cf794309e3d007"}, - {file = "coverage-7.6.3-cp310-cp310-win_amd64.whl", hash = "sha256:bca4c8abc50d38f9773c1ec80d43f3768df2e8576807d1656016b9d3eeaa96fd"}, - {file = "coverage-7.6.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c51ef82302386d686feea1c44dbeef744585da16fcf97deea2a8d6c1556f519b"}, - {file = "coverage-7.6.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0ca37993206402c6c35dc717f90d4c8f53568a8b80f0bf1a1b2b334f4d488fba"}, - {file = "coverage-7.6.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c77326300b839c44c3e5a8fe26c15b7e87b2f32dfd2fc9fee1d13604347c9b38"}, - {file = "coverage-7.6.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e484e479860e00da1f005cd19d1c5d4a813324e5951319ac3f3eefb497cc549"}, - {file = "coverage-7.6.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c6c0f4d53ef603397fc894a895b960ecd7d44c727df42a8d500031716d4e8d2"}, - {file = "coverage-7.6.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:37be7b5ea3ff5b7c4a9db16074dc94523b5f10dd1f3b362a827af66a55198175"}, - {file = "coverage-7.6.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:43b32a06c47539fe275106b376658638b418c7cfdfff0e0259fbf877e845f14b"}, - {file = "coverage-7.6.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ee77c7bef0724165e795b6b7bf9c4c22a9b8468a6bdb9c6b4281293c6b22a90f"}, - {file = "coverage-7.6.3-cp311-cp311-win32.whl", hash = "sha256:43517e1f6b19f610a93d8227e47790722c8bf7422e46b365e0469fc3d3563d97"}, - {file = "coverage-7.6.3-cp311-cp311-win_amd64.whl", hash = "sha256:04f2189716e85ec9192df307f7c255f90e78b6e9863a03223c3b998d24a3c6c6"}, - {file = "coverage-7.6.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:27bd5f18d8f2879e45724b0ce74f61811639a846ff0e5c0395b7818fae87aec6"}, - {file = "coverage-7.6.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d546cfa78844b8b9c1c0533de1851569a13f87449897bbc95d698d1d3cb2a30f"}, - {file = "coverage-7.6.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9975442f2e7a5cfcf87299c26b5a45266ab0696348420049b9b94b2ad3d40234"}, - {file = "coverage-7.6.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:583049c63106c0555e3ae3931edab5669668bbef84c15861421b94e121878d3f"}, - {file = "coverage-7.6.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2341a78ae3a5ed454d524206a3fcb3cec408c2a0c7c2752cd78b606a2ff15af4"}, - {file = "coverage-7.6.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a4fb91d5f72b7e06a14ff4ae5be625a81cd7e5f869d7a54578fc271d08d58ae3"}, - {file = "coverage-7.6.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e279f3db904e3b55f520f11f983cc8dc8a4ce9b65f11692d4718ed021ec58b83"}, - {file = "coverage-7.6.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:aa23ce39661a3e90eea5f99ec59b763b7d655c2cada10729ed920a38bfc2b167"}, - {file = "coverage-7.6.3-cp312-cp312-win32.whl", hash = "sha256:52ac29cc72ee7e25ace7807249638f94c9b6a862c56b1df015d2b2e388e51dbd"}, - {file = "coverage-7.6.3-cp312-cp312-win_amd64.whl", hash = "sha256:40e8b1983080439d4802d80b951f4a93d991ef3261f69e81095a66f86cf3c3c6"}, - {file = "coverage-7.6.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9134032f5aa445ae591c2ba6991d10136a1f533b1d2fa8f8c21126468c5025c6"}, - {file = "coverage-7.6.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:99670790f21a96665a35849990b1df447993880bb6463a0a1d757897f30da929"}, - {file = "coverage-7.6.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc7d6b380ca76f5e817ac9eef0c3686e7834c8346bef30b041a4ad286449990"}, - {file = "coverage-7.6.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7b26757b22faf88fcf232f5f0e62f6e0fd9e22a8a5d0d5016888cdfe1f6c1c4"}, - {file = "coverage-7.6.3-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c59d6a4a4633fad297f943c03d0d2569867bd5372eb5684befdff8df8522e39"}, - {file = "coverage-7.6.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f263b18692f8ed52c8de7f40a0751e79015983dbd77b16906e5b310a39d3ca21"}, - {file = "coverage-7.6.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:79644f68a6ff23b251cae1c82b01a0b51bc40c8468ca9585c6c4b1aeee570e0b"}, - {file = "coverage-7.6.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:71967c35828c9ff94e8c7d405469a1fb68257f686bca7c1ed85ed34e7c2529c4"}, - {file = "coverage-7.6.3-cp313-cp313-win32.whl", hash = "sha256:e266af4da2c1a4cbc6135a570c64577fd3e6eb204607eaff99d8e9b710003c6f"}, - {file = "coverage-7.6.3-cp313-cp313-win_amd64.whl", hash = "sha256:ea52bd218d4ba260399a8ae4bb6b577d82adfc4518b93566ce1fddd4a49d1dce"}, - {file = "coverage-7.6.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:8d4c6ea0f498c7c79111033a290d060c517853a7bcb2f46516f591dab628ddd3"}, - {file = "coverage-7.6.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:331b200ad03dbaa44151d74daeb7da2cf382db424ab923574f6ecca7d3b30de3"}, - {file = "coverage-7.6.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54356a76b67cf8a3085818026bb556545ebb8353951923b88292556dfa9f812d"}, - {file = "coverage-7.6.3-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ebec65f5068e7df2d49466aab9128510c4867e532e07cb6960075b27658dca38"}, - {file = "coverage-7.6.3-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d33a785ea8354c480515e781554d3be582a86297e41ccbea627a5c632647f2cd"}, - {file = "coverage-7.6.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f7ddb920106bbbbcaf2a274d56f46956bf56ecbde210d88061824a95bdd94e92"}, - {file = "coverage-7.6.3-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:70d24936ca6c15a3bbc91ee9c7fc661132c6f4c9d42a23b31b6686c05073bde5"}, - {file = "coverage-7.6.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c30e42ea11badb147f0d2e387115b15e2bd8205a5ad70d6ad79cf37f6ac08c91"}, - {file = "coverage-7.6.3-cp313-cp313t-win32.whl", hash = "sha256:365defc257c687ce3e7d275f39738dcd230777424117a6c76043459db131dd43"}, - {file = "coverage-7.6.3-cp313-cp313t-win_amd64.whl", hash = "sha256:23bb63ae3f4c645d2d82fa22697364b0046fbafb6261b258a58587441c5f7bd0"}, - {file = "coverage-7.6.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:da29ceabe3025a1e5a5aeeb331c5b1af686daab4ff0fb4f83df18b1180ea83e2"}, - {file = "coverage-7.6.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:df8c05a0f574d480947cba11b947dc41b1265d721c3777881da2fb8d3a1ddfba"}, - {file = "coverage-7.6.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec1e3b40b82236d100d259854840555469fad4db64f669ab817279eb95cd535c"}, - {file = "coverage-7.6.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b4adeb878a374126f1e5cf03b87f66279f479e01af0e9a654cf6d1509af46c40"}, - {file = "coverage-7.6.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43d6a66e33b1455b98fc7312b124296dad97a2e191c80320587234a77b1b736e"}, - {file = "coverage-7.6.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1990b1f4e2c402beb317840030bb9f1b6a363f86e14e21b4212e618acdfce7f6"}, - {file = "coverage-7.6.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:12f9515d875859faedb4144fd38694a761cd2a61ef9603bf887b13956d0bbfbb"}, - {file = "coverage-7.6.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:99ded130555c021d99729fabd4ddb91a6f4cc0707df4b1daf912c7850c373b13"}, - {file = "coverage-7.6.3-cp39-cp39-win32.whl", hash = "sha256:c3a79f56dee9136084cf84a6c7c4341427ef36e05ae6415bf7d787c96ff5eaa3"}, - {file = "coverage-7.6.3-cp39-cp39-win_amd64.whl", hash = "sha256:aac7501ae73d4a02f4b7ac8fcb9dc55342ca98ffb9ed9f2dfb8a25d53eda0e4d"}, - {file = "coverage-7.6.3-pp39.pp310-none-any.whl", hash = "sha256:b9853509b4bf57ba7b1f99b9d866c422c9c5248799ab20e652bbb8a184a38181"}, - {file = "coverage-7.6.3.tar.gz", hash = "sha256:bb7d5fe92bd0dc235f63ebe9f8c6e0884f7360f88f3411bfed1350c872ef2054"}, + {file = "coverage-7.6.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f8ae553cba74085db385d489c7a792ad66f7f9ba2ee85bfa508aeb84cf0ba07"}, + {file = "coverage-7.6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8165b796df0bd42e10527a3f493c592ba494f16ef3c8b531288e3d0d72c1f6f0"}, + {file = "coverage-7.6.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7c8b95bf47db6d19096a5e052ffca0a05f335bc63cef281a6e8fe864d450a72"}, + {file = "coverage-7.6.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ed9281d1b52628e81393f5eaee24a45cbd64965f41857559c2b7ff19385df51"}, + {file = "coverage-7.6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0809082ee480bb8f7416507538243c8863ac74fd8a5d2485c46f0f7499f2b491"}, + {file = "coverage-7.6.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d541423cdd416b78626b55f123412fcf979d22a2c39fce251b350de38c15c15b"}, + {file = "coverage-7.6.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58809e238a8a12a625c70450b48e8767cff9eb67c62e6154a642b21ddf79baea"}, + {file = "coverage-7.6.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c9b8e184898ed014884ca84c70562b4a82cbc63b044d366fedc68bc2b2f3394a"}, + {file = "coverage-7.6.4-cp310-cp310-win32.whl", hash = "sha256:6bd818b7ea14bc6e1f06e241e8234508b21edf1b242d49831831a9450e2f35fa"}, + {file = "coverage-7.6.4-cp310-cp310-win_amd64.whl", hash = "sha256:06babbb8f4e74b063dbaeb74ad68dfce9186c595a15f11f5d5683f748fa1d172"}, + {file = "coverage-7.6.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:73d2b73584446e66ee633eaad1a56aad577c077f46c35ca3283cd687b7715b0b"}, + {file = "coverage-7.6.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:51b44306032045b383a7a8a2c13878de375117946d68dcb54308111f39775a25"}, + {file = "coverage-7.6.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b3fb02fe73bed561fa12d279a417b432e5b50fe03e8d663d61b3d5990f29546"}, + {file = "coverage-7.6.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed8fe9189d2beb6edc14d3ad19800626e1d9f2d975e436f84e19efb7fa19469b"}, + {file = "coverage-7.6.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b369ead6527d025a0fe7bd3864e46dbee3aa8f652d48df6174f8d0bac9e26e0e"}, + {file = "coverage-7.6.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ade3ca1e5f0ff46b678b66201f7ff477e8fa11fb537f3b55c3f0568fbfe6e718"}, + {file = "coverage-7.6.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:27fb4a050aaf18772db513091c9c13f6cb94ed40eacdef8dad8411d92d9992db"}, + {file = "coverage-7.6.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4f704f0998911abf728a7783799444fcbbe8261c4a6c166f667937ae6a8aa522"}, + {file = "coverage-7.6.4-cp311-cp311-win32.whl", hash = "sha256:29155cd511ee058e260db648b6182c419422a0d2e9a4fa44501898cf918866cf"}, + {file = "coverage-7.6.4-cp311-cp311-win_amd64.whl", hash = "sha256:8902dd6a30173d4ef09954bfcb24b5d7b5190cf14a43170e386979651e09ba19"}, + {file = "coverage-7.6.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:12394842a3a8affa3ba62b0d4ab7e9e210c5e366fbac3e8b2a68636fb19892c2"}, + {file = "coverage-7.6.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b6b4c83d8e8ea79f27ab80778c19bc037759aea298da4b56621f4474ffeb117"}, + {file = "coverage-7.6.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d5b8007f81b88696d06f7df0cb9af0d3b835fe0c8dbf489bad70b45f0e45613"}, + {file = "coverage-7.6.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b57b768feb866f44eeed9f46975f3d6406380275c5ddfe22f531a2bf187eda27"}, + {file = "coverage-7.6.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5915fcdec0e54ee229926868e9b08586376cae1f5faa9bbaf8faf3561b393d52"}, + {file = "coverage-7.6.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b58c672d14f16ed92a48db984612f5ce3836ae7d72cdd161001cc54512571f2"}, + {file = "coverage-7.6.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2fdef0d83a2d08d69b1f2210a93c416d54e14d9eb398f6ab2f0a209433db19e1"}, + {file = "coverage-7.6.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8cf717ee42012be8c0cb205dbbf18ffa9003c4cbf4ad078db47b95e10748eec5"}, + {file = "coverage-7.6.4-cp312-cp312-win32.whl", hash = "sha256:7bb92c539a624cf86296dd0c68cd5cc286c9eef2d0c3b8b192b604ce9de20a17"}, + {file = "coverage-7.6.4-cp312-cp312-win_amd64.whl", hash = "sha256:1032e178b76a4e2b5b32e19d0fd0abbce4b58e77a1ca695820d10e491fa32b08"}, + {file = "coverage-7.6.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:023bf8ee3ec6d35af9c1c6ccc1d18fa69afa1cb29eaac57cb064dbb262a517f9"}, + {file = "coverage-7.6.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b0ac3d42cb51c4b12df9c5f0dd2f13a4f24f01943627120ec4d293c9181219ba"}, + {file = "coverage-7.6.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8fe4984b431f8621ca53d9380901f62bfb54ff759a1348cd140490ada7b693c"}, + {file = "coverage-7.6.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5fbd612f8a091954a0c8dd4c0b571b973487277d26476f8480bfa4b2a65b5d06"}, + {file = "coverage-7.6.4-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dacbc52de979f2823a819571f2e3a350a7e36b8cb7484cdb1e289bceaf35305f"}, + {file = "coverage-7.6.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:dab4d16dfef34b185032580e2f2f89253d302facba093d5fa9dbe04f569c4f4b"}, + {file = "coverage-7.6.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:862264b12ebb65ad8d863d51f17758b1684560b66ab02770d4f0baf2ff75da21"}, + {file = "coverage-7.6.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5beb1ee382ad32afe424097de57134175fea3faf847b9af002cc7895be4e2a5a"}, + {file = "coverage-7.6.4-cp313-cp313-win32.whl", hash = "sha256:bf20494da9653f6410213424f5f8ad0ed885e01f7e8e59811f572bdb20b8972e"}, + {file = "coverage-7.6.4-cp313-cp313-win_amd64.whl", hash = "sha256:182e6cd5c040cec0a1c8d415a87b67ed01193ed9ad458ee427741c7d8513d963"}, + {file = "coverage-7.6.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a181e99301a0ae128493a24cfe5cfb5b488c4e0bf2f8702091473d033494d04f"}, + {file = "coverage-7.6.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:df57bdbeffe694e7842092c5e2e0bc80fff7f43379d465f932ef36f027179806"}, + {file = "coverage-7.6.4-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bcd1069e710600e8e4cf27f65c90c7843fa8edfb4520fb0ccb88894cad08b11"}, + {file = "coverage-7.6.4-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99b41d18e6b2a48ba949418db48159d7a2e81c5cc290fc934b7d2380515bd0e3"}, + {file = "coverage-7.6.4-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6b1e54712ba3474f34b7ef7a41e65bd9037ad47916ccb1cc78769bae324c01a"}, + {file = "coverage-7.6.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:53d202fd109416ce011578f321460795abfe10bb901b883cafd9b3ef851bacfc"}, + {file = "coverage-7.6.4-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:c48167910a8f644671de9f2083a23630fbf7a1cb70ce939440cd3328e0919f70"}, + {file = "coverage-7.6.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cc8ff50b50ce532de2fa7a7daae9dd12f0a699bfcd47f20945364e5c31799fef"}, + {file = "coverage-7.6.4-cp313-cp313t-win32.whl", hash = "sha256:b8d3a03d9bfcaf5b0141d07a88456bb6a4c3ce55c080712fec8418ef3610230e"}, + {file = "coverage-7.6.4-cp313-cp313t-win_amd64.whl", hash = "sha256:f3ddf056d3ebcf6ce47bdaf56142af51bb7fad09e4af310241e9db7a3a8022e1"}, + {file = "coverage-7.6.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cb7fa111d21a6b55cbf633039f7bc2749e74932e3aa7cb7333f675a58a58bf3"}, + {file = "coverage-7.6.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:11a223a14e91a4693d2d0755c7a043db43d96a7450b4f356d506c2562c48642c"}, + {file = "coverage-7.6.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a413a096c4cbac202433c850ee43fa326d2e871b24554da8327b01632673a076"}, + {file = "coverage-7.6.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00a1d69c112ff5149cabe60d2e2ee948752c975d95f1e1096742e6077affd376"}, + {file = "coverage-7.6.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f76846299ba5c54d12c91d776d9605ae33f8ae2b9d1d3c3703cf2db1a67f2c0"}, + {file = "coverage-7.6.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:fe439416eb6380de434886b00c859304338f8b19f6f54811984f3420a2e03858"}, + {file = "coverage-7.6.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:0294ca37f1ba500667b1aef631e48d875ced93ad5e06fa665a3295bdd1d95111"}, + {file = "coverage-7.6.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6f01ba56b1c0e9d149f9ac85a2f999724895229eb36bd997b61e62999e9b0901"}, + {file = "coverage-7.6.4-cp39-cp39-win32.whl", hash = "sha256:bc66f0bf1d7730a17430a50163bb264ba9ded56739112368ba985ddaa9c3bd09"}, + {file = "coverage-7.6.4-cp39-cp39-win_amd64.whl", hash = "sha256:c481b47f6b5845064c65a7bc78bc0860e635a9b055af0df46fdf1c58cebf8e8f"}, + {file = "coverage-7.6.4-pp39.pp310-none-any.whl", hash = "sha256:3c65d37f3a9ebb703e710befdc489a38683a5b152242664b973a7b7b22348a4e"}, + {file = "coverage-7.6.4.tar.gz", hash = "sha256:29fc0f17b1d3fea332f8001d4558f8214af7f1d87a345f3a133c901d60347c73"}, ] [package.dependencies] @@ -1070,13 +1070,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio [[package]] name = "ipython" -version = "8.28.0" +version = "8.29.0" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.10" files = [ - {file = "ipython-8.28.0-py3-none-any.whl", hash = "sha256:530ef1e7bb693724d3cdc37287c80b07ad9b25986c007a53aa1857272dac3f35"}, - {file = "ipython-8.28.0.tar.gz", hash = "sha256:0d0d15ca1e01faeb868ef56bc7ee5a0de5bd66885735682e8a322ae289a13d1a"}, + {file = "ipython-8.29.0-py3-none-any.whl", hash = "sha256:0188a1bd83267192123ccea7f4a8ed0a78910535dbaa3f37671dca76ebd429c8"}, + {file = "ipython-8.29.0.tar.gz", hash = "sha256:40b60e15b22591450eef73e40a027cf77bd652e757523eebc5bd7c7c498290eb"}, ] [package.dependencies] @@ -1558,72 +1558,72 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] [[package]] name = "markupsafe" -version = "3.0.1" +version = "3.0.2" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.9" files = [ - {file = "MarkupSafe-3.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:db842712984e91707437461930e6011e60b39136c7331e971952bb30465bc1a1"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3ffb4a8e7d46ed96ae48805746755fadd0909fea2306f93d5d8233ba23dda12a"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67c519635a4f64e495c50e3107d9b4075aec33634272b5db1cde839e07367589"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48488d999ed50ba8d38c581d67e496f955821dc183883550a6fbc7f1aefdc170"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f31ae06f1328595d762c9a2bf29dafd8621c7d3adc130cbb46278079758779ca"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:80fcbf3add8790caddfab6764bde258b5d09aefbe9169c183f88a7410f0f6dea"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3341c043c37d78cc5ae6e3e305e988532b072329639007fd408a476642a89fd6"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cb53e2a99df28eee3b5f4fea166020d3ef9116fdc5764bc5117486e6d1211b25"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-win32.whl", hash = "sha256:db15ce28e1e127a0013dfb8ac243a8e392db8c61eae113337536edb28bdc1f97"}, - {file = "MarkupSafe-3.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:4ffaaac913c3f7345579db4f33b0020db693f302ca5137f106060316761beea9"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:26627785a54a947f6d7336ce5963569b5d75614619e75193bdb4e06e21d447ad"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b954093679d5750495725ea6f88409946d69cfb25ea7b4c846eef5044194f583"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:973a371a55ce9ed333a3a0f8e0bcfae9e0d637711534bcb11e130af2ab9334e7"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:244dbe463d5fb6d7ce161301a03a6fe744dac9072328ba9fc82289238582697b"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d98e66a24497637dd31ccab090b34392dddb1f2f811c4b4cd80c230205c074a3"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ad91738f14eb8da0ff82f2acd0098b6257621410dcbd4df20aaa5b4233d75a50"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7044312a928a66a4c2a22644147bc61a199c1709712069a344a3fb5cfcf16915"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a4792d3b3a6dfafefdf8e937f14906a51bd27025a36f4b188728a73382231d91"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-win32.whl", hash = "sha256:fa7d686ed9883f3d664d39d5a8e74d3c5f63e603c2e3ff0abcba23eac6542635"}, - {file = "MarkupSafe-3.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:9ba25a71ebf05b9bb0e2ae99f8bc08a07ee8e98c612175087112656ca0f5c8bf"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:8ae369e84466aa70f3154ee23c1451fda10a8ee1b63923ce76667e3077f2b0c4"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40f1e10d51c92859765522cbd79c5c8989f40f0419614bcdc5015e7b6bf97fc5"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a4cb365cb49b750bdb60b846b0c0bc49ed62e59a76635095a179d440540c346"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee3941769bd2522fe39222206f6dd97ae83c442a94c90f2b7a25d847d40f4729"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62fada2c942702ef8952754abfc1a9f7658a4d5460fabe95ac7ec2cbe0d02abc"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c2d64fdba74ad16138300815cfdc6ab2f4647e23ced81f59e940d7d4a1469d9"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:fb532dd9900381d2e8f48172ddc5a59db4c445a11b9fab40b3b786da40d3b56b"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0f84af7e813784feb4d5e4ff7db633aba6c8ca64a833f61d8e4eade234ef0c38"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-win32.whl", hash = "sha256:cbf445eb5628981a80f54087f9acdbf84f9b7d862756110d172993b9a5ae81aa"}, - {file = "MarkupSafe-3.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:a10860e00ded1dd0a65b83e717af28845bb7bd16d8ace40fe5531491de76b79f"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e81c52638315ff4ac1b533d427f50bc0afc746deb949210bc85f05d4f15fd772"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:312387403cd40699ab91d50735ea7a507b788091c416dd007eac54434aee51da"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ae99f31f47d849758a687102afdd05bd3d3ff7dbab0a8f1587981b58a76152a"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c97ff7fedf56d86bae92fa0a646ce1a0ec7509a7578e1ed238731ba13aabcd1c"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7420ceda262dbb4b8d839a4ec63d61c261e4e77677ed7c66c99f4e7cb5030dd"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:45d42d132cff577c92bfba536aefcfea7e26efb975bd455db4e6602f5c9f45e7"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4c8817557d0de9349109acb38b9dd570b03cc5014e8aabf1cbddc6e81005becd"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6a54c43d3ec4cf2a39f4387ad044221c66a376e58c0d0e971d47c475ba79c6b5"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-win32.whl", hash = "sha256:c91b394f7601438ff79a4b93d16be92f216adb57d813a78be4446fe0f6bc2d8c"}, - {file = "MarkupSafe-3.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:fe32482b37b4b00c7a52a07211b479653b7fe4f22b2e481b9a9b099d8a430f2f"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:17b2aea42a7280db02ac644db1d634ad47dcc96faf38ab304fe26ba2680d359a"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:852dc840f6d7c985603e60b5deaae1d89c56cb038b577f6b5b8c808c97580f1d"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0778de17cff1acaeccc3ff30cd99a3fd5c50fc58ad3d6c0e0c4c58092b859396"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:800100d45176652ded796134277ecb13640c1a537cad3b8b53da45aa96330453"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d06b24c686a34c86c8c1fba923181eae6b10565e4d80bdd7bc1c8e2f11247aa4"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:33d1c36b90e570ba7785dacd1faaf091203d9942bc036118fab8110a401eb1a8"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:beeebf760a9c1f4c07ef6a53465e8cfa776ea6a2021eda0d0417ec41043fe984"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:bbde71a705f8e9e4c3e9e33db69341d040c827c7afa6789b14c6e16776074f5a"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-win32.whl", hash = "sha256:82b5dba6eb1bcc29cc305a18a3c5365d2af06ee71b123216416f7e20d2a84e5b"}, - {file = "MarkupSafe-3.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:730d86af59e0e43ce277bb83970530dd223bf7f2a838e086b50affa6ec5f9295"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:4935dd7883f1d50e2ffecca0aa33dc1946a94c8f3fdafb8df5c330e48f71b132"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e9393357f19954248b00bed7c56f29a25c930593a77630c719653d51e7669c2a"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40621d60d0e58aa573b68ac5e2d6b20d44392878e0bfc159012a5787c4e35bc8"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f94190df587738280d544971500b9cafc9b950d32efcb1fba9ac10d84e6aa4e6"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6a387d61fe41cdf7ea95b38e9af11cfb1a63499af2759444b99185c4ab33f5b"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8ad4ad1429cd4f315f32ef263c1342166695fad76c100c5d979c45d5570ed58b"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e24bfe89c6ac4c31792793ad9f861b8f6dc4546ac6dc8f1c9083c7c4f2b335cd"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2a4b34a8d14649315c4bc26bbfa352663eb51d146e35eef231dd739d54a5430a"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-win32.whl", hash = "sha256:242d6860f1fd9191aef5fae22b51c5c19767f93fb9ead4d21924e0bcb17619d8"}, - {file = "MarkupSafe-3.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:93e8248d650e7e9d49e8251f883eed60ecbc0e8ffd6349e18550925e31bd029b"}, - {file = "markupsafe-3.0.1.tar.gz", hash = "sha256:3e683ee4f5d0fa2dde4db77ed8dd8a876686e3fc417655c2ece9a90576905344"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-win32.whl", hash = "sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a"}, + {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] [[package]] @@ -1703,6 +1703,35 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "maturin" +version = "1.7.4" +description = "Build and publish crates with pyo3, cffi and uniffi bindings as well as rust binaries as python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "maturin-1.7.4-py3-none-linux_armv6l.whl", hash = "sha256:eb7b7753b733ae302c08f80bca7b0c3fda1eea665c2b1922c58795f35a54c833"}, + {file = "maturin-1.7.4-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0182a9638399c8835afd39d2aeacf56908e37cba3f7abb15816b9df6774fab81"}, + {file = "maturin-1.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:41a29c5b23f3ebdfe7633637e3de256579a1b2700c04cd68c16ed46934440c5a"}, + {file = "maturin-1.7.4-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:23fae44e345a2da5cb391ae878726fb793394826e2f97febe41710bd4099460e"}, + {file = "maturin-1.7.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:8b441521c151f0dbe70ed06fb1feb29b855d787bda038ff4330ca962e5d56641"}, + {file = "maturin-1.7.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:7ccb66d0c5297cf06652c5f72cb398f447d3a332eccf5d1e73b3fe14dbc9498c"}, + {file = "maturin-1.7.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:71f668f19e719048605dbca6a1f4d0dc03b987c922ad9c4bf5be03b9b278e4c3"}, + {file = "maturin-1.7.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:c179fcb2b494f19186781b667320e43d95b3e71fcb1c98fffad9ef6bd6e276b3"}, + {file = "maturin-1.7.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd5b4b95286f2f376437340f8a4908f4761587212170263084455be8099099a7"}, + {file = "maturin-1.7.4-py3-none-win32.whl", hash = "sha256:35487a424467d1fda4567cbb02d21f09febb10eda22f5fd647b130bc0767dc61"}, + {file = "maturin-1.7.4-py3-none-win_amd64.whl", hash = "sha256:f70c1c8ec9bd4749a53c0f3ae8fdbb326ce45be4f1c5551985ee25a6d7150328"}, + {file = "maturin-1.7.4-py3-none-win_arm64.whl", hash = "sha256:f3d38a6d0c7fd7b04bec30dd470b2173cf9bd184ab6220c1acaf49df6b48faf5"}, + {file = "maturin-1.7.4.tar.gz", hash = "sha256:2b349d742a07527d236f0b4b6cab26f53ebecad0ceabfc09ec4c6a396e3176f9"}, +] + +[package.dependencies] +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} + +[package.extras] +patchelf = ["patchelf"] +zig = ["ziglang (>=0.10.0,<0.13.0)"] + [[package]] name = "mdit-py-plugins" version = "0.4.2" @@ -2144,95 +2173,90 @@ ptyprocess = ">=0.5" [[package]] name = "pillow" -version = "10.4.0" +version = "11.0.0" description = "Python Imaging Library (Fork)" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "pillow-10.4.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e"}, - {file = "pillow-10.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc"}, - {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e"}, - {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46"}, - {file = "pillow-10.4.0-cp310-cp310-win32.whl", hash = "sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984"}, - {file = "pillow-10.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141"}, - {file = "pillow-10.4.0-cp310-cp310-win_arm64.whl", hash = "sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1"}, - {file = "pillow-10.4.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c"}, - {file = "pillow-10.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319"}, - {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d"}, - {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696"}, - {file = "pillow-10.4.0-cp311-cp311-win32.whl", hash = "sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496"}, - {file = "pillow-10.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91"}, - {file = "pillow-10.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22"}, - {file = "pillow-10.4.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94"}, - {file = "pillow-10.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a"}, - {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b"}, - {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9"}, - {file = "pillow-10.4.0-cp312-cp312-win32.whl", hash = "sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42"}, - {file = "pillow-10.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a"}, - {file = "pillow-10.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9"}, - {file = "pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3"}, - {file = "pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc"}, - {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a"}, - {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309"}, - {file = "pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060"}, - {file = "pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea"}, - {file = "pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d"}, - {file = "pillow-10.4.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736"}, - {file = "pillow-10.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd"}, - {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84"}, - {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0"}, - {file = "pillow-10.4.0-cp38-cp38-win32.whl", hash = "sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e"}, - {file = "pillow-10.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab"}, - {file = "pillow-10.4.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d"}, - {file = "pillow-10.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c"}, - {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1"}, - {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df"}, - {file = "pillow-10.4.0-cp39-cp39-win32.whl", hash = "sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef"}, - {file = "pillow-10.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5"}, - {file = "pillow-10.4.0-cp39-cp39-win_arm64.whl", hash = "sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3"}, - {file = "pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06"}, -] - -[package.extras] -docs = ["furo", "olefile", "sphinx (>=7.3)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] + {file = "pillow-11.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947"}, + {file = "pillow-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97"}, + {file = "pillow-11.0.0-cp310-cp310-win32.whl", hash = "sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50"}, + {file = "pillow-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c"}, + {file = "pillow-11.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9"}, + {file = "pillow-11.0.0-cp311-cp311-win32.whl", hash = "sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5"}, + {file = "pillow-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291"}, + {file = "pillow-11.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc"}, + {file = "pillow-11.0.0-cp312-cp312-win32.whl", hash = "sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6"}, + {file = "pillow-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47"}, + {file = "pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb"}, + {file = "pillow-11.0.0-cp313-cp313-win32.whl", hash = "sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798"}, + {file = "pillow-11.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de"}, + {file = "pillow-11.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a"}, + {file = "pillow-11.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8"}, + {file = "pillow-11.0.0-cp313-cp313t-win32.whl", hash = "sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8"}, + {file = "pillow-11.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904"}, + {file = "pillow-11.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae"}, + {file = "pillow-11.0.0-cp39-cp39-win32.whl", hash = "sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4"}, + {file = "pillow-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd"}, + {file = "pillow-11.0.0-cp39-cp39-win_arm64.whl", hash = "sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944"}, + {file = "pillow-11.0.0.tar.gz", hash = "sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] @@ -2286,32 +2310,33 @@ wcwidth = "*" [[package]] name = "psutil" -version = "6.0.0" +version = "6.1.0" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, - {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, - {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, - {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, - {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, - {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, - {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, - {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, - {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, - {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, - {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, - {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, - {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, - {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, + {file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:000d1d1ebd634b4efb383f4034437384e44a6d455260aaee2eca1e9c1b55f047"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:5cd2bcdc75b452ba2e10f0e8ecc0b57b827dd5d7aaffbc6821b2a9a242823a76"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:045f00a43c737f960d273a83973b2511430d61f283a44c96bf13a6e829ba8fdc"}, + {file = "psutil-6.1.0-cp27-none-win32.whl", hash = "sha256:9118f27452b70bb1d9ab3198c1f626c2499384935aaf55388211ad982611407e"}, + {file = "psutil-6.1.0-cp27-none-win_amd64.whl", hash = "sha256:a8506f6119cff7015678e2bce904a4da21025cc70ad283a53b099e7620061d85"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6e2dcd475ce8b80522e51d923d10c7871e45f20918e027ab682f94f1c6351688"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0895b8414afafc526712c498bd9de2b063deaac4021a3b3c34566283464aff8e"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dcbfce5d89f1d1f2546a2090f4fcf87c7f669d1d90aacb7d7582addece9fb38"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498c6979f9c6637ebc3a73b3f87f9eb1ec24e1ce53a7c5173b8508981614a90b"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d905186d647b16755a800e7263d43df08b790d709d575105d419f8b6ef65423a"}, + {file = "psutil-6.1.0-cp36-cp36m-win32.whl", hash = "sha256:6d3fbbc8d23fcdcb500d2c9f94e07b1342df8ed71b948a2649b5cb060a7c94ca"}, + {file = "psutil-6.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1209036fbd0421afde505a4879dee3b2fd7b1e14fee81c0069807adcbbcca747"}, + {file = "psutil-6.1.0-cp37-abi3-win32.whl", hash = "sha256:1ad45a1f5d0b608253b11508f80940985d1d0c8f6111b5cb637533a0e6ddc13e"}, + {file = "psutil-6.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:a8fb3752b491d246034fa4d279ff076501588ce8cbcdbb62c32fd7a377d996be"}, + {file = "psutil-6.1.0.tar.gz", hash = "sha256:353815f59a7f64cdaca1c0307ee13558a0512f6db064e92fe833784f08539c7a"}, ] [package.extras] -test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"] +test = ["pytest", "pytest-xdist", "setuptools"] [[package]] name = "ptyprocess" @@ -2833,13 +2858,13 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rich" -version = "13.9.2" +version = "13.9.3" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.8.0" files = [ - {file = "rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1"}, - {file = "rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c"}, + {file = "rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283"}, + {file = "rich-13.9.3.tar.gz", hash = "sha256:bc1e01b899537598cf02579d2b9f4a415104d3fc439313a7a2c165d76557a08e"}, ] [package.dependencies] @@ -3035,23 +3060,23 @@ stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"] [[package]] name = "setuptools" -version = "75.1.0" +version = "75.3.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-75.1.0-py3-none-any.whl", hash = "sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2"}, - {file = "setuptools-75.1.0.tar.gz", hash = "sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538"}, + {file = "setuptools-75.3.0-py3-none-any.whl", hash = "sha256:f2504966861356aa38616760c0f66568e535562374995367b4e69c7143cf6bcd"}, + {file = "setuptools-75.3.0.tar.gz", hash = "sha256:fba5dd4d766e97be1b1681d98712680ae8f2f26d7881245f2ce9e40714f1a686"}, ] [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12.*)", "pytest-mypy"] [[package]] name = "six" @@ -3363,60 +3388,68 @@ test = ["pytest"] [[package]] name = "sqlalchemy" -version = "2.0.35" +version = "2.0.36" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.35-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:67219632be22f14750f0d1c70e62f204ba69d28f62fd6432ba05ab295853de9b"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4668bd8faf7e5b71c0319407b608f278f279668f358857dbfd10ef1954ac9f90"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb8bea573863762bbf45d1e13f87c2d2fd32cee2dbd50d050f83f87429c9e1ea"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f552023710d4b93d8fb29a91fadf97de89c5926c6bd758897875435f2a939f33"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:016b2e665f778f13d3c438651dd4de244214b527a275e0acf1d44c05bc6026a9"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7befc148de64b6060937231cbff8d01ccf0bfd75aa26383ffdf8d82b12ec04ff"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-win32.whl", hash = "sha256:22b83aed390e3099584b839b93f80a0f4a95ee7f48270c97c90acd40ee646f0b"}, - {file = "SQLAlchemy-2.0.35-cp310-cp310-win_amd64.whl", hash = "sha256:a29762cd3d116585278ffb2e5b8cc311fb095ea278b96feef28d0b423154858e"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e21f66748ab725ade40fa7af8ec8b5019c68ab00b929f6643e1b1af461eddb60"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8a6219108a15fc6d24de499d0d515c7235c617b2540d97116b663dade1a54d62"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:042622a5306c23b972192283f4e22372da3b8ddf5f7aac1cc5d9c9b222ab3ff6"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:627dee0c280eea91aed87b20a1f849e9ae2fe719d52cbf847c0e0ea34464b3f7"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4fdcd72a789c1c31ed242fd8c1bcd9ea186a98ee8e5408a50e610edfef980d71"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:89b64cd8898a3a6f642db4eb7b26d1b28a497d4022eccd7717ca066823e9fb01"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-win32.whl", hash = "sha256:6a93c5a0dfe8d34951e8a6f499a9479ffb9258123551fa007fc708ae2ac2bc5e"}, - {file = "SQLAlchemy-2.0.35-cp311-cp311-win_amd64.whl", hash = "sha256:c68fe3fcde03920c46697585620135b4ecfdfc1ed23e75cc2c2ae9f8502c10b8"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eb60b026d8ad0c97917cb81d3662d0b39b8ff1335e3fabb24984c6acd0c900a2"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6921ee01caf375363be5e9ae70d08ce7ca9d7e0e8983183080211a062d299468"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8cdf1a0dbe5ced887a9b127da4ffd7354e9c1a3b9bb330dce84df6b70ccb3a8d"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93a71c8601e823236ac0e5d087e4f397874a421017b3318fd92c0b14acf2b6db"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e04b622bb8a88f10e439084486f2f6349bf4d50605ac3e445869c7ea5cf0fa8c"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1b56961e2d31389aaadf4906d453859f35302b4eb818d34a26fab72596076bb8"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-win32.whl", hash = "sha256:0f9f3f9a3763b9c4deb8c5d09c4cc52ffe49f9876af41cc1b2ad0138878453cf"}, - {file = "SQLAlchemy-2.0.35-cp312-cp312-win_amd64.whl", hash = "sha256:25b0f63e7fcc2a6290cb5f7f5b4fc4047843504983a28856ce9b35d8f7de03cc"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f021d334f2ca692523aaf7bbf7592ceff70c8594fad853416a81d66b35e3abf9"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05c3f58cf91683102f2f0265c0db3bd3892e9eedabe059720492dbaa4f922da1"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:032d979ce77a6c2432653322ba4cbeabf5a6837f704d16fa38b5a05d8e21fa00"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:2e795c2f7d7249b75bb5f479b432a51b59041580d20599d4e112b5f2046437a3"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:cc32b2990fc34380ec2f6195f33a76b6cdaa9eecf09f0c9404b74fc120aef36f"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-win32.whl", hash = "sha256:9509c4123491d0e63fb5e16199e09f8e262066e58903e84615c301dde8fa2e87"}, - {file = "SQLAlchemy-2.0.35-cp37-cp37m-win_amd64.whl", hash = "sha256:3655af10ebcc0f1e4e06c5900bb33e080d6a1fa4228f502121f28a3b1753cde5"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4c31943b61ed8fdd63dfd12ccc919f2bf95eefca133767db6fbbd15da62078ec"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a62dd5d7cc8626a3634208df458c5fe4f21200d96a74d122c83bc2015b333bc1"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0630774b0977804fba4b6bbea6852ab56c14965a2b0c7fc7282c5f7d90a1ae72"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d625eddf7efeba2abfd9c014a22c0f6b3796e0ffb48f5d5ab106568ef01ff5a"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ada603db10bb865bbe591939de854faf2c60f43c9b763e90f653224138f910d9"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c41411e192f8d3ea39ea70e0fae48762cd11a2244e03751a98bd3c0ca9a4e936"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-win32.whl", hash = "sha256:d299797d75cd747e7797b1b41817111406b8b10a4f88b6e8fe5b5e59598b43b0"}, - {file = "SQLAlchemy-2.0.35-cp38-cp38-win_amd64.whl", hash = "sha256:0375a141e1c0878103eb3d719eb6d5aa444b490c96f3fedab8471c7f6ffe70ee"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ccae5de2a0140d8be6838c331604f91d6fafd0735dbdcee1ac78fc8fbaba76b4"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2a275a806f73e849e1c309ac11108ea1a14cd7058577aba962cd7190e27c9e3c"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:732e026240cdd1c1b2e3ac515c7a23820430ed94292ce33806a95869c46bd139"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:890da8cd1941fa3dab28c5bac3b9da8502e7e366f895b3b8e500896f12f94d11"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c0d8326269dbf944b9201911b0d9f3dc524d64779a07518199a58384c3d37a44"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b76d63495b0508ab9fc23f8152bac63205d2a704cd009a2b0722f4c8e0cba8e0"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-win32.whl", hash = "sha256:69683e02e8a9de37f17985905a5eca18ad651bf592314b4d3d799029797d0eb3"}, - {file = "SQLAlchemy-2.0.35-cp39-cp39-win_amd64.whl", hash = "sha256:aee110e4ef3c528f3abbc3c2018c121e708938adeeff9006428dd7c8555e9b3f"}, - {file = "SQLAlchemy-2.0.35-py3-none-any.whl", hash = "sha256:2ab3f0336c0387662ce6221ad30ab3a5e6499aab01b9790879b6578fd9b8faa1"}, - {file = "sqlalchemy-2.0.35.tar.gz", hash = "sha256:e11d7ea4d24f0a262bccf9a7cd6284c976c5369dac21db237cff59586045ab9f"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59b8f3adb3971929a3e660337f5dacc5942c2cdb760afcabb2614ffbda9f9f72"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37350015056a553e442ff672c2d20e6f4b6d0b2495691fa239d8aa18bb3bc908"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8318f4776c85abc3f40ab185e388bee7a6ea99e7fa3a30686580b209eaa35c08"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c245b1fbade9c35e5bd3b64270ab49ce990369018289ecfde3f9c318411aaa07"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:69f93723edbca7342624d09f6704e7126b152eaed3cdbb634cb657a54332a3c5"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f9511d8dd4a6e9271d07d150fb2f81874a3c8c95e11ff9af3a2dfc35fe42ee44"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-win32.whl", hash = "sha256:c3f3631693003d8e585d4200730616b78fafd5a01ef8b698f6967da5c605b3fa"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-win_amd64.whl", hash = "sha256:a86bfab2ef46d63300c0f06936bd6e6c0105faa11d509083ba8f2f9d237fb5b5"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fd3a55deef00f689ce931d4d1b23fa9f04c880a48ee97af488fd215cf24e2a6c"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4f5e9cd989b45b73bd359f693b935364f7e1f79486e29015813c338450aa5a71"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0ddd9db6e59c44875211bc4c7953a9f6638b937b0a88ae6d09eb46cced54eff"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2519f3a5d0517fc159afab1015e54bb81b4406c278749779be57a569d8d1bb0d"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59b1ee96617135f6e1d6f275bbe988f419c5178016f3d41d3c0abb0c819f75bb"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:39769a115f730d683b0eb7b694db9789267bcd027326cccc3125e862eb03bfd8"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-win32.whl", hash = "sha256:66bffbad8d6271bb1cc2f9a4ea4f86f80fe5e2e3e501a5ae2a3dc6a76e604e6f"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-win_amd64.whl", hash = "sha256:23623166bfefe1487d81b698c423f8678e80df8b54614c2bf4b4cfcd7c711959"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7b64e6ec3f02c35647be6b4851008b26cff592a95ecb13b6788a54ef80bbdd4"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46331b00096a6db1fdc052d55b101dbbfc99155a548e20a0e4a8e5e4d1362855"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdf3386a801ea5aba17c6410dd1dc8d39cf454ca2565541b5ac42a84e1e28f53"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac9dfa18ff2a67b09b372d5db8743c27966abf0e5344c555d86cc7199f7ad83a"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:90812a8933df713fdf748b355527e3af257a11e415b613dd794512461eb8a686"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1bc330d9d29c7f06f003ab10e1eaced295e87940405afe1b110f2eb93a233588"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win32.whl", hash = "sha256:79d2e78abc26d871875b419e1fd3c0bca31a1cb0043277d0d850014599626c2e"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win_amd64.whl", hash = "sha256:b544ad1935a8541d177cb402948b94e871067656b3a0b9e91dbec136b06a2ff5"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5cc79df7f4bc3d11e4b542596c03826063092611e481fcf1c9dfee3c94355ef"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3c01117dd36800f2ecaa238c65365b7b16497adc1522bf84906e5710ee9ba0e8"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9bc633f4ee4b4c46e7adcb3a9b5ec083bf1d9a97c1d3854b92749d935de40b9b"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e46ed38affdfc95d2c958de328d037d87801cfcbea6d421000859e9789e61c2"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b2985c0b06e989c043f1dc09d4fe89e1616aadd35392aea2844f0458a989eacf"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a121d62ebe7d26fec9155f83f8be5189ef1405f5973ea4874a26fab9f1e262c"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win32.whl", hash = "sha256:0572f4bd6f94752167adfd7c1bed84f4b240ee6203a95e05d1e208d488d0d436"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win_amd64.whl", hash = "sha256:8c78ac40bde930c60e0f78b3cd184c580f89456dd87fc08f9e3ee3ce8765ce88"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:be9812b766cad94a25bc63bec11f88c4ad3629a0cec1cd5d4ba48dc23860486b"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50aae840ebbd6cdd41af1c14590e5741665e5272d2fee999306673a1bb1fdb4d"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4557e1f11c5f653ebfdd924f3f9d5ebfc718283b0b9beebaa5dd6b77ec290971"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:07b441f7d03b9a66299ce7ccf3ef2900abc81c0db434f42a5694a37bd73870f2"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:28120ef39c92c2dd60f2721af9328479516844c6b550b077ca450c7d7dc68575"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win32.whl", hash = "sha256:b81ee3d84803fd42d0b154cb6892ae57ea6b7c55d8359a02379965706c7efe6c"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win_amd64.whl", hash = "sha256:f942a799516184c855e1a32fbc7b29d7e571b52612647866d4ec1c3242578fcb"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3d6718667da04294d7df1670d70eeddd414f313738d20a6f1d1f379e3139a545"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:72c28b84b174ce8af8504ca28ae9347d317f9dba3999e5981a3cd441f3712e24"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b11d0cfdd2b095e7b0686cf5fabeb9c67fae5b06d265d8180715b8cfa86522e3"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e32092c47011d113dc01ab3e1d3ce9f006a47223b18422c5c0d150af13a00687"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6a440293d802d3011028e14e4226da1434b373cbaf4a4bbb63f845761a708346"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c54a1e53a0c308a8e8a7dffb59097bff7facda27c70c286f005327f21b2bd6b1"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win32.whl", hash = "sha256:1e0d612a17581b6616ff03c8e3d5eff7452f34655c901f75d62bd86449d9750e"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win_amd64.whl", hash = "sha256:8958b10490125124463095bbdadda5aa22ec799f91958e410438ad6c97a7b793"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dc022184d3e5cacc9579e41805a681187650e170eb2fd70e28b86192a479dcaa"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b817d41d692bf286abc181f8af476c4fbef3fd05e798777492618378448ee689"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4e46a888b54be23d03a89be510f24a7652fe6ff660787b96cd0e57a4ebcb46d"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ae3005ed83f5967f961fd091f2f8c5329161f69ce8480aa8168b2d7fe37f06"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03e08af7a5f9386a43919eda9de33ffda16b44eb11f3b313e6822243770e9763"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3dbb986bad3ed5ceaf090200eba750b5245150bd97d3e67343a3cfed06feecf7"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win32.whl", hash = "sha256:9fe53b404f24789b5ea9003fc25b9a3988feddebd7e7b369c8fac27ad6f52f28"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win_amd64.whl", hash = "sha256:af148a33ff0349f53512a049c6406923e4e02bf2f26c5fb285f143faf4f0e46a"}, + {file = "SQLAlchemy-2.0.36-py3-none-any.whl", hash = "sha256:fddbe92b4760c6f5d48162aef14824add991aeda8ddadb3c31d56eb15ca69f8e"}, + {file = "sqlalchemy-2.0.36.tar.gz", hash = "sha256:7f2767680b6d2398aea7082e45a774b2b0767b5c8d8ffb9c8b683088ea9b29c5"}, ] [package.dependencies] @@ -3429,7 +3462,7 @@ aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] -mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10)"] mssql = ["pyodbc"] mssql-pymssql = ["pymssql"] mssql-pyodbc = ["pyodbc"] @@ -3564,13 +3597,13 @@ files = [ [[package]] name = "tqdm" -version = "4.66.5" +version = "4.66.6" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, - {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, + {file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"}, + {file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"}, ] [package.dependencies] @@ -3693,13 +3726,13 @@ files = [ [[package]] name = "xarray" -version = "2024.9.0" +version = "2024.10.0" description = "N-D labeled arrays and datasets in Python" optional = false python-versions = ">=3.10" files = [ - {file = "xarray-2024.9.0-py3-none-any.whl", hash = "sha256:4fd534abdf12d5fa75dd566c56483d5081f77864462cf3d6ad53e13f9db48222"}, - {file = "xarray-2024.9.0.tar.gz", hash = "sha256:e796a6b3eaec11da24f33e4bb14af41897011660a0516fa4037d3ae4bbd1d378"}, + {file = "xarray-2024.10.0-py3-none-any.whl", hash = "sha256:ae1d38cb44a0324dfb61e492394158ae22389bf7de9f3c174309c17376df63a0"}, + {file = "xarray-2024.10.0.tar.gz", hash = "sha256:e369e2bac430e418c2448e5b96f07da4635f98c1319aa23cfeb3fbcb9a01d2e0"}, ] [package.dependencies] @@ -3708,12 +3741,13 @@ packaging = ">=23.1" pandas = ">=2.1" [package.extras] -accel = ["bottleneck", "flox", "numbagg", "opt-einsum", "scipy"] -complete = ["xarray[accel,dev,io,parallel,viz]"] -dev = ["hypothesis", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-env", "pytest-timeout", "pytest-xdist", "ruff", "xarray[complete]"] +accel = ["bottleneck", "flox", "numba (>=0.54)", "numbagg", "opt-einsum", "scipy"] +complete = ["xarray[accel,etc,io,parallel,viz]"] +dev = ["hypothesis", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-env", "pytest-timeout", "pytest-xdist", "ruff", "sphinx", "sphinx-autosummary-accessors", "xarray[complete]"] +etc = ["sparse"] io = ["cftime", "fsspec", "h5netcdf", "netCDF4", "pooch", "pydap", "scipy", "zarr"] parallel = ["dask[complete]"] -viz = ["matplotlib", "nc-time-axis", "seaborn"] +viz = ["cartopy", "matplotlib", "nc-time-axis", "seaborn"] [[package]] name = "xarray-einstats" @@ -3770,4 +3804,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "57f2c5fa5800793a25d89b3e07908b0b594841c110dacda450fda9aac66a5fd7" +content-hash = "8ad3b512644f5552c6f4c5577959a9fd770b111f55e41ecd54f876f311869d67" From ee5ddea4ae3b49f7bbc519f4ddfe63813ed556a8 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 29 Oct 2024 12:41:56 +0100 Subject: [PATCH 19/20] fix github action --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8fefe136d..526f50ed2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,7 +52,7 @@ jobs: # Step 7: Install the Rust package - name: Build and Install the Package - run: maturin develop + run: poetry run maturin develop # Step 8: Run Tests and Generate Coverage Report - name: Run tests and coverage From b7c28ed78711a2887ade779833924521cf4618cf Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 29 Oct 2024 13:16:47 +0100 Subject: [PATCH 20/20] publish package on release --- .github/workflows/pypi.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 3baf212c5..9fc89e407 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -18,6 +18,12 @@ jobs: curl -sSL https://install.python-poetry.org | python3 - echo "$HOME/.local/bin" >> $GITHUB_PATH poetry self add "poetry-dynamic-versioning[plugin]" + - name: Install the package + run: | + poetry install --with dev + - name: Build wheels + run: | + poetry run maturin build --release - name: Build a binary wheel and a source tarball run: >- poetry build