Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions crates/prover/benches/pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,33 @@ fn bench_pcs(c: &mut Criterion) {
transcript.message().write_scalar(eval);

group.bench_function(format!("prove/log_len={log_len}"), |b| {
b.iter(|| {
let mut transcript = transcript.clone();
pcs_prover
.prove(
codeword.as_ref(),
&codeword_committed,
b.iter_batched(
|| {
(
transcript.clone(),
codeword.clone(),
packed_multilin.clone(),
eval_point.clone(),
&mut transcript,
)
.unwrap()
});
},
|(mut transcript, codeword, packed_multilin, eval_point)| {
pcs_prover
.prove(
codeword,
&codeword_committed,
packed_multilin,
eval_point,
&mut transcript,
)
.unwrap()
},
criterion::BatchSize::SmallInput,
);
});

pcs_prover
.prove(
codeword.as_ref(),
codeword.clone(),
&codeword_committed,
packed_multilin.clone(),
eval_point.clone(),
Expand Down
83 changes: 45 additions & 38 deletions crates/prover/src/fri/fold.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
// Copyright 2024-2025 Irreducible Inc.

use binius_field::{BinaryField, Field, PackedField, packed::len_packed_slice};
use binius_math::{multilinear::eq::eq_ind_partial_eval, ntt::AdditiveNTT};
use binius_field::{BinaryField, Field, PackedField};
use binius_math::{
FieldBuffer, FieldSlice, inner_product::inner_product_buffers,
multilinear::eq::eq_ind_partial_eval, ntt::AdditiveNTT,
};
use binius_transcript::{
ProverTranscript,
fiat_shamir::{CanSampleBits, Challenger},
};
use binius_utils::{SerializeBytes, checked_arithmetics::log2_strict_usize, rayon::prelude::*};
use binius_verifier::{
fri::{
FRIParams,
fold::{fold_chunk, fold_interleaved_chunk},
},
fri::{FRIParams, fold::fold_chunk},
merkle_tree::MerkleTreeScheme,
};
use tracing::instrument;
Expand All @@ -20,7 +20,7 @@ use super::{error::Error, query::FRIQueryProver};
use crate::merkle_tree::MerkleTreeProver;

/// The type of the termination round codeword in the FRI protocol.
pub type TerminateCodeword<F> = Vec<F>;
pub type TerminateCodeword<F> = FieldBuffer<F>;

pub enum FoldRoundOutput<VCSCommitment> {
NoCommitment,
Expand All @@ -37,9 +37,9 @@ where
params: &'a FRIParams<F>,
ntt: &'a NTT,
merkle_prover: &'a MerkleProver,
codeword: &'a [P],
codeword: FieldBuffer<P>,
codeword_committed: &'a MerkleProver::Committed,
round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
round_committed: Vec<(FieldBuffer<F>, MerkleProver::Committed)>,
curr_round: usize,
next_commit_round: Option<usize>,
unprocessed_challenges: Vec<F>,
Expand All @@ -58,10 +58,10 @@ where
params: &'a FRIParams<F>,
ntt: &'a NTT,
merkle_prover: &'a MerkleProver,
committed_codeword: &'a [P],
committed_codeword: FieldBuffer<P>,
committed: &'a MerkleProver::Committed,
) -> Result<Self, Error> {
if len_packed_slice(committed_codeword) < 1 << params.log_len() {
if committed_codeword.len() < 1 << params.log_len() {
return Err(Error::InvalidArgs(
"Reed-Solomon code length must match interleaved codeword length".to_string(),
));
Expand Down Expand Up @@ -95,7 +95,7 @@ where
pub fn current_codeword_len(&self) -> usize {
match self.round_committed.last() {
Some((codeword, _)) => codeword.len(),
None => len_packed_slice(self.codeword),
None => self.codeword.len(),
}
}

Expand Down Expand Up @@ -132,18 +132,13 @@ where
Some((prev_codeword, _)) => {
// Fold a full codeword committed in the previous FRI round into a codeword with
// reduced dimension and rate.
fold_codeword(
self.ntt,
prev_codeword,
&self.unprocessed_challenges,
log2_strict_usize(prev_codeword.len()),
)
fold_codeword(self.ntt, prev_codeword.to_ref(), &self.unprocessed_challenges)
}
None => {
// Fold the interleaved codeword that was originally committed into a single
// codeword with the same block length.
fold_interleaved(
self.codeword,
self.codeword.to_ref(),
&self.unprocessed_challenges,
self.params.rs_code().log_len(),
self.params.log_batch_size(),
Expand All @@ -162,7 +157,9 @@ where
let log_coset_size = next_arity.unwrap_or_else(|| self.params.n_final_challenges());
let coset_size = 1 << log_coset_size;
let merkle_tree_span = tracing::debug_span!("Merkle Tree").entered();
let (commitment, committed) = self.merkle_prover.commit(&folded_codeword, coset_size)?;
let (commitment, committed) = self
.merkle_prover
.commit(folded_codeword.as_ref(), coset_size)?;
drop(merkle_tree_span);

self.next_commit_round = self.next_commit_round.take().and_then(|next_commit_round| {
Expand Down Expand Up @@ -195,7 +192,10 @@ where
.round_committed
.last()
.map(|(codeword, _)| codeword.clone())
.unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect());
.unwrap_or_else(|| {
let scalars: Vec<F> = self.codeword.iter_scalars().collect();
FieldBuffer::from_values(&scalars).expect("codeword has power-of-two length")
});

self.unprocessed_challenges.clear();

Expand Down Expand Up @@ -227,7 +227,7 @@ where
{
let (terminate_codeword, query_prover) = self.finalize()?;
let mut advice = transcript.decommitment();
advice.write_scalar_slice(&terminate_codeword);
advice.write_scalar_slice(terminate_codeword.as_ref());

let layers = query_prover.vcs_optimal_layers()?;
for layer in layers {
Expand Down Expand Up @@ -258,27 +258,29 @@ where
/// [DP24]: <https://eprint.iacr.org/2024/504>
#[instrument(skip_all, level = "debug")]
fn fold_interleaved<F, P>(
codeword: &[P],
codeword: FieldSlice<P>,
challenges: &[F],
log_len: usize,
log_batch_size: usize,
) -> Vec<F>
) -> FieldBuffer<F>
where
F: Field,
P: PackedField<Scalar = F>,
{
assert_eq!(codeword.len(), 1 << (log_len + log_batch_size).saturating_sub(P::LOG_WIDTH));
assert_eq!(codeword.log_len(), log_len + log_batch_size);
assert_eq!(challenges.len(), log_batch_size);

let tensor = eq_ind_partial_eval(challenges);

// For each chunk of size `2^chunk_size` in the interleaved codeword, fold it with the folding
// challenges.
let chunk_size = 1 << (challenges.len() - P::LOG_WIDTH);
codeword
.par_chunks(chunk_size)
.map(|chunk| fold_interleaved_chunk(log_batch_size, chunk, tensor.as_ref()))
.collect()
let values = codeword
.chunks_par(log_batch_size)
.expect("log_batch_size <= codeword.log_len()")
.map(|chunk| inner_product_buffers(&chunk, &tensor))
.collect::<Vec<_>>();
FieldBuffer::new(log_len, values.into_boxed_slice())
.expect("codeword.log_len() == log_len + log_batch_size")
}

/// FRI-fold the codeword using the given challenges.
Expand All @@ -294,27 +296,32 @@ where
///
/// [DP24]: <https://eprint.iacr.org/2024/504>
#[instrument(skip_all, level = "debug")]
fn fold_codeword<F, NTT>(ntt: &NTT, codeword: &[F], challenges: &[F], log_len: usize) -> Vec<F>
fn fold_codeword<F, NTT>(ntt: &NTT, codeword: FieldSlice<F>, challenges: &[F]) -> FieldBuffer<F>
where
F: BinaryField,
NTT: AdditiveNTT<Field = F> + Sync,
{
assert_eq!(codeword.len(), 1 << log_len);
let log_len = codeword.log_len();
assert!(challenges.len() <= log_len);

let folded_log_len = log_len - challenges.len();

// For each coset of size `2^chunk_size` in the codeword, fold it with the folding challenges.
let chunk_size = 1 << challenges.len();
codeword
.par_chunks(chunk_size)
let values: Vec<F> = codeword
.chunks_par(challenges.len())
.expect("precondition: challenges.len() <= log_len")
.enumerate()
.map_init(
|| vec![F::default(); chunk_size],
|scratch_buffer, (i, chunk)| {
scratch_buffer.copy_from_slice(chunk);
scratch_buffer.copy_from_slice(chunk.as_ref());
fold_chunk(ntt, log_len, i, scratch_buffer, challenges)
},
)
.collect()
.collect();
FieldBuffer::new(folded_log_len, values.into_boxed_slice())
.expect("codeword.len() == 1 << log_len")
}

#[cfg(test)]
Expand Down Expand Up @@ -358,12 +365,12 @@ mod tests {
ntt.forward_transform(codeword.to_mut(), 0, 0);

// Fold the encoded message using FRI folding.
let folded_codeword = fold_codeword(&ntt, codeword.as_ref(), &challenges, log_dim + arity);
let folded_codeword = fold_codeword(&ntt, codeword.to_ref(), &challenges);

// Encode the folded message.
ntt.forward_transform(folded_msg.to_mut(), 0, 0);

// Check that folding and encoding commute.
assert_eq!(folded_codeword, folded_msg.as_ref());
assert_eq!(folded_codeword, folded_msg);
}
}
22 changes: 13 additions & 9 deletions crates/prover/src/fri/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

use std::iter;

use binius_field::{BinaryField, PackedField, packed::iter_packed_slice_with_offset};
use binius_field::{BinaryField, PackedField};
use binius_math::{FieldBuffer, FieldSlice};
use binius_transcript::TranscriptWriter;
use binius_verifier::{
fri::{FRIParams, vcs_optimal_layers_depths_iter},
Expand All @@ -24,9 +25,9 @@ where
VCS: MerkleTreeScheme<F>,
{
pub(super) params: &'a FRIParams<F>,
pub(super) codeword: &'a [P],
pub(super) codeword: FieldBuffer<P>,
pub(super) codeword_committed: &'a MerkleProver::Committed,
pub(super) round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
pub(super) round_committed: Vec<(FieldBuffer<F>, MerkleProver::Committed)>,
pub(super) merkle_prover: &'a MerkleProver,
}

Expand Down Expand Up @@ -64,7 +65,7 @@ where

prove_coset_opening(
self.merkle_prover,
self.codeword,
self.codeword.to_ref(),
self.codeword_committed,
index,
self.params.log_batch_size(),
Expand All @@ -78,7 +79,7 @@ where
index >>= arity;
prove_coset_opening(
self.merkle_prover,
codeword,
codeword.to_ref(),
committed,
index,
arity,
Expand Down Expand Up @@ -111,7 +112,7 @@ where

fn prove_coset_opening<F, P, MTProver, B>(
merkle_prover: &MTProver,
codeword: &[P],
codeword: FieldSlice<P>,
committed: &MTProver::Committed,
coset_index: usize,
log_coset_size: usize,
Expand All @@ -124,9 +125,12 @@ where
MTProver: MerkleTreeProver<F>,
B: BufMut,
{
let values = iter_packed_slice_with_offset(codeword, coset_index << log_coset_size)
.take(1 << log_coset_size);
advice.write_scalar_iter(values);
assert!(coset_index < (1 << (codeword.log_len() - log_coset_size))); // precondition

let values = codeword
.chunk(log_coset_size, coset_index)
.expect("precondition: coset_index < 2^(codeword.log_len() - log_coset_size)");
advice.write_scalar_iter(values.iter_scalars());

merkle_prover.prove_opening(committed, optimal_layer_depth, coset_index, advice)?;

Expand Down
3 changes: 1 addition & 2 deletions crates/prover/src/fri/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ fn test_commit_prove_verify_success<F, P>(

// Run the prover to generate the proximity proof
let mut round_prover =
FRIFoldProver::new(&params, &ntt, &merkle_prover, codeword.as_ref(), &codeword_committed)
.unwrap();
FRIFoldProver::new(&params, &ntt, &merkle_prover, codeword, &codeword_committed).unwrap();

let mut prover_challenger = ProverTranscript::new(StdChallenger::default());
prover_challenger.message().write(&codeword_commitment);
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ where
/// * `transcript` - the transcript of the prover's proof
pub fn prove<P, Challenger_>(
&self,
committed_codeword: &'a [P],
committed_codeword: FieldBuffer<P>,
committed: &'a MerkleProver::Committed,
packed_multilin: FieldBuffer<P>,
evaluation_point: Vec<B128>,
Expand Down Expand Up @@ -261,7 +261,7 @@ mod test {
let mut prover_transcript = ProverTranscript::new(StdChallenger::default());
prover_transcript.message().write(&codeword_commitment);
ring_switch_pcs_prover.prove(
codeword.as_ref(),
codeword,
&codeword_committed,
packed_mle,
evaluation_point.clone(),
Expand Down
9 changes: 2 additions & 7 deletions crates/prover/src/protocols/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,8 @@ mod test {
let mut prover_transcript = ProverTranscript::new(StdChallenger::default());
prover_transcript.message().write(&codeword_commitment);

let fri_folder = FRIFoldProver::new(
&fri_params,
&ntt,
&merkle_prover,
codeword.as_ref(),
&codeword_committed,
)?;
let fri_folder =
FRIFoldProver::new(&fri_params, &ntt, &merkle_prover, codeword, &codeword_committed)?;

let prover = BaseFoldProver::new(multilinear, eval_point_eq, evaluation_claim, fri_folder);
prover.prove(&mut prover_transcript)?;
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ where
)
.entered();
pcs_prover.prove(
trace_codeword.as_ref(),
trace_codeword,
&trace_committed,
witness_packed,
eval_point,
Expand Down
2 changes: 1 addition & 1 deletion crates/spartan-prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ where
self.verifier.fri_params(),
&self.ntt,
&self.merkle_prover,
codeword.as_ref(),
codeword,
&codeword_committed,
)?;
wiring::prove(
Expand Down
Loading