diff --git a/crates/prover/benches/pcs.rs b/crates/prover/benches/pcs.rs index 0dd7f3a74..c9eca0782 100644 --- a/crates/prover/benches/pcs.rs +++ b/crates/prover/benches/pcs.rs @@ -87,23 +87,33 @@ fn bench_pcs(c: &mut Criterion) { transcript.message().write_scalar(eval); group.bench_function(format!("prove/log_len={log_len}"), |b| { - b.iter(|| { - let mut transcript = transcript.clone(); - pcs_prover - .prove( - codeword.as_ref(), - &codeword_committed, + b.iter_batched( + || { + ( + transcript.clone(), + codeword.clone(), packed_multilin.clone(), eval_point.clone(), - &mut transcript, ) - .unwrap() - }); + }, + |(mut transcript, codeword, packed_multilin, eval_point)| { + pcs_prover + .prove( + codeword, + &codeword_committed, + packed_multilin, + eval_point, + &mut transcript, + ) + .unwrap() + }, + criterion::BatchSize::SmallInput, + ); }); pcs_prover .prove( - codeword.as_ref(), + codeword.clone(), &codeword_committed, packed_multilin.clone(), eval_point.clone(), diff --git a/crates/prover/src/fri/fold.rs b/crates/prover/src/fri/fold.rs index cb467443e..7f57e78be 100644 --- a/crates/prover/src/fri/fold.rs +++ b/crates/prover/src/fri/fold.rs @@ -1,17 +1,17 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{BinaryField, Field, PackedField, packed::len_packed_slice}; -use binius_math::{multilinear::eq::eq_ind_partial_eval, ntt::AdditiveNTT}; +use binius_field::{BinaryField, Field, PackedField}; +use binius_math::{ + FieldBuffer, FieldSlice, inner_product::inner_product_buffers, + multilinear::eq::eq_ind_partial_eval, ntt::AdditiveNTT, +}; use binius_transcript::{ ProverTranscript, fiat_shamir::{CanSampleBits, Challenger}, }; use binius_utils::{SerializeBytes, checked_arithmetics::log2_strict_usize, rayon::prelude::*}; use binius_verifier::{ - fri::{ - FRIParams, - fold::{fold_chunk, fold_interleaved_chunk}, - }, + fri::{FRIParams, fold::fold_chunk}, merkle_tree::MerkleTreeScheme, }; use tracing::instrument; @@ -20,7 +20,7 @@ use super::{error::Error, query::FRIQueryProver}; use crate::merkle_tree::MerkleTreeProver; /// The type of the termination round codeword in the FRI protocol. -pub type TerminateCodeword = Vec; +pub type TerminateCodeword = FieldBuffer; pub enum FoldRoundOutput { NoCommitment, @@ -37,9 +37,9 @@ where params: &'a FRIParams, ntt: &'a NTT, merkle_prover: &'a MerkleProver, - codeword: &'a [P], + codeword: FieldBuffer

, codeword_committed: &'a MerkleProver::Committed, - round_committed: Vec<(Vec, MerkleProver::Committed)>, + round_committed: Vec<(FieldBuffer, MerkleProver::Committed)>, curr_round: usize, next_commit_round: Option, unprocessed_challenges: Vec, @@ -58,10 +58,10 @@ where params: &'a FRIParams, ntt: &'a NTT, merkle_prover: &'a MerkleProver, - committed_codeword: &'a [P], + committed_codeword: FieldBuffer

, committed: &'a MerkleProver::Committed, ) -> Result { - if len_packed_slice(committed_codeword) < 1 << params.log_len() { + if committed_codeword.len() < 1 << params.log_len() { return Err(Error::InvalidArgs( "Reed-Solomon code length must match interleaved codeword length".to_string(), )); @@ -95,7 +95,7 @@ where pub fn current_codeword_len(&self) -> usize { match self.round_committed.last() { Some((codeword, _)) => codeword.len(), - None => len_packed_slice(self.codeword), + None => self.codeword.len(), } } @@ -132,18 +132,13 @@ where Some((prev_codeword, _)) => { // Fold a full codeword committed in the previous FRI round into a codeword with // reduced dimension and rate. - fold_codeword( - self.ntt, - prev_codeword, - &self.unprocessed_challenges, - log2_strict_usize(prev_codeword.len()), - ) + fold_codeword(self.ntt, prev_codeword.to_ref(), &self.unprocessed_challenges) } None => { // Fold the interleaved codeword that was originally committed into a single // codeword with the same block length. fold_interleaved( - self.codeword, + self.codeword.to_ref(), &self.unprocessed_challenges, self.params.rs_code().log_len(), self.params.log_batch_size(), @@ -162,7 +157,9 @@ where let log_coset_size = next_arity.unwrap_or_else(|| self.params.n_final_challenges()); let coset_size = 1 << log_coset_size; let merkle_tree_span = tracing::debug_span!("Merkle Tree").entered(); - let (commitment, committed) = self.merkle_prover.commit(&folded_codeword, coset_size)?; + let (commitment, committed) = self + .merkle_prover + .commit(folded_codeword.as_ref(), coset_size)?; drop(merkle_tree_span); self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| { @@ -195,7 +192,10 @@ where .round_committed .last() .map(|(codeword, _)| codeword.clone()) - .unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect()); + .unwrap_or_else(|| { + let scalars: Vec = self.codeword.iter_scalars().collect(); + FieldBuffer::from_values(&scalars).expect("codeword has power-of-two length") + }); self.unprocessed_challenges.clear(); @@ -227,7 +227,7 @@ where { let (terminate_codeword, query_prover) = self.finalize()?; let mut advice = transcript.decommitment(); - advice.write_scalar_slice(&terminate_codeword); + advice.write_scalar_slice(terminate_codeword.as_ref()); let layers = query_prover.vcs_optimal_layers()?; for layer in layers { @@ -258,27 +258,29 @@ where /// [DP24]: #[instrument(skip_all, level = "debug")] fn fold_interleaved( - codeword: &[P], + codeword: FieldSlice

, challenges: &[F], log_len: usize, log_batch_size: usize, -) -> Vec +) -> FieldBuffer where F: Field, P: PackedField, { - assert_eq!(codeword.len(), 1 << (log_len + log_batch_size).saturating_sub(P::LOG_WIDTH)); + assert_eq!(codeword.log_len(), log_len + log_batch_size); assert_eq!(challenges.len(), log_batch_size); let tensor = eq_ind_partial_eval(challenges); // For each chunk of size `2^chunk_size` in the interleaved codeword, fold it with the folding // challenges. - let chunk_size = 1 << (challenges.len() - P::LOG_WIDTH); - codeword - .par_chunks(chunk_size) - .map(|chunk| fold_interleaved_chunk(log_batch_size, chunk, tensor.as_ref())) - .collect() + let values = codeword + .chunks_par(log_batch_size) + .expect("log_batch_size <= codeword.log_len()") + .map(|chunk| inner_product_buffers(&chunk, &tensor)) + .collect::>(); + FieldBuffer::new(log_len, values.into_boxed_slice()) + .expect("codeword.log_len() == log_len + log_batch_size") } /// FRI-fold the codeword using the given challenges. @@ -294,27 +296,32 @@ where /// /// [DP24]: #[instrument(skip_all, level = "debug")] -fn fold_codeword(ntt: &NTT, codeword: &[F], challenges: &[F], log_len: usize) -> Vec +fn fold_codeword(ntt: &NTT, codeword: FieldSlice, challenges: &[F]) -> FieldBuffer where F: BinaryField, NTT: AdditiveNTT + Sync, { - assert_eq!(codeword.len(), 1 << log_len); + let log_len = codeword.log_len(); assert!(challenges.len() <= log_len); + let folded_log_len = log_len - challenges.len(); + // For each coset of size `2^chunk_size` in the codeword, fold it with the folding challenges. let chunk_size = 1 << challenges.len(); - codeword - .par_chunks(chunk_size) + let values: Vec = codeword + .chunks_par(challenges.len()) + .expect("precondition: challenges.len() <= log_len") .enumerate() .map_init( || vec![F::default(); chunk_size], |scratch_buffer, (i, chunk)| { - scratch_buffer.copy_from_slice(chunk); + scratch_buffer.copy_from_slice(chunk.as_ref()); fold_chunk(ntt, log_len, i, scratch_buffer, challenges) }, ) - .collect() + .collect(); + FieldBuffer::new(folded_log_len, values.into_boxed_slice()) + .expect("codeword.len() == 1 << log_len") } #[cfg(test)] @@ -358,12 +365,12 @@ mod tests { ntt.forward_transform(codeword.to_mut(), 0, 0); // Fold the encoded message using FRI folding. - let folded_codeword = fold_codeword(&ntt, codeword.as_ref(), &challenges, log_dim + arity); + let folded_codeword = fold_codeword(&ntt, codeword.to_ref(), &challenges); // Encode the folded message. ntt.forward_transform(folded_msg.to_mut(), 0, 0); // Check that folding and encoding commute. - assert_eq!(folded_codeword, folded_msg.as_ref()); + assert_eq!(folded_codeword, folded_msg); } } diff --git a/crates/prover/src/fri/query.rs b/crates/prover/src/fri/query.rs index d75dfa7c7..e5691d646 100644 --- a/crates/prover/src/fri/query.rs +++ b/crates/prover/src/fri/query.rs @@ -2,7 +2,8 @@ use std::iter; -use binius_field::{BinaryField, PackedField, packed::iter_packed_slice_with_offset}; +use binius_field::{BinaryField, PackedField}; +use binius_math::{FieldBuffer, FieldSlice}; use binius_transcript::TranscriptWriter; use binius_verifier::{ fri::{FRIParams, vcs_optimal_layers_depths_iter}, @@ -24,9 +25,9 @@ where VCS: MerkleTreeScheme, { pub(super) params: &'a FRIParams, - pub(super) codeword: &'a [P], + pub(super) codeword: FieldBuffer

, pub(super) codeword_committed: &'a MerkleProver::Committed, - pub(super) round_committed: Vec<(Vec, MerkleProver::Committed)>, + pub(super) round_committed: Vec<(FieldBuffer, MerkleProver::Committed)>, pub(super) merkle_prover: &'a MerkleProver, } @@ -64,7 +65,7 @@ where prove_coset_opening( self.merkle_prover, - self.codeword, + self.codeword.to_ref(), self.codeword_committed, index, self.params.log_batch_size(), @@ -78,7 +79,7 @@ where index >>= arity; prove_coset_opening( self.merkle_prover, - codeword, + codeword.to_ref(), committed, index, arity, @@ -111,7 +112,7 @@ where fn prove_coset_opening( merkle_prover: &MTProver, - codeword: &[P], + codeword: FieldSlice

, committed: &MTProver::Committed, coset_index: usize, log_coset_size: usize, @@ -124,9 +125,12 @@ where MTProver: MerkleTreeProver, B: BufMut, { - let values = iter_packed_slice_with_offset(codeword, coset_index << log_coset_size) - .take(1 << log_coset_size); - advice.write_scalar_iter(values); + assert!(coset_index < (1 << (codeword.log_len() - log_coset_size))); // precondition + + let values = codeword + .chunk(log_coset_size, coset_index) + .expect("precondition: coset_index < 2^(codeword.log_len() - log_coset_size)"); + advice.write_scalar_iter(values.iter_scalars()); merkle_prover.prove_opening(committed, optimal_layer_depth, coset_index, advice)?; diff --git a/crates/prover/src/fri/tests.rs b/crates/prover/src/fri/tests.rs index 484b8e3fb..427c88fcc 100644 --- a/crates/prover/src/fri/tests.rs +++ b/crates/prover/src/fri/tests.rs @@ -63,8 +63,7 @@ fn test_commit_prove_verify_success( // Run the prover to generate the proximity proof let mut round_prover = - FRIFoldProver::new(¶ms, &ntt, &merkle_prover, codeword.as_ref(), &codeword_committed) - .unwrap(); + FRIFoldProver::new(¶ms, &ntt, &merkle_prover, codeword, &codeword_committed).unwrap(); let mut prover_challenger = ProverTranscript::new(StdChallenger::default()); prover_challenger.message().write(&codeword_commitment); diff --git a/crates/prover/src/pcs.rs b/crates/prover/src/pcs.rs index 2620c3628..04abc3164 100644 --- a/crates/prover/src/pcs.rs +++ b/crates/prover/src/pcs.rs @@ -106,7 +106,7 @@ where /// * `transcript` - the transcript of the prover's proof pub fn prove( &self, - committed_codeword: &'a [P], + committed_codeword: FieldBuffer

, committed: &'a MerkleProver::Committed, packed_multilin: FieldBuffer

, evaluation_point: Vec, @@ -261,7 +261,7 @@ mod test { let mut prover_transcript = ProverTranscript::new(StdChallenger::default()); prover_transcript.message().write(&codeword_commitment); ring_switch_pcs_prover.prove( - codeword.as_ref(), + codeword, &codeword_committed, packed_mle, evaluation_point.clone(), diff --git a/crates/prover/src/protocols/basefold.rs b/crates/prover/src/protocols/basefold.rs index 9bd1b5355..8b082e761 100644 --- a/crates/prover/src/protocols/basefold.rs +++ b/crates/prover/src/protocols/basefold.rs @@ -226,13 +226,8 @@ mod test { let mut prover_transcript = ProverTranscript::new(StdChallenger::default()); prover_transcript.message().write(&codeword_commitment); - let fri_folder = FRIFoldProver::new( - &fri_params, - &ntt, - &merkle_prover, - codeword.as_ref(), - &codeword_committed, - )?; + let fri_folder = + FRIFoldProver::new(&fri_params, &ntt, &merkle_prover, codeword, &codeword_committed)?; let prover = BaseFoldProver::new(multilinear, eval_point_eq, evaluation_claim, fri_folder); prover.prove(&mut prover_transcript)?; diff --git a/crates/prover/src/prove.rs b/crates/prover/src/prove.rs index 0ff0126e9..0ade286c7 100644 --- a/crates/prover/src/prove.rs +++ b/crates/prover/src/prove.rs @@ -274,7 +274,7 @@ where ) .entered(); pcs_prover.prove( - trace_codeword.as_ref(), + trace_codeword, &trace_committed, witness_packed, eval_point, diff --git a/crates/spartan-prover/src/lib.rs b/crates/spartan-prover/src/lib.rs index 71ee4392e..4f411073b 100644 --- a/crates/spartan-prover/src/lib.rs +++ b/crates/spartan-prover/src/lib.rs @@ -143,7 +143,7 @@ where self.verifier.fri_params(), &self.ntt, &self.merkle_prover, - codeword.as_ref(), + codeword, &codeword_committed, )?; wiring::prove( diff --git a/crates/spartan-prover/src/pcs.rs b/crates/spartan-prover/src/pcs.rs index 1dcac6d6f..d7d6e19d2 100644 --- a/crates/spartan-prover/src/pcs.rs +++ b/crates/spartan-prover/src/pcs.rs @@ -89,7 +89,7 @@ where /// * `transcript` - the prover's transcript pub fn prove( &self, - committed_codeword: &'a [P], + committed_codeword: FieldBuffer

, committed: &'a MerkleProver::Committed, multilinear: FieldBuffer

, evaluation_point: &[F], @@ -198,7 +198,7 @@ mod tests { let mut prover_transcript = ProverTranscript::new(StdChallenger::default()); prover_transcript.message().write(&codeword_commitment); pcs_prover.prove( - codeword.as_ref(), + codeword, &codeword_committed, multilinear, &evaluation_point, diff --git a/crates/spartan-prover/src/wiring.rs b/crates/spartan-prover/src/wiring.rs index 684664397..df2c5861a 100644 --- a/crates/spartan-prover/src/wiring.rs +++ b/crates/spartan-prover/src/wiring.rs @@ -534,14 +534,9 @@ mod tests { .expect("commit should succeed"); // Create FRI fold prover - let fri_prover = FRIFoldProver::new( - &fri_params, - &ntt, - &merkle_prover, - codeword.as_ref(), - &codeword_committed, - ) - .expect("FRI fold prover creation should succeed"); + let fri_prover = + FRIFoldProver::new(&fri_params, &ntt, &merkle_prover, codeword, &codeword_committed) + .expect("FRI fold prover creation should succeed"); // Prover side let mut prover_transcript = ProverTranscript::new(StdChallenger::default());