Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove padding examples #974

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use itertools::Itertools;
use num_traits::One;
use tracing::{span, Level};

use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupTraceGenerator, LookupElements};
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId};
use crate::constraint_framework::{
assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
Expand Down Expand Up @@ -32,7 +32,6 @@ relation!(PlonkLookupElements, 2);
pub struct PlonkEval {
pub log_n_rows: u32,
pub lookup_elements: PlonkLookupElements,
pub claimed_sum: ClaimedPrefixSum,
pub total_sum: SecureField,
pub base_trace_location: TreeSubspan,
pub interaction_trace_location: TreeSubspan,
Expand Down Expand Up @@ -119,12 +118,11 @@ pub fn gen_trace(

pub fn gen_interaction_trace(
log_size: u32,
padding_offset: usize,
circuit: &PlonkCircuitTrace,
lookup_elements: &LookupElements<2>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
[SecureField; 2],
SecureField,
) {
let _span = span!(Level::INFO, "Generate interaction trace").entered();
let mut logup_gen = LogupTraceGenerator::new(log_size);
Expand All @@ -148,7 +146,7 @@ pub fn gen_interaction_trace(
}
col_gen.finalize_col();

logup_gen.finalize_at([(1 << log_size) - 1, padding_offset])
logup_gen.finalize_last()
}

#[allow(unused)]
Expand All @@ -163,7 +161,6 @@ pub fn prove_fibonacci_plonk(
for _ in 0..(1 << log_n_rows) {
fib_values.push(fib_values[fib_values.len() - 1] + fib_values[fib_values.len() - 2]);
}
let padding_offset = 17;
let range = 0..(1 << log_n_rows);
let mut circuit = PlonkCircuitTrace {
mult: range.clone().map(|_| 2.into()).collect(),
Expand Down Expand Up @@ -228,8 +225,7 @@ pub fn prove_fibonacci_plonk(

// Interaction trace.
let span = span!(Level::INFO, "Interaction").entered();
let (trace, [total_sum, claimed_sum]) =
gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements.0);
let (trace, total_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements.0);
let mut tree_builder = commitment_scheme.tree_builder();
let interaction_trace_location = tree_builder.extend_evals(trace);
tree_builder.commit(channel);
Expand All @@ -240,13 +236,12 @@ pub fn prove_fibonacci_plonk(
PlonkEval {
log_n_rows,
lookup_elements,
claimed_sum: (claimed_sum, padding_offset),
total_sum,
base_trace_location,
interaction_trace_location,
constants_trace_location,
},
(total_sum, Some((claimed_sum, padding_offset))),
(total_sum, None),
);

// Sanity check. Remove for production.
Expand All @@ -260,7 +255,7 @@ pub fn prove_fibonacci_plonk(
|mut eval| {
component.evaluate(eval);
},
(total_sum, Some((claimed_sum, padding_offset))),
(total_sum, None),
);

let proof = prove(&[&component], channel, commitment_scheme).unwrap();
Expand Down
11 changes: 2 additions & 9 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use num_traits::{One, Zero};

use crate::constraint_framework::logup::ClaimedPrefixSum;
use crate::constraint_framework::relation_tracker::{
RelationTrackerComponent, RelationTrackerEntry,
};
Expand Down Expand Up @@ -36,7 +35,6 @@ pub struct StateTransitionEval<const COORDINATE: usize> {
pub log_n_rows: u32,
pub lookup_elements: StateMachineElements,
pub total_sum: QM31,
pub claimed_sum: ClaimedPrefixSum,
}

impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE> {
Expand Down Expand Up @@ -109,7 +107,6 @@ fn state_transition_info<const INDEX: usize>() -> InfoEvaluator {
log_n_rows: 1,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
};
component.evaluate(InfoEvaluator::empty())
}
Expand Down Expand Up @@ -139,8 +136,6 @@ pub fn track_state_machine_relations(
trace: &TreeVec<&Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
x_axis_log_n_rows: u32,
y_axis_log_n_rows: u32,
n_rows_x: u32,
n_rows_y: u32,
) -> Vec<RelationTrackerEntry> {
let tree_span_provider = &mut TraceLocationAllocator::default();
let mut entries = vec![];
Expand All @@ -151,9 +146,8 @@ pub fn track_state_machine_relations(
log_n_rows: x_axis_log_n_rows,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
},
n_rows_x as usize,
1 << x_axis_log_n_rows,
)
.entries(&trace.into()),
);
Expand All @@ -164,9 +158,8 @@ pub fn track_state_machine_relations(
log_n_rows: y_axis_log_n_rows,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
},
n_rows_y as usize,
1 << y_axis_log_n_rows,
)
.entries(&trace.into()),
);
Expand Down
11 changes: 4 additions & 7 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,14 @@ pub fn gen_trace(
}

pub fn gen_interaction_trace(
n_rows: usize,
trace: &ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
inc_index: usize,
lookup_elements: &StateMachineElements,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
[QM31; 2],
QM31,
) {
let log_size = trace[0].domain.log_size();
assert!(n_rows <= 1 << log_size, "n_rows exceeds the trace size");

let ones = PackedM31::broadcast(M31::one());
let mut logup_gen = LogupTraceGenerator::new(log_size);
Expand All @@ -85,7 +83,7 @@ pub fn gen_interaction_trace(
}
col_gen.finalize_col();

logup_gen.finalize_at([(1 << log_size) - 1, n_rows])
logup_gen.finalize_last()
}

#[cfg(test)]
Expand Down Expand Up @@ -133,11 +131,10 @@ mod tests {
let first_state_comb: QM31 = lookup_elements.combine(&first_state);
let last_state_comb: QM31 = lookup_elements.combine(&last_state);

let (interaction_trace, [total_sum, claimed_sum]) =
gen_interaction_trace((1 << log_size) - 1, &trace, inc_index, &lookup_elements);
let (interaction_trace, total_sum) =
gen_interaction_trace(&trace, inc_index, &lookup_elements);

assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One extension column.
assert_eq!(claimed_sum, total_sum);
assert_eq!(
total_sum,
first_state_comb.inverse() - last_state_comb.inverse()
Expand Down
55 changes: 24 additions & 31 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,12 @@ pub fn prove_state_machine(
Option<RelationSummary>,
) {
let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1);
let (x_row, y_row) = (34, 56);
assert!(y_axis_log_rows >= LOG_N_LANES && x_axis_log_rows >= LOG_N_LANES);
assert!(x_row < 1 << x_axis_log_rows);
assert!(y_row < 1 << y_axis_log_rows);

let mut intermediate_state = initial_state;
intermediate_state[0] += M31::from_u32_unchecked(x_row);
intermediate_state[0] += M31::from_u32_unchecked(1 << x_axis_log_rows);
let mut final_state = intermediate_state;
final_state[1] += M31::from_u32_unchecked(y_row);
final_state[1] += M31::from_u32_unchecked(1 << y_axis_log_rows);

// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
Expand Down Expand Up @@ -82,8 +79,6 @@ pub fn prove_state_machine(
&TreeVec(vec![&preprocessed_trace, &trace]),
x_axis_log_rows,
y_axis_log_rows,
x_row,
y_row,
),
)),
};
Expand All @@ -107,14 +102,14 @@ pub fn prove_state_machine(
let lookup_elements = StateMachineElements::draw(channel);

// Interaction trace.
let (interaction_trace_op0, [total_sum_op0, claimed_sum_op0]) =
gen_interaction_trace(x_row as usize - 1, &trace_op0, 0, &lookup_elements);
let (interaction_trace_op1, [total_sum_op1, claimed_sum_op1]) =
gen_interaction_trace(y_row as usize - 1, &trace_op1, 1, &lookup_elements);
let (interaction_trace_op0, total_sum_op0) =
gen_interaction_trace(&trace_op0, 0, &lookup_elements);
let (interaction_trace_op1, total_sum_op1) =
gen_interaction_trace(&trace_op1, 1, &lookup_elements);

let stmt1 = StateMachineStatement1 {
x_axis_claimed_sum: claimed_sum_op0,
y_axis_claimed_sum: claimed_sum_op1,
x_axis_claimed_sum: total_sum_op0,
y_axis_claimed_sum: total_sum_op1,
};
stmt1.mix_into(channel);

Expand All @@ -130,19 +125,17 @@ pub fn prove_state_machine(
log_n_rows: x_axis_log_rows,
lookup_elements: lookup_elements.clone(),
total_sum: total_sum_op0,
claimed_sum: (claimed_sum_op0, x_row as usize - 1),
},
(total_sum_op0, Some((claimed_sum_op0, x_row as usize - 1))),
(total_sum_op0, None),
);
let component1 = StateMachineOp1Component::new(
tree_span_provider,
StateTransitionEval {
log_n_rows: y_axis_log_rows,
lookup_elements,
total_sum: total_sum_op1,
claimed_sum: (claimed_sum_op1, y_row as usize - 1),
},
(total_sum_op1, Some((claimed_sum_op1, y_row as usize - 1))),
(total_sum_op1, None),
);

tree_span_provider.validate_preprocessed_columns(&preprocessed_columns);
Expand Down Expand Up @@ -231,19 +224,16 @@ mod tests {
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());

// Interaction trace.
let (interaction_trace, [total_sum, claimed_sum]) =
gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements);
let (interaction_trace, total_sum) = gen_interaction_trace(&trace, 0, &lookup_elements);

assert_eq!(total_sum, claimed_sum);
let component = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum,
claimed_sum: (total_sum, (1 << log_n_rows) - 1),
},
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
(total_sum, None),
);

let trace = TreeVec::new(vec![
Expand All @@ -258,7 +248,7 @@ mod tests {
|eval| {
component.evaluate(eval);
},
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
(total_sum, None),
);
}

Expand All @@ -269,7 +259,10 @@ mod tests {

// Initial and last state.
let initial_state = [M31::zero(); STATE_SIZE];
let last_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)];
let last_state = [
M31::from_u32_unchecked(1 << log_n_rows),
M31::from_u32_unchecked(1 << (log_n_rows - 1)),
];

// Setup protocol.
let channel = &mut Blake2sChannel::default();
Expand All @@ -281,7 +274,7 @@ mod tests {
let last_state_comb: QM31 = interaction_elements.combine(&last_state);

assert_eq!(
component.component0.claimed_sum.0 + component.component1.claimed_sum.0,
component.component0.total_sum + component.component1.total_sum,
initial_state_comb.inverse() - last_state_comb.inverse()
);
}
Expand All @@ -291,7 +284,10 @@ mod tests {
let log_n_rows = 8;
let config = PcsConfig::default();
let initial_state = [M31::zero(); STATE_SIZE];
let final_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)];
let final_state = [
M31::from_u32_unchecked(1 << log_n_rows),
M31::from_u32_unchecked(1 << (log_n_rows - 1)),
];

// Summarize `StateMachineElements`.
let (_, _, summary) = prove_state_machine(
Expand Down Expand Up @@ -344,19 +340,16 @@ mod tests {
let trace = gen_trace(log_n_rows, initial_state, 0);
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());

let (_, [total_sum, claimed_sum]) =
gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements);
let (_, total_sum) = gen_interaction_trace(&trace, 0, &lookup_elements);

assert_eq!(total_sum, claimed_sum);
let component = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum,
claimed_sum: (total_sum, (1 << log_n_rows) - 1),
},
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
(total_sum, None),
);

let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));
Expand Down
Loading