diff --git a/backend/Cargo.lock b/backend/Cargo.lock index c5d76924..67b7408b 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -2972,9 +2972,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", @@ -2982,14 +2982,12 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "num_cpus", ] [[package]] @@ -3795,6 +3793,7 @@ dependencies = [ "num_cpus", "poseidon-rs", "rand 0.8.5", + "rayon", "regex", "serde", "serde_json", diff --git a/zk_prover/Cargo.lock b/zk_prover/Cargo.lock index 014c93cc..1f4e72c3 100644 --- a/zk_prover/Cargo.lock +++ b/zk_prover/Cargo.lock @@ -3371,9 +3371,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", @@ -3381,14 +3381,12 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "num_cpus", ] [[package]] @@ -4170,6 +4168,7 @@ dependencies = [ "plotters", "poseidon-rs", "rand 0.8.5", + "rayon", "regex", "serde", "serde_json", diff --git a/zk_prover/Cargo.toml b/zk_prover/Cargo.toml index e66502ac..00483a5a 100644 --- a/zk_prover/Cargo.toml +++ b/zk_prover/Cargo.toml @@ -32,6 +32,7 @@ nova-scotia = { git = "https://github.com/nalinbhardwaj/Nova-Scotia" } poseidon-rs = { git = "https://github.com/arnaucube/poseidon-rs" } ff = {package="ff_ce" , version="0.11", features = ["derive"]} num-traits = "0.2.16" +rayon = "1.8.0" [dev-dependencies] criterion= "0.3" diff --git a/zk_prover/src/merkle_sum_tree/mst.rs b/zk_prover/src/merkle_sum_tree/mst.rs index 0a24aef9..4145f7ed 100644 --- a/zk_prover/src/merkle_sum_tree/mst.rs +++ b/zk_prover/src/merkle_sum_tree/mst.rs @@ -26,7 +26,7 @@ pub struct MerkleSumTree { } impl MerkleSumTree { - pub const MAX_DEPTH: usize = 27; + pub const MAX_DEPTH: usize = 29; /// Builds a Merkle Sum Tree from a CSV file stored at `path`. The CSV file must be formatted as follows: /// diff --git a/zk_prover/src/merkle_sum_tree/utils/build_tree.rs b/zk_prover/src/merkle_sum_tree/utils/build_tree.rs index 66b5860b..c3e1741f 100644 --- a/zk_prover/src/merkle_sum_tree/utils/build_tree.rs +++ b/zk_prover/src/merkle_sum_tree/utils/build_tree.rs @@ -1,6 +1,6 @@ use crate::merkle_sum_tree::{Entry, Node}; use halo2_proofs::halo2curves::bn256::Fr as Fp; -use std::thread; +use rayon::prelude::*; pub fn build_merkle_tree_from_entries( entries: &[Entry], @@ -39,7 +39,7 @@ where build_leaves_level(entries, &mut tree); for level in 1..=depth { - build_middle_level(level, &mut tree, n) + build_middle_level(level, &mut tree) } let root = tree[depth][0].clone(); @@ -53,57 +53,27 @@ fn build_leaves_level( ) where [usize; N_ASSETS + 1]: Sized, { - // Compute the leaves in parallel - let mut handles = vec![]; - let chunk_size = (entries.len() + num_cpus::get() - 1) / num_cpus::get(); - for chunk in entries.chunks(chunk_size) { - let chunk = chunk.to_vec(); - handles.push(thread::spawn(move || { - chunk - .into_iter() - .map(|entry| entry.compute_leaf()) - .collect::>() - })); - } + let results = entries + .par_iter() + .map(|entry| entry.compute_leaf()) + .collect::>(); - let mut index = 0; - for handle in handles { - let result = handle.join().unwrap(); - for leaf in result { - tree[0][index] = leaf; - index += 1; - } + for (index, node) in results.iter().enumerate() { + tree[0][index] = node.clone(); } } -fn build_middle_level( - level: usize, - tree: &mut [Vec>], - n: usize, -) where +fn build_middle_level(level: usize, tree: &mut [Vec>]) +where [usize; 2 * (1 + N_ASSETS)]: Sized, { - let nodes_in_level = (n + (1 << level) - 1) / (1 << level); - - let mut handles = vec![]; - let chunk_size = (nodes_in_level + num_cpus::get() - 1) / num_cpus::get(); - - for chunk in tree[level - 1].chunks(chunk_size * 2) { - let chunk = chunk.to_vec(); - handles.push(thread::spawn(move || { - chunk - .chunks(2) - .map(|pair| Node::middle(&pair[0], &pair[1])) - .collect::>() - })); - } + let results: Vec> = (0..tree[level - 1].len()) + .into_par_iter() + .step_by(2) + .map(|index| Node::middle(&tree[level - 1][index], &tree[level - 1][index + 1])) + .collect(); - let mut index = 0; - for handle in handles { - let result = handle.join().unwrap(); - for node in result { - tree[level][index] = node; - index += 1; - } + for (index, new_node) in results.into_iter().enumerate() { + tree[level][index] = new_node; } }