Skip to content

Commit

Permalink
Add Blake3 as the default hasher for Plonk (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
weikengchen authored Dec 16, 2024
1 parent f4804bb commit 69d72e1
Show file tree
Hide file tree
Showing 17 changed files with 669 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "0.1.1"
edition = "2021"

[workspace.dependencies]
blake2 = "0.10.6"
educe = "0.5.0"
hex = "0.4.3"
itertools = "0.12.0"
Expand All @@ -17,6 +16,7 @@ bytemuck = "1.14.3"
tracing = "0.1.40"
indexmap = "2.2.6"
sha2 = "0.10.8"
blake3 = "1.5.5"

[profile.bench]
codegen-units = 1
Expand Down
1 change: 1 addition & 0 deletions crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ rayon = { version = "1.10.0", optional = true }
serde = { version = "1.0", features = ["derive"] }
sha2.workspace = true
indexmap.workspace = true
blake3.workspace = true

[dev-dependencies]
aligned = "0.4.2"
Expand Down
24 changes: 24 additions & 0 deletions crates/prover/src/core/backend/cpu/blake3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use itertools::Itertools;

use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::vcs::blake3_hash::Blake3Hash;
use crate::core::vcs::blake3_merkle::Blake3MerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};

impl MerkleOps<Blake3MerkleHasher> for CpuBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Blake3Hash>>,
columns: &[&Vec<BaseField>],
) -> Vec<Blake3Hash> {
(0..(1 << log_size))
.map(|i| {
Blake3MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}
3 changes: 3 additions & 0 deletions crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod accumulation;
mod blake3;
mod circle;
mod fri;
mod grind;
Expand All @@ -15,6 +16,7 @@ use crate::core::fields::Field;
use crate::core::lookups::mle::Mle;
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
use crate::core::utils::bit_reverse;
use crate::core::vcs::blake3_merkle::Blake3MerkleChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;
use crate::core::vcs::sha256_merkle::Sha256MerkleChannel;
Expand All @@ -24,6 +26,7 @@ pub struct CpuBackend;

impl Backend for CpuBackend {}
impl BackendForChannel<Sha256MerkleChannel> for CpuBackend {}
impl BackendForChannel<Blake3MerkleChannel> for CpuBackend {}
#[cfg(not(target_arch = "wasm32"))]
impl BackendForChannel<Poseidon252MerkleChannel> for CpuBackend {}

Expand Down
41 changes: 41 additions & 0 deletions crates/prover/src/core/backend/simd/blake3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, ColumnOps};
use crate::core::vcs::blake3_hash::Blake3Hash;
use crate::core::vcs::blake3_merkle::Blake3MerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};

impl ColumnOps<Blake3Hash> for SimdBackend {
type Column = Vec<Blake3Hash>;

fn bit_reverse_column(_column: &mut Self::Column) {
unimplemented!()
}
}

// TODO(BWS): not simd at all
impl MerkleOps<Blake3MerkleHasher> for SimdBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Blake3Hash>>,
columns: &[&BaseColumn],
) -> Vec<Blake3Hash> {
#[cfg(not(feature = "parallel"))]
let iter = 0..1 << log_size;

#[cfg(feature = "parallel")]
let iter = (0..1 << log_size).into_par_iter();

iter.map(|i| {
Blake3MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column.at(i)).collect_vec(),
)
})
.collect()
}
}
15 changes: 15 additions & 0 deletions crates/prover/src/core/backend/simd/grind.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::SimdBackend;
use crate::core::channel::blake3::Blake3Channel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::channel::Poseidon252Channel;
use crate::core::channel::{Channel, Sha256Channel};
Expand All @@ -18,6 +19,20 @@ impl GrindOps<Sha256Channel> for SimdBackend {
}
}

impl GrindOps<Blake3Channel> for SimdBackend {
fn grind(channel: &Blake3Channel, pow_bits: u32) -> u64 {
let mut nonce = 0;
loop {
let mut channel = channel.clone();
channel.mix_nonce(nonce);
if channel.trailing_zeros() >= pow_bits {
return nonce;
}
nonce += 1;
}
}
}

// TODO(spapini): This is a naive implementation. Optimize it.
#[cfg(not(target_arch = "wasm32"))]
impl GrindOps<Poseidon252Channel> for SimdBackend {
Expand Down
3 changes: 3 additions & 0 deletions crates/prover/src/core/backend/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use serde::{Deserialize, Serialize};

use super::{Backend, BackendForChannel};
use crate::core::vcs::blake3_merkle::Blake3MerkleChannel;
use crate::core::vcs::sha256_merkle::Sha256MerkleChannel;

pub mod accumulation;
pub mod bit_reverse;
pub mod blake3;
pub mod circle;
pub mod cm31;
pub mod column;
Expand All @@ -26,3 +28,4 @@ pub struct SimdBackend;

impl Backend for SimdBackend {}
impl BackendForChannel<Sha256MerkleChannel> for SimdBackend {}
impl BackendForChannel<Blake3MerkleChannel> for SimdBackend {}
158 changes: 158 additions & 0 deletions crates/prover/src/core/channel/blake3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use crate::core::channel::{extract_common, Channel};
use crate::core::fields::cm31::CM31;
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::utils::sha256_qm31;
use crate::core::vcs::blake3_hash::{Blake3Hash, Blake3Hasher};

#[derive(Default, Clone)]
/// A channel.
pub struct Blake3Channel {
/// Current state of the channel
pub digest: Blake3Hash,
}

impl Blake3Channel {
pub fn digest(&self) -> Blake3Hash {
self.digest
}

pub fn update_digest(&mut self, digest: Blake3Hash) {
self.digest = digest;
}
}

impl Channel for Blake3Channel {
const BYTES_PER_HASH: usize = 32;

fn mix_felts(&mut self, felts: &[SecureField]) {
for felt in felts.iter() {
let mut hasher = blake3::Hasher::new();
hasher.update(&sha256_qm31(felt));
hasher.update(self.digest.as_ref());
self.update_digest(hasher.finalize().as_bytes().as_ref().into())
}
}

fn mix_nonce(&mut self, nonce: u64) {
// mix_nonce is called during PoW. However, later we plan to replace it by a Bitcoin block
// inclusion proof, then this function would never be called.

let mut hash = [0u8; 32];
hash[..8].copy_from_slice(&nonce.to_le_bytes());

self.digest = Blake3Hasher::concat_and_hash(&Blake3Hash(hash), &self.digest);
}

fn draw_felt(&mut self) -> SecureField {
let mut extract = [0u8; 32];

let mut hasher = blake3::Hasher::new();
hasher.update(self.digest.as_ref());
hasher.update(&[0u8]);
extract.copy_from_slice(hasher.finalize().as_bytes());

let mut hasher = blake3::Hasher::new();
hasher.update(self.digest.as_ref());
self.update_digest(hasher.finalize().as_bytes().as_ref().into());

let res_1 = extract_common(&extract);
let res_2 = extract_common(&extract[4..]);
let res_3 = extract_common(&extract[8..]);
let res_4 = extract_common(&extract[12..]);

QM31(CM31(res_1, res_2), CM31(res_3, res_4))
}

fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField> {
let mut res = vec![];
for _ in 0..n_felts {
res.push(self.draw_felt());
}
res
}

fn draw_random_bytes(&mut self) -> Vec<u8> {
let mut extract = [0u8; 32];

let mut hasher = blake3::Hasher::new();
hasher.update(self.digest.as_ref());
hasher.update(&[0u8]);
extract.copy_from_slice(hasher.finalize().as_bytes());

let mut hasher = blake3::Hasher::new();
hasher.update(self.digest.as_ref());
self.update_digest(hasher.finalize().as_bytes().as_ref().into());

extract.to_vec()
}

fn trailing_zeros(&self) -> u32 {
let mut n_bits = 0;
for byte in self.digest.0.iter().rev() {
if *byte == 0 {
n_bits += 8;
} else {
n_bits += byte.leading_zeros();
break;
}
}
n_bits
}
}

#[cfg(test)]
mod tests {
use std::collections::BTreeSet;

use crate::core::channel::blake3::Blake3Channel;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::m31;

#[test]
fn test_draw_random_bytes() {
let mut channel = Blake3Channel::default();

let first_random_bytes = channel.draw_random_bytes();

// Assert that next random bytes are different.
assert_ne!(first_random_bytes, channel.draw_random_bytes());
}

#[test]
pub fn test_draw_felt() {
let mut channel = Blake3Channel::default();

let first_random_felt = channel.draw_felt();

// Assert that next random felt is different.
assert_ne!(first_random_felt, channel.draw_felt());
}

#[test]
pub fn test_draw_felts() {
let mut channel = Blake3Channel::default();

let mut random_felts = channel.draw_felts(5);
random_felts.extend(channel.draw_felts(4));

// Assert that all the random felts are unique.
assert_eq!(
random_felts.len(),
random_felts.iter().collect::<BTreeSet<_>>().len()
);
}

#[test]
pub fn test_mix_felts() {
let mut channel = Blake3Channel::default();
let initial_digest = channel.digest;
let felts: Vec<SecureField> = (0..2)
.map(|i| SecureField::from(m31!(i + 1923782)))
.collect();

channel.mix_felts(felts.as_slice());

assert_ne!(initial_digest, channel.digest);
}
}
15 changes: 15 additions & 0 deletions crates/prover/src/core/channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ pub use poseidon252::Poseidon252Channel;
pub mod sha256;
pub use sha256::Sha256Channel;

use crate::core::fields::m31::M31;

pub mod blake3;

pub const EXTENSION_FELTS_PER_HASH: usize = 2;

#[derive(Clone, Default)]
Expand Down Expand Up @@ -52,3 +56,14 @@ pub trait MerkleChannel: Default {
type H: MerkleHasher;
fn mix_root(channel: &mut Self::C, root: <Self::H as MerkleHasher>::Hash);
}

pub(crate) fn extract_common(hash: &[u8]) -> M31 {
let mut bytes = [0u8; 4];
bytes.copy_from_slice(&hash[0..4]);

let mut res = u32::from_le_bytes(bytes);
res &= 0x7fffffff;
res %= (1 << 31) - 1;

M31::from(res)
}
24 changes: 5 additions & 19 deletions crates/prover/src/core/channel/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use sha2::{Digest, Sha256};

use crate::core::channel::Channel;
use crate::core::channel::{extract_common, Channel};
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::utils::sha256_qm31;
use crate::core::vcs::sha256_hash::{Sha256Hash, Sha256Hasher};
Expand Down Expand Up @@ -62,10 +61,10 @@ impl Channel for Sha256Channel {
Digest::update(&mut hasher, self.digest);
self.digest.0.copy_from_slice(hasher.finalize().as_slice());

let res_1 = Self::extract_common(&extract);
let res_2 = Self::extract_common(&extract[4..]);
let res_3 = Self::extract_common(&extract[8..]);
let res_4 = Self::extract_common(&extract[12..]);
let res_1 = extract_common(&extract);
let res_2 = extract_common(&extract[4..]);
let res_3 = extract_common(&extract[8..]);
let res_4 = extract_common(&extract[12..]);

QM31(CM31(res_1, res_2), CM31(res_3, res_4))
}
Expand Down Expand Up @@ -107,19 +106,6 @@ impl Channel for Sha256Channel {
}
}

impl Sha256Channel {
fn extract_common(hash: &[u8]) -> M31 {
let mut bytes = [0u8; 4];
bytes.copy_from_slice(&hash[0..4]);

let mut res = u32::from_le_bytes(bytes);
res &= 0x7fffffff;
res %= (1 << 31) - 1;

M31::from(res)
}
}

#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
Expand Down
Loading

0 comments on commit 69d72e1

Please sign in to comment.