diff --git a/src/poly/evals.rs b/src/poly/evals.rs index 3ba44940..75f82e09 100644 --- a/src/poly/evals.rs +++ b/src/poly/evals.rs @@ -919,6 +919,23 @@ where // REDUCTION: Merge all thread-local accumulators .par_fold_reduce(SvoAccumulators::new, |a, b| a + b, |a, b| a + b) } + + /// Compute g = coeff_a·f_a + coeff_b·f_b element-wise + pub fn linear_combination>( + f_a: &Self, + coeff_a: EF, + f_b: &Self, + coeff_b: EF, + ) -> EvaluationsList { + assert_eq!(f_a.num_evals(), f_b.num_evals()); + let evals = f_a + .as_slice() + .iter() + .zip(f_b.as_slice().iter()) + .map(|(&a, &b)| coeff_a * a + coeff_b * b) + .collect(); + EvaluationsList::new(evals) + } } impl EvaluationsList { @@ -1246,6 +1263,21 @@ mod tests { assert_eq!(evaluations_list.as_slice(), &evals); } + #[test] + fn test_linear_combination() { + let f_a = EvaluationsList::new(vec![F::ONE, F::TWO]); + let f_b = EvaluationsList::new(vec![F::TWO, F::ONE]); + let coeff_a = EF4::from_u64(3); + let coeff_b = EF4::from_u64(2); + + let result = EvaluationsList::::linear_combination(&f_a, coeff_a, &f_b, coeff_b); + + // result[0] = 3*1 + 2*2 = 7, result[1] = 3*2 + 2*1 = 8 + assert_eq!(result.as_slice().len(), 2); + assert_eq!(result.as_slice()[0], EF4::from_u64(7)); + assert_eq!(result.as_slice()[1], EF4::from_u64(8)); + } + #[test] #[should_panic] fn test_new_evaluations_list_invalid_length() { diff --git a/src/whir/batch_proof.rs b/src/whir/batch_proof.rs new file mode 100644 index 00000000..4ab2386e --- /dev/null +++ b/src/whir/batch_proof.rs @@ -0,0 +1,480 @@ +//! Batch opening proof for WHIR protocol using the selector variable approach. +//! +//! This module implements batch polynomial opening where two polynomials f_a and f_b +//! are opened simultaneously at their respective evaluation points z_a and z_b. +//! +//! # Selector Variable Approach +//! +//! Given two polynomials f_a and f_b (both with m variables), we construct a combined +//! polynomial f_c with m+1 variables: +//! +//! ```text +//! f_c(X, x_1, ..., x_m) = X·f_a(x_1, ..., x_m) + α(1-X)·f_b(x_1, ..., x_m) +//! ``` +//! +//! The selector variable X chooses between f_a (when X=1) and f_b (when X=0). +//! α is the folding randomness chosen by the Verifier. +//! +//! The combined weight polynomial is: +//! +//! ```text +//! w(X, b) = X·eq(b, z_a) + α(1-X)·eq(b, z_b) +//! ``` +//! +//! The first sumcheck round folds the selector variable, producing: +//! - Folded polynomial: g(b) = r_0·f_a(b) + (1-r_0)·f_b(b) +//! - Folded weights: w'(b) = r_0·eq(b, z_a) + α(1-r_0)·eq(b, z_b) +//! - Folded claim: σ' = r_0·v_a + α(1-r_0)·v_b +//! +//! The protocol then continues with standard WHIR on the folded polynomial. + +use alloc::{vec, vec::Vec}; +use hashbrown::Equivalent; +use p3_challenger::{FieldChallenger, GrindingChallenger}; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_matrix::dense::DenseMatrix; +use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction}; +use serde::{Deserialize, Serialize}; +use tracing::instrument; + +use super::proof::SumcheckData; +use crate::{ + fiat_shamir::errors::FiatShamirError, + poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, + sumcheck::{product_polynomial::ProductPolynomial, sumcheck_single::SumcheckSingle}, + whir::{ + WhirProof, + committer::{Witness, reader::ParsedCommitment}, + constraints::statement::EqStatement, + prover::{Prover, RoundState}, + verifier::{Verifier, errors::VerifierError}, + }, +}; + +/// Batch opening proof wrapper +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(bound( + serialize = "F: Serialize, EF: Serialize, [F; DIGEST_ELEMS]: Serialize", + deserialize = "F: Deserialize<'de>, EF: Deserialize<'de>, [F; DIGEST_ELEMS]: Deserialize<'de>" +))] +pub struct BatchWhirProof { + /// Commitment to first polynomial (f_a) + pub commitment_a: [F; DIGEST_ELEMS], + + /// Commitment to second polynomial (f_b) + pub commitment_b: [F; DIGEST_ELEMS], + + pub initial_ood_answers: [Vec; 2], + + /// Selector sumcheck data: stores [c0, c2] for h(X) = c0 + c1·X + c2·X² + /// c0 = h(0) = α·v_b + /// c2 = quadratic coefficient + /// Verifier derives c1 + pub selector_sumcheck: SumcheckData, + + /// Inner WHIR proof on the folded polynomial g = r_0·f_a + (1-r_0)·f_b + pub inner_proof: WhirProof, +} + +impl Prover<'_, EF, F, H, C, Challenger> +where + F: TwoAdicField + Ord, + EF: ExtensionField + TwoAdicField, + Challenger: FieldChallenger + GrindingChallenger, +{ + /// Performs batch opening of two polynomials using the selector variable approach. + /// + /// This function executes the batch WHIR protocol: + /// 1. Sample batching randomness α + /// 2. Run selector round (first round of sumcheck on selector variable) + /// 3. Fold to get combined polynomial g + /// 4. Continue with standard WHIR on g + /// + /// # Arguments + /// + /// * `dft` - DFT backend for polynomial operations + /// * `proof` - Mutable proof structure to fill in + /// * `challenger` - Fiat-Shamir transcript + /// * `statement_a` - Evaluation constraints for polynomial A (point z_a, value v_a) + /// * `witness_a` - Polynomial A with its Merkle commitment + /// * `statement_b` - Evaluation constraints for polynomial B (point z_b, value v_b) + /// * `witness_b` - Polynomial B with its Merkle commitment + /// + /// # Errors + /// + /// Returns an error if the protocol fails. + #[allow(clippy::too_many_arguments)] + pub fn batch_prove, const DIGEST_ELEMS: usize>( + &self, + dft: &Dft, + proof: &mut BatchWhirProof, + challenger: &mut Challenger, + statement_a: &EqStatement, + witness_a: &Witness, DIGEST_ELEMS>, + statement_b: &EqStatement, + witness_b: &Witness, DIGEST_ELEMS>, + ) -> Result<(), FiatShamirError> + where + H: CryptographicHasher + + CryptographicHasher + + Sync, + C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + + PseudoCompressionFunction<[F::Packing; DIGEST_ELEMS], 2> + + Sync, + [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, + { + // Validate that both polynomials have the same number of variables + assert_eq!( + witness_a.polynomial.num_variables(), + witness_b.polynomial.num_variables(), + "Batch opening requires same-degree polynomials" + ); + + let num_variables = witness_a.polynomial.num_variables(); + + // Store commitments + proof.commitment_a = witness_a.prover_data.root().into(); + proof.commitment_b = witness_b.prover_data.root().into(); + + // Sample batching randomness α + let alpha: EF = challenger.sample_algebra_element(); + + // Extract claims from statements + // For now, we assume single-constraint statements + let (z_a, v_a) = extract_single_constraint(statement_a); + let (z_b, v_b) = extract_single_constraint(statement_b); + + // Run selector round + let (sumcheck_prover, r_0) = self.selector_round( + &mut proof.selector_sumcheck, + challenger, + &witness_a.polynomial, + &witness_b.polynomial, + &z_a, + &z_b, + v_a, + v_b, + alpha, + ); + + // Fold OOD constraints (assuming same OOD points for both, in same order) + let mut folded_ood = EqStatement::initialize(num_variables); + for ((point_a, &v_a), (point_b, &v_b)) in witness_a + .ood_statement + .iter() + .zip(witness_b.ood_statement.iter()) + { + debug_assert_eq!(point_a, point_b, "OOD points must match"); + let folded_value = r_0 * v_a + (EF::ONE - r_0) * v_b; + folded_ood.add_evaluated_constraint(point_a.clone(), folded_value); + } + + // Initialize selector rounds for batching + let mut round_state = RoundState { + // Starting domain H_0 with |H_0| = 2^m evaluation points + domain_size: self.starting_domain_size(), + // Compute next domain generator: ω_1 = ω_0^{2^k} for H_1 after folding + next_domain_gen: F::two_adic_generator( + self.starting_domain_size().ilog2() as usize - self.folding_factor.at_round(0), + ), + // Sumcheck prover configured for constraint verification + sumcheck_prover: sumcheck_prover.clone(), + // Current round's folding challenges (α_1, ..., α_k) + folding_randomness: MultilinearPoint::new(vec![r_0]), + // Merkle commitment from witness for base field polynomial + commitment_merkle_prover_data: vec![ + witness_a.prover_data.clone(), + witness_b.prover_data.clone(), + ], + merkle_prover_data: None, + // No extension field commitment yet (first round operates in base field) + // Constraint set augmented with OOD evaluations + statement: folded_ood, + }; + + // Run the WHIR protocol round-by-round + for round in 0..=self.n_rounds() { + self.round( + dft, + round, + &mut proof.inner_proof, + challenger, + &mut round_state, + )?; + } + + Ok(()) + } + + /// Executes the selector round of the batch opening protocol. + /// + /// This is the first round of sumcheck over the selector variable X, + /// which combines the two polynomials into one folded polynomial. + /// + /// # Arguments + /// + /// * `selector_data` - Sumcheck data structure to fill in + /// * `challenger` - Fiat-Shamir transcript + /// * `f_a` - Evaluations of the first polynomial + /// * `f_b` - Evaluations of the second polynomial + /// * `z_a` - Evaluation point for f_a + /// * `z_b` - Evaluation point for f_b + /// * `v_a` - Claimed evaluation f_a(z_a) + /// * `v_b` - Claimed evaluation f_b(z_b) + /// * `alpha` - Batching randomness + /// + /// # Returns + /// + /// A tuple containing: + /// * `SumcheckSingle` - The folded sumcheck state + /// * `EF` - The folding challenge r_0 + #[allow(clippy::too_many_arguments)] + fn selector_round( + &self, + selector_data: &mut SumcheckData, + challenger: &mut Challenger, + f_a: &EvaluationsList, + f_b: &EvaluationsList, + z_a: &MultilinearPoint, + z_b: &MultilinearPoint, + v_a: EF, + v_b: EF, + alpha: EF, + ) -> (SumcheckSingle, EF) { + // Create combined polynomial and weights for selector round: + // The combined polynomial f_c(X, b) = X·f_a(b) + (1-X)·f_b(b) over {0,1}^{m+1} + // has evaluations [f_b | f_a] (first half is f_b at X=0, second half is f_a at X=1) + // + // The combined weight w(X, b) = X·eq(b, z_a) + α(1-X)·eq(b, z_b) + // has evaluations [α·eq(·, z_b) | eq(·, z_a)] + + // Build combined polynomial: [f_b | f_a] + let combined_evals: Vec = f_b + .as_slice() + .iter() + .chain(f_a.as_slice().iter()) + .copied() + .collect(); + let combined_poly = EvaluationsList::new(combined_evals); + + // Build combined weights: [α·eq(·, z_b) | eq(·, z_a)] + let eq_z_b = EvaluationsList::new_from_point(z_b.as_slice(), alpha); + let eq_z_a = EvaluationsList::new_from_point(z_a.as_slice(), EF::ONE); + let combined_weights_vec: Vec = eq_z_b + .as_slice() + .iter() + .chain(eq_z_a.as_slice().iter()) + .copied() + .collect(); + let combined_weights = EvaluationsList::new(combined_weights_vec); + + // Compute sumcheck polynomial coefficients: + // h(X) = Σ_b f_c(X, b) · w(X, b) + // We compute c0 = h(0) and c2 (quadratic coefficient) + let (c0, c2) = combined_poly.sumcheck_coefficients(&combined_weights); + + // Sanity check: c0 should equal α·v_b + debug_assert_eq!(c0, alpha * v_b, "c0 = h(0) should equal α·v_b"); + + // Observe Fiat-Shamir + let pow_bits = self.starting_folding_pow_bits; + let r_0 = selector_data.observe_and_sample::(challenger, c0, c2, pow_bits); + + // Fold the polynomial and weights: + // g(b) = r_0·f_a(b) + (1-r_0)·f_b(b) + // w'(b) = r_0·eq(b, z_a) + α(1-r_0)·eq(b, z_b) + // σ' = h(r_0) + + // Folded polynomial in extension field + let g = EvaluationsList::linear_combination(f_a, r_0, f_b, EF::ONE - r_0); + + // Folded weights: r_0·eq(·, z_a) + (1-r_0)·α·eq(·, z_b) + let w_prime: Vec = eq_z_a + .iter() + .zip(eq_z_b.iter()) + .map(|(&a, &b)| r_0 * a + (EF::ONE - r_0) * b) + .collect(); + let w_prime = EvaluationsList::new(w_prime); + + // Folded sum: σ' = h(r_0) + let sigma = v_a + alpha * v_b; + let h_1 = sigma - c0; + let c1 = h_1 - c0 - c2; + let sigma_prime = c0 + c1 * r_0 + c2 * r_0.square(); + + // Create SumcheckSingle for continuation + let poly = ProductPolynomial::new(g, w_prime); + debug_assert_eq!(poly.dot_product(), sigma_prime); + + let sumcheck_prover = SumcheckSingle { + poly, + sum: sigma_prime, + }; + + (sumcheck_prover, r_0) + } +} + +/// Extracts a single constraint (point, evaluation) from an EqStatement. +/// +/// # Panics +/// +/// Panics if the statement is empty. +/// TODO: remove or generalize this function as batch opening needs to work for +/// in the general case with multiple constrants +fn extract_single_constraint(statement: &EqStatement) -> (MultilinearPoint, EF) { + assert!( + !statement.is_empty(), + "Statement must contain at least one constraint" + ); + let (point, &eval) = statement.iter().next().unwrap(); + (point.clone(), eval) +} + +impl<'a, EF, F, H, C, Challenger> Verifier<'a, EF, F, H, C, Challenger> +where + F: TwoAdicField, + EF: ExtensionField + TwoAdicField, + Challenger: FieldChallenger + GrindingChallenger, +{ + #[instrument(skip_all)] + #[allow(clippy::too_many_lines)] + pub fn batch_verify( + &self, + proof: &BatchWhirProof, + challenger: &mut Challenger, + parsed_commitments: Vec<&ParsedCommitment>>, + mut statement_a: EqStatement, + mut statement_b: EqStatement, + ) -> Result, VerifierError> + where + H: CryptographicHasher + Sync, + C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, + [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, + { + todo!() + } +} + +#[cfg(test)] +mod tests { + use alloc::vec; + + use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; + use p3_challenger::DuplexChallenger; + use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; + use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; + use rand::{SeedableRng, rngs::SmallRng}; + + use super::*; + use crate::{ + parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, + whir::{parameters::InitialPhaseConfig, prover::Prover}, + }; + + type F = BabyBear; + type EF = BinomialExtensionField; + type Perm = Poseidon2BabyBear<16>; + type MyHash = PaddingFreeSponge; + type MyCompress = TruncatedPermutation; + type MyChallenger = DuplexChallenger; + type MyWhirConfig = + crate::whir::parameters::WhirConfig; + + /// Test selector_round produces correct folded polynomial and sumcheck state. + /// + /// Verifies: + /// 1. The returned SumcheckSingle has correct folded polynomial g = r_0·f_a + (1-r_0)·f_b + /// 2. h(r_0) has been computed correctly + #[test] + fn test_selector_round() { + let num_vars = 2; + + // Create test polynomials + let f_a = EvaluationsList::new(vec![ + F::from_u64(1), + F::from_u64(2), + F::from_u64(3), + F::from_u64(4), + ]); + let f_b = EvaluationsList::new(vec![ + F::from_u64(5), + F::from_u64(6), + F::from_u64(7), + F::from_u64(8), + ]); + + // Create evaluation points + let z_a = MultilinearPoint::new(vec![EF::from_u64(2), EF::from_u64(3)]); + let z_b = MultilinearPoint::new(vec![EF::from_u64(5), EF::from_u64(7)]); + + // Compute actual evaluations v_a = f_a(z_a), v_b = f_b(z_b) + let v_a = f_a.evaluate_hypercube_base::(&z_a); + let v_b = f_b.evaluate_hypercube_base::(&z_b); + + // Batching randomness + let alpha = EF::from_u64(11); + + // Set up minimal WhirConfig for creating a Prover + let mut rng = SmallRng::seed_from_u64(42); + let perm = Perm::new_from_rng_128(&mut rng); + let merkle_hash = MyHash::new(perm.clone()); + let merkle_compress = MyCompress::new(perm.clone()); + + let whir_params = ProtocolParameters { + initial_phase_config: InitialPhaseConfig::WithStatementClassic, + security_level: 32, + pow_bits: 0, + rs_domain_initial_reduction_factor: 1, + folding_factor: FoldingFactor::Constant(2), + merkle_hash, + merkle_compress, + soundness_type: SecurityAssumption::CapacityBound, + starting_log_inv_rate: 1, + }; + let config = MyWhirConfig::new(num_vars + 2, whir_params); + + let prover = Prover(&config); + + // Create challenger + let mut challenger = MyChallenger::new(perm); + + // Run selector_round + let mut selector_data = SumcheckData::default(); + let (sumcheck_prover, r_0) = prover.selector_round( + &mut selector_data, + &mut challenger, + &f_a, + &f_b, + &z_a, + &z_b, + v_a, + v_b, + alpha, + ); + + // Verify the folded polynomial g = r_0·f_a + (1-r_0)·f_b + let expected_g = EvaluationsList::linear_combination(&f_a, r_0, &f_b, EF::ONE - r_0); + assert_eq!( + sumcheck_prover.evals().as_slice(), + expected_g.as_slice(), + "Folded polynomial should be g = r_0·f_a + (1-r_0)·f_b" + ); + + // Verify selector_data was populated + assert_eq!( + selector_data.polynomial_evaluations.len(), + 1, + "Should have one sumcheck round recorded" + ); + + // Compute h(r_0) from the recorded coefficients + let [c0, c2] = selector_data.polynomial_evaluations[0]; + let sigma = v_a + alpha * v_b; // original claim + let c1 = sigma - c0 - c0 - c2; // c1 = σ - 2·c0 - c2 + let h_at_r0 = c0 + c1 * r_0 + c2 * r_0 * r_0; + + assert_eq!(sumcheck_prover.sum, h_at_r0, "sigma' should equal h(r_0)"); + } +} diff --git a/src/whir/mod.rs b/src/whir/mod.rs index 98754b82..52a2d2d5 100644 --- a/src/whir/mod.rs +++ b/src/whir/mod.rs @@ -19,6 +19,7 @@ use crate::{ whir::proof::WhirProof, }; +pub mod batch_proof; pub mod committer; pub mod constraints; pub mod parameters; diff --git a/src/whir/proof.rs b/src/whir/proof.rs index 59f10a62..66163194 100644 --- a/src/whir/proof.rs +++ b/src/whir/proof.rs @@ -159,6 +159,12 @@ pub enum QueryOpening { /// Merkle authentication path proof: Vec<[F; DIGEST_ELEMS]>, }, + + #[serde(rename = "batch")] + Batch { + values: [Vec; 2], + proof: [Vec<[F; DIGEST_ELEMS]>; 2], + }, } /// Sumcheck polynomial data @@ -760,6 +766,7 @@ mod tests { assert_eq!(p.len(), 1); } QueryOpening::Extension { .. } => panic!("Expected Base variant"), + QueryOpening::Batch { .. } => panic!("Expected Base variant"), } // Test Extension variant @@ -788,6 +795,7 @@ mod tests { assert_eq!(p.len(), 1); } QueryOpening::Base { .. } => panic!("Expected Extension variant"), + QueryOpening::Batch { .. } => panic!("Expected Extension variant"), } } diff --git a/src/whir/prover/mod.rs b/src/whir/prover/mod.rs index 826465ad..938cfc62 100644 --- a/src/whir/prover/mod.rs +++ b/src/whir/prover/mod.rs @@ -12,7 +12,7 @@ use p3_matrix::{ }; use p3_merkle_tree::MerkleTreeMmcs; use p3_symmetric::{CryptographicHasher, PseudoCompressionFunction}; -use round_state::RoundState; +pub use round_state::RoundState; use serde::{Deserialize, Serialize}; use tracing::{info_span, instrument}; @@ -138,8 +138,7 @@ where /// /// /// # Errors - /// Returns an error if the witness or statement are invalid, or if a round fails. - #[instrument(skip_all)] + /// Returns an error if the witness or statement are invalid, or if a round f #[instrument(skip_all)] pub fn prove, const DIGEST_ELEMS: usize>( &self, dft: &Dft, @@ -179,7 +178,7 @@ where #[instrument(skip_all, fields(round_number = round_index, log_size = self.num_variables - self.folding_factor.total_number(round_index)))] #[allow(clippy::too_many_lines)] - fn round>( + pub fn round>( &self, dft: &Dft, round_index: usize, @@ -306,72 +305,121 @@ where match &round_state.merkle_prover_data { None => { let mut answers = Vec::with_capacity(stir_challenges_indexes.len()); - for challenge in &stir_challenges_indexes { - let commitment = - mmcs.open_batch(*challenge, &round_state.commitment_merkle_prover_data); - let answer = commitment.opened_values[0].clone(); - answers.push(answer.clone()); - - queries.push(QueryOpening::Base { - values: answer.clone(), - proof: commitment.opening_proof, - }); - } - - // Determine if this is the special first round where the univariate skip is applied. - let is_skip_round = round_index == 0 - && matches!( - self.initial_phase_config, - InitialPhaseConfig::WithStatementUnivariateSkip - ) - && self.folding_factor.at_round(0) >= K_SKIP_SUMCHECK; + let is_batch = round_state.commitment_merkle_prover_data.len() > 1; + + if is_batch { + // Batch mode: open both trees at each query + for &challenge in &stir_challenges_indexes { + let commit_a = mmcs.open_batch( + challenge, + round_state.commitment_merkle_prover_data[0].as_ref(), + ); + let commit_b = mmcs.open_batch( + challenge, + round_state.commitment_merkle_prover_data[1].as_ref(), + ); + + queries.push(QueryOpening::Batch { + values: [ + commit_a.opened_values[0].clone(), + commit_b.opened_values[0].clone(), + ], + proof: [commit_a.opening_proof, commit_b.opening_proof], + }); + // Fold the opened values: g(b) = r_0·f_a(b) + (1-r_0)·f_b(b) + let r_0 = round_state.folding_randomness.as_slice()[0]; + let values_a = &commit_a.opened_values[0]; + let values_b = &commit_b.opened_values[0]; + + // Fold element-wise to get the combined answer + let folded_answer: Vec = values_a + .iter() + .zip(values_b.iter()) + .map(|(&a, &b)| r_0 * EF::from(a) + (EF::ONE - r_0) * EF::from(b)) + .collect(); + + answers.push(folded_answer); + + // Process folded answers into constraints + for (answer, var) in answers { + let evals = EvaluationsList::new(answer); + let eval = evals.evaluate_hypercube_ext::(&round_state.folding_randomness); + stir_statement.add_constraint(var, eval); + } + } + } else { + for challenge in &stir_challenges_indexes { + let commitment = mmcs.open_batch( + *challenge, + round_state.commitment_merkle_prover_data[0].as_ref(), + ); + let answer = commitment.opened_values[0].clone(); + answers.push(answer.clone()); + + queries.push(QueryOpening::Base { + values: answer.clone(), + proof: commitment.opening_proof, + }); + + + } - // Process each set of evaluations retrieved from the Merkle tree openings. - for (answer, var) in answers.iter().zip(stir_vars.into_iter()) { - let evals = EvaluationsList::new(answer.clone()); - // Fold the polynomial represented by the `answer` evaluations using the verifier's challenge. - // The evaluation method depends on whether this is a "skip round" or a "standard round". - if is_skip_round { - // Case 1: Univariate Skip Round Evaluation - // - - // The challenges for the remaining (non-skipped) variables. - let num_remaining_vars = evals.num_variables() - K_SKIP_SUMCHECK; - - // The width of the matrix corresponds to the number of remaining variables. - let width = 1 << num_remaining_vars; - - // Reshape the `answer` evaluations into the `2^k x 2^(n-k)` matrix format. - let mat = evals.into_mat(width); - - // For a skip round, `folding_randomness` is the special `(n-k)+1` challenge object. - let r_all = round_state.folding_randomness.clone(); - - // Deconstruct the special challenge object `r_all`. - // - // The last element is the single challenge for the `k_skip` variables being folded. - let r_skip = *r_all - .last_variable() - .expect("skip challenge must be present"); - // The first `n - k_skip` elements are the challenges for the remaining variables. - let r_rest = r_all.get_subpoint_over_range(..num_remaining_vars); - - // Perform the two-stage skip-aware evaluation: - // - // "Fold" the skipped variables by interpolating the matrix at `r_skip`. - let folded_row = interpolate_subgroup(&mat, r_skip); - // Evaluate the resulting smaller polynomial at the remaining challenges `r_rest`. - let eval = - EvaluationsList::new(folded_row).evaluate_hypercube_ext::(&r_rest); - stir_statement.add_constraint(var, eval); - } else { - // Case 2: Standard Sumcheck Round - // - // The `answer` represents a standard multilinear polynomial. - - // Perform a standard multilinear evaluation at the full challenge point `r`. - let eval = evals.evaluate_hypercube_base(&round_state.folding_randomness); - stir_statement.add_constraint(var, eval); + // Determine if this is the special first round where the univariate skip is applied. + let is_skip_round = round_index == 0 + && matches!( + self.initial_phase_config, + InitialPhaseConfig::WithStatementUnivariateSkip + ) + && self.folding_factor.at_round(0) >= K_SKIP_SUMCHECK; + + // Process each set of evaluations retrieved from the Merkle tree openings. + for (answer, var) in answers.iter().zip(stir_vars.into_iter()) { + let evals = EvaluationsList::new(answer.clone()); + // Fold the polynomial represented by the `answer` evaluations using the verifier's challenge. + // The evaluation method depends on whether this is a "skip round" or a "standard round". + if is_skip_round { + // Case 1: Univariate Skip Round Evaluation + // + + // The challenges for the remaining (non-skipped) variables. + let num_remaining_vars = evals.num_variables() - K_SKIP_SUMCHECK; + + // The width of the matrix corresponds to the number of remaining variables. + let width = 1 << num_remaining_vars; + + // Reshape the `answer` evaluations into the `2^k x 2^(n-k)` matrix format. + let mat = evals.into_mat(width); + + // For a skip round, `folding_randomness` is the special `(n-k)+1` challenge object. + let r_all = round_state.folding_randomness.clone(); + + // Deconstruct the special challenge object `r_all`. + // + // The last element is the single challenge for the `k_skip` variables being folded. + let r_skip = *r_all + .last_variable() + .expect("skip challenge must be present"); + // The first `n - k_skip` elements are the challenges for the remaining variables. + let r_rest = r_all.get_subpoint_over_range(..num_remaining_vars); + + // Perform the two-stage skip-aware evaluation: + // + // "Fold" the skipped variables by interpolating the matrix at `r_skip`. + let folded_row = interpolate_subgroup(&mat, r_skip); + // Evaluate the resulting smaller polynomial at the remaining challenges `r_rest`. + let eval = EvaluationsList::new(folded_row) + .evaluate_hypercube_ext::(&r_rest); + stir_statement.add_constraint(var, eval); + } else { + // Case 2: Standard Sumcheck Round + // + // The `answer` represents a standard multilinear polynomial. + + // Perform a standard multilinear evaluation at the full challenge point `r`. + let eval = + evals.evaluate_hypercube_base(&round_state.folding_randomness); + stir_statement.add_constraint(var, eval); + } } } } @@ -488,8 +536,10 @@ where match &round_state.merkle_prover_data { None => { for challenge in final_challenge_indexes { - let commitment = - mmcs.open_batch(challenge, &round_state.commitment_merkle_prover_data); + let commitment = mmcs.open_batch( + challenge, + round_state.commitment_merkle_prover_data[0].as_ref(), + ); proof.final_queries.push(QueryOpening::Base { values: commitment.opened_values[0].clone(), diff --git a/src/whir/prover/round_state/state.rs b/src/whir/prover/round_state/state.rs index 6d86c82a..d5825b25 100644 --- a/src/whir/prover/round_state/state.rs +++ b/src/whir/prover/round_state/state.rs @@ -2,7 +2,7 @@ //! //! This module implements the core round state management for the WHIR protocol. -use alloc::{sync::Arc, vec::Vec}; +use alloc::{sync::Arc, vec, vec::Vec}; use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_field::{ExtensionField, TwoAdicField}; @@ -79,16 +79,23 @@ where /// cryptographic soundness. The length equals the folding factor k for this round. pub folding_randomness: MultilinearPoint, - /// Merkle tree commitment for the base field polynomial f: F^n → F. + /// Merkle tree commitments for base field polynomials f: F^n → F. /// - /// This commitment covers the initial polynomial evaluation table over the starting - /// domain H_0. The Merkle tree enables selective opening of polynomial values at + /// This vector contains commitments for one or more polynomials. For single-proof + /// scenarios, this contains exactly one tree. For batch opening, it contains + /// multiple trees (one per polynomial being batch-opened). + /// + /// The trees enable selective opening of polynomial values at /// verifier-chosen query points while maintaining cryptographic integrity. /// - /// In WHIR's proximity testing, this commitment proves the prover knows some - /// polynomial that is purportedly close to a Reed-Solomon codeword. The verifier + /// In WHIR's proximity testing, these commitments prove the prover knows + /// polynomials that are purportedly close to Reed-Solomon codewords. The verifier /// can later query specific positions to verify proximity claims. - pub commitment_merkle_prover_data: Arc>, + /// + /// Using `Vec>` allows sharing individual trees independently, + /// which is useful when passing single trees to functions without cloning + /// the entire collection. + pub commitment_merkle_prover_data: Vec>>, /// Merkle tree commitment for extension field polynomials f': (EF)^{n-k} → EF. /// @@ -277,8 +284,8 @@ where sumcheck_prover, // Current round's folding challenges (α_1, ..., α_k) folding_randomness, - // Merkle commitment from witness for base field polynomial - commitment_merkle_prover_data: witness.prover_data, + // Merkle commitment from witness for base field polynomial (single-element vector) + commitment_merkle_prover_data: vec![witness.prover_data], // No extension field commitment yet (first round operates in base field) merkle_prover_data: None, // Constraint set augmented with OOD evaluations diff --git a/src/whir/verifier/mod.rs b/src/whir/verifier/mod.rs index 470d7089..9ca09b58 100644 --- a/src/whir/verifier/mod.rs +++ b/src/whir/verifier/mod.rs @@ -476,6 +476,7 @@ where values.clone() } + &QueryOpening::Batch { .. } => todo!(), }; results.push(values_ef);