Skip to content

Commit

Permalink
Added assertion for node and root balance while generating MST
Browse files Browse the repository at this point in the history
  • Loading branch information
sifnoc committed Aug 14, 2024
1 parent 23587ad commit 780a4df
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 25 deletions.
1 change: 1 addition & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
run: |
cd zk_prover
cargo test --release --features dev-graph -- --nocapture
cargo test --package summa-solvency --lib --features "skip-node-balance-check" -- circuits::tests::test::test_balance_not_in_range
test-zk-prover-examples:
runs-on: ubuntu-latest
Expand Down
17 changes: 17 additions & 0 deletions csv/entry_16_max_node_balance.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
username,balance_ETH_ETH,balance_USDT_ETH
dxGaEAii,18446744073709551615,18446744073709551615
MBlfbBGI,0,0
lAhWlEWZ,18446744073709551615,18446744073709551615
nuZweYtO,0,0
gbdSwiuY,18446744073709551615,18446744073709551615
RZNneNuP,0,0
YsscHXkp,0,0
RkLzkDun,0,0
HlQlnEYI,18446744073709551615,18446744073709551615
RqkZOFYe,0,0
NjCSRAfD,0,0
pHniJMQY,0,0
dOGIMzKR,18446744073709551615,18446744073709551615
HfMDmNLp,0,0
xPLKzCBl,0,0
AtwIxZHo,0,2
2 changes: 1 addition & 1 deletion zk_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ edition = "2021"

[features]
dev-graph = ["halo2_proofs/dev-graph", "plotters"]

skip-node-balance-check = []

[dependencies]
halo2_proofs = { git = "https://github.com/summa-dev/halo2"}
Expand Down
6 changes: 6 additions & 0 deletions zk_prover/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ cargo build
cargo test --release --features dev-graph
```

For testing the overflow case that in range check, execute:

```
cargo test --package summa-solvency --lib --features "skip-node-balance-check" -- circuits::tests::test::test_balance_not_in_range
```

## Documentation

The documentation for the circuits can be generated by running
Expand Down
16 changes: 7 additions & 9 deletions zk_prover/examples/gen_inclusion_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ use num_traits::Num;
use prelude::*;

use halo2_solidity_verifier::{compile_solidity, BatchOpenScheme::Bdfg21, SolidityGenerator};
use summa_solvency::circuits::utils::generate_setup_artifacts;
use summa_solvency::circuits::{merkle_sum_tree::MstInclusionCircuit, WithInstances};
use summa_solvency::{
circuits::{
utils::generate_setup_artifacts,
{merkle_sum_tree::MstInclusionCircuit, WithInstances},
},
merkle_sum_tree::utils::calculate_max_root_balance,
};

const LEVELS: usize = 4;
const N_CURRENCIES: usize = 2;
Expand Down Expand Up @@ -52,13 +57,6 @@ fn save_solidity(name: impl AsRef<str>, solidity: &str) {
println!("Saved {path}");
}

// Calculate the maximum value that the Merkle Root can have, given N_BYTES and LEVELS
fn calculate_max_root_balance(n_bytes: usize, n_levels: usize) -> BigInt {
// The max value that can be stored in a leaf node or a sibling node, according to the constraint set in the circuit
let max_leaf_value = BigInt::from(2).pow(n_bytes as u32 * 8) - 1;
max_leaf_value * (n_levels + 1)
}

// Given a combination of `N_BYTES` and `LEVELS`, check if there is a risk of overflow in the Merkle Root
fn is_there_risk_of_overflow(n_bytes: usize, n_levels: usize) -> bool {
// Calculate the max root balance value
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/circuits/merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ where
{
pub fn init_empty() -> Self {
Self {
entry: Entry::zero_entry(),
entry: Entry::init_empty(),
path_indices: vec![Fp::zero(); LEVELS],
sibling_leaf_node_hash_preimage: [Fp::zero(); N_CURRENCIES + 1],
sibling_middle_node_hash_preimages: vec![[Fp::zero(); N_CURRENCIES + 2]; LEVELS],
Expand Down
1 change: 1 addition & 0 deletions zk_prover/src/circuits/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ mod test {

// Building a proof using as input a csv file with an entry that is not in range [0, 2^N_BYTES*8 - 1] should fail the range check constraint on the leaf balance
#[test]
#[cfg(feature = "skip-node-balance-check")]
fn test_balance_not_in_range() {
let merkle_sum_tree =
MerkleSumTree::<N_CURRENCIES, N_BYTES>::from_csv("../csv/entry_16_overflow.csv")
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/merkle_sum_tree/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl<const N_CURRENCIES: usize> Entry<N_CURRENCIES> {
}

/// Returns a zero entry where the username is 0 and the balances are all 0
pub fn zero_entry() -> Self {
pub fn init_empty() -> Self {
let empty_balances: [BigUint; N_CURRENCIES] = std::array::from_fn(|_| BigUint::from(0u32));

Entry {
Expand Down
4 changes: 2 additions & 2 deletions zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ impl<const N_CURRENCIES: usize, const N_BYTES: usize> MerkleSumTree<N_CURRENCIES
// Pad the entries with empty entries to make the number of entries equal to 2^depth
if entries.len() < 2usize.pow(depth as u32) {
entries.extend(vec![
Entry::zero_entry();
Entry::init_empty();
2usize.pow(depth as u32) - entries.len()
]);
}

let leaves = build_leaves_from_entries(&entries);

let (root, nodes) = build_merkle_tree_from_leaves(&leaves, depth)?;
let (root, nodes) = build_merkle_tree_from_leaves::<N_CURRENCIES, N_BYTES>(&leaves, depth)?;

Ok(MerkleSumTree {
root,
Expand Down
24 changes: 20 additions & 4 deletions zk_prover/src/merkle_sum_tree/tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#[cfg(test)]
mod test {

use crate::merkle_sum_tree::utils::big_uint_to_fp;
use crate::merkle_sum_tree::utils::{
big_uint_to_fp, build_leaves_from_entries, build_merkle_tree_from_leaves,
parse_csv_to_entries,
};
use crate::merkle_sum_tree::{Entry, MerkleSumTree, Node, Tree};
use num_bigint::{BigUint, ToBigUint};
use rand::Rng as _;
Expand Down Expand Up @@ -65,6 +67,20 @@ mod test {
assert!(!merkle_tree.verify_proof(&proof_invalid_2));
}

#[test]
#[should_panic(expected = "Node balance is exceed limit")]
fn test_node_and_root_balance_limit() {
let (_, entries) = parse_csv_to_entries::<&str, N_CURRENCIES, N_BYTES>(
"../csv/entry_16_max_node_balance.csv",
)
.unwrap();

let leaves = build_leaves_from_entries::<N_CURRENCIES>(&entries);
let depth = (entries.len() as f64).log2().ceil() as usize;

let _ = build_merkle_tree_from_leaves::<N_CURRENCIES, N_BYTES>(&leaves, depth);
}

#[test]
fn test_update_mst_leaf() {
let merkle_tree_1 =
Expand Down Expand Up @@ -210,7 +226,7 @@ mod test {
// The last 3 entries of the merkle tree should be zero entries
for i in 13..16 {
let entry = merkle_tree.entries()[i].clone();
assert_eq!(entry, Entry::<N_CURRENCIES>::zero_entry());
assert_eq!(entry, Entry::<N_CURRENCIES>::init_empty());
}

// expect root hash to be different than 0
Expand Down Expand Up @@ -242,7 +258,7 @@ mod test {
// The last 15 entries of the merkle tree should be zero entries
for i in 17..32 {
let entry = merkle_tree.entries()[i].clone();
assert_eq!(entry, Entry::<N_CURRENCIES>::zero_entry());
assert_eq!(entry, Entry::<N_CURRENCIES>::init_empty());
}

// expect root hash to be different than 0
Expand Down
50 changes: 44 additions & 6 deletions zk_prover/src/merkle_sum_tree/utils/build_tree.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::merkle_sum_tree::{Entry, Node};
use crate::merkle_sum_tree::{utils::big_uint_to_fp, Entry, Node};
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use num_bigint::BigInt;
use rayon::prelude::*;

pub fn build_merkle_tree_from_leaves<const N_CURRENCIES: usize>(
pub fn build_merkle_tree_from_leaves<const N_CURRENCIES: usize, const N_BYTES: usize>(
leaves: &[Node<N_CURRENCIES>],
depth: usize,
) -> Result<(Node<N_CURRENCIES>, Vec<Vec<Node<N_CURRENCIES>>>), Box<dyn std::error::Error>>
Expand All @@ -18,8 +19,28 @@ where

tree.push(leaves.to_vec());

// The allowed_max_root_balance should be safe if set to half of `calculate_max_root_balance`.
// To achieve this, we use a depth one level lower here.
let allowed_max_node_balance = big_uint_to_fp(
&calculate_max_root_balance(N_BYTES, depth - 1)
.to_biguint()
.unwrap(),
);
let allowed_max_root_balance = big_uint_to_fp(
&calculate_max_root_balance(N_BYTES, depth)
.to_biguint()
.unwrap(),
);

for level in 1..=depth {
build_middle_level(level, &mut tree)
// Determine the maximum node balance based on the current level
let max_node_balance = if level == depth {
allowed_max_root_balance
} else {
allowed_max_node_balance
};

build_middle_level::<N_CURRENCIES, N_BYTES>(level, &mut tree, max_node_balance)
}

let root = tree[depth][0].clone();
Expand All @@ -33,14 +54,14 @@ where
[usize; N_CURRENCIES + 1]: Sized,
{
// Precompute the zero leaf (this will only be used if we encounter a zero entry)
let zero_leaf = Entry::<N_CURRENCIES>::zero_entry().compute_leaf();
let zero_leaf = Entry::<N_CURRENCIES>::init_empty().compute_leaf();

let leaves = entries
.par_iter()
.map(|entry| {
// If the entry is the zero entry then we return the precomputed zero leaf
// Otherwise, we compute the leaf as usual
if entry == &Entry::<N_CURRENCIES>::zero_entry() {
if entry == &Entry::<N_CURRENCIES>::init_empty() {
zero_leaf.clone()
} else {
entry.compute_leaf()
Expand All @@ -51,9 +72,10 @@ where
leaves
}

fn build_middle_level<const N_CURRENCIES: usize>(
fn build_middle_level<const N_CURRENCIES: usize, const N_BYTES: usize>(
level: usize,
tree: &mut Vec<Vec<Node<N_CURRENCIES>>>,
max_node_balance: Fp,
) where
[usize; N_CURRENCIES + 2]: Sized,
{
Expand All @@ -66,6 +88,15 @@ fn build_middle_level<const N_CURRENCIES: usize>(
for (i, balance) in hash_preimage.iter_mut().enumerate().take(N_CURRENCIES) {
*balance =
tree[level - 1][index].balances[i] + tree[level - 1][index + 1].balances[i];

// This conditional is for the test case `test_balance_not_in_range` that performs exceed case while generating proof
if !cfg!(feature = "skip-node-balance-check") {
assert!(
balance.to_owned() <= max_node_balance,
"{}",
format!("Node balance is exceed limit: {:#?}", max_node_balance),
);
}
}

hash_preimage[N_CURRENCIES] = tree[level - 1][index].hash;
Expand All @@ -76,3 +107,10 @@ fn build_middle_level<const N_CURRENCIES: usize>(

tree.push(results);
}

// Calculate the maximum value that the Merkle Root can have, given N_BYTES and LEVELS
pub fn calculate_max_root_balance(n_bytes: usize, n_levels: usize) -> BigInt {
// The max value that can be stored in a leaf node or a sibling node, according to the constraint set in the circuit
let max_leaf_value = BigInt::from(2).pow(n_bytes as u32 * 8) - 1;
max_leaf_value * (n_levels + 1)
}
4 changes: 3 additions & 1 deletion zk_prover/src/merkle_sum_tree/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod build_tree;
mod csv_parser;
mod operation_helpers;

pub use build_tree::{build_leaves_from_entries, build_merkle_tree_from_leaves};
pub use build_tree::{
build_leaves_from_entries, build_merkle_tree_from_leaves, calculate_max_root_balance,
};
pub use csv_parser::parse_csv_to_entries;
pub use operation_helpers::*;

0 comments on commit 780a4df

Please sign in to comment.