Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions crates/prover/src/merkle_tree/binary_merkle_tree.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// Copyright 2024-2025 Irreducible Inc.

use std::{fmt::Debug, iter::repeat_with, mem::MaybeUninit};
use std::{fmt::Debug, mem::MaybeUninit};

use binius_field::Field;
use binius_utils::{
checked_arithmetics::log2_strict_usize,
mem::slice_assume_init_mut,
rand::par_rand,
rayon::{prelude::*, slice::ParallelSlice},
};
use binius_verifier::merkle_tree::Error;
use digest::{FixedOutputReset, Output, crypto_common::BlockSizeUser};
use rand::{CryptoRng, Rng};
use rand::{CryptoRng, Rng, rngs::StdRng};

use crate::hash::{ParallelDigest, parallel_compression::ParallelPseudoCompression};

Expand Down Expand Up @@ -78,9 +79,8 @@ where
let log_len = log2_strict_usize(iterated_chunks.len()); // precondition

// Generate salts if needed
let salts = repeat_with(|| F::random(&mut rng))
.take(salt_len << log_len)
.collect::<Vec<_>>();
let salts =
par_rand::<StdRng, _, _>(salt_len << log_len, &mut rng, F::random).collect::<Vec<_>>();

let total_length = (1 << (log_len + 1)) - 1;
let mut inner_nodes = Vec::with_capacity(total_length);
Expand Down
187 changes: 181 additions & 6 deletions crates/prover/src/protocols/basefold.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Copyright 2025 Irreducible Inc.

use binius_field::{BinaryField, PackedField};
use binius_math::{FieldBuffer, ntt::AdditiveNTT};
use binius_math::{
FieldBuffer, inner_product::inner_product_par, line::extrapolate_line_packed,
multilinear::fold::fold_highest_var_inplace, ntt::AdditiveNTT,
};
use binius_transcript::{
ProverTranscript,
fiat_shamir::{CanSample, Challenger},
Expand Down Expand Up @@ -66,7 +69,7 @@ where
fri_folder: FRIFoldProver<'a, F, P, NTT, MerkleProver>,
) -> Self {
assert_eq!(multilinear.log_len(), transparent_multilinear.log_len());
assert_eq!(multilinear.log_len(), fri_folder.n_rounds());
assert_eq!(multilinear.log_len(), fri_folder.n_rounds() - fri_folder.curr_round());

let sumcheck_prover =
BivariateProductSumcheckProver::new([multilinear, transparent_multilinear], claim)
Expand Down Expand Up @@ -154,6 +157,69 @@ where
}
}

/// Performs ZK batching setup and returns a BaseFoldProver for the remaining protocol.
///
/// This handles the zero-knowledge case where the witness is blinded with a random mask.
/// It performs the initial unbatch round and returns a prover configured for the remaining
/// n rounds of sumcheck + FRI.
///
/// ## Arguments
///
/// * `multilinear` - batched (witness || mask) polynomial with log_len = n+1
/// * `transparent_multilinear` - l_poly with log_len = n
/// * `claim` - the original sumcheck claim (before ZK batching)
/// * `fri_folder` - FRI fold prover with n_rounds = n+1
/// * `transcript` - prover transcript
///
/// ## Returns
///
/// A `BaseFoldProver` configured for the remaining n rounds. Caller should call
/// `.prove(transcript)`.
pub fn prove_zk<'a, F, P, NTT, MerkleScheme, MerkleProver, Challenger_>(
mut multilinear: FieldBuffer<P>,
transparent_multilinear: FieldBuffer<P>,
sum_claim: F,
mut fri_folder: FRIFoldProver<'a, F, P, NTT, MerkleProver>,
transcript: &mut ProverTranscript<Challenger_>,
) -> BaseFoldProver<'a, F, P, NTT, MerkleProver>
where
F: BinaryField,
P: PackedField<Scalar = F>,
NTT: AdditiveNTT<Field = F> + Sync,
MerkleScheme: MerkleTreeScheme<F, Digest: SerializeBytes>,
MerkleProver: MerkleTreeProver<F, Scheme = MerkleScheme>,
Challenger_: Challenger,
{
let _scope = tracing::debug_span!("Basefold ZK setup").entered();

assert_eq!(multilinear.log_len(), transparent_multilinear.log_len() + 1);
assert_eq!(multilinear.log_len(), fri_folder.n_rounds());

// Compute blinding_eval = sum_x[mask * l_poly]
// The verifier will compute sum = (1-r)*claim + r*blinding_eval using linear interpolation.
let (_witness, mask) = multilinear
.split_half_ref()
.expect("multilinear has log_len >= 1");
let mask_claim = inner_product_par(&mask, &transparent_multilinear);

// Write blinding_eval to transcript
transcript.message().write(&mask_claim);

// Sample batch challenge (before FRI fold round, matching verifier order)
let batch_challenge: F = transcript.sample();

// Receive batch challenge to advance to round 1 (no commitment at batch round)
fri_folder.receive_challenge(batch_challenge);

// Fold multilinear at its last variable.
fold_highest_var_inplace(&mut multilinear, batch_challenge)
.expect("multilinear has log_len >= 1");

// Compute the batched sum using linear interpolation.
let batched_sum = extrapolate_line_packed(sum_claim, mask_claim, batch_challenge);
BaseFoldProver::new(multilinear, transparent_multilinear, batched_sum, fri_folder)
}

#[cfg(test)]
mod test {
use anyhow::{Result, bail};
Expand All @@ -177,7 +243,7 @@ mod test {
};
use rand::{SeedableRng, rngs::StdRng};

use super::BaseFoldProver;
use super::{BaseFoldProver, prove_zk};
use crate::{
fri::{self, CommitOutput, FRIFoldProver},
hash::parallel_compression::ParallelCompressionAdaptor,
Expand Down Expand Up @@ -267,13 +333,13 @@ mod test {
{
let mut rng = StdRng::from_seed([0; 32]);

let multilinear = random_field_buffer::<P>(&mut rng, n_vars);
let witness = random_field_buffer::<P>(&mut rng, n_vars);
let evaluation_point = random_scalars::<F>(&mut rng, n_vars);

let eval_point_eq = eq_ind_partial_eval(&evaluation_point);
let evaluation_claim = inner_product_buffers(&multilinear, &eval_point_eq);
let evaluation_claim = inner_product_buffers(&witness, &eval_point_eq);

(multilinear, evaluation_point, evaluation_claim)
(witness, evaluation_point, evaluation_claim)
}

fn dubiously_modify_claim<F, P>(claim: &mut F)
Expand All @@ -284,6 +350,115 @@ mod test {
*claim += P::Scalar::ONE
}

fn run_basefold_zk_prove_and_verify<F, P>(
witness_plus_mask: FieldBuffer<P>,
evaluation_point: Vec<F>,
evaluation_claim: F,
) -> Result<()>
where
F: BinaryField,
P: PackedField<Scalar = F> + PackedExtension<F>,
{
let n_vars = evaluation_point.len();
assert_eq!(witness_plus_mask.log_len(), n_vars + 1);

let eval_point_eq = eq_ind_partial_eval::<P>(&evaluation_point);

let merkle_prover = BinaryMerkleTreeProver::<F, StdDigest, _>::new(
ParallelCompressionAdaptor::new(StdCompression::default()),
);

// Setup NTT with subspace dimension = witness.log_len + LOG_INV_RATE
let subspace = BinarySubspace::with_dim(n_vars + LOG_INV_RATE).unwrap();
let domain_context = GenericOnTheFly::generate_from_subspace(&subspace);
let ntt = NeighborsLastSingleThread::new(domain_context);

// Create FRI params with log_batch_size = 1
let fri_params = FRIParams::with_strategy(
&ntt,
merkle_prover.scheme(),
witness_plus_mask.log_len(),
Some(1),
LOG_INV_RATE,
32,
&ConstantArityStrategy::new(2),
)?;

// Commit batched multilinear
let CommitOutput {
commitment: codeword_commitment,
committed: codeword_committed,
codeword,
} = fri::commit_interleaved(&fri_params, &ntt, &merkle_prover, witness_plus_mask.to_ref())?;

let mut prover_transcript = ProverTranscript::new(StdChallenger::default());
prover_transcript.message().write(&codeword_commitment);

let fri_folder =
FRIFoldProver::new(&fri_params, &ntt, &merkle_prover, codeword, &codeword_committed)?;

// Run prove_zk then continue with basefold prover
let prover = prove_zk(
witness_plus_mask,
eval_point_eq,
evaluation_claim,
fri_folder,
&mut prover_transcript,
);
prover.prove(&mut prover_transcript)?;

// Verify
let mut verifier_transcript = prover_transcript.into_verifier();
let retrieved_commitment = verifier_transcript.message().read()?;

let basefold::ReducedOutput {
final_fri_value,
final_sumcheck_value,
challenges,
} = basefold::verify_zk(
&fri_params,
merkle_prover.scheme(),
retrieved_commitment,
evaluation_claim,
&mut verifier_transcript,
)?;

// Check consistency - skip batch challenge (challenges[0])
let sumcheck_challenges = challenges[1..].to_vec();
if !basefold::sumcheck_fri_consistency(
final_fri_value,
final_sumcheck_value,
&evaluation_point,
sumcheck_challenges,
) {
bail!("Sumcheck and FRI are inconsistent");
}

Ok(())
}

#[test]
fn test_basefold_zk_valid_proof() {
type P = PackedBinaryGhash1x128b;

let n_vars = 8;
let mut rng = StdRng::seed_from_u64(0);

let witness_plus_mask = random_field_buffer::<P>(&mut rng, n_vars + 1);
let evaluation_point = random_scalars(&mut rng, n_vars);

let (witness, _mask) = witness_plus_mask.split_half_ref().unwrap();
let eval_point_eq = eq_ind_partial_eval::<P>(&evaluation_point);
let evaluation_claim = inner_product_buffers(&witness, &eval_point_eq);

run_basefold_zk_prove_and_verify::<_, P>(
witness_plus_mask,
evaluation_point,
evaluation_claim,
)
.unwrap();
}

#[test]
fn test_basefold_valid_proof() {
type P = PackedBinaryGhash1x128b;
Expand Down
1 change: 1 addition & 0 deletions crates/spartan-prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ binius-utils = { path = "../utils" }
binius-prover = { path = "../prover" }
binius-verifier = { path = "../verifier" }
digest.workspace = true
itertools.workspace = true
rand.workspace = true
thiserror.workspace = true
tracing.workspace = true
Expand Down
69 changes: 37 additions & 32 deletions crates/spartan-prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ mod error;
pub mod pcs;
mod wiring;

use std::marker::PhantomData;
use std::{
iter::{repeat_n, repeat_with},
marker::PhantomData,
};

use binius_field::{BinaryField, Field, PackedExtension, PackedField};
use binius_math::{
Expand All @@ -23,10 +26,16 @@ use binius_transcript::{
ProverTranscript,
fiat_shamir::{CanSample, Challenger},
};
use binius_utils::{SerializeBytes, checked_arithmetics::checked_log_2, rayon::prelude::*};
use binius_utils::{
SerializeBytes,
checked_arithmetics::checked_log_2,
rand::par_rand,
rayon::{self, prelude::*},
};
use digest::{Digest, FixedOutputReset, Output, core_api::BlockSizeUser};
pub use error::*;
use rand::CryptoRng;
use itertools::chain;
use rand::{CryptoRng, rngs::StdRng};

use crate::wiring::WiringTranspose;

Expand Down Expand Up @@ -218,37 +227,33 @@ fn pack_and_blind_witness<F: Field, P: PackedField<Scalar = F>>(
n_private: usize,
mut rng: impl CryptoRng,
) -> FieldBuffer<P> {
// Precondition: witness length must match expected size
let expected_size = 1 << log_witness_elems;
assert_eq!(
witness.len(),
expected_size,
"witness length {} does not match expected size {}",
witness.len(),
expected_size
);

let len = 1 << log_witness_elems.saturating_sub(P::LOG_WIDTH);
let mut packed_witness = Vec::<P>::with_capacity(len);

packed_witness
.spare_capacity_mut()
.into_par_iter()
.enumerate()
.for_each(|(i, dst)| {
let offset = i << P::LOG_WIDTH;
let value = P::from_fn(|j| witness[offset + j]);

dst.write(value);
});

// SAFETY: We just initialized all elements
unsafe {
packed_witness.set_len(len);
let packed_witness = if log_witness_elems < P::LOG_WIDTH {
let elems_iter = witness.iter().copied();
let zeros_iter = repeat_n(F::ZERO, (1 << log_witness_elems) - witness.len());
let mask_iter = repeat_with(|| F::random(&mut rng)).take(1 << log_witness_elems);

let elems = P::from_scalars(chain!(elems_iter, zeros_iter, mask_iter));
vec![elems]
} else {
let packed_len = 1 << (log_witness_elems - P::LOG_WIDTH);

let elems_iter = witness
.par_chunks(P::WIDTH)
.map(|chunk| P::from_scalars(chunk.iter().copied()));
let zeros_iter = rayon::iter::repeat_n(P::zero(), packed_len - elems_iter.len());

// Append a random mask to the end of the witness buffer, of equal length to the witness.
let mask_iter = par_rand::<StdRng, _, _>(packed_len, &mut rng, P::random);

elems_iter
.chain(zeros_iter)
.chain(mask_iter)
.collect::<Vec<_>>()
};

let mut witness_packed = FieldBuffer::new(log_witness_elems, packed_witness.into_boxed_slice())
.expect("FieldBuffer::new should succeed with correct log_witness_elems");
let mut witness_packed =
FieldBuffer::new(log_witness_elems + 1, packed_witness.into_boxed_slice())
.expect("FieldBuffer::new should succeed with correct log_witness_elems");

// Add blinding values
let base = n_public + n_private;
Expand Down
Loading