Skip to content

Commit

Permalink
fix: applied 'rayon' instead of 'thread' for OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
sifnoc committed Oct 26, 2023
1 parent 4187d53 commit 4207b72
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 60 deletions.
11 changes: 5 additions & 6 deletions backend/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 5 additions & 6 deletions zk_prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions zk_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ nova-scotia = { git = "https://github.com/nalinbhardwaj/Nova-Scotia" }
poseidon-rs = { git = "https://github.com/arnaucube/poseidon-rs" }
ff = {package="ff_ce" , version="0.11", features = ["derive"]}
num-traits = "0.2.16"
rayon = "1.8.0"

[dev-dependencies]
criterion= "0.3"
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct MerkleSumTree<const N_ASSETS: usize, const N_BYTES: usize> {
}

impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTES> {
pub const MAX_DEPTH: usize = 27;
pub const MAX_DEPTH: usize = 29;

/// Builds a Merkle Sum Tree from a CSV file stored at `path`. The CSV file must be formatted as follows:
///
Expand Down
64 changes: 17 additions & 47 deletions zk_prover/src/merkle_sum_tree/utils/build_tree.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::merkle_sum_tree::{Entry, Node};
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use std::thread;
use rayon::prelude::*;

pub fn build_merkle_tree_from_entries<const N_ASSETS: usize>(
entries: &[Entry<N_ASSETS>],
Expand Down Expand Up @@ -39,7 +39,7 @@ where
build_leaves_level(entries, &mut tree);

for level in 1..=depth {
build_middle_level(level, &mut tree, n)
build_middle_level(level, &mut tree)
}

let root = tree[depth][0].clone();
Expand All @@ -53,57 +53,27 @@ fn build_leaves_level<const N_ASSETS: usize>(
) where
[usize; N_ASSETS + 1]: Sized,
{
// Compute the leaves in parallel
let mut handles = vec![];
let chunk_size = (entries.len() + num_cpus::get() - 1) / num_cpus::get();
for chunk in entries.chunks(chunk_size) {
let chunk = chunk.to_vec();
handles.push(thread::spawn(move || {
chunk
.into_iter()
.map(|entry| entry.compute_leaf())
.collect::<Vec<_>>()
}));
}
let results = entries
.par_iter()
.map(|entry| entry.compute_leaf())
.collect::<Vec<_>>();

let mut index = 0;
for handle in handles {
let result = handle.join().unwrap();
for leaf in result {
tree[0][index] = leaf;
index += 1;
}
for (index, node) in results.iter().enumerate() {
tree[0][index] = node.clone();
}
}

fn build_middle_level<const N_ASSETS: usize>(
level: usize,
tree: &mut [Vec<Node<N_ASSETS>>],
n: usize,
) where
fn build_middle_level<const N_ASSETS: usize>(level: usize, tree: &mut [Vec<Node<N_ASSETS>>])
where
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let nodes_in_level = (n + (1 << level) - 1) / (1 << level);

let mut handles = vec![];
let chunk_size = (nodes_in_level + num_cpus::get() - 1) / num_cpus::get();

for chunk in tree[level - 1].chunks(chunk_size * 2) {
let chunk = chunk.to_vec();
handles.push(thread::spawn(move || {
chunk
.chunks(2)
.map(|pair| Node::middle(&pair[0], &pair[1]))
.collect::<Vec<_>>()
}));
}
let results: Vec<Node<N_ASSETS>> = (0..tree[level - 1].len())
.into_par_iter()
.step_by(2)
.map(|index| Node::middle(&tree[level - 1][index], &tree[level - 1][index + 1]))
.collect();

let mut index = 0;
for handle in handles {
let result = handle.join().unwrap();
for node in result {
tree[level][index] = node;
index += 1;
}
for (index, new_node) in results.into_iter().enumerate() {
tree[level][index] = new_node;
}
}

0 comments on commit 4207b72

Please sign in to comment.