Skip to content

Commit

Permalink
clean up ProofStream
Browse files Browse the repository at this point in the history
- remove unused methods
- slightly cleaner API for `enqueue`
- derive `Default` instead of manually implementing it
- add some tests for expected failures
  • Loading branch information
jan-ferdinand committed Sep 29, 2023
2 parents 3bb462c + 23408ed commit 3f0c320
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 61 deletions.
6 changes: 3 additions & 3 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<'stream, H: AlgebraicHasher> FriProver<'stream, H> {
fn commit_to_round(&mut self, round: &ProverRound<H>) {
let merkle_root = round.merkle_tree.get_root();
let proof_item = ProofItem::MerkleRoot(merkle_root);
self.proof_stream.enqueue(&proof_item);
self.proof_stream.enqueue(proof_item);
}

fn store_round(&mut self, round: ProverRound<H>) {
Expand All @@ -102,7 +102,7 @@ impl<'stream, H: AlgebraicHasher> FriProver<'stream, H> {
fn send_last_codeword(&mut self) {
let last_codeword = self.rounds.last().unwrap().codeword.clone();
let proof_item = ProofItem::FriCodeword(last_codeword);
self.proof_stream.enqueue(&proof_item);
self.proof_stream.enqueue(proof_item);
}

fn query(&mut self) {
Expand Down Expand Up @@ -155,7 +155,7 @@ impl<'stream, H: AlgebraicHasher> FriProver<'stream, H> {
revealed_leaves,
};
let proof_item = ProofItem::FriResponse(fri_response);
self.proof_stream.enqueue(&proof_item)
self.proof_stream.enqueue(proof_item)
}
}

Expand Down
8 changes: 4 additions & 4 deletions triton-vm/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pub mod test_claim_proof {
#[test]
fn proof_with_no_log_2_padded_height_gives_err() {
let mut proof_stream = ProofStream::<StarkHasher>::new();
proof_stream.enqueue(&ProofItem::MerkleRoot(random()));
proof_stream.enqueue(ProofItem::MerkleRoot(random()));
let proof: Proof = proof_stream.into();
let maybe_padded_height = proof.padded_height();
assert!(maybe_padded_height.is_err());
Expand All @@ -114,9 +114,9 @@ pub mod test_claim_proof {
#[test]
fn proof_with_multiple_log_2_padded_height_gives_err() {
let mut proof_stream = ProofStream::<StarkHasher>::new();
proof_stream.enqueue(&ProofItem::Log2PaddedHeight(8));
proof_stream.enqueue(&ProofItem::MerkleRoot(random()));
proof_stream.enqueue(&ProofItem::Log2PaddedHeight(7));
proof_stream.enqueue(ProofItem::Log2PaddedHeight(8));
proof_stream.enqueue(ProofItem::MerkleRoot(random()));
proof_stream.enqueue(ProofItem::Log2PaddedHeight(7));
let proof: Proof = proof_stream.into();
let maybe_padded_height = proof.padded_height();
assert!(maybe_padded_height.is_err());
Expand Down
4 changes: 2 additions & 2 deletions triton-vm/src/proof_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ mod proof_item_typed_tests {

// test encoding and decoding in a stream
let mut proof_stream = ProofStream::<H>::new();
proof_stream.enqueue(&ProofItem::FriResponse(fri_response.clone()));
proof_stream.enqueue(ProofItem::FriResponse(fri_response.clone()));
let proof: Proof = proof_stream.into();
let mut proof_stream = ProofStream::<H>::try_from(&proof).unwrap();
let fri_response_ = proof_stream.dequeue().unwrap();
Expand Down Expand Up @@ -331,7 +331,7 @@ mod proof_item_typed_tests {

// test encoding and decoding in a stream
let mut proof_stream = ProofStream::<H>::new();
proof_stream.enqueue(&ProofItem::AuthenticationStructure(auth_structure.clone()));
proof_stream.enqueue(ProofItem::AuthenticationStructure(auth_structure.clone()));
let proof: Proof = proof_stream.into();
let mut proof_stream = ProofStream::<H>::try_from(&proof).unwrap();
let auth_structure_ = proof_stream.dequeue().unwrap();
Expand Down
78 changes: 41 additions & 37 deletions triton-vm/src/proof_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use twenty_first::util_types::algebraic_hasher::AlgebraicHasher;
use crate::proof::Proof;
use crate::proof_item::ProofItem;

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct ProofStream<H>
where
H: AlgebraicHasher,
Expand All @@ -33,15 +33,6 @@ where
}
}

pub fn is_empty(&self) -> bool {
self.items.is_empty()
}

/// The number of items in the proof stream.
pub fn len(&self) -> usize {
self.items.len()
}

/// The number of field elements required to encode the proof.
pub fn transcript_length(&self) -> usize {
let Proof(b_field_elements) = self.into();
Expand Down Expand Up @@ -86,11 +77,11 @@ where
/// in question was hashed previously.
/// - If the proof stream is not used to sample any more randomness, _i.e._, after the last
/// round of interaction, no further items need to be hashed.
pub fn enqueue(&mut self, item: &ProofItem) {
pub fn enqueue(&mut self, item: ProofItem) {
if item.include_in_fiat_shamir_heuristic() {
self.alter_fiat_shamir_state_with(item);
self.alter_fiat_shamir_state_with(&item);
}
self.items.push(item.clone());
self.items.push(item);
}

/// Receive a proof item from prover as verifier.
Expand Down Expand Up @@ -127,25 +118,15 @@ where
}
}

impl<H> Default for ProofStream<H>
where
H: AlgebraicHasher,
{
fn default() -> Self {
Self::new()
}
}

impl<H> BFieldCodec for ProofStream<H>
where
H: AlgebraicHasher,
{
fn decode(sequence: &[BFieldElement]) -> Result<Box<Self>> {
let items = *Vec::<ProofItem>::decode(sequence)?;
let proof_stream = ProofStream {
let items = *Vec::decode(sequence)?;
let proof_stream = Self {
items,
items_index: 0,
sponge_state: H::init(),
..Self::new()
};
Ok(Box::new(proof_stream))
}
Expand Down Expand Up @@ -190,7 +171,7 @@ where
}

#[cfg(test)]
mod proof_stream_typed_tests {
mod tests {
use itertools::Itertools;
use rand::distributions::Standard;
use rand::prelude::Distribution;
Expand Down Expand Up @@ -269,23 +250,23 @@ mod proof_stream_typed_tests {
let mut proof_stream = ProofStream::<H>::new();

sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::AuthenticationStructure(auth_structure.clone()));
proof_stream.enqueue(ProofItem::AuthenticationStructure(auth_structure.clone()));
sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::MasterBaseTableRows(base_rows.clone()));
proof_stream.enqueue(ProofItem::MasterBaseTableRows(base_rows.clone()));
sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::MasterExtTableRows(ext_rows.clone()));
proof_stream.enqueue(ProofItem::MasterExtTableRows(ext_rows.clone()));
sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::OutOfDomainBaseRow(ood_base_row.clone()));
proof_stream.enqueue(ProofItem::OutOfDomainBaseRow(ood_base_row.clone()));
sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::OutOfDomainExtRow(ood_ext_row.clone()));
proof_stream.enqueue(ProofItem::OutOfDomainExtRow(ood_ext_row.clone()));
sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::MerkleRoot(root));
proof_stream.enqueue(ProofItem::MerkleRoot(root));
sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::QuotientSegmentsElements(quot_elements.clone()));
proof_stream.enqueue(ProofItem::QuotientSegmentsElements(quot_elements.clone()));
sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::FriCodeword(fri_codeword.clone()));
proof_stream.enqueue(ProofItem::FriCodeword(fri_codeword.clone()));
sponge_states.push_back(proof_stream.sponge_state.state);
proof_stream.enqueue(&ProofItem::FriResponse(fri_response.clone()));
proof_stream.enqueue(ProofItem::FriResponse(fri_response.clone()));
sponge_states.push_back(proof_stream.sponge_state.state);

let proof = proof_stream.into();
Expand Down Expand Up @@ -405,7 +386,7 @@ mod proof_stream_typed_tests {
};

let mut proof_stream = ProofStream::<H>::new();
proof_stream.enqueue(&ProofItem::FriResponse(fri_response));
proof_stream.enqueue(ProofItem::FriResponse(fri_response));

// TODO: Also check that deserializing from Proof works here.

Expand All @@ -424,4 +405,27 @@ mod proof_stream_typed_tests {
);
assert!(verdict);
}

#[test]
#[should_panic(expected = "Queue must be non-empty")]
fn dequeuing_from_empty_stream_fails() {
let mut proof_stream = ProofStream::<Tip5>::new();
proof_stream.dequeue().unwrap();
}

#[test]
#[should_panic(expected = "Queue must be non-empty")]
fn dequeuing_more_items_than_have_been_enqueued_fails() {
let mut proof_stream = ProofStream::<Tip5>::new();
proof_stream.enqueue(ProofItem::FriCodeword(vec![]));
proof_stream.enqueue(ProofItem::Log2PaddedHeight(7));
proof_stream.dequeue().unwrap();
proof_stream.dequeue().unwrap();
proof_stream.dequeue().unwrap();
}

#[test]
fn encoded_length_of_prove_stream_is_not_known_at_compile_time() {
assert!(ProofStream::<Tip5>::static_length().is_none());
}
}
30 changes: 15 additions & 15 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl Stark {
let max_degree = Self::derive_max_degree(padded_height, parameters.num_trace_randomizers);
let fri = Self::derive_fri(parameters, padded_height);
let quotient_domain = Self::quotient_domain(fri.domain, max_degree);
proof_stream.enqueue(&ProofItem::Log2PaddedHeight(padded_height.ilog2()));
proof_stream.enqueue(ProofItem::Log2PaddedHeight(padded_height.ilog2()));
prof_stop!(maybe_profiler, "derive additional parameters");

prof_start!(maybe_profiler, "base tables");
Expand Down Expand Up @@ -174,7 +174,7 @@ impl Stark {
prof_stop!(maybe_profiler, "Merkle tree");

prof_start!(maybe_profiler, "Fiat-Shamir", "hash");
proof_stream.enqueue(&ProofItem::MerkleRoot(base_merkle_tree.get_root()));
proof_stream.enqueue(ProofItem::MerkleRoot(base_merkle_tree.get_root()));
let challenges = proof_stream.sample_scalars(Challenges::num_challenges_to_sample());
let challenges = Challenges::new(challenges, claim);
prof_stop!(maybe_profiler, "Fiat-Shamir");
Expand All @@ -199,7 +199,7 @@ impl Stark {
prof_stop!(maybe_profiler, "Merkle tree");

prof_start!(maybe_profiler, "Fiat-Shamir", "hash");
proof_stream.enqueue(&ProofItem::MerkleRoot(ext_merkle_tree.get_root()));
proof_stream.enqueue(ProofItem::MerkleRoot(ext_merkle_tree.get_root()));
prof_stop!(maybe_profiler, "Fiat-Shamir");
prof_stop!(maybe_profiler, "ext tables");

Expand Down Expand Up @@ -270,7 +270,7 @@ impl Stark {
let quot_merkle_tree: MerkleTree<StarkHasher> =
MTMaker::from_digests(&fri_domain_quotient_segment_codewords_digests);
let quot_merkle_tree_root = quot_merkle_tree.get_root();
proof_stream.enqueue(&ProofItem::MerkleRoot(quot_merkle_tree_root));
proof_stream.enqueue(ProofItem::MerkleRoot(quot_merkle_tree_root));
prof_stop!(maybe_profiler, "Merkle tree");
prof_stop!(maybe_profiler, "commit to quotient codeword segments");
debug_assert_eq!(fri.domain.length, quot_merkle_tree.get_leaf_count());
Expand All @@ -280,16 +280,16 @@ impl Stark {
let out_of_domain_point_curr_row = proof_stream.sample_scalars(1)[0];
let out_of_domain_point_next_row = trace_domain_generator * out_of_domain_point_curr_row;

proof_stream.enqueue(&ProofItem::OutOfDomainBaseRow(
proof_stream.enqueue(ProofItem::OutOfDomainBaseRow(
master_base_table.row(out_of_domain_point_curr_row).to_vec(),
));
proof_stream.enqueue(&ProofItem::OutOfDomainExtRow(
proof_stream.enqueue(ProofItem::OutOfDomainExtRow(
master_ext_table.row(out_of_domain_point_curr_row).to_vec(),
));
proof_stream.enqueue(&ProofItem::OutOfDomainBaseRow(
proof_stream.enqueue(ProofItem::OutOfDomainBaseRow(
master_base_table.row(out_of_domain_point_next_row).to_vec(),
));
proof_stream.enqueue(&ProofItem::OutOfDomainExtRow(
proof_stream.enqueue(ProofItem::OutOfDomainExtRow(
master_ext_table.row(out_of_domain_point_next_row).to_vec(),
));

Expand All @@ -300,7 +300,7 @@ impl Stark {
.to_vec()
.try_into()
.unwrap();
proof_stream.enqueue(&ProofItem::OutOfDomainQuotientSegments(
proof_stream.enqueue(ProofItem::OutOfDomainQuotientSegments(
out_of_domain_curr_row_quot_segments,
));
prof_stop!(maybe_profiler, "out-of-domain rows");
Expand Down Expand Up @@ -455,8 +455,8 @@ impl Stark {
);
let base_authentication_structure =
base_merkle_tree.get_authentication_structure(&revealed_current_row_indices);
proof_stream.enqueue(&ProofItem::MasterBaseTableRows(revealed_base_elems));
proof_stream.enqueue(&ProofItem::AuthenticationStructure(
proof_stream.enqueue(ProofItem::MasterBaseTableRows(revealed_base_elems));
proof_stream.enqueue(ProofItem::AuthenticationStructure(
base_authentication_structure,
));

Expand All @@ -466,8 +466,8 @@ impl Stark {
);
let ext_authentication_structure =
ext_merkle_tree.get_authentication_structure(&revealed_current_row_indices);
proof_stream.enqueue(&ProofItem::MasterExtTableRows(revealed_ext_elems));
proof_stream.enqueue(&ProofItem::AuthenticationStructure(
proof_stream.enqueue(ProofItem::MasterExtTableRows(revealed_ext_elems));
proof_stream.enqueue(ProofItem::AuthenticationStructure(
ext_authentication_structure,
));

Expand All @@ -481,10 +481,10 @@ impl Stark {
.collect_vec();
let revealed_quotient_authentication_structure =
quot_merkle_tree.get_authentication_structure(&revealed_current_row_indices);
proof_stream.enqueue(&ProofItem::QuotientSegmentsElements(
proof_stream.enqueue(ProofItem::QuotientSegmentsElements(
revealed_quotient_segments_rows,
));
proof_stream.enqueue(&ProofItem::AuthenticationStructure(
proof_stream.enqueue(ProofItem::AuthenticationStructure(
revealed_quotient_authentication_structure,
));
prof_stop!(maybe_profiler, "open trace leafs");
Expand Down

0 comments on commit 3f0c320

Please sign in to comment.