diff --git a/src/sumcheck/product_polynomial.rs b/src/sumcheck/product_polynomial.rs index 516758c0..5a5ef466 100644 --- a/src/sumcheck/product_polynomial.rs +++ b/src/sumcheck/product_polynomial.rs @@ -20,7 +20,7 @@ //! over remaining variables. For quadratic sumcheck, `h(X)` is degree-2. use p3_challenger::{FieldChallenger, GrindingChallenger}; -use p3_field::{ExtensionField, Field, PackedFieldExtension, PackedValue, dot_product}; +use p3_field::{ExtensionField, PackedFieldExtension, PackedValue, TwoAdicField, dot_product}; use p3_util::log2_strict_usize; use tracing::instrument; @@ -71,7 +71,7 @@ use crate::{ /// This occurs after sufficient rounds of folding reduce the polynomial size below the /// SIMD efficiency threshold. #[derive(Debug, Clone)] -pub(crate) enum ProductPolynomial> { +pub(crate) enum ProductPolynomial> { /// SIMD-packed representation for large polynomials. /// /// Each element in `evals` and `weights` is an `EF::ExtensionPacking`, which holds @@ -109,7 +109,7 @@ pub(crate) enum ProductPolynomial> { }, } -impl> ProductPolynomial { +impl> ProductPolynomial { /// Creates a new [`ProductPolynomial`] from extension field evaluations. /// /// Automatically selects the optimal representation (packed or scalar) based on @@ -464,7 +464,7 @@ impl> ProductPolynomial { /// /// * `sum` - Running sum to update with new constraint contributions. /// * `constraint` - The constraint to combine into weights. - pub(crate) fn combine(&mut self, sum: &mut EF, constraint: &Constraint) { + pub(crate) fn combine(&mut self, sum: &mut EF, constraint: &Constraint) { match self { Self::Packed { weights, .. } => { constraint.combine_packed(weights, sum); diff --git a/src/sumcheck/sumcheck_single.rs b/src/sumcheck/sumcheck_single.rs index 461dd2d8..ad64c2e4 100644 --- a/src/sumcheck/sumcheck_single.rs +++ b/src/sumcheck/sumcheck_single.rs @@ -38,7 +38,7 @@ use crate::{ /// /// The sumcheck protocol ensures that the claimed sum is correct. #[derive(Debug, Clone)] -pub struct SumcheckSingle> { +pub struct SumcheckSingle> { /// Paired evaluation and weight polynomials for the quadratic sumcheck. /// /// This holds both `f(x)` (the polynomial being sumchecked) and `w(x)` (the constraint @@ -54,7 +54,7 @@ pub struct SumcheckSingle> { impl SumcheckSingle where - F: Field + Ord, + F: TwoAdicField + Ord, EF: ExtensionField, { /// Constructs a new `SumcheckSingle` instance from evaluations in the extension field. @@ -110,7 +110,7 @@ where challenger: &mut Challenger, folding_factor: usize, pow_bits: usize, - constraint: &Constraint, + constraint: &Constraint, ) -> (Self, MultilinearPoint) where F: TwoAdicField, @@ -157,7 +157,7 @@ where folding_factor: usize, pow_bits: usize, k_skip: usize, - constraint: &Constraint, + constraint: &Constraint, ) -> (Self, MultilinearPoint) where F: TwoAdicField, @@ -291,7 +291,7 @@ where challenger: &mut Challenger, folding_factor: usize, pow_bits: usize, - constraint: Option>, + constraint: Option>, ) -> MultilinearPoint where F: TwoAdicField, @@ -337,11 +337,11 @@ where /// * The verifier's challenge `r` as an `EF` element. /// * [`ProductPolynomial`] with new compressed polynomial evaluations and weights in the extension field. /// * Updated sum. -fn initial_round, Challenger>( +fn initial_round, Challenger>( evals: &EvaluationsList, sumcheck_data: &mut SumcheckData, challenger: &mut Challenger, - constraint: &Constraint, + constraint: &Constraint, pow_bits: usize, ) -> (ProductPolynomial, EF, EF) where diff --git a/src/sumcheck/sumcheck_single_svo.rs b/src/sumcheck/sumcheck_single_svo.rs index de01e9f8..b7b156c3 100644 --- a/src/sumcheck/sumcheck_single_svo.rs +++ b/src/sumcheck/sumcheck_single_svo.rs @@ -1,7 +1,7 @@ use alloc::{vec, vec::Vec}; use p3_challenger::{FieldChallenger, GrindingChallenger}; -use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_field::{ExtensionField, TwoAdicField}; use crate::{ poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, @@ -21,7 +21,7 @@ pub(crate) const NUM_SVO_ROUNDS: usize = 3; impl SumcheckSingle where - F: Field + Ord, + F: TwoAdicField + Ord, EF: ExtensionField, { /// Compute a Sumcheck using the Small Value Optimization (SVO) for the first three rounds and @@ -33,7 +33,7 @@ where challenger: &mut Challenger, folding_factor: usize, pow_bits: usize, - constraint: &Constraint, + constraint: &Constraint, ) -> (Self, MultilinearPoint) where F: TwoAdicField, diff --git a/src/sumcheck/tests.rs b/src/sumcheck/tests.rs index 36753a8f..a1a4aca4 100644 --- a/src/sumcheck/tests.rs +++ b/src/sumcheck/tests.rs @@ -16,7 +16,7 @@ use crate::{ constraints::{ Constraint, evaluator::ConstraintPolyEvaluator, - statement::{EqStatement, SelectStatement}, + statement::{DomainStatement, EqStatement}, }, parameters::InitialPhaseConfig, proof::{InitialPhase, SumcheckData, WhirProof}, @@ -76,7 +76,7 @@ fn make_constraint( num_eqs: usize, num_sels: usize, poly: &EvaluationsList, -) -> Constraint +) -> Constraint where Challenger: FieldChallenger + GrindingChallenger, { @@ -85,7 +85,7 @@ where // Create a new empty eq and select statements of that arity let mut eq_statement = EqStatement::initialize(num_vars); - let mut sel_statement = SelectStatement::initialize(num_vars); + let mut sel_statement = DomainStatement::initialize(num_vars, num_vars); // - Sample `num_eqs` univariate challenge points. // - Evaluate the sumcheck polynomial on them. @@ -130,7 +130,7 @@ where challenger.observe_algebra_element(eval); // Add the evaluation constraint: poly(point) == eval. - sel_statement.add_constraint(var, eval); + sel_statement.add_constraint(index, eval); }); // Return the constructed constraint with the alpha used for linear combination. @@ -146,7 +146,7 @@ fn make_constraint_ext( num_eqs: usize, num_sels: usize, poly: &EvaluationsList, -) -> Constraint +) -> Constraint where Challenger: FieldChallenger + GrindingChallenger, { @@ -155,7 +155,7 @@ where // Create a new empty eq and select statements of that arity let mut eq_statement = EqStatement::initialize(num_vars); - let mut sel_statement = SelectStatement::initialize(num_vars); + let mut sel_statement = DomainStatement::initialize(num_vars, num_vars); // - Sample `num_eqs` univariate challenge points. // - Evaluate the sumcheck polynomial on them. @@ -201,7 +201,7 @@ where challenger.observe_algebra_element(eval); // Add the evaluation constraint: poly(point) == eval. - sel_statement.add_constraint(var, eval); + sel_statement.add_constraint(index, eval); }); // Return the constructed constraint with the alpha used for linear combination. @@ -216,7 +216,7 @@ fn read_constraint( num_vars: usize, num_eqs: usize, num_sels: usize, -) -> Constraint +) -> Constraint where Challenger: FieldChallenger + GrindingChallenger, { @@ -237,16 +237,12 @@ where } // Create a new statement that will hold all reconstructed constraints. - let mut sel_statement = SelectStatement::::initialize(num_vars); - - // To simulate stir point derivation derive domain generator - let omega = F::two_adic_generator(num_vars); + let mut sel_statement = DomainStatement::initialize(num_vars, num_vars); // For each point, sample a challenge and read its corresponding evaluation from the proof. for i in 0..num_sels { // Simulate stir point derivation let index: usize = challenger.sample_bits(num_vars); - let var = omega.exp_u64(index as u64); // Read the committed evaluation corresponding to this point from constraint_evals. // Sel evaluations are stored after eq evaluations. @@ -256,7 +252,7 @@ where challenger.observe_algebra_element(eval); // Add the constraint: poly(point) == eval. - sel_statement.add_constraint(var, eval); + sel_statement.add_constraint(index, eval); } Constraint::new( @@ -510,7 +506,8 @@ fn run_sumcheck_test( // // No skip optimization, so the first round is treated as a standard sumcheck round. let evaluator = ConstraintPolyEvaluator::new(num_vars, folding_factor, None); - let weights = evaluator.eval_constraints_poly(&constraints, &verifier_randomness.reversed()); + let weights = + evaluator.eval_constraints_poly::(&constraints, &verifier_randomness.reversed()); // CHECK SUM == f(r) * weights(z, r) assert_eq!(sum, final_folded_value * weights); @@ -777,7 +774,7 @@ fn run_sumcheck_test_skips( // // Evaluate eq(z, r) using the unified constraint evaluation function. let evaluator = ConstraintPolyEvaluator::new(num_vars, folding_factor, Some(K_SKIP_SUMCHECK)); - let weights = evaluator.eval_constraints_poly(&constraints, &verifier_randomness); + let weights = evaluator.eval_constraints_poly::(&constraints, &verifier_randomness); // FINAL SUMCHECK CHECK // @@ -1006,7 +1003,8 @@ fn run_sumcheck_test_svo( // // No skip optimization, so the first round is treated as a standard sumcheck round. let evaluator = ConstraintPolyEvaluator::new(num_vars, folding_factor, None); - let weights = evaluator.eval_constraints_poly(&constraints, &verifier_randomness.reversed()); + let weights = + evaluator.eval_constraints_poly::(&constraints, &verifier_randomness.reversed()); // CHECK SUM == f(r) * weights(z, r) assert_eq!(sum, final_folded_value * weights); diff --git a/src/whir/constraints/evaluator.rs b/src/whir/constraints/evaluator.rs index 76de60ea..ab9e691a 100644 --- a/src/whir/constraints/evaluator.rs +++ b/src/whir/constraints/evaluator.rs @@ -8,9 +8,9 @@ use crate::{ }; /// Evaluate a single round's constraint. -fn eval_round + TwoAdicField>( +fn eval_round + TwoAdicField>( round: usize, - constraint: &Constraint, + constraint: &Constraint, original_point: &MultilinearPoint, context: &PointContext, ) -> EF { @@ -60,7 +60,7 @@ fn eval_round + TwoAdicField>( let sel_contribution = constraint .iter_sels() - .map(|(&var, coeff)| { + .map(|(var, coeff)| { let expanded = MultilinearPoint::expand_from_univariate(var, constraint.num_variables()); coeff * expanded.select_poly(&eval_point) @@ -108,9 +108,9 @@ impl ConstraintPolyEvaluator { /// Constraint i needs evaluation point matching its polynomial's remaining variables. /// This means using challenges from prover round i onwards + final sumcheck. #[must_use] - pub fn eval_constraints_poly + TwoAdicField>( + pub fn eval_constraints_poly + TwoAdicField>( &self, - constraints: &[Constraint], + constraints: &[Constraint], point: &MultilinearPoint, ) -> EF { let using_skip = self.univariate_skip.is_some(); @@ -180,7 +180,7 @@ mod tests { use crate::{ parameters::FoldingFactor, poly::evals::EvaluationsList, - whir::constraints::statement::{EqStatement, SelectStatement}, + whir::constraints::statement::{DomainStatement, EqStatement}, }; type F = BabyBear; @@ -233,8 +233,16 @@ mod tests { }); // Create select statement for the current domain size (20, then 15, then 10). - let mut sel_statement = SelectStatement::::initialize(num_vars_at_round); - (0..num_sel).for_each(|_| sel_statement.add_constraint(rng.random(), rng.random())); + let (indicies, evals): (Vec<_>, Vec<_>) = (0..num_sel) + .map(|_| { + ( + rng.random_range::(..1 << num_vars_at_round), + rng.random::(), + ) + }) + .unzip(); + let sel_statement = + DomainStatement::::new(num_vars_at_round, num_vars_at_round, &indicies, &evals); constraints.push(Constraint::new(gamma, eq_statement, sel_statement)); // Shrink the number of variables for the next round. @@ -246,7 +254,8 @@ mod tests { // Calculate W(r) using the function under test let evaluator = ConstraintPolyEvaluator::new(num_vars, folding_factor, None); - let result_from_eval_poly = evaluator.eval_constraints_poly(&constraints, &final_point); + let result_from_eval_poly = + evaluator.eval_constraints_poly::(&constraints, &final_point); // Calculate W(r) by materializing and evaluating round-by-round // This simpler, more direct method serves as our ground truth. @@ -257,7 +266,7 @@ mod tests { let num_vars = constraint.num_variables(); let mut combined = EvaluationsList::zero(num_vars); let mut eval = EF::ZERO; - constraint.combine(&mut combined, &mut eval); + constraint.combine::(&mut combined, &mut eval); let point = final_point.get_subpoint_over_range(0..num_vars).reversed(); combined.evaluate_hypercube_ext::(&point) }) @@ -333,8 +342,13 @@ mod tests { }); // Create select statement for the current domain size (20, then 15, then 10). - let mut sel_statement = SelectStatement::::initialize(num_vars_current); - (0..num_sel).for_each(|_| sel_statement.add_constraint(rng.random(), rng.random())); + let (indicies, evals):(Vec<_>, Vec<_>) = (0..num_sel) + .map(|_| (rng.random_range::(..1 << num_vars_current), rng.random::())) + .unzip(); + let sel_statement = DomainStatement::::new(num_vars_current, num_vars_current, &indicies, &evals); + + + constraints.push(Constraint::new(gamma, eq_statement, sel_statement)); // Shrink the number of variables for the next round. @@ -350,7 +364,7 @@ mod tests { // This is the recursive method we want to validate. let evaluator = ConstraintPolyEvaluator::new(num_vars, folding_factor, None); let result_from_eval_poly = - evaluator.eval_constraints_poly(&constraints, &final_point); + evaluator.eval_constraints_poly::(&constraints, &final_point); // Calculate W(r) by materializing and evaluating round-by-round // @@ -364,7 +378,7 @@ mod tests { let point = final_point.get_subpoint_over_range(0..num_vars_at_round).reversed(); let mut combined = EvaluationsList::zero(constraint.num_variables()); let mut eval = EF::ZERO; - constraint.combine(&mut combined, &mut eval); + constraint.combine::(&mut combined, &mut eval); num_vars_at_round -= folding_factors_vec[round_idx]; combined.evaluate_hypercube_ext::(&point) }) @@ -416,8 +430,16 @@ mod tests { }); // Create select statement for the current domain size (20, then 15, then 10). - let mut sel_statement = SelectStatement::::initialize(num_vars_at_round); - (0..num_sel).for_each(|_| sel_statement.add_constraint(rng.random(), rng.random())); + let (indicies, evals): (Vec<_>, Vec<_>) = (0..num_sel) + .map(|_| { + ( + rng.random_range::(..1 << num_vars_at_round), + rng.random::(), + ) + }) + .unzip(); + let sel_statement = + DomainStatement::::new(num_vars_at_round, num_vars_at_round, &indicies, &evals); constraints.push(Constraint::new(rng.random(), eq_statement, sel_statement)); // Shrink the number of variables for the next round. @@ -431,7 +453,8 @@ mod tests { // Calculate W(r) using the function under test let evaluator = ConstraintPolyEvaluator::new(num_vars, folding_factor, Some(K_SKIP_SUMCHECK)); - let result_from_eval_poly = evaluator.eval_constraints_poly(&constraints, &final_point); + let result_from_eval_poly = + evaluator.eval_constraints_poly::(&constraints, &final_point); // Manually compute W(r) with explicit recursive evaluation let mut expected_result = EF::ZERO; @@ -476,7 +499,7 @@ mod tests { .map(|constraint| { let mut combined = EvaluationsList::zero(constraint.num_variables()); let mut eval = EF::ZERO; - constraint.combine(&mut combined, &mut eval); + constraint.combine::(&mut combined, &mut eval); let point = r_rest.get_subpoint_over_range(0..constraint.num_variables()); combined.evaluate_hypercube_ext::(&point.reversed()) }) @@ -553,8 +576,10 @@ mod tests { }); // Create select statement for the current domain size (20, then 15, then 10). - let mut sel_statement = SelectStatement::::initialize(num_vars_current); - (0..num_sel).for_each(|_| sel_statement.add_constraint(rng.random(), rng.random())); + let (indicies, evals):(Vec<_>, Vec<_>) = (0..num_sel) + .map(|_| (rng.random_range::(..1 << num_vars_current), rng.random::())) + .unzip(); + let sel_statement = DomainStatement::::new(num_vars_current, num_vars_current, &indicies, &evals); constraints.push(Constraint::new(rng.random(), eq_statement, sel_statement)); // Shrink the number of variables for the next round. @@ -569,7 +594,7 @@ mod tests { // Calculate W(r) using the function under test let result_from_eval_poly = - evaluator.eval_constraints_poly(&constraints, &final_point); + evaluator.eval_constraints_poly::(&constraints, &final_point); // Calculate W(r) by materializing and evaluating round-by-round @@ -608,7 +633,7 @@ mod tests { .map(|constraint| { let mut combined = EvaluationsList::zero(constraint.num_variables()); let mut eval = EF::ZERO; - constraint.combine(&mut combined, &mut eval); + constraint.combine::(&mut combined, &mut eval); let point = r_rest.get_subpoint_over_range(0..constraint.num_variables()); combined.evaluate_hypercube_ext::(&point.reversed()) }) diff --git a/src/whir/constraints/mod.rs b/src/whir/constraints/mod.rs index f9421073..6f9e380a 100644 --- a/src/whir/constraints/mod.rs +++ b/src/whir/constraints/mod.rs @@ -1,9 +1,9 @@ -use p3_field::{ExtensionField, Field, PackedValue}; +use p3_field::{ExtensionField, Field, PackedValue, TwoAdicField}; use p3_util::log2_strict_usize; use crate::{ poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, - whir::constraints::statement::{EqStatement, SelectStatement}, + whir::constraints::statement::{DomainStatement, EqStatement}, }; /// Constraint evaluation utilities. @@ -35,7 +35,7 @@ pub mod statement; /// S = Σ_{i=0}^{n_eq-1} γ^i · s_eq_i + Σ_{j=0}^{n_sel-1} γ^{n_eq+j} · s_sel_j /// ``` #[derive(Clone, Debug)] -pub struct Constraint> { +pub struct Constraint { /// Equality-based evaluation constraints of the form `p(z_i) = s_i`. /// /// Each constraint specifies a point `z_i` and expected evaluation `s_i`. @@ -45,7 +45,7 @@ pub struct Constraint> { /// /// Each constraint specifies a univariate value `z_j` that is expanded /// via the power map to create a multilinear evaluation point. - pub sel_statement: SelectStatement, + pub sel_statement: DomainStatement, /// Random challenge `γ` used for batching constraints. /// @@ -55,7 +55,7 @@ pub struct Constraint> { pub challenge: EF, } -impl> Constraint { +impl Constraint { /// Creates a new constraint combining equality and select statements. /// /// This constructor initializes a unified constraint system that batches both @@ -80,7 +80,7 @@ impl> Constraint { pub const fn new( challenge: EF, eq_statement: EqStatement, - sel_statement: SelectStatement, + sel_statement: DomainStatement, ) -> Self { // Verify that both statements have the same number of variables. // @@ -110,7 +110,7 @@ impl> Constraint { /// /// A `Constraint` with the given equality statement and an empty select statement. #[must_use] - pub const fn new_eq_only(challenge: EF, eq_statement: EqStatement) -> Self { + pub fn new_eq_only(challenge: EF, eq_statement: EqStatement) -> Self { // Extract the number of variables from the equality statement. let num_variables = eq_statement.num_variables(); @@ -119,7 +119,7 @@ impl> Constraint { Self::new( challenge, eq_statement, - SelectStatement::initialize(num_variables), + DomainStatement::initialize(num_variables, num_variables), ) } @@ -194,7 +194,10 @@ impl> Constraint { /// ```text /// eval += Σ_i γ^i · s_eq_i + Σ_j γ^{n_eq+j} · s_sel_j /// ``` - pub fn combine(&self, combined: &mut EvaluationsList, eval: &mut EF) { + pub fn combine(&self, combined: &mut EvaluationsList, eval: &mut EF) + where + EF: ExtensionField, + { // Combine equality constraints with accumulation enabled (INITIALIZED=true). // This adds the equality portion of W(X) to the existing values in `combined`. self.eq_statement @@ -231,11 +234,13 @@ impl> Constraint { /// ```text /// eval += Σ_i γ^i · s_eq_i + Σ_j γ^{n_eq+j} · s_sel_j /// ``` - pub fn combine_packed( + pub fn combine_packed( &self, combined: &mut EvaluationsList, eval: &mut EF, - ) { + ) where + EF: ExtensionField, + { // Combine equality constraints with accumulation enabled (INITIALIZED=true). // This adds the equality portion of W(X) to the existing values in `combined`. self.eq_statement @@ -262,7 +267,10 @@ impl> Constraint { /// /// Use this method when starting a new constraint combination. /// Use [`combine`](Self::combine) when accumulating multiple constraints. - pub fn combine_new(&self) -> (EvaluationsList, EF) { + pub fn combine_new(&self) -> (EvaluationsList, EF) + where + EF: ExtensionField, + { // Initialize fresh accumulators for the weight polynomial and expected evaluation. // The weight polynomial needs 2^k entries for the full Boolean hypercube. let mut combined = EvaluationsList::zero(self.num_variables()); @@ -301,7 +309,10 @@ impl> Constraint { /// /// Use this method when starting a new constraint combination. /// Use [`combine_packed`](Self::combine_packed) when accumulating multiple constraints. - pub fn combine_new_packed(&self) -> (EvaluationsList, EF) { + pub fn combine_new_packed(&self) -> (EvaluationsList, EF) + where + EF: ExtensionField, + { let k_pack = log2_strict_usize(F::Packing::WIDTH); let k = self.num_variables(); @@ -382,12 +393,15 @@ impl> Constraint { /// /// Challenge powers are skipped by `n_eq` to ensure select constraints /// use distinct powers from equality constraints. - pub fn iter_sels(&self) -> impl Iterator { + pub fn iter_sels(&self) -> impl Iterator + where + EF: ExtensionField, + { // Pair each select variable with its corresponding challenge power. // Powers start at γ^{n_eq} to avoid overlap with equality constraints. self.sel_statement - .vars .iter() + .map(|(var, _)| var) .zip(self.challenge.powers().skip(self.eq_statement.len())) } } @@ -428,12 +442,12 @@ mod tests { EqStatement::new_hypercube(vec![eq_point_0, eq_point_1], vec![eq_eval_0, eq_eval_1]); // Create a select statement with 1 constraint - let sel_var = F::from_u64(7); + let idx = 1; let sel_eval = EF::from_u64(30); - let sel_statement = SelectStatement::new(num_variables, vec![sel_var], vec![sel_eval]); + let sel_statement = DomainStatement::new(num_variables, num_variables, &[idx], &[sel_eval]); // Construct the combined constraint - let constraint: Constraint = Constraint::new(challenge, eq_statement, sel_statement); + let constraint: Constraint = Constraint::new(challenge, eq_statement, sel_statement); // Verify that the constraint was constructed with correct fields assert_eq!(constraint.challenge, challenge); @@ -455,9 +469,10 @@ mod tests { // Select statement with 2 variables (different!) let num_variables_sel = 2; - let sel_var = F::from_u64(7); + let idx = 1; let sel_eval = EF::from_u64(30); - let sel_statement = SelectStatement::new(num_variables_sel, vec![sel_var], vec![sel_eval]); + let sel_statement = + DomainStatement::new(num_variables_sel, num_variables_sel, &[idx], &[sel_eval]); // Random challenge let challenge = EF::from_u64(42); @@ -489,7 +504,7 @@ mod tests { ); // Create constraint with only equality constraints - let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); + let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); // Verify that the select statement is empty assert_eq!(constraint.sel_statement.len(), 0); @@ -515,10 +530,10 @@ mod tests { // Create empty statements with the specified number of variables let eq_statement = EqStatement::initialize(num_variables); - let sel_statement = SelectStatement::initialize(num_variables); + let sel_statement = DomainStatement::initialize(num_variables, num_variables); // Create constraint - let constraint: Constraint = Constraint::new(challenge, eq_statement, sel_statement); + let constraint: Constraint = Constraint::new(challenge, eq_statement, sel_statement); // Verify that num_variables returns the correct value assert_eq!(constraint.num_variables(), num_variables); @@ -549,12 +564,12 @@ mod tests { // Create select statement with 1 constraint // Constraint 2: p(z_2) = 11, weighted by γ^2 = 4 - let sel_var = F::from_u64(3); + let idx = 1; let sel_eval = EF::from_u64(11); - let sel_statement = SelectStatement::new(num_variables, vec![sel_var], vec![sel_eval]); + let sel_statement = DomainStatement::new(num_variables, num_variables, &[idx], &[sel_eval]); // Create constraint - let constraint: Constraint = Constraint::new(gamma, eq_statement, sel_statement); + let constraint: Constraint = Constraint::new(gamma, eq_statement, sel_statement); // Initialize accumulator let mut eval = EF::ZERO; @@ -582,7 +597,7 @@ mod tests { let eq_point = MultilinearPoint::new(vec![EF::from_u64(1), EF::from_u64(1)]); let eq_eval = EF::from_u64(10); let eq_statement = EqStatement::new_hypercube(vec![eq_point], vec![eq_eval]); - let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); + let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); // Start with a non-zero accumulator let initial_value = EF::from_u64(100); @@ -611,10 +626,10 @@ mod tests { let eq_statement = EqStatement::new_hypercube(vec![eq_point], vec![eq_eval]); // Create constraint (eq-only for simplicity) - let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); + let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); // Combine into fresh accumulators - let (combined, eval) = constraint.combine_new(); + let (combined, eval) = constraint.combine_new::(); // Verify that the combined weight polynomial has the correct size // Should have 2^num_variables = 4 entries @@ -643,15 +658,15 @@ mod tests { let eq_point = MultilinearPoint::new(vec![EF::ZERO, EF::ONE]); let eq_eval = EF::from_u64(15); let eq_statement = EqStatement::new_hypercube(vec![eq_point], vec![eq_eval]); - let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); + let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); // Method 1: Use combine_new - let (combined_new, eval_new) = constraint.combine_new(); + let (combined_new, eval_new) = constraint.combine_new::(); // Method 2: Use combine with fresh accumulators let mut combined_manual = EvaluationsList::zero(num_variables); let mut eval_manual = EF::ZERO; - constraint.combine(&mut combined_manual, &mut eval_manual); + constraint.combine::(&mut combined_manual, &mut eval_manual); // Verify that both methods produce identical results assert_eq!(combined_new.0.len(), combined_manual.0.len()); @@ -673,7 +688,7 @@ mod tests { MultilinearPoint::new(vec![EF::from_u64(1), EF::from_u64(2), EF::from_u64(3)]); let eq_eval = EF::from_u64(99); let eq_statement = EqStatement::new_hypercube(vec![eq_point], vec![eq_eval]); - let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); + let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); // This should not panic because select statement is empty constraint.validate_for_skip_case(); @@ -694,11 +709,11 @@ mod tests { let eq_statement = EqStatement::initialize(num_variables); // Add a select constraint (this makes it invalid for skip) - let sel_var = F::from_u64(5); + let idx = 1; let sel_eval = EF::from_u64(25); - let sel_statement = SelectStatement::new(num_variables, vec![sel_var], vec![sel_eval]); + let sel_statement = DomainStatement::new(num_variables, num_variables, &[idx], &[sel_eval]); - let constraint: Constraint = Constraint::new(challenge, eq_statement, sel_statement); + let constraint: Constraint = Constraint::new(challenge, eq_statement, sel_statement); // This should panic because select statement is not empty constraint.validate_for_skip_case(); @@ -730,7 +745,7 @@ mod tests { ); // Create constraint - let constraint: Constraint = Constraint::new_eq_only(gamma, eq_statement); + let constraint: Constraint = Constraint::new_eq_only(gamma, eq_statement); // Collect iterator results let results: Vec<_> = constraint.iter_eqs().collect(); @@ -771,21 +786,22 @@ mod tests { // Create select statement with 2 constraints // These should use challenge powers γ^2 and γ^3 - let sel_var_0 = F::from_u64(5); + let sel_var_0 = 0; let sel_eval_0 = EF::from_u64(30); - let sel_var_1 = F::from_u64(6); + let sel_var_1 = 1; let sel_eval_1 = EF::from_u64(40); - let sel_statement = SelectStatement::new( + let sel_statement = DomainStatement::new( num_variables, - vec![sel_var_0, sel_var_1], - vec![sel_eval_0, sel_eval_1], + num_variables, + &[sel_var_0, sel_var_1], + &[sel_eval_0, sel_eval_1], ); // Create constraint - let constraint: Constraint = Constraint::new(gamma, eq_statement, sel_statement); + let constraint: Constraint = Constraint::new(gamma, eq_statement, sel_statement); // Collect iterator results - let results: Vec<_> = constraint.iter_sels().collect(); + let results: Vec<_> = constraint.iter_sels::().collect(); // Verify that we have 2 pairs assert_eq!(results.len(), 2); @@ -795,12 +811,13 @@ mod tests { let expected_weights = [EF::from_u64(4), EF::from_u64(8)]; let expected_vars = [sel_var_0, sel_var_1]; + let omega = F::two_adic_generator(num_variables); for (i, (var, coeff)) in results.iter().enumerate() { // Verify challenge weight assert_eq!(*coeff, expected_weights[i]); // Verify variable reference matches original - assert_eq!(**var, expected_vars[i]); + assert_eq!(*var, omega.exp_u64(expected_vars[i] as u64)); } } @@ -815,9 +832,9 @@ mod tests { let eq_point = MultilinearPoint::new(vec![EF::from_u64(1), EF::from_u64(2)]); let eq_eval = EF::from_u64(10); let eq_statement = EqStatement::new_hypercube(vec![eq_point], vec![eq_eval]); - let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); + let constraint: Constraint = Constraint::new_eq_only(challenge, eq_statement); // Verify that the iterator is empty - assert_eq!(constraint.iter_sels().count(), 0); + assert_eq!(constraint.iter_sels::().count(), 0); } } diff --git a/src/whir/constraints/statement/domain.rs b/src/whir/constraints/statement/domain.rs new file mode 100644 index 00000000..bf2ac725 --- /dev/null +++ b/src/whir/constraints/statement/domain.rs @@ -0,0 +1,1021 @@ +use alloc::vec::Vec; + +use itertools::Itertools; +use p3_dft::{Radix2DFTSmallBatch, TwoAdicSubgroupDft}; +use p3_field::{ + ExtensionField, Field, PackedFieldExtension, PackedValue, TwoAdicField, dot_product, +}; +use p3_maybe_rayon::prelude::*; +use p3_util::log2_strict_usize; +use tracing::instrument; + +use crate::poly::evals::EvaluationsList; + +/// A batched system of domain-indexed evaluation constraints for multilinear polynomials. +/// +/// This struct represents a collection of evaluation constraints of the form +/// `p(w^{idx_i}) = s_i` for a multilinear polynomial `p` over the Boolean hypercube +/// `{0,1}^k`, where `w` is a two-adic generator of a domain of size `2^{k_domain}` +/// and each constraint is specified by a domain index `idx_i`. The domain size is chosen +/// so that `k_domain >= k`, matching the sumcheck polynomial's variable count. +/// +/// # Domain Specialization +/// +/// The constraints are a special case of `select`-based statements where the points are +/// restricted to the multiplicative subgroup generated by `w`. For `b in {0,1}^k`, +/// `int(b)` interprets the Boolean vector as an integer, and: +/// +/// ```text +/// select(pow(w^{idx}), b) = (w^{idx})^{int(b)} = w^{idx * int(b)}. +/// ``` +/// +/// Because all evaluation points lie on the same subgroup, the batched weights can be +/// computed via a DFT over the domain instead of combining per-point powers. +/// +/// # Verification Claims +/// +/// Each constraint `(idx_i, s_i)` in this statement asserts: +/// +/// ```text +/// sum_{b in {0,1}^k} P(b) * w^{idx_i * int(b)} = s_i +/// ``` +/// +/// where `P(b)` are the evaluations of the polynomial over the Boolean hypercube. +/// +/// # Batching +/// +/// Multiple constraints are batched using random challenge `alpha` to produce: +/// +/// - **Weight polynomial**: `W(b) = sum_i alpha^i * w^{idx_i * int(b)}` +/// - **Target sum**: `S = sum_i alpha^i * s_i` +/// +/// This reduces `n` separate verification claims to a single sumcheck: +/// +/// ```text +/// sum_{b in {0,1}^k} P(b) * W(b) = S +/// ``` +#[derive(Clone, Debug)] +pub struct DomainStatement { + /// Number of variables `k` defining the Boolean hypercube `{0,1}^k`. + /// + /// This determines the dimension of the multilinear polynomial space and the size + /// of the evaluation domain (2^k points). + num_variables: usize, + + /// `k_domain` is log2 of the domain size `2^k_domain`, and must satisfy `k_domain >= k`. + k_domain: usize, + + /// Evaluation points `[w^{idx_0}, w^{idx_1}, ..., w^{idx_{n-1}}]` encoded by domain indices. + /// + /// Each index `idx_i` maps to the subgroup element `w^{idx_i}` where `w` is the + /// two-adic generator for the `2^k_domain` domain. + indicies: Vec, + + /// Expected evaluation values `[s_1, s_2, ..., s_n]` corresponding to each constraint. + /// + /// Each `s_i` in `EF` is an extension field element representing the claimed evaluation + /// of the polynomial at point `w^{idx_i}`. + evaluations: Vec, +} + +impl DomainStatement { + /// Creates an empty domain statement for polynomials over `{0,1}^k`. + /// + /// # Parameters + /// + /// - `num_variables`: The dimension `k` of the Boolean hypercube + /// - `k_domain`: log2 of the domain size for domain indices + /// + /// # Returns + /// + /// An initialized statement with no constraints, ready to accept constraints. + #[must_use] + pub fn initialize(num_variables: usize, k_domain: usize) -> Self { + assert!(num_variables <= k_domain); + Self { + k_domain, + num_variables, + indicies: Vec::new(), + evaluations: Vec::new(), + } + } + + /// Creates a domain statement pre-populated with constraints. + /// + /// # Parameters + /// + /// - `num_variables`: The dimension `k` of the Boolean hypercube + /// - `k_domain`: log2 of the domain size for domain indices + /// - `indicies`: Domain indices `[idx_0, ..., idx_{n-1}]` selecting points `w^{idx_i}` + /// - `evaluations`: Expected values `[s_1, ..., s_n]` + /// + /// # Panics + /// + /// Panics if the number of variables and evaluations do not match. + /// Panics if an index is larger than domain. + #[must_use] + pub fn new( + num_variables: usize, + k_domain: usize, + indicies: &[usize], + evaluations: &[EF], + ) -> Self { + assert!(num_variables <= k_domain); + assert!(indicies.len() == evaluations.len()); + assert!(indicies.iter().all(|&index| index < (1 << k_domain))); + // Remove duplicates + let (indicies, evaluations): (Vec<_>, _) = indicies + .iter() + .copied() + .zip(evaluations.iter().copied()) + .sorted_by(|(i0, _), (i1, _)| i0.cmp(i1)) + .dedup_by(|(i0, _), (i1, _)| i0 == i1) + .unzip(); + + Self { + num_variables, + k_domain, + indicies, + evaluations, + } + } + + /// Returns the number of variables `k` defining the polynomial space dimension. + /// + /// This is the dimension of the Boolean hypercube `{0,1}^k` over which polynomials + /// are defined, containing `2^k` evaluation points. + #[must_use] + pub const fn num_variables(&self) -> usize { + self.num_variables + } + + /// Returns `true` if no constraints have been added to this statement. + #[must_use] + pub const fn is_empty(&self) -> bool { + debug_assert!(self.indicies.is_empty() == self.evaluations.is_empty()); + self.indicies.is_empty() + } + + /// Returns an iterator over constraint pairs `(w^{idx_i}, s_i)`. + /// + /// Each pair represents one evaluation constraint: `p(w^{idx_i}) = s_i`. + pub fn iter(&self) -> impl Iterator + where + EF: ExtensionField, + { + if self.is_empty() { + itertools::Either::Left(core::iter::empty()) + } else { + let w = F::two_adic_generator(self.k_domain); + itertools::Either::Right( + self.indicies + .iter() + .map(move |&index| w.exp_u64(index as u64)) + .zip(self.evaluations.iter()), + ) + } + } + + /// Returns the number of evaluation constraints `n` in this statement. + #[must_use] + pub const fn len(&self) -> usize { + debug_assert!(self.indicies.len() == self.evaluations.len()); + self.indicies.len() + } + + /// Verifies that a given polynomial satisfies all constraints in the statement. + /// + /// For each constraint `(w^{idx_i}, s_i)`, this method interprets the evaluation table as + /// coefficients of a univariate polynomial, evaluates it at `w^{idx_i}` using Horner's method, + /// and checks if the result equals the expected value `s_i`. + /// + /// For a polynomial represented by evaluations `[c_0, c_1, ..., c_{2^k-1}]`: + /// + /// ```text + /// p(z) = c_0 + z(c_1 + z(c_2 + z(...))) + /// ``` + /// + /// This is computed right-to-left as: + /// ```text + /// acc = 0 + /// for i = 2^k-1 down to 0: + /// acc = acc * z + c_i + /// ``` + /// + /// # Parameters + /// + /// - `poly`: Evaluation table treated as univariate polynomial coefficients + /// + /// # Returns + /// + /// `true` if all constraints are satisfied, `false` otherwise. + #[must_use] + pub fn verify(&self, poly: &EvaluationsList) -> bool + where + EF: ExtensionField, + { + self.iter::().all(|(var, &expected_eval)| { + // Evaluate the polynomial at `var` using Horner's method. + // This computes: p(var) = c_0 + var(c_1 + var(c_2 + ...)) + poly.iter() + .rfold(EF::ZERO, |result, coeff| result * var + *coeff) + == expected_eval + }) + } + + /// Adds a single evaluation constraint `p(w^{idx}) = s` to the statement. + /// + /// # Parameters + /// + /// - `index`: Domain index `idx` selecting `w^{idx}` + /// - `eval`: Expected evaluation value `s` in `EF` + pub fn add_constraint(&mut self, index: usize, eval: EF) { + assert!(index < (1 << self.k_domain)); + if !self.indicies.contains(&index) { + self.evaluations.push(eval); + self.indicies.push(index); + } + } + + /// Batches all constraints into a single weighted polynomial and target sum for sumcheck. + /// + /// Given constraints `p(w^{idx_0}) = s_0, ..., p(w^{idx_{n-1}}) = s_{n-1}`, + /// this method transforms them into a single sumcheck claim using random challenge `alpha`: + /// + /// ```text + /// sum_{b in {0,1}^k} P(b) * W(b) = S + /// ``` + /// + /// where: + /// - **Weight polynomial**: `W(b) = sum_i alpha^{i+shift} * w^{idx_i * int(b)}` + /// - **Target sum**: `S = sum_i alpha^{i+shift} * s_i` + /// + /// The method computes `W(b)` for all `b in {0,1}^k` via a DFT over the `2^k_domain` + /// subgroup and adds the result, along with `S`, to the provided accumulators. + /// + /// # Parameters + /// + /// - `weights`: Accumulator for the weight polynomial `W(b)`. Must have `2^k` entries. + /// This method **adds** the batched weights to existing values. + /// + /// - `eval`: Accumulator for the target sum `S`. This method **adds** the batched + /// evaluations to the existing value. + /// + /// - `alpha`: Random challenge used for batching. + /// + /// - `shift`: Power offset for the challenge. Constraint `i` uses weight `alpha^{i+shift}`. + /// Allows multiple statement types to use non-overlapping challenge powers. + #[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))] + pub fn combine( + &self, + weights: &mut EvaluationsList, + eval: &mut EF, + alpha: EF, + shift: usize, + ) where + EF: ExtensionField, + { + if self.indicies.is_empty() { + return; + } + + self.combine_evals(eval, alpha, shift); + + let dft = Radix2DFTSmallBatch::::default(); + let mut sparse = EF::zero_vec(1 << self.k_domain); + self.indicies + .iter() + .zip(alpha.powers().skip(shift)) + .for_each(|(&index, challenge)| sparse[index] = challenge); + let pows_combined = dft.dft_algebra(sparse); + + weights + .0 + .par_iter_mut() + .zip(pows_combined.par_iter()) + .for_each(|(out, &val)| *out += val); + } + + #[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))] + pub fn combine_packed( + &self, + weights: &mut EvaluationsList, + eval: &mut EF, + challenge: EF, + shift: usize, + ) where + EF: ExtensionField, + { + if self.indicies.is_empty() { + return; + } + + let k = self.num_variables(); + let k_pack = log2_strict_usize(F::Packing::WIDTH); + assert!(k >= k_pack); + assert_eq!(weights.num_variables() + k_pack, k); + + self.combine_evals(eval, challenge, shift); + + let dft = Radix2DFTSmallBatch::::default(); + let mut sparse = EF::zero_vec(1 << self.k_domain); + self.indicies + .iter() + .zip(challenge.powers().skip(shift)) + .for_each(|(&index, challenge)| sparse[index] = challenge); + let pows_combined = dft.dft_algebra(sparse); + + weights + .0 + .par_iter_mut() + .zip(pows_combined.par_chunks(F::Packing::WIDTH)) + .for_each(|(out, chunk)| *out += EF::ExtensionPacking::from_ext_slice(chunk)); + } + + /// Batches expected evaluation values into a single target sum using challenge powers. + /// + /// Computes and adds to `claimed_eval`: + /// + /// ```text + /// S = sum_i alpha^{i+shift} * s_i + /// ``` + /// + /// where `s_i` are the expected evaluation values in `self.evaluations`. + /// + /// # Parameters + /// + /// - `claimed_eval`: Accumulator for the target sum. This method **adds** the batched + /// evaluations to the existing value. + /// + /// - `challenge`: Random challenge `alpha` used for batching. + /// + /// - `shift`: Power offset. Constraint `i` uses weight `alpha^{i+shift}`. + pub fn combine_evals(&self, claimed_eval: &mut EF, challenge: EF, shift: usize) { + // Compute: Σ_i γ^{i+shift} · s_i + // This is equivalent to dot_product(evaluations, [γ^shift, γ^{shift+1}, ...]) + *claimed_eval += dot_product::( + self.evaluations.iter().copied(), + challenge.powers().skip(shift).take(self.len()), + ); + } +} + +#[cfg(test)] +mod tests { + use alloc::vec; + + use p3_baby_bear::BabyBear; + use p3_field::{ + PackedFieldExtension, PrimeCharacteristicRing, extension::BinomialExtensionField, + }; + use proptest::prelude::*; + use rand::{SeedableRng, rngs::SmallRng}; + + use super::*; + + type F = BabyBear; + type EF = BinomialExtensionField; + + #[test] + fn test_domain_statement_initialize() { + // Test that initialize creates an empty statement with correct num_variables. + let statement = DomainStatement::::initialize(3, 3); + + // The statement should have 3 variables. + assert_eq!(statement.num_variables(), 3); + // The statement should be empty (no constraints). + assert!(statement.is_empty()); + // The length should be 0. + assert_eq!(statement.len(), 0); + } + + #[test] + fn test_domain_statement_new() { + // Test that new creates a statement with pre-populated constraints. + let indicies = vec![0, 1]; + let evaluations = vec![F::from_u64(10), F::from_u64(20)]; + + let statement = DomainStatement::new(2, 2, &indicies, &evaluations); + + // The statement should have 2 variables. + assert_eq!(statement.num_variables(), 2); + // The statement should not be empty. + assert!(!statement.is_empty()); + // The statement should have 2 constraints. + assert_eq!(statement.len(), 2); + // The vars and evaluations should match. + assert_eq!(statement.indicies, indicies); + assert_eq!(statement.evaluations, evaluations); + } + + #[test] + #[should_panic(expected = "assertion")] + fn test_domain_statement_new_mismatched_lengths() { + // Test that new panics when vars.len() != evaluations.len(). + let indicies = vec![5]; + let evaluations = vec![F::from_u64(10), F::from_u64(20)]; + + // This should panic due to length mismatch. + let _ = DomainStatement::new(2, 2, &indicies, &evaluations); + } + + #[test] + fn test_domain_statement_add_constraint() { + // Test adding constraints one at a time. + let num_variables = 10; + let w = F::two_adic_generator(num_variables); + let mut statement = DomainStatement::::initialize(num_variables, num_variables); + + // Initially empty. + assert!(statement.is_empty()); + assert_eq!(statement.len(), 0); + + // Add first constraint: p(5) = 10. + statement.add_constraint(5, F::from_u64(10)); + assert!(!statement.is_empty()); + assert_eq!(statement.len(), 1); + + // Add second constraint: p(7) = 20. + statement.add_constraint(7, F::from_u64(20)); + assert_eq!(statement.len(), 2); + + // Verify the constraints were added correctly. + let constraints: Vec<_> = statement.iter().collect(); + assert_eq!(constraints.len(), 2); + assert_eq!(constraints[0].0, w.exp_u64(5)); + assert_eq!(*constraints[0].1, F::from_u64(10)); + assert_eq!(constraints[1].0, w.exp_u64(7)); + assert_eq!(*constraints[1].1, F::from_u64(20)); + } + + #[test] + fn test_domain_statement_verify_basic() { + // Test the verify method with a simple polynomial. + // + // Create a polynomial with evaluations [c0, c1, c2, c3] over {0,1}^2. + let c0 = F::from_u64(1); + let c1 = F::from_u64(2); + let c2 = F::from_u64(3); + let c3 = F::from_u64(4); + let poly = EvaluationsList::new(vec![c0, c1, c2, c3]); + + // Create a statement with k=2 variables. + let k = 2; + let w = F::two_adic_generator(k); + let mut statement = DomainStatement::::initialize(k, k); + + // The polynomial evaluations [c0, c1, c2, c3] can be interpreted as a univariate polynomial: + // p(z) = c0 + c1*z + c2*z^2 + c3*z^3 + // + // Test p(0) = c0 = 1. + let eval0 = c0 + c1 + c2 + c3; + statement.add_constraint(0, eval0); + assert!(statement.verify(&poly)); + + // Test p(1) = c0 + c1 + c2 + c3 + let mut statement2 = DomainStatement::::initialize(k, k); + let z1 = w; + let eval1 = c0 + c1 * z1 + c2 * z1 * z1 + c3 * z1 * z1 * z1; + statement2.add_constraint(1, eval1); + assert!(statement2.verify(&poly)); + + // Test p(2) = c0 + c1*2 + c2*4 + c3*8 + let mut statement3 = DomainStatement::::initialize(k, k); + let z2 = w * w; + let eval2 = c0 + c1 * z2 + c2 * z2 * z2 + c3 * z2 * z2 * z2; + statement3.add_constraint(2, eval2); + assert!(statement3.verify(&poly)); + + // Test a failing verification: p(1) = wrong_eval + let mut statement4 = DomainStatement::::initialize(k, k); + let wrong_eval = F::from_u64(56765); + statement4.add_constraint(3, wrong_eval); + assert!(!statement4.verify(&poly)); + } + + #[test] + fn test_domain_statement_combine_single_constraint() { + // Test combining a single constraint. + // + // For k=2 variables, we have a 2^2 = 4-point domain. + let k = 2; + let domain_size = 1 << k; + let w = F::two_adic_generator(k); + + // Create a statement with one constraint: p(z) = s. + let mut statement = DomainStatement::::initialize(k, k); + let i = 2; + let z = w.exp_u64(i as u64); + let s = F::from_u64(100); + statement.add_constraint(i, s); + + // The challenge γ is unused for a single constraint (it would multiply by γ^0 = 1). + let gamma = F::from_u64(2); + let shift = 0; + + // Initialize accumulators. + let mut acc_weights = EvaluationsList::zero(k); + let mut acc_sum = F::ZERO; + + // Combine the constraints. + statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift); + + // The target sum should be S = γ^0 · s = 1 · s = s. + let expected_sum = s; + assert_eq!(acc_sum, expected_sum); + + // The weight polynomial should be W(b) = select(pow(z), b) for all b ∈ {0,1}^k. + // + // Verify each entry manually using the property: select(pow(z), b) = z^b. + for b in 0..domain_size { + let expected_weight = z.exp_u64(b as u64); + assert_eq!( + acc_weights.as_slice()[b], + expected_weight, + "Weight mismatch at index {b}" + ); + } + } + + #[test] + fn test_domain_statement_combine_multiple_constraints() { + // Test combining multiple constraints with batching. + // + // For k=2 variables, we have a 2^2 = 4-point domain. + let k = 2; + let domain_size = 1 << k; + let w = F::two_adic_generator(k); + + // Create a statement with two constraints: + // - Constraint 0: p(z0) = s0 + // - Constraint 1: p(z1) = s1 + let mut statement = DomainStatement::::initialize(k, k); + let i0 = 1; + let z0 = w.exp_u64(i0 as u64); + let s0 = F::from_u64(10); + let i1 = 2; + let z1 = w.exp_u64(i1 as u64); + let s1 = F::from_u64(20); + statement.add_constraint(i0, s0); + statement.add_constraint(i1, s1); + + // Use challenge γ for batching. + let gamma = F::from_u64(2); + let shift = 0; + + // Initialize accumulators. + let mut acc_weights = EvaluationsList::zero(k); + let mut acc_sum = F::ZERO; + + // Combine the constraints. + statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift); + + // The target sum should be: + // S = γ^0 · s0 + γ^1 · s1 = 1·s0 + γ·s1 = s0 + gamma*s1. + let expected_sum = s0 + gamma * s1; + assert_eq!(acc_sum, expected_sum); + + // The weight polynomial should be: + // W(b) = γ^0 · select(pow(z0), b) + γ^1 · select(pow(z1), b) + // = select(pow(z0), b) + gamma · select(pow(z1), b) + // Using the property: select(pow(z), b) = z^b. + for b in 0..domain_size { + let weight0 = z0.exp_u64(b as u64); + let weight1 = z1.exp_u64(b as u64); + let expected_weight = weight0 + gamma * weight1; + assert_eq!( + acc_weights.as_slice()[b], + expected_weight, + "Weight mismatch at index {b}" + ); + } + } + + #[test] + fn test_domain_statement_combine_with_shift() { + // Test combining constraints with a non-zero shift parameter. + // + // The shift parameter allows multiple statement types to use non-overlapping + // challenge powers for batching. + let k = 1; + let domain_size = 1 << k; + let w = F::two_adic_generator(k); + + // Create a statement with one constraint: p(z) = s. + let mut statement = DomainStatement::::initialize(k, k); + let i = 1; + let z = w.exp_u64(i as u64); + let s = F::from_u64(100); + statement.add_constraint(i, s); + + // Use challenge γ with shift. + // This means the constraint will be weighted by γ^{0+shift} = γ^shift. + let gamma = F::from_u64(2); + let shift = 3; + + // Initialize accumulators. + let mut acc_weights = EvaluationsList::zero(k); + let mut acc_sum = F::ZERO; + + // Combine the constraints. + statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift); + + // The target sum should be S = γ^shift · s. + let gamma_to_shift = gamma.exp_u64(shift as u64); + let expected_sum = gamma_to_shift * s; + assert_eq!(acc_sum, expected_sum); + + // The weight polynomial should be W(b) = γ^shift · select(pow(z), b). + // Using the property: select(pow(z), b) = z^b. + for b in 0..domain_size { + let select_val = z.exp_u64(b as u64); + let expected_weight = gamma_to_shift * select_val; + assert_eq!( + acc_weights.as_slice()[b], + expected_weight, + "Weight mismatch at index {b}" + ); + } + } + + #[test] + fn test_domain_statement_combine_empty() { + // Test that combining an empty statement does nothing. + let k = 2; + let statement = DomainStatement::::initialize(k, k); + + // Initialize accumulators with non-zero values. + let w0 = F::from_u64(1); + let w1 = F::from_u64(2); + let w2 = F::from_u64(3); + let w3 = F::from_u64(4); + let mut acc_weights = EvaluationsList::new(vec![w0, w1, w2, w3]); + let initial_sum = F::from_u64(99); + let mut acc_sum = initial_sum; + + // Store original values. + let original_weights = acc_weights.clone(); + let original_sum = acc_sum; + + // Combine the empty statement. + let gamma = F::from_u64(2); + let shift = 0; + statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift); + + // The accumulators should remain unchanged. + assert_eq!(acc_weights, original_weights); + assert_eq!(acc_sum, original_sum); + } + + #[test] + fn test_domain_statement_combine_accumulation() { + // Test that combine properly accumulates (adds to) existing values. + // + // This is important for batching multiple statements together. + let k = 1; + let domain_size = 1 << k; + let w = F::two_adic_generator(k); + + // Create first statement with constraint p(z1) = s1. + let mut statement1 = DomainStatement::::initialize(k, k); + let i1 = 0; + let s1 = F::from_u64(5); + statement1.add_constraint(i1, s1); + + // Create second statement with constraint p(z2) = s2. + let mut statement2 = DomainStatement::::initialize(k, k); + let i2 = 1; + let z2 = w.exp_u64(i2 as u64); + let s2 = F::from_u64(7); + statement2.add_constraint(i2, s2); + + let gamma = F::from_u64(2); + let shift = 0; + + // Initialize accumulators. + let mut acc_weights = EvaluationsList::zero(k); + let mut acc_sum = F::ZERO; + + // Combine first statement. + statement1.combine(&mut acc_weights, &mut acc_sum, gamma, shift); + + // Store intermediate values. + let intermediate_weights = acc_weights.clone(); + let intermediate_sum = acc_sum; + + // Combine second statement (should add to existing values). + statement2.combine(&mut acc_weights, &mut acc_sum, gamma, shift); + + // The accumulated sum should be intermediate_sum + s2. + let expected_sum = intermediate_sum + s2; + assert_eq!(acc_sum, expected_sum); + + // The accumulated weights should be the sum of both select functions. + // Using the property: select(pow(z), b) = z^b. + for b in 0..domain_size { + let weight2 = z2.exp_u64(b as u64); + let expected_weight = intermediate_weights.as_slice()[b] + weight2; + assert_eq!( + acc_weights.as_slice()[b], + expected_weight, + "Accumulated weight mismatch at index {b}" + ); + } + } + + #[test] + fn test_domain_statement_combine_evals() { + // Test the combine_evals method. + let k = 2; + + // Create a statement with two constraints. + let mut statement = DomainStatement::::initialize(k, k); + let s0 = F::from_u64(10); + let s1 = F::from_u64(20); + statement.add_constraint(0, s0); + statement.add_constraint(1, s1); + + let gamma = F::from_u64(2); + let shift = 1; + + // Test combine_evals. + let mut claimed_eval = F::ZERO; + statement.combine_evals(&mut claimed_eval, gamma, shift); + + // Expected: S = γ^{shift} · s0 + γ^{shift+1} · s1 = γ^1·s0 + γ^2·s1. + let gamma_1 = gamma.exp_u64(shift as u64); + let gamma_2 = gamma.exp_u64((shift + 1) as u64); + let expected = gamma_1 * s0 + gamma_2 * s1; + assert_eq!(claimed_eval, expected); + } + + #[test] + fn test_domain_statement_combine_evals_accumulation() { + // Test that combine_evals properly accumulates. + let k = 1; + + let mut statement = DomainStatement::::initialize(k, k); + let s = F::from_u64(10); + statement.add_constraint(0, s); + + let gamma = F::from_u64(3); + let shift = 0; + + // Start with a non-zero claimed_eval. + let initial_eval = F::from_u64(42); + let mut claimed_eval = initial_eval; + + // Combine evals should add to the existing value. + statement.combine_evals(&mut claimed_eval, gamma, shift); + + // Expected: initial_eval + γ^0 · s = initial_eval + 1·s = initial_eval + s. + let expected = initial_eval + s; + assert_eq!(claimed_eval, expected); + } + + #[test] + fn test_domain_combine_consistency_with_verify() { + // Test that combine and verify are consistent. + // + // If we create a polynomial that satisfies the constraints, then: + // 1. verify() should return true + // 2. The combined weights should correctly compute the polynomial evaluations + let k = 2; + let domain_size = 1 << k; + let w = F::two_adic_generator(k); + + // Create a simple polynomial: evaluations [c0, c1, c2, c3]. + let c0 = F::from_u64(1); + let c1 = F::from_u64(2); + let c2 = F::from_u64(3); + let c3 = F::from_u64(4); + let poly = EvaluationsList::new(vec![c0, c1, c2, c3]); + + // Create constraints that match the polynomial. + // Using Horner evaluation: p(z) = c0 + c1*z + c2*z^2 + c3*z^3. + let mut statement = DomainStatement::::initialize(k, k); + + // Evaluate p(z) at z using Horner's method. + // let z = F::from_u64(2); + let i = 1; + let z = w.exp_u64(i as u64); + let expected_eval = poly + .iter() + .rfold(F::ZERO, |result, &coeff| result * z + coeff); + statement.add_constraint(i, expected_eval); + + // Verify should pass. + assert!(statement.verify(&poly)); + + // Now combine and check that the weight polynomial correctly represents + // the select function. + let gamma = F::from_u64(3); + let shift = 0; + let mut acc_weights = EvaluationsList::zero(k); + let mut acc_sum = F::ZERO; + statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift); + + // The sum should match the expected evaluation. + assert_eq!(acc_sum, expected_eval); + + // The weight polynomial should satisfy: + // Σ_{b ∈ {0,1}^k} poly(b) · W(b) = expected_eval + let mut computed_sum = F::ZERO; + for b in 0..domain_size { + computed_sum += poly.as_slice()[b] * acc_weights.as_slice()[b]; + } + assert_eq!(computed_sum, expected_eval); + } + + proptest! { + #[test] + fn prop_select_statement_combine_sum( + // Number of variables (1 to 4 for reasonable test size). + k in 1usize..=4, + // Number of constraints (1 to 5). + num_constraints in 1usize..=5, + // Random evaluation points (avoiding 0 for better coverage). + // Generate exactly num_constraints values. + z_indicies in prop::collection::vec(1u32..100, 1..=5), + // Random expected evaluations. + s_values in prop::collection::vec(0u32..100, 1..=5), + // Random challenge. + challenge in 1u32..50, + ) { + // Ensure we have enough values for the test. + let actual_num_constraints = num_constraints.min(z_indicies.len()).min(s_values.len()); + if actual_num_constraints == 0 { + return Ok(()); + } + + let z_indicies = &z_indicies[..actual_num_constraints]; + let z_indicies = z_indicies.iter().map(|&z| z as usize % (1 << k)).sorted().dedup().collect::>(); + let s_values = &s_values[..z_indicies.len()]; + + // Create statement with random constraints. + let mut statement = DomainStatement::::initialize(k, k); + for (&z, &s) in z_indicies.iter().zip(s_values.iter()) { + statement.add_constraint(z, F::from_u32(s)); + } + + let gamma = F::from_u32(challenge); + + // Combine with shift=0. + let mut acc_weights = EvaluationsList::zero(k); + let mut acc_sum = F::ZERO; + statement.combine(&mut acc_weights, &mut acc_sum, gamma, 0); + + // Compute expected sum manually: S = Σ_i γ^i · s_i. + let mut expected_sum = F::ZERO; + for (i, &s) in s_values.iter().enumerate() { + expected_sum += gamma.exp_u64(i as u64) * F::from_u32(s); + } + + prop_assert_eq!(acc_sum, expected_sum); + } + } + + proptest! { + #[test] + fn prop_select_statement_verify( + // Polynomial evaluations (2^k values for k=3). + poly_evals in prop::collection::vec(0u32..100, 8), + // Evaluation point (avoiding 0 for better coverage). + i in 1u32..50, + ) { + let k = 3; // Fixed k=3 gives 2^3 = 8 evaluations. + let poly = EvaluationsList::new(poly_evals.into_iter().map(F::from_u32).collect()); + + // Compute expected evaluation using Horner's method. + let i = i as usize % (1 << k); + let w = F::two_adic_generator(k); + let z = w.exp_u64(i as u64); + let expected_eval = poly + .iter() + .rfold(F::ZERO, |result, &coeff| result * z + coeff); + + // Create statement with correct constraint. + let mut statement = DomainStatement::::initialize(k, k); + statement.add_constraint(i, expected_eval); + + // Verify should pass. + prop_assert!(statement.verify(&poly)); + + // Add a wrong constraint (off by 1, unless it wraps to same value). + let wrong_eval = expected_eval + F::ONE; + if wrong_eval != expected_eval { + statement.add_constraint((i+1) % (1 << k), wrong_eval); + // Verify should fail now. + prop_assert!(!statement.verify(&poly)); + } + } + } + + proptest! { + #[test] + fn prop_combine_evals_consistency( + // Number of constraints. + num_constraints in 1usize..=5, + // Random evaluations. + s_values in prop::collection::vec(0u32..100, 1..=5), + // Random challenge. + challenge in 1u32..50, + // Random shift. + shift in 0usize..3, + ) { + let num_variables = 2; + let s_values = &s_values[..num_constraints.min(s_values.len()).min(1<::initialize(num_variables, num_variables); + for (i, &s) in s_values.iter().enumerate() { + statement.add_constraint(i, F::from_u32(s)); + } + + let gamma = F::from_u32(challenge); + + // Method 1: Use combine_evals. + let mut claimed_eval1 = F::ZERO; + statement.combine_evals(&mut claimed_eval1, gamma, shift); + + // Method 2: Compute manually. + let mut claimed_eval2 = F::ZERO; + for (i, &s) in s_values.iter().enumerate() { + claimed_eval2 += gamma.exp_u64((i + shift) as u64) * F::from_u32(s); + } + + prop_assert_eq!(claimed_eval1, claimed_eval2); + } + } + + fn combine_ref>( + out: &mut EvaluationsList, + statement: &DomainStatement, + alpha: EF, + shift: usize, + ) { + let k = statement.num_variables(); + statement + .iter::() + .zip(alpha.powers().skip(shift)) + .for_each(|((var, _), alpha)| { + EF::from(var) + .shifted_powers(alpha) + .take(1 << k) + .zip(out.0.iter_mut()) + .for_each(|(el, out)| *out += el); + }); + } + + #[test] + fn test_packed_combine() { + type PackedExt = >::ExtensionPacking; + + let mut rng = SmallRng::seed_from_u64(1); + let alpha: EF = rng.random(); + let k_pack = log2_strict_usize(::Packing::WIDTH); + + let mut shift = 0; + for k in k_pack..10 { + for rate in 0..3 { + let k_domain = k + rate; + let mut out0 = EvaluationsList::zero(k); + let mut out1 = EvaluationsList::zero(k); + let mut out_packed = EvaluationsList::::zero(k - k_pack); + let mut sum0 = EF::ZERO; + let mut sum1 = EF::ZERO; + for n in [1, 2, 10, 11] { + let indicies = (0..n) + .map(|_| rng.random_range::(0..1 << (k_domain))) + .sorted() + .dedup() + .collect::>(); + + let evals = (0..indicies.len()) + .map(|_| rng.random()) + .collect::>(); + + let statement = DomainStatement::::new(k, k_domain, &indicies, &evals); + combine_ref::(&mut out0, &statement, alpha, shift); + statement.combine::(&mut out1, &mut sum0, alpha, shift); + assert_eq!(out0, out1); + statement.combine_packed::(&mut out_packed, &mut sum1, alpha, shift); + + assert_eq!(sum0, sum1); + assert_eq!( + out0.0, + <>::ExtensionPacking as PackedFieldExtension< + F, + EF, + >>::to_ext_iter( + out_packed.as_slice().iter().copied(), + ) + .collect::>() + ); + + shift += statement.len(); + } + } + } + } +} diff --git a/src/whir/constraints/statement/mod.rs b/src/whir/constraints/statement/mod.rs index 1bcca17d..9acecaa4 100644 --- a/src/whir/constraints/statement/mod.rs +++ b/src/whir/constraints/statement/mod.rs @@ -1,9 +1,8 @@ /// Equality statement for polynomial evaluation constraints. pub mod eq; -/// Selection statement for conditional constraints. -pub mod select; +pub mod domain; // Re-export main types for convenient access. +pub use domain::DomainStatement; pub use eq::EqStatement; -pub use select::SelectStatement; diff --git a/src/whir/prover/mod.rs b/src/whir/prover/mod.rs index 826465ad..604543d5 100644 --- a/src/whir/prover/mod.rs +++ b/src/whir/prover/mod.rs @@ -12,6 +12,7 @@ use p3_matrix::{ }; use p3_merkle_tree::MerkleTreeMmcs; use p3_symmetric::{CryptographicHasher, PseudoCompressionFunction}; +use p3_util::log2_strict_usize; use round_state::RoundState; use serde::{Deserialize, Serialize}; use tracing::{info_span, instrument}; @@ -26,7 +27,7 @@ use crate::{ fiat_shamir::errors::FiatShamirError, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ - constraints::{Constraint, statement::SelectStatement}, + constraints::{Constraint, statement::DomainStatement}, proof::{QueryOpening, SumcheckData, WhirProof}, utils::get_challenge_stir_queries, }, @@ -292,12 +293,10 @@ where challenger, )?; - let stir_vars = stir_challenges_indexes - .iter() - .map(|&i| round_state.next_domain_gen.exp_u64(i as u64)) - .collect::>(); - - let mut stir_statement = SelectStatement::initialize(num_variables); + let mut stir_statement = DomainStatement::initialize( + num_variables, + log2_strict_usize(round_state.domain_size >> self.folding_factor.at_round(round_index)), + ); // Initialize vector of queries let mut queries = Vec::with_capacity(stir_challenges_indexes.len()); @@ -327,7 +326,7 @@ where && self.folding_factor.at_round(0) >= K_SKIP_SUMCHECK; // Process each set of evaluations retrieved from the Merkle tree openings. - for (answer, var) in answers.iter().zip(stir_vars.into_iter()) { + for (answer, idx) in answers.iter().zip(stir_challenges_indexes.into_iter()) { let evals = EvaluationsList::new(answer.clone()); // Fold the polynomial represented by the `answer` evaluations using the verifier's challenge. // The evaluation method depends on whether this is a "skip round" or a "standard round". @@ -363,7 +362,7 @@ where // Evaluate the resulting smaller polynomial at the remaining challenges `r_rest`. let eval = EvaluationsList::new(folded_row).evaluate_hypercube_ext::(&r_rest); - stir_statement.add_constraint(var, eval); + stir_statement.add_constraint(idx, eval); } else { // Case 2: Standard Sumcheck Round // @@ -371,7 +370,7 @@ where // Perform a standard multilinear evaluation at the full challenge point `r`. let eval = evals.evaluate_hypercube_base(&round_state.folding_randomness); - stir_statement.add_constraint(var, eval); + stir_statement.add_constraint(idx, eval); } } } @@ -388,12 +387,12 @@ where } // Process each set of evaluations retrieved from the Merkle tree openings. - for (answer, var) in answers.iter().zip(stir_vars.into_iter()) { + for (answer, idx) in answers.iter().zip(stir_challenges_indexes.into_iter()) { // Wrap the evaluations to represent the polynomial. let evals = EvaluationsList::new(answer.clone()); // Perform a standard multilinear evaluation at the full challenge point `r`. let eval = evals.evaluate_hypercube_ext::(&round_state.folding_randomness); - stir_statement.add_constraint(var, eval); + stir_statement.add_constraint(idx, eval); } } } diff --git a/src/whir/verifier/mod.rs b/src/whir/verifier/mod.rs index 470d7089..c1010d43 100644 --- a/src/whir/verifier/mod.rs +++ b/src/whir/verifier/mod.rs @@ -9,6 +9,7 @@ use p3_interpolation::interpolate_subgroup; use p3_matrix::Dimensions; use p3_merkle_tree::MerkleTreeMmcs; use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction}; +use p3_util::log2_strict_usize; use serde::{Deserialize, Serialize}; use tracing::instrument; @@ -21,7 +22,7 @@ use crate::{ poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ EqStatement, - constraints::{Constraint, evaluator::ConstraintPolyEvaluator, statement::SelectStatement}, + constraints::{Constraint, evaluator::ConstraintPolyEvaluator, statement::DomainStatement}, parameters::{InitialPhaseConfig, WhirConfig}, proof::{QueryOpening, WhirProof}, verifier::sumcheck::{ @@ -81,11 +82,8 @@ where if self.initial_phase_config.has_initial_statement() { statement.concatenate(&prev_commitment.ood_statement); - let constraint = Constraint::new( - challenger.sample_algebra_element(), - statement, - SelectStatement::initialize(self.num_variables), - ); + let constraint = + Constraint::new_eq_only(challenger.sample_algebra_element(), statement); // Combine claimed evals with combination randomness constraint.combine_evals(&mut claimed_eval); constraints.push(constraint); @@ -263,7 +261,7 @@ where commitment: &ParsedCommitment>, folding_randomness: &MultilinearPoint, round_index: usize, - ) -> Result, VerifierError> + ) -> Result, VerifierError> where H: CryptographicHasher + Sync, C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, @@ -370,16 +368,11 @@ where } }) .collect(); - - let stir_constraints = stir_challenges_indexes - .iter() - .map(|&index| params.folded_domain_gen.exp_u64(index as u64)) - .collect(); - - Ok(SelectStatement::new( + Ok(DomainStatement::new( params.num_variables, - stir_constraints, - folds, + log2_strict_usize(params.domain_size >> params.folding_factor), + &stir_challenges_indexes, + &folds, )) } diff --git a/src/whir/verifier/sumcheck.rs b/src/whir/verifier/sumcheck.rs index c247dfd9..472651ed 100644 --- a/src/whir/verifier/sumcheck.rs +++ b/src/whir/verifier/sumcheck.rs @@ -749,7 +749,7 @@ mod tests { // Save a fresh copy for verify_initial_sumcheck_rounds let mut verifier_challenger_for_verify = verifier_challenger.clone(); - let (_, mut expected_initial_sum) = constraint.combine_new(); + let (_, mut expected_initial_sum) = constraint.combine_new::(); // Start with the claimed sum before folding let mut current_sum = expected_initial_sum;