diff --git a/crates/prover/src/merkle_tree/binary_merkle_tree.rs b/crates/prover/src/merkle_tree/binary_merkle_tree.rs index 9ffc17f81..654738ae4 100644 --- a/crates/prover/src/merkle_tree/binary_merkle_tree.rs +++ b/crates/prover/src/merkle_tree/binary_merkle_tree.rs @@ -1,16 +1,17 @@ // Copyright 2024-2025 Irreducible Inc. -use std::{fmt::Debug, iter::repeat_with, mem::MaybeUninit}; +use std::{fmt::Debug, mem::MaybeUninit}; use binius_field::Field; use binius_utils::{ checked_arithmetics::log2_strict_usize, mem::slice_assume_init_mut, + rand::par_rand, rayon::{prelude::*, slice::ParallelSlice}, }; use binius_verifier::merkle_tree::Error; use digest::{FixedOutputReset, Output, crypto_common::BlockSizeUser}; -use rand::{CryptoRng, Rng}; +use rand::{CryptoRng, Rng, rngs::StdRng}; use crate::hash::{ParallelDigest, parallel_compression::ParallelPseudoCompression}; @@ -78,9 +79,8 @@ where let log_len = log2_strict_usize(iterated_chunks.len()); // precondition // Generate salts if needed - let salts = repeat_with(|| F::random(&mut rng)) - .take(salt_len << log_len) - .collect::>(); + let salts = + par_rand::(salt_len << log_len, &mut rng, F::random).collect::>(); let total_length = (1 << (log_len + 1)) - 1; let mut inner_nodes = Vec::with_capacity(total_length); diff --git a/crates/prover/src/protocols/basefold.rs b/crates/prover/src/protocols/basefold.rs index 8b082e761..63d1c0fc9 100644 --- a/crates/prover/src/protocols/basefold.rs +++ b/crates/prover/src/protocols/basefold.rs @@ -1,7 +1,10 @@ // Copyright 2025 Irreducible Inc. use binius_field::{BinaryField, PackedField}; -use binius_math::{FieldBuffer, ntt::AdditiveNTT}; +use binius_math::{ + FieldBuffer, inner_product::inner_product_par, line::extrapolate_line_packed, + multilinear::fold::fold_highest_var_inplace, ntt::AdditiveNTT, +}; use binius_transcript::{ ProverTranscript, fiat_shamir::{CanSample, Challenger}, @@ -66,7 +69,7 @@ where fri_folder: FRIFoldProver<'a, F, P, NTT, MerkleProver>, ) -> Self { assert_eq!(multilinear.log_len(), transparent_multilinear.log_len()); - assert_eq!(multilinear.log_len(), fri_folder.n_rounds()); + assert_eq!(multilinear.log_len(), fri_folder.n_rounds() - fri_folder.curr_round()); let sumcheck_prover = BivariateProductSumcheckProver::new([multilinear, transparent_multilinear], claim) @@ -154,6 +157,69 @@ where } } +/// Performs ZK batching setup and returns a BaseFoldProver for the remaining protocol. +/// +/// This handles the zero-knowledge case where the witness is blinded with a random mask. +/// It performs the initial unbatch round and returns a prover configured for the remaining +/// n rounds of sumcheck + FRI. +/// +/// ## Arguments +/// +/// * `multilinear` - batched (witness || mask) polynomial with log_len = n+1 +/// * `transparent_multilinear` - l_poly with log_len = n +/// * `claim` - the original sumcheck claim (before ZK batching) +/// * `fri_folder` - FRI fold prover with n_rounds = n+1 +/// * `transcript` - prover transcript +/// +/// ## Returns +/// +/// A `BaseFoldProver` configured for the remaining n rounds. Caller should call +/// `.prove(transcript)`. +pub fn prove_zk<'a, F, P, NTT, MerkleScheme, MerkleProver, Challenger_>( + mut multilinear: FieldBuffer

, + transparent_multilinear: FieldBuffer

, + sum_claim: F, + mut fri_folder: FRIFoldProver<'a, F, P, NTT, MerkleProver>, + transcript: &mut ProverTranscript, +) -> BaseFoldProver<'a, F, P, NTT, MerkleProver> +where + F: BinaryField, + P: PackedField, + NTT: AdditiveNTT + Sync, + MerkleScheme: MerkleTreeScheme, + MerkleProver: MerkleTreeProver, + Challenger_: Challenger, +{ + let _scope = tracing::debug_span!("Basefold ZK setup").entered(); + + assert_eq!(multilinear.log_len(), transparent_multilinear.log_len() + 1); + assert_eq!(multilinear.log_len(), fri_folder.n_rounds()); + + // Compute blinding_eval = sum_x[mask * l_poly] + // The verifier will compute sum = (1-r)*claim + r*blinding_eval using linear interpolation. + let (_witness, mask) = multilinear + .split_half_ref() + .expect("multilinear has log_len >= 1"); + let mask_claim = inner_product_par(&mask, &transparent_multilinear); + + // Write blinding_eval to transcript + transcript.message().write(&mask_claim); + + // Sample batch challenge (before FRI fold round, matching verifier order) + let batch_challenge: F = transcript.sample(); + + // Receive batch challenge to advance to round 1 (no commitment at batch round) + fri_folder.receive_challenge(batch_challenge); + + // Fold multilinear at its last variable. + fold_highest_var_inplace(&mut multilinear, batch_challenge) + .expect("multilinear has log_len >= 1"); + + // Compute the batched sum using linear interpolation. + let batched_sum = extrapolate_line_packed(sum_claim, mask_claim, batch_challenge); + BaseFoldProver::new(multilinear, transparent_multilinear, batched_sum, fri_folder) +} + #[cfg(test)] mod test { use anyhow::{Result, bail}; @@ -177,7 +243,7 @@ mod test { }; use rand::{SeedableRng, rngs::StdRng}; - use super::BaseFoldProver; + use super::{BaseFoldProver, prove_zk}; use crate::{ fri::{self, CommitOutput, FRIFoldProver}, hash::parallel_compression::ParallelCompressionAdaptor, @@ -267,13 +333,13 @@ mod test { { let mut rng = StdRng::from_seed([0; 32]); - let multilinear = random_field_buffer::

(&mut rng, n_vars); + let witness = random_field_buffer::

(&mut rng, n_vars); let evaluation_point = random_scalars::(&mut rng, n_vars); let eval_point_eq = eq_ind_partial_eval(&evaluation_point); - let evaluation_claim = inner_product_buffers(&multilinear, &eval_point_eq); + let evaluation_claim = inner_product_buffers(&witness, &eval_point_eq); - (multilinear, evaluation_point, evaluation_claim) + (witness, evaluation_point, evaluation_claim) } fn dubiously_modify_claim(claim: &mut F) @@ -284,6 +350,115 @@ mod test { *claim += P::Scalar::ONE } + fn run_basefold_zk_prove_and_verify( + witness_plus_mask: FieldBuffer

, + evaluation_point: Vec, + evaluation_claim: F, + ) -> Result<()> + where + F: BinaryField, + P: PackedField + PackedExtension, + { + let n_vars = evaluation_point.len(); + assert_eq!(witness_plus_mask.log_len(), n_vars + 1); + + let eval_point_eq = eq_ind_partial_eval::

(&evaluation_point); + + let merkle_prover = BinaryMerkleTreeProver::::new( + ParallelCompressionAdaptor::new(StdCompression::default()), + ); + + // Setup NTT with subspace dimension = witness.log_len + LOG_INV_RATE + let subspace = BinarySubspace::with_dim(n_vars + LOG_INV_RATE).unwrap(); + let domain_context = GenericOnTheFly::generate_from_subspace(&subspace); + let ntt = NeighborsLastSingleThread::new(domain_context); + + // Create FRI params with log_batch_size = 1 + let fri_params = FRIParams::with_strategy( + &ntt, + merkle_prover.scheme(), + witness_plus_mask.log_len(), + Some(1), + LOG_INV_RATE, + 32, + &ConstantArityStrategy::new(2), + )?; + + // Commit batched multilinear + let CommitOutput { + commitment: codeword_commitment, + committed: codeword_committed, + codeword, + } = fri::commit_interleaved(&fri_params, &ntt, &merkle_prover, witness_plus_mask.to_ref())?; + + let mut prover_transcript = ProverTranscript::new(StdChallenger::default()); + prover_transcript.message().write(&codeword_commitment); + + let fri_folder = + FRIFoldProver::new(&fri_params, &ntt, &merkle_prover, codeword, &codeword_committed)?; + + // Run prove_zk then continue with basefold prover + let prover = prove_zk( + witness_plus_mask, + eval_point_eq, + evaluation_claim, + fri_folder, + &mut prover_transcript, + ); + prover.prove(&mut prover_transcript)?; + + // Verify + let mut verifier_transcript = prover_transcript.into_verifier(); + let retrieved_commitment = verifier_transcript.message().read()?; + + let basefold::ReducedOutput { + final_fri_value, + final_sumcheck_value, + challenges, + } = basefold::verify_zk( + &fri_params, + merkle_prover.scheme(), + retrieved_commitment, + evaluation_claim, + &mut verifier_transcript, + )?; + + // Check consistency - skip batch challenge (challenges[0]) + let sumcheck_challenges = challenges[1..].to_vec(); + if !basefold::sumcheck_fri_consistency( + final_fri_value, + final_sumcheck_value, + &evaluation_point, + sumcheck_challenges, + ) { + bail!("Sumcheck and FRI are inconsistent"); + } + + Ok(()) + } + + #[test] + fn test_basefold_zk_valid_proof() { + type P = PackedBinaryGhash1x128b; + + let n_vars = 8; + let mut rng = StdRng::seed_from_u64(0); + + let witness_plus_mask = random_field_buffer::

(&mut rng, n_vars + 1); + let evaluation_point = random_scalars(&mut rng, n_vars); + + let (witness, _mask) = witness_plus_mask.split_half_ref().unwrap(); + let eval_point_eq = eq_ind_partial_eval::

(&evaluation_point); + let evaluation_claim = inner_product_buffers(&witness, &eval_point_eq); + + run_basefold_zk_prove_and_verify::<_, P>( + witness_plus_mask, + evaluation_point, + evaluation_claim, + ) + .unwrap(); + } + #[test] fn test_basefold_valid_proof() { type P = PackedBinaryGhash1x128b; diff --git a/crates/spartan-prover/Cargo.toml b/crates/spartan-prover/Cargo.toml index b48631037..f4b4f726b 100644 --- a/crates/spartan-prover/Cargo.toml +++ b/crates/spartan-prover/Cargo.toml @@ -18,6 +18,7 @@ binius-utils = { path = "../utils" } binius-prover = { path = "../prover" } binius-verifier = { path = "../verifier" } digest.workspace = true +itertools.workspace = true rand.workspace = true thiserror.workspace = true tracing.workspace = true diff --git a/crates/spartan-prover/src/lib.rs b/crates/spartan-prover/src/lib.rs index f190d9209..47b72feb0 100644 --- a/crates/spartan-prover/src/lib.rs +++ b/crates/spartan-prover/src/lib.rs @@ -4,7 +4,10 @@ mod error; pub mod pcs; mod wiring; -use std::marker::PhantomData; +use std::{ + iter::{repeat_n, repeat_with}, + marker::PhantomData, +}; use binius_field::{BinaryField, Field, PackedExtension, PackedField}; use binius_math::{ @@ -23,10 +26,16 @@ use binius_transcript::{ ProverTranscript, fiat_shamir::{CanSample, Challenger}, }; -use binius_utils::{SerializeBytes, checked_arithmetics::checked_log_2, rayon::prelude::*}; +use binius_utils::{ + SerializeBytes, + checked_arithmetics::checked_log_2, + rand::par_rand, + rayon::{self, prelude::*}, +}; use digest::{Digest, FixedOutputReset, Output, core_api::BlockSizeUser}; pub use error::*; -use rand::CryptoRng; +use itertools::chain; +use rand::{CryptoRng, rngs::StdRng}; use crate::wiring::WiringTranspose; @@ -218,37 +227,33 @@ fn pack_and_blind_witness>( n_private: usize, mut rng: impl CryptoRng, ) -> FieldBuffer

{ - // Precondition: witness length must match expected size - let expected_size = 1 << log_witness_elems; - assert_eq!( - witness.len(), - expected_size, - "witness length {} does not match expected size {}", - witness.len(), - expected_size - ); - - let len = 1 << log_witness_elems.saturating_sub(P::LOG_WIDTH); - let mut packed_witness = Vec::

::with_capacity(len); - - packed_witness - .spare_capacity_mut() - .into_par_iter() - .enumerate() - .for_each(|(i, dst)| { - let offset = i << P::LOG_WIDTH; - let value = P::from_fn(|j| witness[offset + j]); - - dst.write(value); - }); - - // SAFETY: We just initialized all elements - unsafe { - packed_witness.set_len(len); + let packed_witness = if log_witness_elems < P::LOG_WIDTH { + let elems_iter = witness.iter().copied(); + let zeros_iter = repeat_n(F::ZERO, (1 << log_witness_elems) - witness.len()); + let mask_iter = repeat_with(|| F::random(&mut rng)).take(1 << log_witness_elems); + + let elems = P::from_scalars(chain!(elems_iter, zeros_iter, mask_iter)); + vec![elems] + } else { + let packed_len = 1 << (log_witness_elems - P::LOG_WIDTH); + + let elems_iter = witness + .par_chunks(P::WIDTH) + .map(|chunk| P::from_scalars(chunk.iter().copied())); + let zeros_iter = rayon::iter::repeat_n(P::zero(), packed_len - elems_iter.len()); + + // Append a random mask to the end of the witness buffer, of equal length to the witness. + let mask_iter = par_rand::(packed_len, &mut rng, P::random); + + elems_iter + .chain(zeros_iter) + .chain(mask_iter) + .collect::>() }; - let mut witness_packed = FieldBuffer::new(log_witness_elems, packed_witness.into_boxed_slice()) - .expect("FieldBuffer::new should succeed with correct log_witness_elems"); + let mut witness_packed = + FieldBuffer::new(log_witness_elems + 1, packed_witness.into_boxed_slice()) + .expect("FieldBuffer::new should succeed with correct log_witness_elems"); // Add blinding values let base = n_public + n_private; diff --git a/crates/spartan-prover/src/wiring.rs b/crates/spartan-prover/src/wiring.rs index d614c8b22..25d5e8f8a 100644 --- a/crates/spartan-prover/src/wiring.rs +++ b/crates/spartan-prover/src/wiring.rs @@ -7,9 +7,7 @@ use binius_math::{ FieldBuffer, FieldSlice, multilinear, multilinear::eq::eq_ind_partial_eval, ntt::AdditiveNTT, univariate::evaluate_univariate, }; -use binius_prover::{ - fri::FRIFoldProver, merkle_tree::MerkleTreeProver, protocols::basefold::BaseFoldProver, -}; +use binius_prover::{fri::FRIFoldProver, merkle_tree::MerkleTreeProver, protocols::basefold}; use binius_spartan_frontend::constraint_system::{MulConstraint, Operand, WitnessIndex}; use binius_transcript::{ ProverTranscript, @@ -212,7 +210,7 @@ where // Run sumcheck on bivariate product let batched_sum = evaluate_univariate(mulcheck_evals, lambda) + batch_coeff * public_eval; - BaseFoldProver::new(witness, l_poly, batched_sum, fri_prover).prove(transcript)?; + basefold::prove_zk(witness, l_poly, batched_sum, fri_prover, transcript).prove(transcript)?; Ok(()) } @@ -338,7 +336,7 @@ mod tests { }; use binius_transcript::ProverTranscript; use binius_verifier::{ - fri::{ConstantArityStrategy, FRIParams, calculate_n_test_queries}, + fri::{ConstantArityStrategy, FRIParams}, hash::{StdCompression, StdDigest}, }; use rand::{Rng, SeedableRng, rngs::StdRng}; @@ -347,7 +345,6 @@ mod tests { use super::*; const LOG_INV_RATE: usize = 1; - const SECURITY_BITS: usize = 32; /// Generate random MulConstraints for testing. /// Each operand has 0-4 random wires. @@ -476,7 +473,8 @@ mod tests { let constraints = generate_random_constraints(&mut rng, n_constraints, witness_size); // Create random witness using random_field_buffer - let witness_packed = random_field_buffer::(&mut rng, log_witness_size); + let witness_mask_packed = random_field_buffer::(&mut rng, log_witness_size + 1); + let (witness_packed, _mask_packed) = witness_mask_packed.split_half_ref().unwrap(); // Compute mulcheck witness let mulcheck_witness = build_mulcheck_witness(&constraints, witness_packed.to_ref()); @@ -514,14 +512,13 @@ mod tests { let domain_context = GenericOnTheFly::generate_from_subspace(&subspace); let ntt = NeighborsLastSingleThread::new(domain_context); - let n_test_queries = calculate_n_test_queries(SECURITY_BITS, LOG_INV_RATE); let fri_params = FRIParams::with_strategy( &ntt, merkle_prover.scheme(), - log_witness_size, - None, + witness_mask_packed.log_len(), + Some(1), LOG_INV_RATE, - n_test_queries, + 32, &ConstantArityStrategy::new(2), ) .expect("FRI params creation should succeed"); @@ -531,7 +528,7 @@ mod tests { commitment: codeword_commitment, committed: codeword_committed, codeword, - } = fri::commit_interleaved(&fri_params, &ntt, &merkle_prover, witness_packed.to_ref()) + } = fri::commit_interleaved(&fri_params, &ntt, &merkle_prover, witness_mask_packed.to_ref()) .expect("commit should succeed"); // Create FRI fold prover @@ -548,7 +545,7 @@ mod tests { fri_prover, &r_public, &r_x, - witness_packed, + witness_mask_packed, &mulcheck_evals, &mut prover_transcript, ) diff --git a/crates/spartan-verifier/src/lib.rs b/crates/spartan-verifier/src/lib.rs index 2c2d0d2c2..b1e68be6b 100644 --- a/crates/spartan-verifier/src/lib.rs +++ b/crates/spartan-verifier/src/lib.rs @@ -19,10 +19,13 @@ use binius_transcript::{ }; use binius_utils::{DeserializeBytes, checked_arithmetics::checked_log_2}; use binius_verifier::{ - fri::{self, ConstantArityStrategy, FRIParams, calculate_n_test_queries}, + fri::{self, FRIParams, MinProofSizeStrategy, calculate_n_test_queries}, hash::PseudoCompressionFunction, merkle_tree::BinaryMerkleTreeScheme, - protocols::{mlecheck, sumcheck, sumcheck::SumcheckOutput}, + protocols::{ + mlecheck, + sumcheck::{self, SumcheckOutput}, + }, }; use digest::{Digest, Output, core_api::BlockSizeUser}; @@ -63,11 +66,13 @@ where }; let constraint_system = ConstraintSystemPadded::new(constraint_system, blinding_info); - let log_msg_len = constraint_system.log_size() as usize; - let log_code_len = log_msg_len + log_inv_rate; + // The message contains the witness and a random mask of equal size to the witness. + // For ZK mode, the batch size is 1 (witness and mask are the two interleaved elements). + let log_witness_size = constraint_system.log_size() as usize; + let log_batch_size = 1; + let log_dim = log_witness_size; // RS code dimension equals witness size + let log_code_len = log_dim + log_inv_rate; let merkle_scheme = BinaryMerkleTreeScheme::new(compression); - let fri_arity = - ConstantArityStrategy::with_optimal_arity::(&merkle_scheme, log_code_len).arity; let subspace = BinarySubspace::with_dim(log_code_len)?; let domain_context = GenericOnTheFly::generate_from_subspace(&subspace); @@ -76,11 +81,11 @@ where let fri_params = FRIParams::with_strategy( &ntt, &merkle_scheme, - log_msg_len, - None, + log_dim + log_batch_size, + Some(log_batch_size), log_inv_rate, n_test_queries, - &ConstantArityStrategy::new(fri_arity), + &MinProofSizeStrategy, )?; Ok(Self { @@ -190,6 +195,8 @@ pub enum Error { Sumcheck(#[from] sumcheck::Error), #[error("Math error: {0}")] Math(#[from] binius_math::Error), + #[error("Reed-Solomon error: {0}")] + ReedSolomon(#[source] binius_math::reed_solomon::Error), #[error("wiring error: {0}")] Wiring(#[from] wiring::Error), #[error("Transcript error: {0}")] diff --git a/crates/spartan-verifier/src/wiring.rs b/crates/spartan-verifier/src/wiring.rs index 56633eba3..7ab54f421 100644 --- a/crates/spartan-verifier/src/wiring.rs +++ b/crates/spartan-verifier/src/wiring.rs @@ -57,10 +57,20 @@ where final_fri_value: witness_eval, final_sumcheck_value: eval, challenges: mut r_y, - } = basefold::verify(fri_params, merkle_scheme, codeword_commitment, batched_claim, transcript)?; + } = basefold::verify_zk( + fri_params, + merkle_scheme, + codeword_commitment, + batched_claim, + transcript, + )?; r_y.reverse(); + // The challenges include the batch_challenge as the last element after reversal. + // For the wiring check, we only need the sumcheck challenges (first n elements). + r_y.pop(); + Ok(Output { lambda, batch_coeff, diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 0e3e05cbb..1d010ae8b 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -15,13 +15,13 @@ bytes.workspace = true cfg-if.workspace = true generic-array.workspace = true itertools.workspace = true +rand.workspace = true rayon = { optional = true, workspace = true} regex = { workspace = true, optional = true } thiserror.workspace = true trait-set.workspace = true [dev-dependencies] -rand.workspace = true [features] default = [] diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs index 5108534b3..68bb53204 100644 --- a/crates/utils/src/lib.rs +++ b/crates/utils/src/lib.rs @@ -10,6 +10,7 @@ pub mod iter; pub mod mem; #[cfg(feature = "platform-diagnostics")] pub mod platform_diagnostics; +pub mod rand; pub mod random_access_sequence; pub mod rayon; pub mod serialization; diff --git a/crates/utils/src/rand.rs b/crates/utils/src/rand.rs new file mode 100644 index 000000000..43cb686d9 --- /dev/null +++ b/crates/utils/src/rand.rs @@ -0,0 +1,39 @@ +// Copyright 2026 The Binius Developers + +//! Parallel random number generation utilities. + +use rand::{Rng, SeedableRng}; + +use crate::rayon::prelude::*; + +/// Generates random values in parallel using a deterministic per-index seeding scheme. +/// +/// Creates a base seed from the provided RNG, then for each index `i` in `0..n`, +/// XORs the index bytes into the seed to create a unique but deterministic seed +/// for that index's RNG. +pub fn par_rand( + n: usize, + mut rng: impl Rng, + f: F, +) -> impl IndexedParallelIterator +where + InnerR: SeedableRng, + InnerR::Seed: Send + Sync, + T: Send, + F: Fn(InnerR) -> T + Sync + Send, +{ + let mut base_seed = InnerR::Seed::default(); + rng.fill_bytes(base_seed.as_mut()); + + (0..n).into_par_iter().map(move |i| { + let mut seed = base_seed.clone(); + let seed_bytes = seed.as_mut(); + + let index_bytes = i.to_le_bytes(); + for (seed_byte, &index_byte) in seed_bytes.iter_mut().zip(index_bytes.iter()) { + *seed_byte ^= index_byte; + } + + f(InnerR::from_seed(seed)) + }) +} diff --git a/crates/verifier/src/protocols/basefold.rs b/crates/verifier/src/protocols/basefold.rs index c7bb3b445..17f86acad 100644 --- a/crates/verifier/src/protocols/basefold.rs +++ b/crates/verifier/src/protocols/basefold.rs @@ -20,7 +20,7 @@ //! [BCS16]: use binius_field::{BinaryField, Field}; -use binius_math::multilinear::eq::eq_ind; +use binius_math::{line::extrapolate_line_packed, multilinear::eq::eq_ind}; use binius_transcript::{ VerifierTranscript, fiat_shamir::{CanSample, Challenger}, @@ -102,6 +102,71 @@ where }) } +pub fn verify_zk( + fri_params: &FRIParams, + merkle_scheme: &MTScheme, + codeword_commitment: MTScheme::Digest, + sum_claim: F, + transcript: &mut VerifierTranscript, +) -> Result, Error> +where + F: BinaryField, + Challenger_: Challenger, + MTScheme: MerkleTreeScheme, +{ + // The multivariate polynomial evaluated is a degree-2 multilinear composite. + const DEGREE: usize = 2; + + assert_eq!(fri_params.log_batch_size(), 1); // precondition + + // Read the evaluation claim for the mask from the transcript. + let mask_claim = transcript.message().read::()?; + + let n_vars = fri_params.rs_code().log_dim(); + let mut challenges = Vec::with_capacity(n_vars + 1); + + let mut fri_fold_verifier = FRIFoldVerifier::new(fri_params); + + let batch_challenge = transcript.sample(); + + // Compute the batched sum using linear interpolation. + let mut sum = extrapolate_line_packed(sum_claim, mask_claim, batch_challenge); + + fri_fold_verifier.process_round(&mut transcript.message())?; + challenges.push(batch_challenge); + + for _ in 0..n_vars { + let round_proof = RoundProof(RoundCoeffs(transcript.message().read_vec(DEGREE)?)); + fri_fold_verifier.process_round(&mut transcript.message())?; + + let round_coeffs = round_proof.recover(sum); + let challenge = transcript.sample(); + sum = round_coeffs.evaluate(challenge); + challenges.push(challenge); + } + + // Finalize and get commitments + fri_fold_verifier.process_round(&mut transcript.message())?; + let round_commitments = fri_fold_verifier.finalize()?; + + // TODO: Make all commitments after the first non-hiding + let fri_verifier = FRIQueryVerifier::new( + fri_params, + merkle_scheme, + &codeword_commitment, + &round_commitments, + &challenges, + )?; + + let final_fri_value = fri_verifier.verify(transcript)?; + + Ok(ReducedOutput { + final_fri_value, + final_sumcheck_value: sum, + challenges, + }) +} + /// Output type of the [`verify`] function. pub struct ReducedOutput { pub final_fri_value: F,