Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,20 @@ p3-commit = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-goldilocks = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-dft = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-maybe-rayon = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016", features = [
] }
p3-maybe-rayon = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016", features = [] }
p3-interpolation = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-multilinear-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-uni-stark = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }
p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "f742016" }

hashbrown = "0.16"
itertools = { version = "0.14.0", default-features = false, features = [
"use_alloc",
] }
libm = "0.2.15"
bincode = "1.3"
thiserror = { version = "2.0", default-features = false }
tracing = { version = "0.1.37", default-features = false, features = [
"attributes",
Expand Down
272 changes: 240 additions & 32 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,38 @@
use core::fmt::Debug;
use std::time::Instant;

use rand::distr::StandardUniform;
use rand::prelude::Distribution;

use clap::Parser;
use p3_baby_bear::BabyBear;
use p3_challenger::DuplexChallenger;
use p3_dft::Radix2DFTSmallBatch;
use p3_field::{PrimeField64, extension::BinomialExtensionField};
use p3_challenger::{DuplexChallenger, FieldChallenger};
use p3_commit::{BatchOpening, ExtensionMmcs, Pcs, PolynomialSpace};
use p3_dft::{Radix2DFTSmallBatch, Radix2DitParallel, TwoAdicSubgroupDft};
use p3_field::coset::TwoAdicMultiplicativeCoset;
use p3_field::{
extension::BinomialExtensionField, ExtensionField, Field, PrimeField32, PrimeField64,
TwoAdicField,
};
use p3_fri::{FriParameters, FriProof, TwoAdicFriPcs};
use p3_goldilocks::Goldilocks;
use p3_koala_bear::{KoalaBear, Poseidon2KoalaBear};
use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
use p3_matrix::dense::RowMajorMatrix;
use p3_merkle_tree::MerkleTreeMmcs;
use p3_symmetric::{CryptographicPermutation, PaddingFreeSponge, TruncatedPermutation};
use p3_uni_stark::{
prove, verify, DebugConstraintBuilder, ProverConstraintFolder, StarkConfig, StarkGenericConfig,
SymbolicAirBuilder, VerifierConstraintFolder,
};
use rand::{
Rng, SeedableRng,
rngs::{SmallRng, StdRng},
Rng, SeedableRng,
};
use tracing_forest::{ForestLayer, util::LevelFilter};
use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt};
use tracing_forest::{util::LevelFilter, ForestLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry};
use whir_p3::{
fiat_shamir::domain_separator::DomainSeparator,
parameters::{DEFAULT_MAX_POW, FoldingFactor, ProtocolParameters, errors::SecurityAssumption},
parameters::{errors::SecurityAssumption, FoldingFactor, ProtocolParameters, DEFAULT_MAX_POW},
poly::{evals::EvaluationsList, multilinear::MultilinearPoint},
whir::{
committer::{reader::CommitmentReader, writer::CommitmentWriter},
Expand All @@ -41,6 +57,7 @@ type Poseidon24 = Poseidon2KoalaBear<24>;
type MerkleHash = PaddingFreeSponge<Poseidon24, 24, 16, 8>; // leaf hashing
type MerkleCompress = TruncatedPermutation<Poseidon16, 2, 8, 16>; // 2-to-1 compression
type MyChallenger = DuplexChallenger<F, Poseidon16, 16, 8>;
type DFT = Radix2DFTSmallBatch<F>;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
Expand Down Expand Up @@ -70,16 +87,64 @@ struct Args {
rs_domain_initial_reduction_factor: usize,
}

#[allow(clippy::too_many_lines)]
fn main() {
// Types related to using Poseidon2 in the Merkle tree.
pub(crate) type Poseidon2Sponge<Perm24> = PaddingFreeSponge<Perm24, 24, 16, 8>;
pub(crate) type Poseidon2Compression<Perm16> = TruncatedPermutation<Perm16, 2, 8, 16>;
pub(crate) type Poseidon2MerkleMmcs<F, Perm16, Perm24> = MerkleTreeMmcs<
<F as Field>::Packing,
<F as Field>::Packing,
Poseidon2Sponge<Perm24>,
Poseidon2Compression<Perm16>,
8,
>;

/// General context handling that stores things we need in WHIR as well as FRI
struct Context {
rng: StdRng,
polynomial: EvaluationsList<F>,
poseidon16: Poseidon16,
poseidon24: Poseidon24,
num_coeffs: usize,
num_evaluations: usize,
challenger: MyChallenger,
}

/// Initialize the `Context` object storing things used in both WHIR and FRI
fn init_context() -> Context {
let mut args = Args::parse(); // we parse again in `prepare_config`, but well..
let num_coeffs = 1 << args.num_variables;
let num_evaluations = args.num_evaluations;

let mut rng = StdRng::seed_from_u64(0);
let polynomial = EvaluationsList::<F>::new((0..num_coeffs).map(|_| rng.random()).collect());
let poseidon16 = Poseidon16::new_from_rng_128(&mut rng);
let poseidon24 = Poseidon24::new_from_rng_128(&mut rng);

// IMPORTANT: We obviously need to *clone* this challenger for every prove / verify call,
// otherwise transcript state would persist
let challenger = MyChallenger::new(poseidon16.clone());

Context {
rng,
polynomial,
poseidon16,
poseidon24,
num_coeffs,
num_evaluations,
challenger,
}
}

/// Prepare the `WhirConfig` used for WHIR and a few fields in FRI
fn prepare_config(ctx: &Context) -> WhirConfig<EF, F, MerkleHash, MerkleCompress, MyChallenger> {
let env_filter = EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy();

Registry::default()
let _ = Registry::default()
.with(env_filter)
.with(ForestLayer::default())
.init();
.try_init();

let mut args = Args::parse();

Expand All @@ -102,16 +167,12 @@ fn main() {

// Create hash and compression functions for the Merkle tree
let mut rng = SmallRng::seed_from_u64(1);
let poseidon16 = Poseidon16::new_from_rng_128(&mut rng);
let poseidon24 = Poseidon24::new_from_rng_128(&mut rng);

let merkle_hash = MerkleHash::new(poseidon24);
let merkle_compress = MerkleCompress::new(poseidon16.clone());
let merkle_hash = MerkleHash::new(ctx.poseidon24.clone());
let merkle_compress = MerkleCompress::new(ctx.poseidon16.clone());

let rs_domain_initial_reduction_factor = args.rs_domain_initial_reduction_factor;

let num_coeffs = 1 << num_variables;

// Construct WHIR protocol parameters
let whir_params = ProtocolParameters {
initial_statement: true,
Expand All @@ -131,20 +192,26 @@ fn main() {
whir_params,
);

let mut rng = StdRng::seed_from_u64(0);
let polynomial = EvaluationsList::<F>::new((0..num_coeffs).map(|_| rng.random()).collect());
params
}

fn run_whir(ctx: &mut Context) {
let args = Args::parse();

let params = prepare_config(ctx);
let dft = Radix2DFTSmallBatch::<F>::new(1 << params.max_fft_size());

// Sample `num_points` random multilinear points in the Boolean hypercube
let points: Vec<_> = (0..num_evaluations)
.map(|_| MultilinearPoint::rand(&mut rng, num_variables))
let points: Vec<_> = (0..ctx.num_evaluations)
.map(|_| MultilinearPoint::rand(&mut ctx.rng, params.num_variables))
.collect();

// Construct a new statement with the correct number of variables
let mut statement = EqStatement::<EF>::initialize(num_variables);
let mut statement = EqStatement::<EF>::initialize(params.num_variables);

// Add constraints for each sampled point (equality constraints)
for point in &points {
statement.add_unevaluated_constraint_hypercube(point.clone(), &polynomial);
statement.add_unevaluated_constraint_hypercube(point.clone(), &ctx.polynomial);
}

// Define the Fiat-Shamir domain separator pattern for committing and proving
Expand All @@ -158,11 +225,8 @@ fn main() {
println!("WARN: more PoW bits required than what specified.");
}

let challenger = MyChallenger::new(poseidon16);
let mut prover_challenger = challenger.clone();

// Initialize the Merlin transcript from the IOPattern
let mut prover_state = domainsep.to_prover_state(challenger.clone());
let mut prover_state = domainsep.to_prover_state(ctx.challenger.clone());

// Commit to the polynomial and produce a witness
let committer = CommitmentWriter::new(&params);
Expand All @@ -177,8 +241,8 @@ fn main() {
&dft,
&mut prover_state,
&mut proof,
&mut prover_challenger,
polynomial,
&mut ctx.challenger.clone(),
ctx.polynomial.clone(),
)
.unwrap();
let commit_time = time.elapsed();
Expand All @@ -200,10 +264,10 @@ fn main() {
let verifier = Verifier::new(&params);

// Reconstruct verifier's view of the transcript using the DomainSeparator and prover's data
let mut verifier_challenger = challenger.clone();
let mut verifier_state =
domainsep.to_verifier_state(prover_state.proof_data().to_vec(), challenger);
domainsep.to_verifier_state(prover_state.proof_data().to_vec(), ctx.challenger.clone());

let mut verifier_challenger = ctx.challenger.clone();
// Parse the commitment
let parsed_commitment = commitment_reader
.parse_commitment::<8>(&mut verifier_state, &proof, &mut verifier_challenger)
Expand Down Expand Up @@ -231,3 +295,147 @@ fn main() {
println!("proof size: {:.2} KiB", proof_size / 1024.0);
println!("Verification time: {} μs", verify_time.as_micros());
}

/// Creates a set of `FriParameters` suitable for benchmarking.
/// These parameters represents numbers used in Valida
pub const fn create_benchmark_fri_params<Mmcs>(mmcs: Mmcs) -> FriParameters<Mmcs> {
FriParameters {
log_blowup: 1,
log_final_poly_len: 0,
num_queries: 40,
proof_of_work_bits: 8,
mmcs,
}
}
Comment on lines +301 to +309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should fix a security parameter and then calculate the number of queries based off that.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But with target_security = 90 for WHIR vs num_queries: 40 here for FRI we don't have the same level of security no? Or did I miss something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should have mentioned that these are obviously placeholder at the moment. I just don't have a good intuition for how to choose parameters sensibly here to compare with WHIR.


/// Report the result of the proof.
///
/// Either print that the proof was successful or panic and return the error.
#[inline]
pub fn report_result(result: Result<(), impl Debug>) {
if let Err(e) = result {
panic!("{e:?}");
} else {
println!("Proof Verified Successfully")
}
}

/// Returns the size of the FRI proof in bytes
fn calc_fri_proof_size(
opened_values: Vec<Vec<Vec<Vec<EF>>>>,
proof: FriProof<
EF,
ExtensionMmcs<F, EF, Poseidon2MerkleMmcs<F, Poseidon16, Poseidon24>>,
F,
Vec<BatchOpening<F, Poseidon2MerkleMmcs<F, Poseidon16, Poseidon24>>>,
>,
) -> usize {
let opening_bytes = bincode::serialize(&opened_values).expect("serialize openings");
let proof_bytes = bincode::serialize(&proof).expect("serialize proof");
opening_bytes.len() + proof_bytes.len()
}

fn run_fri(ctx: &mut Context) {
// WHIR setup, to reuse parameters for FRI (that are applicable)
let params = prepare_config(ctx);

// TODO: The DFT size might be different for FRI and WHIR, no?
// Comment on WhirConfig for max_fft_size says:
// /// Returns the log2 size of the largest FFT
// /// (At commitment we perform 2^folding_factor FFT of size 2^max_fft_size)
// but folding factor will be different?
let dft = Radix2DFTSmallBatch::<F>::new(1 << params.max_fft_size());
// Set up MMCS and TwoAdicFriPcs
let val_mmcs = Poseidon2MerkleMmcs::<F, Poseidon16, Poseidon24>::new(
params.merkle_hash.clone(),
params.merkle_compress.clone(),
);
let challenge_mmcs = ExtensionMmcs::<F, EF, _>::new(val_mmcs.clone());
let fri_params = create_benchmark_fri_params(challenge_mmcs);
let pcs = TwoAdicFriPcs::<F, DFT, _, _>::new(dft, val_mmcs, fri_params);

println!("\n\n=========================================");
println!("FRI (PCS) 🍳️");

let log_height = params.num_variables;
let trace_height = 1 << log_height;

// Define the number of columns we split the evaluations into
// TODO: could make this a CL arg?
const LOG_NUM_COLS: usize = 5;
const NUM_COLS: usize = 1 << LOG_NUM_COLS; // 32
Comment on lines +363 to +366
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this doesn't really make sense from the FRI perspective (i.e. the low degree test).

I guess this is like doing several low degree tests in parallel?


// Construct a domain of the required size based on the height of the "trace". We split the evaluations
// `Context::polynomial` into a `trace_height x NUM_COLS` `RowMajorMatrix`.
// NOTE: In KoalaBear the F::TWO_ADICITY is 24. So the height can at most be 2^{23}.
let domain = TwoAdicMultiplicativeCoset::new(F::GENERATOR, log_height - LOG_NUM_COLS)
.expect("log height too large");

// Convert the polynomial evaluations into a RowMajorMatrix, as used by FRI with `NUM_COLS` columns.
// We need an iterator for the Pcs::commit, so wrap with `once`
let matrix_iter = std::iter::once((
domain,
RowMajorMatrix::new(ctx.polynomial.as_slice().to_vec(), NUM_COLS),
));

// Commit to the matrix
let commit_time = Instant::now();
let (commitment, prover_data) = Pcs::<EF, MyChallenger>::commit(&pcs, matrix_iter.clone());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cleaner

Suggested change
let (commitment, prover_data) = Pcs::<EF, MyChallenger>::commit(&pcs, matrix_iter.clone());
let (commitment, prover_data) = pcs.commit(matrix_iter.clone());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's what I wrote initially, but it leads to fun type annotations needed errors:

error[E0284]: type annotations needed
   --> src/bin/main.rs:384:41
    |
384 |     let (commitment, prover_data) = pcs.commit(matrix_iter.clone());
    |                                         ^^^^^^ cannot infer type for type parameter `Challenger`
    |
    = note: cannot satisfy `<_ as GrindingChallenger>::Witness == p3_monty_31::monty_31::MontyField31<KoalaBearParameters>`
    = note: required for `TwoAdicFriPcs<MontyField31<KoalaBearParameters>, Radix2DFTSmallBatch<MontyField31<KoalaBearParameters>>, MerkleTreeMmcs<..., ..., ..., ..., 8>, ...>` to implement `Pcs<BinomialExtensionField<p3_monty_31::monty_31::MontyField31<KoalaBearParameters>, 4>, _>`
    = note: the full name for the type has been written to '/home/basti/src/rust/whir-p3/target/debug/deps/main-6703cdc4c79f20a7.long-type-9852739075644608931.txt'
    = note: consider using `--verbose` to print the full type name to the console

error[E0283]: type annotations needed
   --> src/bin/main.rs:384:41
    |
384 |     let (commitment, prover_data) = pcs.commit(matrix_iter.clone());
    |                                         ^^^^^^ cannot infer type for type parameter `Challenger`
    |
    = note: multiple `impl`s satisfying `_: FieldChallenger<p3_monty_31::monty_31::MontyField31<KoalaBearParameters>>` found in the `p3_challenger` crate:
...................................

error[E0283]: type annotations needed
   --> src/bin/main.rs:384:41
    |
384 |     let (commitment, prover_data) = pcs.commit(matrix_iter.clone());
    |                                         ^^^^^^ cannot infer type for type parameter `Challenger`
    |
    = note: multiple `impl`s satisfying `_: CanObserve<p3_symmetric::Hash<p3_monty_31::monty_31::MontyField31<KoalaBearParameters>, p3_monty_31::monty_31::MontyField31<KoalaBearParameters>, 8>>` found in the `p3_challenger` crate:
..................................

at which point I just resigned to calling it as above...

let commitment_time = commit_time.elapsed();

// Randomly sample the correct number of opening points
let open_points: Vec<EF> = (0..ctx.num_evaluations)
.map(|_| ctx.rng.random::<EF>())
.collect();
let num_chunks = ctx.polynomial.num_evals() / trace_height;
let points = vec![open_points.clone(); num_chunks];
// Generate the opening proof
let open_time = Instant::now();
let (opened_values, proof) =
pcs.open(vec![(&prover_data, points)], &mut ctx.challenger.clone());
let opening_time = open_time.elapsed();

// Construct the points needed for the verifier
let verifier_points = matrix_iter
.zip(&opened_values[0]) // first and only commitment
.map(|((domain, _), mat_openings)| {
let openings = open_points
.iter()
.copied()
.zip(mat_openings.iter().cloned())
.collect();
(domain, openings)
})
.collect();

// Verify the opening proof
let verif_time = Instant::now();
let res = pcs.verify(
vec![(commitment, verifier_points)],
&proof,
&mut ctx.challenger.clone(),
);
let verify_time = verif_time.elapsed();

report_result(res);

println!(
"\nProving time: {} ms (commit: {} ms, opening: {} ms)",
commitment_time.as_millis() + opening_time.as_millis(),
commitment_time.as_millis(),
opening_time.as_millis()
);
let proof_size = calc_fri_proof_size(opened_values, proof) as f64;
println!("proof size: {:.2} KiB", proof_size / 1024.0);
println!("Verification time: {} μs", verify_time.as_micros());
}

fn main() {
let mut ctx = init_context();

// 1. First run WHIR
run_whir(&mut ctx);

// 2. Now run FRI
run_fri(&mut ctx);
}
2 changes: 1 addition & 1 deletion src/whir/prover/round_state/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
sumcheck::sumcheck_single::SumcheckSingle,
whir::{
committer::{RoundMerkleTree, Witness},
constraints::{Constraint, statement::EqStatement},
constraints::{statement::EqStatement, Constraint},
parameters::SumcheckOptimization,
prover::Prover,
},
Expand Down