diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs index b61bbf51d..ce37a0cb3 100644 --- a/crates/prover/src/constraint_framework/assert.rs +++ b/crates/prover/src/constraint_framework/assert.rs @@ -1,7 +1,7 @@ use num_traits::{One, Zero}; -use super::logup::LogupAtRow; -use super::EvalAtRow; +use super::logup::{LogupAtRow, LogupSums}; +use super::{EvalAtRow, INTERACTION_TRACE_IDX}; use crate::core::backend::{Backend, Column}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; @@ -19,12 +19,17 @@ pub struct AssertEvaluator<'a> { pub logup: LogupAtRow, } impl<'a> AssertEvaluator<'a> { - pub fn new(trace: &'a TreeVec>>, row: usize) -> Self { + pub fn new( + trace: &'a TreeVec>>, + row: usize, + log_size: u32, + logup_sums: LogupSums, + ) -> Self { Self { trace, col_index: TreeVec::new(vec![0; trace.len()]), row, - logup: LogupAtRow::dummy(), + logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), } } } @@ -69,6 +74,7 @@ pub fn assert_constraints( trace_polys: &TreeVec>>, trace_domain: CanonicCoset, assert_func: impl Fn(AssertEvaluator<'_>), + logup_sums: LogupSums, ) { let traces = trace_polys.as_ref().map(|tree| { tree.iter() @@ -84,7 +90,8 @@ pub fn assert_constraints( .collect() }); for row in 0..trace_domain.size() { - let eval = AssertEvaluator::new(&traces, row); + let eval = AssertEvaluator::new(&traces, row, trace_domain.log_size(), logup_sums); + assert_func(eval); } } diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 8c55f951a..ec61d7615 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -10,6 +10,7 @@ use rayon::prelude::*; use tracing::{span, Level}; use super::cpu_domain::CpuDomainEvaluator; +use super::logup::LogupSums; use super::preprocessed_columns::PreprocessedColumn; use super::{ EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX, @@ -113,11 +114,16 @@ pub struct FrameworkComponent { trace_locations: TreeVec, info: InfoEvaluator, preprocessed_column_indices: Vec, + logup_sums: LogupSums, } impl FrameworkComponent { - pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self { - let info = eval.evaluate(InfoEvaluator::default()); + pub fn new( + location_allocator: &mut TraceLocationAllocator, + eval: E, + logup_sums: LogupSums, + ) -> Self { + let info = eval.evaluate(InfoEvaluator::new(eval.log_size(), vec![], logup_sums)); let trace_locations = location_allocator.next_for_structure(&info.mask_offsets); let preprocessed_column_indices = info @@ -148,6 +154,7 @@ impl FrameworkComponent { trace_locations, info, preprocessed_column_indices, + logup_sums, } } @@ -217,6 +224,8 @@ impl Component for FrameworkComponent { mask_points, evaluation_accumulator, coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(), + self.eval.log_size(), + self.logup_sums, )); } } @@ -296,6 +305,8 @@ impl ComponentProver for FrameworkComponen &accum.random_coeff_powers, trace_domain.log_size(), eval_domain.log_size(), + self.eval.log_size(), + self.logup_sums, ); let row_res = self.eval.evaluate(eval).row_res; @@ -333,6 +344,8 @@ impl ComponentProver for FrameworkComponen &accum.random_coeff_powers, trace_domain.log_size(), eval_domain.log_size(), + self.eval.log_size(), + self.logup_sums, ); let row_res = self.eval.evaluate(eval).row_res; diff --git a/crates/prover/src/constraint_framework/cpu_domain.rs b/crates/prover/src/constraint_framework/cpu_domain.rs index 9acdc5e40..72d285aeb 100644 --- a/crates/prover/src/constraint_framework/cpu_domain.rs +++ b/crates/prover/src/constraint_framework/cpu_domain.rs @@ -2,8 +2,8 @@ use std::ops::Mul; use num_traits::Zero; -use super::logup::LogupAtRow; -use super::EvalAtRow; +use super::logup::{LogupAtRow, LogupSums}; +use super::{EvalAtRow, INTERACTION_TRACE_IDX}; use crate::core::backend::CpuBackend; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; @@ -35,6 +35,8 @@ impl<'a> CpuDomainEvaluator<'a> { random_coeff_powers: &'a [SecureField], domain_log_size: u32, eval_log_size: u32, + log_size: u32, + logup_sums: LogupSums, ) -> Self { Self { trace_eval, @@ -45,7 +47,7 @@ impl<'a> CpuDomainEvaluator<'a> { constraint_index: 0, domain_log_size, eval_domain_log_size: eval_log_size, - logup: LogupAtRow::dummy(), + logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), } } } diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index f0b5eb775..d5724b125 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -2,8 +2,8 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; -use super::logup::LogupAtRow; -use super::EvalAtRow; +use super::logup::{LogupAtRow, LogupSums}; +use super::{EvalAtRow, INTERACTION_TRACE_IDX}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; @@ -156,6 +156,16 @@ struct ExprEvaluator { pub logup: LogupAtRow, } +impl ExprEvaluator { + pub fn _new(log_size: u32, logup_sums: LogupSums) -> Self { + Self { + cur_var_index: Default::default(), + constraints: Default::default(), + logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), + } + } +} + impl EvalAtRow for ExprEvaluator { // TODO(alont): Should there be a version of this that disallows Secure fields for F? type F = Expr; diff --git a/crates/prover/src/constraint_framework/info.rs b/crates/prover/src/constraint_framework/info.rs index 6d428178d..2d2831b44 100644 --- a/crates/prover/src/constraint_framework/info.rs +++ b/crates/prover/src/constraint_framework/info.rs @@ -2,9 +2,9 @@ use std::ops::Mul; use num_traits::One; -use super::logup::LogupAtRow; +use super::logup::{LogupAtRow, LogupSums}; use super::preprocessed_columns::PreprocessedColumn; -use super::EvalAtRow; +use super::{EvalAtRow, INTERACTION_TRACE_IDX}; use crate::constraint_framework::PREPROCESSED_TRACE_IDX; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; @@ -21,8 +21,23 @@ pub struct InfoEvaluator { pub logup: LogupAtRow, } impl InfoEvaluator { - pub fn new() -> Self { - Self::default() + pub fn new( + log_size: u32, + preprocessed_columns: Vec, + logup_sums: LogupSums, + ) -> Self { + Self { + mask_offsets: Default::default(), + n_constraints: Default::default(), + preprocessed_columns, + logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), + } + } + + /// Create an empty `InfoEvaluator`, to measure components before their size and logup sums are + /// available. + pub fn empty() -> Self { + Self::new(16, vec![], (SecureField::default(), None)) } } impl EvalAtRow for InfoEvaluator { diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index dcee6a6aa..650607c1d 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -24,6 +24,8 @@ use crate::core::ColumnVec; /// Represents the value of the prefix sum column at some index. /// Should be used to eliminate padded rows for the logup sum. pub type ClaimedPrefixSum = (SecureField, usize); +// (total_sum, claimed_sum) +pub type LogupSums = (SecureField, Option); /// Evaluates constraints for batched logups. /// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum. @@ -38,11 +40,12 @@ pub struct LogupAtRow { pub claimed_sum: Option, /// The evaluation of the last cumulative sum column. pub prev_col_cumsum: E::EF, - cur_frac: Option>, - is_finalized: bool, + pub cur_frac: Option>, + pub is_finalized: bool, /// The value of the `is_first` constant column at current row. /// See [`super::preprocessed_columns::gen_is_first()`]. pub is_first: E::F, + pub log_size: u32, } impl Default for LogupAtRow { @@ -55,7 +58,7 @@ impl LogupAtRow { interaction: usize, total_sum: SecureField, claimed_sum: Option, - is_first: E::F, + log_size: u32, ) -> Self { Self { interaction, @@ -63,8 +66,9 @@ impl LogupAtRow { claimed_sum, prev_col_cumsum: E::EF::zero(), cur_frac: None, - is_finalized: false, - is_first, + is_finalized: true, + is_first: E::F::zero(), + log_size, } } @@ -78,53 +82,9 @@ impl LogupAtRow { cur_frac: None, is_finalized: true, is_first: E::F::zero(), + log_size: 10, } } - - pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction) { - // Add a constraint that num / denom = diff. - if let Some(cur_frac) = self.cur_frac.clone() { - let [cur_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0]); - let diff = cur_cumsum.clone() - self.prev_col_cumsum.clone(); - self.prev_col_cumsum = cur_cumsum; - eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); - } - self.cur_frac = Some(fraction); - } - - pub fn finalize(&mut self, eval: &mut E) { - assert!(!self.is_finalized, "LogupAtRow was already finalized"); - - let frac = self.cur_frac.clone().unwrap(); - - // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted offset - // from the is_first column when constant columns are supported. - let (cur_cumsum, prev_row_cumsum) = match self.claimed_sum { - Some((claimed_sum, claimed_row_index)) => { - let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = eval - .next_extension_interaction_mask( - self.interaction, - [0, -1, claimed_row_index as isize], - ); - - // Constrain that the claimed_sum in case that it is not equal to the total_sum. - eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first.clone()); - (cur_cumsum, prev_row_cumsum) - } - None => { - let [cur_cumsum, prev_row_cumsum] = - eval.next_extension_interaction_mask(self.interaction, [0, -1]); - (cur_cumsum, prev_row_cumsum) - } - }; - // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. - let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first.clone() * self.total_sum; - let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum.clone(); - - eval.add_constraint(diff * frac.denominator - frac.numerator); - - self.is_finalized = true; - } } /// Ensures that the LogupAtRow is finalized. @@ -314,30 +274,11 @@ impl<'a> LogupColGenerator<'a> { #[cfg(test)] mod tests { - use num_traits::One; - - use super::{LogupAtRow, LookupElements}; - use crate::constraint_framework::{InfoEvaluator, INTERACTION_TRACE_IDX}; + use super::LookupElements; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; - use crate::core::lookups::utils::Fraction; - - #[test] - #[should_panic] - fn test_logup_not_finalized_panic() { - let mut logup = LogupAtRow::::new( - INTERACTION_TRACE_IDX, - SecureField::one(), - None, - BaseField::one(), - ); - logup.write_frac( - &mut InfoEvaluator::default(), - Fraction::new(SecureField::one(), SecureField::one()), - ); - } #[test] fn test_lookup_elements_combine() { diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 7d3454be0..107dfd531 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -131,19 +131,11 @@ pub trait EvalAtRow { }, ) .collect(); - self.write_frac(fracs.into_iter().sum()); + self.write_logup_frac(fracs.into_iter().sum()); } // TODO(alont): Remove these once LogupAtRow is no longer used. - fn init_logup( - &mut self, - _total_sum: SecureField, - _claimed_sum: Option, - _log_size: u32, - ) { - unimplemented!() - } - fn write_frac(&mut self, _fraction: Fraction) { + fn write_logup_frac(&mut self, _fraction: Fraction) { unimplemented!() } fn finalize_logup(&mut self) { @@ -156,35 +148,58 @@ pub trait EvalAtRow { /// TODO(alont): Remove once LogupAtRow is no longer used. macro_rules! logup_proxy { () => { - fn init_logup( - &mut self, - total_sum: SecureField, - claimed_sum: Option, - log_size: u32, - ) { - let is_first = self.get_preprocessed_column( - crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst( - log_size, - ), - ); - self.logup = crate::constraint_framework::logup::LogupAtRow::new( - crate::constraint_framework::INTERACTION_TRACE_IDX, - total_sum, - claimed_sum, - is_first, - ); - } - - fn write_frac(&mut self, fraction: Fraction) { - let mut logup = std::mem::take(&mut self.logup); - logup.write_frac(self, fraction); - self.logup = logup; + fn write_logup_frac(&mut self, fraction: Fraction) { + // Add a constraint that num / denom = diff. + if let Some(cur_frac) = self.logup.cur_frac.clone() { + let [cur_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [0]); + let diff = cur_cumsum.clone() - self.logup.prev_col_cumsum.clone(); + self.logup.prev_col_cumsum = cur_cumsum; + self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); + } else { + self.logup.is_first = self.get_preprocessed_column( + super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size), + ); + self.logup.is_finalized = false; + } + self.logup.cur_frac = Some(fraction); } fn finalize_logup(&mut self) { - let mut logup = std::mem::take(&mut self.logup); - logup.finalize(self); - self.logup = logup; + assert!(!self.logup.is_finalized, "LogupAtRow was already finalized"); + + let frac = self.logup.cur_frac.clone().unwrap(); + + // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted + // offset from the is_first column when constant columns are supported. + let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum { + Some((claimed_sum, claimed_row_index)) => { + let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = self + .next_extension_interaction_mask( + self.logup.interaction, + [0, -1, claimed_row_index as isize], + ); + + // Constrain that the claimed_sum in case that it is not equal to the total_sum. + self.add_constraint( + (claimed_cumsum - claimed_sum) * self.logup.is_first.clone(), + ); + (cur_cumsum, prev_row_cumsum) + } + None => { + let [cur_cumsum, prev_row_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [0, -1]); + (cur_cumsum, prev_row_cumsum) + } + }; + // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. + let fixed_prev_row_cumsum = + prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum; + let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone(); + + self.add_constraint(diff * frac.denominator - frac.numerator); + + self.logup.is_finalized = true; } }; } diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index 309d5dd56..3fc2ad510 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -1,7 +1,7 @@ use std::ops::Mul; -use super::logup::LogupAtRow; -use super::EvalAtRow; +use super::logup::{LogupAtRow, LogupSums}; +use super::{EvalAtRow, INTERACTION_TRACE_IDX}; use crate::core::air::accumulation::PointEvaluationAccumulator; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; @@ -22,6 +22,8 @@ impl<'a> PointEvaluator<'a> { mask: TreeVec>>, evaluation_accumulator: &'a mut PointEvaluationAccumulator, denom_inverse: SecureField, + log_size: u32, + logup_sums: LogupSums, ) -> Self { let col_index = vec![0; mask.len()]; Self { @@ -29,7 +31,7 @@ impl<'a> PointEvaluator<'a> { evaluation_accumulator, col_index, denom_inverse, - logup: LogupAtRow::dummy(), + logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), } } } diff --git a/crates/prover/src/constraint_framework/simd_domain.rs b/crates/prover/src/constraint_framework/simd_domain.rs index dc71ebf82..b18e8e265 100644 --- a/crates/prover/src/constraint_framework/simd_domain.rs +++ b/crates/prover/src/constraint_framework/simd_domain.rs @@ -2,8 +2,8 @@ use std::ops::Mul; use num_traits::Zero; -use super::logup::LogupAtRow; -use super::EvalAtRow; +use super::logup::{LogupAtRow, LogupSums}; +use super::{EvalAtRow, INTERACTION_TRACE_IDX}; use crate::core::backend::simd::column::VeryPackedBaseColumn; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::very_packed_m31::{ @@ -41,6 +41,8 @@ impl<'a> SimdDomainEvaluator<'a> { random_coeff_powers: &'a [SecureField], domain_log_size: u32, eval_log_size: u32, + log_size: u32, + logup_sums: LogupSums, ) -> Self { Self { trace_eval, @@ -51,7 +53,7 @@ impl<'a> SimdDomainEvaluator<'a> { constraint_index: 0, domain_log_size, eval_domain_log_size: eval_log_size, - logup: LogupAtRow::dummy(), + logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), } } } diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 5f854b35b..96db39bad 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -187,6 +187,7 @@ impl BlakeComponents { round_lookup_elements: all_elements.round_elements.clone(), total_sum: stmt1.scheduler_claimed_sum, }, + (stmt1.scheduler_claimed_sum, None), ), round_components: ROUND_LOG_SPLIT .iter() @@ -200,6 +201,7 @@ impl BlakeComponents { round_lookup_elements: all_elements.round_elements.clone(), total_sum: claimed_sum, }, + (claimed_sum, None), ) }) .collect(), @@ -209,6 +211,7 @@ impl BlakeComponents { lookup_elements: all_elements.xor_elements.xor12.clone(), claimed_sum: stmt1.xor12_claimed_sum, }, + (stmt1.xor12_claimed_sum, None), ), xor9: XorTableComponent::new( tree_span_provider, @@ -216,6 +219,7 @@ impl BlakeComponents { lookup_elements: all_elements.xor_elements.xor9.clone(), claimed_sum: stmt1.xor9_claimed_sum, }, + (stmt1.xor9_claimed_sum, None), ), xor8: XorTableComponent::new( tree_span_provider, @@ -223,6 +227,7 @@ impl BlakeComponents { lookup_elements: all_elements.xor_elements.xor8.clone(), claimed_sum: stmt1.xor8_claimed_sum, }, + (stmt1.xor8_claimed_sum, None), ), xor7: XorTableComponent::new( tree_span_provider, @@ -230,6 +235,7 @@ impl BlakeComponents { lookup_elements: all_elements.xor_elements.xor7.clone(), claimed_sum: stmt1.xor7_claimed_sum, }, + (stmt1.xor7_claimed_sum, None), ), xor4: XorTableComponent::new( tree_span_provider, @@ -237,6 +243,7 @@ impl BlakeComponents { lookup_elements: all_elements.xor_elements.xor4.clone(), claimed_sum: stmt1.xor4_claimed_sum, }, + (stmt1.xor4_claimed_sum, None), ), } } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index bac37015d..964080b7e 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -20,7 +20,6 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> { } impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { pub fn eval(mut self) -> E { - self.eval.init_logup(self.total_sum, None, self.log_size); let mut v: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); let input_v = v.clone(); let m: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); @@ -197,7 +196,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { lookup_elements.combine::(&[a[1].clone(), b[1].clone(), c[1].clone()]); self.eval - .write_frac(Reciprocal::new(comb0) + Reciprocal::new(comb1)); + .write_logup_frac(Reciprocal::new(comb0) + Reciprocal::new(comb1)); c } } diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index d19496c75..8fa238b26 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -47,7 +47,7 @@ pub fn blake_round_info() -> InfoEvaluator { round_lookup_elements: RoundElements::dummy(), total_sum: SecureField::zero(), }; - component.evaluate(InfoEvaluator::default()) + component.evaluate(InfoEvaluator::empty()) } #[cfg(test)] @@ -107,6 +107,7 @@ mod tests { |eval| { component.evaluate(eval); }, + (total_sum, None), ) } } diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 9e4248ec1..2318d7ff0 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -3,7 +3,6 @@ use num_traits::{One, Zero}; use super::BlakeElements; use crate::constraint_framework::{EvalAtRow, RelationEntry, RelationType}; -use crate::core::fields::qm31::SecureField; use crate::core::lookups::utils::Fraction; use crate::core::vcs::blake2s_ref::SIGMA; use crate::examples::blake::round::RoundElements; @@ -13,10 +12,7 @@ pub fn eval_blake_scheduler_constraints( eval: &mut E, blake_lookup_elements: &BlakeElements, round_lookup_elements: &RoundElements, - total_sum: SecureField, - log_size: u32, ) { - eval.init_logup(total_sum, None, log_size); let messages: [Fu32; STATE_SIZE] = std::array::from_fn(|_| eval_next_u32(eval)); let states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1] = std::array::from_fn(|_| std::array::from_fn(|_| eval_next_u32(eval))); @@ -45,7 +41,7 @@ pub fn eval_blake_scheduler_constraints( let output_state = &states[N_ROUNDS]; // TODO(alont): Remove blake interaction. - eval.write_frac(Fraction::new( + eval.write_logup_frac(Fraction::new( E::EF::zero(), blake_lookup_elements.combine( &chain![ diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index 7af8a5517..b69318ce4 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -34,8 +34,6 @@ impl FrameworkEval for BlakeSchedulerEval { &mut eval, &self.blake_lookup_elements, &self.round_lookup_elements, - self.total_sum, - self.log_size(), ); eval } @@ -48,7 +46,7 @@ pub fn blake_scheduler_info() -> InfoEvaluator { round_lookup_elements: RoundElements::dummy(), total_sum: SecureField::zero(), }; - component.evaluate(InfoEvaluator::default()) + component.evaluate(InfoEvaluator::empty()) } #[cfg(test)] @@ -104,6 +102,7 @@ mod tests { |eval| { component.evaluate(eval); }, + (total_sum, None), ) } } diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index 8244ec99a..c49bf6ce3 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -19,7 +19,6 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS> { pub fn eval(mut self) -> E { - self.eval.init_logup(self.claimed_sum, None, self.log_size); // al, bl are the constant columns for the inputs: All pairs of elements in [0, // 2^LIMB_BITS). // cl is the constant column for the xor: al ^ bl. @@ -62,7 +61,7 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> for frac_chunk in frac_chunks.chunks(2) { let sum_frac: Fraction = frac_chunk.iter().cloned().sum(); - self.eval.write_frac(sum_frac); + self.eval.write_logup_frac(sum_frac); } self.eval.finalize_logup(); self.eval diff --git a/crates/prover/src/examples/blake/xor_table/gen.rs b/crates/prover/src/examples/blake/xor_table/gen.rs index 0d5782f27..5959da22c 100644 --- a/crates/prover/src/examples/blake/xor_table/gen.rs +++ b/crates/prover/src/examples/blake/xor_table/gen.rs @@ -165,6 +165,6 @@ pub fn generate_constant_trace( ) }) .to_vec(); - constant_trace.insert(0, gen_is_first(column_bits::())); + constant_trace.push(gen_is_first(column_bits::())); constant_trace } diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index a5c2728f3..d988e15c2 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -35,7 +35,7 @@ pub fn trace_sizes() -> TreeVec::dummy(), claimed_sum: SecureField::zero(), }; - let info = component.evaluate(InfoEvaluator::default()); + let info = component.evaluate(InfoEvaluator::empty()); info.mask_offsets .as_cols_ref() .map_cols(|_| column_bits::()) @@ -158,6 +158,7 @@ mod tests { |eval| { component.evaluate(eval); }, + (claimed_sum, None), ) } } diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 1a318fa38..8a6d0060e 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -49,8 +49,6 @@ impl FrameworkEval for PlonkEval { } fn evaluate(&self, mut eval: E) -> E { - eval.init_logup(self.total_sum, Some(self.claimed_sum), self.log_size()); - let a_wire = eval.get_preprocessed_column(PreprocessedColumn::Plonk(0)); let b_wire = eval.get_preprocessed_column(PreprocessedColumn::Plonk(1)); // Note: c_wire could also be implicit: (self.eval.point() - M31_CIRCLE_GEN.into_ef()).x. @@ -206,7 +204,7 @@ pub fn prove_fibonacci_plonk( ) }) .collect_vec(); - constant_trace.insert(0, is_first); + constant_trace.push(is_first); let constants_trace_location = tree_builder.extend_evals(constant_trace); tree_builder.commit(channel); span.exit(); @@ -230,7 +228,6 @@ pub fn prove_fibonacci_plonk( let interaction_trace_location = tree_builder.extend_evals(trace); tree_builder.commit(channel); span.exit(); - // Prove constraints. let component = PlonkComponent::new( &mut TraceLocationAllocator::default(), @@ -243,6 +240,7 @@ pub fn prove_fibonacci_plonk( interaction_trace_location, constants_trace_location, }, + (total_sum, Some((claimed_sum, padding_offset))), ); // Sanity check. Remove for production. @@ -250,9 +248,14 @@ pub fn prove_fibonacci_plonk( .trees .as_ref() .map(|t| t.polynomials.iter().cloned().collect_vec()); - assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |mut eval| { - component.evaluate(eval); - }); + assert_constraints( + &trace_polys, + CanonicCoset::new(log_n_rows), + |mut eval| { + component.evaluate(eval); + }, + (total_sum, Some((claimed_sum, padding_offset))), + ); let proof = prove(&[&component], channel, commitment_scheme).unwrap(); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index c8aa68c7d..e3f5a14b1 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -61,7 +61,6 @@ impl FrameworkEval for PoseidonEval { self.log_n_rows + LOG_EXPAND } fn evaluate(&self, mut eval: E) -> E { - eval.init_logup(self.total_sum, None, self.log_size()); eval_poseidon_constraints(&mut eval, &self.lookup_elements); eval } @@ -380,6 +379,7 @@ pub fn prove_poseidon( lookup_elements, total_sum, }, + (total_sum, None), ); info!("Poseidon component info:\n{}", component); let proof = prove(&[&component], channel, commitment_scheme).unwrap(); @@ -394,8 +394,8 @@ mod tests { use itertools::Itertools; use num_traits::One; + use crate::constraint_framework::assert_constraints; use crate::constraint_framework::preprocessed_columns::gen_is_first; - use crate::constraint_framework::{assert_constraints, EvalAtRow}; use crate::core::air::Component; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; @@ -473,10 +473,14 @@ mod tests { let traces = TreeVec::new(vec![vec![gen_is_first(LOG_N_ROWS)], trace0, trace1]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); - assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| { - eval.init_logup(total_sum, None, LOG_N_ROWS); - eval_poseidon_constraints(&mut eval, &lookup_elements); - }); + assert_constraints( + &trace_polys, + CanonicCoset::new(LOG_N_ROWS), + |mut eval| { + eval_poseidon_constraints(&mut eval, &lookup_elements); + }, + (total_sum, None), + ); } #[test_log::test] diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index 4dae23a5c..7d3f78893 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -43,8 +43,6 @@ impl FrameworkEval for StateTransitionEval self.log_n_rows + LOG_CONSTRAINT_DEGREE } fn evaluate(&self, mut eval: E) -> E { - eval.init_logup(self.total_sum, Some(self.claimed_sum), self.log_size()); - let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask()); let input_denom: E::EF = self.lookup_elements.combine(&input_state); @@ -52,7 +50,7 @@ impl FrameworkEval for StateTransitionEval output_state[COORDINATE] += E::F::one(); let output_denom: E::EF = self.lookup_elements.combine(&output_state); - eval.write_frac( + eval.write_logup_frac( Fraction::new(E::EF::one(), input_denom) + Fraction::new(-E::EF::one(), output_denom.clone()), ); @@ -105,7 +103,7 @@ fn state_transition_info() -> InfoEvaluator { total_sum: QM31::zero(), claimed_sum: (QM31::zero(), 0), }; - component.evaluate(InfoEvaluator::default()) + component.evaluate(InfoEvaluator::empty()) } pub struct StateMachineComponents { diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index feed745e9..f379a79f1 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -105,6 +105,7 @@ pub fn prove_state_machine( total_sum: total_sum_op0, claimed_sum: (claimed_sum_op0, x_row as usize - 1), }, + (total_sum_op0, Some((claimed_sum_op0, x_row as usize - 1))), ); let component1 = StateMachineOp1Component::new( tree_span_provider, @@ -114,6 +115,7 @@ pub fn prove_state_machine( total_sum: total_sum_op1, claimed_sum: (claimed_sum_op1, y_row as usize - 1), }, + (total_sum_op1, Some((claimed_sum_op1, y_row as usize - 1))), ); let components = StateMachineComponents { component0, @@ -210,6 +212,7 @@ mod tests { total_sum, claimed_sum: (total_sum, (1 << log_n_rows) - 1), }, + (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), ); let trace = TreeVec::new(vec![ @@ -218,9 +221,14 @@ mod tests { interaction_trace, ]); let trace_polys = trace.map_cols(|c| c.interpolate()); - assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| { - component.evaluate(eval); - }); + assert_constraints( + &trace_polys, + CanonicCoset::new(log_n_rows), + |eval| { + component.evaluate(eval); + }, + (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), + ); } #[test] diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 9b40a5bae..020a24c65 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -84,6 +84,7 @@ mod tests { #[cfg(not(target_arch = "wasm32"))] use crate::core::channel::Poseidon252Channel; use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; @@ -145,6 +146,7 @@ mod tests { &trace_polys, CanonicCoset::new(LOG_N_INSTANCES), fibonacci_constraint_evaluator::, + (SecureField::zero(), None), ); } @@ -164,6 +166,7 @@ mod tests { &trace_polys, CanonicCoset::new(LOG_N_INSTANCES), fibonacci_constraint_evaluator::, + (SecureField::zero(), None), ); } @@ -202,6 +205,7 @@ mod tests { WideFibonacciEval:: { log_n_rows: log_n_instances, }, + (SecureField::zero(), None), ); let proof = prove::( @@ -261,6 +265,7 @@ mod tests { WideFibonacciEval:: { log_n_rows: LOG_N_INSTANCES, }, + (SecureField::zero(), None), ); let proof = prove::( &[&component], diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 51cbebcda..46e608bb4 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -163,7 +163,13 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> Component let component_mask = mask.sub_tree(&self.trace_locations); let trace_coset = CanonicCoset::new(self.log_size()).coset; let vanish_on_trace_eval_inv = coset_vanishing(trace_coset, point).inverse(); - let mut eval = PointEvaluator::new(component_mask, accumulator, vanish_on_trace_eval_inv); + let mut eval = PointEvaluator::new( + component_mask, + accumulator, + vanish_on_trace_eval_inv, + self.log_size(), + (SecureField::zero(), None), + ); let carry_quotients_col_eval = eval_carry_quotient_col(&self.mle_eval_point, point); let is_first = eval_is_first(trace_coset, point); @@ -243,6 +249,8 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver &acc.random_coeff_powers, trace_domain.log_size(), eval_domain.log_size(), + self.log_size(), + (SecureField::zero(), None), ); let [mle_coeffs_col_eval] = eval.next_extension_interaction_mask(aux_interaction, [0]); let [carry_quotients_col_eval] = @@ -363,7 +371,13 @@ impl<'oracle, O: MleCoeffColumnOracle> Component for MleEvalVerifierComponent<'o let component_mask = mask.sub_tree(&self.trace_location); let trace_coset = CanonicCoset::new(self.log_size()).coset; let vanish_on_trace_eval_inv = coset_vanishing(trace_coset, point).inverse(); - let mut eval = PointEvaluator::new(component_mask, accumulator, vanish_on_trace_eval_inv); + let mut eval = PointEvaluator::new( + component_mask, + accumulator, + vanish_on_trace_eval_inv, + self.log_size(), + (SecureField::zero(), None), + ); let mle_coeff_col_eval = self.mle_coeff_column_oracle.evaluate_at_point(point, mask); let carry_quotients_col_eval = eval_carry_quotient_col(&self.mle_eval_point, point); @@ -384,7 +398,7 @@ impl<'oracle, O: MleCoeffColumnOracle> Component for MleEvalVerifierComponent<'o } fn mle_eval_info(interaction: usize, n_variables: usize) -> InfoEvaluator { - let mut eval = InfoEvaluator::default(); + let mut eval = InfoEvaluator::empty(); let mle_eval_point = MleEvalPoint::new(&vec![SecureField::from(2); n_variables]); let mle_claim_shift = SecureField::zero(); let mle_coeffs_col_eval = SecureField::zero(); @@ -724,7 +738,7 @@ mod tests { use itertools::{chain, Itertools}; use mle_coeff_column::{MleCoeffColumnComponent, MleCoeffColumnEval}; - use num_traits::One; + use num_traits::{One, Zero}; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -799,6 +813,7 @@ mod tests { let mle_coeffs_col_component = MleCoeffColumnComponent::new( trace_location_allocator, MleCoeffColumnEval::new(COEFFS_COL_TRACE, mle.n_variables()), + (SecureField::zero(), None), ); let mle_eval_component = MleEvalProverComponent::generate( trace_location_allocator, @@ -875,6 +890,7 @@ mod tests { let mle_coeffs_col_component = MleCoeffColumnComponent::new( trace_location_allocator, MleCoeffColumnEval::new(COEFFS_COL_TRACE, mle.n_variables()), + (SecureField::zero(), None), ); let mle_eval_component = MleEvalProverComponent::generate( trace_location_allocator, @@ -895,6 +911,7 @@ mod tests { let mle_coeffs_col_component = MleCoeffColumnComponent::new( trace_location_allocator, MleCoeffColumnEval::new(COEFFS_COL_TRACE, N_VARIABLES), + (SecureField::zero(), None), ); let mle_eval_component = MleEvalVerifierComponent::new( trace_location_allocator, @@ -941,21 +958,29 @@ mod tests { let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(log_size); - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(COEFFS_COL_TRACE, [0]); - let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); - let [is_first_eval, is_second_eval] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); - eval_mle_eval_constraints( - MLE_EVAL_TRACE, - &mut eval, - mle_coeff_col_eval, - &mle_eval_point, - claim_shift, - carry_quotients_col_eval, - is_first_eval, - is_second_eval, - ) - }); + assert_constraints( + &trace_polys, + trace_domain, + |mut eval| { + let [mle_coeff_col_eval] = + eval.next_extension_interaction_mask(COEFFS_COL_TRACE, [0]); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(AUX_TRACE, [0]); + let [is_first_eval, is_second_eval] = + eval.next_interaction_mask(AUX_TRACE, [0, -1]); + eval_mle_eval_constraints( + MLE_EVAL_TRACE, + &mut eval, + mle_coeff_col_eval, + &mle_eval_point, + claim_shift, + carry_quotients_col_eval, + is_first_eval, + is_second_eval, + ) + }, + (SecureField::zero(), None), + ) } #[test] @@ -975,18 +1000,24 @@ mod tests { let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(N_VARIABLES as u32); - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); - let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); - eval_eq_constraints( - EQ_EVAL_TRACE, - &mut eval, - &mle_eval_point, - carry_quotients_col_eval, - is_first, - is_second, - ); - }); + assert_constraints( + &trace_polys, + trace_domain, + |mut eval| { + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(AUX_TRACE, [0]); + let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); + eval_eq_constraints( + EQ_EVAL_TRACE, + &mut eval, + &mle_eval_point, + carry_quotients_col_eval, + is_first, + is_second, + ); + }, + (SecureField::zero(), None), + ); } #[test] @@ -1006,18 +1037,24 @@ mod tests { let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(N_VARIABLES as u32); - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); - let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); - eval_eq_constraints( - EQ_EVAL_TRACE, - &mut eval, - &mle_eval_point, - carry_quotients_col_eval, - is_first, - is_second, - ); - }); + assert_constraints( + &trace_polys, + trace_domain, + |mut eval| { + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(AUX_TRACE, [0]); + let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); + eval_eq_constraints( + EQ_EVAL_TRACE, + &mut eval, + &mle_eval_point, + carry_quotients_col_eval, + is_first, + is_second, + ); + }, + (SecureField::zero(), None), + ); } #[test] @@ -1037,18 +1074,24 @@ mod tests { let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(N_VARIABLES as u32); - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); - let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); - eval_eq_constraints( - EQ_EVAL_TRACE, - &mut eval, - &mle_eval_point, - carry_quotients_col_eval, - is_first, - is_second, - ); - }); + assert_constraints( + &trace_polys, + trace_domain, + |mut eval| { + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(AUX_TRACE, [0]); + let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); + eval_eq_constraints( + EQ_EVAL_TRACE, + &mut eval, + &mle_eval_point, + carry_quotients_col_eval, + is_first, + is_second, + ); + }, + (SecureField::zero(), None), + ); } #[test] @@ -1062,10 +1105,15 @@ mod tests { let trace_polys = trace.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(LOG_SIZE); - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let [row_diff] = eval.next_extension_interaction_mask(0, [0]); - eval_prefix_sum_constraints(0, &mut eval, row_diff, cumulative_sum_shift) - }); + assert_constraints( + &trace_polys, + trace_domain, + |mut eval| { + let [row_diff] = eval.next_extension_interaction_mask(0, [0]); + eval_prefix_sum_constraints(0, &mut eval, row_diff, cumulative_sum_shift) + }, + (SecureField::zero(), None), + ); } #[test] @@ -1140,7 +1188,7 @@ mod tests { } mod mle_coeff_column { - use num_traits::One; + use num_traits::{One, Zero}; use crate::constraint_framework::{ EvalAtRow, FrameworkComponent, FrameworkEval, PointEvaluator, @@ -1200,6 +1248,8 @@ mod tests { mask.sub_tree(self.trace_locations()), &mut accumulator, SecureField::one(), + self.log_size(), + (SecureField::zero(), None), ); eval_mle_coeff_col(self.interaction, &mut eval)