Skip to content

Commit

Permalink
Add faster STARK configuration for testing purposes (#739)
Browse files Browse the repository at this point in the history
Enables fast STARK configurations in integration tests. Achieved up to 82% test time savings. Detailed results are provided below. With the big test time savings for the two_to_one_block_aggregation test, I have removed [ignore] from it.
  • Loading branch information
sai-deng authored Nov 4, 2024
1 parent c38c0c6 commit 73fd6cb
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 50 deletions.
71 changes: 50 additions & 21 deletions evm_arithmetization/src/fixed_recursive_verifier.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::mem::{self, MaybeUninit};
use core::ops::Range;
use std::cmp::max;
use std::collections::BTreeMap;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
Expand All @@ -9,7 +10,6 @@ use hashbrown::HashMap;
use itertools::{zip_eq, Itertools};
use mpt_trie::partial_trie::{HashedPartialTrie, Node, PartialTrie};
use plonky2::field::extension::Extendable;
use plonky2::fri::FriParams;
use plonky2::gates::constant::ConstantGate;
use plonky2::gates::noop::NoopGate;
use plonky2::hash::hash_types::{MerkleCapTarget, RichField, NUM_HASH_OUT_ELTS};
Expand Down Expand Up @@ -57,6 +57,7 @@ use crate::recursive_verifier::{
recursive_stark_circuit, set_final_public_value_targets, set_public_value_targets,
PlonkWrapperCircuit, PublicInputs, StarkWrapperCircuit,
};
use crate::testing_utils::TWO_TO_ONE_BLOCK_CIRCUIT_TEST_THRESHOLD_DEGREE_BITS;
use crate::util::h256_limbs;
use crate::verifier::initial_memory_merkle_cap;

Expand Down Expand Up @@ -786,10 +787,19 @@ where
all_stark: &AllStark<F, D>,
degree_bits_ranges: &[Range<usize>; NUM_TABLES],
stark_config: &StarkConfig,
shrinking_circuit_config: Option<&CircuitConfig>,
recursion_circuit_config: Option<&CircuitConfig>,
threshold_degree_bits: Option<usize>,
) -> Self {
// Sanity check on the provided config
assert_eq!(DEFAULT_CAP_LEN, 1 << stark_config.fri_config.cap_height);

let shrinking_config = shrinking_config();
let shrinking_circuit_config = shrinking_circuit_config.unwrap_or(&shrinking_config);
let circuit_config = CircuitConfig::standard_recursion_config();
let recursion_circuit_config = recursion_circuit_config.unwrap_or(&circuit_config);
let threshold_degree_bits = threshold_degree_bits.unwrap_or(THRESHOLD_DEGREE_BITS);

macro_rules! create_recursive_circuit {
($table_enum:expr, $stark_field:ident) => {
RecursiveCircuitsForTable::new(
Expand All @@ -798,6 +808,8 @@ where
degree_bits_ranges[*$table_enum].clone(),
&all_stark.cross_table_lookups,
stark_config,
shrinking_circuit_config,
threshold_degree_bits,
)
};
}
Expand Down Expand Up @@ -829,7 +841,7 @@ where
poseidon,
];

let root = Self::create_segment_circuit(&by_table, stark_config);
let root = Self::create_segment_circuit(&by_table, stark_config, recursion_circuit_config);
let segment_aggregation = Self::create_segment_aggregation_circuit(&root);
let batch_aggregation =
Self::create_batch_aggregation_circuit(&segment_aggregation, stark_config);
Expand Down Expand Up @@ -897,11 +909,12 @@ where
fn create_segment_circuit(
by_table: &[RecursiveCircuitsForTable<F, C, D>; NUM_TABLES],
stark_config: &StarkConfig,
circuit_config: &CircuitConfig,
) -> RootCircuitData<F, C, D> {
let inner_common_data: [_; NUM_TABLES] =
core::array::from_fn(|i| &by_table[i].final_circuits()[0].common);

let mut builder = CircuitBuilder::new(CircuitConfig::standard_recursion_config());
let mut builder = CircuitBuilder::new(circuit_config.clone());

let table_in_use: [BoolTarget; NUM_TABLES] =
core::array::from_fn(|_| builder.add_virtual_bool_target_safe());
Expand Down Expand Up @@ -1481,18 +1494,9 @@ where
fn create_block_circuit(
agg: &BatchAggregationCircuitData<F, C, D>,
) -> BlockCircuitData<F, C, D> {
// Here, we have two block proofs and we aggregate them together.
// The block circuit is similar to the agg circuit; both verify two inner
// proofs.
let expected_common_data = CommonCircuitData {
fri_params: FriParams {
degree_bits: 14,
..agg.circuit.common.fri_params.clone()
},
..agg.circuit.common.clone()
};
let expected_common_data = agg.circuit.common.clone();

let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config());
let mut builder = CircuitBuilder::<F, D>::new(agg.circuit.common.config.clone());
let public_values = add_virtual_public_values_public_input(&mut builder);
let has_parent_block = builder.add_virtual_bool_target_safe();
let parent_block_proof = builder.add_virtual_proof_with_pis(&expected_common_data);
Expand Down Expand Up @@ -1567,6 +1571,10 @@ where
let agg_verifier_data = builder.constant_verifier_data(&agg.circuit.verifier_only);
builder.verify_proof::<C>(&agg_root_proof, &agg_verifier_data, &agg.circuit.common);

while log2_ceil(builder.num_gates()) < agg.circuit.common.degree_bits() {
builder.add_gate(NoopGate, vec![]);
}

let circuit = builder.build::<C>();
BlockCircuitData {
circuit,
Expand Down Expand Up @@ -1739,7 +1747,17 @@ where
// Pad to match the (non-existing yet!) 2-to-1 circuit's degree.
// We use the block circuit's degree as target reference here, as they end up
// having same degree.
while log2_ceil(builder.num_gates()) < block.circuit.common.degree_bits() {
let degree_bits_to_be_padded = block.circuit.common.degree_bits();

// When using test configurations, the block circuit's degree is less than the
// 2-to-1 circuit's degree. Therefore, we also need to ensure its size meets
// the 2-to-1 circuit's recursion threshold degree bits.
let degree_bits_to_be_padded = max(
degree_bits_to_be_padded,
TWO_TO_ONE_BLOCK_CIRCUIT_TEST_THRESHOLD_DEGREE_BITS,
);

while log2_ceil(builder.num_gates()) < degree_bits_to_be_padded {
builder.add_gate(NoopGate, vec![]);
}

Expand Down Expand Up @@ -1826,6 +1844,11 @@ where

builder.connect_hashes(mix_hash, mix_hash_virtual);

// Pad to match the block circuit's degree.
while log2_ceil(builder.num_gates()) < block_wrapper_circuit.circuit.common.degree_bits() {
builder.add_gate(NoopGate, vec![]);
}

let circuit = builder.build::<C>();
TwoToOneBlockCircuitData {
circuit,
Expand Down Expand Up @@ -2900,6 +2923,8 @@ where
degree_bits_range: Range<usize>,
all_ctls: &[CrossTableLookup<F>],
stark_config: &StarkConfig,
shrinking_circuit_config: &CircuitConfig,
threshold_degree_bits: usize,
) -> Self {
let by_stark_size = degree_bits_range
.map(|degree_bits| {
Expand All @@ -2911,6 +2936,8 @@ where
degree_bits,
all_ctls,
stark_config,
shrinking_circuit_config,
threshold_degree_bits,
),
)
})
Expand Down Expand Up @@ -3023,15 +3050,17 @@ where
degree_bits: usize,
all_ctls: &[CrossTableLookup<F>],
stark_config: &StarkConfig,
shrinking_config: &CircuitConfig,
threshold_degree_bits: usize,
) -> Self {
let initial_wrapper = recursive_stark_circuit(
table,
stark,
degree_bits,
all_ctls,
stark_config,
&shrinking_config(),
THRESHOLD_DEGREE_BITS,
shrinking_config,
threshold_degree_bits,
);
let mut shrinking_wrappers = vec![];

Expand All @@ -3042,12 +3071,12 @@ where
.map(|wrapper: &PlonkWrapperCircuit<F, C, D>| &wrapper.circuit)
.unwrap_or(&initial_wrapper.circuit);
let last_degree_bits = last.common.degree_bits();
assert!(last_degree_bits >= THRESHOLD_DEGREE_BITS);
if last_degree_bits == THRESHOLD_DEGREE_BITS {
assert!(last_degree_bits >= threshold_degree_bits);
if last_degree_bits == threshold_degree_bits {
break;
}

let mut builder = CircuitBuilder::new(shrinking_config());
let mut builder = CircuitBuilder::new(shrinking_config.clone());
let proof_with_pis_target = builder.add_virtual_proof_with_pis(&last.common);
let last_vk = builder.constant_verifier_data(&last.verifier_only);
builder.verify_proof::<C>(&proof_with_pis_target, &last_vk, &last.common);
Expand All @@ -3058,7 +3087,7 @@ where
assert!(
circuit.common.degree_bits() < last_degree_bits,
"Couldn't shrink to expected recursion threshold of 2^{}; stalled at 2^{}",
THRESHOLD_DEGREE_BITS,
threshold_degree_bits,
circuit.common.degree_bits()
);
shrinking_wrappers.push(PlonkWrapperCircuit {
Expand Down
42 changes: 42 additions & 0 deletions evm_arithmetization/src/testing_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ use mpt_trie::{
partial_trie::{HashedPartialTrie, Node, PartialTrie},
};
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::fri::reduction_strategies::FriReductionStrategy;
use plonky2::fri::FriConfig;
use plonky2::plonk::circuit_data::CircuitConfig;
use starky::config::StarkConfig;

pub use crate::cpu::kernel::cancun_constants::*;
pub use crate::cpu::kernel::constants::global_exit_root::*;
Expand All @@ -27,6 +31,44 @@ pub const EMPTY_NODE_HASH: H256 = H256(hex!(
"56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421"
));

/// The recursion threshold when using test configurations
pub const TEST_THRESHOLD_DEGREE_BITS: usize = 10;

/// The recursion threshold for 2-to-1 block circuit.
pub const TWO_TO_ONE_BLOCK_CIRCUIT_TEST_THRESHOLD_DEGREE_BITS: usize = 13;

/// A fast STARK config for testing purposes only.
pub const TEST_STARK_CONFIG: StarkConfig = StarkConfig {
security_bits: 1,
num_challenges: 1,
fri_config: FriConfig {
rate_bits: 1,
cap_height: 4,
proof_of_work_bits: 1,
reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5),
num_query_rounds: 1,
},
};

/// A fast Circuit config for testing purposes only.
pub const TEST_RECURSION_CONFIG: CircuitConfig = CircuitConfig {
num_wires: 135,
num_routed_wires: 80,
num_constants: 2,
use_base_arithmetic_gate: true,
security_bits: 1,
num_challenges: 1,
zero_knowledge: false,
max_quotient_degree_factor: 8,
fri_config: FriConfig {
rate_bits: 3,
cap_height: 4,
proof_of_work_bits: 1,
reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5),
num_query_rounds: 1,
},
};

pub fn init_logger() {
let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "info"));
}
Expand Down
8 changes: 3 additions & 5 deletions evm_arithmetization/tests/add11_yml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots};
use evm_arithmetization::prover::testing::prove_all_segments;
use evm_arithmetization::testing_utils::{
beacon_roots_account_nibbles, beacon_roots_contract_from_storage, init_logger,
preinitialized_state_and_storage_tries, update_beacon_roots_account_storage,
preinitialized_state_and_storage_tries, update_beacon_roots_account_storage, TEST_STARK_CONFIG,
};
use evm_arithmetization::verifier::testing::verify_all_proofs;
use evm_arithmetization::{
AllStark, GenerationInputs, Node, StarkConfig, EMPTY_CONSOLIDATED_BLOCKHASH,
};
use evm_arithmetization::{AllStark, GenerationInputs, Node, EMPTY_CONSOLIDATED_BLOCKHASH};
use hex_literal::hex;
use keccak_hash::keccak;
use mpt_trie::nibbles::Nibbles;
Expand Down Expand Up @@ -208,7 +206,7 @@ fn add11_yml() -> anyhow::Result<()> {
init_logger();

let all_stark = AllStark::<F, D>::default();
let config = StarkConfig::standard_fast_config();
let config = TEST_STARK_CONFIG;
let inputs = get_generation_inputs();

let max_cpu_len_log = 20;
Expand Down
11 changes: 8 additions & 3 deletions evm_arithmetization/tests/empty_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ use std::time::Duration;

use evm_arithmetization::fixed_recursive_verifier::AllRecursiveCircuits;
use evm_arithmetization::prover::prove;
use evm_arithmetization::testing_utils::{init_logger, segment_with_empty_tables};
use evm_arithmetization::testing_utils::{
init_logger, segment_with_empty_tables, TEST_RECURSION_CONFIG, TEST_STARK_CONFIG,
TEST_THRESHOLD_DEGREE_BITS,
};
use evm_arithmetization::verifier::testing::verify_all_proofs;
use evm_arithmetization::AllStark;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::plonk::config::PoseidonGoldilocksConfig;
use plonky2::timed;
use plonky2::util::serialization::{DefaultGateSerializer, DefaultGeneratorSerializer};
use plonky2::util::timing::TimingTree;
use starky::config::StarkConfig;

/// This test focuses on testing zkVM proofs with some empty tables.
#[test]
Expand All @@ -24,7 +26,7 @@ fn empty_tables() -> anyhow::Result<()> {
init_logger();

let all_stark = AllStark::<F, D>::default();
let config = StarkConfig::standard_fast_config();
let config = TEST_STARK_CONFIG;
let timing = &mut TimingTree::new("Empty Table Test", log::Level::Info);

// Generate segment data
Expand Down Expand Up @@ -59,6 +61,9 @@ fn empty_tables() -> anyhow::Result<()> {
&all_stark,
&[16..17, 8..9, 7..8, 4..6, 8..9, 4..5, 16..17, 16..17, 16..17],
&config,
Some(&TEST_RECURSION_CONFIG),
Some(&TEST_RECURSION_CONFIG),
Some(TEST_THRESHOLD_DEGREE_BITS),
)
);

Expand Down
5 changes: 3 additions & 2 deletions evm_arithmetization/tests/erc20.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use evm_arithmetization::prover::testing::prove_all_segments;
use evm_arithmetization::testing_utils::{
beacon_roots_account_nibbles, beacon_roots_contract_from_storage, create_account_storage,
init_logger, preinitialized_state_and_storage_tries, sd2u, update_beacon_roots_account_storage,
TEST_STARK_CONFIG,
};
use evm_arithmetization::verifier::testing::verify_all_proofs;
use evm_arithmetization::{AllStark, Node, StarkConfig, EMPTY_CONSOLIDATED_BLOCKHASH};
use evm_arithmetization::{AllStark, Node, EMPTY_CONSOLIDATED_BLOCKHASH};
use hex_literal::hex;
use keccak_hash::keccak;
use mpt_trie::nibbles::Nibbles;
Expand Down Expand Up @@ -52,7 +53,7 @@ fn test_erc20() -> anyhow::Result<()> {
init_logger();

let all_stark = AllStark::<F, D>::default();
let config = StarkConfig::standard_fast_config();
let config = TEST_STARK_CONFIG;

let beneficiary = hex!("deadbeefdeadbeefdeadbeefdeadbeefdeadbeef");
let sender = hex!("70997970C51812dc3A010C7d01b50e0d17dc79C8");
Expand Down
6 changes: 3 additions & 3 deletions evm_arithmetization/tests/erc721.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ use evm_arithmetization::prover::testing::prove_all_segments;
use evm_arithmetization::testing_utils::{
beacon_roots_account_nibbles, beacon_roots_contract_from_storage, create_account_storage,
init_logger, preinitialized_state_and_storage_tries, sd2u, sh2u,
update_beacon_roots_account_storage,
update_beacon_roots_account_storage, TEST_STARK_CONFIG,
};
use evm_arithmetization::verifier::testing::verify_all_proofs;
use evm_arithmetization::{AllStark, Node, StarkConfig, EMPTY_CONSOLIDATED_BLOCKHASH};
use evm_arithmetization::{AllStark, Node, EMPTY_CONSOLIDATED_BLOCKHASH};
use hex_literal::hex;
use keccak_hash::keccak;
use mpt_trie::nibbles::Nibbles;
Expand Down Expand Up @@ -56,7 +56,7 @@ fn test_erc721() -> anyhow::Result<()> {
init_logger();

let all_stark = AllStark::<F, D>::default();
let config = StarkConfig::standard_fast_config();
let config = TEST_STARK_CONFIG;

let beneficiary = hex!("deadbeefdeadbeefdeadbeefdeadbeefdeadbeef");
let owner = hex!("5B38Da6a701c568545dCfcB03FcB875f56beddC4");
Expand Down
5 changes: 3 additions & 2 deletions evm_arithmetization/tests/global_exit_root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use evm_arithmetization::testing_utils::{
ger_account_nibbles, ger_contract_from_storage, init_logger, scalable_account_nibbles,
scalable_contract_from_storage, update_ger_account_storage, update_scalable_account_storage,
ADDRESS_SCALABLE_L2_ADDRESS_HASHED, GLOBAL_EXIT_ROOT_ACCOUNT, GLOBAL_EXIT_ROOT_ADDRESS_HASHED,
TEST_STARK_CONFIG,
};
use evm_arithmetization::verifier::testing::verify_all_proofs;
use evm_arithmetization::{AllStark, Node, StarkConfig, EMPTY_CONSOLIDATED_BLOCKHASH};
use evm_arithmetization::{AllStark, Node, EMPTY_CONSOLIDATED_BLOCKHASH};
use keccak_hash::keccak;
use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie};
use plonky2::field::goldilocks_field::GoldilocksField;
Expand All @@ -31,7 +32,7 @@ fn test_global_exit_root() -> anyhow::Result<()> {
init_logger();

let all_stark = AllStark::<F, D>::default();
let config = StarkConfig::standard_fast_config();
let config = TEST_STARK_CONFIG;

let block_metadata = BlockMetadata {
block_timestamp: 1.into(),
Expand Down
Loading

0 comments on commit 73fd6cb

Please sign in to comment.