diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 56ccc846d..76cf1e6d4 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -78,6 +78,10 @@ harness = false name = "prodcheck" harness = false +[[bench]] +name = "fracaddcheck" +harness = false + [features] default = ["rayon"] rayon = ["binius-utils/rayon"] diff --git a/crates/prover/benches/fracaddcheck.rs b/crates/prover/benches/fracaddcheck.rs new file mode 100644 index 000000000..63784cc2a --- /dev/null +++ b/crates/prover/benches/fracaddcheck.rs @@ -0,0 +1,88 @@ +// Copyright 2025-2026 The Binius Developers + +use binius_field::arch::OptimalPackedB128; +use binius_math::{multilinear::evaluate::evaluate, test_utils::random_field_buffer}; +use binius_prover::protocols::fracaddcheck::FracAddCheckProver; +use binius_transcript::ProverTranscript; +use binius_verifier::{config::StdChallenger, protocols::prodcheck::MultilinearEvalClaim}; +use criterion::{BatchSize, Criterion, Throughput, criterion_group, criterion_main}; + +type P = OptimalPackedB128; + +fn bench_fracaddcheck_new(c: &mut Criterion) { + let mut group = c.benchmark_group("fracaddcheck/new"); + + for n_vars in [12, 16, 20] { + // Full reduction: k = n_vars, so sums layer has log_len = 0. + let k = n_vars; + + // Consider each element to be one hypercube vertex. + group.throughput(Throughput::Elements(1 << n_vars)); + group.bench_function(format!("n_vars={n_vars}"), |b| { + let mut rng = rand::rng(); + let witness_num = random_field_buffer::

(&mut rng, n_vars); + let witness_den = random_field_buffer::

(&mut rng, n_vars); + + b.iter_batched( + || (witness_num.clone(), witness_den.clone()), + |(witness_num, witness_den)| { + FracAddCheckProver::

::new(k, (witness_num, witness_den)).unwrap() + }, + BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +fn bench_fracaddcheck_prove(c: &mut Criterion) { + let mut group = c.benchmark_group("fracaddcheck/prove"); + + for n_vars in [12, 16, 20] { + // Full reduction: k = n_vars, so sums layer has log_len = 0. + let k = n_vars; + + // Consider each element to be one hypercube vertex. + group.throughput(Throughput::Elements(1 << n_vars)); + group.bench_function(format!("n_vars={n_vars}"), |b| { + let mut rng = rand::rng(); + let witness_num = random_field_buffer::

(&mut rng, n_vars); + let witness_den = random_field_buffer::

(&mut rng, n_vars); + + // Pre-compute the claim (final sums layer evaluation at empty point). + let (_prover, sums) = + FracAddCheckProver::new(k, (witness_num.clone(), witness_den.clone())).unwrap(); + let sum_num_eval = evaluate(&sums.0, &[]).unwrap(); + let sum_den_eval = evaluate(&sums.1, &[]).unwrap(); + let claim = ( + MultilinearEvalClaim { + eval: sum_num_eval, + point: vec![], + }, + MultilinearEvalClaim { + eval: sum_den_eval, + point: vec![], + }, + ); + + let mut transcript = ProverTranscript::new(StdChallenger::default()); + + b.iter_batched( + || { + let (prover, _sums) = + FracAddCheckProver::new(k, (witness_num.clone(), witness_den.clone())) + .unwrap(); + (prover, claim.clone()) + }, + |(prover, claim)| prover.prove(claim, &mut transcript).unwrap(), + BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +criterion_group!(fracaddcheck, bench_fracaddcheck_new, bench_fracaddcheck_prove); +criterion_main!(fracaddcheck); diff --git a/crates/prover/src/protocols/fracaddcheck.rs b/crates/prover/src/protocols/fracaddcheck.rs new file mode 100644 index 000000000..8f50718fc --- /dev/null +++ b/crates/prover/src/protocols/fracaddcheck.rs @@ -0,0 +1,315 @@ +// Copyright 2025-2026 The Binius Developers + +use binius_field::{Field, PackedField}; +use binius_math::{FieldBuffer, line::extrapolate_line_packed}; +use binius_transcript::{ + ProverTranscript, + fiat_shamir::{CanSample, Challenger}, +}; +use binius_utils::rayon::iter::{IntoParallelIterator, ParallelIterator}; +use binius_verifier::protocols::prodcheck::MultilinearEvalClaim; + +use crate::protocols::sumcheck::{ + Error as SumcheckError, MleToSumCheckDecorator, + batch::batch_prove_and_write_evals, + common::SumcheckProver, + frac_add_mle::{FracAddMleCheckProver, FractionalBuffer}, +}; + +/// Prover for the fractional addition protocol. +/// +/// Each layer is a double of the numerator and denominator values of fractional terms. Each layer +/// represents the addition of siblings with respect to the fractional addition rule: +/// $$\frac{a_0}{b_0} + \frac{a_1}{b_1} = \frac{a_0b_1 + a_1b_0}{b_0b_1}$ +pub struct FracAddCheckProver { + layers: Vec<(FieldBuffer

, FieldBuffer

)>, +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error( + "mismatched numerator/denominator lengths: numerator log_len {num_log_len}, denominator log_len {den_log_len}" + )] + MismatchedWitnessLengths { + num_log_len: usize, + den_log_len: usize, + }, + #[error("sumcheck error: {0}")] + Sumcheck(#[from] SumcheckError), +} + +impl FracAddCheckProver

+where + F: Field, + P: PackedField, +{ + /// Creates a new [`FracAddCheckProver`]. + /// + /// Returns `(prover, sums)` where `sums` is the final layer containing the + /// fractional additions over all `k` variables. + /// + /// # Arguments + /// * `k` - The number of variables over which the reduction is taken. Each reduction step + /// reduces one variable by computing fractional additions of sibling terms. + /// * `witness` - The witness numerator/denominator layers + /// + /// # Preconditions + /// * `witness.0.log_len() >= k` + pub fn new( + k: usize, + witness: FractionalBuffer

, + ) -> Result<(Self, FractionalBuffer

), Error> { + let (witness_num, witness_den) = witness; + if witness_num.log_len() != witness_den.log_len() { + return Err(Error::MismatchedWitnessLengths { + num_log_len: witness_num.log_len(), + den_log_len: witness_den.log_len(), + }); + } + assert!(witness_num.log_len() >= k); + + let mut layers = Vec::with_capacity(k + 1); + layers.push((witness_num, witness_den)); + + for _ in 0..k { + let prev_layer = layers.last().expect("layers is non-empty"); + + let (num, den) = prev_layer; + let (num_0, num_1) = num + .split_half_ref() + .expect("layer has at least one variable"); + + let (den_0, den_1) = den + .split_half_ref() + .expect("layer has at least one variable"); + + let (next_layer_num, next_layer_den) = + (num_0.as_ref(), den_0.as_ref(), num_1.as_ref(), den_1.as_ref()) + .into_par_iter() + .map(|(&a_0, &b_0, &a_1, &b_1)| (a_0 * b_1 + a_1 * b_0, b_0 * b_1)) + .collect::<(Vec<_>, Vec<_>)>(); + + let next_layer = ( + FieldBuffer::new(num.log_len() - 1, next_layer_num.into_boxed_slice()) + .expect("Should be half of previous layer"), + FieldBuffer::new(den.log_len() - 1, next_layer_den.into_boxed_slice()) + .expect("Should be half of previous layer"), + ); + + layers.push(next_layer); + } + + let sums = layers.pop().expect("layers has k+1 elements"); + Ok((Self { layers }, sums)) + } + + /// Returns the number of remaining layers to prove. + pub fn n_layers(&self) -> usize { + self.layers.len() + } + + /// Pops the last layer and returns a sumcheck prover for it. + /// + /// Returns `(layer_prover, remaining)` where: + /// - `layer_prover` is a sumcheck prover for the popped layer + /// - `remaining` is `Some(self)` if there are more layers, `None` otherwise + pub fn layer_prover( + mut self, + claim: (MultilinearEvalClaim, MultilinearEvalClaim), + ) -> Result<(impl SumcheckProver, Option), Error> { + let (num_claim, den_claim) = claim; + assert_eq!( + num_claim.point, den_claim.point, + "fractional claims must share the evaluation point" + ); + + let layer = self.layers.pop().expect("layers is non-empty"); + + let remaining = if self.layers.is_empty() { + None + } else { + Some(self) + }; + + let prover = + FracAddMleCheckProver::new(layer, &num_claim.point, [num_claim.eval, den_claim.eval])?; + + Ok((MleToSumCheckDecorator::new(prover), remaining)) + } + + /// Runs the fractional addition check protocol and returns the final evaluation claims. + /// + /// This consumes the prover and runs sumcheck reductions from the smallest layer back to + /// the largest. + /// + /// # Arguments + /// * `claim` - The initial multilinear evaluation claims (numerator, denominator) + /// * `transcript` - The prover transcript + /// + /// # Preconditions + /// * `claim.0.point.len() == witness.log_len() - k` (where k is the number of reduction layers) + pub fn prove( + self, + claim: (MultilinearEvalClaim, MultilinearEvalClaim), + transcript: &mut ProverTranscript, + ) -> Result<(MultilinearEvalClaim, MultilinearEvalClaim), Error> + where + Challenger_: Challenger, + { + let mut prover_opt = Some(self); + let mut claim = claim; + + while let Some(prover) = prover_opt { + let (sumcheck_prover, remaining) = prover.layer_prover(claim)?; + prover_opt = remaining; + + let output = batch_prove_and_write_evals(vec![sumcheck_prover], transcript)?; + + let mut multilinear_evals = output.multilinear_evals; + let evals = multilinear_evals.pop().expect("batch contains one prover"); + + let [num_0, num_1, den_0, den_1] = evals + .try_into() + .expect("prover evaluates four multilinears"); + + let r = transcript.sample(); + + let next_num = extrapolate_line_packed(num_0, num_1, r); + let next_den = extrapolate_line_packed(den_0, den_1, r); + + let mut next_point = output.challenges; + next_point.push(r); + + let num_claim = MultilinearEvalClaim { + eval: next_num, + point: next_point.clone(), + }; + let den_claim = MultilinearEvalClaim { + eval: next_den, + point: next_point, + }; + + claim = (num_claim, den_claim); + } + + Ok(claim) + } +} + +#[cfg(test)] +mod tests { + use binius_field::PackedField; + use binius_math::{ + multilinear::evaluate::evaluate, + test_utils::{Packed128b, random_field_buffer, random_scalars}, + }; + use binius_transcript::ProverTranscript; + use binius_verifier::{config::StdChallenger, protocols::fracaddcheck}; + use rand::{SeedableRng, rngs::StdRng}; + + use super::*; + + fn test_frac_add_check_prove_verify_helper(n: usize, k: usize) { + let mut rng = StdRng::seed_from_u64(0); + + // 1. Create random witness with log_len = n + k + let witness_num = random_field_buffer::

(&mut rng, n + k); + let witness_den = random_field_buffer::

(&mut rng, n + k); + + // 2. Create prover (computes fractional-add layers) + let (prover, sums) = + FracAddCheckProver::new(k, (witness_num.clone(), witness_den.clone())).unwrap(); + + // 3. Generate random n-dimensional challenge point + let eval_point = random_scalars::(&mut rng, n); + + // 4. Evaluate sums at challenge point to create claims + let sum_num_eval = evaluate(&sums.0, &eval_point).unwrap(); + let sum_den_eval = evaluate(&sums.1, &eval_point).unwrap(); + let prover_claim = ( + MultilinearEvalClaim { + eval: sum_num_eval, + point: eval_point.clone(), + }, + MultilinearEvalClaim { + eval: sum_den_eval, + point: eval_point.clone(), + }, + ); + let verifier_claim = fracaddcheck::FracAddEvalClaim { + num_eval: sum_num_eval, + den_eval: sum_den_eval, + point: eval_point, + }; + + // 5. Run prover + let mut prover_transcript = ProverTranscript::new(StdChallenger::default()); + let prover_output = prover + .prove(prover_claim.clone(), &mut prover_transcript) + .unwrap(); + + // 6. Run verifier + let mut verifier_transcript = prover_transcript.into_verifier(); + let verifier_output = + fracaddcheck::verify(k, verifier_claim, &mut verifier_transcript).unwrap(); + + // 7. Check outputs match + assert_eq!(prover_output.0.point, prover_output.1.point); + assert_eq!(prover_output.0.point, verifier_output.point); + assert_eq!(prover_output.0.eval, verifier_output.num_eval); + assert_eq!(prover_output.1.eval, verifier_output.den_eval); + + // 8. Verify multilinear evaluation of original witness + let expected_num = evaluate(&witness_num, &verifier_output.point).unwrap(); + let expected_den = evaluate(&witness_den, &verifier_output.point).unwrap(); + assert_eq!(verifier_output.num_eval, expected_num); + assert_eq!(verifier_output.den_eval, expected_den); + } + + #[test] + fn test_frac_add_check_prove_verify() { + test_frac_add_check_prove_verify_helper::(4, 3); + } + + #[test] + fn test_frac_add_check_full_prove_verify() { + test_frac_add_check_prove_verify_helper::(0, 4); + } + + fn test_frac_add_check_layer_computation_helper(n: usize, k: usize) { + let mut rng = StdRng::seed_from_u64(0); + + // Create random witness with log_len = n + k + let witness_num = random_field_buffer::

(&mut rng, n + k); + let witness_den = random_field_buffer::

(&mut rng, n + k); + + // Create prover (computes fractional-add layers) + let (_prover, sums) = + FracAddCheckProver::new(k, (witness_num.clone(), witness_den.clone())).unwrap(); + + // For each index i in the sums layer, verify it equals the fractional sum of witness values + // at indices i + z * 2^n for z in 0..2^k (strided access, not contiguous) + let stride = 1 << n; + let num_terms = 1 << k; + for i in 0..(1 << n) { + let mut expected_num = witness_num.get_checked(i).unwrap(); + let mut expected_den = witness_den.get_checked(i).unwrap(); + for z in 1..num_terms { + let idx = i + z * stride; + let num_z = witness_num.get_checked(idx).unwrap(); + let den_z = witness_den.get_checked(idx).unwrap(); + expected_num = expected_num * den_z + num_z * expected_den; + expected_den *= den_z; + } + let actual_num = sums.0.get_checked(i).unwrap(); + let actual_den = sums.1.get_checked(i).unwrap(); + assert_eq!(actual_num, expected_num, "Numerator mismatch at index {i}"); + assert_eq!(actual_den, expected_den, "Denominator mismatch at index {i}"); + } + } + + #[test] + fn test_frac_add_check_layer_computation() { + test_frac_add_check_layer_computation_helper::(4, 3); + } +} diff --git a/crates/prover/src/protocols/mod.rs b/crates/prover/src/protocols/mod.rs index e9fb40e2b..4ca8fc672 100644 --- a/crates/prover/src/protocols/mod.rs +++ b/crates/prover/src/protocols/mod.rs @@ -1,6 +1,7 @@ // Copyright 2025 Irreducible Inc. pub mod basefold; +pub mod fracaddcheck; mod inout_check; pub mod intmul; pub mod prodcheck; diff --git a/crates/prover/src/protocols/sumcheck/frac_add_mle.rs b/crates/prover/src/protocols/sumcheck/frac_add_mle.rs new file mode 100644 index 000000000..1082f489f --- /dev/null +++ b/crates/prover/src/protocols/sumcheck/frac_add_mle.rs @@ -0,0 +1,447 @@ +// Copyright 2025-2026 The Binius Developers + +use binius_field::{Field, PackedField}; +use binius_math::{ + AsSlicesMut, FieldBuffer, FieldSliceMut, field_buffer::FieldBufferSplitMut, + multilinear::fold::fold_highest_var_inplace, +}; +use binius_utils::rayon::prelude::*; +use binius_verifier::protocols::sumcheck::RoundCoeffs; +use itertools::izip; + +use super::error::Error; +use crate::protocols::sumcheck::{ + common::{MleCheckProver, SumcheckProver}, + gruen32::Gruen32, + round_evals::RoundEvals2, +}; + +pub type FractionalBuffer

= (FieldBuffer

, FieldBuffer

); +#[derive(Debug, Clone)] +enum RoundCoeffsOrEvals { + Coeffs([RoundCoeffs; 2]), + Evals([F; 2]), +} + +// Prover for the fractional additional claims required in LogUp*. We keep numerators and +// denominators to be added in a single buffer respectively, with the assumption that the 2 +// collections to be added are in either half. +pub struct FracAddMleCheckProver { + // Parallel arrays: index 0 = numerator MLE evals, index 1 = denominator MLE evals. + fraction_pairs: [FieldBuffer

; 2], + // Alternates between the last round's polynomial coefficients and the folded evaluation + // values. + last_coeffs_or_evals: RoundCoeffsOrEvals, + gruen32: Gruen32

, +} + +impl> FracAddMleCheckProver

{ + /// Constructs a prover, given the multilinear polynomial evaluations (in pairs) and + /// evaluation claims on the shared evaluation point. + pub fn new( + fraction: (FieldBuffer

, FieldBuffer

), + eval_point: &[F], + eval_claims: [F; 2], + ) -> Result { + let n_vars = eval_point.len(); + + let (num, den) = fraction; + // One extra variable for the numerator/denominator selector bit. + if num.log_len() != n_vars + 1 || den.log_len() != n_vars + 1 { + return Err(Error::MultilinearSizeMismatch); + } + + let last_coeffs_or_evals = RoundCoeffsOrEvals::Evals(eval_claims); + + let gruen32 = Gruen32::new(eval_point); + + let fraction_pairs = [num, den]; + Ok(Self { + fraction_pairs, + last_coeffs_or_evals, + gruen32, + }) + } +} + +impl MleCheckProver for FracAddMleCheckProver

+where + F: Field, + P: PackedField, +{ + // Expose the evaluation point so wrappers can lift this MLE-check prover into sumcheck. + fn eval_point(&self) -> &[F] { + self.gruen32.eval_point() + } +} + +impl SumcheckProver for FracAddMleCheckProver

+where + F: Field, + P: PackedField, +{ + fn n_vars(&self) -> usize { + self.gruen32.n_vars_remaining() + } + + fn n_claims(&self) -> usize { + 2 + } + + fn execute(&mut self) -> Result>, Error> { + let RoundCoeffsOrEvals::Evals(sums) = &self.last_coeffs_or_evals else { + return Err(Error::ExpectedFold); + }; + + // We need at least one variable to produce a round polynomial. + assert!(self.n_vars() > 0); + let n_vars = self.n_vars(); + let [num, den] = &mut self.fraction_pairs; + + let mut num_split = num.split_half_mut()?; + let mut den_split = den.split_half_mut()?; + + // Fixed ordering expected by accumulate_chunk: num(0), num(1), den(0), den(1). + let slices = split_and_truncate(&mut num_split, &mut den_split, n_vars); + + // Perform chunked summation for benefits detailed in bivariate_product_multi_mle. + const MAX_CHUNK_VARS: usize = 8; + // Keep enough vars per chunk to amortize eq-eval overhead, but never exceed n_vars - 1 + // because the highest variable is folded by the round polynomial. + let chunk_vars = std::cmp::max(MAX_CHUNK_VARS, P::LOG_WIDTH).min(n_vars - 1); + + let packed_prime_evals: [RoundEvals2

; 2] = (0..1 << (n_vars - 1 - chunk_vars)) + .into_par_iter() + .try_fold( + || [RoundEvals2::default(); 2], + |mut packed_prime_evals: [RoundEvals2

; 2], chunk_index| -> Result<_, Error> { + accumulate_chunk( + &self.gruen32, + &slices, + chunk_vars, + &mut packed_prime_evals, + chunk_index, + )?; + Ok(packed_prime_evals) + }, + ) + .try_reduce( + || [RoundEvals2::default(); 2], + |lhs, rhs| Ok([lhs[0] + &rhs[0], lhs[1] + &rhs[1]]), + )?; + + // These are MLE-check "prime" round polynomials; sumcheck wrappers apply the eq factor. + let alpha = self.gruen32.next_coordinate(); + let round_coeffs = izip!(sums, packed_prime_evals) + .map(|(&sum, packed_evals)| { + let round_evals = packed_evals.sum_scalars(n_vars); + round_evals.interpolate_eq(sum, alpha) + }) + .collect::>(); + + self.last_coeffs_or_evals = RoundCoeffsOrEvals::Coeffs( + round_coeffs.clone().try_into().expect("Will have length 2"), + ); + Ok(round_coeffs) + } + + fn fold(&mut self, challenge: F) -> Result<(), Error> { + let RoundCoeffsOrEvals::Coeffs(prime_coeffs) = &self.last_coeffs_or_evals else { + return Err(Error::ExpectedExecute); + }; + + // Folding substitutes the newest challenge into the highest variable. + assert!(self.n_vars() > 0); + + let evals = [ + prime_coeffs[0].evaluate(challenge), + prime_coeffs[1].evaluate(challenge), + ]; + + let n_vars = self.n_vars(); + let [num, den] = &mut self.fraction_pairs; + + let mut num_split = num.split_half_mut()?; + let mut den_split = den.split_half_mut()?; + let mut multilinears = split_and_truncate(&mut num_split, &mut den_split, n_vars); + + for multilinear in &mut multilinears { + fold_highest_var_inplace(multilinear, challenge)? + } + + self.gruen32.fold(challenge)?; + // After folding, we keep only the new evaluations for the next round. + self.last_coeffs_or_evals = RoundCoeffsOrEvals::Evals(evals); + Ok(()) + } + + fn finish(self) -> Result, Error> { + if self.n_vars() > 0 { + let error = match self.last_coeffs_or_evals { + RoundCoeffsOrEvals::Coeffs(_) => Error::ExpectedFold, + RoundCoeffsOrEvals::Evals(_) => Error::ExpectedExecute, + }; + + return Err(error); + } + + let multilinear_evals = self + .fraction_pairs + .into_iter() + .flat_map(|multilinear| { + let (lo, hi) = multilinear.split_half_ref().expect("Should have 2 values"); + [lo.get(0), hi.get(0)] + }) + .collect(); + + Ok(multilinear_evals) + } +} + +fn accumulate_chunk( + gruen32: &Gruen32

, + fraction_pairs: &[FieldSliceMut

; 4], + chunk_vars: usize, + packed_prime_evals: &mut [RoundEvals2

; 2], + chunk_index: usize, +) -> Result<(), Error> { + let eq_chunk = gruen32.eq_expansion().chunk(chunk_vars, chunk_index)?; + + let splits = fraction_pairs + .iter() + .map(|slice| slice.split_half_ref()) + .collect::, _>>()?; + + let chunks = splits + .iter() + .flat_map(|(lo, hi)| { + [ + lo.chunk(chunk_vars, chunk_index), + hi.chunk(chunk_vars, chunk_index), + ] + }) + .collect::, _>>()?; + // Ordering: [num_a, den_a, num_b, den_b] × {low, high} chunks. + + let [ + evals_num_a_0_chunk, + evals_num_a_1_chunk, + evals_num_b_0_chunk, + evals_num_b_1_chunk, + evals_den_a_0_chunk, + evals_den_a_1_chunk, + evals_den_b_0_chunk, + evals_den_b_1_chunk, + ]: [FieldBuffer; 8] = chunks + .try_into() + .expect( + "The destructuring contains the high and low chunk slices for each of the 4 MLES, resulting in 8 slices total" + ); + + for ( + &eq_i, + &evals_num_a_0_i, + &evals_num_a_1_i, + &evals_den_a_0_i, + &evals_den_a_1_i, + &evals_num_b_0_i, + &evals_num_b_1_i, + &evals_den_b_0_i, + &evals_den_b_1_i, + ) in izip!( + eq_chunk.as_ref(), + evals_num_a_0_chunk.as_ref(), + evals_num_a_1_chunk.as_ref(), + evals_den_a_0_chunk.as_ref(), + evals_den_a_1_chunk.as_ref(), + evals_num_b_0_chunk.as_ref(), + evals_num_b_1_chunk.as_ref(), + evals_den_b_0_chunk.as_ref(), + evals_den_b_1_chunk.as_ref() + ) { + // Infinity evals are computed by M(∞) = M(0) + M(1) for each multilinear. + let evals_num_a_inf_i = evals_num_a_0_i + evals_num_a_1_i; + let evals_den_a_inf_i = evals_den_a_0_i + evals_den_a_1_i; + let evals_num_b_inf_i = evals_num_b_0_i + evals_num_b_1_i; + let evals_den_b_inf_i = evals_den_b_0_i + evals_den_b_1_i; + + // Numerator composition: a0/b0 + a1/b1 => a0*b1 + a1*b0. + let num_1_i = evals_num_a_1_i * evals_den_b_1_i + evals_num_b_1_i * evals_den_a_1_i; + let num_inf_i = + evals_num_a_inf_i * evals_den_b_inf_i + evals_num_b_inf_i * evals_den_a_inf_i; + + // Denominator composition: b0*b1. + let den_1_i = evals_den_a_1_i * evals_den_b_1_i; + let den_inf_i = evals_den_a_inf_i * evals_den_b_inf_i; + + // Accumulate eq-weighted round evals for numerator (0) and denominator (1). + packed_prime_evals[0].y_1 += eq_i * num_1_i; + packed_prime_evals[0].y_inf += eq_i * num_inf_i; + packed_prime_evals[1].y_1 += eq_i * den_1_i; + packed_prime_evals[1].y_inf += eq_i * den_inf_i; + } + + Ok(()) +} + +fn split_and_truncate<'a, P: PackedField>( + num_split: &'a mut FieldBufferSplitMut, + den_split: &'a mut FieldBufferSplitMut, + n_vars: usize, +) -> [FieldSliceMut<'a, P>; 4] { + let [mut num_a, mut num_b] = num_split.as_slices_mut(); + let [mut den_a, mut den_b] = den_split.as_slices_mut(); + + num_a.truncate(n_vars); + num_b.truncate(n_vars); + den_a.truncate(n_vars); + den_b.truncate(n_vars); + // Fixed ordering expected by accumulate_chunk: num(0), num(1), den(0), den(1). + [num_a, num_b, den_a, den_b] +} + +#[cfg(test)] +mod tests { + use binius_field::arch::{OptimalB128, OptimalPackedB128}; + use binius_math::{ + FieldBuffer, + multilinear::{eq::eq_ind, evaluate::evaluate}, + test_utils::{random_field_buffer, random_scalars}, + }; + use binius_transcript::ProverTranscript; + use binius_verifier::{config::StdChallenger, protocols::sumcheck::batch_verify}; + use itertools::{Itertools, izip}; + use rand::{SeedableRng, prelude::StdRng}; + + use super::*; + use crate::protocols::sumcheck::{MleToSumCheckDecorator, batch::batch_prove}; + + fn test_frac_add_sumcheck_prove_verify( + prover: MleToSumCheckDecorator>, + eval_claims: [F; 2], + eval_point: &[F], + num: FieldBuffer

, + den: FieldBuffer

, + ) where + F: Field, + P: PackedField, + { + let n_vars = prover.n_vars(); + let (num_a, num_b) = num.split_half_ref().unwrap(); + let (den_a, den_b) = den.split_half_ref().unwrap(); + // Run the proving protocol + let mut prover_transcript = ProverTranscript::new(StdChallenger::default()); + let output = batch_prove(vec![prover], &mut prover_transcript).unwrap(); + + assert_eq!(output.multilinear_evals.len(), 1); + let prover_evals = output.multilinear_evals[0].clone(); + + // Write the multilinear evaluations to the transcript + prover_transcript + .message() + .write_scalar_slice(&prover_evals); + + // Convert to verifier transcript and run verification + let mut verifier_transcript = prover_transcript.into_verifier(); + let sumcheck_output = + // Degree 3 because quadratic prime polynomials are multiplied by a linear eq term. + batch_verify(n_vars, 3, &eval_claims, &mut verifier_transcript).unwrap(); + + // The prover binds variables from high to low, but evaluate expects them from low to high + let mut reduced_eval_point = sumcheck_output.challenges.clone(); + reduced_eval_point.reverse(); + + // Read the multilinear evaluations from the transcript + let multilinear_evals: Vec = verifier_transcript.message().read_vec(4).unwrap(); + + // Evaluate the equality indicator + let eq_ind_eval = eq_ind(eval_point, &reduced_eval_point); + + // Check that the original multilinears evaluate to the claimed values at the challenge + // point + let eval_num_a = evaluate(&num_a, &reduced_eval_point).unwrap(); + let eval_den_a = evaluate(&den_a, &reduced_eval_point).unwrap(); + let eval_num_b = evaluate(&num_b, &reduced_eval_point).unwrap(); + let eval_den_b = evaluate(&den_b, &reduced_eval_point).unwrap(); + + assert_eq!( + eval_num_a, multilinear_evals[0], + "Numerator A should evaluate to the first claimed evaluation" + ); + + assert_eq!( + eval_num_b, multilinear_evals[1], + "Numerator B should evaluate to the second claimed evaluation" + ); + assert_eq!( + eval_den_a, multilinear_evals[2], + "Denominator A should evaluate to the third claimed evaluation" + ); + + assert_eq!( + eval_den_b, multilinear_evals[3], + "Denominator B should evaluate to the fourth claimed evaluation" + ); + + // Check that the batched evaluation matches the sumcheck output + // Sumcheck wraps the prime polynomial with an eq factor, so include eq_ind_eval here. + let numerator_eval = (eval_num_a * eval_den_b + eval_num_b * eval_den_a) * eq_ind_eval; + let denominator_eval = (eval_den_a * eval_den_b) * eq_ind_eval; + let batched_eval = numerator_eval + denominator_eval * sumcheck_output.batch_coeff; + + assert_eq!( + batched_eval, sumcheck_output.eval, + "Batched evaluation should equal the reduced evaluation" + ); + + // Also verify the challenges match what the prover saw + let mut prover_challenges = output.challenges.clone(); + prover_challenges.reverse(); + assert_eq!( + prover_challenges, sumcheck_output.challenges, + "Prover and verifier challenges should match" + ); + } + + #[test] + fn test_frac_add_sumcheck() { + type F = OptimalB128; + type P = OptimalPackedB128; + + let n_vars = 8; + let mut rng = StdRng::seed_from_u64(0); + + let num = random_field_buffer::

(&mut rng, n_vars + 1); + let den = random_field_buffer::

(&mut rng, n_vars + 1); + let (num_a, num_b) = num.split_half_ref().unwrap(); + let (den_a, den_b) = den.split_half_ref().unwrap(); + + let numerator_values = + izip!(num_a.as_ref(), den_a.as_ref(), num_b.as_ref(), den_b.as_ref()) + .map(|(&num_a, &den_a, &num_b, &den_b)| num_a * den_b + num_b * den_a) + .collect_vec(); + + let denominator_values = izip!(den_a.as_ref(), den_b.as_ref()) + .map(|(&den_a, &den_b)| den_a * den_b) + .collect_vec(); + + let numerator_buffer = FieldBuffer::new(n_vars, numerator_values).unwrap(); + let denominator_buffer = FieldBuffer::new(n_vars, denominator_values).unwrap(); + + let eval_point = random_scalars::(&mut rng, n_vars); + // Claims are at the original eval_point; verifier handles challenge ordering separately. + let eval_claims = [ + evaluate(&numerator_buffer, &eval_point).unwrap(), + evaluate(&denominator_buffer, &eval_point).unwrap(), + ]; + + let frac_prover = + FracAddMleCheckProver::new((num.clone(), den.clone()), &eval_point, eval_claims) + .unwrap(); + + // Wrap the MLE-check prover so it emits sumcheck-compatible round polynomials. + let prover = MleToSumCheckDecorator::new(frac_prover); + + test_frac_add_sumcheck_prove_verify(prover, eval_claims, &eval_point, num, den); + } +} diff --git a/crates/prover/src/protocols/sumcheck/mod.rs b/crates/prover/src/protocols/sumcheck/mod.rs index a83925ea4..6f59d72a9 100644 --- a/crates/prover/src/protocols/sumcheck/mod.rs +++ b/crates/prover/src/protocols/sumcheck/mod.rs @@ -17,3 +17,4 @@ mod switchover; pub use error::*; pub use mle_to_sumcheck::*; pub use prove::*; +pub mod frac_add_mle; diff --git a/crates/prover/src/protocols/sumcheck/round_evals.rs b/crates/prover/src/protocols/sumcheck/round_evals.rs index 9f8267c2c..292c3aec3 100644 --- a/crates/prover/src/protocols/sumcheck/round_evals.rs +++ b/crates/prover/src/protocols/sumcheck/round_evals.rs @@ -55,7 +55,7 @@ impl Mul for RoundEvals1

{ // is defined as limit of P(X)/X^n as X approaches infinity, which equals the leading coefficient. // This is the Karatsuba trick. Take note that it may require removing lower-degree terms from the // composition polynomial. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Copy, Debug, Default)] pub struct RoundEvals2 { pub y_1: P, pub y_inf: P, diff --git a/crates/verifier/src/protocols/fracaddcheck.rs b/crates/verifier/src/protocols/fracaddcheck.rs new file mode 100644 index 000000000..7b0a68a56 --- /dev/null +++ b/crates/verifier/src/protocols/fracaddcheck.rs @@ -0,0 +1,124 @@ +// Copyright 2025-2026 The Binius Developers + +//! Reduction from fractional-addition layers to a multilinear evaluation claim. +//! +//! Each layer represents combining siblings with the fractional-addition rule: +//! (a0 / b0) + (a1 / b1) = (a0 * b1 + a1 * b0) / (b0 * b1). + +use binius_field::Field; +use binius_math::{line::extrapolate_line_packed, multilinear::eq::eq_ind}; +use binius_transcript::{ + Error as TranscriptError, VerifierTranscript, + fiat_shamir::{CanSample, Challenger}, +}; + +use crate::protocols::sumcheck::{self, BatchSumcheckOutput}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FracAddEvalClaim { + /// The evaluation of the numerator and denominator multilinears. + pub num_eval: F, + pub den_eval: F, + /// The evaluation point. + pub point: Vec, +} + +pub fn verify( + k: usize, + claim: FracAddEvalClaim, + transcript: &mut VerifierTranscript, +) -> Result, Error> { + if k == 0 { + return Ok(claim); + } + + let FracAddEvalClaim { + num_eval, + den_eval, + point, + } = claim; + + let n_vars = point.len(); + let sums = [num_eval, den_eval]; + + // Reduce numerator and denominator sum claims to evaluations at a challenge point. + let BatchSumcheckOutput { + batch_coeff, + eval, + mut challenges, + } = sumcheck::batch_verify(n_vars, 3, &sums, transcript)?; + + // Read evaluations of numerator/denominator halves at the reduced point. + let [num_0, num_1, den_0, den_1] = transcript.message().read()?; + + // Sumcheck binds variables high-to-low; reverse to low-to-high for point evaluation. + challenges.reverse(); + let reduced_eval_point = challenges; + + let eq_eval = eq_ind(&point, &reduced_eval_point); + let numerator_eval = (num_0 * den_1 + num_1 * den_0) * eq_eval; + let denominator_eval = (den_0 * den_1) * eq_eval; + let batched_eval = numerator_eval + denominator_eval * batch_coeff; + + if batched_eval != eval { + return Err(VerificationError::IncorrectLayerFractionSumEvaluation { round: k }.into()); + } + + // Reduce evaluations of the two halves to a single evaluation at the next point. + let r = transcript.sample(); + let next_num = extrapolate_line_packed(num_0, num_1, r); + let next_den = extrapolate_line_packed(den_0, den_1, r); + + let mut next_point = reduced_eval_point; + next_point.push(r); + + verify( + k - 1, + FracAddEvalClaim { + num_eval: next_num, + den_eval: next_den, + point: next_point, + }, + transcript, + ) +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("sumcheck error: {0}")] + Sumcheck(#[source] sumcheck::Error), + #[error("transcript error: {0}")] + Transcript(#[source] TranscriptError), + #[error("verification error: {0}")] + Verification(#[from] VerificationError), +} + +impl From for Error { + fn from(err: sumcheck::Error) -> Self { + match err { + sumcheck::Error::Verification(err) => VerificationError::Sumcheck(err).into(), + _ => Error::Sumcheck(err), + } + } +} + +impl From for Error { + fn from(err: TranscriptError) -> Self { + match err { + TranscriptError::NotEnoughBytes => VerificationError::TranscriptIsEmpty.into(), + _ => Error::Transcript(err), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum VerificationError { + #[error("sumcheck: {0}")] + Sumcheck(#[from] sumcheck::VerificationError), + #[error("incorrect layer fraction sum evaluation: {round}")] + IncorrectLayerFractionSumEvaluation { round: usize }, + #[error("incorrect round evaluation: {round}")] + IncorrectRoundEvaluation { round: usize }, + #[error("transcript is empty")] + TranscriptIsEmpty, +} diff --git a/crates/verifier/src/protocols/mod.rs b/crates/verifier/src/protocols/mod.rs index a301efff4..9adbb1544 100644 --- a/crates/verifier/src/protocols/mod.rs +++ b/crates/verifier/src/protocols/mod.rs @@ -1,6 +1,7 @@ // Copyright 2025 Irreducible Inc. pub mod basefold; +pub mod fracaddcheck; pub mod intmul; pub mod mlecheck; pub mod prodcheck;