Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions crates/spartan-frontend/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,6 @@ impl ConstraintSystemIR {
}
}

// Pad mul_constraints to the next power of two with dummy constraints
// The prover requires power-of-two sized constraint lists for multilinear extensions
let current_len = self.mul_constraints.len();
self.mul_constraints.resize(
current_len.next_power_of_two(),
MulConstraint {
a: one_operand.clone(),
b: one_operand.clone(),
c: one_operand.clone(),
},
);

// Create private_alive array from wire status (invert pruned logic)
let private_alive: Vec<bool> = self
.private_wires_status
Expand Down Expand Up @@ -178,13 +166,18 @@ impl ConstraintSystemIR {
})
.collect();

// Map one_wire to WitnessIndex
let one_wire_index = layout
.get(&one_wire)
.expect("one_wire constant should exist in layout");

let cs = ConstraintSystem::new(
constants,
layout.n_inout() as u32,
layout.n_private() as u32,
layout.log_public(),
layout.log_size(),
mul_constraints,
one_wire_index,
);

(cs, layout)
Expand Down
162 changes: 149 additions & 13 deletions crates/spartan-frontend/src/constraint_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,43 +126,36 @@ pub struct WitnessIndex(pub u32);
/// Contains multiplication constraints of the form `A * B = C` where A, B, C are operands
/// (XOR combinations of witness values). Constraints directly reference [`WitnessIndex`]
/// positions in the witness array.
///
/// This struct does not guarantee power-of-two constraint counts or witness size. Use
/// [`ConstraintSystemPadded`] for a version with blinding and power-of-two padding.
#[derive(Debug, Clone)]
pub struct ConstraintSystem<F: Field = B128> {
constants: Vec<F>,
n_inout: u32,
n_private: u32,
log_public: u32,
log_size: u32,
mul_constraints: Vec<MulConstraint<WitnessIndex>>,
one_wire: WitnessIndex,
}

impl<F: Field> ConstraintSystem<F> {
/// Create a new constraint system.
///
/// # Preconditions
///
/// * `mul_constraints.len()` must be a power of two. This is required by the prover's
/// multilinear extension protocol, which operates over power-of-two sized domains.
pub fn new(
constants: Vec<F>,
n_inout: u32,
n_private: u32,
log_public: u32,
log_size: u32,
mul_constraints: Vec<MulConstraint<WitnessIndex>>,
one_wire: WitnessIndex,
) -> Self {
assert!(
mul_constraints.len().is_power_of_two(),
"mul_constraints length must be a power of two, got {}",
mul_constraints.len()
);
Self {
constants,
n_inout,
n_private,
log_public,
log_size,
mul_constraints,
one_wire,
}
}

Expand All @@ -182,6 +175,125 @@ impl<F: Field> ConstraintSystem<F> {
self.log_public
}

pub fn n_public(&self) -> u32 {
1 << self.log_public
}

pub fn mul_constraints(&self) -> &[MulConstraint<WitnessIndex>] {
&self.mul_constraints
}

pub fn one_wire(&self) -> WitnessIndex {
self.one_wire
}

/// Validate that a witness satisfies all multiplication constraints.
pub fn validate(&self, witness: &[B128]) {
let operand_val = |operand: &Operand<WitnessIndex>| {
operand
.wires()
.iter()
.map(|idx| witness[idx.0 as usize])
.sum::<B128>()
};

for MulConstraint { a, b, c } in &self.mul_constraints {
assert_eq!(operand_val(a) * operand_val(b), operand_val(c));
}
}
}

/// A constraint system with blinding and power-of-two padding.
///
/// Wraps a [`ConstraintSystem`], adds dummy constraints for blinding, and pads the total
/// number of constraints to a power of two (required by the prover's multilinear extension
/// protocol).
#[derive(Debug, Clone)]
pub struct ConstraintSystemPadded<F: Field = B128> {
inner: ConstraintSystem<F>,
log_size: u32,
blinding_info: BlindingInfo,
mul_constraints: Vec<MulConstraint<WitnessIndex>>,
}

impl<F: Field> ConstraintSystemPadded<F> {
/// Create a new padded constraint system with blinding.
///
/// This:
/// 1. Adds dummy multiplication constraints for blinding (3 wires each: A * B = C)
/// 2. Pads the total constraint count to a power of two with `one * one = one` constraints
/// 3. Calculates the log_size based on witness requirements
pub fn new(cs: ConstraintSystem<F>, blinding_info: BlindingInfo) -> Self {
let mut mul_constraints = cs.mul_constraints.clone();

// Calculate witness size and log_size
let n_public = cs.n_public() as usize;
let n_private = cs.n_private as usize;
let total_witness_size = n_public
+ n_private
+ blinding_info.n_dummy_wires
+ 3 * blinding_info.n_dummy_constraints;
let log_size = log2_ceil_usize(total_witness_size) as u32;

// Add dummy constraints for blinding
// Each dummy constraint uses 3 consecutive wires starting after n_dummy_wires
let dummy_constraint_wire_base = n_public + n_private + blinding_info.n_dummy_wires;
for i in 0..blinding_info.n_dummy_constraints {
let a = WitnessIndex((dummy_constraint_wire_base + 3 * i) as u32);
let b = WitnessIndex((dummy_constraint_wire_base + 3 * i + 1) as u32);
let c = WitnessIndex((dummy_constraint_wire_base + 3 * i + 2) as u32);

mul_constraints.push(MulConstraint {
a: Operand::from(a),
b: Operand::from(b),
c: Operand::from(c),
});
}

// Pad to next power of two with `one * one = one` constraints
let one_operand = Operand::from(cs.one_wire);
let current_len = mul_constraints.len();
mul_constraints.resize(
current_len.next_power_of_two(),
MulConstraint {
a: one_operand.clone(),
b: one_operand.clone(),
c: one_operand.clone(),
},
);

Self {
inner: cs,
log_size,
blinding_info,
mul_constraints,
}
}

pub fn constants(&self) -> &[F] {
self.inner.constants()
}

pub fn n_inout(&self) -> u32 {
self.inner.n_inout()
}

pub fn n_private(&self) -> u32 {
self.inner.n_private()
}

pub fn log_public(&self) -> u32 {
self.inner.log_public()
}

pub fn n_public(&self) -> u32 {
self.inner.n_public()
}

pub fn one_wire(&self) -> WitnessIndex {
self.inner.one_wire()
}

pub fn log_size(&self) -> u32 {
self.log_size
}
Expand All @@ -190,6 +302,10 @@ impl<F: Field> ConstraintSystem<F> {
1 << self.log_size as usize
}

pub fn blinding_info(&self) -> &BlindingInfo {
&self.blinding_info
}

pub fn mul_constraints(&self) -> &[MulConstraint<WitnessIndex>] {
&self.mul_constraints
}
Expand Down Expand Up @@ -249,6 +365,18 @@ impl WitnessLayout {
}
}

pub fn with_blinding(self, info: BlindingInfo) -> Self {
let log_public = self.log_public;
let n_private = self.n_private as usize;

let private_offset = 1 << log_public as usize;
let total_size =
private_offset + n_private + info.n_dummy_wires + 3 * info.n_dummy_constraints;
let log_size = log2_ceil_usize(total_size) as u32;

Self { log_size, ..self }
}

pub fn size(&self) -> usize {
1 << self.log_size as usize
}
Expand Down Expand Up @@ -300,6 +428,14 @@ impl WitnessLayout {
}
}

#[derive(Debug, Clone)]
pub struct BlindingInfo {
/// The number of random dummy wires that must be added.
pub n_dummy_wires: usize,
/// The number of random dummy multiplication constraints that must be added.
pub n_dummy_constraints: usize,
}

#[cfg(test)]
mod tests {
use smallvec::smallvec;
Expand Down
1 change: 1 addition & 0 deletions crates/spartan-prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ binius-utils = { path = "../utils" }
binius-prover = { path = "../prover" }
binius-verifier = { path = "../verifier" }
digest.workspace = true
rand.workspace = true
thiserror.workspace = true
tracing.workspace = true

Expand Down
51 changes: 43 additions & 8 deletions crates/spartan-prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use binius_prover::{
merkle_tree::prover::BinaryMerkleTreeProver,
protocols::sumcheck::{prove_single_mlecheck, quadratic_mle::QuadraticMleCheckProver},
};
use binius_spartan_frontend::constraint_system::{MulConstraint, WitnessIndex};
use binius_spartan_frontend::constraint_system::{BlindingInfo, MulConstraint, WitnessIndex};
use binius_spartan_verifier::Verifier;
use binius_transcript::{
ProverTranscript,
Expand All @@ -26,6 +26,7 @@ use binius_transcript::{
use binius_utils::{SerializeBytes, checked_arithmetics::checked_log_2, rayon::prelude::*};
use digest::{Digest, FixedOutputReset, Output, core_api::BlockSizeUser};
pub use error::*;
use rand::CryptoRng;

use crate::wiring::WiringTranspose;

Expand Down Expand Up @@ -65,6 +66,7 @@ where
verifier: Verifier<F, MerkleHash, ParallelMerkleCompress::Compression>,
compression: ParallelMerkleCompress,
) -> Result<Self, Error> {
let cs = verifier.constraint_system();
let subspace = verifier.fri_params().rs_code().subspace();
let domain_context = GenericPreExpanded::generate_from_subspace(subspace);
let log_num_shares = binius_utils::rayon::current_num_threads().ilog2() as usize;
Expand All @@ -73,7 +75,6 @@ where
let merkle_prover = BinaryMerkleTreeProver::<_, ParallelMerkleHasher, _>::new(compression);

// Compute wiring transpose from constraint system
let cs = verifier.constraint_system();
let wiring_transpose = WiringTranspose::transpose(cs.size(), cs.mul_constraints());

Ok(Prover {
Expand All @@ -88,6 +89,7 @@ where
pub fn prove<Challenger_: Challenger>(
&self,
witness: &[F],
mut rng: impl CryptoRng,
transcript: &mut ProverTranscript<Challenger_>,
) -> Result<(), Error> {
let _prove_guard =
Expand All @@ -111,9 +113,16 @@ where
let public = &witness[..1 << cs.log_public()];
transcript.observe().write_slice(public);

// Pack witness into field elements
// TODO: Populate witness directly into a FieldBuffer
let witness_packed = pack_witness::<_, P>(cs.log_size() as usize, witness);
// Pack witness into field elements and add blinding
let blinding_info = cs.blinding_info();
let witness_packed = pack_and_blind_witness::<_, P>(
cs.log_size() as usize,
witness,
blinding_info,
cs.n_public() as usize,
cs.n_private() as usize,
&mut rng,
);

// Commit the witness
let CommitOutput {
Expand Down Expand Up @@ -201,9 +210,13 @@ where
}
}

fn pack_witness<F: Field, P: PackedField<Scalar = F>>(
fn pack_and_blind_witness<F: Field, P: PackedField<Scalar = F>>(
log_witness_elems: usize,
witness: &[F],
blinding_info: &BlindingInfo,
n_public: usize,
n_private: usize,
mut rng: impl CryptoRng,
) -> FieldBuffer<P> {
// Precondition: witness length must match expected size
let expected_size = 1 << log_witness_elems;
Expand Down Expand Up @@ -234,6 +247,28 @@ fn pack_witness<F: Field, P: PackedField<Scalar = F>>(
packed_witness.set_len(len);
};

FieldBuffer::new(log_witness_elems, packed_witness.into_boxed_slice())
.expect("FieldBuffer::new should succeed with correct log_witness_elems")
let mut witness_packed = FieldBuffer::new(log_witness_elems, packed_witness.into_boxed_slice())
.expect("FieldBuffer::new should succeed with correct log_witness_elems");

// Add blinding values
let base = n_public + n_private;

// Set random values for non-constraint dummy wires
for i in 0..blinding_info.n_dummy_wires {
witness_packed.set(base + i, F::random(&mut rng));
}

// Set random values for dummy constraint wires (A * B = C)
let constraint_wire_base = base + blinding_info.n_dummy_wires;
for i in 0..blinding_info.n_dummy_constraints {
let a = F::random(&mut rng);
let b = F::random(&mut rng);
let c = a * b;

witness_packed.set(constraint_wire_base + 3 * i, a);
witness_packed.set(constraint_wire_base + 3 * i + 1, b);
witness_packed.set(constraint_wire_base + 3 * i + 2, c);
}

witness_packed
}
Loading