From 26608289d9da4adeeede9ad08412510acc58e70f Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Thu, 9 Jan 2025 18:32:01 +0200 Subject: [PATCH] Remove padding examples --- crates/prover/src/examples/plonk/mod.rs | 17 ++---- .../src/examples/state_machine/components.rs | 11 +--- .../prover/src/examples/state_machine/gen.rs | 11 ++-- .../prover/src/examples/state_machine/mod.rs | 55 ++++++++----------- 4 files changed, 36 insertions(+), 58 deletions(-) diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index b2bcdbcb3..841079c8d 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use num_traits::One; use tracing::{span, Level}; -use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupTraceGenerator, LookupElements}; +use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; use crate::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId}; use crate::constraint_framework::{ assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, @@ -32,7 +32,6 @@ relation!(PlonkLookupElements, 2); pub struct PlonkEval { pub log_n_rows: u32, pub lookup_elements: PlonkLookupElements, - pub claimed_sum: ClaimedPrefixSum, pub total_sum: SecureField, pub base_trace_location: TreeSubspan, pub interaction_trace_location: TreeSubspan, @@ -119,12 +118,11 @@ pub fn gen_trace( pub fn gen_interaction_trace( log_size: u32, - padding_offset: usize, circuit: &PlonkCircuitTrace, lookup_elements: &LookupElements<2>, ) -> ( ColumnVec>, - [SecureField; 2], + SecureField, ) { let _span = span!(Level::INFO, "Generate interaction trace").entered(); let mut logup_gen = LogupTraceGenerator::new(log_size); @@ -148,7 +146,7 @@ pub fn gen_interaction_trace( } col_gen.finalize_col(); - logup_gen.finalize_at([(1 << log_size) - 1, padding_offset]) + logup_gen.finalize_last() } #[allow(unused)] @@ -163,7 +161,6 @@ pub fn prove_fibonacci_plonk( for _ in 0..(1 << log_n_rows) { fib_values.push(fib_values[fib_values.len() - 1] + fib_values[fib_values.len() - 2]); } - let padding_offset = 17; let range = 0..(1 << log_n_rows); let mut circuit = PlonkCircuitTrace { mult: range.clone().map(|_| 2.into()).collect(), @@ -228,8 +225,7 @@ pub fn prove_fibonacci_plonk( // Interaction trace. let span = span!(Level::INFO, "Interaction").entered(); - let (trace, [total_sum, claimed_sum]) = - gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements.0); + let (trace, total_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements.0); let mut tree_builder = commitment_scheme.tree_builder(); let interaction_trace_location = tree_builder.extend_evals(trace); tree_builder.commit(channel); @@ -240,13 +236,12 @@ pub fn prove_fibonacci_plonk( PlonkEval { log_n_rows, lookup_elements, - claimed_sum: (claimed_sum, padding_offset), total_sum, base_trace_location, interaction_trace_location, constants_trace_location, }, - (total_sum, Some((claimed_sum, padding_offset))), + (total_sum, None), ); // Sanity check. Remove for production. @@ -260,7 +255,7 @@ pub fn prove_fibonacci_plonk( |mut eval| { component.evaluate(eval); }, - (total_sum, Some((claimed_sum, padding_offset))), + (total_sum, None), ); let proof = prove(&[&component], channel, commitment_scheme).unwrap(); diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index 23bcf2977..f7afc675d 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -1,6 +1,5 @@ use num_traits::{One, Zero}; -use crate::constraint_framework::logup::ClaimedPrefixSum; use crate::constraint_framework::relation_tracker::{ RelationTrackerComponent, RelationTrackerEntry, }; @@ -36,7 +35,6 @@ pub struct StateTransitionEval { pub log_n_rows: u32, pub lookup_elements: StateMachineElements, pub total_sum: QM31, - pub claimed_sum: ClaimedPrefixSum, } impl FrameworkEval for StateTransitionEval { @@ -109,7 +107,6 @@ fn state_transition_info() -> InfoEvaluator { log_n_rows: 1, lookup_elements: StateMachineElements::dummy(), total_sum: QM31::zero(), - claimed_sum: (QM31::zero(), 0), }; component.evaluate(InfoEvaluator::empty()) } @@ -139,8 +136,6 @@ pub fn track_state_machine_relations( trace: &TreeVec<&Vec>>, x_axis_log_n_rows: u32, y_axis_log_n_rows: u32, - n_rows_x: u32, - n_rows_y: u32, ) -> Vec { let tree_span_provider = &mut TraceLocationAllocator::default(); let mut entries = vec![]; @@ -151,9 +146,8 @@ pub fn track_state_machine_relations( log_n_rows: x_axis_log_n_rows, lookup_elements: StateMachineElements::dummy(), total_sum: QM31::zero(), - claimed_sum: (QM31::zero(), 0), }, - n_rows_x as usize, + 1 << x_axis_log_n_rows, ) .entries(&trace.into()), ); @@ -164,9 +158,8 @@ pub fn track_state_machine_relations( log_n_rows: y_axis_log_n_rows, lookup_elements: StateMachineElements::dummy(), total_sum: QM31::zero(), - claimed_sum: (QM31::zero(), 0), }, - n_rows_y as usize, + 1 << y_axis_log_n_rows, ) .entries(&trace.into()), ); diff --git a/crates/prover/src/examples/state_machine/gen.rs b/crates/prover/src/examples/state_machine/gen.rs index ec61388d2..f32db6616 100644 --- a/crates/prover/src/examples/state_machine/gen.rs +++ b/crates/prover/src/examples/state_machine/gen.rs @@ -52,16 +52,14 @@ pub fn gen_trace( } pub fn gen_interaction_trace( - n_rows: usize, trace: &ColumnVec>, inc_index: usize, lookup_elements: &StateMachineElements, ) -> ( ColumnVec>, - [QM31; 2], + QM31, ) { let log_size = trace[0].domain.log_size(); - assert!(n_rows <= 1 << log_size, "n_rows exceeds the trace size"); let ones = PackedM31::broadcast(M31::one()); let mut logup_gen = LogupTraceGenerator::new(log_size); @@ -85,7 +83,7 @@ pub fn gen_interaction_trace( } col_gen.finalize_col(); - logup_gen.finalize_at([(1 << log_size) - 1, n_rows]) + logup_gen.finalize_last() } #[cfg(test)] @@ -133,11 +131,10 @@ mod tests { let first_state_comb: QM31 = lookup_elements.combine(&first_state); let last_state_comb: QM31 = lookup_elements.combine(&last_state); - let (interaction_trace, [total_sum, claimed_sum]) = - gen_interaction_trace((1 << log_size) - 1, &trace, inc_index, &lookup_elements); + let (interaction_trace, total_sum) = + gen_interaction_trace(&trace, inc_index, &lookup_elements); assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One extension column. - assert_eq!(claimed_sum, total_sum); assert_eq!( total_sum, first_state_comb.inverse() - last_state_comb.inverse() diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 8efc91246..d892621af 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -36,15 +36,12 @@ pub fn prove_state_machine( Option, ) { let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1); - let (x_row, y_row) = (34, 56); assert!(y_axis_log_rows >= LOG_N_LANES && x_axis_log_rows >= LOG_N_LANES); - assert!(x_row < 1 << x_axis_log_rows); - assert!(y_row < 1 << y_axis_log_rows); let mut intermediate_state = initial_state; - intermediate_state[0] += M31::from_u32_unchecked(x_row); + intermediate_state[0] += M31::from_u32_unchecked(1 << x_axis_log_rows); let mut final_state = intermediate_state; - final_state[1] += M31::from_u32_unchecked(y_row); + final_state[1] += M31::from_u32_unchecked(1 << y_axis_log_rows); // Precompute twiddles. let twiddles = SimdBackend::precompute_twiddles( @@ -82,8 +79,6 @@ pub fn prove_state_machine( &TreeVec(vec![&preprocessed_trace, &trace]), x_axis_log_rows, y_axis_log_rows, - x_row, - y_row, ), )), }; @@ -107,14 +102,14 @@ pub fn prove_state_machine( let lookup_elements = StateMachineElements::draw(channel); // Interaction trace. - let (interaction_trace_op0, [total_sum_op0, claimed_sum_op0]) = - gen_interaction_trace(x_row as usize - 1, &trace_op0, 0, &lookup_elements); - let (interaction_trace_op1, [total_sum_op1, claimed_sum_op1]) = - gen_interaction_trace(y_row as usize - 1, &trace_op1, 1, &lookup_elements); + let (interaction_trace_op0, total_sum_op0) = + gen_interaction_trace(&trace_op0, 0, &lookup_elements); + let (interaction_trace_op1, total_sum_op1) = + gen_interaction_trace(&trace_op1, 1, &lookup_elements); let stmt1 = StateMachineStatement1 { - x_axis_claimed_sum: claimed_sum_op0, - y_axis_claimed_sum: claimed_sum_op1, + x_axis_claimed_sum: total_sum_op0, + y_axis_claimed_sum: total_sum_op1, }; stmt1.mix_into(channel); @@ -130,9 +125,8 @@ pub fn prove_state_machine( log_n_rows: x_axis_log_rows, lookup_elements: lookup_elements.clone(), total_sum: total_sum_op0, - claimed_sum: (claimed_sum_op0, x_row as usize - 1), }, - (total_sum_op0, Some((claimed_sum_op0, x_row as usize - 1))), + (total_sum_op0, None), ); let component1 = StateMachineOp1Component::new( tree_span_provider, @@ -140,9 +134,8 @@ pub fn prove_state_machine( log_n_rows: y_axis_log_rows, lookup_elements, total_sum: total_sum_op1, - claimed_sum: (claimed_sum_op1, y_row as usize - 1), }, - (total_sum_op1, Some((claimed_sum_op1, y_row as usize - 1))), + (total_sum_op1, None), ); tree_span_provider.validate_preprocessed_columns(&preprocessed_columns); @@ -231,19 +224,16 @@ mod tests { let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default()); // Interaction trace. - let (interaction_trace, [total_sum, claimed_sum]) = - gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements); + let (interaction_trace, total_sum) = gen_interaction_trace(&trace, 0, &lookup_elements); - assert_eq!(total_sum, claimed_sum); let component = StateMachineOp0Component::new( &mut TraceLocationAllocator::default(), StateTransitionEval { log_n_rows, lookup_elements, total_sum, - claimed_sum: (total_sum, (1 << log_n_rows) - 1), }, - (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), + (total_sum, None), ); let trace = TreeVec::new(vec![ @@ -258,7 +248,7 @@ mod tests { |eval| { component.evaluate(eval); }, - (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), + (total_sum, None), ); } @@ -269,7 +259,10 @@ mod tests { // Initial and last state. let initial_state = [M31::zero(); STATE_SIZE]; - let last_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)]; + let last_state = [ + M31::from_u32_unchecked(1 << log_n_rows), + M31::from_u32_unchecked(1 << (log_n_rows - 1)), + ]; // Setup protocol. let channel = &mut Blake2sChannel::default(); @@ -281,7 +274,7 @@ mod tests { let last_state_comb: QM31 = interaction_elements.combine(&last_state); assert_eq!( - component.component0.claimed_sum.0 + component.component1.claimed_sum.0, + component.component0.total_sum + component.component1.total_sum, initial_state_comb.inverse() - last_state_comb.inverse() ); } @@ -291,7 +284,10 @@ mod tests { let log_n_rows = 8; let config = PcsConfig::default(); let initial_state = [M31::zero(); STATE_SIZE]; - let final_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)]; + let final_state = [ + M31::from_u32_unchecked(1 << log_n_rows), + M31::from_u32_unchecked(1 << (log_n_rows - 1)), + ]; // Summarize `StateMachineElements`. let (_, _, summary) = prove_state_machine( @@ -344,19 +340,16 @@ mod tests { let trace = gen_trace(log_n_rows, initial_state, 0); let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default()); - let (_, [total_sum, claimed_sum]) = - gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements); + let (_, total_sum) = gen_interaction_trace(&trace, 0, &lookup_elements); - assert_eq!(total_sum, claimed_sum); let component = StateMachineOp0Component::new( &mut TraceLocationAllocator::default(), StateTransitionEval { log_n_rows, lookup_elements, total_sum, - claimed_sum: (total_sum, (1 << log_n_rows) - 1), }, - (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), + (total_sum, None), ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));