diff --git a/benches/whir.rs b/benches/whir.rs index aee03f5c..10aee5c4 100644 --- a/benches/whir.rs +++ b/benches/whir.rs @@ -32,8 +32,7 @@ type MyChallenger = DuplexChallenger; #[allow(clippy::type_complexity)] fn prepare_inputs() -> ( - WhirConfig, - Radix2DFTSmallBatch, + WhirConfig>, EvaluationsList, EqStatement, MyChallenger, @@ -80,6 +79,7 @@ fn prepare_inputs() -> ( folding_factor, merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type, starting_log_inv_rate: starting_rate, rs_domain_initial_reduction_factor, @@ -115,30 +115,25 @@ fn prepare_inputs() -> ( let mut domainsep = DomainSeparator::new(vec![]); // Commit protocol parameters and proof type to the domain separator. - domainsep.commit_statement::<_, _, _, 32>(¶ms); - domainsep.add_whir_proof::<_, _, _, 32>(¶ms); + domainsep.commit_statement::<_, _, _, _, 32>(¶ms); + domainsep.add_whir_proof::<_, _, _, _, 32>(¶ms); // Instantiate the Fiat-Shamir challenger from an empty seed and Keccak. let challenger = MyChallenger::new(poseidon16); - // DFT backend setup - - // Construct a Radix-2 FFT backend that supports small batch DFTs over `F`. - let dft = Radix2DFTSmallBatch::::new(1 << params.max_fft_size()); - // Return all preprocessed components needed to run commit/prove/verify benchmarks. - (params, dft, polynomial, statement, challenger, domainsep) + (params, polynomial, statement, challenger, domainsep) } fn benchmark_commit_and_prove(c: &mut Criterion) { - let (params, dft, polynomial, statement, challenger, domainsep) = prepare_inputs(); + let (params, polynomial, statement, challenger, domainsep) = prepare_inputs(); c.bench_function("commit", |b| { b.iter(|| { let mut prover_state = domainsep.to_prover_state(challenger.clone()); let committer = CommitmentWriter::new(¶ms); let _witness = committer - .commit(&dft, &mut prover_state, polynomial.clone()) + .commit(&mut prover_state, polynomial.clone()) .unwrap(); }); }); @@ -148,12 +143,12 @@ fn benchmark_commit_and_prove(c: &mut Criterion) { let mut prover_state = domainsep.to_prover_state(challenger.clone()); let committer = CommitmentWriter::new(¶ms); let witness = committer - .commit(&dft, &mut prover_state, polynomial.clone()) + .commit(&mut prover_state, polynomial.clone()) .unwrap(); let prover = Prover(¶ms); prover - .prove(&dft, &mut prover_state, statement.clone(), witness) + .prove(&mut prover_state, statement.clone(), witness) .unwrap(); }); }); diff --git a/src/bin/main.rs b/src/bin/main.rs index 1c5e59b0..b975dffa 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -118,13 +118,14 @@ fn main() { folding_factor, merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type, starting_log_inv_rate: starting_rate, rs_domain_initial_reduction_factor, univariate_skip: false, }; - let params = WhirConfig::::new( + let params = WhirConfig::::new( num_variables, whir_params, ); @@ -147,8 +148,8 @@ fn main() { // Define the Fiat-Shamir domain separator pattern for committing and proving let mut domainsep = DomainSeparator::new(vec![]); - domainsep.commit_statement::<_, _, _, 32>(¶ms); - domainsep.add_whir_proof::<_, _, _, 32>(¶ms); + domainsep.commit_statement::<_, _, _, _, 32>(¶ms); + domainsep.add_whir_proof::<_, _, _, _, 32>(¶ms); println!("========================================="); println!("Whir (PCS) 🌪️"); @@ -164,12 +165,8 @@ fn main() { // Commit to the polynomial and produce a witness let committer = CommitmentWriter::new(¶ms); - let dft = Radix2DFTSmallBatch::::new(1 << params.max_fft_size()); - let time = Instant::now(); - let witness = committer - .commit(&dft, &mut prover_state, polynomial) - .unwrap(); + let witness = committer.commit(&mut prover_state, polynomial).unwrap(); let commit_time = time.elapsed(); // Generate a proof using the prover @@ -178,7 +175,7 @@ fn main() { // Generate a proof for the given statement and witness let time = Instant::now(); prover - .prove(&dft, &mut prover_state, statement.clone(), witness) + .prove(&mut prover_state, statement.clone(), witness) .unwrap(); let opening_time = time.elapsed(); diff --git a/src/fiat_shamir/domain_separator.rs b/src/fiat_shamir/domain_separator.rs index 9f496bb0..35af97df 100644 --- a/src/fiat_shamir/domain_separator.rs +++ b/src/fiat_shamir/domain_separator.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; use p3_challenger::{FieldChallenger, GrindingChallenger}; +use p3_dft::TwoAdicSubgroupDft; use p3_field::{ExtensionField, Field, TwoAdicField}; use crate::{ @@ -129,11 +130,13 @@ where } } - pub fn commit_statement( + pub fn commit_statement( &mut self, - params: &WhirConfig, + params: &WhirConfig, ) where + F: TwoAdicField, Challenger: FieldChallenger + GrindingChallenger, + Dft: TwoAdicSubgroupDft, { // TODO: Add params self.observe(DIGEST_ELEMS, Observe::MerkleDigest); @@ -143,13 +146,14 @@ where } } - pub fn add_whir_proof( + pub fn add_whir_proof( &mut self, - params: &WhirConfig, + params: &WhirConfig, ) where - Challenger: FieldChallenger + GrindingChallenger, - EF: TwoAdicField, F: TwoAdicField, + EF: TwoAdicField, + Challenger: FieldChallenger + GrindingChallenger, + Dft: TwoAdicSubgroupDft, { // TODO: Add statement if params.initial_statement { diff --git a/src/parameters/mod.rs b/src/parameters/mod.rs index c033b7ed..3c921b04 100644 --- a/src/parameters/mod.rs +++ b/src/parameters/mod.rs @@ -144,7 +144,7 @@ impl FoldingFactor { /// Configuration parameters for WHIR proofs. #[derive(Clone, Debug)] -pub struct ProtocolParameters { +pub struct ProtocolParameters { /// Whether the initial statement is included in the proof. pub initial_statement: bool, /// The logarithmic inverse rate for sampling. @@ -167,11 +167,13 @@ pub struct ProtocolParameters { pub merkle_hash: H, /// Compression method used in the Merkle tree. pub merkle_compress: C, + /// DFT implementation for polynomial operations. + pub dft: Dft, /// Whether the univariate skip optimization is enabled for the sumcheck protocol. pub univariate_skip: bool, } -impl Display for ProtocolParameters { +impl Display for ProtocolParameters { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!( f, diff --git a/src/whir/committer/reader.rs b/src/whir/committer/reader.rs index b1314cde..42781e77 100644 --- a/src/whir/committer/reader.rs +++ b/src/whir/committer/reader.rs @@ -1,6 +1,7 @@ use std::{fmt::Debug, ops::Deref}; use p3_challenger::{FieldChallenger, GrindingChallenger}; +use p3_dft::TwoAdicSubgroupDft; use p3_field::{ExtensionField, Field, TwoAdicField}; use p3_symmetric::Hash; @@ -97,27 +98,29 @@ where /// The `CommitmentReader` wraps the WHIR configuration and provides a convenient /// method to extract a `ParsedCommitment` by reading values from the Fiat-Shamir transcript. #[derive(Debug)] -pub struct CommitmentReader<'a, EF, F, H, C, Challenger>( - /// Reference to the verifier’s configuration object. +pub struct CommitmentReader<'a, EF, F, H, C, Challenger, Dft>( + /// Reference to the verifier's configuration object. /// /// This contains all parameters needed to parse the commitment, /// including how many out-of-domain samples are expected. - &'a WhirConfig, + &'a WhirConfig, ) where - F: Field, - EF: ExtensionField; + F: TwoAdicField, + EF: ExtensionField, + Dft: TwoAdicSubgroupDft; -impl<'a, EF, F, H, C, Challenger> CommitmentReader<'a, EF, F, H, C, Challenger> +impl<'a, EF, F, H, C, Challenger, Dft> CommitmentReader<'a, EF, F, H, C, Challenger, Dft> where F: TwoAdicField, EF: ExtensionField + TwoAdicField, Challenger: FieldChallenger + GrindingChallenger, + Dft: TwoAdicSubgroupDft, { /// Create a new commitment reader from a WHIR configuration. /// /// This allows the verifier to parse a commitment from the Fiat-Shamir transcript. - pub const fn new(params: &'a WhirConfig) -> Self { + pub const fn new(params: &'a WhirConfig) -> Self { Self(params) } @@ -137,12 +140,13 @@ where } } -impl Deref for CommitmentReader<'_, EF, F, H, C, Challenger> +impl Deref for CommitmentReader<'_, EF, F, H, C, Challenger, Dft> where - F: Field, + F: TwoAdicField, EF: ExtensionField, + Dft: TwoAdicSubgroupDft, { - type Target = WhirConfig; + type Target = WhirConfig; fn deref(&self) -> &Self::Target { self.0 @@ -179,7 +183,14 @@ mod tests { num_variables: usize, ood_samples: usize, ) -> ( - WhirConfig, + WhirConfig< + BabyBear, + BabyBear, + MyHash, + MyCompress, + MyChallenger, + Radix2DFTSmallBatch, + >, rand::rngs::ThreadRng, ) { let mut rng = SmallRng::seed_from_u64(1); @@ -200,6 +211,7 @@ mod tests { folding_factor: FoldingFactor::ConstantFromSecondRound(4, 4), merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type: SecurityAssumption::CapacityBound, starting_log_inv_rate: 1, univariate_skip: false, @@ -226,12 +238,9 @@ mod tests { // Instantiate the committer using the test config. let committer = CommitmentWriter::new(¶ms); - // Use a DFT engine to expand/fold the polynomial for evaluation. - let dft = Radix2DFTSmallBatch::::default(); - // Set up Fiat-Shamir transcript and commit the protocol parameters. let mut ds = DomainSeparator::new(vec![]); - ds.commit_statement::<_, _, _, 8>(¶ms); + ds.commit_statement::<_, _, _, _, 8>(¶ms); // Create the prover state from the transcript. let mut rng = SmallRng::seed_from_u64(1); @@ -240,9 +249,7 @@ mod tests { let mut prover_state = ds.to_prover_state(challenger.clone()); // Commit the polynomial and obtain a witness (root, Merkle proof, OOD evaluations). - let witness = committer - .commit(&dft, &mut prover_state, polynomial) - .unwrap(); + let witness = committer.commit(&mut prover_state, polynomial).unwrap(); // Simulate verifier state using transcript view of prover’s nonce string. let mut verifier_state = @@ -269,11 +276,10 @@ mod tests { // Set up the committer and DFT engine. let committer = CommitmentWriter::new(¶ms); - let dft = Radix2DFTSmallBatch::::default(); // Begin the transcript and commit to the statement parameters. let mut ds = DomainSeparator::new(vec![]); - ds.commit_statement::<_, _, _, 8>(¶ms); + ds.commit_statement::<_, _, _, _, 8>(¶ms); // Generate the prover state from the transcript. let mut rng = SmallRng::seed_from_u64(1); @@ -281,9 +287,7 @@ mod tests { let mut prover_state = ds.to_prover_state(challenger.clone()); // Commit the polynomial to obtain the witness. - let witness = committer - .commit(&dft, &mut prover_state, polynomial) - .unwrap(); + let witness = committer.commit(&mut prover_state, polynomial).unwrap(); // Initialize the verifier view of the transcript. let mut verifier_state = @@ -310,11 +314,10 @@ mod tests { // Initialize the committer and DFT engine. let committer = CommitmentWriter::new(¶ms); - let dft = Radix2DFTSmallBatch::::default(); // Start a new transcript and commit to the public parameters. let mut ds = DomainSeparator::new(vec![]); - ds.commit_statement::<_, _, _, 8>(¶ms); + ds.commit_statement::<_, _, _, _, 8>(¶ms); // Create prover state from the transcript. let mut rng = SmallRng::seed_from_u64(1); @@ -323,9 +326,7 @@ mod tests { let mut prover_state = ds.to_prover_state(challenger.clone()); // Commit the polynomial and obtain the witness. - let witness = committer - .commit(&dft, &mut prover_state, polynomial) - .unwrap(); + let witness = committer.commit(&mut prover_state, polynomial).unwrap(); // Initialize verifier view from prover's transcript string. let mut verifier_state = @@ -350,20 +351,17 @@ mod tests { // Instantiate a committer and DFT backend. let committer = CommitmentWriter::new(¶ms); - let dft = Radix2DFTSmallBatch::::default(); // Set up Fiat-Shamir transcript and commit to the public parameters. let mut ds = DomainSeparator::new(vec![]); - ds.commit_statement::<_, _, _, 8>(¶ms); + ds.commit_statement::<_, _, _, _, 8>(¶ms); // Generate prover and verifier transcript states. let mut rng = SmallRng::seed_from_u64(1); let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); let mut prover_state = ds.to_prover_state(challenger.clone()); - let witness = committer - .commit(&dft, &mut prover_state, polynomial) - .unwrap(); + let witness = committer.commit(&mut prover_state, polynomial).unwrap(); let mut verifier_state = ds.to_verifier_state(prover_state.proof_data().to_vec(), challenger); diff --git a/src/whir/committer/writer.rs b/src/whir/committer/writer.rs index bb5d3d39..be3381f2 100644 --- a/src/whir/committer/writer.rs +++ b/src/whir/committer/writer.rs @@ -3,7 +3,7 @@ use std::{ops::Deref, sync::Arc}; use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_commit::Mmcs; use p3_dft::TwoAdicSubgroupDft; -use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_field::{ExtensionField, TwoAdicField}; use p3_matrix::{Matrix, dense::RowMajorMatrix}; use p3_merkle_tree::MerkleTreeMmcs; use p3_symmetric::{CryptographicHasher, PseudoCompressionFunction}; @@ -24,22 +24,24 @@ use crate::{ /// /// It provides a commitment that can be used for proof generation and verification. #[derive(Debug)] -pub struct CommitmentWriter<'a, EF, F, H, C, Challenger>( +pub struct CommitmentWriter<'a, EF, F, H, C, Challenger, Dft>( /// Reference to the WHIR protocol configuration. - &'a WhirConfig, + &'a WhirConfig, ) where - F: Field, - EF: ExtensionField; + F: TwoAdicField, + EF: ExtensionField, + Dft: TwoAdicSubgroupDft; -impl<'a, EF, F, H, C, Challenger> CommitmentWriter<'a, EF, F, H, C, Challenger> +impl<'a, EF, F, H, C, Challenger, Dft> CommitmentWriter<'a, EF, F, H, C, Challenger, Dft> where F: TwoAdicField, EF: ExtensionField + TwoAdicField, Challenger: FieldChallenger + GrindingChallenger, + Dft: TwoAdicSubgroupDft, { /// Create a new writer that borrows the WHIR protocol configuration. - pub const fn new(params: &'a WhirConfig) -> Self { + pub const fn new(params: &'a WhirConfig) -> Self { Self(params) } @@ -53,9 +55,8 @@ where /// - Computes out-of-domain (OOD) challenge points and their evaluations. /// - Returns a `Witness` containing the commitment data. #[instrument(skip_all)] - pub fn commit, const DIGEST_ELEMS: usize>( + pub fn commit( &self, - dft: &Dft, prover_state: &mut ProverState, polynomial: EvaluationsList, ) -> Result, DIGEST_ELEMS>, FiatShamirError> @@ -78,7 +79,7 @@ where // Perform DFT on the padded evaluations matrix let folded_matrix = info_span!("dft", height = padded.height(), width = padded.width()) - .in_scope(|| dft.dft_batch(padded).to_row_major_matrix()); + .in_scope(|| self.0.dft.dft_batch(padded).to_row_major_matrix()); // Commit to the Merkle tree let merkle_tree = MerkleTreeMmcs::::new( @@ -109,12 +110,13 @@ where } } -impl Deref for CommitmentWriter<'_, EF, F, H, C, Challenger> +impl Deref for CommitmentWriter<'_, EF, F, H, C, Challenger, Dft> where - F: Field, + F: TwoAdicField, EF: ExtensionField, + Dft: TwoAdicSubgroupDft, { - type Target = WhirConfig; + type Target = WhirConfig; fn deref(&self) -> &Self::Target { self.0 @@ -168,14 +170,17 @@ mod tests { ), merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type: SecurityAssumption::CapacityBound, starting_log_inv_rate: starting_rate, univariate_skip: false, }; // Define multivariate parameters for the polynomial. - let params = - WhirConfig::::new(num_variables, whir_params); + let params = WhirConfig::::new( + num_variables, + whir_params, + ); // Generate a random polynomial with 32 coefficients. let mut rng = rand::rng(); @@ -183,8 +188,8 @@ mod tests { // Set up the DomainSeparator and initialize a ProverState narg_string. let mut domainsep: DomainSeparator = DomainSeparator::new(vec![]); - domainsep.commit_statement::<_, _, _, 8>(¶ms); - domainsep.add_whir_proof::<_, _, _, 8>(¶ms); + domainsep.commit_statement::<_, _, _, _, 8>(¶ms); + domainsep.add_whir_proof::<_, _, _, _, 8>(¶ms); let mut rng = SmallRng::seed_from_u64(1); let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); @@ -193,9 +198,8 @@ mod tests { // Run the Commitment Phase let committer = CommitmentWriter::new(¶ms); - let dft = Radix2DFTSmallBatch::::default(); let witness = committer - .commit(&dft, &mut prover_state, polynomial.clone()) + .commit(&mut prover_state, polynomial.clone()) .unwrap(); // Ensure OOD (out-of-domain) points are generated. @@ -247,30 +251,30 @@ mod tests { ), merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type: SecurityAssumption::CapacityBound, starting_log_inv_rate: starting_rate, univariate_skip: false, }; - let params = - WhirConfig::::new(num_variables, whir_params); + let params = WhirConfig::::new( + num_variables, + whir_params, + ); let mut rng = rand::rng(); let polynomial = EvaluationsList::::new(vec![rng.random(); 1024]); let mut domainsep = DomainSeparator::new(vec![]); - domainsep.commit_statement::<_, _, _, 8>(¶ms); + domainsep.commit_statement::<_, _, _, _, 8>(¶ms); let mut rng = SmallRng::seed_from_u64(1); let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); let mut prover_state = domainsep.to_prover_state(challenger); - let dft = Radix2DFTSmallBatch::::default(); let committer = CommitmentWriter::new(¶ms); - let _ = committer - .commit(&dft, &mut prover_state, polynomial) - .unwrap(); + let _ = committer.commit(&mut prover_state, polynomial).unwrap(); } #[test] @@ -299,13 +303,16 @@ mod tests { ), merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type: SecurityAssumption::CapacityBound, starting_log_inv_rate: starting_rate, univariate_skip: false, }; - let mut params = - WhirConfig::::new(num_variables, whir_params); + let mut params = WhirConfig::::new( + num_variables, + whir_params, + ); // Explicitly set OOD samples to 0 params.commitment_ood_samples = 0; @@ -314,18 +321,15 @@ mod tests { let polynomial = EvaluationsList::::new(vec![rng.random(); 32]); let mut domainsep = DomainSeparator::new(vec![]); - domainsep.commit_statement::<_, _, _, 8>(¶ms); + domainsep.commit_statement::<_, _, _, _, 8>(¶ms); let mut rng = SmallRng::seed_from_u64(1); let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); let mut prover_state = domainsep.to_prover_state(challenger); - let dft = Radix2DFTSmallBatch::::default(); let committer = CommitmentWriter::new(¶ms); - let witness = committer - .commit(&dft, &mut prover_state, polynomial) - .unwrap(); + let witness = committer.commit(&mut prover_state, polynomial).unwrap(); assert!( witness.ood_statement.is_empty(), diff --git a/src/whir/constraints/evaluator.rs b/src/whir/constraints/evaluator.rs index 6fcd9464..43e0390a 100644 --- a/src/whir/constraints/evaluator.rs +++ b/src/whir/constraints/evaluator.rs @@ -1,3 +1,4 @@ +use p3_dft::TwoAdicSubgroupDft; use p3_field::{ExtensionField, Field, TwoAdicField}; use crate::{ @@ -187,12 +188,14 @@ impl ConstraintPolyEvaluator { } } -impl From> for ConstraintPolyEvaluator +impl From> + for ConstraintPolyEvaluator where - F: Field, + F: TwoAdicField, EF: ExtensionField, + Dft: TwoAdicSubgroupDft, { - fn from(cfg: WhirConfig) -> Self { + fn from(cfg: WhirConfig) -> Self { Self { num_variables: cfg.num_variables, folding_factor: cfg.folding_factor, @@ -205,6 +208,7 @@ where mod tests { use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; use p3_challenger::DuplexChallenger; + use p3_dft::Radix2DFTSmallBatch; use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_interpolation::interpolate_subgroup; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; @@ -253,6 +257,7 @@ mod tests { folding_factor, merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), univariate_skip: false, initial_statement: true, security_level: 90, @@ -262,7 +267,7 @@ mod tests { rs_domain_initial_reduction_factor: 1, }; let params = - WhirConfig::::new(num_vars, whir_params); + WhirConfig::::new(num_vars, whir_params); let evaluator: ConstraintPolyEvaluator = params.into(); // -- Random Constraints and Challenges -- @@ -380,6 +385,7 @@ mod tests { folding_factor, merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), // This test is for the standard, non-skip case. univariate_skip: false, initial_statement: true, @@ -391,7 +397,7 @@ mod tests { }; // Create the complete verifier configuration object. let params = - WhirConfig::::new(num_vars, whir_params); + WhirConfig::::new(num_vars, whir_params); let evaluator: ConstraintPolyEvaluator = params.into(); // -- Random Constraints and Challenges -- @@ -487,6 +493,7 @@ mod tests { folding_factor, merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), // This test is for the skip case. univariate_skip: true, initial_statement: true, @@ -497,7 +504,7 @@ mod tests { rs_domain_initial_reduction_factor: 1, }; let params = - WhirConfig::::new(num_vars, whir_params); + WhirConfig::::new(num_vars, whir_params); let evaluator: ConstraintPolyEvaluator = params.into(); // -- Random Constraints and Challenges -- @@ -637,6 +644,7 @@ mod tests { folding_factor, merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), // This test is for the skip case. univariate_skip: true, initial_statement: true, @@ -647,7 +655,7 @@ mod tests { rs_domain_initial_reduction_factor: 1, }; let params = - WhirConfig::::new(num_vars, whir_params); + WhirConfig::::new(num_vars, whir_params); let evaluator: ConstraintPolyEvaluator = params.into(); // -- Random Constraints and Challenges -- diff --git a/src/whir/mod.rs b/src/whir/mod.rs index 9e5e6107..dcfdf0a2 100644 --- a/src/whir/mod.rs +++ b/src/whir/mod.rs @@ -66,6 +66,7 @@ pub fn make_whir_things( folding_factor, merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type, starting_log_inv_rate: 1, univariate_skip: use_univariate_skip, @@ -73,7 +74,7 @@ pub fn make_whir_things( // Create unified configuration combining protocol and polynomial parameters let params = - WhirConfig::::new(num_variables, whir_params); + WhirConfig::::new(num_variables, whir_params); // Define test polynomial: all coefficients = 1 for simple verification // @@ -98,9 +99,9 @@ pub fn make_whir_things( // Setup Fiat-Shamir transcript structure for non-interactive proof generation let mut domainsep = DomainSeparator::new(vec![]); // Add statement commitment to transcript - domainsep.commit_statement::<_, _, _, 32>(¶ms); + domainsep.commit_statement::<_, _, _, _, 32>(¶ms); // Add proof structure to transcript - domainsep.add_whir_proof::<_, _, _, 32>(¶ms); + domainsep.add_whir_proof::<_, _, _, _, 32>(¶ms); // Create fresh RNG and challenger for transcript randomness let mut rng = SmallRng::seed_from_u64(1); @@ -111,20 +112,16 @@ pub fn make_whir_things( // Create polynomial commitment using Merkle tree over evaluation domain let committer = CommitmentWriter::new(¶ms); - // DFT evaluator for polynomial - let dft = Radix2DFTSmallBatch::::default(); // Commit to polynomial evaluations and generate cryptographic witness - let witness = committer - .commit(&dft, &mut prover_state, polynomial) - .unwrap(); + let witness = committer.commit(&mut prover_state, polynomial).unwrap(); // Initialize WHIR prover with the configured parameters let prover = Prover(¶ms); // Generate WHIR proof prover - .prove(&dft, &mut prover_state, statement.clone(), witness) + .prove(&mut prover_state, statement.clone(), witness) .unwrap(); // Sample final challenge to ensure transcript consistency between prover/verifier diff --git a/src/whir/parameters.rs b/src/whir/parameters.rs index 0108ae93..e1655e69 100644 --- a/src/whir/parameters.rs +++ b/src/whir/parameters.rs @@ -1,6 +1,7 @@ use std::{f64::consts::LOG2_10, marker::PhantomData}; use p3_challenger::{FieldChallenger, GrindingChallenger}; +use p3_dft::TwoAdicSubgroupDft; use p3_field::{ExtensionField, Field, TwoAdicField}; use crate::parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}; @@ -19,7 +20,7 @@ pub struct RoundConfig { } #[derive(Debug, Clone)] -pub struct WhirConfig +pub struct WhirConfig where F: Field, EF: ExtensionField, @@ -53,6 +54,9 @@ where pub merkle_hash: Hash, pub merkle_compress: C, + // DFT implementation for polynomial operations + pub dft: Dft, + // Univariate skip optimization pub univariate_skip: bool, @@ -61,14 +65,15 @@ where pub _challenger: PhantomData, } -impl WhirConfig +impl WhirConfig where F: TwoAdicField, EF: ExtensionField + TwoAdicField, Challenger: FieldChallenger + GrindingChallenger, + Dft: TwoAdicSubgroupDft, { #[allow(clippy::too_many_lines)] - pub fn new(num_variables: usize, whir_parameters: ProtocolParameters) -> Self { + pub fn new(num_variables: usize, whir_parameters: ProtocolParameters) -> Self { // We need to store the initial number of variables for the final composition. let initial_num_variables = num_variables; whir_parameters @@ -239,6 +244,7 @@ where final_log_inv_rate: log_inv_rate, merkle_hash: whir_parameters.merkle_hash, merkle_compress: whir_parameters.merkle_compress, + dft: whir_parameters.dft, univariate_skip: whir_parameters.univariate_skip, _base_field: PhantomData, _extension_field: PhantomData, @@ -445,6 +451,7 @@ where mod tests { use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; use p3_challenger::DuplexChallenger; + use p3_dft::Radix2DFTSmallBatch; use p3_field::PrimeCharacteristicRing; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; @@ -457,8 +464,9 @@ mod tests { type MyChallenger = DuplexChallenger; /// Generates default WHIR parameters - const fn default_whir_params() - -> ProtocolParameters, Poseidon2Compression> { + fn default_whir_params() + -> ProtocolParameters, Poseidon2Compression, Radix2DFTSmallBatch> + { ProtocolParameters { initial_statement: true, security_level: 100, @@ -467,6 +475,7 @@ mod tests { folding_factor: FoldingFactor::ConstantFromSecondRound(4, 4), merkle_hash: Poseidon2Sponge::new(44), // Just a placeholder merkle_compress: Poseidon2Compression::new(55), // Just a placeholder + dft: Radix2DFTSmallBatch::::default(), soundness_type: SecurityAssumption::CapacityBound, starting_log_inv_rate: 1, univariate_skip: false, @@ -477,10 +486,14 @@ mod tests { fn test_whir_config_creation() { let params = default_whir_params(); - let config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); assert_eq!(config.security_level, 100); assert_eq!(config.max_pow_bits, 20); @@ -491,10 +504,14 @@ mod tests { #[test] fn test_n_rounds() { let params = default_whir_params(); - let config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); assert_eq!(config.n_rounds(), config.round_parameters.len()); } @@ -504,13 +521,14 @@ mod tests { let field_size_bits = 64; let soundness = SecurityAssumption::CapacityBound; - let pow_bits = WhirConfig::::folding_pow_bits( - 100, // Security level - soundness, - field_size_bits, - 10, // Number of variables - 5, // Log inverse rate - ); + let pow_bits = + WhirConfig::>::folding_pow_bits( + 100, // Security level + soundness, + field_size_bits, + 10, // Number of variables + 5, // Log inverse rate + ); // PoW bits should never be negative assert!(pow_bits >= 0.); @@ -519,10 +537,14 @@ mod tests { #[test] fn test_check_pow_bits_within_limits() { let params = default_whir_params(); - let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let mut config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); // Set all values within limits config.max_pow_bits = 20; @@ -565,10 +587,14 @@ mod tests { #[test] fn test_check_pow_bits_starting_folding_exceeds() { let params = default_whir_params(); - let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let mut config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 21; // Exceeds max_pow_bits @@ -584,10 +610,14 @@ mod tests { #[test] fn test_check_pow_bits_final_pow_exceeds() { let params = default_whir_params(); - let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let mut config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 15; @@ -603,10 +633,14 @@ mod tests { #[test] fn test_check_pow_bits_round_pow_exceeds() { let params = default_whir_params(); - let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let mut config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 15; @@ -635,10 +669,14 @@ mod tests { #[test] fn test_check_pow_bits_round_folding_pow_exceeds() { let params = default_whir_params(); - let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let mut config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 15; @@ -667,10 +705,14 @@ mod tests { #[test] fn test_check_pow_bits_exactly_at_limit() { let params = default_whir_params(); - let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let mut config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 20; @@ -698,10 +740,14 @@ mod tests { #[test] fn test_check_pow_bits_all_exceed() { let params = default_whir_params(); - let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + let mut config = WhirConfig::< + F, + F, + Poseidon2Sponge, + Poseidon2Compression, + MyChallenger, + _, + >::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 22; diff --git a/src/whir/prover/mod.rs b/src/whir/prover/mod.rs index 19ced4fe..29f50a19 100644 --- a/src/whir/prover/mod.rs +++ b/src/whir/prover/mod.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_commit::{ExtensionMmcs, Mmcs}; use p3_dft::TwoAdicSubgroupDft; -use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_field::{ExtensionField, TwoAdicField}; use p3_interpolation::interpolate_subgroup; use p3_matrix::{ Matrix, @@ -32,31 +32,34 @@ pub type Proof = Vec>; pub type Leafs = Vec>; #[derive(Debug)] -pub struct Prover<'a, EF, F, H, C, Challenger>( +pub struct Prover<'a, EF, F, H, C, Challenger, Dft>( /// Reference to the protocol configuration shared across prover components. - pub &'a WhirConfig, + pub &'a WhirConfig, ) where - F: Field, - EF: ExtensionField; + F: TwoAdicField, + EF: ExtensionField, + Dft: TwoAdicSubgroupDft; -impl Deref for Prover<'_, EF, F, H, C, Challenger> +impl Deref for Prover<'_, EF, F, H, C, Challenger, Dft> where - F: Field, + F: TwoAdicField, EF: ExtensionField, + Dft: TwoAdicSubgroupDft, { - type Target = WhirConfig; + type Target = WhirConfig; fn deref(&self) -> &Self::Target { self.0 } } -impl Prover<'_, EF, F, H, C, Challenger> +impl Prover<'_, EF, F, H, C, Challenger, Dft> where F: TwoAdicField + Ord, EF: ExtensionField + TwoAdicField, Challenger: FieldChallenger + GrindingChallenger, + Dft: TwoAdicSubgroupDft, { /// Validates that the total number of variables expected by the prover configuration /// matches the number implied by the folding schedule and the final rounds. @@ -136,9 +139,8 @@ where /// # Errors /// Returns an error if the witness or statement are invalid, or if a round fails. #[instrument(skip_all)] - pub fn prove, const DIGEST_ELEMS: usize>( + pub fn prove( &self, - dft: &Dft, prover_state: &mut ProverState, statement: EqStatement, witness: Witness, DIGEST_ELEMS>, @@ -166,7 +168,7 @@ where // Run the WHIR protocol round-by-round for round in 0..=self.n_rounds() { - self.round(dft, round, prover_state, &mut round_state)?; + self.round(round, prover_state, &mut round_state)?; } // Reverse the vector of verifier challenges (used as evaluation point) @@ -181,9 +183,8 @@ where #[instrument(skip_all, fields(round_number = round_index, log_size = self.num_variables - self.folding_factor.total_number(round_index)))] #[allow(clippy::too_many_lines)] - fn round>( + fn round( &self, - dft: &Dft, round_index: usize, prover_state: &mut ProverState, round_state: &mut RoundState, DIGEST_ELEMS>, @@ -226,7 +227,7 @@ where // Perform DFT on the padded evaluations matrix let folded_matrix = info_span!("dft", height = padded.height(), width = padded.width()) - .in_scope(|| dft.dft_algebra_batch(padded).to_row_major_matrix()); + .in_scope(|| self.dft.dft_algebra_batch(padded).to_row_major_matrix()); let mmcs = MerkleTreeMmcs::::new( self.merkle_hash.clone(), diff --git a/src/whir/prover/round_state/state.rs b/src/whir/prover/round_state/state.rs index f1d19c65..37d3ee4e 100644 --- a/src/whir/prover/round_state/state.rs +++ b/src/whir/prover/round_state/state.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use p3_challenger::{FieldChallenger, GrindingChallenger}; +use p3_dft::TwoAdicSubgroupDft; use p3_field::{ExtensionField, TwoAdicField}; use p3_matrix::dense::DenseMatrix; use p3_merkle_tree::MerkleTree; @@ -146,8 +147,8 @@ where /// /// Returns the complete `RoundState` ready for the first WHIR folding round. #[instrument(skip_all)] - pub fn initialize_first_round_state( - prover: &Prover<'_, EF, F, MyChallenger, C, Challenger>, + pub fn initialize_first_round_state( + prover: &Prover<'_, EF, F, MyChallenger, C, Challenger, Dft>, prover_state: &mut ProverState, mut statement: EqStatement, witness: Witness, DIGEST_ELEMS>, @@ -156,6 +157,7 @@ where Challenger: FieldChallenger + GrindingChallenger, MyChallenger: Clone, C: Clone, + Dft: TwoAdicSubgroupDft, { // Append OOD constraints to statement for Reed-Solomon proximity testing statement.concatenate(&witness.ood_statement); diff --git a/src/whir/prover/round_state/tests.rs b/src/whir/prover/round_state/tests.rs index 76e8d451..221aa28a 100644 --- a/src/whir/prover/round_state/tests.rs +++ b/src/whir/prover/round_state/tests.rs @@ -43,7 +43,7 @@ fn make_test_config( initial_statement: bool, folding_factor: usize, pow_bits: usize, -) -> WhirConfig { +) -> WhirConfig> { let mut rng = SmallRng::seed_from_u64(1); let perm = Perm::new_from_rng_128(&mut rng); @@ -60,6 +60,7 @@ fn make_test_config( folding_factor: FoldingFactor::Constant(folding_factor), merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type: SecurityAssumption::CapacityBound, starting_log_inv_rate: 1, univariate_skip: false, @@ -80,7 +81,7 @@ fn make_test_config( /// This is used as a boilerplate step before running the first WHIR round. #[allow(clippy::type_complexity)] fn setup_domain_and_commitment( - params: &WhirConfig, + params: &WhirConfig>, poly: EvaluationsList, ) -> ( DomainSeparator, @@ -91,10 +92,10 @@ fn setup_domain_and_commitment( let mut domsep = DomainSeparator::new(vec![]); // Observe the public statement into the transcript for binding. - domsep.commit_statement::<_, _, _, 8>(params); + domsep.commit_statement::<_, _, _, _, 8>(params); // Reserve transcript space for WHIR proof messages. - domsep.add_whir_proof::<_, _, _, 8>(params); + domsep.add_whir_proof::<_, _, _, _, 8>(params); let mut rng = SmallRng::seed_from_u64(1); let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); @@ -107,13 +108,7 @@ fn setup_domain_and_commitment( // Perform DFT-based commitment to the polynomial, producing a witness // which includes the Merkle tree and polynomial values. - let witness = committer - .commit( - &Radix2DFTSmallBatch::::default(), - &mut prover_state, - poly, - ) - .unwrap(); + let witness = committer.commit(&mut prover_state, poly).unwrap(); // Return all initialized components needed for round state setup. (domsep, prover_state, witness) diff --git a/src/whir/verifier/mod.rs b/src/whir/verifier/mod.rs index aaa8eef1..c6259bb1 100644 --- a/src/whir/verifier/mod.rs +++ b/src/whir/verifier/mod.rs @@ -3,7 +3,8 @@ use std::{fmt::Debug, ops::Deref}; use errors::VerifierError; use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_commit::{BatchOpeningRef, ExtensionMmcs, Mmcs}; -use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{ExtensionField, TwoAdicField}; use p3_interpolation::interpolate_subgroup; use p3_matrix::Dimensions; use p3_merkle_tree::MerkleTreeMmcs; @@ -37,21 +38,23 @@ pub mod sumcheck; /// This type provides a lightweight, ergonomic interface to verification methods /// by wrapping a reference to the `WhirConfig`. #[derive(Debug)] -pub struct Verifier<'a, EF, F, H, C, Challenger>( - /// Reference to the verifier’s configuration containing all round parameters. - pub(crate) &'a WhirConfig, +pub struct Verifier<'a, EF, F, H, C, Challenger, Dft>( + /// Reference to the verifier's configuration containing all round parameters. + pub(crate) &'a WhirConfig, ) where - F: Field, - EF: ExtensionField; + F: TwoAdicField, + EF: ExtensionField, + Dft: TwoAdicSubgroupDft; -impl<'a, EF, F, H, C, Challenger> Verifier<'a, EF, F, H, C, Challenger> +impl<'a, EF, F, H, C, Challenger, Dft> Verifier<'a, EF, F, H, C, Challenger, Dft> where F: TwoAdicField, EF: ExtensionField + TwoAdicField, Challenger: FieldChallenger + GrindingChallenger, + Dft: TwoAdicSubgroupDft, { - pub const fn new(params: &'a WhirConfig) -> Self { + pub const fn new(params: &'a WhirConfig) -> Self { Self(params) } @@ -491,12 +494,13 @@ where } } -impl Deref for Verifier<'_, EF, F, H, C, Challenger> +impl Deref for Verifier<'_, EF, F, H, C, Challenger, Dft> where - F: Field, + F: TwoAdicField, EF: ExtensionField, + Dft: TwoAdicSubgroupDft, { - type Target = WhirConfig; + type Target = WhirConfig; fn deref(&self) -> &Self::Target { self.0 diff --git a/src/whir/verifier/sumcheck.rs b/src/whir/verifier/sumcheck.rs index 6ad84eb6..e9cd92ae 100644 --- a/src/whir/verifier/sumcheck.rs +++ b/src/whir/verifier/sumcheck.rs @@ -144,6 +144,7 @@ where mod tests { use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; use p3_challenger::DuplexChallenger; + use p3_dft::Radix2DFTSmallBatch; use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use rand::{SeedableRng, rngs::SmallRng}; @@ -174,7 +175,7 @@ mod tests { /// Constructs a default WHIR configuration for testing fn default_whir_config( num_variables: usize, - ) -> WhirConfig { + ) -> WhirConfig> { // Create hash and compression functions for the Merkle tree let mut rng = SmallRng::seed_from_u64(1); let perm = Perm::new_from_rng_128(&mut rng); @@ -191,6 +192,7 @@ mod tests { folding_factor: FoldingFactor::Constant(2), merkle_hash, merkle_compress, + dft: Radix2DFTSmallBatch::::default(), soundness_type: SecurityAssumption::UniqueDecoding, starting_log_inv_rate: 1, univariate_skip: false,