Skip to content

Commit

Permalink
Initialize logups automatically when first used.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 12, 2024
1 parent e6cd98e commit 0c64f92
Show file tree
Hide file tree
Showing 23 changed files with 313 additions and 216 deletions.
17 changes: 12 additions & 5 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::backend::{Backend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand All @@ -19,12 +19,17 @@ pub struct AssertEvaluator<'a> {
pub logup: LogupAtRow<Self>,
}
impl<'a> AssertEvaluator<'a> {
pub fn new(trace: &'a TreeVec<Vec<Vec<BaseField>>>, row: usize) -> Self {
pub fn new(
trace: &'a TreeVec<Vec<Vec<BaseField>>>,
row: usize,
log_size: u32,
logup_sums: LogupSums,
) -> Self {
Self {
trace,
col_index: TreeVec::new(vec![0; trace.len()]),
row,
logup: LogupAtRow::dummy(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}
Expand Down Expand Up @@ -69,6 +74,7 @@ pub fn assert_constraints<B: Backend>(
trace_polys: &TreeVec<Vec<CirclePoly<B>>>,
trace_domain: CanonicCoset,
assert_func: impl Fn(AssertEvaluator<'_>),
logup_sums: LogupSums,
) {
let traces = trace_polys.as_ref().map(|tree| {
tree.iter()
Expand All @@ -84,7 +90,8 @@ pub fn assert_constraints<B: Backend>(
.collect()
});
for row in 0..trace_domain.size() {
let eval = AssertEvaluator::new(&traces, row);
let eval = AssertEvaluator::new(&traces, row, trace_domain.log_size(), logup_sums);

assert_func(eval);
}
}
24 changes: 22 additions & 2 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rayon::prelude::*;
use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::logup::ClaimedPrefixSum;
use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
Expand Down Expand Up @@ -80,16 +81,28 @@ pub struct FrameworkComponent<C: FrameworkEval> {
eval: C,
trace_locations: TreeVec<TreeSubspan>,
info: InfoEvaluator,
total_sum: SecureField,
claimed_sum: Option<ClaimedPrefixSum>,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
let info = eval.evaluate(InfoEvaluator::default());
pub fn new(
location_allocator: &mut TraceLocationAllocator,
eval: E,
total_sum: SecureField,
claimed_sum: Option<ClaimedPrefixSum>,
) -> Self {
let info = eval.evaluate(InfoEvaluator::new(
eval.log_size(),
(total_sum, claimed_sum),
));
let trace_locations = location_allocator.next_for_structure(&info.mask_offsets);
Self {
eval,
trace_locations,
info,
total_sum,
claimed_sum,
}
}

Expand Down Expand Up @@ -137,6 +150,8 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
mask.sub_tree(&self.trace_locations),
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
self.eval.log_size(),
(self.total_sum, self.claimed_sum),
));
}
}
Expand All @@ -154,6 +169,7 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());

println!("n polys: {}", trace.polys[0].len());
let component_polys = trace.polys.sub_tree(&self.trace_locations);
let component_evals = trace.evals.sub_tree(&self.trace_locations);

Expand Down Expand Up @@ -205,6 +221,8 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
self.eval.log_size(),
(self.total_sum, self.claimed_sum),
);
let row_res = self.eval.evaluate(eval).row_res;

Expand Down Expand Up @@ -242,6 +260,8 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
self.eval.log_size(),
(self.total_sum, self.claimed_sum),
);
let row_res = self.eval.evaluate(eval).row_res;

Expand Down
8 changes: 5 additions & 3 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::ops::Mul;

use num_traits::Zero;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -35,6 +35,8 @@ impl<'a> CpuDomainEvaluator<'a> {
random_coeff_powers: &'a [SecureField],
domain_log_size: u32,
eval_log_size: u32,
log_size: u32,
logup_sums: LogupSums,
) -> Self {
Self {
trace_eval,
Expand All @@ -45,7 +47,7 @@ impl<'a> CpuDomainEvaluator<'a> {
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
logup: LogupAtRow::dummy(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}
Expand Down
14 changes: 12 additions & 2 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};

use num_traits::{One, Zero};

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
Expand Down Expand Up @@ -156,6 +156,16 @@ struct ExprEvaluator {
pub logup: LogupAtRow<Self>,
}

impl ExprEvaluator {
pub fn _new(log_size: u32, logup_sums: LogupSums) -> Self {
Self {
cur_var_index: Default::default(),
constraints: Default::default(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}
}

impl EvalAtRow for ExprEvaluator {
// TODO(alont): Should there be a version of this that disallows Secure fields for F?
type F = Expr;
Expand Down
18 changes: 14 additions & 4 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::ops::Mul;

use num_traits::One;

use super::logup::LogupAtRow;
use super::EvalAtRow;
use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Fraction;
Expand All @@ -18,8 +18,18 @@ pub struct InfoEvaluator {
pub logup: LogupAtRow<Self>,
}
impl InfoEvaluator {
pub fn new() -> Self {
Self::default()
pub fn new(log_size: u32, logup_sums: LogupSums) -> Self {
Self {
mask_offsets: Default::default(),
n_constraints: Default::default(),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
}
}

/// Create an empty `InfoEvaluator`, to measure components before their size and logup sums are
/// available.
pub fn empty() -> Self {
Self::new(16, (SecureField::default(), None))
}
}
impl EvalAtRow for InfoEvaluator {
Expand Down
83 changes: 12 additions & 71 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use crate::core::ColumnVec;
/// Represents the value of the prefix sum column at some index.
/// Should be used to eliminate padded rows for the logup sum.
pub type ClaimedPrefixSum = (SecureField, usize);
// (total_sum, claimed_sum)
pub type LogupSums = (SecureField, Option<ClaimedPrefixSum>);

/// Evaluates constraints for batched logups.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
Expand All @@ -38,11 +40,12 @@ pub struct LogupAtRow<E: EvalAtRow> {
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
cur_frac: Option<Fraction<E::EF, E::EF>>,
is_finalized: bool,
pub cur_frac: Option<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::gen_is_first()`].
pub is_first: E::F,
pub log_size: u32,
}

impl<E: EvalAtRow> Default for LogupAtRow<E> {
Expand All @@ -55,16 +58,17 @@ impl<E: EvalAtRow> LogupAtRow<E> {
interaction: usize,
total_sum: SecureField,
claimed_sum: Option<ClaimedPrefixSum>,
is_first: E::F,
log_size: u32,
) -> Self {
Self {
interaction,
total_sum,
claimed_sum,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
is_finalized: false,
is_first,
is_finalized: true,
is_first: E::F::zero(),
log_size,
}
}

Expand All @@ -78,60 +82,16 @@ impl<E: EvalAtRow> LogupAtRow<E> {
cur_frac: None,
is_finalized: true,
is_first: E::F::zero(),
log_size: 10,
}
}

pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.cur_frac.clone() {
let [cur_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0]);
let diff = cur_cumsum.clone() - self.prev_col_cumsum.clone();
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
}
self.cur_frac = Some(fraction);
}

pub fn finalize(&mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");

let frac = self.cur_frac.clone().unwrap();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted offset
// from the is_first column when constant columns are supported.
let (cur_cumsum, prev_row_cumsum) = match self.claimed_sum {
Some((claimed_sum, claimed_row_index)) => {
let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = eval
.next_extension_interaction_mask(
self.interaction,
[0, -1, claimed_row_index as isize],
);

// Constrain that the claimed_sum in case that it is not equal to the total_sum.
eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first.clone());
(cur_cumsum, prev_row_cumsum)
}
None => {
let [cur_cumsum, prev_row_cumsum] =
eval.next_extension_interaction_mask(self.interaction, [0, -1]);
(cur_cumsum, prev_row_cumsum)
}
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first.clone() * self.total_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum.clone();

eval.add_constraint(diff * frac.denominator - frac.numerator);

self.is_finalized = true;
}
}

/// Ensures that the LogupAtRow is finalized.
/// LogupAtRow should be finalized exactly once.
impl<E: EvalAtRow> Drop for LogupAtRow<E> {
fn drop(&mut self) {
assert!(self.is_finalized, "LogupAtRow was not finalized");
// assert!(self.is_finalized, "LogupAtRow was not finalized");
}
}

Expand Down Expand Up @@ -314,30 +274,11 @@ impl<'a> LogupColGenerator<'a> {

#[cfg(test)]
mod tests {
use num_traits::One;

use super::{LogupAtRow, LookupElements};
use crate::constraint_framework::{InfoEvaluator, INTERACTION_TRACE_IDX};
use super::LookupElements;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

#[test]
#[should_panic]
fn test_logup_not_finalized_panic() {
let mut logup = LogupAtRow::<InfoEvaluator>::new(
INTERACTION_TRACE_IDX,
SecureField::one(),
None,
BaseField::one(),
);
logup.write_frac(
&mut InfoEvaluator::default(),
Fraction::new(SecureField::one(), SecureField::one()),
);
}

#[test]
fn test_lookup_elements_combine() {
Expand Down
Loading

0 comments on commit 0c64f92

Please sign in to comment.