Skip to content

Commit

Permalink
Remove padding examples
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Jan 12, 2025
1 parent 31e8dbc commit 1f7dbdd
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 58 deletions.
17 changes: 6 additions & 11 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use itertools::Itertools;
use num_traits::One;
use tracing::{span, Level};

use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupTraceGenerator, LookupElements};
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn};
use crate::constraint_framework::{
assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
Expand Down Expand Up @@ -32,7 +32,6 @@ relation!(PlonkLookupElements, 2);
pub struct PlonkEval {
pub log_n_rows: u32,
pub lookup_elements: PlonkLookupElements,
pub claimed_sum: ClaimedPrefixSum,
pub total_sum: SecureField,
pub base_trace_location: TreeSubspan,
pub interaction_trace_location: TreeSubspan,
Expand Down Expand Up @@ -119,12 +118,11 @@ pub fn gen_trace(

pub fn gen_interaction_trace(
log_size: u32,
padding_offset: usize,
circuit: &PlonkCircuitTrace,
lookup_elements: &LookupElements<2>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
[SecureField; 2],
SecureField,
) {
let _span = span!(Level::INFO, "Generate interaction trace").entered();
let mut logup_gen = LogupTraceGenerator::new(log_size);
Expand All @@ -148,7 +146,7 @@ pub fn gen_interaction_trace(
}
col_gen.finalize_col();

logup_gen.finalize_at([(1 << log_size) - 1, padding_offset])
logup_gen.finalize_last()
}

#[allow(unused)]
Expand All @@ -163,7 +161,6 @@ pub fn prove_fibonacci_plonk(
for _ in 0..(1 << log_n_rows) {
fib_values.push(fib_values[fib_values.len() - 1] + fib_values[fib_values.len() - 2]);
}
let padding_offset = 17;
let range = 0..(1 << log_n_rows);
let mut circuit = PlonkCircuitTrace {
mult: range.clone().map(|_| 2.into()).collect(),
Expand Down Expand Up @@ -228,8 +225,7 @@ pub fn prove_fibonacci_plonk(

// Interaction trace.
let span = span!(Level::INFO, "Interaction").entered();
let (trace, [total_sum, claimed_sum]) =
gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements.0);
let (trace, total_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements.0);
let mut tree_builder = commitment_scheme.tree_builder();
let interaction_trace_location = tree_builder.extend_evals(trace);
tree_builder.commit(channel);
Expand All @@ -240,13 +236,12 @@ pub fn prove_fibonacci_plonk(
PlonkEval {
log_n_rows,
lookup_elements,
claimed_sum: (claimed_sum, padding_offset),
total_sum,
base_trace_location,
interaction_trace_location,
constants_trace_location,
},
(total_sum, Some((claimed_sum, padding_offset))),
(total_sum, None),
);

// Sanity check. Remove for production.
Expand All @@ -260,7 +255,7 @@ pub fn prove_fibonacci_plonk(
|mut eval| {
component.evaluate(eval);
},
(total_sum, Some((claimed_sum, padding_offset))),
(total_sum, None),
);

let proof = prove(&[&component], channel, commitment_scheme).unwrap();
Expand Down
11 changes: 2 additions & 9 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use num_traits::{One, Zero};

use crate::constraint_framework::logup::ClaimedPrefixSum;
use crate::constraint_framework::relation_tracker::{
RelationTrackerComponent, RelationTrackerEntry,
};
Expand Down Expand Up @@ -36,7 +35,6 @@ pub struct StateTransitionEval<const COORDINATE: usize> {
pub log_n_rows: u32,
pub lookup_elements: StateMachineElements,
pub total_sum: QM31,
pub claimed_sum: ClaimedPrefixSum,
}

impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE> {
Expand Down Expand Up @@ -109,7 +107,6 @@ fn state_transition_info<const INDEX: usize>() -> InfoEvaluator {
log_n_rows: 1,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
};
component.evaluate(InfoEvaluator::empty())
}
Expand Down Expand Up @@ -139,8 +136,6 @@ pub fn track_state_machine_relations(
trace: &TreeVec<&Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
x_axis_log_n_rows: u32,
y_axis_log_n_rows: u32,
n_rows_x: u32,
n_rows_y: u32,
) -> Vec<RelationTrackerEntry> {
let tree_span_provider = &mut TraceLocationAllocator::default();
let mut entries = vec![];
Expand All @@ -151,9 +146,8 @@ pub fn track_state_machine_relations(
log_n_rows: x_axis_log_n_rows,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
},
n_rows_x as usize,
1 << x_axis_log_n_rows,
)
.entries(&trace.into()),
);
Expand All @@ -164,9 +158,8 @@ pub fn track_state_machine_relations(
log_n_rows: y_axis_log_n_rows,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
},
n_rows_y as usize,
1 << y_axis_log_n_rows,
)
.entries(&trace.into()),
);
Expand Down
11 changes: 4 additions & 7 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,14 @@ pub fn gen_trace(
}

pub fn gen_interaction_trace(
n_rows: usize,
trace: &ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
inc_index: usize,
lookup_elements: &StateMachineElements,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
[QM31; 2],
QM31,
) {
let log_size = trace[0].domain.log_size();
assert!(n_rows <= 1 << log_size, "n_rows exceeds the trace size");

let ones = PackedM31::broadcast(M31::one());
let mut logup_gen = LogupTraceGenerator::new(log_size);
Expand All @@ -85,7 +83,7 @@ pub fn gen_interaction_trace(
}
col_gen.finalize_col();

logup_gen.finalize_at([(1 << log_size) - 1, n_rows])
logup_gen.finalize_last()
}

#[cfg(test)]
Expand Down Expand Up @@ -133,11 +131,10 @@ mod tests {
let first_state_comb: QM31 = lookup_elements.combine(&first_state);
let last_state_comb: QM31 = lookup_elements.combine(&last_state);

let (interaction_trace, [total_sum, claimed_sum]) =
gen_interaction_trace((1 << log_size) - 1, &trace, inc_index, &lookup_elements);
let (interaction_trace, total_sum) =
gen_interaction_trace(&trace, inc_index, &lookup_elements);

assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One extension column.
assert_eq!(claimed_sum, total_sum);
assert_eq!(
total_sum,
first_state_comb.inverse() - last_state_comb.inverse()
Expand Down
55 changes: 24 additions & 31 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,12 @@ pub fn prove_state_machine(
Option<RelationSummary>,
) {
let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1);
let (x_row, y_row) = (34, 56);
assert!(y_axis_log_rows >= LOG_N_LANES && x_axis_log_rows >= LOG_N_LANES);
assert!(x_row < 1 << x_axis_log_rows);
assert!(y_row < 1 << y_axis_log_rows);

let mut intermediate_state = initial_state;
intermediate_state[0] += M31::from_u32_unchecked(x_row);
intermediate_state[0] += M31::from_u32_unchecked(1 << x_axis_log_rows);
let mut final_state = intermediate_state;
final_state[1] += M31::from_u32_unchecked(y_row);
final_state[1] += M31::from_u32_unchecked(1 << y_axis_log_rows);

// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
Expand Down Expand Up @@ -80,8 +77,6 @@ pub fn prove_state_machine(
&TreeVec(vec![&preprocessed_trace, &trace]),
x_axis_log_rows,
y_axis_log_rows,
x_row,
y_row,
),
)),
};
Expand All @@ -105,14 +100,14 @@ pub fn prove_state_machine(
let lookup_elements = StateMachineElements::draw(channel);

// Interaction trace.
let (interaction_trace_op0, [total_sum_op0, claimed_sum_op0]) =
gen_interaction_trace(x_row as usize - 1, &trace_op0, 0, &lookup_elements);
let (interaction_trace_op1, [total_sum_op1, claimed_sum_op1]) =
gen_interaction_trace(y_row as usize - 1, &trace_op1, 1, &lookup_elements);
let (interaction_trace_op0, total_sum_op0) =
gen_interaction_trace(&trace_op0, 0, &lookup_elements);
let (interaction_trace_op1, total_sum_op1) =
gen_interaction_trace(&trace_op1, 1, &lookup_elements);

let stmt1 = StateMachineStatement1 {
x_axis_claimed_sum: claimed_sum_op0,
y_axis_claimed_sum: claimed_sum_op1,
x_axis_claimed_sum: total_sum_op0,
y_axis_claimed_sum: total_sum_op1,
};
stmt1.mix_into(channel);

Expand All @@ -128,19 +123,17 @@ pub fn prove_state_machine(
log_n_rows: x_axis_log_rows,
lookup_elements: lookup_elements.clone(),
total_sum: total_sum_op0,
claimed_sum: (claimed_sum_op0, x_row as usize - 1),
},
(total_sum_op0, Some((claimed_sum_op0, x_row as usize - 1))),
(total_sum_op0, None),
);
let component1 = StateMachineOp1Component::new(
tree_span_provider,
StateTransitionEval {
log_n_rows: y_axis_log_rows,
lookup_elements,
total_sum: total_sum_op1,
claimed_sum: (claimed_sum_op1, y_row as usize - 1),
},
(total_sum_op1, Some((claimed_sum_op1, y_row as usize - 1))),
(total_sum_op1, None),
);

tree_span_provider.validate_preprocessed_columns(&preprocessed_columns);
Expand Down Expand Up @@ -229,19 +222,16 @@ mod tests {
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());

// Interaction trace.
let (interaction_trace, [total_sum, claimed_sum]) =
gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements);
let (interaction_trace, total_sum) = gen_interaction_trace(&trace, 0, &lookup_elements);

assert_eq!(total_sum, claimed_sum);
let component = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum,
claimed_sum: (total_sum, (1 << log_n_rows) - 1),
},
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
(total_sum, None),
);

let trace = TreeVec::new(vec![
Expand All @@ -256,7 +246,7 @@ mod tests {
|eval| {
component.evaluate(eval);
},
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
(total_sum, None),
);
}

Expand All @@ -267,7 +257,10 @@ mod tests {

// Initial and last state.
let initial_state = [M31::zero(); STATE_SIZE];
let last_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)];
let last_state = [
M31::from_u32_unchecked(1 << log_n_rows),
M31::from_u32_unchecked(1 << (log_n_rows - 1)),
];

// Setup protocol.
let channel = &mut Blake2sChannel::default();
Expand All @@ -279,7 +272,7 @@ mod tests {
let last_state_comb: QM31 = interaction_elements.combine(&last_state);

assert_eq!(
component.component0.claimed_sum.0 + component.component1.claimed_sum.0,
component.component0.total_sum + component.component1.total_sum,
initial_state_comb.inverse() - last_state_comb.inverse()
);
}
Expand All @@ -289,7 +282,10 @@ mod tests {
let log_n_rows = 8;
let config = PcsConfig::default();
let initial_state = [M31::zero(); STATE_SIZE];
let final_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)];
let final_state = [
M31::from_u32_unchecked(1 << log_n_rows),
M31::from_u32_unchecked(1 << (log_n_rows - 1)),
];

// Summarize `StateMachineElements`.
let (_, _, summary) = prove_state_machine(
Expand Down Expand Up @@ -342,19 +338,16 @@ mod tests {
let trace = gen_trace(log_n_rows, initial_state, 0);
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());

let (_, [total_sum, claimed_sum]) =
gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements);
let (_, total_sum) = gen_interaction_trace(&trace, 0, &lookup_elements);

assert_eq!(total_sum, claimed_sum);
let component = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum,
claimed_sum: (total_sum, (1 << log_n_rows) - 1),
},
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
(total_sum, None),
);

let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));
Expand Down

0 comments on commit 1f7dbdd

Please sign in to comment.