Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Parallel Merkle Tree #125

Merged
merged 24 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f76ac13
template for parallel implementation
Dec 14, 2023
63dc5fd
added a maybe useful snippet of rayon mut_par_iter as comment in the …
Dec 14, 2023
6c745c7
to do: try build and run, then test
Dec 14, 2023
fcf8843
solved build issues
Dec 18, 2023
3d11411
solved bug where in new_with_leaf_digest_parallel the index of non le…
Dec 18, 2023
4fec1c5
modified mod.rs to build function only if feature is enabled rather t…
Dec 19, 2023
571b651
adapted unit tests for parallel merkle_tree in merkle_tree/tests/mod.rs
Dec 20, 2023
f44806f
implemented benches for merkle_tree
Dec 20, 2023
4ebc098
modified field merkle tree bench
Dec 20, 2023
5eecd5a
fmt
WizardOfMenlo Dec 21, 2023
0056600
cleanup
WizardOfMenlo Dec 21, 2023
f888171
Fix compilation issue
WizardOfMenlo Dec 21, 2023
4b6be6c
Move to SHA256
WizardOfMenlo Dec 21, 2023
a37041b
More realistic benches
WizardOfMenlo Dec 21, 2023
278445e
spacing in new (parallel)
intx4 Dec 21, 2023
efc0d61
refactored code inside merkle_tree/mod.rs
intx4 Dec 23, 2023
007aeef
refactored merkle_tree/mod.rs with cfg_into_iter macro. Refactored ac…
intx4 Dec 24, 2023
f9ae45a
refactored merkle_tree benches and tests with cfg_iter macro
intx4 Dec 24, 2023
a7684e9
modified merkle_tree/mod.rs new_with_leaf_digest to use macro cfg_ite…
intx4 Dec 24, 2023
8f0a6d4
Make error send
Pratyush Dec 28, 2023
a781861
Parallel iteration fixes
Pratyush Dec 28, 2023
c17337f
Fixes
Pratyush Dec 28, 2023
93e0e8d
Change signature to enable correct iteration
Pratyush Dec 28, 2023
bf076db
Change back to taking `Vec<LeafDigest>`
Pratyush Dec 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Cargo.lock
params
*.swp
*.swo

.vscode
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ path = "benches/signature.rs"
harness = false
required-features = [ "signature" ]

[[bench]]
name = "merkle_tree"
path = "benches/merkle_tree.rs"
harness = false
required-features = [ "merkle_tree" ]

[patch.crates-io]
ark-r1cs-std = { git = "https://github.com/arkworks-rs/r1cs-std/" }
ark-ff = { git = "https://github.com/arkworks-rs/algebra/" }
Expand Down
66 changes: 66 additions & 0 deletions benches/merkle_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#[macro_use]
extern crate criterion;

static NUM_LEAVES: i32 = 1 << 20;

mod bytes_mt_benches {
use ark_crypto_primitives::crh::*;
use ark_crypto_primitives::merkle_tree::*;
use ark_crypto_primitives::to_uncompressed_bytes;
use ark_ff::BigInteger256;
use ark_serialize::CanonicalSerialize;
use ark_std::{test_rng, UniformRand};
use criterion::Criterion;
use std::borrow::Borrow;

use crate::NUM_LEAVES;

type LeafH = sha2::Sha256;
type CompressH = sha2::Sha256;

struct Sha256MerkleTreeParams;

impl Config for Sha256MerkleTreeParams {
type Leaf = [u8];

type LeafDigest = <LeafH as CRHScheme>::Output;
type LeafInnerDigestConverter = ByteDigestConverter<Self::LeafDigest>;
type InnerDigest = <CompressH as TwoToOneCRHScheme>::Output;

type LeafHash = LeafH;
type TwoToOneHash = CompressH;
}
type Sha256MerkleTree = MerkleTree<Sha256MerkleTreeParams>;

pub fn merkle_tree_create(c: &mut Criterion) {
let mut rng = test_rng();
let leaves: Vec<_> = (0..NUM_LEAVES)
.map(|_| {
let rnd = BigInteger256::rand(&mut rng);
to_uncompressed_bytes!(rnd).unwrap()
})
.collect();
let leaf_crh_params = <LeafH as CRHScheme>::setup(&mut rng).unwrap();
let two_to_one_params = <CompressH as TwoToOneCRHScheme>::setup(&mut rng)
.unwrap()
.clone();
c.bench_function("Merkle Tree Create (Leaves as [u8])", move |b| {
b.iter(|| {
Sha256MerkleTree::new(
&leaf_crh_params.clone(),
&two_to_one_params.clone(),
&leaves,
)
.unwrap();
})
});
}

criterion_group! {
name = mt_create;
config = Criterion::default().sample_size(10);
targets = merkle_tree_create
}
}

criterion_main!(crate::bytes_mt_benches::mt_create,);
4 changes: 2 additions & 2 deletions src/crh/bowe_hopwood/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use ark_ec::{
twisted_edwards::Projective as TEProjective, twisted_edwards::TECurveConfig, AdditiveGroup,
CurveGroup,
};
use ark_ff::{biginteger::BigInteger, fields::PrimeField};
use ark_ff::fields::PrimeField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::borrow::Borrow;
use ark_std::cfg_chunks;
Expand Down Expand Up @@ -82,7 +82,7 @@ impl<P: TECurveConfig, W: pedersen::Window> CRHScheme for CRH<P, W> {
let mut c = 0;
let mut range = F::BigInt::from(2_u64);
while range < upper_limit {
range.muln(4);
range <<= 4;
c += 1;
}

Expand Down
6 changes: 3 additions & 3 deletions src/crh/injective_map/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{CryptoError, Error};
use crate::Error;
use ark_std::rand::Rng;
use ark_std::{fmt::Debug, hash::Hash, marker::PhantomData};

Expand All @@ -16,15 +16,15 @@ pub mod constraints;
pub trait InjectiveMap<C: CurveGroup> {
type Output: Clone + Eq + Hash + Default + Debug + CanonicalSerialize + CanonicalDeserialize;

fn injective_map(ge: &C::Affine) -> Result<Self::Output, CryptoError>;
fn injective_map(ge: &C::Affine) -> Result<Self::Output, Error>;
}

pub struct TECompressor;

impl<P: TECurveConfig> InjectiveMap<TEProjective<P>> for TECompressor {
type Output = <P as CurveConfig>::BaseField;

fn injective_map(ge: &TEAffine<P>) -> Result<Self::Output, CryptoError> {
fn injective_map(ge: &TEAffine<P>) -> Result<Self::Output, Error> {
debug_assert!(ge.is_in_correct_subgroup_assuming_on_curve());
Ok(ge.x)
}
Expand Down
2 changes: 1 addition & 1 deletion src/crh/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub use constraints::*;
/// Interface to CRH. Note that in this release, while all implementations of `CRH` have fixed length,
/// variable length CRH may also implement this trait in future.
pub trait CRHScheme {
type Input: ?Sized;
type Input: ?Sized + Send;
type Output: Clone
+ Eq
+ core::fmt::Debug
Expand Down
27 changes: 17 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,29 @@ pub mod snark;
#[cfg(feature = "sponge")]
pub mod sponge;

pub type Error = Box<dyn ark_std::error::Error>;

#[derive(Debug)]
pub enum CryptoError {
pub enum Error {
IncorrectInputLength(usize),
NotPrimeOrder,
GenericError(Box<dyn ark_std::error::Error + Send>),
SerializationError(ark_serialize::SerializationError),
}

impl core::fmt::Display for CryptoError {
impl core::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let msg = match self {
CryptoError::IncorrectInputLength(len) => format!("input length is wrong: {}", len),
CryptoError::NotPrimeOrder => "element is not prime order".to_owned(),
};
write!(f, "{}", msg)
match self {
Self::IncorrectInputLength(len) => write!(f, "incorrect input length: {len}"),
Self::NotPrimeOrder => write!(f, "element is not prime order"),
Self::GenericError(e) => write!(f, "{e}"),
Self::SerializationError(e) => write!(f, "{e}"),
}
}
}

impl ark_std::error::Error for CryptoError {}
impl ark_std::error::Error for Error {}

impl From<ark_serialize::SerializationError> for Error {
fn from(e: ark_serialize::SerializationError) -> Self {
Self::SerializationError(e)
}
}
115 changes: 73 additions & 42 deletions src/merkle_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ mod tests;
#[cfg(feature = "r1cs")]
pub mod constraints;

#[cfg(feature = "parallel")]
use rayon::prelude::*;

/// Convert the hash digest in different layers by converting previous layer's output to
/// `TargetType`, which is a `Borrow` to next layer's input.
pub trait DigestConverter<From, To: ?Sized> {
Expand Down Expand Up @@ -52,15 +55,16 @@ impl<T: CanonicalSerialize> DigestConverter<T, [u8]> for ByteDigestConverter<T>
/// * `LeafHash`: Convert leaf to leaf digest
/// * `TwoToOneHash`: Compress two inner digests to one inner digest
pub trait Config {
type Leaf: ?Sized; // merkle tree does not store the leaf
// leaf layer
type Leaf: ?Sized + Send; // merkle tree does not store the leaf
// leaf layer
type LeafDigest: Clone
+ Eq
+ core::fmt::Debug
+ Hash
+ Default
+ CanonicalSerialize
+ CanonicalDeserialize;
+ CanonicalDeserialize
+ Send;
// transition between leaf layer to inner layer
type LeafInnerDigestConverter: DigestConverter<
Self::LeafDigest,
Expand All @@ -73,7 +77,8 @@ pub trait Config {
+ Hash
+ Default
+ CanonicalSerialize
+ CanonicalDeserialize;
+ CanonicalDeserialize
+ Send;

// Tom's Note: in the future, if we want different hash function, we can simply add more
// types of digest here and specify a digest converter. Same for constraints.
Expand Down Expand Up @@ -229,32 +234,30 @@ impl<P: Config> MerkleTree<P> {
height: usize,
) -> Result<Self, crate::Error> {
// use empty leaf digest
let leaves_digest = vec![P::LeafDigest::default(); 1 << (height - 1)];
Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digest)
let leaf_digests = vec![P::LeafDigest::default(); 1 << (height - 1)];
Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaf_digests)
}

/// Returns a new merkle tree. `leaves.len()` should be power of two.
pub fn new<L: Borrow<P::Leaf>>(
pub fn new<L: AsRef<P::Leaf> + Send>(
leaf_hash_param: &LeafParam<P>,
two_to_one_hash_param: &TwoToOneParam<P>,
leaves: impl IntoIterator<Item = L>,
#[cfg(not(feature = "parallel"))] leaves: impl IntoIterator<Item = L>,
#[cfg(feature = "parallel")] leaves: impl IntoParallelIterator<Item = L>,
) -> Result<Self, crate::Error> {
let mut leaves_digests = Vec::new();

// compute and store hash values for each leaf
for leaf in leaves.into_iter() {
leaves_digests.push(P::LeafHash::evaluate(leaf_hash_param, leaf)?)
}
let leaf_digests: Vec<_> = cfg_into_iter!(leaves)
.map(|input| P::LeafHash::evaluate(leaf_hash_param, input.as_ref()))
.collect::<Result<Vec<_>, _>>()?;

Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digests)
Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaf_digests)
}

pub fn new_with_leaf_digest(
leaf_hash_param: &LeafParam<P>,
two_to_one_hash_param: &TwoToOneParam<P>,
leaves_digest: Vec<P::LeafDigest>,
leaf_digests: Vec<P::LeafDigest>,
) -> Result<Self, crate::Error> {
let leaf_nodes_size = leaves_digest.len();
let leaf_nodes_size = leaf_digests.len();
assert!(
leaf_nodes_size.is_power_of_two() && leaf_nodes_size > 1,
"`leaves.len() should be power of two and greater than one"
Expand All @@ -266,7 +269,7 @@ impl<P: Config> MerkleTree<P> {
let hash_of_empty: P::InnerDigest = P::InnerDigest::default();

// initialize the merkle tree as array of nodes in level order
let mut non_leaf_nodes: Vec<P::InnerDigest> = (0..non_leaf_nodes_size)
let mut non_leaf_nodes: Vec<P::InnerDigest> = cfg_into_iter!(0..non_leaf_nodes_size)
.map(|_| hash_of_empty.clone())
.collect();

Expand All @@ -282,39 +285,67 @@ impl<P: Config> MerkleTree<P> {
{
let start_index = level_indices.pop().unwrap();
let upper_bound = left_child(start_index);
for current_index in start_index..upper_bound {
// `left_child(current_index)` and `right_child(current_index) returns the position of
// leaf in the whole tree (represented as a list in level order). We need to shift it
// by `-upper_bound` to get the index in `leaf_nodes` list.
let left_leaf_index = left_child(current_index) - upper_bound;
let right_leaf_index = right_child(current_index) - upper_bound;
// compute hash
non_leaf_nodes[current_index] = P::TwoToOneHash::evaluate(
&two_to_one_hash_param,
P::LeafInnerDigestConverter::convert(leaves_digest[left_leaf_index].clone())?,
P::LeafInnerDigestConverter::convert(leaves_digest[right_leaf_index].clone())?,
)?
}

cfg_iter_mut!(non_leaf_nodes[start_index..upper_bound])
.enumerate()
.try_for_each(|(i, n)| {
// `left_child(current_index)` and `right_child(current_index) returns the position of
// leaf in the whole tree (represented as a list in level order). We need to shift it
// by `-upper_bound` to get the index in `leaf_nodes` list.

//similarly, we need to rescale i by start_index
//to get the index outside the slice and in the level-ordered list of nodes

let current_index = i + start_index;
let left_leaf_index = left_child(current_index) - upper_bound;
let right_leaf_index = right_child(current_index) - upper_bound;

*n = P::TwoToOneHash::evaluate(
two_to_one_hash_param,
P::LeafInnerDigestConverter::convert(
leaf_digests[left_leaf_index].clone(),
)?,
P::LeafInnerDigestConverter::convert(
leaf_digests[right_leaf_index].clone(),
)?,
)?;
Ok::<(), crate::Error>(())
})?;
}

// compute the hash values for nodes in every other layer in the tree
level_indices.reverse();
for &start_index in &level_indices {
// The layer beginning `start_index` ends at `upper_bound` (exclusive).
let upper_bound = left_child(start_index);
for current_index in start_index..upper_bound {
let left_index = left_child(current_index);
let right_index = right_child(current_index);
non_leaf_nodes[current_index] = P::TwoToOneHash::compress(
&two_to_one_hash_param,
non_leaf_nodes[left_index].clone(),
non_leaf_nodes[right_index].clone(),
)?
}
}

let (nodes_at_level, nodes_at_prev_level) =
non_leaf_nodes[..].split_at_mut(upper_bound);
// Iterate over the nodes at the current level, and compute the hash of each node
cfg_iter_mut!(nodes_at_level[start_index..])
.enumerate()
.try_for_each(|(i, n)| {
// `left_child(current_index)` and `right_child(current_index) returns the position of
// leaf in the whole tree (represented as a list in level order). We need to shift it
// by `-upper_bound` to get the index in `leaf_nodes` list.

//similarly, we need to rescale i by start_index
//to get the index outside the slice and in the level-ordered list of nodes
let current_index = i + start_index;
let left_leaf_index = left_child(current_index) - upper_bound;
let right_leaf_index = right_child(current_index) - upper_bound;

// need for unwrap as Box<Error> does not implement trait Send
*n = P::TwoToOneHash::compress(
two_to_one_hash_param,
nodes_at_prev_level[left_leaf_index].clone(),
nodes_at_prev_level[right_leaf_index].clone(),
)?;
Ok::<_, crate::Error>(())
})?;
}
Ok(MerkleTree {
leaf_nodes: leaves_digest,
leaf_nodes: leaf_digests,
non_leaf_nodes,
height: tree_height,
leaf_hash_param: leaf_hash_param.clone(),
Expand Down
15 changes: 3 additions & 12 deletions src/merkle_tree/tests/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,8 @@ mod byte_mt_tests {

let leaf_crh_params = <LeafH as CRHScheme>::setup(&mut rng).unwrap();
let two_to_one_crh_params = <CompressH as TwoToOneCRHScheme>::setup(&mut rng).unwrap();
let mut tree = JubJubMerkleTree::new(
&leaf_crh_params,
&two_to_one_crh_params,
leaves.iter().map(|v| v.as_slice()),
)
.unwrap();
let mut tree =
JubJubMerkleTree::new(&leaf_crh_params, &two_to_one_crh_params, leaves).unwrap();
let root = tree.root();
for (i, leaf) in leaves.iter().enumerate() {
let cs = ConstraintSystem::<Fq>::new_ref();
Expand Down Expand Up @@ -288,12 +284,7 @@ mod field_mt_tests {
) {
let leaf_crh_params = poseidon_parameters();
let two_to_one_params = leaf_crh_params.clone();
let mut tree = FieldMT::new(
&leaf_crh_params,
&two_to_one_params,
leaves.iter().map(|x| x.as_slice()),
)
.unwrap();
let mut tree = FieldMT::new(&leaf_crh_params, &two_to_one_params, leaves).unwrap();
let root = tree.root();
for (i, leaf) in leaves.iter().enumerate() {
let cs = ConstraintSystem::<F>::new_ref();
Expand Down
Loading
Loading