Skip to content

Commit

Permalink
Poseidon31 Channel and Merkle Hasher (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
weikengchen authored Dec 17, 2024
1 parent 69d72e1 commit d4c18cb
Show file tree
Hide file tree
Showing 11 changed files with 421 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tracing = "0.1.40"
indexmap = "2.2.6"
sha2 = "0.10.8"
blake3 = "1.5.5"
poseidon2-m31 = { git = "https://github.com/Bitcoin-Wildlife-Sanctuary/poseidon2-m31" }

[profile.bench]
codegen-units = 1
Expand Down
1 change: 1 addition & 0 deletions crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ serde = { version = "1.0", features = ["derive"] }
sha2.workspace = true
indexmap.workspace = true
blake3.workspace = true
poseidon2-m31.workspace = true

[dev-dependencies]
aligned = "0.4.2"
Expand Down
15 changes: 15 additions & 0 deletions crates/prover/src/core/backend/simd/grind.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::SimdBackend;
use crate::core::channel::blake3::Blake3Channel;
use crate::core::channel::poseidon31::Poseidon31Channel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::channel::Poseidon252Channel;
use crate::core::channel::{Channel, Sha256Channel};
Expand Down Expand Up @@ -33,6 +34,20 @@ impl GrindOps<Blake3Channel> for SimdBackend {
}
}

impl GrindOps<Poseidon31Channel> for SimdBackend {
fn grind(channel: &Poseidon31Channel, pow_bits: u32) -> u64 {
let mut nonce = 0;
loop {
let mut channel = channel.clone();
channel.mix_nonce(nonce);
if channel.trailing_zeros() >= pow_bits {
return nonce;
}
nonce += 1;
}
}
}

// TODO(spapini): This is a naive implementation. Optimize it.
#[cfg(not(target_arch = "wasm32"))]
impl GrindOps<Poseidon252Channel> for SimdBackend {
Expand Down
3 changes: 3 additions & 0 deletions crates/prover/src/core/backend/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};

use super::{Backend, BackendForChannel};
use crate::core::vcs::blake3_merkle::Blake3MerkleChannel;
use crate::core::vcs::poseidon31_merkle::Poseidon31MerkleChannel;
use crate::core::vcs::sha256_merkle::Sha256MerkleChannel;

pub mod accumulation;
Expand All @@ -16,6 +17,7 @@ pub mod fri;
mod grind;
pub mod lookups;
pub mod m31;
pub mod poseidon31;
pub mod prefix_sum;
pub mod qm31;
pub mod quotients;
Expand All @@ -29,3 +31,4 @@ pub struct SimdBackend;
impl Backend for SimdBackend {}
impl BackendForChannel<Sha256MerkleChannel> for SimdBackend {}
impl BackendForChannel<Blake3MerkleChannel> for SimdBackend {}
impl BackendForChannel<Poseidon31MerkleChannel> for SimdBackend {}
40 changes: 40 additions & 0 deletions crates/prover/src/core/backend/simd/poseidon31.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, ColumnOps};
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
use crate::core::vcs::poseidon31_hash::Poseidon31Hash;
use crate::core::vcs::poseidon31_merkle::Poseidon31MerkleHasher;

impl ColumnOps<Poseidon31Hash> for SimdBackend {
type Column = Vec<Poseidon31Hash>;

fn bit_reverse_column(_: &mut Self::Column) {
unimplemented!()
}
}

impl MerkleOps<Poseidon31MerkleHasher> for SimdBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Poseidon31Hash>>,
columns: &[&BaseColumn],
) -> Vec<Poseidon31Hash> {
#[cfg(not(feature = "parallel"))]
let iter = 0..1 << log_size;

#[cfg(feature = "parallel")]
let iter = (0..1 << log_size).into_par_iter();

iter.map(|i| {
Poseidon31MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column.at(i)).collect_vec(),
)
})
.collect()
}
}
3 changes: 3 additions & 0 deletions crates/prover/src/core/channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ pub use sha256::Sha256Channel;
use crate::core::fields::m31::M31;

pub mod blake3;
pub use blake3::Blake3Channel;

pub mod poseidon31;

pub const EXTENSION_FELTS_PER_HASH: usize = 2;

Expand Down
151 changes: 151 additions & 0 deletions crates/prover/src/core/channel/poseidon31.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use poseidon2_m31::Poseidon31Sponge;

use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;

#[derive(Clone, Default)]
pub struct Poseidon31Channel {
sponge: Poseidon31Sponge,
}

impl Poseidon31Channel {
fn draw_base_felts(&mut self) -> [BaseField; 8] {
let u32s = self.sponge.squeeze(8);

[
BaseField::from(u32s[0]),
BaseField::from(u32s[1]),
BaseField::from(u32s[2]),
BaseField::from(u32s[3]),
BaseField::from(u32s[4]),
BaseField::from(u32s[5]),
BaseField::from(u32s[6]),
BaseField::from(u32s[7]),
]
}
}

impl Channel for Poseidon31Channel {
const BYTES_PER_HASH: usize = 32;

fn trailing_zeros(&self) -> u32 {
let res = &self.sponge.state[0..4];

let mut bytes = [0u8; 16];
bytes[0..4].copy_from_slice(&res[0].to_le_bytes());
bytes[4..8].copy_from_slice(&res[1].to_le_bytes());
bytes[8..12].copy_from_slice(&res[2].to_le_bytes());
bytes[12..16].copy_from_slice(&res[3].to_le_bytes());

u128::from_le_bytes(bytes).trailing_zeros()
}

fn mix_felts(&mut self, felts: &[SecureField]) {
let mut inputs = Vec::with_capacity(felts.len() * 4);

for felt in felts.iter() {
inputs.push(felt.0 .0 .0);
inputs.push(felt.0 .1 .0);
inputs.push(felt.1 .0 .0);
inputs.push(felt.1 .1 .0);
}

self.sponge.absorb(&inputs);
}

fn mix_nonce(&mut self, nonce: u64) {
let n1 = nonce % ((1 << 22) - 1); // 22 bytes
let n2 = (nonce >> 22) & ((1 << 21) - 1); // 21 bytes
let n3 = (nonce >> 43) & ((1 << 21) - 1); // 21 bytes

self.sponge.absorb(&[n1 as u32, n2 as u32, n3 as u32, 0]);
}

fn draw_felt(&mut self) -> SecureField {
let felts: [BaseField; 8] = self.draw_base_felts();
SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap())
}

fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField> {
let mut res = vec![];
for _ in 0..n_felts {
res.push(self.draw_felt());
}
res
}

fn draw_random_bytes(&mut self) -> Vec<u8> {
// the implementation here is based on the assumption that the only place draw_random_bytes
// will be used is in generating the queries, where only the lowest n bits of every 4 bytes
// slice would be used.
let elems = self.sponge.squeeze(8);

let mut res = Vec::with_capacity(32);
for elem in elems.iter() {
res.extend_from_slice(&elem.to_le_bytes());
}

res
}
}

#[cfg(test)]
mod tests {
use std::collections::BTreeSet;

use crate::core::channel::poseidon31::Poseidon31Channel;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::m31;

#[test]
fn test_draw_random_bytes() {
let mut channel = Poseidon31Channel::default();

let first_random_bytes = channel.draw_random_bytes();

// Assert that next random bytes are different.
assert_ne!(first_random_bytes, channel.draw_random_bytes());
}

#[test]
pub fn test_draw_felt() {
let mut channel = Poseidon31Channel::default();

let first_random_felt = channel.draw_felt();

// Assert that next random felt is different.
assert_ne!(first_random_felt, channel.draw_felt());
}

#[test]
pub fn test_draw_felts() {
let mut channel = Poseidon31Channel::default();

let mut random_felts = channel.draw_felts(5);
random_felts.extend(channel.draw_felts(4));

// Assert that all the random felts are unique.
assert_eq!(
random_felts.len(),
random_felts.iter().collect::<BTreeSet<_>>().len()
);
}

#[test]
pub fn test_mix_felts() {
let mut channel = Poseidon31Channel::default();
let initial_digest = channel.sponge.state;
let felts: Vec<SecureField> = (0..2)
.map(|i| SecureField::from(m31!(i + 1923782)))
.collect();

channel.mix_felts(felts.as_slice());
// this works because aftering mixing with 8 elements, the state should be updated

assert!(channel.sponge.buffer.is_empty());
assert_ne!(initial_digest, channel.sponge.state);
}
}
1 change: 1 addition & 0 deletions crates/prover/src/core/vcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod blake3_hash;
pub mod blake3_merkle;

pub mod poseidon31_hash;
pub mod poseidon31_merkle;

#[cfg(test)]
mod test_utils;
30 changes: 30 additions & 0 deletions crates/prover/src/core/vcs/poseidon31_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,33 @@ impl std::fmt::Display for Poseidon31Hash {
}

impl super::hash::Hash for Poseidon31Hash {}

impl Poseidon31Hash {
pub fn as_limbs(&self) -> [u32; 8] {
[
self.0[0].0,
self.0[1].0,
self.0[2].0,
self.0[3].0,
self.0[4].0,
self.0[5].0,
self.0[6].0,
self.0[7].0,
]
}
}

impl From<[u32; 8]> for Poseidon31Hash {
fn from(value: [u32; 8]) -> Self {
Self([
M31::from(value[0]),
M31::from(value[1]),
M31::from(value[2]),
M31::from(value[3]),
M31::from(value[4]),
M31::from(value[5]),
M31::from(value[6]),
M31::from(value[7]),
])
}
}
66 changes: 66 additions & 0 deletions crates/prover/src/core/vcs/poseidon31_merkle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use poseidon2_m31::Poseidon31CRH;
use serde::{Deserialize, Serialize};

use crate::core::channel::poseidon31::Poseidon31Channel;
use crate::core::channel::{Channel, MerkleChannel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::vcs::ops::MerkleHasher;
use crate::core::vcs::poseidon31_hash::Poseidon31Hash;

#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Deserialize, Serialize)]
pub struct Poseidon31MerkleHasher;
impl MerkleHasher for Poseidon31MerkleHasher {
type Hash = Poseidon31Hash;

fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
column_values: &[BaseField],
) -> Self::Hash {
let column_hash = if column_values.is_empty() {
None
} else {
let mut data = Vec::with_capacity(column_values.len());
for column_value in column_values.iter() {
data.push(column_value.0);
}
Some(Poseidon31CRH::hash_fixed_length(&data))
};

match (children_hashes, column_hash) {
(Some(children_hashes), Some(column_hash)) => {
let mut data = [0u32; 24];
data[0..8].copy_from_slice(&children_hashes.0.as_limbs());
data[8..16].copy_from_slice(&column_hash);
data[16..24].copy_from_slice(&children_hashes.1.as_limbs());
Poseidon31CRH::hash_fixed_length(&data).into()
}
(Some(children_hashes), None) => {
let mut data = [0u32; 16];
data[0..8].copy_from_slice(&children_hashes.0.as_limbs());
data[8..16].copy_from_slice(&children_hashes.1.as_limbs());
Poseidon31CRH::hash_fixed_length(&data).into()
}
(None, Some(column_hash)) => {
// omit this hash assuming that we always know a leaf is a leaf
// which is the case in FRI protocols, but not for the general usage
column_hash.into()
}
(None, None) => unreachable!(),
}
}
}

#[derive(Default)]
pub struct Poseidon31MerkleChannel;

impl MerkleChannel for Poseidon31MerkleChannel {
type C = Poseidon31Channel;
type H = Poseidon31MerkleHasher;

fn mix_root(channel: &mut Self::C, root: <Self::H as MerkleHasher>::Hash) {
let r1 = SecureField::from_m31(root.0[0], root.0[1], root.0[2], root.0[3]);
let r2 = SecureField::from_m31(root.0[4], root.0[5], root.0[6], root.0[7]);
channel.mix_felts(&[r1, r2]);
}
}
Loading

0 comments on commit d4c18cb

Please sign in to comment.