diff --git a/Cargo.toml b/Cargo.toml index 3f144652..70aa374b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,6 @@ resolver = "3" [workspace.lints] rust.missing_debug_implementations = "warn" -rust.unreachable_pub = "warn" rust.unused_must_use = "deny" rust.rust_2018_idioms = { level = "deny", priority = -1 } rust.dead_code = "allow" @@ -35,7 +34,6 @@ rustdoc.all = "warn" all = { level = "warn", priority = -1 } nursery = { level = "warn", priority = -1 } doc_markdown = "allow" -pedantic = { level = "warn", priority = -1 } cast_possible_truncation = "allow" cast_precision_loss = "allow" missing_errors_doc = "allow" @@ -45,6 +43,20 @@ should_panic_without_expect = "allow" similar_names = "allow" suboptimal_flops = "allow" cast_sign_loss = "allow" +redundant_pub_crate = "allow" +too_many_lines = "allow" +to_string_trait_impl = "allow" +transmute_ptr_to_ptr = "allow" +missing_transmute_annotations = "allow" +type_complexity = "allow" +needless_range_loop = "allow" +too_many_arguments = "allow" +result_large_err = "allow" +format_push_string = "allow" +option_if_let_else = "allow" +mismatching_type_param_order = "allow" +cognitive_complexity = "allow" +large_enum_variant = "allow" [workspace.dependencies] lean-isa = { path = "crates/leanIsa" } diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index fce8ba41..868cd247 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -76,21 +76,21 @@ pub fn prove_many_air_3< "TODO handle the case UNIVARIATE_SKIPS >= log_length" ); } - for i in 0..tables_2.len() { + for witness in witnesses_2.iter().take(tables_2.len()) { assert!( - univariate_skips < witnesses_2[i].log_n_rows(), + univariate_skips < witness.log_n_rows(), "TODO handle the case UNIVARIATE_SKIPS >= log_length" ); } - for i in 0..tables_3.len() { + for witness in witnesses_3.iter().take(tables_3.len()) { assert!( - univariate_skips < witnesses_3[i].log_n_rows(), + univariate_skips < witness.log_n_rows(), "TODO handle the case UNIVARIATE_SKIPS >= log_length" ); } - let structured_air = if tables_1.len() > 0 { + let structured_air = if !tables_1.is_empty() { tables_1[0].air.structured() - } else if tables_2.len() > 0 { + } else if !tables_2.is_empty() { tables_2[0].air.structured() } else { tables_3[0].air.structured() @@ -113,15 +113,15 @@ pub fn prove_many_air_3< let log_lengths_1 = witnesses_1 .iter() - .map(|w| w.log_n_rows()) + .map(super::witness::AirWitness::log_n_rows) .collect::>(); let log_lengths_2 = witnesses_2 .iter() - .map(|w| w.log_n_rows()) + .map(super::witness::AirWitness::log_n_rows) .collect::>(); let log_lengths_3 = witnesses_3 .iter() - .map(|w| w.log_n_rows()) + .map(super::witness::AirWitness::log_n_rows) .collect::>(); let max_n_constraints = Iterator::max( @@ -186,10 +186,10 @@ pub fn prove_many_air_3< sumcheck::prove_in_parallel_3::( vec![univariate_skips; n_tables], columns_for_zero_check_packed, - tables_1.iter().map(|t| &t.air).collect::>(), - tables_2.iter().map(|t| &t.air).collect::>(), - tables_3.iter().map(|t| &t.air).collect::>(), - vec![&constraints_batching_scalars; n_tables], + &tables_1.iter().map(|t| &t.air).collect::>(), + &tables_2.iter().map(|t| &t.air).collect::>(), + &tables_3.iter().map(|t| &t.air).collect::>(), + &vec![constraints_batching_scalars.as_slice(); n_tables], all_zerocheck_challenges, vec![true; n_tables], prover_state, @@ -237,11 +237,11 @@ pub fn prove_many_air_3< impl>, A: MyAir> AirTable { #[instrument(name = "air: prove base", skip_all)] - pub fn prove_base<'a>( + pub fn prove_base( &self, prover_state: &mut FSProver>, univariate_skips: usize, - witness: AirWitness<'a, PF>, + witness: AirWitness<'_, PF>, ) -> Vec> { let mut res = prove_many_air_3::( prover_state, @@ -258,11 +258,11 @@ impl>, A: MyAir> AirTable { } #[instrument(name = "air: prove base", skip_all)] - pub fn prove_extension<'a>( + pub fn prove_extension( &self, prover_state: &mut FSProver>, univariate_skips: usize, - witness: AirWitness<'a, EF>, + witness: AirWitness<'_, EF>, ) -> Vec> { let mut res = prove_many_air_3::( prover_state, @@ -290,8 +290,8 @@ fn eval_unstructured_column_groups> + ExtensionField> + ExtensionField>>( witnesses_1 .iter() .chain(witnesses_2) - .map(|w| w.max_columns_per_group()) - .chain(witnesses_3.iter().map(|w| w.max_columns_per_group())), + .map(super::witness::AirWitness::max_columns_per_group) + .chain( + witnesses_3 + .iter() + .map(super::witness::AirWitness::max_columns_per_group), + ), ) .unwrap(); let columns_batching_scalars = prover_state.sample_vec(log2_ceil_usize(max_columns_per_group)); @@ -374,7 +378,7 @@ fn open_unstructured_columns<'a, EF: ExtensionField>>( [ from_end(&columns_batching_scalars, log2_ceil_usize(group.len())).to_vec(), epsilons.0.clone(), - outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec(), + outer_sumcheck_challenge[1..=(log_n_rows - univariate_skips)].to_vec(), ] .concat(), ), @@ -387,10 +391,10 @@ fn open_unstructured_columns<'a, EF: ExtensionField>>( } #[instrument(skip_all)] -fn open_structured_columns<'a, EF: ExtensionField> + ExtensionField, IF: Field>( +fn open_structured_columns> + ExtensionField, IF: Field>( prover_state: &mut FSProver>, univariate_skips: usize, - witness: &AirWitness<'a, IF>, + witness: &AirWitness<'_, IF>, outer_sumcheck_challenge: &[EF], ) -> Vec> { let columns_batching_scalars = prover_state.sample_vec(witness.log_max_columns_per_group()); @@ -403,7 +407,7 @@ fn open_structured_columns<'a, EF: ExtensionField> + ExtensionField, for group in &witness.column_groups { let batched_column = multilinears_linear_combination( &witness.cols[group.clone()], - &eval_eq(&from_end( + &eval_eq(from_end( &columns_batching_scalars, log2_ceil_usize(group.len()), ))[..group.len()], @@ -419,7 +423,7 @@ fn open_structured_columns<'a, EF: ExtensionField> + ExtensionField, let sub_evals = fold_multilinear( &batched_column_mixed, &MultilinearPoint( - outer_sumcheck_challenge[1..witness.log_n_rows() - univariate_skips + 1].to_vec(), + outer_sumcheck_challenge[1..=(witness.log_n_rows() - univariate_skips)].to_vec(), ), ); @@ -434,7 +438,7 @@ fn open_structured_columns<'a, EF: ExtensionField> + ExtensionField, { let point = [ epsilons.clone(), - outer_sumcheck_challenge[1..witness.log_n_rows() - univariate_skips + 1].to_vec(), + outer_sumcheck_challenge[1..=(witness.log_n_rows() - univariate_skips)].to_vec(), ] .concat(); let mles_for_inner_sumcheck = vec![ @@ -456,8 +460,8 @@ fn open_structured_columns<'a, EF: ExtensionField> + ExtensionField, let (inner_challenges, all_inner_evals, _) = sumcheck::prove_in_parallel_1::( vec![1; n_groups], all_inner_mles, - vec![&ProductComputation; n_groups], - vec![&[]; n_groups], + &vec![&ProductComputation; n_groups], + &vec![[].as_slice(); n_groups], vec![None; n_groups], vec![false; n_groups], prover_state, diff --git a/crates/air/src/table.rs b/crates/air/src/table.rs index 6c620001..2cdc34f4 100644 --- a/crates/air/src/table.rs +++ b/crates/air/src/table.rs @@ -19,8 +19,12 @@ impl>, A: MyAir> AirTable { pub fn new(air: A) -> Self { let symbolic_constraints = get_symbolic_constraints(&air, 0, 0); let n_constraints = symbolic_constraints.len(); - let constraint_degree = - Iterator::max(symbolic_constraints.iter().map(|c| c.degree_multiple())).unwrap(); + let constraint_degree = Iterator::max( + symbolic_constraints + .iter() + .map(p3_uni_stark::SymbolicExpression::degree_multiple), + ) + .unwrap(); assert_eq!(constraint_degree, air.degree()); Self { air, @@ -42,17 +46,17 @@ impl>, A: MyAir> AirTable { EF: ExtensionField, { if witness.n_columns() != self.n_columns() { - return Err(format!("Invalid number of columns",)); + return Err("Invalid number of columns".to_string()); } let handle_errors = |row: usize, constraint_checker: &mut ConstraintChecker<'_, IF, EF>| { - if constraint_checker.errors.len() > 0 { + if !constraint_checker.errors.is_empty() { return Err(format!( "Trace is not valid at row {}: contraints not respected: {}", row, constraint_checker .errors .iter() - .map(|e| e.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", ") )); diff --git a/crates/air/src/test.rs b/crates/air/src/test.rs index 4ba39feb..b3e9f091 100644 --- a/crates/air/src/test.rs +++ b/crates/air/src/test.rs @@ -1,8 +1,7 @@ use std::borrow::Borrow; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::PrimeCharacteristicRing; -use p3_field::extension::BinomialExtensionField; +use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_koala_bear::KoalaBear; use p3_matrix::Matrix; use rand::{Rng, SeedableRng, rngs::StdRng}; @@ -107,9 +106,13 @@ fn generate_structured_trace(outer_challenges: &[F]) -> Vec { +pub fn matrix_up_folded(outer_challenges: &[F]) -> Vec { let n = outer_challenges.len(); let mut folded = eval_eq(outer_challenges); let outer_challenges_prod: F = outer_challenges.iter().copied().product(); @@ -13,7 +13,7 @@ pub(crate) fn matrix_up_folded(outer_challenges: &[F]) -> Vec { } #[instrument(name = "matrix_down_folded", skip_all)] -pub(crate) fn matrix_down_folded(outer_challenges: &[F]) -> Vec { +pub fn matrix_down_folded(outer_challenges: &[F]) -> Vec { let n = outer_challenges.len(); let mut folded = vec![F::ZERO; 1 << n]; for k in 0..n { diff --git a/crates/air/src/utils.rs b/crates/air/src/utils.rs index c78e92c2..5ed9df37 100644 --- a/crates/air/src/utils.rs +++ b/crates/air/src/utils.rs @@ -2,7 +2,7 @@ use p3_field::Field; use rayon::prelude::*; use whir_p3::poly::multilinear::MultilinearPoint; -pub(crate) fn matrix_up_lde(point: &[F]) -> F { +pub fn matrix_up_lde(point: &[F]) -> F { /* Matrix UP: @@ -29,7 +29,7 @@ pub(crate) fn matrix_up_lde(point: &[F]) -> F { * (F::ONE - point[point.len() - 1] * F::TWO) } -pub(crate) fn matrix_down_lde(point: &[F]) -> F { +pub fn matrix_down_lde(point: &[F]) -> F { /* Matrix DOWN: @@ -130,7 +130,7 @@ fn next_mle(point: &[F]) -> F { .sum() } -pub(crate) fn columns_up_and_down(columns: &[&[F]]) -> Vec> { +pub fn columns_up_and_down(columns: &[&[F]]) -> Vec> { columns .par_iter() .map(|c| column_up(c)) @@ -138,13 +138,13 @@ pub(crate) fn columns_up_and_down(columns: &[&[F]]) -> Vec> { .collect() } -pub(crate) fn column_up(column: &[F]) -> Vec { +pub fn column_up(column: &[F]) -> Vec { let mut up = column.to_vec(); up[column.len() - 1] = up[column.len() - 2]; up } -pub(crate) fn column_down(column: &[F]) -> Vec { +pub fn column_down(column: &[F]) -> Vec { let mut down = column[1..].to_vec(); down.push(*down.last().unwrap()); down diff --git a/crates/air/src/verify.rs b/crates/air/src/verify.rs index 2b0b53f8..76454ee5 100644 --- a/crates/air/src/verify.rs +++ b/crates/air/src/verify.rs @@ -1,25 +1,26 @@ +use std::ops::Range; + use p3_field::{ExtensionField, cyclic_subgroup_known_order, dot_product}; use p3_util::log2_ceil_usize; -use std::ops::Range; use sumcheck::SumcheckComputation; use tracing::instrument; -use utils::univariate_selectors; -use utils::{Evaluation, from_end}; -use utils::{FSVerifier, PF}; -use whir_p3::fiat_shamir::FSChallenger; -use whir_p3::poly::evals::eval_eq; +use utils::{Evaluation, FSVerifier, PF, from_end, univariate_selectors}; use whir_p3::{ - fiat_shamir::errors::ProofError, - poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, + fiat_shamir::{FSChallenger, errors::ProofError}, + poly::{ + evals::{EvaluationsList, eval_eq}, + multilinear::MultilinearPoint, + }, }; -use crate::MyAir; -use crate::utils::{matrix_down_lde, matrix_up_lde}; - use super::table::AirTable; +use crate::{ + MyAir, + utils::{matrix_down_lde, matrix_up_lde}, +}; #[instrument(name = "air table: verify many", skip_all)] -pub fn verify_many_air_2<'a, EF: ExtensionField>, A1: MyAir, A2: MyAir>( +pub fn verify_many_air_2>, A1: MyAir, A2: MyAir>( verifier_state: &mut FSVerifier>, tables_1: &[&AirTable], tables_2: &[&AirTable], @@ -40,7 +41,6 @@ pub fn verify_many_air_2<'a, EF: ExtensionField>, A1: MyAir, A2: MyAi #[instrument(name = "air table: verify many", skip_all)] pub fn verify_many_air_3< - 'a, EF: ExtensionField>, A1: MyAir, A2: MyAir, @@ -66,7 +66,7 @@ pub fn verify_many_air_3< sumcheck::verify_with_univariate_skip_in_parallel::( verifier_state, univariate_skips, - &log_lengths, + log_lengths, &tables_1 .iter() .map(|t| t.air.degree() + 1) @@ -155,7 +155,7 @@ pub fn verify_many_air_3< global_zerocheck_challenges[1..log_lengths[i] + 1 - univariate_skips].to_vec(), ) .eq_poly_outside(&MultilinearPoint( - outer_sumcheck_point[1..log_lengths[i] - univariate_skips + 1].to_vec(), + outer_sumcheck_point[1..=(log_lengths[i] - univariate_skips)].to_vec(), )) * global_constraint_evals[i] != outer_sumcheck_values[i] { @@ -199,7 +199,7 @@ pub fn verify_many_air_3< &column_groups[i], &Evaluation { point: MultilinearPoint( - outer_sumcheck_point[1..log_lengths[i] - univariate_skips + 1].to_vec(), + outer_sumcheck_point[1..=(log_lengths[i] - univariate_skips)].to_vec(), ), value: outer_sumcheck_values[i], }, @@ -212,11 +212,11 @@ pub fn verify_many_air_3< verify_many_unstructured_columns( verifier_state, univariate_skips, - all_inner_sums, - &column_groups, + &all_inner_sums, + column_groups, &outer_sumcheck_point, &outer_selector_evals, - &log_lengths, + log_lengths, ) } } @@ -247,7 +247,7 @@ impl>, A: MyAir> AirTable { fn verify_many_unstructured_columns>>( verifier_state: &mut FSVerifier>, univariate_skips: usize, - all_inner_sums: Vec>, + all_inner_sums: &[Vec], column_groups: &[Vec>], outer_sumcheck_point: &MultilinearPoint, outer_selector_evals: &[EF], @@ -256,7 +256,7 @@ fn verify_many_unstructured_columns>>( let max_columns_per_group = Iterator::max( column_groups .iter() - .map(|g| Iterator::max(g.iter().map(|r| r.len())).unwrap()), + .map(|g| Iterator::max(g.iter().map(std::iter::ExactSizeIterator::len)).unwrap()), ) .unwrap(); let log_max_columns_per_group = log2_ceil_usize(max_columns_per_group); @@ -273,7 +273,7 @@ fn verify_many_unstructured_columns>>( outer_selector_evals.iter().copied(), ) != dot_product::( all_inner_sums[i][group.clone()].iter().copied(), - eval_eq(&from_end( + eval_eq(from_end( &columns_batching_scalars, log2_ceil_usize(group.len()), ))[..group.len()] @@ -299,7 +299,7 @@ fn verify_many_unstructured_columns>>( [ from_end(&columns_batching_scalars, log2_ceil_usize(group.len())).to_vec(), epsilons.0.clone(), - outer_sumcheck_point[1..log_lengths[i] - univariate_skips + 1].to_vec(), + outer_sumcheck_point[1..=(log_lengths[i] - univariate_skips)].to_vec(), ] .concat(), ); @@ -314,6 +314,7 @@ fn verify_many_unstructured_columns>>( Ok(all_evaluations_remaining_to_verify) } +#[allow(clippy::too_many_arguments)] fn verify_structured_columns>>( verifier_state: &mut FSVerifier>, n_columns: usize, @@ -324,7 +325,8 @@ fn verify_structured_columns>>( outer_selector_evals: &[EF], log_n_rows: usize, ) -> Result>, ProofError> { - let max_columns_per_group = Iterator::max(column_groups.iter().map(|g| g.len())).unwrap(); + let max_columns_per_group = + Iterator::max(column_groups.iter().map(std::iter::ExactSizeIterator::len)).unwrap(); let log_max_columns_per_group = log2_ceil_usize(max_columns_per_group); let columns_batching_scalars = verifier_state.sample_vec(log_max_columns_per_group); @@ -346,7 +348,7 @@ fn verify_structured_columns>>( outer_selector_evals.iter().copied(), ) != dot_product::( witness_up.iter().copied(), - eval_eq(&from_end( + eval_eq(from_end( &columns_batching_scalars, log2_ceil_usize(group.len()), ))[..group.len()] @@ -354,7 +356,7 @@ fn verify_structured_columns>>( .copied(), ) + dot_product::( witness_down.iter().copied(), - eval_eq(&from_end( + eval_eq(from_end( &columns_batching_scalars, log2_ceil_usize(group.len()), ))[..group.len()] diff --git a/crates/air/src/witness.rs b/crates/air/src/witness.rs index 720fe921..649959e8 100644 --- a/crates/air/src/witness.rs +++ b/crates/air/src/witness.rs @@ -22,7 +22,10 @@ impl<'a, F> Deref for AirWitness<'a, F> { impl<'a, F> AirWitness<'a, F> { pub fn new(cols: &'a [impl Borrow<[F]>], column_groups: &[Range]) -> Self { - let cols = cols.iter().map(|col| col.borrow()).collect::>(); + let cols = cols + .iter() + .map(std::borrow::Borrow::borrow) + .collect::>(); assert!( cols.iter() .all(|col| col.len() == (1 << log2_strict_usize(cols[0].len()))), @@ -37,22 +40,31 @@ impl<'a, F> AirWitness<'a, F> { } } - pub fn n_columns(&self) -> usize { + #[must_use] + pub const fn n_columns(&self) -> usize { self.cols.len() } + #[must_use] pub fn n_rows(&self) -> usize { self.cols[0].len() } + #[must_use] pub fn log_n_rows(&self) -> usize { log2_strict_usize(self.n_rows()) } + #[must_use] pub fn max_columns_per_group(&self) -> usize { - self.column_groups.iter().map(|g| g.len()).max().unwrap() + self.column_groups + .iter() + .map(std::iter::ExactSizeIterator::len) + .max() + .unwrap() } + #[must_use] pub fn log_max_columns_per_group(&self) -> usize { log2_ceil_usize(self.max_columns_per_group()) } diff --git a/crates/compiler/src/a_simplify_lang.rs b/crates/compiler/src/a_simplify_lang.rs index 9d0a3314..4f119972 100644 --- a/crates/compiler/src/a_simplify_lang.rs +++ b/crates/compiler/src/a_simplify_lang.rs @@ -26,7 +26,7 @@ pub(crate) struct SimpleFunction { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum VarOrConstMallocAccess { +pub enum VarOrConstMallocAccess { Var(Var), ConstMallocAccess { malloc_label: ConstMallocLabel, @@ -37,11 +37,11 @@ pub(crate) enum VarOrConstMallocAccess { impl From for SimpleExpr { fn from(var_or_const: VarOrConstMallocAccess) -> Self { match var_or_const { - VarOrConstMallocAccess::Var(var) => SimpleExpr::Var(var), + VarOrConstMallocAccess::Var(var) => Self::Var(var), VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset, - } => SimpleExpr::ConstMallocAccess { + } => Self::ConstMallocAccess { malloc_label, offset, }, @@ -54,8 +54,8 @@ impl TryInto for SimpleExpr { fn try_into(self) -> Result { match self { - SimpleExpr::Var(var) => Ok(VarOrConstMallocAccess::Var(var)), - SimpleExpr::ConstMallocAccess { + Self::Var(var) => Ok(VarOrConstMallocAccess::Var(var)), + Self::ConstMallocAccess { malloc_label, offset, } => Ok(VarOrConstMallocAccess::ConstMallocAccess { @@ -74,7 +74,7 @@ impl From for VarOrConstMallocAccess { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum SimpleLine { +pub enum SimpleLine { Assignment { var: VarOrConstMallocAccess, operation: HighLevelOperation, @@ -128,7 +128,7 @@ pub(crate) enum SimpleLine { }, } -pub(crate) fn simplify_program(mut program: Program) -> SimpleProgram { +pub fn simplify_program(mut program: Program) -> SimpleProgram { handle_const_arguments(&mut program); let mut new_functions = BTreeMap::new(); let mut counters = Counters::default(); @@ -293,9 +293,9 @@ fn simplify_lines( unreachable!("Weird: {:?}, {:?}", left, right) }; res.push(SimpleLine::Assignment { - var: var, + var, operation: HighLevelOperation::Add, - arg0: other.into(), + arg0: other, arg1: SimpleExpr::zero(), }); } @@ -383,7 +383,7 @@ fn simplify_lines( unroll, } => { if *unroll { - let (internal_variables, _) = find_variable_usage(&body); + let (internal_variables, _) = find_variable_usage(body); let mut unrolled_lines = Vec::new(); let start_evaluated = start.naive_eval().unwrap().to_usize(); let end_evaluated = end.naive_eval().unwrap().to_usize(); @@ -413,8 +413,10 @@ fn simplify_lines( unimplemented!("Reverse for non-unrolled loops are not implemented yet"); } - let mut loop_const_malloc = ConstMalloc::default(); - loop_const_malloc.counter = const_malloc.counter; + let mut loop_const_malloc = ConstMalloc { + counter: const_malloc.counter, + ..Default::default() + }; let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); array_manager.valid.clear(); let simplified_body = simplify_lines( @@ -432,7 +434,7 @@ fn simplify_lines( counters.loops += 1; // Find variables used inside loop but defined outside - let (_, mut external_vars) = find_variable_usage(&body); + let (_, mut external_vars) = find_variable_usage(body); // Include variables in start/end for expr in [start, end] { @@ -586,7 +588,7 @@ fn simplify_lines( Line::DecomposeBits { var, to_decompose } => { assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); let simplified_to_decompose = simplify_expr( - &to_decompose, + to_decompose, &mut res, counters, array_manager, @@ -618,7 +620,7 @@ fn simplify_expr( const_malloc: &ConstMalloc, ) -> SimpleExpr { match expr { - Expression::Value(value) => return value.simplify_if_const(), + Expression::Value(value) => value.simplify_if_const(), Expression::ArrayAccess { array, index } => { if let Some(label) = const_malloc.map.get(array) { if let Ok(mut offset) = ConstExpression::try_from(*index.clone()) { @@ -633,7 +635,7 @@ fn simplify_expr( let aux_arr = array_manager.get_aux_var(array, index); // auxiliary var to store m[array + index] if !array_manager.valid.insert(aux_arr.clone()) { - return SimpleExpr::Var(aux_arr.clone()); + return SimpleExpr::Var(aux_arr); } handle_array_assignment( @@ -645,7 +647,7 @@ fn simplify_expr( array_manager, const_malloc, ); - return SimpleExpr::Var(aux_arr); + SimpleExpr::Var(aux_arr) } Expression::Binary { left, @@ -673,13 +675,13 @@ fn simplify_expr( arg0: left_var, arg1: right_var, }); - return SimpleExpr::Var(aux_var); + SimpleExpr::Var(aux_var) } } } /// Returns (internal_vars, external_vars) -pub(crate) fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { +pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { let mut internal_vars = BTreeSet::new(); let mut external_vars = BTreeSet::new(); @@ -761,7 +763,7 @@ pub(crate) fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet { - on_new_expr(&to_decompose, &internal_vars, &mut external_vars); + on_new_expr(to_decompose, &internal_vars, &mut external_vars); internal_vars.insert(var.clone()); } Line::ForLoop { @@ -844,7 +846,7 @@ fn handle_array_assignment( let arg1 = simplify_expr(&right, res, counters, array_manager, const_malloc); res.push(SimpleLine::Assignment { var: VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: label.clone(), + malloc_label: *label, offset, }, operation, @@ -857,7 +859,7 @@ fn handle_array_assignment( } let value_simplified = match access_type { - ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Var(var.clone()), + ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Var(var), ArrayAccessType::ArrayIsAssigned(expr) => { simplify_expr(&expr, res, counters, array_manager, const_malloc) } @@ -865,25 +867,24 @@ fn handle_array_assignment( // TODO opti: in some case we could use ConstMallocAccess - let (index_var, shift) = match simplified_index { - SimpleExpr::Constant(c) => (array, c), - _ => { - // Create pointer variable: ptr = array + index - let ptr_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: ptr_var.clone().into(), - operation: HighLevelOperation::Add, - arg0: array.clone().into(), - arg1: simplified_index.into(), - }); - (ptr_var, ConstExpression::zero()) - } + let (index_var, shift) = if let SimpleExpr::Constant(c) = simplified_index { + (array, c) + } else { + // Create pointer variable: ptr = array + index + let ptr_var = format!("@aux_var_{}", counters.aux_vars); + counters.aux_vars += 1; + res.push(SimpleLine::Assignment { + var: ptr_var.clone().into(), + operation: HighLevelOperation::Add, + arg0: array.into(), + arg1: simplified_index, + }); + (ptr_var, ConstExpression::zero()) }; res.push(SimpleLine::RawAccess { - res: value_simplified.into(), - index: index_var.into(), + res: value_simplified, + index: index_var, shift, }); } @@ -897,7 +898,7 @@ fn create_recursive_function( external_vars: &[Var], ) -> SimpleFunction { // Add iterator increment - let next_iter = format!("@incremented_{}", iterator); + let next_iter = format!("@incremented_{iterator}"); body.push(SimpleLine::Assignment { var: next_iter.clone().into(), operation: HighLevelOperation::Add, @@ -918,7 +919,7 @@ fn create_recursive_function( return_data: vec![], }); - let diff_var = format!("@diff_{}", iterator); + let diff_var = format!("@diff_{iterator}"); let instructions = vec![ SimpleLine::Assignment { @@ -956,7 +957,7 @@ fn replace_vars_for_unroll_in_expr( if var == iterator { *value_expr = SimpleExpr::Constant(ConstExpression::from(iterator_value)); } else if internal_vars.contains(var) { - *var = format!("@unrolled_{}_{}", iterator_value, var).into(); + *var = format!("@unrolled_{iterator_value}_{var}"); } } SimpleExpr::Constant(_) | SimpleExpr::ConstMallocAccess { .. } => {} @@ -964,7 +965,7 @@ fn replace_vars_for_unroll_in_expr( Expression::ArrayAccess { array, index } => { assert!(array != iterator, "Weird"); if internal_vars.contains(array) { - *array = format!("@unrolled_{}_{}", iterator_value, array).into(); + *array = format!("@unrolled_{iterator_value}_{array}"); } replace_vars_for_unroll_in_expr(index, iterator, iterator_value, internal_vars); } @@ -985,7 +986,7 @@ fn replace_vars_for_unroll( match line { Line::Assignment { var, value } => { assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{}_{}", iterator_value, var).into(); + *var = format!("@unrolled_{iterator_value}_{var}"); replace_vars_for_unroll_in_expr(value, iterator, iterator_value, internal_vars); } Line::ArrayAssign { @@ -996,7 +997,7 @@ fn replace_vars_for_unroll( } => { assert!(array != iterator, "Weird"); if internal_vars.contains(array) { - *array = format!("@unrolled_{}_{}", iterator_value, array).into(); + *array = format!("@unrolled_{iterator_value}_{array}"); } replace_vars_for_unroll_in_expr(index, iterator, iterator_value, internal_vars); replace_vars_for_unroll_in_expr(value, iterator, iterator_value, internal_vars); @@ -1024,7 +1025,7 @@ fn replace_vars_for_unroll( unroll: _, } => { assert!(other_iterator != iterator); - *other_iterator = format!("@unrolled_{}_{}", iterator_value, other_iterator).into(); + *other_iterator = format!("@unrolled_{iterator_value}_{other_iterator}"); replace_vars_for_unroll_in_expr(start, iterator, iterator_value, internal_vars); replace_vars_for_unroll_in_expr(end, iterator, iterator_value, internal_vars); replace_vars_for_unroll(body, iterator, iterator_value, internal_vars); @@ -1039,7 +1040,7 @@ fn replace_vars_for_unroll( replace_vars_for_unroll_in_expr(arg, iterator, iterator_value, internal_vars); } for ret in return_data { - *ret = format!("@unrolled_{}_{}", iterator_value, ret).into(); + *ret = format!("@unrolled_{iterator_value}_{ret}"); } } Line::FunctionRet { return_data } => { @@ -1056,14 +1057,13 @@ fn replace_vars_for_unroll( replace_vars_for_unroll_in_expr(arg, iterator, iterator_value, internal_vars); } for ret in res { - *ret = format!("@unrolled_{}_{}", iterator_value, ret).into(); + *ret = format!("@unrolled_{iterator_value}_{ret}"); } } - Line::Break => {} - Line::Panic => {} + Line::Break | Line::Panic => {} Line::Print { line_info, content } => { // Print statements are not unrolled, so we don't need to change them - *line_info += &format!(" (unrolled {})", iterator_value); + *line_info += &format!(" (unrolled {iterator_value})"); for var in content { replace_vars_for_unroll_in_expr(var, iterator, iterator_value, internal_vars); } @@ -1074,13 +1074,13 @@ fn replace_vars_for_unroll( vectorized: _, } => { assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{}_{}", iterator_value, var).into(); + *var = format!("@unrolled_{iterator_value}_{var}"); replace_vars_for_unroll_in_expr(size, iterator, iterator_value, internal_vars); // vectorized is not changed } Line::DecomposeBits { var, to_decompose } => { assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{}_{}", iterator_value, var).into(); + *var = format!("@unrolled_{iterator_value}_{var}"); replace_vars_for_unroll_in_expr( to_decompose, iterator, @@ -1144,12 +1144,12 @@ fn handle_const_arguments_helper( "{function_name}_{}", const_evals .iter() - .map(|(arg_var, const_eval)| { format!("{}={}", arg_var, const_eval) }) + .map(|(arg_var, const_eval)| { format!("{arg_var}={const_eval}") }) .collect::>() .join("_") ); - *function_name = const_funct_name.clone(); // change the name of the function called + function_name.clone_from(&const_funct_name); // change the name of the function called // ... and remove constant arguments *args = args .iter() @@ -1217,7 +1217,7 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) SimpleExpr::Constant(_) => {} }, Expression::ArrayAccess { array, index } => { - assert!(!map.contains_key(array), "Array {} is a constant", array); + assert!(!map.contains_key(array), "Array {array} is a constant"); replace_vars_by_const_in_expr(index, map); } Expression::Binary { left, right, .. } => { @@ -1231,7 +1231,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { for line in lines { match line { Line::Assignment { var, value } => { - assert!(!map.contains_key(var), "Variable {} is a constant", var); + assert!(!map.contains_key(var), "Variable {var} is a constant"); replace_vars_by_const_in_expr(value, map); } Line::ArrayAssign { @@ -1239,7 +1239,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { index, value, } => { - assert!(!map.contains_key(array), "Array {} is a constant", array); + assert!(!map.contains_key(array), "Array {array} is a constant"); replace_vars_by_const_in_expr(index, map); replace_vars_by_const_in_expr(value, map); } @@ -1252,8 +1252,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { for ret in return_data { assert!( !map.contains_key(ret), - "Return variable {} is a constant", - ret + "Return variable {ret} is a constant" ); } } @@ -1298,7 +1297,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { replace_vars_by_const_in_expr(arg, map); } for r in return_data { - assert!(!map.contains_key(r), "Return variable {} is a constant", r); + assert!(!map.contains_key(r), "Return variable {r} is a constant"); } } Line::Print { content, .. } => { @@ -1307,11 +1306,11 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { } } Line::DecomposeBits { var, to_decompose } => { - assert!(!map.contains_key(var), "Variable {} is a constant", var); + assert!(!map.contains_key(var), "Variable {var} is a constant"); replace_vars_by_const_in_expr(to_decompose, map); } Line::MAlloc { var, size, .. } => { - assert!(!map.contains_key(var), "Variable {} is a constant", var); + assert!(!map.contains_key(var), "Variable {var} is a constant"); replace_vars_by_const_in_expr(size, map); } Line::Panic | Line::Break => {} @@ -1328,8 +1327,8 @@ impl ToString for SimpleLine { impl ToString for VarOrConstMallocAccess { fn to_string(&self) -> String { match self { - VarOrConstMallocAccess::Var(var) => var.to_string(), - VarOrConstMallocAccess::ConstMallocAccess { + Self::Var(var) => var.to_string(), + Self::ConstMallocAccess { malloc_label, offset, } => { @@ -1347,7 +1346,7 @@ impl SimpleLine { fn to_string_with_indent(&self, indent: usize) -> String { let spaces = " ".repeat(indent); let line_str = match self { - SimpleLine::Assignment { + Self::Assignment { var, operation, arg0, @@ -1361,26 +1360,22 @@ impl SimpleLine { arg1.to_string() ) } - SimpleLine::DecomposeBits { + Self::DecomposeBits { var: result, to_decompose, label: _, } => { - format!( - "{} = decompose_bits({})", - result.to_string(), - to_decompose.to_string() - ) + format!("{} = decompose_bits({})", result, to_decompose.to_string()) } - SimpleLine::RawAccess { res, index, shift } => { + Self::RawAccess { res, index, shift } => { format!( "{} = memory[{} + {}]", res.to_string(), - index.to_string(), + index, shift.to_string() ) } - SimpleLine::IfNotZero { + Self::IfNotZero { condition, then_branch, else_branch, @@ -1415,37 +1410,37 @@ impl SimpleLine { ) } } - SimpleLine::FunctionCall { + Self::FunctionCall { function_name, args, return_data, } => { let args_str = args .iter() - .map(|arg| arg.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); let return_data_str = return_data .iter() - .map(|var| var.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); if return_data.is_empty() { - format!("{}({})", function_name, args_str) + format!("{function_name}({args_str})") } else { - format!("{} = {}({})", return_data_str, function_name, args_str) + format!("{return_data_str} = {function_name}({args_str})") } } - SimpleLine::FunctionRet { return_data } => { + Self::FunctionRet { return_data } => { let return_data_str = return_data .iter() - .map(|arg| arg.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); - format!("return {}", return_data_str) + format!("return {return_data_str}") } - SimpleLine::Precompile { + Self::Precompile { precompile, args, res: return_data, @@ -1455,23 +1450,23 @@ impl SimpleLine { return_data.join(", "), &precompile.name.to_string(), args.iter() - .map(|arg| arg.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", ") ) } - SimpleLine::Print { + Self::Print { line_info: _, content, } => { let content_str = content .iter() - .map(|c| c.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); - format!("print({})", content_str) + format!("print({content_str})") } - SimpleLine::HintMAlloc { + Self::HintMAlloc { var, size, vectorized, @@ -1481,18 +1476,18 @@ impl SimpleLine { } else { "malloc" }; - format!("{} = {}({})", var.to_string(), alloc_type, size.to_string()) + format!("{} = {}({})", var, alloc_type, size.to_string()) } - SimpleLine::ConstMalloc { + Self::ConstMalloc { var, size, label: _, } => { - format!("{} = malloc({})", var.to_string(), size.to_string()) + format!("{} = malloc({})", var, size.to_string()) } - SimpleLine::Panic => "panic".to_string(), + Self::Panic => "panic".to_string(), }; - format!("{}{}", spaces, line_str) + format!("{spaces}{line_str}") } } @@ -1501,7 +1496,7 @@ impl ToString for SimpleFunction { let args_str = self .arguments .iter() - .map(|arg| arg.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); diff --git a/crates/compiler/src/b_compile_intermediate.rs b/crates/compiler/src/b_compile_intermediate.rs index e538343c..24a90bb1 100644 --- a/crates/compiler/src/b_compile_intermediate.rs +++ b/crates/compiler/src/b_compile_intermediate.rs @@ -5,9 +5,18 @@ use std::{ use p3_field::Field; use utils::ToUsize; -use vm::*; - -use crate::{F, a_simplify_lang::*, intermediate_bytecode::*, lang::*, precompiles::*}; +use vm::{Label, Operation}; + +use crate::{ + F, + a_simplify_lang::{SimpleFunction, SimpleLine, SimpleProgram, VarOrConstMallocAccess}, + intermediate_bytecode::{ + HighLevelOperation, IntermediaryMemOrFpOrConstant, IntermediateBytecode, + IntermediateInstruction, IntermediateValue, + }, + lang::{ConstExpression, ConstMallocLabel, SimpleExpr, Var}, + precompiles::{Precompile, PrecompileName}, +}; struct Compiler { bytecode: BTreeMap>, @@ -21,7 +30,7 @@ struct Compiler { } impl Compiler { - fn new() -> Self { + const fn new() -> Self { Self { var_positions: BTreeMap::new(), stack_size: 0, @@ -39,7 +48,7 @@ impl Compiler { VarOrConstMallocAccess::Var(var) => (*self .var_positions .get(var) - .unwrap_or_else(|| panic!("Variable {} not in scope", var))) + .unwrap_or_else(|| panic!("Variable {var} not in scope"))) .into(), VarOrConstMallocAccess::ConstMallocAccess { malloc_label, @@ -49,7 +58,7 @@ impl Compiler { self.const_mallocs .get(malloc_label) .copied() - .unwrap_or_else(|| panic!("Const malloc {} not in scope", malloc_label)) + .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) .into(), ), operation: HighLevelOperation::Add, @@ -60,18 +69,19 @@ impl Compiler { } impl SimpleExpr { + #[allow(clippy::wrong_self_convention)] fn into_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediaryMemOrFpOrConstant { match self { - SimpleExpr::Var(var) => IntermediaryMemOrFpOrConstant::MemoryAfterFp { + Self::Var(var) => IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset: compiler.get_offset(&var.clone().into()), }, - SimpleExpr::Constant(c) => IntermediaryMemOrFpOrConstant::Constant(c.clone()), - SimpleExpr::ConstMallocAccess { + Self::Constant(c) => IntermediaryMemOrFpOrConstant::Constant(c.clone()), + Self::ConstMallocAccess { malloc_label, offset, } => IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: malloc_label.clone(), + malloc_label: *malloc_label, offset: offset.clone(), }), }, @@ -96,7 +106,7 @@ impl IntermediateValue { .const_mallocs .get(malloc_label) .copied() - .unwrap_or_else(|| panic!("Const malloc {} not in scope", malloc_label)) + .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) .into(), ), operation: HighLevelOperation::Add, @@ -119,7 +129,7 @@ impl IntermediateValue { } pub(crate) fn compile_to_intermediate_bytecode( - simple_program: SimpleProgram, + simple_program: &SimpleProgram, ) -> Result { let mut compiler = Compiler::new(); let mut memory_sizes = BTreeMap::new(); @@ -162,7 +172,7 @@ fn compile_function( } stack_pos += internal_vars.len(); - compiler.func_name = function.name.clone(); + compiler.func_name.clone_from(&function.name); compiler.var_positions = var_positions; compiler.stack_size = stack_pos; compiler.args_count = function.arguments.len(); @@ -211,9 +221,9 @@ fn compile_lines( compiler.if_counter += 1; let (if_label, else_label, end_label) = ( - format!("@if_{}", if_id), - format!("@else_{}", if_id), - format!("@if_else_end_{}", if_id), + format!("@if_{if_id}"), + format!("@else_{if_id}"), + format!("@if_else_end_{if_id}"), ); // c: condition @@ -261,7 +271,7 @@ fn compile_lines( arg_a: IntermediateValue::MemoryAfterFp { offset: one_minus_product_offset.into(), }, - arg_c: condition_simplified.clone(), + arg_c: condition_simplified, res: ConstExpression::zero().into(), }); @@ -281,7 +291,7 @@ fn compile_lines( let mut then_declared_vars = declared_vars.clone(); let then_instructions = compile_lines( - &then_branch, + then_branch, compiler, Some(end_label.to_string()), &mut then_declared_vars, @@ -291,7 +301,7 @@ fn compile_lines( compiler.stack_size = original_stack; let mut else_declared_vars = declared_vars.clone(); let else_instructions = compile_lines( - &else_branch, + else_branch, compiler, Some(end_label.to_string()), &mut else_declared_vars, @@ -333,7 +343,7 @@ fn compile_lines( } => { let call_id = compiler.call_counter; compiler.call_counter += 1; - let return_label = format!("@return_from_call_{}", call_id); + let return_label = format!("@return_from_call_{call_id}"); let new_fp_pos = compiler.stack_size; compiler.stack_size += 1; @@ -344,7 +354,7 @@ fn compile_lines( new_fp_pos, &return_label, compiler, - )?); + )); validate_vars_declared(args, declared_vars)?; declared_vars.extend(return_data.iter().cloned()); @@ -459,7 +469,14 @@ fn compile_lines( } SimpleLine::ConstMalloc { var, size, label } => { let size = size.naive_eval().unwrap().to_usize(); // TODO not very good; - handle_const_malloc(declared_vars, &mut instructions, compiler, var, size, label); + handle_const_malloc( + declared_vars, + &mut instructions, + compiler, + var, + size, + *label, + ); } SimpleLine::DecomposeBits { var, @@ -477,7 +494,7 @@ fn compile_lines( compiler, var, F::bits(), - label, + *label, ); } SimpleLine::Print { line_info, content } => { @@ -508,7 +525,7 @@ fn handle_const_malloc( compiler: &mut Compiler, var: &Var, size: usize, - label: &ConstMallocLabel, + label: ConstMallocLabel, ) { declared_vars.insert(var.clone()); instructions.push(IntermediateInstruction::Computation { @@ -519,9 +536,7 @@ fn handle_const_malloc( offset: compiler.get_offset(&var.clone().into()), }, }); - compiler - .const_mallocs - .insert(label.clone(), compiler.stack_size); + compiler.const_mallocs.insert(label, compiler.stack_size); compiler.stack_size += size; } @@ -541,7 +556,7 @@ fn validate_vars_declared>( for voc in vocs { if let SimpleExpr::Var(v) = voc.borrow() { if !declared.contains(v) { - return Err(format!("Variable {} not declared", v)); + return Err(format!("Variable {v} not declared")); } } } @@ -554,7 +569,7 @@ fn setup_function_call( new_fp_pos: usize, return_label: &str, compiler: &Compiler, -) -> Result, String> { +) -> Vec { let mut instructions = vec![ IntermediateInstruction::RequestMemory { offset: new_fp_pos.into(), @@ -585,20 +600,20 @@ fn setup_function_call( } instructions.push(IntermediateInstruction::Jump { - dest: IntermediateValue::label(format!("@function_{}", func_name)), + dest: IntermediateValue::label(format!("@function_{func_name}")), updated_fp: Some(IntermediateValue::MemoryAfterFp { offset: new_fp_pos.into(), }), }); - Ok(instructions) + instructions } fn compile_poseidon( instructions: &mut Vec, args: &[SimpleExpr], res: &[Var], - compiler: &mut Compiler, + compiler: &Compiler, declared_vars: &mut BTreeSet, over_16: bool, // otherwise over_24 ) -> Result<(), String> { @@ -674,10 +689,8 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet { internal_vars.insert(var.clone()); } } - SimpleLine::FunctionCall { return_data, .. } => { - internal_vars.extend(return_data.iter().cloned()); - } - SimpleLine::Precompile { + SimpleLine::FunctionCall { return_data, .. } + | SimpleLine::Precompile { res: return_data, .. } => { internal_vars.extend(return_data.iter().cloned()); diff --git a/crates/compiler/src/c_compile_final.rs b/crates/compiler/src/c_compile_final.rs index 99f91e23..9161539e 100644 --- a/crates/compiler/src/c_compile_final.rs +++ b/crates/compiler/src/c_compile_final.rs @@ -2,12 +2,21 @@ use std::collections::BTreeMap; use p3_field::PrimeCharacteristicRing; use utils::ToUsize; -use vm::*; +use vm::{ + Bytecode, Hint, Instruction, Label, MemOrConstant, MemOrFp, MemOrFpOrConstant, Operation, +}; -use crate::{F, PUBLIC_INPUT_START, ZERO_VEC_PTR, intermediate_bytecode::*, lang::*}; +use crate::{ + F, PUBLIC_INPUT_START, ZERO_VEC_PTR, + intermediate_bytecode::{ + IntermediaryMemOrFpOrConstant, IntermediateBytecode, IntermediateInstruction, + IntermediateValue, + }, + lang::{ConstExpression, ConstantValue}, +}; impl IntermediateInstruction { - fn is_hint(&self) -> bool { + const fn is_hint(&self) -> bool { match self { Self::RequestMemory { .. } | Self::Print { .. } @@ -62,7 +71,7 @@ pub(crate) fn compile_to_low_level_bytecode( pc += instructions.iter().filter(|i| !i.is_hint()).count(); } - let ending_pc = label_to_pc.get("@end_program").cloned().unwrap(); + let ending_pc = label_to_pc.get("@end_program").copied().unwrap(); let mut low_level_bytecode = Vec::new(); let mut hints = BTreeMap::new(); @@ -81,7 +90,7 @@ pub(crate) fn compile_to_low_level_bytecode( offset: eval_const_expression_usize(offset, &compiler), }); } - return None; + None }; let try_as_mem_or_fp = |value: &IntermediateValue| match value { @@ -89,7 +98,7 @@ pub(crate) fn compile_to_low_level_bytecode( offset: eval_const_expression_usize(offset, &compiler), }), IntermediateValue::Fp => Some(MemOrFp::Fp), - _ => None, + IntermediateValue::Constant(_) => None, }; for (pc_start, chunk) in code_chunks { @@ -102,8 +111,6 @@ pub(crate) fn compile_to_low_level_bytecode( mut arg_c, res, } => { - let operation: Operation = operation.try_into().unwrap(); - if let Some(arg_a_cst) = try_as_constant(&arg_a, &compiler) { if let Some(arg_b_cst) = try_as_constant(&arg_c, &compiler) { // res = constant +/x constant @@ -169,9 +176,8 @@ pub(crate) fn compile_to_low_level_bytecode( dest, updated_fp, } => { - let updated_fp = updated_fp - .map(|fp| try_as_mem_or_fp(&fp).unwrap()) - .unwrap_or(MemOrFp::Fp); + let updated_fp = + updated_fp.map_or(MemOrFp::Fp, |fp| try_as_mem_or_fp(&fp).unwrap()); low_level_bytecode.push(Instruction::JumpIfNotZero { condition: try_as_mem_or_constant(&condition).unwrap(), dest: try_as_mem_or_constant(&dest).unwrap(), @@ -183,8 +189,7 @@ pub(crate) fn compile_to_low_level_bytecode( condition: MemOrConstant::one(), dest: try_as_mem_or_constant(&dest).unwrap(), updated_fp: updated_fp - .map(|fp| try_as_mem_or_fp(&fp).unwrap()) - .unwrap_or(MemOrFp::Fp), + .map_or(MemOrFp::Fp, |fp| try_as_mem_or_fp(&fp).unwrap()), }); } IntermediateInstruction::Poseidon2_16 { arg_a, arg_b, res } => { @@ -275,12 +280,12 @@ pub(crate) fn compile_to_low_level_bytecode( } } - return Ok(Bytecode { + Ok(Bytecode { instructions: low_level_bytecode, hints, starting_frame_memory, ending_pc, - }); + }) } fn eval_constant_value(constant: &ConstantValue, compiler: &Compiler) -> usize { @@ -291,11 +296,8 @@ fn eval_constant_value(constant: &ConstantValue, compiler: &Compiler) -> usize { ConstantValue::FunctionSize { function_name } => *compiler .memory_size_per_function .get(function_name) - .expect(&format!( - "Function {} not found in memory size map", - function_name - )), - ConstantValue::Label(label) => compiler.label_to_pc.get(label).cloned().unwrap(), + .unwrap_or_else(|| panic!("Function {function_name} not found in memory size map")), + ConstantValue::Label(label) => compiler.label_to_pc.get(label).copied().unwrap(), } } @@ -321,21 +323,21 @@ impl IntermediateValue { fn try_into_mem_or_fp(&self, compiler: &Compiler) -> Result { match self { Self::MemoryAfterFp { offset } => Ok(MemOrFp::MemoryAfterFp { - offset: eval_const_expression_usize(&offset, compiler), + offset: eval_const_expression_usize(offset, compiler), }), Self::Fp => Ok(MemOrFp::Fp), - _ => Err(format!("Cannot convert {:?} to MemOrFp", self)), + _ => Err(format!("Cannot convert {self:?} to MemOrFp")), } } fn try_into_mem_or_constant(&self, compiler: &Compiler) -> Result { if let Some(cst) = try_as_constant(self, compiler) { return Ok(MemOrConstant::Constant(cst)); } - if let IntermediateValue::MemoryAfterFp { offset } = self { + if let Self::MemoryAfterFp { offset } = self { return Ok(MemOrConstant::MemoryAfterFp { offset: eval_const_expression_usize(offset, compiler), }); } - Err(format!("Cannot convert {:?} to MemOrConstant", self)) + Err(format!("Cannot convert {self:?} to MemOrConstant")) } } diff --git a/crates/compiler/src/intermediate_bytecode.rs b/crates/compiler/src/intermediate_bytecode.rs index 76211e86..8d4de9ac 100644 --- a/crates/compiler/src/intermediate_bytecode.rs +++ b/crates/compiler/src/intermediate_bytecode.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use p3_field::{PrimeCharacteristicRing, PrimeField64}; -use vm::*; +use vm::{Label, Operation}; use crate::{F, lang::ConstExpression}; @@ -20,7 +20,7 @@ pub(crate) enum IntermediateValue { impl From for IntermediateValue { fn from(value: ConstExpression) -> Self { - IntermediateValue::Constant(value) + Self::Constant(value) } } impl TryFrom for Operation { @@ -28,9 +28,9 @@ impl TryFrom for Operation { fn try_from(value: HighLevelOperation) -> Result { match value { - HighLevelOperation::Add => Ok(Operation::Add), - HighLevelOperation::Mul => Ok(Operation::Mul), - _ => Err(format!("Cannot convert {:?} to +/x", value)), + HighLevelOperation::Add => Ok(Self::Add), + HighLevelOperation::Mul => Ok(Self::Mul), + _ => Err(format!("Cannot convert {value:?} to +/x")), } } } @@ -43,12 +43,12 @@ pub(crate) enum IntermediaryMemOrFpOrConstant { } impl IntermediateValue { - pub(crate) fn label(label: Label) -> Self { + pub(crate) const fn label(label: Label) -> Self { Self::Constant(ConstExpression::label(label)) } - pub(crate) fn is_constant(&self) -> bool { - matches!(self, IntermediateValue::Constant(_)) + pub(crate) const fn is_constant(&self) -> bool { + matches!(self, Self::Constant(_)) } } @@ -64,11 +64,11 @@ pub enum HighLevelOperation { impl HighLevelOperation { pub fn eval(&self, a: F, b: F) -> F { match self { - HighLevelOperation::Add => a + b, - HighLevelOperation::Mul => a * b, - HighLevelOperation::Sub => a - b, - HighLevelOperation::Div => a / b, - HighLevelOperation::Exp => a.exp_u64(b.as_canonical_u64()), + Self::Add => a + b, + Self::Mul => a * b, + Self::Sub => a - b, + Self::Div => a / b, + Self::Exp => a.exp_u64(b.as_canonical_u64()), } } } @@ -175,7 +175,7 @@ impl IntermediateInstruction { } } - pub(crate) fn equality(left: IntermediateValue, right: IntermediateValue) -> Self { + pub(crate) const fn equality(left: IntermediateValue, right: IntermediateValue) -> Self { Self::Computation { operation: Operation::Add, arg_a: left, @@ -188,9 +188,9 @@ impl IntermediateInstruction { impl ToString for IntermediateValue { fn to_string(&self) -> String { match self { - IntermediateValue::Constant(value) => value.to_string(), - IntermediateValue::Fp => "fp".to_string(), - IntermediateValue::MemoryAfterFp { offset } => { + Self::Constant(value) => value.to_string(), + Self::Fp => "fp".to_string(), + Self::MemoryAfterFp { offset } => { format!("m[fp + {}]", offset.to_string()) } } @@ -202,7 +202,7 @@ impl ToString for IntermediaryMemOrFpOrConstant { match self { Self::MemoryAfterFp { offset } => format!("m[fp + {}]", offset.to_string()), Self::Fp => "fp".to_string(), - Self::Constant(c) => format!("{}", c.to_string()), + Self::Constant(c) => c.to_string(), } } } @@ -269,13 +269,10 @@ impl ToString for IntermediateInstruction { ) } Self::Panic => "panic".to_string(), - Self::Jump { dest, updated_fp } => { - if let Some(fp) = updated_fp { - format!("jump {} with fp = {}", dest.to_string(), fp.to_string()) - } else { - format!("jump {}", dest.to_string()) - } - } + Self::Jump { dest, updated_fp } => updated_fp.as_ref().map_or_else( + || format!("jump {}", dest.to_string()), + |fp| format!("jump {} with fp = {}", dest.to_string(), fp.to_string()), + ), Self::JumpIfNotZero { condition, dest, @@ -330,7 +327,7 @@ impl ToString for IntermediateInstruction { line_info, content .iter() - .map(|c| c.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", ") ), @@ -341,11 +338,11 @@ impl ToString for IntermediateInstruction { impl ToString for HighLevelOperation { fn to_string(&self) -> String { match self { - HighLevelOperation::Add => "+".to_string(), - HighLevelOperation::Mul => "*".to_string(), - HighLevelOperation::Sub => "-".to_string(), - HighLevelOperation::Div => "/".to_string(), - HighLevelOperation::Exp => "**".to_string(), + Self::Add => "+".to_string(), + Self::Mul => "*".to_string(), + Self::Sub => "-".to_string(), + Self::Div => "/".to_string(), + Self::Exp => "**".to_string(), } } } @@ -354,14 +351,14 @@ impl ToString for IntermediateBytecode { fn to_string(&self) -> String { let mut res = String::new(); for (label, instructions) in &self.bytecode { - res.push_str(&format!("\n{}:\n", label)); + res.push_str(&format!("\n{label}:\n")); for instruction in instructions { res.push_str(&format!(" {}\n", instruction.to_string())); } } res.push_str("\nMemory size per function:\n"); for (function_name, size) in &self.memory_size_per_function { - res.push_str(&format!("{}: {}\n", function_name, size)); + res.push_str(&format!("{function_name}: {size}\n")); } res } diff --git a/crates/compiler/src/lang.rs b/crates/compiler/src/lang.rs index a5a91b52..7175766e 100644 --- a/crates/compiler/src/lang.rs +++ b/crates/compiler/src/lang.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use p3_field::PrimeCharacteristicRing; use utils::ToUsize; -use vm::*; +use vm::Label; use crate::{F, intermediate_bytecode::HighLevelOperation, precompiles::Precompile}; @@ -51,7 +51,7 @@ impl SimpleExpr { Self::Constant(ConstantValue::Scalar(scalar).into()) } - pub(crate) fn is_constant(&self) -> bool { + pub(crate) const fn is_constant(&self) -> bool { matches!(self, Self::Constant(_)) } @@ -84,9 +84,8 @@ impl From for SimpleExpr { impl SimpleExpr { pub(crate) fn as_constant(&self) -> Option { match self { - Self::Var(_) => None, + Self::Var(_) | Self::ConstMallocAccess { .. } => None, Self::Constant(constant) => Some(constant.clone()), - Self::ConstMallocAccess { .. } => None, } } } @@ -118,7 +117,7 @@ pub(crate) enum ConstExpression { impl From for ConstExpression { fn from(value: usize) -> Self { - ConstExpression::Value(ConstantValue::Scalar(value)) + Self::Value(ConstantValue::Scalar(value)) } } @@ -128,16 +127,15 @@ impl TryFrom for ConstExpression { fn try_from(value: Expression) -> Result { match value { Expression::Value(SimpleExpr::Constant(const_expr)) => Ok(const_expr), - Expression::Value(_) => Err(()), - Expression::ArrayAccess { .. } => Err(()), + Expression::Value(_) | Expression::ArrayAccess { .. } => Err(()), Expression::Binary { left, operation, right, } => { - let left_expr = ConstExpression::try_from(*left)?; - let right_expr = ConstExpression::try_from(*right)?; - Ok(ConstExpression::Binary { + let left_expr = Self::try_from(*left)?; + let right_expr = Self::try_from(*right)?; + Ok(Self::Binary { left: Box::new(left_expr), operation, right: Box::new(right_expr), @@ -148,23 +146,23 @@ impl TryFrom for ConstExpression { } impl ConstExpression { - pub(crate) fn zero() -> Self { + pub(crate) const fn zero() -> Self { Self::scalar(0) } - pub(crate) fn one() -> Self { + pub(crate) const fn one() -> Self { Self::scalar(1) } - pub(crate) fn label(label: Label) -> Self { + pub(crate) const fn label(label: Label) -> Self { Self::Value(ConstantValue::Label(label)) } - pub(crate) fn scalar(scalar: usize) -> Self { + pub(crate) const fn scalar(scalar: usize) -> Self { Self::Value(ConstantValue::Scalar(scalar)) } - pub(crate) fn function_size(function_name: Label) -> Self { + pub(crate) const fn function_size(function_name: Label) -> Self { Self::Value(ConstantValue::FunctionSize { function_name }) } @@ -190,11 +188,8 @@ impl ConstExpression { } pub(crate) fn try_naive_simplification(&self) -> Self { - if let Some(value) = self.naive_eval() { - Self::scalar(value.to_usize()) - } else { - self.clone() - } + self.naive_eval() + .map_or_else(|| self.clone(), |value| Self::scalar(value.to_usize())) } } @@ -248,11 +243,11 @@ impl Expression { ArrayFn: Fn(&Var, F) -> Option, { match self { - Expression::Value(value) => value_fn(value), - Expression::ArrayAccess { array, index } => { + Self::Value(value) => value_fn(value), + Self::ArrayAccess { array, index } => { array_fn(array, index.eval_with(value_fn, array_fn)?) } - Expression::Binary { + Self::Binary { left, operation, right, @@ -324,11 +319,11 @@ pub(crate) enum Line { impl ToString for Expression { fn to_string(&self) -> String { match self { - Expression::Value(val) => val.to_string(), - Expression::ArrayAccess { array, index } => { + Self::Value(val) => val.to_string(), + Self::ArrayAccess { array, index } => { format!("{}[{}]", array, index.to_string()) } - Expression::Binary { + Self::Binary { left, operation, right, @@ -348,23 +343,18 @@ impl Line { fn to_string_with_indent(&self, indent: usize) -> String { let spaces = " ".repeat(indent); let line_str = match self { - Line::Assignment { var, value } => { - format!("{} = {}", var.to_string(), value.to_string()) + Self::Assignment { var, value } => { + format!("{} = {}", var, value.to_string()) } - Line::ArrayAssign { + Self::ArrayAssign { array, index, value, } => { - format!( - "{}[{}] = {}", - array.to_string(), - index.to_string(), - value.to_string() - ) + format!("{}[{}] = {}", array, index.to_string(), value.to_string()) } - Line::Assert(condition) => format!("assert {}", condition.to_string()), - Line::IfCondition { + Self::Assert(condition) => format!("assert {}", condition.to_string()), + Self::IfCondition { condition, then_branch, else_branch, @@ -399,7 +389,7 @@ impl Line { ) } } - Line::ForLoop { + Self::ForLoop { iterator, start, end, @@ -414,7 +404,7 @@ impl Line { .join("\n"); format!( "for {} in {}{}..{} {}{{\n{}\n{}}}", - iterator.to_string(), + iterator, start.to_string(), if *rev { "rev " } else { "" }, end.to_string(), @@ -423,37 +413,37 @@ impl Line { spaces ) } - Line::FunctionCall { + Self::FunctionCall { function_name, args, return_data, } => { let args_str = args .iter() - .map(|arg| arg.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); let return_data_str = return_data .iter() - .map(|var| var.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); if return_data.is_empty() { - format!("{}({})", function_name, args_str) + format!("{function_name}({args_str})") } else { - format!("{} = {}({})", return_data_str, function_name, args_str) + format!("{return_data_str} = {function_name}({args_str})") } } - Line::FunctionRet { return_data } => { + Self::FunctionRet { return_data } => { let return_data_str = return_data .iter() - .map(|arg| arg.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); - format!("return {}", return_data_str) + format!("return {return_data_str}") } - Line::Precompile { + Self::Precompile { precompile, args, res: return_data, @@ -462,63 +452,55 @@ impl Line { "{} = {}({})", return_data .iter() - .map(|var| var.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "), precompile.name.to_string(), args.iter() - .map(|arg| arg.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", ") ) } - Line::Print { + Self::Print { line_info: _, content, } => { let content_str = content .iter() - .map(|c| c.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "); - format!("print({})", content_str) + format!("print({content_str})") } - Line::MAlloc { + Self::MAlloc { var, size, vectorized, } => { if *vectorized { - format!( - "{} = malloc_vectorized({})", - var.to_string(), - size.to_string() - ) + format!("{} = malloc_vectorized({})", var, size.to_string()) } else { - format!("{} = malloc({})", var.to_string(), size.to_string()) + format!("{} = malloc({})", var, size.to_string()) } } - Line::DecomposeBits { var, to_decompose } => { - format!( - "{} = decompose_bits({})", - var.to_string(), - to_decompose.to_string() - ) + Self::DecomposeBits { var, to_decompose } => { + format!("{} = decompose_bits({})", var, to_decompose.to_string()) } - Line::Break => "break".to_string(), - Line::Panic => "panic".to_string(), + Self::Break => "break".to_string(), + Self::Panic => "panic".to_string(), }; - format!("{}{}", spaces, line_str) + format!("{spaces}{line_str}") } } impl ToString for Boolean { fn to_string(&self) -> String { match self { - Boolean::Equal { left, right } => { + Self::Equal { left, right } => { format!("{} == {}", left.to_string(), right.to_string()) } - Boolean::Different { left, right } => { + Self::Different { left, right } => { format!("{} != {}", left.to_string(), right.to_string()) } } @@ -528,13 +510,13 @@ impl ToString for Boolean { impl ToString for ConstantValue { fn to_string(&self) -> String { match self { - ConstantValue::Scalar(scalar) => scalar.to_string(), - ConstantValue::PublicInputStart => "@public_input_start".to_string(), - ConstantValue::PointerToZeroVector => "@pointer_to_zero_vector".to_string(), - ConstantValue::FunctionSize { function_name } => { - format!("@function_size_{}", function_name) + Self::Scalar(scalar) => scalar.to_string(), + Self::PublicInputStart => "@public_input_start".to_string(), + Self::PointerToZeroVector => "@pointer_to_zero_vector".to_string(), + Self::FunctionSize { function_name } => { + format!("@function_size_{function_name}") } - ConstantValue::Label(label) => label.to_string(), + Self::Label(label) => label.to_string(), } } } @@ -542,9 +524,9 @@ impl ToString for ConstantValue { impl ToString for SimpleExpr { fn to_string(&self) -> String { match self { - SimpleExpr::Var(var) => var.to_string(), - SimpleExpr::Constant(constant) => constant.to_string(), - SimpleExpr::ConstMallocAccess { + Self::Var(var) => var.to_string(), + Self::Constant(constant) => constant.to_string(), + Self::ConstMallocAccess { malloc_label, offset, } => { @@ -557,8 +539,8 @@ impl ToString for SimpleExpr { impl ToString for ConstExpression { fn to_string(&self) -> String { match self { - ConstExpression::Value(value) => value.to_string(), - ConstExpression::Binary { + Self::Value(value) => value.to_string(), + Self::Binary { left, operation, right, @@ -599,7 +581,7 @@ impl ToString for Function { .arguments .iter() .map(|arg| match arg { - (name, true) => format!("const {}", name), + (name, true) => format!("const {name}"), (name, false) => name.to_string(), }) .collect::>() diff --git a/crates/compiler/src/lib.rs b/crates/compiler/src/lib.rs index c7005751..c27ff07b 100644 --- a/crates/compiler/src/lib.rs +++ b/crates/compiler/src/lib.rs @@ -1,4 +1,4 @@ -use vm::*; +use vm::{Bytecode, F, PUBLIC_INPUT_START, ZERO_VEC_PTR, execute_bytecode}; use crate::{ a_simplify_lang::simplify_program, b_compile_intermediate::compile_to_intermediate_bytecode, @@ -14,19 +14,20 @@ mod parser; mod precompiles; pub use precompiles::PRECOMPILES; +#[must_use] pub fn compile_program(program: &str) -> Bytecode { let parsed_program = parse_program(program).unwrap(); // println!("Parsed program: {}", parsed_program.to_string()); let simple_program = simplify_program(parsed_program); // println!("Simplified program: {}", simple_program.to_string()); - let intermediate_bytecode = compile_to_intermediate_bytecode(simple_program).unwrap(); + let intermediate_bytecode = compile_to_intermediate_bytecode(&simple_program).unwrap(); // println!("Intermediate Bytecode:\n\n{}", intermediate_bytecode.to_string()); - let compiled = compile_to_low_level_bytecode(intermediate_bytecode).unwrap(); + // println!("Compiled Program:\n\n{}", compiled.to_string()); - compiled + compile_to_low_level_bytecode(intermediate_bytecode).unwrap() } pub fn compile_and_run(program: &str, public_input: &[F], private_input: &[F]) { let bytecode = compile_program(program); - execute_bytecode(&bytecode, &public_input, private_input); + let _ = execute_bytecode(&bytecode, public_input, private_input); } diff --git a/crates/compiler/src/parser.rs b/crates/compiler/src/parser.rs index 23475596..20cb1c6d 100644 --- a/crates/compiler/src/parser.rs +++ b/crates/compiler/src/parser.rs @@ -6,11 +6,17 @@ use pest_derive::Parser; use utils::ToUsize; use vm::F; -use crate::{intermediate_bytecode::*, lang::*, precompiles::PRECOMPILES}; +use crate::{ + intermediate_bytecode::HighLevelOperation, + lang::{ + Boolean, ConstExpression, ConstantValue, Expression, Function, Line, Program, SimpleExpr, + }, + precompiles::PRECOMPILES, +}; #[derive(Parser)] #[grammar = "grammar.pest"] -pub(crate) struct LangParser; +pub struct LangParser; #[derive(Debug)] pub(crate) enum ParseError { @@ -20,15 +26,15 @@ pub(crate) enum ParseError { impl From> for ParseError { fn from(error: pest::error::Error) -> Self { - ParseError::PestError(error) + Self::PestError(error) } } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ParseError::PestError(e) => write!(f, "Parse error: {}", e), - ParseError::SemanticError(e) => write!(f, "Semantic error: {}", e), + Self::PestError(e) => write!(f, "Parse error: {e}"), + Self::SemanticError(e) => write!(f, "Semantic error: {e}"), } } } @@ -92,7 +98,7 @@ fn parse_constant_declaration( }, &|_, _| None, ) - .unwrap_or_else(|| panic!("Failed to evaluate constant: {}", name)) + .unwrap_or_else(|| panic!("Failed to evaluate constant: {name}")) .to_usize(); Ok((name, value)) } @@ -146,7 +152,7 @@ fn parse_parameter(pair: Pair<'_, Rule>) -> Result<(String, bool), ParseError> { let mut inner = pair.into_inner(); let first = inner.next().unwrap(); - if let Rule::const_keyword = first.as_rule() { + if first.as_rule() == Rule::const_keyword { // If the first token is "const", the next one should be the identifier let identifier = inner.next().ok_or_else(|| { ParseError::SemanticError("Expected identifier after 'const'".to_string()) @@ -170,7 +176,7 @@ fn parse_statement( Rule::if_statement => parse_if_statement(inner, constants, trash_var_count), Rule::for_statement => parse_for_statement(inner, constants, trash_var_count), Rule::return_statement => parse_return_statement(inner, constants), - Rule::function_call => parse_function_call(inner, constants, trash_var_count), + Rule::function_call => parse_function_call(&inner, constants, trash_var_count), Rule::assert_eq_statement => parse_assert_eq(inner, constants), Rule::assert_not_eq_statement => parse_assert_not_eq(inner, constants), Rule::break_statement => Ok(Line::Break), @@ -329,7 +335,7 @@ fn parse_binary_expr( let mut inner = pair.into_inner(); let mut expr = parse_expression(inner.next().unwrap(), constants)?; - while let Some(right) = inner.next() { + for right in inner { let right_expr = parse_expression(right, constants)?; expr = Expression::Binary { left: Box::new(expr), @@ -366,7 +372,7 @@ fn parse_tuple_expression( } fn parse_function_call( - pair: Pair<'_, Rule>, + pair: &Pair<'_, Rule>, constants: &BTreeMap, trash_var_count: &mut usize, ) -> Result { @@ -402,7 +408,7 @@ fn parse_function_call( for var in &mut return_data { if var == "_" { *trash_var_count += 1; - *var = format!("@trash_{}", trash_var_count); + *var = format!("@trash_{trash_var_count}"); } } @@ -525,7 +531,7 @@ fn parse_var_or_constant( match pair.as_rule() { Rule::var_or_constant => { - return parse_var_or_constant(pair.into_inner().next().unwrap(), constants); + parse_var_or_constant(pair.into_inner().next().unwrap(), constants) } Rule::identifier | Rule::constant_value => match text { "public_input_start" => Ok(SimpleExpr::Constant(ConstExpression::Value( @@ -569,7 +575,7 @@ mod tests { #[test] fn test_parser() { - let program = r#" + let program = r" // This is a comment @@ -629,7 +635,7 @@ fn my_function1(a, const b, c) -> 2 { return e, d; } } - "#; + "; let parsed = parse_program(program).unwrap(); println!("{}", parsed.to_string()); @@ -637,12 +643,12 @@ fn my_function1(a, const b, c) -> 2 { #[test] fn test_const_parameters() { - let program = r#" + let program = r" fn test_func(const a, b, const c) -> 1 { d = a + b + c; return d; } - "#; + "; let parsed = parse_program(program).unwrap(); println!("{}", parsed.to_string()); @@ -650,7 +656,7 @@ fn test_func(const a, b, const c) -> 1 { #[test] fn test_exponent_operation() { - let program = r#" + let program = r" fn test_exp() -> 1 { a = 2 ** 3; b = x ** y ** z; // Should parse as x ** (y ** z) @@ -658,7 +664,7 @@ fn test_exp() -> 1 { d = a ** 2 * b; // Should parse as (a ** 2) * b return a; } - "#; + "; let parsed = parse_program(program).unwrap(); println!("{}", parsed.to_string()); diff --git a/crates/compiler/src/precompiles.rs b/crates/compiler/src/precompiles.rs index 8714974e..9e53d05d 100644 --- a/crates/compiler/src/precompiles.rs +++ b/crates/compiler/src/precompiles.rs @@ -16,10 +16,10 @@ pub enum PrecompileName { impl ToString for PrecompileName { fn to_string(&self) -> String { match self { - PrecompileName::Poseidon16 => "poseidon16", - PrecompileName::Poseidon24 => "poseidon24", - PrecompileName::DotProduct => "dot_product", - PrecompileName::MultilinearEval => "multilinear_eval", + Self::Poseidon16 => "poseidon16", + Self::Poseidon24 => "poseidon24", + Self::DotProduct => "dot_product", + Self::MultilinearEval => "multilinear_eval", } .to_string() } diff --git a/crates/leanVm/src/bytecode/operand.rs b/crates/leanVm/src/bytecode/operand.rs index fbe16184..d6f29667 100644 --- a/crates/leanVm/src/bytecode/operand.rs +++ b/crates/leanVm/src/bytecode/operand.rs @@ -33,12 +33,14 @@ impl MemOrConstant { } /// Returns a constant operand with value `0`. - #[must_use] pub const fn zero() -> Self { + #[must_use] + pub const fn zero() -> Self { Self::Constant(F::ZERO) } /// Returns a constant operand with value `1`. - #[must_use] pub const fn one() -> Self { + #[must_use] + pub const fn one() -> Self { Self::Constant(F::ONE) } } diff --git a/crates/leanVm/src/intermediate_bytecode/instruction.rs b/crates/leanVm/src/intermediate_bytecode/instruction.rs index 4f414a7e..d5080168 100644 --- a/crates/leanVm/src/intermediate_bytecode/instruction.rs +++ b/crates/leanVm/src/intermediate_bytecode/instruction.rs @@ -116,7 +116,8 @@ impl IntermediateInstruction { } } - #[must_use] pub const fn is_hint(&self) -> bool { + #[must_use] + pub const fn is_hint(&self) -> bool { match self { Self::RequestMemory { .. } | Self::Print { .. } diff --git a/crates/lookup/src/logup_star.rs b/crates/lookup/src/logup_star.rs index 96d3c1e2..c04b3e26 100644 --- a/crates/lookup/src/logup_star.rs +++ b/crates/lookup/src/logup_star.rs @@ -28,7 +28,7 @@ pub struct LogupStarStatements { } #[instrument(skip_all)] -pub fn prove_logup_star>( +pub fn prove_logup_star( prover_state: &mut FSProver>, table: &[IF], indexes: &[PF], @@ -37,7 +37,8 @@ pub fn prove_logup_star>( pushforward: &[EF], // already commited ) -> LogupStarStatements where - EF: ExtensionField>, + IF: Field, + EF: ExtensionField> + ExtensionField, PF: PrimeField64, { let table_length = table.len(); @@ -50,8 +51,8 @@ where .in_scope(|| { ( pack_extension(&table_embedded), - pack_extension(&poly_eq_point), - pack_extension(&pushforward), + pack_extension(poly_eq_point), + pack_extension(pushforward), ) }); @@ -82,7 +83,7 @@ where prover_state.add_extension_scalar(pushforwardt_eval); // delayed opening let mut on_pushforward = vec![Evaluation { - point: sc_point.clone(), + point: sc_point, value: pushforwardt_eval, }]; @@ -150,7 +151,7 @@ where } } -pub fn verify_logup_star( +pub fn verify_logup_star( verifier_state: &mut FSVerifier>, log_table_len: usize, log_indexes_len: usize, @@ -318,10 +319,7 @@ mod tests { let time = std::time::Instant::now(); let poly_eq_point = info_span!("eval_eq").in_scope(|| eval_eq(&point.0)); let pushforward = compute_pushforward(&indexes, table_length, &poly_eq_point); - let claim = Evaluation { - point: point.clone(), - value: eval, - }; + let claim = Evaluation { point, value: eval }; prove_logup_star( &mut prover_state, diff --git a/crates/lookup/src/product_gkr.rs b/crates/lookup/src/product_gkr.rs index 1c18d26c..54677602 100644 --- a/crates/lookup/src/product_gkr.rs +++ b/crates/lookup/src/product_gkr.rs @@ -7,7 +7,7 @@ with custom GKR */ -use p3_field::{ExtensionField, Field, PrimeCharacteristicRing, PrimeField64}; +use p3_field::{ExtensionField, PrimeCharacteristicRing, PrimeField64}; use rayon::prelude::*; use sumcheck::{MleGroupRef, ProductComputation}; use tracing::instrument; @@ -30,7 +30,7 @@ A': [a0*a4, a1*a5, a2*a6, a3*a7] */ #[instrument(skip_all)] -pub fn prove_gkr_product( +pub fn prove_gkr_product( prover_state: &mut FSProver>, final_layer: Vec>, ) -> (EF, Evaluation) @@ -58,7 +58,7 @@ where assert_eq!(layers_not_packed[n - last_packed - 2].len(), 2); let product = layers_not_packed[n - last_packed - 2] .iter() - .cloned() + .copied() .product::(); prover_state.add_extension_scalars(&layers_not_packed[n - last_packed - 2]); @@ -79,7 +79,7 @@ where } #[instrument(skip_all)] -fn prove_gkr_product_step( +fn prove_gkr_product_step( prover_state: &mut FSProver>, up_layer: &[EF], claim: &Evaluation, @@ -97,7 +97,7 @@ where } #[instrument(skip_all)] -fn prove_gkr_product_step_packed( +fn prove_gkr_product_step_packed( prover_state: &mut FSProver>, up_layer_packed: &[EFPacking], claim: &Evaluation, @@ -118,7 +118,7 @@ where } #[instrument(skip_all)] -fn prove_gkr_product_step_core( +fn prove_gkr_product_step_core( prover_state: &mut FSProver>, up_layer: MleGroupRef<'_, EF>, claim: &Evaluation, @@ -143,7 +143,7 @@ where let mixing_challenge = prover_state.sample(); - let mut next_point = sc_point.clone(); + let mut next_point = sc_point; next_point.0.insert(0, mixing_challenge); let next_claim = inner_evals[0] * (EF::ONE - mixing_challenge) + inner_evals[1] * mixing_challenge; @@ -151,7 +151,7 @@ where (next_point, next_claim).into() } -pub fn verify_gkr_product( +pub fn verify_gkr_product( verifier_state: &mut FSVerifier>, n_vars: usize, ) -> Result<(EF, Evaluation), ProofError> @@ -176,7 +176,7 @@ where Ok((product, claim)) } -fn verify_gkr_product_step( +fn verify_gkr_product_step( verifier_state: &mut FSVerifier>, current_layer_log_len: usize, claim: &Evaluation, @@ -201,7 +201,7 @@ where let mixing_challenge = verifier_state.sample(); - let mut next_point = postponed.point.clone(); + let mut next_point = postponed.point; next_point.0.insert(0, mixing_challenge); let next_claim = eval_left * (EF::ONE - mixing_challenge) + eval_right * mixing_challenge; diff --git a/crates/lookup/src/quotient_gkr.rs b/crates/lookup/src/quotient_gkr.rs index 2c16affd..0f780f9f 100644 --- a/crates/lookup/src/quotient_gkr.rs +++ b/crates/lookup/src/quotient_gkr.rs @@ -8,7 +8,7 @@ with custom GKR */ use p3_field::{ - ExtensionField, Field, PackedFieldExtension, PrimeCharacteristicRing, PrimeField64, dot_product, + ExtensionField, PackedFieldExtension, PrimeCharacteristicRing, PrimeField64, dot_product, }; use rayon::prelude::*; use sumcheck::{Mle, MleGroupRef, SumcheckComputation, SumcheckComputationPacked}; @@ -49,7 +49,7 @@ with: U0 = AB(0 0 --- ) */ #[instrument(skip_all)] -pub(crate) fn prove_gkr_quotient( +pub(crate) fn prove_gkr_quotient( prover_state: &mut FSProver>, final_layer: Vec>, ) -> (Evaluation, EF, EF) @@ -97,7 +97,7 @@ where } #[instrument(skip_all)] -fn prove_gkr_quotient_step( +fn prove_gkr_quotient_step( prover_state: &mut FSProver>, up_layer: &[EF], claim: &Evaluation, @@ -196,7 +196,7 @@ where let mixing_challenge_a = prover_state.sample(); let mixing_challenge_b = prover_state.sample(); - let mut next_point = sc_point.clone(); + let mut next_point = sc_point; next_point.0.insert(0, mixing_challenge_a); next_point.0[1] = mixing_challenge_b; @@ -218,9 +218,9 @@ where } #[instrument(skip_all)] -fn prove_gkr_quotient_step_packed( +fn prove_gkr_quotient_step_packed( prover_state: &mut FSProver>, - up_layer_packed: &Vec>, + up_layer_packed: &[EFPacking], claim: &Evaluation, ) -> (Evaluation, EF, EF) where @@ -349,7 +349,7 @@ where ) } -pub(crate) fn verify_gkr_quotient( +pub(crate) fn verify_gkr_quotient( verifier_state: &mut FSVerifier>, n_vars: usize, ) -> Result<(EF, Evaluation), ProofError> @@ -374,7 +374,7 @@ where Ok((quotient, claim)) } -fn verify_gkr_quotient_step( +fn verify_gkr_quotient_step( verifier_state: &mut FSVerifier>, current_layer_log_len: usize, claim: &Evaluation, @@ -406,7 +406,7 @@ where let mixing_challenge_a = verifier_state.sample(); let mixing_challenge_b = verifier_state.sample(); - let mut next_point = postponed.point.clone(); + let mut next_point = postponed.point; next_point.0.insert(0, mixing_challenge_a); next_point.0[1] = mixing_challenge_b; @@ -414,13 +414,12 @@ where [q0, q1, q2, q3].into_iter(), eval_eq(&[mixing_challenge_a, mixing_challenge_b]) .iter() - .cloned(), + .copied(), ); Ok((next_point, next_claim).into()) } -#[derive(Debug)] pub(crate) struct GKRQuotientComputation { u4_const: EF, u5_const: EF, diff --git a/crates/pcs/src/batch_pcs.rs b/crates/pcs/src/batch_pcs.rs index 7ce8295d..2165070c 100644 --- a/crates/pcs/src/batch_pcs.rs +++ b/crates/pcs/src/batch_pcs.rs @@ -55,10 +55,9 @@ pub struct WhirBatchPcs( impl BatchPCS for WhirBatchPcs where - F: TwoAdicField, + F: TwoAdicField + ExtensionField>, PF: TwoAdicField, EF: ExtensionField + TwoAdicField + ExtensionField>, - F: ExtensionField>, H: CryptographicHasher, [PF; DIGEST_ELEMS]> + CryptographicHasher, [PFPacking; DIGEST_ELEMS]> + Sync, diff --git a/crates/pcs/src/combinatorics.rs b/crates/pcs/src/combinatorics.rs index 4f13c453..de62f86e 100644 --- a/crates/pcs/src/combinatorics.rs +++ b/crates/pcs/src/combinatorics.rs @@ -1,5 +1,4 @@ -use std::cmp::Reverse; -use std::collections::BinaryHeap; +use std::{cmp::Reverse, collections::BinaryHeap}; #[derive(Debug, Clone)] pub struct TreeOfVariables { @@ -43,8 +42,8 @@ impl TreeOfVariables { impl TreeOfVariablesInner { pub fn total_vars(&self, vars_per_polynomial: &[usize]) -> usize { match self { - TreeOfVariablesInner::Polynomial(i) => vars_per_polynomial[*i], - TreeOfVariablesInner::Composed { left, right } => { + Self::Polynomial(i) => vars_per_polynomial[*i], + Self::Composed { left, right } => { 1 + left .total_vars(vars_per_polynomial) .max(right.total_vars(vars_per_polynomial)) @@ -61,8 +60,8 @@ impl TreeOfVariables { let root = Self::compute_greedy(&vars_per_polynomial); Self { - root, vars_per_polynomial, + root, } } @@ -107,7 +106,7 @@ mod tests { #[test] fn test_tree_of_variables() { let vars_per_polynomial = vec![2]; - let tree = TreeOfVariables::compute_optimal(vars_per_polynomial.clone()); + let tree = TreeOfVariables::compute_optimal(vars_per_polynomial); dbg!(&tree, tree.total_vars()); } } diff --git a/crates/pcs/src/packed_pcs.rs b/crates/pcs/src/packed_pcs.rs index 713bf4ef..878f3894 100644 --- a/crates/pcs/src/packed_pcs.rs +++ b/crates/pcs/src/packed_pcs.rs @@ -27,6 +27,7 @@ pub fn num_packed_vars_for_pols(polynomials: &[&[F]]) -> usize { TreeOfVariables::compute_optimal(vars_per_polynomial).total_vars() } +#[must_use] pub fn num_packed_vars_for_vars(vars_per_polynomial: &[usize]) -> usize { TreeOfVariables::compute_optimal(vars_per_polynomial.to_vec()).total_vars() } @@ -57,11 +58,12 @@ pub fn packed_pcs_commit, Pcs: PCS>( } } +#[must_use] pub fn packed_pcs_global_statements( tree: &TreeOfVariables, statements_per_polynomial: &[Vec>], ) -> Vec> { - check_tree(&tree, statements_per_polynomial).expect("Invalid tree structure for multi-open"); + check_tree(tree, statements_per_polynomial); tree.root .global_statement(&tree.vars_per_polynomial, statements_per_polynomial, &[]) @@ -89,7 +91,7 @@ pub fn packed_pcs_parse_commitment, Pcs: PCS( tree: &TreeOfVariables, statements_per_polynomial: &[Vec>], -) -> Result<(), ProofError> { +) { assert_eq!( statements_per_polynomial.len(), tree.vars_per_polynomial.len() @@ -102,7 +104,6 @@ fn check_tree( assert_eq!(eval.point.len(), vars); } } - Ok(()) } impl TreeOfVariablesInner { @@ -113,11 +114,11 @@ impl TreeOfVariablesInner { vars_per_polynomial: &[usize], ) { match self { - TreeOfVariablesInner::Polynomial(i) => { + Self::Polynomial(i) => { let len = polynomials[*i].len(); buff[..len].copy_from_slice(polynomials[*i]); } - TreeOfVariablesInner::Composed { left, right } => { + Self::Composed { left, right } => { let (left_buff, right_buff) = buff.split_at_mut(buff.len() / 2); let left_buff = &mut left_buff[..1 << left.total_vars(vars_per_polynomial)]; let right_buff = &mut right_buff[..1 << right.total_vars(vars_per_polynomial)]; @@ -136,7 +137,7 @@ impl TreeOfVariablesInner { selectors: &[EF], ) -> Vec> { match self { - TreeOfVariablesInner::Polynomial(i) => { + Self::Polynomial(i) => { let mut res = Vec::new(); for eval in &statements_per_polynomial[*i] { res.push(Evaluation { @@ -148,7 +149,7 @@ impl TreeOfVariablesInner { } res } - TreeOfVariablesInner::Composed { left, right } => { + Self::Composed { left, right } => { let left_vars = left.total_vars(vars_per_polynomial); let right_vars = right.total_vars(vars_per_polynomial); @@ -232,7 +233,10 @@ mod tests { let mut prover_state = build_prover_state(); let dft = EvalsDft::::default(); - let polynomials_refs = polynomials.iter().map(|p| p.as_slice()).collect::>(); + let polynomials_refs = polynomials + .iter() + .map(std::vec::Vec::as_slice) + .collect::>(); let witness = packed_pcs_commit(&pcs, &polynomials_refs, &dft, &mut prover_state); let packed_statements = diff --git a/crates/pcs/src/pcs.rs b/crates/pcs/src/pcs.rs index 72c21b27..dfe0fc17 100644 --- a/crates/pcs/src/pcs.rs +++ b/crates/pcs/src/pcs.rs @@ -66,10 +66,9 @@ where impl PCS for WhirConfigBuilder where - F: TwoAdicField, + F: TwoAdicField + ExtensionField>, PF: TwoAdicField, EF: ExtensionField + TwoAdicField + ExtensionField>, - F: ExtensionField>, H: CryptographicHasher, [PF; DIGEST_ELEMS]> + CryptographicHasher, [PFPacking; DIGEST_ELEMS]> + Sync, diff --git a/crates/pcs/src/ring_switch.rs b/crates/pcs/src/ring_switch.rs index 9c7cd387..3582ca62 100644 --- a/crates/pcs/src/ring_switch.rs +++ b/crates/pcs/src/ring_switch.rs @@ -23,7 +23,7 @@ pub struct RingSwitching { impl RingSwitching { - pub fn new(inner_pcs: InnerPcs) -> Self { + pub const fn new(inner_pcs: InnerPcs) -> Self { Self { inner_pcs, f: PhantomData, @@ -60,7 +60,7 @@ impl< witness: Self::Witness, polynomial: &[F], ) { - let _span = info_span!("RingSwitching::open").entered(); + let span = info_span!("RingSwitching::open").entered(); assert_eq!(statements.len(), 1); let eval = &statements[0]; assert_eq!(>::DIMENSION, EXTENSION_DEGREE); @@ -70,7 +70,7 @@ impl< let point = &eval.point; let packed_point = &point[..point.len() - kappa]; let eval_eq_packed_point = - info_span!("eval_eq of packed_point").in_scope(|| eval_eq(&packed_point)); + info_span!("eval_eq of packed_point").in_scope(|| eval_eq(packed_point)); let s_hat = eval_mixed_tensor::(transmuted_pol, &eval_eq_packed_point); @@ -114,11 +114,11 @@ impl< ); let packed_eval = Evaluation { - point: r_p.clone(), + point: r_p, value: packed_value, }; let inner_statements = vec![packed_eval]; - _span.exit(); + span.exit(); self.inner_pcs.open( dft, prover_state, @@ -211,11 +211,7 @@ fn get_s_prime, const EXTENSION_DEGREE: usize>( + e.scale_columns(EF::ONE - r).scale_rows(EF::ONE - r_prime); } - sc_value - / dot_product( - e.rows::().into_iter(), - lagranged_r_pp.into_iter().cloned(), - ) + sc_value / dot_product(e.rows::().into_iter(), lagranged_r_pp.iter().copied()) } struct TensorAlgebra([[F; D]; D]); @@ -272,7 +268,7 @@ impl TensorAlgebra { Self(data) } - fn one() -> Self { + const fn one() -> Self { let mut res = [[F::ZERO; D]; D]; res[0][0] = F::ONE; Self(res) @@ -371,6 +367,7 @@ mod tests { const DIMENSION: usize = 8; type F = KoalaBear; type EF = BinomialExtensionField; + type InnerPcs = WhirConfigBuilder; #[test] fn test_ring_switching() { @@ -406,8 +403,6 @@ mod tests { - inner_pcs.folding_factor.at_round(0)), ); - type InnerPcs = WhirConfigBuilder; - let statement = vec![Evaluation { point: point.clone(), value: polynomial.evaluate(&point), diff --git a/crates/sumcheck/src/mle.rs b/crates/sumcheck/src/mle.rs index 7427d7c0..1e6897df 100644 --- a/crates/sumcheck/src/mle.rs +++ b/crates/sumcheck/src/mle.rs @@ -23,7 +23,7 @@ pub enum MleGroup<'a, EF: ExtensionField>> { Ref(MleGroupRef<'a, EF>), } -impl<'a, EF: ExtensionField>> From> for MleGroup<'a, EF> { +impl>> From> for MleGroup<'_, EF> { fn from(owned: MleGroupOwned) -> Self { MleGroup::Owned(owned) } @@ -52,6 +52,7 @@ pub enum MleGroupRef<'a, EF: ExtensionField>> { } impl<'a, EF: ExtensionField>> MleGroup<'a, EF> { + #[must_use] pub fn by_ref(&'a self) -> MleGroupRef<'a, EF> { match self { Self::Owned(owned) => owned.by_ref(), @@ -59,6 +60,7 @@ impl<'a, EF: ExtensionField>> MleGroup<'a, EF> { } } + #[must_use] pub fn n_vars(&self) -> usize { match self { Self::Owned(owned) => owned.n_vars(), @@ -66,7 +68,8 @@ impl<'a, EF: ExtensionField>> MleGroup<'a, EF> { } } - pub fn n_columns(&self) -> usize { + #[must_use] + pub const fn n_columns(&self) -> usize { match self { Self::Owned(owned) => owned.n_columns(), Self::Ref(r) => r.n_columns(), @@ -75,21 +78,25 @@ impl<'a, EF: ExtensionField>> MleGroup<'a, EF> { } impl>> MleGroupOwned { - pub fn by_ref<'a>(&'a self) -> MleGroupRef<'a, EF> { + #[must_use] + pub fn by_ref(&self) -> MleGroupRef<'_, EF> { match self { - Self::Base(base) => MleGroupRef::Base(base.iter().map(|v| v.as_slice()).collect()), + Self::Base(base) => { + MleGroupRef::Base(base.iter().map(std::vec::Vec::as_slice).collect()) + } Self::Extension(ext) => { - MleGroupRef::Extension(ext.iter().map(|v| v.as_slice()).collect()) + MleGroupRef::Extension(ext.iter().map(std::vec::Vec::as_slice).collect()) } Self::BasePacked(packed_base) => { - MleGroupRef::BasePacked(packed_base.iter().map(|v| v.as_slice()).collect()) - } - Self::ExtensionPacked(ext_packed) => { - MleGroupRef::ExtensionPacked(ext_packed.iter().map(|v| v.as_slice()).collect()) + MleGroupRef::BasePacked(packed_base.iter().map(std::vec::Vec::as_slice).collect()) } + Self::ExtensionPacked(ext_packed) => MleGroupRef::ExtensionPacked( + ext_packed.iter().map(std::vec::Vec::as_slice).collect(), + ), } } + #[must_use] pub fn n_vars(&self) -> usize { match self { Self::Base(v) => log2_strict_usize(v[0].len()), @@ -99,7 +106,8 @@ impl>> MleGroupOwned { } } - pub fn n_columns(&self) -> usize { + #[must_use] + pub const fn n_columns(&self) -> usize { match self { Self::Base(v) => v.len(), Self::Extension(v) => v.len(), @@ -110,7 +118,8 @@ impl>> MleGroupOwned { } impl>> Mle { - pub fn packed_len(&self) -> usize { + #[must_use] + pub const fn packed_len(&self) -> usize { match self { Self::Base(v) => v.len(), Self::Extension(v) => v.len(), @@ -119,7 +128,8 @@ impl>> Mle { } } - pub fn unpacked_len(&self) -> usize { + #[must_use] + pub const fn unpacked_len(&self) -> usize { let mut res = self.packed_len(); if self.is_packed() { res *= packing_width::(); @@ -127,6 +137,7 @@ impl>> Mle { res } + #[must_use] pub fn n_vars(&self) -> usize { log2_strict_usize(self.unpacked_len()) } @@ -140,35 +151,40 @@ impl>> Mle { } } - pub fn is_packed(&self) -> bool { + #[must_use] + pub const fn is_packed(&self) -> bool { match self { Self::Base(_) | Self::Extension(_) => false, Self::PackedBase(_) | Self::ExtensionPacked(_) => true, } } - pub fn as_base(&self) -> Option<&Vec>> { + #[must_use] + pub const fn as_base(&self) -> Option<&Vec>> { match self { Self::Base(b) => Some(b), _ => None, } } - pub fn as_extension(&self) -> Option<&Vec> { + #[must_use] + pub const fn as_extension(&self) -> Option<&Vec> { match self { Self::Extension(e) => Some(e), _ => None, } } - pub fn as_packed_base(&self) -> Option<&Vec>> { + #[must_use] + pub const fn as_packed_base(&self) -> Option<&Vec>> { match self { Self::PackedBase(pb) => Some(pb), _ => None, } } - pub fn as_extension_packed(&self) -> Option<&Vec>> { + #[must_use] + pub const fn as_extension_packed(&self) -> Option<&Vec>> { match self { Self::ExtensionPacked(ep) => Some(ep), _ => None, @@ -177,7 +193,8 @@ impl>> Mle { } impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { - pub fn group_size(&self) -> usize { + #[must_use] + pub const fn group_size(&self) -> usize { match self { Self::Base(v) => v.len(), Self::Extension(v) => v.len(), @@ -186,6 +203,7 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } } + #[must_use] pub fn n_vars(&self) -> usize { match self { Self::Base(v) => log2_strict_usize(v[0].len()), @@ -195,42 +213,48 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } } - pub fn is_packed(&self) -> bool { + #[must_use] + pub const fn is_packed(&self) -> bool { match self { Self::BasePacked(_) | Self::ExtensionPacked(_) => true, Self::Base(_) | Self::Extension(_) => false, } } - pub fn as_base(&self) -> Option<&Vec<&'a [PF]>> { + #[must_use] + pub const fn as_base(&self) -> Option<&Vec<&'a [PF]>> { match self { Self::Base(b) => Some(b), _ => None, } } - pub fn as_extension(&self) -> Option<&Vec<&'a [EF]>> { + #[must_use] + pub const fn as_extension(&self) -> Option<&Vec<&'a [EF]>> { match self { Self::Extension(e) => Some(e), _ => None, } } - pub fn as_packed_base(&self) -> Option<&Vec<&'a [PFPacking]>> { + #[must_use] + pub const fn as_packed_base(&self) -> Option<&Vec<&'a [PFPacking]>> { match self { Self::BasePacked(pb) => Some(pb), _ => None, } } - pub fn as_extension_packed(&self) -> Option<&Vec<&'a [EFPacking]>> { + #[must_use] + pub const fn as_extension_packed(&self) -> Option<&Vec<&'a [EFPacking]>> { match self { Self::ExtensionPacked(ep) => Some(ep), _ => None, } } - pub fn n_columns(&self) -> usize { + #[must_use] + pub const fn n_columns(&self) -> usize { match self { Self::Base(v) => v.len(), Self::Extension(v) => v.len(), @@ -239,6 +263,7 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } } + #[must_use] pub fn pack(&self) -> MleGroup<'a, EF> { match self { Self::Base(base) => MleGroupRef::BasePacked( @@ -257,6 +282,7 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } // Clone everything in the group, should not be used when n_vars is large + #[must_use] pub fn unpack(&self) -> MleGroupOwned { match self { Self::Base(pols) => MleGroupOwned::Base(pols.iter().map(|v| v.to_vec()).collect()), @@ -302,6 +328,7 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } } + #[allow(clippy::too_many_lines)] pub fn sumcheck_compute( &self, zs: &[usize], @@ -354,7 +381,7 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } unsafe { - let sum_ptr = all_sums.as_ptr() as *mut EFPacking; + let sum_ptr = all_sums.as_ptr().cast_mut(); *sum_ptr.add(z_index * packed_fold_size + i) = res; } } @@ -409,7 +436,7 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } unsafe { - let sum_ptr = all_sums.as_ptr() as *mut EFPacking; + let sum_ptr = all_sums.as_ptr().cast_mut(); *sum_ptr.add(z_index * packed_fold_size + i) = res; } } @@ -462,6 +489,7 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } } +#[allow(clippy::too_many_arguments)] pub fn sumcheck_compute_not_packed< EF: ExtensionField> + ExtensionField, IF: ExtensionField>, @@ -502,7 +530,7 @@ where }) .collect::>(); unsafe { - let sum_ptr = all_sums.as_ptr() as *mut EF; + let sum_ptr = all_sums.as_ptr().cast_mut(); let mut res = computation.eval(&point, batching_scalars); if let Some(eq_mle_eval) = eq_mle_eval { res *= eq_mle_eval; diff --git a/crates/sumcheck/src/prove.rs b/crates/sumcheck/src/prove.rs index a411aeb1..3ff77726 100644 --- a/crates/sumcheck/src/prove.rs +++ b/crates/sumcheck/src/prove.rs @@ -1,18 +1,14 @@ -use p3_field::ExtensionField; -use p3_field::PrimeCharacteristicRing; +use p3_field::{ExtensionField, PrimeCharacteristicRing}; use rayon::prelude::*; -use utils::pack_extension; -use utils::packing_log_width; -use utils::unpack_extension; -use utils::{FSProver, PF, univariate_selectors}; -use whir_p3::fiat_shamir::FSChallenger; -use whir_p3::poly::dense::WhirDensePolynomial; -use whir_p3::poly::evals::eval_eq; -use whir_p3::poly::multilinear::MultilinearPoint; +use utils::{ + FSProver, PF, pack_extension, packing_log_width, univariate_selectors, unpack_extension, +}; +use whir_p3::{ + fiat_shamir::FSChallenger, + poly::{dense::WhirDensePolynomial, evals::eval_eq, multilinear::MultilinearPoint}, +}; -use crate::MleGroup; -use crate::SumcheckComputation; -use crate::{Mle, MySumcheckComputation}; +use crate::{Mle, MleGroup, MySumcheckComputation, SumcheckComputation}; #[allow(clippy::too_many_arguments)] pub fn prove<'a, EF, SC>( @@ -33,8 +29,8 @@ where let (challenges, mut final_folds, mut sum) = prove_in_parallel_1( vec![skips], vec![multilinears], - vec![computation], - vec![batching_scalars], + &[computation], + &[batching_scalars], vec![eq_factor], vec![is_zerofier], prover_state, @@ -46,12 +42,12 @@ where (challenges, final_folds.pop().unwrap(), sum.pop().unwrap()) } -#[allow(clippy::too_many_arguments)] +#[allow(clippy::too_many_arguments, clippy::type_complexity)] pub fn prove_in_parallel_1<'a, EF, SC, M: Into>>( skips: Vec, // skips == 1: classic sumcheck. skips >= 2: sumcheck with univariate skips (eprint 2024/108) multilinears: Vec, - computations: Vec<&SC>, - batching_scalars: Vec<&[EF]>, + computations: &[&SC], + batching_scalars: &[&[EF]], eq_factors: Vec, Option>)>>, // (a, b, c ...), eq_poly(b, c, ...) is_zerofier: Vec, prover_state: &mut FSProver>, @@ -67,8 +63,8 @@ where skips, multilinears, computations, - vec![], - vec![], + &[], + &[], batching_scalars, eq_factors, is_zerofier, @@ -83,10 +79,10 @@ where pub fn prove_in_parallel_3<'a, EF, SC1, SC2, SC3, M: Into>>( mut skips: Vec, // skips == 1: classic sumcheck. skips >= 2: sumcheck with univariate skips (eprint 2024/108) multilinears: Vec, - computations_1: Vec<&SC1>, - computations_2: Vec<&SC2>, - computations_3: Vec<&SC3>, - batching_scalars: Vec<&[EF]>, + computations_1: &[&SC1], + computations_2: &[&SC2], + computations_3: &[&SC3], + batching_scalars: &[&[EF]], mut eq_factors: Vec, Option>)>>, // (a, b, c ...), eq_poly(b, c, ...) mut is_zerofier: Vec, prover_state: &mut FSProver>, @@ -153,14 +149,12 @@ where }; // If Packing is enabled, and there are too little variables, we unpack everything: for &i in &concerned_sumchecks { - if multilinears[i].by_ref().is_packed() { - if n_vars[i] <= 1 + packing_log_width::() { - // unpack - multilinears[i] = multilinears[i].by_ref().unpack().into(); - if let Some((_, eq_mle)) = &mut eq_factors[i] { - *eq_mle = - Mle::Extension(unpack_extension(eq_mle.as_extension_packed().unwrap())); - } + if multilinears[i].by_ref().is_packed() && n_vars[i] <= 1 + packing_log_width::() { + // unpack + multilinears[i] = multilinears[i].by_ref().unpack().into(); + if let Some((_, eq_mle)) = &mut eq_factors[i] { + *eq_mle = + Mle::Extension(unpack_extension(eq_mle.as_extension_packed().unwrap())); } } } @@ -172,7 +166,7 @@ where skips[i], &multilinears[i], computations_1[i], - &eq_factors[i], + eq_factors[i].as_ref(), batching_scalars[i], is_zerofier[i], prover_state, @@ -184,7 +178,7 @@ where skips[i], &multilinears[i], computations_2[i - computations_1.len()], - &eq_factors[i], + eq_factors[i].as_ref(), batching_scalars[i], is_zerofier[i], prover_state, @@ -196,7 +190,7 @@ where skips[i], &multilinears[i], computations_3[i - computations_1.len() - computations_2.len()], - &eq_factors[i], + eq_factors[i].as_ref(), batching_scalars[i], is_zerofier[i], prover_state, @@ -231,7 +225,7 @@ where .by_ref() .as_extension() .unwrap() - .into_iter() + .iter() .map(|m| { assert_eq!(m.len(), 1); m[0] @@ -244,11 +238,11 @@ where } #[allow(clippy::too_many_arguments)] -fn compute_and_send_polynomial<'a, EF, SC>( +fn compute_and_send_polynomial( skips: usize, // the first round will fold 2^skips (instead of 2 in the basic sumcheck) - multilinears: &MleGroup<'a, EF>, + multilinears: &MleGroup<'_, EF>, computation: &SC, - eq_factor: &Option<(Vec, Mle)>, // (a, b, c ...), eq_poly(b, c, ...) + eq_factor: Option<&(Vec, Mle)>, // (a, b, c ...), eq_poly(b, c, ...) batching_scalars: &[EF], is_zerofier: bool, prover_state: &mut FSProver>, @@ -336,9 +330,9 @@ where } #[allow(clippy::too_many_arguments)] -fn on_challenge_received<'a, EF>( +fn on_challenge_received( skips: usize, // the first round will fold 2^skips (instead of 2 in the basic sumcheck) - multilinears: &mut MleGroup<'a, EF>, + multilinears: &mut MleGroup<'_, EF>, n_vars: &mut usize, eq_factor: &mut Option<(Vec, Mle)>, // (a, b, c ...), eq_poly(b, c, ...) sum: &mut EF, @@ -373,5 +367,5 @@ fn on_challenge_received<'a, EF>( *multilinears = multilinears .by_ref() .fold_in_large_field(&folding_scalars) - .into() + .into(); } diff --git a/crates/sumcheck/src/sc_computation.rs b/crates/sumcheck/src/sc_computation.rs index 57c876e0..7705a92a 100644 --- a/crates/sumcheck/src/sc_computation.rs +++ b/crates/sumcheck/src/sc_computation.rs @@ -1,7 +1,7 @@ use std::any::TypeId; use p3_air::Air; -use p3_field::{ExtensionField, Field}; +use p3_field::ExtensionField; use p3_matrix::dense::RowMajorMatrixView; use utils::{ ConstraintFolder, ConstraintFolderPackedBase, ConstraintFolderPackedExtension, EFPacking, PF, @@ -62,7 +62,7 @@ where fn degree(&self) -> usize; } -impl SumcheckComputationPacked for A +impl SumcheckComputationPacked for A where EF: ExtensionField>, A: for<'a> Air> @@ -76,7 +76,7 @@ where } let mut folder = ConstraintFolderPackedBase { main: RowMajorMatrixView::new(point, A::width(self)), - alpha_powers: alpha_powers, + alpha_powers, accumulator: Default::default(), constraint_index: 0, }; @@ -93,7 +93,7 @@ where } let mut folder = ConstraintFolderPackedExtension { main: RowMajorMatrixView::new(point, A::width(self)), - alpha_powers: alpha_powers, + alpha_powers, accumulator: Default::default(), constraint_index: 0, }; diff --git a/crates/sumcheck/src/verify.rs b/crates/sumcheck/src/verify.rs index 5f207743..fb02f64a 100644 --- a/crates/sumcheck/src/verify.rs +++ b/crates/sumcheck/src/verify.rs @@ -100,8 +100,8 @@ where verify_core_in_parallel( verifier_state, - max_degree_per_vars, - sumation_sets, + &max_degree_per_vars, + &sumation_sets, share_initial_challenges, ) } @@ -116,8 +116,8 @@ where { let (sum, challenge_point, challenge_value) = verify_core_in_parallel( verifier_state, - vec![max_degree_per_vars], - vec![sumation_sets], + &[max_degree_per_vars], + &[sumation_sets], true, )?; Ok(( @@ -131,8 +131,8 @@ where fn verify_core_in_parallel( verifier_state: &mut FSVerifier>, - max_degree_per_vars: Vec>, - sumation_sets: Vec>>, + max_degree_per_vars: &[Vec], + sumation_sets: &[Vec>], share_initial_challenges: bool, // otherwise, share the final challenges ) -> Result<(Vec, MultilinearPoint, Vec), ProofError> where @@ -150,7 +150,7 @@ where let n_vars = max_degree_per_vars .iter() - .map(|v| v.len()) + .map(std::vec::Vec::len) .collect::>(); let max_n_vars = Iterator::max(n_vars.iter().copied()).unwrap(); diff --git a/crates/utils/src/display.rs b/crates/utils/src/display.rs index b5faf93f..48df39b7 100644 --- a/crates/utils/src/display.rs +++ b/crates/utils/src/display.rs @@ -1,3 +1,4 @@ +#[must_use] pub fn pretty_integer(i: usize) -> String { // ex: 123456789 -> "123,456,789" let s = i.to_string(); @@ -5,7 +6,7 @@ pub fn pretty_integer(i: usize) -> String { let mut result = String::new(); for (index, ch) in chars.iter().enumerate() { - if index > 0 && (chars.len() - index) % 3 == 0 { + if index > 0 && (chars.len() - index).is_multiple_of(3) { result.push(','); } result.push(*ch); diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index a76d5ce7..3525bad4 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -6,46 +6,48 @@ use rayon::prelude::*; use crate::PF; pub fn transmute_slice(slice: &[Before]) -> &[After] { - let new_len = slice.len() * std::mem::size_of::() / std::mem::size_of::(); + let new_len = std::mem::size_of_val(slice) / std::mem::size_of::(); assert_eq!( - slice.len() * std::mem::size_of::(), + std::mem::size_of_val(slice), new_len * std::mem::size_of::() ); assert_eq!(slice.as_ptr() as usize % std::mem::align_of::(), 0); - unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const After, new_len) } + unsafe { std::slice::from_raw_parts(slice.as_ptr().cast::(), new_len) } } -pub fn shift_range(range: Range, shift: usize) -> Range { +#[must_use] +pub const fn shift_range(range: Range, shift: usize) -> Range { Range { start: range.start + shift, end: range.end + shift, } } -pub fn diff_to_next_power_of_two(n: usize) -> usize { +#[must_use] +pub const fn diff_to_next_power_of_two(n: usize) -> usize { n.next_power_of_two() - n } pub fn left_mut(slice: &mut [A]) -> &mut [A] { - assert!(slice.len() % 2 == 0); + assert!(slice.len().is_multiple_of(2)); let mid = slice.len() / 2; &mut slice[..mid] } pub fn right_mut(slice: &mut [A]) -> &mut [A] { - assert!(slice.len() % 2 == 0); + assert!(slice.len().is_multiple_of(2)); let mid = slice.len() / 2; &mut slice[mid..] } pub fn left_ref(slice: &[A]) -> &[A] { - assert!(slice.len() % 2 == 0); + assert!(slice.len().is_multiple_of(2)); let mid = slice.len() / 2; &slice[..mid] } pub fn right_ref(slice: &[A]) -> &[A] { - assert!(slice.len() % 2 == 0); + assert!(slice.len().is_multiple_of(2)); let mid = slice.len() / 2; &slice[mid..] } @@ -62,7 +64,10 @@ pub fn remove_end(slice: &[A], n: usize) -> &[A] { } pub fn field_slice_as_base>(slice: &[EF]) -> Option> { - slice.par_iter().map(|x| x.as_base()).collect() + slice + .par_iter() + .map(p3_field::ExtensionField::as_base) + .collect() } pub fn dot_product_with_base>>(slice: &[EF]) -> EF { @@ -72,6 +77,7 @@ pub fn dot_product_with_base>>(slice: &[EF]) -> EF { .sum::() } +#[must_use] pub fn to_big_endian_bits(value: usize, bit_count: usize) -> Vec { (0..bit_count) .rev() diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index b8a08cee..a8f63a81 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -1,7 +1,6 @@ use std::borrow::Borrow; -use p3_field::{BasedVectorSpace, PackedValue}; -use p3_field::{ExtensionField, Field, dot_product}; +use p3_field::{BasedVectorSpace, ExtensionField, Field, PackedValue, dot_product}; use rayon::prelude::*; use tracing::instrument; use whir_p3::poly::evals::EvaluationsList; @@ -19,7 +18,7 @@ pub fn fold_multilinear_in_small_field, D>( let dim = >::DIMENSION; let m_transmuted: &[F] = - unsafe { std::slice::from_raw_parts(std::mem::transmute(m.as_ptr()), m.len() * dim) }; + unsafe { std::slice::from_raw_parts(m.as_ptr().cast::(), m.len() * dim) }; let res_transmuted = { let new_size = m.len() * dim / scalars.len(); diff --git a/crates/utils/src/packed_constraints_folder.rs b/crates/utils/src/packed_constraints_folder.rs index dc18b228..b83c339d 100644 --- a/crates/utils/src/packed_constraints_folder.rs +++ b/crates/utils/src/packed_constraints_folder.rs @@ -1,10 +1,9 @@ -use crate::EFPacking; -use crate::PF; -use crate::PFPacking; use p3_air::AirBuilder; use p3_field::ExtensionField; use p3_matrix::dense::RowMajorMatrixView; +use crate::{EFPacking, PF, PFPacking}; + #[derive(Debug)] pub struct ConstraintFolderPackedBase<'a, EF: ExtensionField>> { pub main: RowMajorMatrixView<'a, PFPacking>, diff --git a/crates/utils/src/point.rs b/crates/utils/src/point.rs index b9a4d1b2..b6b62b5e 100644 --- a/crates/utils/src/point.rs +++ b/crates/utils/src/point.rs @@ -8,9 +8,9 @@ pub struct Evaluation { pub value: F, } -impl Into<(MultilinearPoint, F)> for Evaluation { - fn into(self) -> (MultilinearPoint, F) { - (self.point, self.value) +impl From> for (MultilinearPoint, F) { + fn from(val: Evaluation) -> Self { + (val.point, val.value) } } diff --git a/crates/utils/src/poseidon2.rs b/crates/utils/src/poseidon2.rs index 9c8f193f..66125fe5 100644 --- a/crates/utils/src/poseidon2.rs +++ b/crates/utils/src/poseidon2.rs @@ -1,13 +1,8 @@ -use p3_koala_bear::GenericPoseidon2LinearLayersKoalaBear; -use p3_koala_bear::KoalaBear; -use p3_koala_bear::Poseidon2KoalaBear; +use p3_koala_bear::{GenericPoseidon2LinearLayersKoalaBear, KoalaBear, Poseidon2KoalaBear}; use p3_matrix::dense::RowMajorMatrix; use p3_poseidon2::ExternalLayerConstants; -use p3_poseidon2_air::Poseidon2Air; -use p3_poseidon2_air::RoundConstants; -use p3_poseidon2_air::generate_trace_rows; -use rand::SeedableRng; -use rand::rngs::StdRng; +use p3_poseidon2_air::{Poseidon2Air, RoundConstants, generate_trace_rows}; +use rand::{SeedableRng, rngs::StdRng}; pub type Poseidon16 = Poseidon2KoalaBear<16>; pub type Poseidon24 = Poseidon2KoalaBear<24>; @@ -43,6 +38,7 @@ pub type Poseidon24Air = Poseidon2Air< PARTIAL_ROUNDS_24, >; +#[must_use] pub fn build_poseidon16() -> Poseidon16 { let round_constants = build_poseidon16_constants(); let external_constants = ExternalLayerConstants::new( @@ -55,6 +51,7 @@ pub fn build_poseidon16() -> Poseidon16 { ) } +#[must_use] pub fn build_poseidon24() -> Poseidon24 { let round_constants = build_poseidon24_constants(); let external_constants = ExternalLayerConstants::new( @@ -67,10 +64,12 @@ pub fn build_poseidon24() -> Poseidon24 { ) } +#[must_use] pub fn build_poseidon_16_air() -> Poseidon16Air { Poseidon16Air::new(build_poseidon16_constants()) } +#[must_use] pub fn build_poseidon_24_air() -> Poseidon24Air { Poseidon24Air::new(build_poseidon24_constants()) } @@ -89,6 +88,7 @@ fn build_poseidon24_constants() ) } +#[must_use] pub fn generate_trace_poseidon_16(inputs: Vec<[KoalaBear; 16]>) -> RowMajorMatrix { generate_trace_rows::< KoalaBear, @@ -101,6 +101,7 @@ pub fn generate_trace_poseidon_16(inputs: Vec<[KoalaBear; 16]>) -> RowMajorMatri >(inputs, &build_poseidon16_constants(), 0) } +#[must_use] pub fn generate_trace_poseidon_24(inputs: Vec<[KoalaBear; 24]>) -> RowMajorMatrix { generate_trace_rows::< KoalaBear, diff --git a/crates/utils/src/univariate.rs b/crates/utils/src/univariate.rs index 9a6f95d8..1dd7fb7b 100644 --- a/crates/utils/src/univariate.rs +++ b/crates/utils/src/univariate.rs @@ -1,17 +1,21 @@ +use std::{ + any::{Any, TypeId}, + collections::HashMap, + sync::{Arc, Mutex, OnceLock}, +}; + use p3_field::Field; use rayon::prelude::*; use whir_p3::poly::dense::WhirDensePolynomial; -use std::any::{Any, TypeId}; -use std::collections::HashMap; -use std::sync::{Arc, Mutex, OnceLock}; - type CacheKey = (TypeId, usize); +#[allow(clippy::type_complexity)] static SELECTORS_CACHE: OnceLock< Mutex>>>>, > = OnceLock::new(); +#[allow(clippy::significant_drop_tightening)] pub fn univariate_selectors(n: usize) -> Arc>> { let key = (TypeId::of::(), n); let mut map = SELECTORS_CACHE diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index 6320f5f2..2be7d40f 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -1,24 +1,19 @@ use p3_challenger::DuplexChallenger; -use p3_field::BasedVectorSpace; -use p3_field::ExtensionField; -use p3_field::PackedFieldExtension; -use p3_field::PackedValue; -use p3_field::PrimeField64; -use p3_field::{Field, PrimeCharacteristicRing}; +use p3_field::{ + BasedVectorSpace, ExtensionField, Field, PackedFieldExtension, PackedValue, + PrimeCharacteristicRing, PrimeField64, +}; use p3_koala_bear::KoalaBear; -use p3_symmetric::CryptographicHasher; -use p3_symmetric::PaddingFreeSponge; -use p3_symmetric::PseudoCompressionFunction; -use p3_symmetric::TruncatedPermutation; - +use p3_symmetric::{ + CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation, +}; use rayon::prelude::*; -use whir_p3::fiat_shamir::{prover::ProverState, verifier::VerifierState}; -use whir_p3::whir::config::WhirConfigBuilder; +use whir_p3::{ + fiat_shamir::{prover::ProverState, verifier::VerifierState}, + whir::config::WhirConfigBuilder, +}; -use crate::Poseidon16; -use crate::Poseidon24; -use crate::build_poseidon16; -use crate::build_poseidon24; +use crate::{Poseidon16, Poseidon24, build_poseidon16, build_poseidon24}; pub type PF = ::PrimeSubfield; pub type PFPacking = as Field>::Packing; @@ -74,7 +69,7 @@ pub fn pack_extension>>(slice: &[EF]) -> Vec>>(vec: &[EFPacking]) -> Vec { - vec.into_iter() + vec.iter() .flat_map(|x| { let packed_coeffs = x.as_basis_coefficients_slice(); (0..packing_width::()) @@ -84,31 +79,38 @@ pub fn unpack_extension>>(vec: &[EFPacking]) -> Ve .collect() } +#[must_use] pub const fn packing_log_width() -> usize { packing_width::().ilog2() as usize } +#[must_use] pub const fn packing_width() -> usize { PFPacking::::WIDTH } +#[must_use] pub fn build_challenger() -> MyChallenger { MyChallenger::new(build_poseidon16()) } +#[must_use] pub fn build_merkle_hash() -> MyMerkleHash { MyMerkleHash::new(build_poseidon24()) } +#[must_use] pub fn build_merkle_compress() -> MyMerkleCompress { MyMerkleCompress::new(build_poseidon16()) } +#[must_use] pub fn build_prover_state>() -> ProverState { ProverState::new(build_challenger()) } +#[must_use] pub fn build_verifier_state>( prover_state: &ProverState, ) -> VerifierState { diff --git a/crates/vm/src/bytecode.rs b/crates/vm/src/bytecode.rs index 5fa17c19..786645aa 100644 --- a/crates/vm/src/bytecode.rs +++ b/crates/vm/src/bytecode.rs @@ -1,7 +1,9 @@ -use crate::F; -use p3_field::PrimeCharacteristicRing; use std::collections::BTreeMap; +use p3_field::PrimeCharacteristicRing; + +use crate::F; + pub type Label = String; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -81,17 +83,19 @@ pub enum Instruction { } impl Operation { + #[must_use] pub fn compute(&self, a: F, b: F) -> F { match self { - Operation::Add => a + b, - Operation::Mul => a * b, + Self::Add => a + b, + Self::Mul => a * b, } } + #[must_use] pub fn inverse_compute(&self, a: F, b: F) -> Option { match self { - Operation::Add => Some(a - b), - Operation::Mul => { + Self::Add => Some(a - b), + Self::Mul => { if b == F::ZERO { None } else { @@ -124,35 +128,35 @@ pub enum Hint { } impl MemOrConstant { - pub fn zero() -> Self { - MemOrConstant::Constant(F::ZERO) + #[must_use] + pub const fn zero() -> Self { + Self::Constant(F::ZERO) } - pub fn one() -> Self { - MemOrConstant::Constant(F::ONE) + #[must_use] + pub const fn one() -> Self { + Self::Constant(F::ONE) } } impl ToString for Bytecode { fn to_string(&self) -> String { - let mut pc = 0; let mut res = String::new(); - for instruction in &self.instructions { + for (pc, instruction) in self.instructions.iter().enumerate() { for hint in self.hints.get(&pc).unwrap_or(&Vec::new()) { res.push_str(&format!("hint: {}\n", hint.to_string())); } res.push_str(&format!("{:>4}: {}\n", pc, instruction.to_string())); - pc += 1; } - return res; + res } } impl ToString for MemOrConstant { fn to_string(&self) -> String { match self { - Self::Constant(c) => format!("{}", c), - Self::MemoryAfterFp { offset } => format!("m[fp + {}]", offset), + Self::Constant(c) => format!("{c}"), + Self::MemoryAfterFp { offset } => format!("m[fp + {offset}]"), } } } @@ -160,7 +164,7 @@ impl ToString for MemOrConstant { impl ToString for MemOrFp { fn to_string(&self) -> String { match self { - Self::MemoryAfterFp { offset } => format!("m[fp + {}]", offset), + Self::MemoryAfterFp { offset } => format!("m[fp + {offset}]"), Self::Fp => "fp".to_string(), } } @@ -169,9 +173,9 @@ impl ToString for MemOrFp { impl ToString for MemOrFpOrConstant { fn to_string(&self) -> String { match self { - Self::MemoryAfterFp { offset } => format!("m[fp + {}]", offset), + Self::MemoryAfterFp { offset } => format!("m[fp + {offset}]"), Self::Fp => "fp".to_string(), - Self::Constant(c) => format!("{}", c), + Self::Constant(c) => format!("{c}"), } } } @@ -299,7 +303,7 @@ impl ToString for Hint { "print({}) for \"{}\"", content .iter() - .map(|v| v.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(", "), line_info, diff --git a/crates/vm/src/runner.rs b/crates/vm/src/runner.rs index 20010e3c..1d4a2f7b 100644 --- a/crates/vm/src/runner.rs +++ b/crates/vm/src/runner.rs @@ -4,7 +4,11 @@ use rayon::prelude::*; use utils::{ToUsize, build_poseidon16, build_poseidon24, pretty_integer}; use whir_p3::poly::{evals::EvaluationsList, multilinear::MultilinearPoint}; -use crate::{bytecode::*, *}; +use crate::{ + DIMENSION, EF, F, POSEIDON_16_NULL_HASH_PTR, POSEIDON_24_NULL_HASH_PTR, PUBLIC_INPUT_START, + ZERO_VEC_PTR, + bytecode::{Bytecode, Hint, Instruction, MemOrConstant, MemOrFp, MemOrFpOrConstant}, +}; const MAX_MEMORY_SIZE: usize = 1 << 23; @@ -22,15 +26,15 @@ pub enum RunnerError { impl ToString for RunnerError { fn to_string(&self) -> String { match self { - RunnerError::OutOfMemory => "Out of memory".to_string(), - RunnerError::MemoryAlreadySet => "Memory already set".to_string(), - RunnerError::NotAPointer => "Not a pointer".to_string(), - RunnerError::DivByZero => "Division by zero".to_string(), - RunnerError::NotEqual(expected, actual) => { - format!("Computation Invalid: {} != {}", expected, actual) + Self::OutOfMemory => "Out of memory".to_string(), + Self::MemoryAlreadySet => "Memory already set".to_string(), + Self::NotAPointer => "Not a pointer".to_string(), + Self::DivByZero => "Division by zero".to_string(), + Self::NotEqual(expected, actual) => { + format!("Computation Invalid: {expected} != {actual}") } - RunnerError::UndefinedMemory => "Undefined memory access".to_string(), - RunnerError::PCOutOfBounds => "Program counter out of bounds".to_string(), + Self::UndefinedMemory => "Undefined memory access".to_string(), + Self::PCOutOfBounds => "Program counter out of bounds".to_string(), } } } @@ -41,19 +45,20 @@ pub struct Memory(pub Vec>); impl MemOrConstant { pub fn read_value(&self, memory: &Memory, fp: usize) -> Result { match self { - MemOrConstant::Constant(c) => Ok(*c), - MemOrConstant::MemoryAfterFp { offset } => memory.get(fp + *offset), + Self::Constant(c) => Ok(*c), + Self::MemoryAfterFp { offset } => memory.get(fp + *offset), } } + #[must_use] pub fn is_value_unknown(&self, memory: &Memory, fp: usize) -> bool { self.read_value(memory, fp).is_err() } - pub fn memory_address(&self, fp: usize) -> Result { + pub const fn memory_address(&self, fp: usize) -> Result { match self { - MemOrConstant::Constant(_) => Err(RunnerError::NotAPointer), - MemOrConstant::MemoryAfterFp { offset } => Ok(fp + *offset), + Self::Constant(_) => Err(RunnerError::NotAPointer), + Self::MemoryAfterFp { offset } => Ok(fp + *offset), } } } @@ -61,19 +66,20 @@ impl MemOrConstant { impl MemOrFp { pub fn read_value(&self, memory: &Memory, fp: usize) -> Result { match self { - MemOrFp::MemoryAfterFp { offset } => memory.get(fp + *offset), - MemOrFp::Fp => Ok(F::from_usize(fp)), + Self::MemoryAfterFp { offset } => memory.get(fp + *offset), + Self::Fp => Ok(F::from_usize(fp)), } } + #[must_use] pub fn is_value_unknown(&self, memory: &Memory, fp: usize) -> bool { self.read_value(memory, fp).is_err() } - pub fn memory_address(&self, fp: usize) -> Result { + pub const fn memory_address(&self, fp: usize) -> Result { match self { - MemOrFp::MemoryAfterFp { offset } => Ok(fp + *offset), - MemOrFp::Fp => Err(RunnerError::NotAPointer), + Self::MemoryAfterFp { offset } => Ok(fp + *offset), + Self::Fp => Err(RunnerError::NotAPointer), } } } @@ -81,21 +87,21 @@ impl MemOrFp { impl MemOrFpOrConstant { pub fn read_value(&self, memory: &Memory, fp: usize) -> Result { match self { - MemOrFpOrConstant::MemoryAfterFp { offset } => memory.get(fp + *offset), - MemOrFpOrConstant::Fp => Ok(F::from_usize(fp)), - MemOrFpOrConstant::Constant(c) => Ok(*c), + Self::MemoryAfterFp { offset } => memory.get(fp + *offset), + Self::Fp => Ok(F::from_usize(fp)), + Self::Constant(c) => Ok(*c), } } + #[must_use] pub fn is_value_unknown(&self, memory: &Memory, fp: usize) -> bool { self.read_value(memory, fp).is_err() } - pub fn memory_address(&self, fp: usize) -> Result { + pub const fn memory_address(&self, fp: usize) -> Result { match self { - MemOrFpOrConstant::MemoryAfterFp { offset } => Ok(fp + *offset), - MemOrFpOrConstant::Fp => Err(RunnerError::NotAPointer), - MemOrFpOrConstant::Constant(_) => Err(RunnerError::NotAPointer), + Self::MemoryAfterFp { offset } => Ok(fp + *offset), + Self::Fp | Self::Constant(_) => Err(RunnerError::NotAPointer), } } } @@ -174,7 +180,7 @@ impl Memory { } pub fn set_vectorized_slice(&mut self, index: usize, value: &[F]) -> Result<(), RunnerError> { - assert!(value.len() % DIMENSION == 0); + assert!(value.len().is_multiple_of(DIMENSION)); for (i, v) in value.iter().enumerate() { let idx = DIMENSION * index + i; self.set(idx, *v)?; @@ -183,6 +189,7 @@ impl Memory { } } +#[must_use] pub fn execute_bytecode( bytecode: &Bytecode, public_input: &[F], @@ -200,7 +207,7 @@ pub fn execute_bytecode( Ok(first_exec) => first_exec, Err(err) => { if !std_out.is_empty() { - print!("{}", std_out); + print!("{std_out}"); } panic!("Error during bytecode execution: {}", err.to_string()); } @@ -225,13 +232,14 @@ pub struct ExecutionResult { pub fps: Vec, } +#[must_use] pub fn build_public_memory(public_input: &[F]) -> Vec { // padded to a power of two let public_memory_len = (PUBLIC_INPUT_START + public_input.len()).next_power_of_two(); let mut public_memory = F::zero_vec(public_memory_len); public_memory[PUBLIC_INPUT_START..][..public_input.len()].copy_from_slice(public_input); - for i in ZERO_VEC_PTR * 8..(ZERO_VEC_PTR + 2) * 8 { - public_memory[i] = F::ZERO; // zero vector + for pm in public_memory.iter_mut().take((ZERO_VEC_PTR + 2) * 8) { + *pm = F::ZERO; // zero vector } public_memory[POSEIDON_16_NULL_HASH_PTR * 8..(POSEIDON_16_NULL_HASH_PTR + 2) * 8] .copy_from_slice(&build_poseidon16().permute([F::ZERO; 16])); @@ -340,7 +348,7 @@ fn execute_bytecode_helper( // Logs for performance analysis: if values[0] == "123456789" { if values.len() == 1 { - *std_out += &format!("[CHECKPOINT]\n"); + *std_out += "[CHECKPOINT]\n"; } else { assert_eq!(values.len(), 2); let new_no_vec_memory = ap - checkpoint_ap; @@ -361,7 +369,7 @@ fn execute_bytecode_helper( continue; } - let line_info = line_info.replace(";", ""); + let line_info = line_info.replace(';', ""); *std_out += &format!("\"{}\" -> {}\n", line_info, values.join(", ")); // does not increase PC } @@ -434,11 +442,11 @@ fn execute_bytecode_helper( } => { let condition_value = condition.read_value(&memory, fp)?; assert!([F::ZERO, F::ONE].contains(&condition_value),); - if condition_value != F::ZERO { + if condition_value == F::ZERO { + pc += 1; + } else { pc = dest.read_value(&memory, fp)?.to_usize(); fp = updated_fp.read_value(&memory, fp)?.to_usize(); - } else { - pc += 1; } } Instruction::Poseidon2_16 { arg_a, arg_b, res } => { @@ -550,7 +558,7 @@ fn execute_bytecode_helper( if final_execution { if !std_out.is_empty() { - print!("{}", std_out); + print!("{std_out}"); } let runtime_memory_size = memory.0.len() - (PUBLIC_INPUT_START + public_input.len()); println!( diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index b01e465e..d95b95b7 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -39,6 +39,7 @@ impl WotsSecretKey { Self::new(pre_images, poseidon16) } + #[must_use] pub fn new(pre_images: [Digest; N_CHAINS], poseidon16: &Poseidon16) -> Self { let mut public_key = [Default::default(); N_CHAINS]; for i in 0..N_CHAINS { @@ -50,10 +51,12 @@ impl WotsSecretKey { } } - pub fn public_key(&self) -> &WotsPublicKey { + #[must_use] + pub const fn public_key(&self) -> &WotsPublicKey { &self.public_key } + #[must_use] pub fn sign(&self, message: &Message, poseidon16: &Poseidon16) -> WotsSignature { let mut signature = [Default::default(); N_CHAINS]; for i in 0..N_CHAINS { @@ -68,10 +71,11 @@ impl WotsSecretKey { } impl WotsSignature { + #[must_use] pub fn recover_public_key( &self, message: &Message, - signature: &WotsSignature, + signature: &Self, poseidon16: &Poseidon16, ) -> WotsPublicKey { let mut public_key = [Default::default(); N_CHAINS]; @@ -91,8 +95,10 @@ impl WotsSignature { } impl WotsPublicKey { + #[must_use] + #[allow(clippy::assertions_on_constants)] pub fn hash(&self, poseidon24: &Poseidon24) -> Digest { - assert!(N_CHAINS % 2 == 0, "TODO"); + assert!(N_CHAINS.is_multiple_of(2), "TODO"); let mut digest = Default::default(); for (a, b) in self.0.chunks(2).map(|chunk| (chunk[0], chunk[1])) { digest = poseidon24.permute([a, b, digest].concat().try_into().unwrap())[16..24] @@ -121,6 +127,7 @@ pub struct XmssPublicKey { } impl XmssSecretKey { + #[allow(clippy::tuple_array_conversions)] pub fn random(rng: &mut R, poseidon16: &Poseidon16, poseidon24: &Poseidon24) -> Self { let mut wots_secret_keys = Vec::new(); for _ in 0..1 << XMSS_MERKLE_HEIGHT { @@ -149,6 +156,7 @@ impl XmssSecretKey { } } + #[must_use] pub fn sign(&self, message: &Message, index: usize, poseidon16: &Poseidon16) -> XmssSignature { assert!( index < (1 << XMSS_MERKLE_HEIGHT), @@ -158,7 +166,7 @@ impl XmssSecretKey { let mut merkle_proof = Vec::new(); let mut current_index = index; for level in 0..XMSS_MERKLE_HEIGHT { - let is_left = current_index % 2 == 0; + let is_left = current_index.is_multiple_of(2); let neighbour_index = if is_left { current_index + 1 } else { @@ -174,6 +182,7 @@ impl XmssSecretKey { } } + #[must_use] pub fn public_key(&self) -> XmssPublicKey { XmssPublicKey { root: self.merkle_tree.last().unwrap()[0], @@ -182,6 +191,7 @@ impl XmssSecretKey { } impl XmssPublicKey { + #[must_use] pub fn verify( &self, message: &Message, diff --git a/crates/zk_vm/src/air.rs b/crates/zk_vm/src/air.rs index 02a8e89c..75d55a1e 100644 --- a/crates/zk_vm/src/air.rs +++ b/crates/zk_vm/src/air.rs @@ -4,7 +4,13 @@ use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::PrimeCharacteristicRing; use p3_matrix::Matrix; -use crate::*; +use crate::{ + COL_INDEX_ADD, COL_INDEX_AUX, COL_INDEX_DEREF, COL_INDEX_FLAG_A, COL_INDEX_FLAG_B, + COL_INDEX_FLAG_C, COL_INDEX_FP, COL_INDEX_JUZ, COL_INDEX_MEM_ADDRESS_A, + COL_INDEX_MEM_ADDRESS_B, COL_INDEX_MEM_ADDRESS_C, COL_INDEX_MEM_VALUE_A, COL_INDEX_MEM_VALUE_B, + COL_INDEX_MEM_VALUE_C, COL_INDEX_MUL, COL_INDEX_OPERAND_A, COL_INDEX_OPERAND_B, + COL_INDEX_OPERAND_C, COL_INDEX_PC, InAirColumnIndex, N_EXEC_AIR_COLUMNS, +}; /* @@ -35,7 +41,7 @@ Execution columns: */ -pub(crate) struct VMAir; +pub struct VMAir; impl BaseAir for VMAir { fn width(&self) -> usize { @@ -97,29 +103,21 @@ impl Air for VMAir { let nu_a = flag_a.clone() * operand_a.clone() + value_a.clone() * (AB::Expr::ONE - flag_a.clone()); - let nu_b = - flag_b.clone() * operand_b.clone() + value_b.clone() * (AB::Expr::ONE - flag_b.clone()); + let nu_b = flag_b.clone() * operand_b.clone() + value_b * (AB::Expr::ONE - flag_b.clone()); let nu_c = flag_c.clone() * fp.clone() + value_c.clone() * (AB::Expr::ONE - flag_c.clone()); + builder.assert_zero((AB::Expr::ONE - flag_a) * (addr_a - (fp.clone() + operand_a))); + builder.assert_zero((AB::Expr::ONE - flag_b) * (addr_b - (fp.clone() + operand_b))); builder.assert_zero( - (AB::Expr::ONE - flag_a.clone()) * (addr_a.clone() - (fp.clone() + operand_a.clone())), - ); - builder.assert_zero( - (AB::Expr::ONE - flag_b.clone()) * (addr_b.clone() - (fp.clone() + operand_b.clone())), - ); - builder.assert_zero( - (AB::Expr::ONE - flag_c.clone()) * (addr_c.clone() - (fp.clone() + operand_c.clone())), + (AB::Expr::ONE - flag_c) * (addr_c.clone() - (fp.clone() + operand_c.clone())), ); - builder.assert_zero(add.clone() * (nu_b.clone() - (nu_a.clone() + nu_c.clone()))); - builder.assert_zero(mul.clone() * (nu_b.clone() - nu_a.clone() * nu_c.clone())); + builder.assert_zero(add * (nu_b.clone() - (nu_a.clone() + nu_c.clone()))); + builder.assert_zero(mul * (nu_b.clone() - nu_a.clone() * nu_c.clone())); - builder - .assert_zero(deref.clone() * (addr_c.clone() - (value_a.clone() + operand_c.clone()))); + builder.assert_zero(deref.clone() * (addr_c - (value_a + operand_c))); builder.assert_zero(deref.clone() * aux.clone() * (value_c.clone() - nu_b.clone())); - builder.assert_zero( - deref.clone() * (AB::Expr::ONE - aux.clone()) * (value_c.clone() - fp.clone()), - ); + builder.assert_zero(deref * (AB::Expr::ONE - aux) * (value_c - fp.clone())); builder.assert_zero( (AB::Expr::ONE - juz.clone()) * (next_pc.clone() - (pc.clone() + AB::Expr::ONE)), @@ -127,15 +125,11 @@ impl Air for VMAir { builder.assert_zero((AB::Expr::ONE - juz.clone()) * (next_fp.clone() - fp.clone())); builder.assert_zero(juz.clone() * nu_a.clone() * (AB::Expr::ONE - nu_a.clone())); - builder.assert_zero(juz.clone() * nu_a.clone() * (next_pc.clone() - nu_b.clone())); - builder.assert_zero(juz.clone() * nu_a.clone() * (next_fp.clone() - nu_c.clone())); - builder.assert_zero( - juz.clone() - * (AB::Expr::ONE - nu_a.clone()) - * (next_pc.clone() - (pc.clone() + AB::Expr::ONE)), - ); + builder.assert_zero(juz.clone() * nu_a.clone() * (next_pc.clone() - nu_b)); + builder.assert_zero(juz.clone() * nu_a.clone() * (next_fp.clone() - nu_c)); builder.assert_zero( - juz.clone() * (AB::Expr::ONE - nu_a.clone()) * (next_fp.clone() - fp.clone()), + juz.clone() * (AB::Expr::ONE - nu_a.clone()) * (next_pc - (pc + AB::Expr::ONE)), ); + builder.assert_zero(juz * (AB::Expr::ONE - nu_a) * (next_fp - fp)); } } diff --git a/crates/zk_vm/src/common.rs b/crates/zk_vm/src/common.rs index d41fcee7..913816de 100644 --- a/crates/zk_vm/src/common.rs +++ b/crates/zk_vm/src/common.rs @@ -9,7 +9,7 @@ use utils::{ EFPacking, Evaluation, PF, Poseidon16Air, Poseidon24Air, from_end, padd_with_zero_to_next_power_of_two, remove_end, }; -use vm::*; +use vm::{Bytecode, EF}; use whir_p3::{ fiat_shamir::errors::ProofError, poly::{ @@ -18,9 +18,14 @@ use whir_p3::{ }, }; -use crate::{instruction_encoder::field_representation, *}; +use crate::{ + COL_INDEX_AUX, COL_INDEX_DOT_PRODUCT, COL_INDEX_FLAG_A, COL_INDEX_FLAG_B, COL_INDEX_FLAG_C, + COL_INDEX_FP, COL_INDEX_MEM_VALUE_A, COL_INDEX_MEM_VALUE_B, COL_INDEX_MEM_VALUE_C, + COL_INDEX_MULTILINEAR_EVAL, COL_INDEX_OPERAND_A, COL_INDEX_OPERAND_B, COL_INDEX_POSEIDON_16, + COL_INDEX_POSEIDON_24, instruction_encoder::field_representation, +}; -pub(crate) fn poseidon_16_column_groups(poseidon_16_air: &Poseidon16Air) -> Vec> { +pub fn poseidon_16_column_groups(poseidon_16_air: &Poseidon16Air) -> Vec> { vec![ 0..8, 8..16, @@ -30,7 +35,7 @@ pub(crate) fn poseidon_16_column_groups(poseidon_16_air: &Poseidon16Air) -> Vec< ] } -pub(crate) fn poseidon_24_column_groups(poseidon_24_air: &Poseidon24Air) -> Vec> { +pub fn poseidon_24_column_groups(poseidon_24_air: &Poseidon24Air) -> Vec> { vec![ 0..8, 8..16, @@ -41,7 +46,7 @@ pub(crate) fn poseidon_24_column_groups(poseidon_24_air: &Poseidon24Air) -> Vec< ] } -pub(crate) fn poseidon_lookup_value( +pub fn poseidon_lookup_value( n_poseidons_16: usize, n_poseidons_24: usize, poseidon16_evals: &[Evaluation], @@ -75,10 +80,10 @@ pub(crate) fn poseidon_lookup_value( poseidon24_evals[2].value * s24, poseidon24_evals[5].value * s24, ] - .evaluate(&poseidon_lookup_batching_chalenges) + .evaluate(poseidon_lookup_batching_chalenges) } -pub(crate) fn poseidon_lookup_index_statements( +pub fn poseidon_lookup_index_statements( poseidon_index_evals: &[EF], n_poseidons_16: usize, n_poseidons_24: usize, @@ -158,10 +163,7 @@ pub(crate) fn poseidon_lookup_index_statements( Ok((p16_indexes_statements, p24_indexes_statements)) } -pub(crate) fn fold_bytecode( - bytecode: &Bytecode, - folding_challenges: &MultilinearPoint, -) -> Vec { +pub fn fold_bytecode(bytecode: &Bytecode, folding_challenges: &MultilinearPoint) -> Vec { let encoded_bytecode = padd_with_zero_to_next_power_of_two( &bytecode .instructions @@ -169,10 +171,10 @@ pub(crate) fn fold_bytecode( .flat_map(|i| padd_with_zero_to_next_power_of_two(&field_representation(i))) .collect::>(), ); - fold_multilinear(&encoded_bytecode, &folding_challenges) + fold_multilinear(&encoded_bytecode, folding_challenges) } -pub(crate) fn intitial_and_final_pc_conditions( +pub fn intitial_and_final_pc_conditions( bytecode: &Bytecode, log_n_cycles: usize, ) -> (Evaluation, Evaluation) { @@ -187,7 +189,7 @@ pub(crate) fn intitial_and_final_pc_conditions( (initial_pc_statement, final_pc_statement) } -pub(crate) struct PrecompileFootprint { +pub struct PrecompileFootprint { pub grand_product_challenge_global: EF, pub grand_product_challenge_p16: [EF; 5], pub grand_product_challenge_p24: [EF; 5], @@ -252,7 +254,7 @@ impl SumcheckComputationPacked for PrecompileFootprint { } } -pub(crate) struct DotProductFootprint { +pub struct DotProductFootprint { pub grand_product_challenge_global: EF, pub grand_product_challenge_dot_product: [EF; 6], } diff --git a/crates/zk_vm/src/dot_product_air.rs b/crates/zk_vm/src/dot_product_air.rs index 81e21f75..526a6076 100644 --- a/crates/zk_vm/src/dot_product_air.rs +++ b/crates/zk_vm/src/dot_product_air.rs @@ -22,9 +22,9 @@ use crate::execution_trace::WitnessDotProduct; */ const DOT_PRODUCT_AIR_COLUMNS: usize = 9; -pub(crate) const DOT_PRODUCT_AIR_COLUMN_GROUPS: [Range; 5] = [0..1, 1..2, 2..5, 5..8, 8..9]; +pub const DOT_PRODUCT_AIR_COLUMN_GROUPS: [Range; 5] = [0..1, 1..2, 2..5, 5..8, 8..9]; -pub(crate) struct DotProductAir; +pub struct DotProductAir; impl BaseAir for DotProductAir { fn width(&self) -> usize { @@ -60,7 +60,7 @@ impl Air for DotProductAir { res_up, computation_up, ] = up - .into_iter() + .iter() .map(|v| v.clone().into()) .collect::>() .try_into() @@ -76,7 +76,7 @@ impl Air for DotProductAir { _res_down, computation_down, ] = down - .into_iter() + .iter() .map(|v| v.clone().into()) .collect::>() .try_into() @@ -91,22 +91,16 @@ impl Air for DotProductAir { flag_down.clone() * product_up.clone() + not_flag_down.clone() * (product_up + computation_down), ); - builder.assert_zero( - not_flag_down.clone() * (len_up.clone() - (len_down.clone() + AB::Expr::ONE)), - ); - builder.assert_zero(flag_down.clone() * (len_up.clone() - AB::Expr::ONE)); - builder.assert_zero( - not_flag_down.clone() * (index_a_up.clone() - (index_a_down.clone() - AB::Expr::ONE)), - ); - builder.assert_zero( - not_flag_down.clone() * (index_b_up.clone() - (index_b_down.clone() - AB::Expr::ONE)), - ); + builder.assert_zero(not_flag_down.clone() * (len_up.clone() - (len_down + AB::Expr::ONE))); + builder.assert_zero(flag_down * (len_up - AB::Expr::ONE)); + builder.assert_zero(not_flag_down.clone() * (index_a_up - (index_a_down - AB::Expr::ONE))); + builder.assert_zero(not_flag_down * (index_b_up - (index_b_down - AB::Expr::ONE))); - builder.assert_zero(flag_up.clone() * (computation_up - res_up)); + builder.assert_zero(flag_up * (computation_up - res_up)); } } -pub(crate) fn build_dot_product_columns(witness: &[WitnessDotProduct]) -> (Vec>, usize) { +pub fn build_dot_product_columns(witness: &[WitnessDotProduct]) -> (Vec>, usize) { let ( mut flag, mut len, diff --git a/crates/zk_vm/src/execution_trace.rs b/crates/zk_vm/src/execution_trace.rs index e4837b1e..2a02b6ab 100644 --- a/crates/zk_vm/src/execution_trace.rs +++ b/crates/zk_vm/src/execution_trace.rs @@ -2,7 +2,10 @@ use p3_field::{Field, PrimeCharacteristicRing}; use p3_symmetric::Permutation; use rayon::prelude::*; use utils::{ToUsize, build_poseidon16, build_poseidon24}; -use vm::*; +use vm::{ + Bytecode, EF, ExecutionResult, F, Instruction, POSEIDON_16_NULL_HASH_PTR, + POSEIDON_24_NULL_HASH_PTR, +}; use crate::{ COL_INDEX_FP, COL_INDEX_MEM_ADDRESS_A, COL_INDEX_MEM_ADDRESS_B, COL_INDEX_MEM_ADDRESS_C, @@ -10,7 +13,7 @@ use crate::{ N_EXEC_COLUMNS, N_INSTRUCTION_COLUMNS, instruction_encoder::field_representation, }; -pub(crate) struct WitnessDotProduct { +pub struct WitnessDotProduct { pub cycle: usize, pub addr_0: usize, // vectorized pointer pub addr_1: usize, // vectorized pointer @@ -21,7 +24,7 @@ pub(crate) struct WitnessDotProduct { pub res: EF, } -pub(crate) struct WitnessMultilinearEval { +pub struct WitnessMultilinearEval { pub cycle: usize, pub addr_coeffs: usize, // vectorized pointer, of size 8.2^size pub addr_point: usize, // vectorized pointer, of size `size` @@ -31,7 +34,7 @@ pub(crate) struct WitnessMultilinearEval { pub res: EF, } -pub(crate) struct WitnessPoseidon16 { +pub struct WitnessPoseidon16 { pub cycle: Option, pub addr_input_a: usize, // vectorized pointer (of size 1) pub addr_input_b: usize, // vectorized pointer (of size 1) @@ -40,7 +43,7 @@ pub(crate) struct WitnessPoseidon16 { pub output: [F; 16], } -pub(crate) struct WitnessPoseidon24 { +pub struct WitnessPoseidon24 { pub cycle: Option, pub addr_input_a: usize, // vectorized pointer (of size 2) pub addr_input_b: usize, // vectorized pointer (of size 1) @@ -49,7 +52,7 @@ pub(crate) struct WitnessPoseidon24 { pub output: [F; 8], // last 8 elements of the output } -pub(crate) struct ExecutionTrace { +pub struct ExecutionTrace { pub full_trace: Vec>, pub n_poseidons_16: usize, pub n_poseidons_24: usize, @@ -61,7 +64,7 @@ pub(crate) struct ExecutionTrace { pub memory: Vec, // of length a multiple of public_memory_size } -pub(crate) fn get_execution_trace( +pub fn get_execution_trace( bytecode: &Bytecode, execution_result: &ExecutionResult, ) -> ExecutionTrace { @@ -84,7 +87,7 @@ pub(crate) fn get_execution_trace( .enumerate() { let instruction = &bytecode.instructions[pc]; - let field_repr = field_representation(&instruction); + let field_repr = field_representation(instruction); // println!( // "Cycle {}: PC = {}, FP = {}, Instruction = {}", @@ -95,17 +98,22 @@ pub(crate) fn get_execution_trace( trace[j][cycle] = *field; } - let mut addr_a = F::ZERO; - if field_repr[3].is_zero() { + let addr_a = if field_repr[3].is_zero() { // flag_a == 0 - addr_a = F::from_usize(fp) + field_repr[0]; // fp + operand_a - } + // fp + operand_a + F::from_usize(fp) + field_repr[0] + } else { + F::ZERO + }; + let value_a = memory.0[addr_a.to_usize()].unwrap(); - let mut addr_b = F::ZERO; - if field_repr[4].is_zero() { + let addr_b = if field_repr[4].is_zero() { // flag_b == 0 - addr_b = F::from_usize(fp) + field_repr[1]; // fp + operand_b - } + // fp + operand_b + F::from_usize(fp) + field_repr[1] + } else { + F::ZERO + }; let value_b = memory.0[addr_b.to_usize()].unwrap(); let mut addr_c = F::ZERO; @@ -130,9 +138,9 @@ pub(crate) fn get_execution_trace( match instruction { Instruction::Poseidon2_16 { arg_a, arg_b, res } => { - let addr_input_a = arg_a.read_value(&memory, fp).unwrap().to_usize(); - let addr_input_b = arg_b.read_value(&memory, fp).unwrap().to_usize(); - let addr_output = res.read_value(&memory, fp).unwrap().to_usize(); + let addr_input_a = arg_a.read_value(memory, fp).unwrap().to_usize(); + let addr_input_b = arg_b.read_value(memory, fp).unwrap().to_usize(); + let addr_output = res.read_value(memory, fp).unwrap().to_usize(); let value_a = memory.get_vector(addr_input_a).unwrap(); let value_b = memory.get_vector(addr_input_b).unwrap(); let output = memory.get_vectorized_slice(addr_output, 2).unwrap(); @@ -146,9 +154,9 @@ pub(crate) fn get_execution_trace( }); } Instruction::Poseidon2_24 { arg_a, arg_b, res } => { - let addr_input_a = arg_a.read_value(&memory, fp).unwrap().to_usize(); - let addr_input_b = arg_b.read_value(&memory, fp).unwrap().to_usize(); - let addr_output = res.read_value(&memory, fp).unwrap().to_usize(); + let addr_input_a = arg_a.read_value(memory, fp).unwrap().to_usize(); + let addr_input_b = arg_b.read_value(memory, fp).unwrap().to_usize(); + let addr_output = res.read_value(memory, fp).unwrap().to_usize(); let value_a = memory.get_vectorized_slice(addr_input_a, 2).unwrap(); let value_b = memory.get_vector(addr_input_b).unwrap().to_vec(); let output = memory.get_vector(addr_output).unwrap(); @@ -167,9 +175,9 @@ pub(crate) fn get_execution_trace( res, size, } => { - let addr_0 = arg0.read_value(&memory, fp).unwrap().to_usize(); - let addr_1 = arg1.read_value(&memory, fp).unwrap().to_usize(); - let addr_res = res.read_value(&memory, fp).unwrap().to_usize(); + let addr_0 = arg0.read_value(memory, fp).unwrap().to_usize(); + let addr_1 = arg1.read_value(memory, fp).unwrap().to_usize(); + let addr_res = res.read_value(memory, fp).unwrap().to_usize(); let slice_0 = memory .get_vectorized_slice_extension(addr_0, *size) .unwrap(); @@ -194,9 +202,9 @@ pub(crate) fn get_execution_trace( res, n_vars, } => { - let addr_coeffs = coeffs.read_value(&memory, fp).unwrap().to_usize(); - let addr_point = point.read_value(&memory, fp).unwrap().to_usize(); - let addr_res = res.read_value(&memory, fp).unwrap().to_usize(); + let addr_coeffs = coeffs.read_value(memory, fp).unwrap().to_usize(); + let addr_point = point.read_value(memory, fp).unwrap().to_usize(); + let addr_res = res.read_value(memory, fp).unwrap().to_usize(); let point = memory .get_vectorized_slice_extension(addr_point, *n_vars) .unwrap(); diff --git a/crates/zk_vm/src/instruction_encoder.rs b/crates/zk_vm/src/instruction_encoder.rs index 65fe66fb..469b6d0d 100644 --- a/crates/zk_vm/src/instruction_encoder.rs +++ b/crates/zk_vm/src/instruction_encoder.rs @@ -1,9 +1,13 @@ use p3_field::PrimeCharacteristicRing; -use vm::*; +use vm::{F, Instruction, MemOrConstant, MemOrFp, MemOrFpOrConstant, Operation}; -use crate::*; +use crate::{ + COL_INDEX_AUX, COL_INDEX_FLAG_A, COL_INDEX_FLAG_B, COL_INDEX_FLAG_C, + COL_INDEX_MULTILINEAR_EVAL, COL_INDEX_OPERAND_A, COL_INDEX_OPERAND_B, COL_INDEX_OPERAND_C, + N_INSTRUCTION_COLUMNS, +}; -pub(crate) fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { +pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { let mut fields = [F::ZERO; N_INSTRUCTION_COLUMNS]; match instr { Instruction::Computation { diff --git a/crates/zk_vm/src/lib.rs b/crates/zk_vm/src/lib.rs index c4ff2328..c1d7b321 100644 --- a/crates/zk_vm/src/lib.rs +++ b/crates/zk_vm/src/lib.rs @@ -52,6 +52,7 @@ const COL_INDEX_MEM_ADDRESS_A: usize = 20; const COL_INDEX_MEM_ADDRESS_B: usize = 21; const COL_INDEX_MEM_ADDRESS_C: usize = 22; +#[allow(clippy::range_plus_one)] fn exec_column_groups() -> Vec> { [ (0..N_INSTRUCTION_COLUMNS_IN_AIR) @@ -72,7 +73,7 @@ fn exec_column_groups() -> Vec> { pub fn compile_and_run(program: &str, public_input: &[vm::F], private_input: &[vm::F]) { let bytecode = compile_program(program); - execute_bytecode(&bytecode, &public_input, private_input); + let _ = execute_bytecode(&bytecode, public_input, private_input); } pub trait InAirColumnIndex { diff --git a/crates/zk_vm/src/poseidon_tables.rs b/crates/zk_vm/src/poseidon_tables.rs index 8bc14a69..a6d3125c 100644 --- a/crates/zk_vm/src/poseidon_tables.rs +++ b/crates/zk_vm/src/poseidon_tables.rs @@ -6,7 +6,7 @@ use vm::F; use crate::execution_trace::{WitnessPoseidon16, WitnessPoseidon24}; -pub(crate) fn build_poseidon_columns( +pub fn build_poseidon_columns( poseidons_16: &[WitnessPoseidon16], poseidons_24: &[WitnessPoseidon24], ) -> (Vec>, Vec>) { @@ -21,7 +21,7 @@ pub(crate) fn build_poseidon_columns( (cols_16, cols_24) } -pub(crate) fn all_poseidon_16_indexes(poseidons_16: &[WitnessPoseidon16]) -> Vec { +pub fn all_poseidon_16_indexes(poseidons_16: &[WitnessPoseidon16]) -> Vec { padd_with_zero_to_next_power_of_two( &[ poseidons_16 @@ -41,7 +41,7 @@ pub(crate) fn all_poseidon_16_indexes(poseidons_16: &[WitnessPoseidon16]) -> Vec ) } -pub(crate) fn all_poseidon_24_indexes(poseidons_24: &[WitnessPoseidon24]) -> Vec { +pub fn all_poseidon_24_indexes(poseidons_24: &[WitnessPoseidon24]) -> Vec { padd_with_zero_to_next_power_of_two( &[ padd_with_zero_to_next_power_of_two( diff --git a/crates/zk_vm/src/prove_execution.rs b/crates/zk_vm/src/prove_execution.rs index bac608ae..0ebaabc0 100644 --- a/crates/zk_vm/src/prove_execution.rs +++ b/crates/zk_vm/src/prove_execution.rs @@ -1,42 +1,46 @@ -use crate::common::*; -use crate::dot_product_air::DOT_PRODUCT_AIR_COLUMN_GROUPS; -use crate::dot_product_air::DotProductAir; -use crate::dot_product_air::build_dot_product_columns; -use crate::execution_trace::ExecutionTrace; -use crate::execution_trace::get_execution_trace; -use crate::poseidon_tables::build_poseidon_columns; -use crate::poseidon_tables::*; -use crate::{air::VMAir, *}; -use ::air::prove_many_air_2; -use ::air::{table::AirTable, witness::AirWitness}; -use lookup::prove_gkr_product; -use lookup::{compute_pushforward, prove_logup_star}; +use ::air::{prove_many_air_2, table::AirTable, witness::AirWitness}; +use lookup::{compute_pushforward, prove_gkr_product, prove_logup_star}; use p3_air::BaseAir; -use p3_field::BasedVectorSpace; -use p3_field::PrimeCharacteristicRing; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; use p3_util::{log2_ceil_usize, log2_strict_usize}; -use pcs::num_packed_vars_for_pols; -use pcs::{BatchPCS, packed_pcs_commit, packed_pcs_global_statements}; +use pcs::{BatchPCS, num_packed_vars_for_pols, packed_pcs_commit, packed_pcs_global_statements}; use rayon::prelude::*; use sumcheck::MleGroupRef; use tracing::info_span; -use utils::ToUsize; -use utils::assert_eq_many; -use utils::dot_product_with_base; -use utils::field_slice_as_base; -use utils::fold_multilinear_in_large_field; -use utils::pack_extension; -use utils::to_big_endian_bits; use utils::{ - Evaluation, PF, build_poseidon_16_air, build_poseidon_24_air, build_prover_state, - padd_with_zero_to_next_power_of_two, + Evaluation, PF, ToUsize, assert_eq_many, build_poseidon_16_air, build_poseidon_24_air, + build_prover_state, dot_product_with_base, field_slice_as_base, + fold_multilinear_in_large_field, pack_extension, padd_with_zero_to_next_power_of_two, + to_big_endian_bits, +}; +use vm::{ + Bytecode, DIMENSION, EF, F, POSEIDON_16_NULL_HASH_PTR, POSEIDON_24_NULL_HASH_PTR, + execute_bytecode, +}; +use whir_p3::{ + dft::EvalsDft, + poly::{ + evals::{EvaluationsList, eval_eq, fold_multilinear}, + multilinear::MultilinearPoint, + }, + utils::compute_eval_eq, +}; + +use crate::{ + COL_INDEX_FP, COL_INDEX_MEM_ADDRESS_A, COL_INDEX_MEM_ADDRESS_C, COL_INDEX_MEM_VALUE_A, + COL_INDEX_MEM_VALUE_B, COL_INDEX_MEM_VALUE_C, COL_INDEX_PC, N_INSTRUCTION_COLUMNS, + N_INSTRUCTION_COLUMNS_IN_AIR, N_TOTAL_COLUMNS, UNIVARIATE_SKIPS, + air::VMAir, + common::{ + DotProductFootprint, PrecompileFootprint, fold_bytecode, intitial_and_final_pc_conditions, + poseidon_16_column_groups, poseidon_24_column_groups, poseidon_lookup_index_statements, + poseidon_lookup_value, + }, + dot_product_air::{DOT_PRODUCT_AIR_COLUMN_GROUPS, DotProductAir, build_dot_product_columns}, + exec_column_groups, + execution_trace::{ExecutionTrace, get_execution_trace}, + poseidon_tables::{all_poseidon_16_indexes, all_poseidon_24_indexes, build_poseidon_columns}, }; -use vm::Bytecode; -use vm::*; -use whir_p3::dft::EvalsDft; -use whir_p3::poly::evals::{eval_eq, fold_multilinear}; -use whir_p3::poly::{evals::EvaluationsList, multilinear::MultilinearPoint}; -use whir_p3::utils::compute_eval_eq; pub fn prove_execution( bytecode: &Bytecode, @@ -55,8 +59,8 @@ pub fn prove_execution( public_memory_size, memory, } = info_span!("Witness generation").in_scope(|| { - let execution_result = execute_bytecode(&bytecode, &public_input, private_input); - get_execution_trace(&bytecode, &execution_result) + let execution_result = execute_bytecode(bytecode, public_input, private_input); + get_execution_trace(bytecode, &execution_result) }); let public_memory = &memory[..public_memory_size]; @@ -77,8 +81,7 @@ pub fn prove_execution( exec_columns.extend( full_trace[N_INSTRUCTION_COLUMNS..] .iter() - .map(Vec::as_slice) - .collect::>(), + .map(Vec::as_slice), ); let exec_witness = AirWitness::>::new(&exec_columns, &exec_column_groups()); let exec_table = AirTable::::new(VMAir); @@ -440,7 +443,7 @@ pub fn prove_execution( point: MultilinearPoint( [ p16_mixing_scalars_grand_product.0.clone(), - grand_product_p16_statement.point.0.clone(), + grand_product_p16_statement.point.0, ] .concat(), ), @@ -471,7 +474,7 @@ pub fn prove_execution( point: MultilinearPoint( [ p24_mixing_scalars_grand_product.0.clone(), - grand_product_p24_statement.point.0.clone(), + grand_product_p24_statement.point.0, ] .concat(), ), @@ -493,11 +496,11 @@ pub fn prove_execution( MleGroupRef::Extension( dot_product_columns[..5] .iter() - .map(|c| c.as_slice()) + .map(std::vec::Vec::as_slice) .collect::>(), ), // TODO packing &DotProductFootprint { - grand_product_challenge_global: grand_product_challenge_global, + grand_product_challenge_global, grand_product_challenge_dot_product: grand_product_challenge_dot_product .clone() .try_into() @@ -530,7 +533,7 @@ pub fn prove_execution( grand_product_dot_product_table_indexes_mixing_challenges .0 .clone(), - grand_product_dot_product_sumcheck_point.0.clone(), + grand_product_dot_product_sumcheck_point.0, ] .concat(), ), @@ -548,10 +551,13 @@ pub fn prove_execution( 1, // TODO univariate skip? MleGroupRef::Base( // TODO not all columns re required - full_trace.iter().map(|c| c.as_slice()).collect::>(), + full_trace + .iter() + .map(std::vec::Vec::as_slice) + .collect::>(), ), // TODO packing &PrecompileFootprint { - grand_product_challenge_global: grand_product_challenge_global, + grand_product_challenge_global, grand_product_challenge_p16: grand_product_challenge_p16.try_into().unwrap(), grand_product_challenge_p24: grand_product_challenge_p24.try_into().unwrap(), grand_product_challenge_dot_product: grand_product_challenge_dot_product @@ -726,7 +732,7 @@ pub fn prove_execution( (poseidons_24.par_iter().map(|p| (&p.input[0..8]).evaluate(&memory_folding_challenges)).collect::>(), poseidon24_steps), (poseidons_24.par_iter().map(|p| (&p.input[8..16]).evaluate(&memory_folding_challenges)).collect::>(), poseidon24_steps), (poseidons_24.par_iter().map(|p| (&p.input[16..24]).evaluate(&memory_folding_challenges)).collect::>(), poseidon24_steps), - (poseidons_24.par_iter().map(|p| (&p.output).evaluate(&memory_folding_challenges)).collect::>(), poseidon24_steps), + (poseidons_24.par_iter().map(|p| p.output.evaluate(&memory_folding_challenges)).collect::>(), poseidon24_steps), ]; for (chunk_idx, (values, step)) in chunks.into_iter().enumerate() { let offset = chunk_idx * max_n_poseidons; @@ -801,10 +807,10 @@ pub fn prove_execution( ) .evaluate(&bytecode_compression_challenges), }; - let bytecode_lookup_point_2 = grand_product_exec_sumcheck_point.clone(); + let bytecode_lookup_point_2 = grand_product_exec_sumcheck_point; let bytecode_lookup_claim_2 = Evaluation { point: bytecode_lookup_point_2.clone(), - value: padd_with_zero_to_next_power_of_two(&grand_product_exec_evals_on_each_column) + value: padd_with_zero_to_next_power_of_two(grand_product_exec_evals_on_each_column) .evaluate(&bytecode_compression_challenges), }; let alpha_bytecode_lookup = prover_state.sample(); @@ -889,7 +895,7 @@ pub fn prove_execution( let poseidon_lookup_memory_point = MultilinearPoint( [ poseidon_logup_star_statements.on_table.point.0.clone(), - memory_folding_challenges.0.clone(), + memory_folding_challenges.0, ] .concat(), ); @@ -1036,7 +1042,7 @@ pub fn prove_execution( // Second Opening let global_statements_extension = packed_pcs_global_statements( &packed_pcs_witness_extension.tree, - &vec![ + &[ exec_logup_star_statements.on_pushforward, poseidon_logup_star_statements.on_pushforward, bytecode_logup_star_statements.on_pushforward, diff --git a/crates/zk_vm/src/recursion.rs b/crates/zk_vm/src/recursion.rs index 28b72e6e..b32aa947 100644 --- a/crates/zk_vm/src/recursion.rs +++ b/crates/zk_vm/src/recursion.rs @@ -1,8 +1,7 @@ use std::marker::PhantomData; use compiler::compile_program; -use p3_field::BasedVectorSpace; -use p3_field::PrimeCharacteristicRing; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; use pcs::WhirBatchPcs; use rand::{Rng, SeedableRng, rngs::StdRng}; use utils::{ @@ -22,8 +21,7 @@ use whir_p3::{ }, }; -use crate::prove_execution::prove_execution; -use crate::verify_execution::verify_execution; +use crate::{prove_execution::prove_execution, verify_execution::verify_execution}; #[test] pub fn test_whir_verif() { @@ -39,9 +37,9 @@ pub fn run_whir_verif() { const F_BITS = 31; // koala-bear = 31 bits const N_VARS = 25; - const LOG_INV_RATE = 2; + const LOG_INV_RATE = 2; const N_ROUNDS = 3; - + const PADDING_FOR_INITIAL_MERKLE_LEAVES = 7; const FOLDING_FACTOR_0 = 7; @@ -84,15 +82,15 @@ pub fn run_whir_verif() { claimed_sum_side = mul_extension_ret(combination_randomness_gen_0, pcs_eval); claimed_sum_0 = add_extension_ret(ood_eval_0, claimed_sum_side); domain_size_0 = N_VARS + LOG_INV_RATE; - fs_state_5, folding_randomness_1, ood_point_1, root_1, circle_values_1, combination_randomness_powers_1, claimed_sum_1 = + fs_state_5, folding_randomness_1, ood_point_1, root_1, circle_values_1, combination_randomness_powers_1, claimed_sum_1 = whir_round(fs_state_4, root_0, FOLDING_FACTOR_0, 2**FOLDING_FACTOR_0, 1, NUM_QUERIES_0, domain_size_0, claimed_sum_0, GRINDING_BITS_0); domain_size_1 = domain_size_0 - RS_REDUCTION_FACTOR_0; - fs_state_6, folding_randomness_2, ood_point_2, root_2, circle_values_2, combination_randomness_powers_2, claimed_sum_2 = + fs_state_6, folding_randomness_2, ood_point_2, root_2, circle_values_2, combination_randomness_powers_2, claimed_sum_2 = whir_round(fs_state_5, root_1, FOLDING_FACTOR_1, 2**FOLDING_FACTOR_1, 0, NUM_QUERIES_1, domain_size_1, claimed_sum_1, GRINDING_BITS_1); domain_size_2 = domain_size_1 - RS_REDUCTION_FACTOR_1; - fs_state_7, folding_randomness_3, ood_point_3, root_3, circle_values_3, combination_randomness_powers_3, claimed_sum_3 = + fs_state_7, folding_randomness_3, ood_point_3, root_3, circle_values_3, combination_randomness_powers_3, claimed_sum_3 = whir_round(fs_state_6, root_2, FOLDING_FACTOR_2, 2**FOLDING_FACTOR_2, 0, NUM_QUERIES_2, domain_size_2, claimed_sum_2, GRINDING_BITS_2); domain_size_3 = domain_size_2 - RS_REDUCTION_FACTOR_2; @@ -101,7 +99,7 @@ pub fn run_whir_verif() { fs_state_10, final_circle_values, final_folds = sample_stir_indexes_and_fold(fs_state_9, NUM_QUERIES_3, 0, FOLDING_FACTOR_3, 2**FOLDING_FACTOR_3, domain_size_3, root_3, folding_randomness_4, GRINDING_BITS_3); - + for i in 0..NUM_QUERIES_3 { powers_of_2_rev = powers_of_two_rev_base(final_circle_values[i], FINAL_VARS); poly_eq = poly_eq_base(powers_of_2_rev, FINAL_VARS, 2**FINAL_VARS); @@ -280,7 +278,7 @@ pub fn run_whir_verif() { answers = malloc(num_queries); // a vector of vectorized pointers, each pointing to `two_pow_folding_factor` field elements (base if first rounds, extension otherwise) fs_states_b = malloc(num_queries + 1); fs_states_b[0] = fs_state_9; - + // the number of chunk of 8 field elements per merkle leaf opened if is_first_round == 1 { n_chuncks_per_answer = two_pow_folding_factor / 8; // "/ 8" because initial merkle leaves are in the basefield @@ -289,7 +287,7 @@ pub fn run_whir_verif() { } for i in 0..num_queries { - new_fs_state, answer = fs_hint(fs_states_b[i], n_chuncks_per_answer); + new_fs_state, answer = fs_hint(fs_states_b[i], n_chuncks_per_answer); fs_states_b[i + 1] = new_fs_state; answers[i] = answer; } @@ -363,8 +361,8 @@ pub fn run_whir_verif() { fs_state_7, folding_randomness, new_claimed_sum_a = sumcheck(fs_state, folding_factor, claimed_sum); fs_state_8, root, ood_point, ood_eval = parse_commitment(fs_state_7); - - fs_state_11, circle_values, folds = + + fs_state_11, circle_values, folds = sample_stir_indexes_and_fold(fs_state_8, num_queries, is_first_round, folding_factor, two_pow_folding_factor, domain_size, prev_root, folding_randomness, grinding_bits); fs_state_12, combination_randomness_gen = fs_sample_ef(fs_state_11); @@ -511,7 +509,7 @@ pub fn run_whir_verif() { b_ptr = b * 8; res_ptr = res * 8; - + prods = malloc(n * 8); for i in 0..n unroll { for j in 0..8 unroll { @@ -549,7 +547,7 @@ pub fn run_whir_verif() { mul_extension(point, inner_res + i, res + two_pow_n_minus_1 + i); sub_extension(inner_res + i, res + two_pow_n_minus_1 + i, res + i); } - + return res; } @@ -574,7 +572,7 @@ pub fn run_whir_verif() { res[two_pow_n_minus_1 + i] = inner_res[i] * point[0]; res[i] = inner_res[i] - res[two_pow_n_minus_1 + i]; } - + return res; } @@ -644,7 +642,7 @@ pub fn run_whir_verif() { fs_state_3, ood_eval = fs_receive(fs_state_2, 1); // vectorized pointer of len 1 return fs_state_3, root, ood_point, ood_eval; } - + // FIAT SHAMIR layout: // 0 -> transcript (vectorized pointer) // 1 -> vectorized pointer to first half of sponge state @@ -672,7 +670,7 @@ pub fn run_whir_verif() { transcript_ptr = fs_state[0] * 8; l_ptr = fs_state[1] * 8; - + new_l = malloc_vec(1); new_l_ptr = new_l * 8; new_l_ptr[0] = transcript_ptr[0]; @@ -770,7 +768,7 @@ pub fn run_whir_verif() { new_fs_state[1] = fs_state[1]; new_fs_state[2] = fs_state[2]; new_fs_state[3] = fs_state[3]; - return new_fs_state, res; + return new_fs_state, res; } fn fs_receive(fs_state, n) -> 2 { @@ -824,7 +822,7 @@ pub fn run_whir_verif() { // ap = a * 8; // bp = b * 8; // cp = c * 8; - + // cp[0] = (ap[0] * bp[0]) + W * ((ap[1] * bp[7]) + (ap[2] * bp[6]) + (ap[3] * bp[5]) + (ap[4] * bp[4]) + (ap[5] * bp[3]) + (ap[6] * bp[2]) + (ap[7] * bp[1])); // cp[1] = (ap[1] * bp[0]) + (ap[0] * bp[1]) + W * ((ap[2] * bp[7]) + (ap[3] * bp[6]) + (ap[4] * bp[5]) + (ap[5] * bp[4]) + (ap[6] * bp[3]) + (ap[7] * bp[2])); // cp[2] = (ap[2] * bp[0]) + (ap[1] * bp[1]) + (ap[0] * bp[2]) + W * ((ap[3] * bp[7]) + (ap[4] * bp[6]) + (ap[5] * bp[5]) + (ap[6] * bp[4]) + (ap[7] * bp[3])); @@ -874,7 +872,7 @@ pub fn run_whir_verif() { } return; } - + fn assert_eq_extension(a, b) { null_ptr = pointer_to_zero_vector; // TODO avoid having to store this in a variable add_extension(a, null_ptr, b); @@ -997,10 +995,7 @@ pub fn run_whir_verif() { % (1 << first_folding_factor)); assert_eq!(proof_data_padding % 8, 0); proof_data_padding /= 8; - println!( - "1st merkle leaf padding: {} (vectorized)", - proof_data_padding - ); // to align the first merkle leaves (in base field) + println!("1st merkle leaf padding: {proof_data_padding} (vectorized)"); // to align the first merkle leaves (in base field) public_input.extend(F::zero_vec(proof_data_padding * 8)); public_input.extend(prover_state.proof_data()[commitment_size..].to_vec()); diff --git a/crates/zk_vm/src/verify_execution.rs b/crates/zk_vm/src/verify_execution.rs index 1f5079fa..a7a533e1 100644 --- a/crates/zk_vm/src/verify_execution.rs +++ b/crates/zk_vm/src/verify_execution.rs @@ -1,28 +1,38 @@ -use crate::common::*; -use crate::dot_product_air::DOT_PRODUCT_AIR_COLUMN_GROUPS; -use crate::dot_product_air::DotProductAir; -use ::air::table::AirTable; -use ::air::verify_many_air_2; -use lookup::verify_gkr_product; -use lookup::verify_logup_star; +use ::air::{table::AirTable, verify_many_air_2}; +use lookup::{verify_gkr_product, verify_logup_star}; use p3_air::BaseAir; -use p3_field::BasedVectorSpace; -use p3_field::PrimeCharacteristicRing; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; use p3_util::{log2_ceil_usize, log2_strict_usize}; -use pcs::num_packed_vars_for_vars; -use pcs::packed_pcs_global_statements; -use pcs::{BatchPCS, NumVariables as _, packed_pcs_parse_commitment}; +use pcs::{ + BatchPCS, NumVariables as _, num_packed_vars_for_vars, packed_pcs_global_statements, + packed_pcs_parse_commitment, +}; use sumcheck::SumcheckComputation; -use utils::dot_product_with_base; -use utils::to_big_endian_bits; -use utils::{Evaluation, PF, build_challenger, padd_with_zero_to_next_power_of_two}; -use utils::{ToUsize, build_poseidon_16_air, build_poseidon_24_air}; -use vm::*; -use whir_p3::fiat_shamir::{errors::ProofError, verifier::VerifierState}; -use whir_p3::poly::evals::EvaluationsList; -use whir_p3::poly::multilinear::MultilinearPoint; - -use crate::{air::VMAir, *}; +use utils::{ + Evaluation, PF, ToUsize, build_challenger, build_poseidon_16_air, build_poseidon_24_air, + dot_product_with_base, padd_with_zero_to_next_power_of_two, to_big_endian_bits, +}; +use vm::{ + Bytecode, DIMENSION, EF, F, POSEIDON_16_NULL_HASH_PTR, POSEIDON_24_NULL_HASH_PTR, + build_public_memory, +}; +use whir_p3::{ + fiat_shamir::{errors::ProofError, verifier::VerifierState}, + poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, +}; + +use crate::{ + COL_INDEX_FP, COL_INDEX_MEM_VALUE_A, COL_INDEX_MEM_VALUE_B, COL_INDEX_MEM_VALUE_C, + N_INSTRUCTION_COLUMNS, N_INSTRUCTION_COLUMNS_IN_AIR, N_TOTAL_COLUMNS, UNIVARIATE_SKIPS, + air::VMAir, + common::{ + DotProductFootprint, PrecompileFootprint, fold_bytecode, intitial_and_final_pc_conditions, + poseidon_16_column_groups, poseidon_24_column_groups, poseidon_lookup_index_statements, + poseidon_lookup_value, + }, + dot_product_air::{DOT_PRODUCT_AIR_COLUMN_GROUPS, DotProductAir}, + exec_column_groups, +}; pub fn verify_execution( bytecode: &Bytecode, @@ -30,6 +40,15 @@ pub fn verify_execution( proof_data: Vec>, pcs: &impl BatchPCS, EF, EF>, ) -> Result<(), ProofError> { + struct RowMultilinearEval { + addr_coeffs: F, + addr_point: F, + addr_res: F, + n_vars: F, + point: Vec, + res: EF, + } + let mut verifier_state = VerifierState::new(proof_data, build_challenger()); let exec_table = AirTable::::new(VMAir); @@ -51,7 +70,7 @@ pub fn verify_execution( ] = verifier_state .next_base_scalars_const::<8>()? .into_iter() - .map(|x| x.to_usize()) + .map(utils::ToUsize::to_usize) .collect::>() .try_into() .unwrap(); @@ -90,15 +109,6 @@ pub fn verify_execution( let vars_private_memory = vec![log_public_memory; n_private_memory_chunks]; - struct RowMultilinearEval { - addr_coeffs: F, - addr_point: F, - addr_res: F, - n_vars: F, - point: Vec, - res: EF, - } - let mut vm_multilinear_evals = Vec::new(); for _ in 0..n_vm_multilinear_evals { let [addr_coeffs, addr_point, addr_res, n_vars] = @@ -298,7 +308,7 @@ pub fn verify_execution( point: MultilinearPoint( [ p16_mixing_scalars_grand_product.0.clone(), - grand_product_p16_statement.point.0.clone(), + grand_product_p16_statement.point.0, ] .concat(), ), @@ -330,7 +340,7 @@ pub fn verify_execution( point: MultilinearPoint( [ p24_mixing_scalars_grand_product.0.clone(), - grand_product_p24_statement.point.0.clone(), + grand_product_p24_statement.point.0, ] .concat(), ), @@ -358,7 +368,7 @@ pub fn verify_execution( .eq_poly_outside(&grand_product_dot_product_statement.point) * { DotProductFootprint { - grand_product_challenge_global: grand_product_challenge_global, + grand_product_challenge_global, grand_product_challenge_dot_product: grand_product_challenge_dot_product .clone() .try_into() @@ -386,7 +396,7 @@ pub fn verify_execution( grand_product_dot_product_table_indexes_mixing_challenges .0 .clone(), - grand_product_dot_product_sumcheck_claim.point.0.clone(), + grand_product_dot_product_sumcheck_claim.point.0, ] .concat(), ), @@ -417,7 +427,7 @@ pub fn verify_execution( .eq_poly_outside(&grand_product_exec_statement.point) * { PrecompileFootprint { - grand_product_challenge_global: grand_product_challenge_global, + grand_product_challenge_global, grand_product_challenge_p16: grand_product_challenge_p16.try_into().unwrap(), grand_product_challenge_p24: grand_product_challenge_p24.try_into().unwrap(), grand_product_challenge_dot_product: grand_product_challenge_dot_product @@ -504,7 +514,7 @@ pub fn verify_execution( } let bytecode_lookup_point_1 = exec_evals_to_verify[0].point.clone(); let bytecode_lookup_claim_1 = Evaluation { - point: bytecode_lookup_point_1.clone(), + point: bytecode_lookup_point_1, value: padd_with_zero_to_next_power_of_two( &[ (0..N_INSTRUCTION_COLUMNS_IN_AIR) @@ -518,8 +528,8 @@ pub fn verify_execution( }; let bytecode_lookup_claim_2 = Evaluation { - point: grand_product_exec_sumcheck_claim.point.clone(), - value: padd_with_zero_to_next_power_of_two(&grand_product_exec_evals_on_each_column) + point: grand_product_exec_sumcheck_claim.point, + value: padd_with_zero_to_next_power_of_two(grand_product_exec_evals_on_each_column) .evaluate(&bytecode_compression_challenges), }; let alpha_bytecode_lookup = verifier_state.sample(); @@ -563,8 +573,8 @@ pub fn verify_execution( let poseidon_lookup_value = poseidon_lookup_value( n_poseidons_16, n_poseidons_24, - &p16_evals_to_verify, - &p24_evals_to_verify, + p16_evals_to_verify, + p24_evals_to_verify, &poseidon_lookup_batching_chalenges, ); let poseidon_lookup_challenge = Evaluation { @@ -598,7 +608,7 @@ pub fn verify_execution( let poseidon_lookup_memory_point = MultilinearPoint( [ poseidon_logup_star_statements.on_table.point.0.clone(), - memory_folding_challenges.0.clone(), + memory_folding_challenges.0, ] .concat(), ); @@ -647,14 +657,17 @@ pub fn verify_execution( let mut chunk_evals_dot_product_lookup = vec![public_memory.evaluate(&dot_product_lookup_chunk_point)]; - for i in 0..n_private_memory_chunks { + for private_memory_statement in private_memory_statements + .iter_mut() + .take(n_private_memory_chunks) + { let chunk_eval_exec_lookup = verifier_state.next_extension_scalar()?; let chunk_eval_poseidon_lookup = verifier_state.next_extension_scalar()?; let chunk_eval_dot_product_lookup = verifier_state.next_extension_scalar()?; chunk_evals_exec_lookup.push(chunk_eval_exec_lookup); chunk_evals_poseidon_lookup.push(chunk_eval_poseidon_lookup); chunk_evals_dot_product_lookup.push(chunk_eval_dot_product_lookup); - private_memory_statements[i].extend(vec![ + private_memory_statement.extend(vec![ Evaluation { point: exec_lookup_chunk_point.clone(), value: chunk_eval_exec_lookup, @@ -783,7 +796,7 @@ pub fn verify_execution( let global_statements_extension = packed_pcs_global_statements( &parsed_commitment_extension.tree, - &vec![ + &[ exec_logup_star_statements.on_pushforward, poseidon_logup_star_statements.on_pushforward, bytecode_logup_star_statements.on_pushforward,