diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index 8b4eac5c..5955ec4c 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -122,7 +122,7 @@ impl>, A: NormalAir, AP: PackedAir> AirTable>, ) -> Vec> { - prove_air::, EF, A, AP>(prover_state, univariate_skips, &self, witness) + prove_air::, EF, A, AP>(prover_state, univariate_skips, self, witness) } #[instrument(name = "air: prove in extension", skip_all)] @@ -132,7 +132,7 @@ impl>, A: NormalAir, AP: PackedAir> AirTable, ) -> Vec> { - prove_air::(prover_state, univariate_skips, &self, witness) + prove_air::(prover_state, univariate_skips, self, witness) } } @@ -224,9 +224,13 @@ fn open_structured_columns<'a, EF: ExtensionField> + ExtensionField, let mut column_scalars = vec![]; let mut index = 0; for group in &witness.column_groups { - for i in index..index + group.len() { - column_scalars.push(poly_eq_batching_scalars[i]); - } + column_scalars.extend( + poly_eq_batching_scalars + .iter() + .skip(index) + .take(group.len()) + .copied(), + ); index += witness.max_columns_per_group().next_power_of_two(); } diff --git a/crates/air/src/table.rs b/crates/air/src/table.rs index e7891928..fdf4bfb5 100644 --- a/crates/air/src/table.rs +++ b/crates/air/src/table.rs @@ -81,18 +81,18 @@ impl>, A: NormalAir, AP: PackedAir> AirTable() == TypeId::of::() { unsafe { - self.air - .eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>( - &mut constraints_checker, - )); + self.air.eval(transmute::< + &mut ConstraintChecker<'_, IF, EF>, + &mut ConstraintChecker<'_, EF, EF>, + >(&mut constraints_checker)); } } else { assert_eq!(TypeId::of::(), TypeId::of::>()); unsafe { - self.air - .eval(transmute::<_, &mut ConstraintChecker<'_, PF, EF>>( - &mut constraints_checker, - )); + self.air.eval(transmute::< + &mut ConstraintChecker<'_, IF, EF>, + &mut ConstraintChecker<'_, PF, EF>, + >(&mut constraints_checker)); } } handle_errors(row, &mut constraints_checker)?; @@ -110,18 +110,18 @@ impl>, A: NormalAir, AP: PackedAir> AirTable() == TypeId::of::() { unsafe { - self.air - .eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>( - &mut constraints_checker, - )); + self.air.eval(transmute::< + &mut ConstraintChecker<'_, IF, EF>, + &mut ConstraintChecker<'_, EF, EF>, + >(&mut constraints_checker)); } } else { assert_eq!(TypeId::of::(), TypeId::of::>()); unsafe { - self.air - .eval(transmute::<_, &mut ConstraintChecker<'_, PF, EF>>( - &mut constraints_checker, - )); + self.air.eval(transmute::< + &mut ConstraintChecker<'_, IF, EF>, + &mut ConstraintChecker<'_, PF, EF>, + >(&mut constraints_checker)); } } handle_errors(row, &mut constraints_checker)?; diff --git a/crates/air/src/test.rs b/crates/air/src/test.rs index 15a02bb3..62c5c333 100644 --- a/crates/air/src/test.rs +++ b/crates/air/src/test.rs @@ -105,16 +105,41 @@ fn generate_structured_trace>()); } let mut witness_cols = vec![vec![F::ZERO]; N_COLUMNS - N_PREPROCESSED_COLUMNS]; - for i in 1..n_rows { - for j in 0..N_COLUMNS - N_PREPROCESSED_COLUMNS { - let witness_cols_j_i_min_1 = witness_cols[j][i - 1]; - witness_cols[j].push( - witness_cols_j_i_min_1 - + F::from_usize(j + N_PREPROCESSED_COLUMNS) - + (0..N_PREPROCESSED_COLUMNS) - .map(|k| trace[k][i]) - .product::(), - ); + let mut prev_values = vec![F::ZERO; N_COLUMNS - N_PREPROCESSED_COLUMNS]; + let mut column_iters = trace[..N_PREPROCESSED_COLUMNS] + .iter() + .map(|col| col.iter()) + .collect::>(); + if column_iters.is_empty() { + trace.extend(witness_cols); + return trace; + } + for iter in &mut column_iters { + iter.next(); // skip first row, already initialised + } + loop { + let mut row_product = F::ONE; + let mut progressed = true; + for iter in &mut column_iters { + match iter.next() { + Some(value) => row_product *= *value, + None => { + progressed = false; + break; + } + } + } + if !progressed { + break; + } + for (j, (witness_col, prev)) in witness_cols + .iter_mut() + .zip(prev_values.iter_mut()) + .enumerate() + { + let next_val = *prev + F::from_usize(j + N_PREPROCESSED_COLUMNS) + row_product; + witness_col.push(next_val); + *prev = next_val; } } trace.extend(witness_cols); @@ -131,14 +156,31 @@ fn generate_unstructured_trace>()); } let mut witness_cols = vec![vec![]; N_COLUMNS - N_PREPROCESSED_COLUMNS]; - for i in 0..n_rows { - for j in 0..N_COLUMNS - N_PREPROCESSED_COLUMNS { - witness_cols[j].push( - F::from_usize(j + N_PREPROCESSED_COLUMNS) - + (0..N_PREPROCESSED_COLUMNS) - .map(|k| trace[k][i]) - .product::(), - ); + let mut column_iters = trace[..N_PREPROCESSED_COLUMNS] + .iter() + .map(|col| col.iter()) + .collect::>(); + if column_iters.is_empty() { + trace.extend(witness_cols); + return trace; + } + loop { + let mut row_product = F::ONE; + let mut progressed = true; + for iter in &mut column_iters { + match iter.next() { + Some(value) => row_product *= *value, + None => { + progressed = false; + break; + } + } + } + if !progressed { + break; + } + for (j, witness_col) in witness_cols.iter_mut().enumerate() { + witness_col.push(F::from_usize(j + N_PREPROCESSED_COLUMNS) + row_product); } } trace.extend(witness_cols); diff --git a/crates/air/src/verify.rs b/crates/air/src/verify.rs index e0a881d1..29892701 100644 --- a/crates/air/src/verify.rs +++ b/crates/air/src/verify.rs @@ -84,16 +84,18 @@ fn verify_air>, A: NormalAir, AP: PackedAir>( if structured_air { verify_structured_columns( verifier_state, - table.n_columns(), - univariate_skips, - &inner_sums, - &column_groups, - &Evaluation::new( - outer_statement.point[1..log_length - univariate_skips + 1].to_vec(), - outer_statement.value, - ), - &outer_selector_evals, - log_length, + StructuredColumnsArgs { + n_columns: table.n_columns(), + univariate_skips, + all_inner_sums: &inner_sums, + column_groups, + outer_sumcheck_challenge: &Evaluation::new( + outer_statement.point[1..log_length - univariate_skips + 1].to_vec(), + outer_statement.value, + ), + outer_selector_evals: &outer_selector_evals, + log_n_rows: log_length, + }, ) } else { verify_many_unstructured_columns( @@ -181,16 +183,30 @@ fn verify_many_unstructured_columns>>( Ok(evaluations_remaining_to_verify) } -fn verify_structured_columns>>( - verifier_state: &mut FSVerifier>, +#[derive(Debug)] +struct StructuredColumnsArgs<'a, EF> { n_columns: usize, univariate_skips: usize, - all_inner_sums: &[EF], - column_groups: &[Range], - outer_sumcheck_challenge: &Evaluation, - outer_selector_evals: &[EF], + all_inner_sums: &'a [EF], + column_groups: &'a [Range], + outer_sumcheck_challenge: &'a Evaluation, + outer_selector_evals: &'a [EF], log_n_rows: usize, +} + +fn verify_structured_columns>>( + verifier_state: &mut FSVerifier>, + args: StructuredColumnsArgs<'_, EF>, ) -> Result>, ProofError> { + let StructuredColumnsArgs { + n_columns, + univariate_skips, + all_inner_sums, + column_groups, + outer_sumcheck_challenge, + outer_selector_evals, + log_n_rows, + } = args; let log_n_groups = log2_ceil_usize(column_groups.len()); let max_columns_per_group = Iterator::max(column_groups.iter().map(|g| g.len())).unwrap(); let log_max_columns_per_group = log2_ceil_usize(max_columns_per_group); @@ -201,9 +217,13 @@ fn verify_structured_columns>>( let mut column_scalars = vec![]; let mut index = 0; for group in column_groups { - for i in index..index + group.len() { - column_scalars.push(poly_eq_batching_scalars[i]); - } + column_scalars.extend( + poly_eq_batching_scalars + .iter() + .skip(index) + .take(group.len()) + .copied(), + ); index += max_columns_per_group.next_power_of_two(); } diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 193b52df..f3ce3113 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -460,8 +460,10 @@ fn simplify_lines( unimplemented!("Reverse for non-unrolled loops are not implemented yet"); } - let mut loop_const_malloc = ConstMalloc::default(); - loop_const_malloc.counter = const_malloc.counter; + let mut loop_const_malloc = ConstMalloc { + counter: const_malloc.counter, + ..ConstMalloc::default() + }; let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); array_manager.valid.clear(); let simplified_body = simplify_lines( @@ -678,16 +680,15 @@ fn simplify_expr( match expr { Expression::Value(value) => value.simplify_if_const(), Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - if let Some(label) = const_malloc.map.get(array_var) { - if let Ok(mut offset) = ConstExpression::try_from(*index.clone()) { - offset = offset.try_naive_simplification(); - return SimpleExpr::ConstMallocAccess { - malloc_label: *label, - offset, - }; - } - } + if let SimpleExpr::Var(array_var) = array + && let Some(label) = const_malloc.map.get(array_var) + && let Ok(mut offset) = ConstExpression::try_from(*index.clone()) + { + offset = offset.try_naive_simplification(); + return SimpleExpr::ConstMallocAccess { + malloc_label: *label, + offset, + }; } let aux_arr = array_manager.get_aux_var(array, index); // auxiliary var to store m[array + index] @@ -1082,30 +1083,27 @@ fn handle_array_assignment( ) { let simplified_index = simplify_expr(index, res, counters, array_manager, const_malloc); - if let SimpleExpr::Constant(offset) = simplified_index.clone() { - if let SimpleExpr::Var(array_var) = &array { - if let Some(label) = const_malloc.map.get(array_var) { - if let ArrayAccessType::ArrayIsAssigned(Expression::Binary { - left, - operation, - right, - }) = access_type - { - let arg0 = simplify_expr(&left, res, counters, array_manager, const_malloc); - let arg1 = simplify_expr(&right, res, counters, array_manager, const_malloc); - res.push(SimpleLine::Assignment { - var: VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *label, - offset, - }, - operation, - arg0, - arg1, - }); - return; - } - } - } + if let SimpleExpr::Constant(offset) = simplified_index.clone() + && let SimpleExpr::Var(array_var) = &array + && let Some(label) = const_malloc.map.get(array_var) + && let ArrayAccessType::ArrayIsAssigned(Expression::Binary { + left, + operation, + right, + }) = &access_type + { + let arg0 = simplify_expr(left, res, counters, array_manager, const_malloc); + let arg1 = simplify_expr(right, res, counters, array_manager, const_malloc); + res.push(SimpleLine::Assignment { + var: VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *label, + offset, + }, + operation: *operation, + arg0, + arg1, + }); + return; } let value_simplified = match access_type { diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index ea49e88d..8b80f685 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -47,7 +47,7 @@ impl Compiler { } impl SimpleExpr { - fn into_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediaryMemOrFpOrConstant { + fn to_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediaryMemOrFpOrConstant { match self { Self::Var(var) => IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset: compiler.get_offset(&var.clone().into()), @@ -368,7 +368,7 @@ fn compile_lines( } SimpleLine::RawAccess { res, index, shift } => { - validate_vars_declared(&[index.clone()], declared_vars)?; + validate_vars_declared(std::slice::from_ref(index), declared_vars)?; if let SimpleExpr::Var(var) = res { declared_vars.insert(var.clone()); } @@ -379,7 +379,7 @@ fn compile_lines( instructions.push(IntermediateInstruction::Deref { shift_0, shift_1: shift.clone(), - res: res.into_mem_after_fp_or_constant(compiler), + res: res.to_mem_after_fp_or_constant(compiler), }); } @@ -623,10 +623,10 @@ fn validate_vars_declared>( declared: &BTreeSet, ) -> Result<(), String> { for voc in vocs { - if let SimpleExpr::Var(v) = voc.borrow() { - if !declared.contains(v) { - return Err(format!("Variable {v} not declared")); - } + if let SimpleExpr::Var(v) = voc.borrow() + && !declared.contains(v) + { + return Err(format!("Variable {v} not declared")); } } Ok(()) @@ -665,7 +665,7 @@ fn setup_function_call( instructions.push(IntermediateInstruction::Deref { shift_0: new_fp_pos.into(), shift_1: (2 + i).into(), - res: arg.into_mem_after_fp_or_constant(compiler), + res: arg.to_mem_after_fp_or_constant(compiler), }); } diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index b59035e5..8b4787bc 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -152,23 +152,23 @@ fn compile_block( mut arg_c, res, } => { - if let Some(arg_a_cst) = try_as_constant(&arg_a, compiler) { - if let Some(arg_b_cst) = try_as_constant(&arg_c, compiler) { - // res = constant +/x constant + if let Some(arg_a_cst) = try_as_constant(&arg_a, compiler) + && let Some(arg_b_cst) = try_as_constant(&arg_c, compiler) + { + // res = constant +/x constant - let op_res = operation.compute(arg_a_cst, arg_b_cst); + let op_res = operation.compute(arg_a_cst, arg_b_cst); - let res: MemOrFp = res.try_into_mem_or_fp(compiler).unwrap(); + let res: MemOrFp = res.try_into_mem_or_fp(compiler).unwrap(); - low_level_bytecode.push(Instruction::Computation { - operation: Operation::Add, - arg_a: MemOrConstant::zero(), - arg_c: res, - res: MemOrConstant::Constant(op_res), - }); - pc += 1; - continue; - } + low_level_bytecode.push(Instruction::Computation { + operation: Operation::Add, + arg_a: MemOrConstant::zero(), + arg_c: res, + res: MemOrConstant::Constant(op_res), + }); + pc += 1; + continue; } if arg_c.is_constant() { diff --git a/crates/lean_prover/src/common.rs b/crates/lean_prover/src/common.rs index a0e63849..3ef9460f 100644 --- a/crates/lean_prover/src/common.rs +++ b/crates/lean_prover/src/common.rs @@ -18,12 +18,12 @@ pub fn get_base_dims( log_public_memory: usize, private_memory_len: usize, bytecode_ending_pc: usize, - n_poseidons_16: usize, - n_poseidons_24: usize, - p16_air_width: usize, - p24_air_width: usize, + poseidon_counts: (usize, usize), + poseidon_widths: (usize, usize), n_rows_table_dot_products: usize, ) -> Vec> { + let (n_poseidons_16, n_poseidons_24) = poseidon_counts; + let (p16_air_width, p24_air_width) = poseidon_widths; let (default_p16_row, default_p24_row) = build_poseidon_columns( &[WitnessPoseidon16::poseidon_of_zero()], &[WitnessPoseidon24::poseidon_of_zero()], diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 4125b12e..f5d7f671 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -211,10 +211,8 @@ pub fn prove_execution( log_public_memory, private_memory.len(), bytecode.ending_pc, - n_poseidons_16, - n_poseidons_24, - p16_air.width(), - p24_air.width(), + (n_poseidons_16, n_poseidons_24), + (p16_air.width(), p24_air.width()), n_rows_table_dot_products, ); @@ -810,12 +808,14 @@ pub fn prove_execution( let index_a: F = dot_product_columns[2][i].as_base().unwrap(); let index_b: F = dot_product_columns[3][i].as_base().unwrap(); let index_res: F = dot_product_columns[4][i].as_base().unwrap(); - for j in 0..DIMENSION { - dot_product_indexes_spread[j][i] = index_a + F::from_usize(j); - dot_product_indexes_spread[j][i + dot_product_table_length] = - index_b + F::from_usize(j); - dot_product_indexes_spread[j][i + 2 * dot_product_table_length] = - index_res + F::from_usize(j); + for (j, column) in dot_product_indexes_spread + .iter_mut() + .enumerate() + .take(DIMENSION) + { + column[i] = index_a + F::from_usize(j); + column[i + dot_product_table_length] = index_b + F::from_usize(j); + column[i + 2 * dot_product_table_length] = index_res + F::from_usize(j); } } let dot_product_values_spread = dot_product_indexes_spread @@ -992,7 +992,7 @@ pub fn prove_execution( &second_batched_whir_config_builder::( whir_config_builder.clone(), log2_strict_usize(packed_pcs_witness_base.packed_polynomial.len()), - num_packed_vars_for_dims::(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK), + num_packed_vars_for_dims::(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK), ), &extension_pols, &extension_dims, diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index e8ae8eb7..b69489fd 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -174,10 +174,8 @@ pub fn verify_execution( log_public_memory, private_memory_len, bytecode.ending_pc, - n_poseidons_16, - n_poseidons_24, - p16_air.width(), - p24_air.width(), + (n_poseidons_16, n_poseidons_24), + (p16_air.width(), p24_air.width()), n_rows_table_dot_products, ); @@ -548,7 +546,7 @@ pub fn verify_execution( &second_batched_whir_config_builder::( whir_config_builder.clone(), parsed_commitment_base.num_variables, - num_packed_vars_for_dims::(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK), + num_packed_vars_for_dims::(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK), ), &mut verifier_state, &extension_dims, @@ -859,8 +857,12 @@ pub fn verify_execution( ); let mut dot_product_indexes_inner_evals_incr = vec![EF::ZERO; 8]; - for i in 0..DIMENSION { - dot_product_indexes_inner_evals_incr[i] = dot_product_logup_star_indexes_inner_value + for (i, value) in dot_product_indexes_inner_evals_incr + .iter_mut() + .enumerate() + .take(DIMENSION) + { + *value = dot_product_logup_star_indexes_inner_value + EF::from_usize(i) * [F::ONE, F::ONE, F::ONE, F::ZERO].evaluate(&MultilinearPoint( mem_lookup_eval_indexes_partial_point.0[3 + index_diff..5 + index_diff] diff --git a/crates/lean_prover/witness_generation/src/execution_trace.rs b/crates/lean_prover/witness_generation/src/execution_trace.rs index 02d651b4..f803ae24 100644 --- a/crates/lean_prover/witness_generation/src/execution_trace.rs +++ b/crates/lean_prover/witness_generation/src/execution_trace.rs @@ -267,10 +267,13 @@ pub fn get_execution_trace( } // repeat the last row to get to a power of two - for j in 0..N_INSTRUCTION_COLUMNS + N_EXEC_COLUMNS { - let last_value = trace[j][n_cycles - 1]; - for i in n_cycles..(1 << log_n_cycles_rounded_up) { - trace[j][i] = last_value; + for column in trace + .iter_mut() + .take(N_INSTRUCTION_COLUMNS + N_EXEC_COLUMNS) + { + let last_value = column[n_cycles - 1]; + for cell in column.iter_mut().skip(n_cycles) { + *cell = last_value; } } diff --git a/crates/lean_vm/src/memory.rs b/crates/lean_vm/src/memory.rs index 91a72f5b..a528244c 100644 --- a/crates/lean_vm/src/memory.rs +++ b/crates/lean_vm/src/memory.rs @@ -51,8 +51,8 @@ impl Memory { pub fn get_ef_element(&self, index: usize) -> Result { // index: non vectorized pointer let mut coeffs = [F::ZERO; DIMENSION]; - for i in 0..DIMENSION { - coeffs[i] = self.get(index + i)?; + for (offset, coeff) in coeffs.iter_mut().enumerate() { + *coeff = self.get(index + offset)?; } Ok(EF::from_basis_coefficients_slice(&coeffs).unwrap()) } diff --git a/crates/lean_vm/src/runner.rs b/crates/lean_vm/src/runner.rs index ccfacb6a..ee366261 100644 --- a/crates/lean_vm/src/runner.rs +++ b/crates/lean_vm/src/runner.rs @@ -99,14 +99,16 @@ pub fn execute_bytecode( let mut instruction_history = ExecutionHistory::default(); let first_exec = match execute_bytecode_helper( bytecode, - public_input, - private_input, - MAX_MEMORY_SIZE / 2, - false, - &mut std_out, - &mut instruction_history, - false, - function_locations, + ExecuteBytecodeParams { + public_input, + private_input, + no_vec_runtime_memory: MAX_MEMORY_SIZE / 2, + final_execution: false, + std_out: &mut std_out, + instruction_history: &mut instruction_history, + profiler: false, + function_locations, + }, ) { Ok(first_exec) => first_exec, Err(err) => { @@ -129,14 +131,16 @@ pub fn execute_bytecode( instruction_history = ExecutionHistory::default(); execute_bytecode_helper( bytecode, - public_input, - private_input, - first_exec.no_vec_runtime_memory, - true, - &mut String::new(), - &mut instruction_history, - profiler, - function_locations, + ExecuteBytecodeParams { + public_input, + private_input, + no_vec_runtime_memory: first_exec.no_vec_runtime_memory, + final_execution: true, + std_out: &mut String::new(), + instruction_history: &mut instruction_history, + profiler, + function_locations, + }, ) .unwrap() } @@ -157,14 +161,24 @@ pub fn build_public_memory(public_input: &[F]) -> Vec { public_memory[PUBLIC_INPUT_START..][..public_input.len()].copy_from_slice(public_input); // "zero" vector - for i in ZERO_VEC_PTR * VECTOR_LEN..(ZERO_VEC_PTR + 2) * VECTOR_LEN { - public_memory[i] = F::ZERO; + let zero_start = ZERO_VEC_PTR * VECTOR_LEN; + for slot in public_memory + .iter_mut() + .skip(zero_start) + .take(2 * VECTOR_LEN) + { + *slot = F::ZERO; } // "one" vector public_memory[ONE_VEC_PTR * VECTOR_LEN] = F::ONE; - for i in ONE_VEC_PTR * VECTOR_LEN + 1..(ONE_VEC_PTR + 1) * VECTOR_LEN { - public_memory[i] = F::ZERO; + let one_start = ONE_VEC_PTR * VECTOR_LEN + 1; + for slot in public_memory + .iter_mut() + .skip(one_start) + .take(VECTOR_LEN - 1) + { + *slot = F::ZERO; } public_memory @@ -176,17 +190,32 @@ pub fn build_public_memory(public_input: &[F]) -> Vec { public_memory } -fn execute_bytecode_helper( - bytecode: &Bytecode, - public_input: &[F], - private_input: &[F], +#[derive(Debug)] +struct ExecuteBytecodeParams<'a> { + public_input: &'a [F], + private_input: &'a [F], no_vec_runtime_memory: usize, final_execution: bool, - std_out: &mut String, - instruction_history: &mut ExecutionHistory, + std_out: &'a mut String, + instruction_history: &'a mut ExecutionHistory, profiler: bool, - function_locations: &BTreeMap, + function_locations: &'a BTreeMap, +} + +fn execute_bytecode_helper( + bytecode: &Bytecode, + params: ExecuteBytecodeParams<'_>, ) -> Result { + let ExecuteBytecodeParams { + public_input, + private_input, + no_vec_runtime_memory, + final_execution, + std_out, + instruction_history, + profiler, + function_locations, + } = params; let poseidon_16 = get_poseidon16(); let poseidon_24 = get_poseidon24(); diff --git a/crates/lookup/src/quotient_gkr.rs b/crates/lookup/src/quotient_gkr.rs index 0c3259ee..6ff6076d 100644 --- a/crates/lookup/src/quotient_gkr.rs +++ b/crates/lookup/src/quotient_gkr.rs @@ -215,7 +215,7 @@ where fn prove_gkr_quotient_step_packed( prover_state: &mut FSProver>, - up_layer_packed: &Vec>, + up_layer_packed: &[EFPacking], claim: &Evaluation, ) -> (Evaluation, EF, EF) where diff --git a/crates/packed_pcs/src/lib.rs b/crates/packed_pcs/src/lib.rs index 74328071..f3c82aba 100644 --- a/crates/packed_pcs/src/lib.rs +++ b/crates/packed_pcs/src/lib.rs @@ -146,7 +146,7 @@ fn split_in_chunks( } } -fn compute_chunks>( +fn compute_chunks( dims: &[ColDims], log_smallest_decomposition_chunk: usize, ) -> (BTreeMap>, usize) { @@ -178,11 +178,11 @@ fn compute_chunks>( (chunks_decomposition, packed_n_vars) } -pub fn num_packed_vars_for_dims>( +pub fn num_packed_vars_for_dims( dims: &[ColDims], log_smallest_decomposition_chunk: usize, ) -> usize { - let (_, packed_n_vars) = compute_chunks::(dims, log_smallest_decomposition_chunk); + let (_, packed_n_vars) = compute_chunks::(dims, log_smallest_decomposition_chunk); packed_n_vars } @@ -193,7 +193,7 @@ pub struct MultiCommitmentWitness> { } #[instrument(skip_all)] -pub fn packed_pcs_commit, H, C>( +pub fn packed_pcs_commit( whir_config_builder: &WhirConfigBuilder, polynomials: &[&[F]], dims: &[ColDims], @@ -202,9 +202,9 @@ pub fn packed_pcs_commit, H, C>( log_smallest_decomposition_chunk: usize, ) -> MultiCommitmentWitness where + F: Field + TwoAdicField + ExtensionField>, PF: TwoAdicField, EF: ExtensionField + TwoAdicField + ExtensionField>, - F: TwoAdicField + ExtensionField>, H: MerkleHasher, C: MerkleCompress, [PF; MY_DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, @@ -221,7 +221,7 @@ where ); } let (chunks_decomposition, packed_n_vars) = - compute_chunks::(dims, log_smallest_decomposition_chunk); + compute_chunks::(dims, log_smallest_decomposition_chunk); { // logging @@ -292,17 +292,16 @@ pub fn packed_pcs_global_statements_for_prover< // - current packing is not optimal in the end: can lead to [16][4][2][2] (instead of [16][8]) let (chunks_decomposition, packed_vars) = - compute_chunks::(dims, log_smallest_decomposition_chunk); + compute_chunks::(dims, log_smallest_decomposition_chunk); let statements_flattened = statements_per_polynomial .iter() .enumerate() - .map(|(poly_index, poly_statements)| { + .flat_map(|(poly_index, poly_statements)| { poly_statements .iter() .map(move |statement| (poly_index, statement)) }) - .flatten() .collect::>(); let sub_packed_statements_and_evals_to_send = statements_flattened @@ -311,7 +310,9 @@ pub fn packed_pcs_global_statements_for_prover< let dim = &dims[*poly_index]; let pol = polynomials[*poly_index]; - let chunks = &chunks_decomposition[&poly_index]; + let chunks = chunks_decomposition + .get(poly_index) + .expect("missing chunk definition for polynomial"); assert!(!chunks.is_empty()); let mut sub_packed_statements = Vec::new(); let mut evals_to_send = Vec::new(); @@ -352,14 +353,14 @@ pub fn packed_pcs_global_statements_for_prover< if !initial_booleans.is_empty() && initial_booleans.len() < offset_in_original_booleans.len() - && &initial_booleans - == &offset_in_original_booleans[..initial_booleans.len()] + && initial_booleans + == offset_in_original_booleans[..initial_booleans.len()] { tracing::warn!("TODO: sparse statement accroos mutiple chunks"); } if initial_booleans.len() >= offset_in_original_booleans.len() { - if &initial_booleans[..missing_vars] != &offset_in_original_booleans { + if initial_booleans[..missing_vars] != offset_in_original_booleans { // this chunk is not concerned by this sparse evaluation return (None, EF::ZERO); } else { @@ -440,7 +441,7 @@ pub fn packed_pcs_parse_commitment< dims: &[ColDims], log_smallest_decomposition_chunk: usize, ) -> Result, ProofError> { - let (_, packed_n_vars) = compute_chunks::(&dims, log_smallest_decomposition_chunk); + let (_, packed_n_vars) = compute_chunks::(dims, log_smallest_decomposition_chunk); WhirConfig::new(whir_config_builder.clone(), packed_n_vars).parse_commitment(verifier_state) } @@ -456,12 +457,14 @@ pub fn packed_pcs_global_statements_for_verifier< ) -> Result>, ProofError> { assert_eq!(dims.len(), statements_per_polynomial.len()); let (chunks_decomposition, packed_n_vars) = - compute_chunks::(dims, log_smallest_decomposition_chunk); + compute_chunks::(dims, log_smallest_decomposition_chunk); let mut packed_statements = Vec::new(); for (poly_index, statements) in statements_per_polynomial.iter().enumerate() { let dim = &dims[poly_index]; let has_public_data = dim.log_public.is_some(); - let chunks = &chunks_decomposition[&poly_index]; + let chunks = chunks_decomposition + .get(&poly_index) + .expect("missing chunk definition for polynomial"); assert!(!chunks.is_empty()); for statement in statements { if chunks.len() == 1 { @@ -494,7 +497,7 @@ pub fn packed_pcs_global_statements_for_verifier< to_big_endian_bits(chunk.offset_in_original >> chunk.n_vars, missing_vars); if initial_booleans.len() >= offset_in_original_booleans.len() { - if &initial_booleans[..missing_vars] != &offset_in_original_booleans { + if initial_booleans[..missing_vars] != offset_in_original_booleans { // this chunk is not concerned by this sparse evaluation sub_values.push(EF::ZERO); } else { diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index a7df67da..e213679a 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -224,9 +224,7 @@ fn test_xmss_aggregate() { &xmss_signature_size_padded.to_string(), ); - let bitfield = (0..n_public_keys) - .map(|i| i % INV_BITFIELD_DENSITY == 0) - .collect::>(); + let bitfield = vec![true; n_public_keys]; let mut rng = StdRng::seed_from_u64(0); let message_hash: [F; 8] = rng.random(); diff --git a/crates/sumcheck/src/mle.rs b/crates/sumcheck/src/mle.rs index 42ad758c..bef7e780 100644 --- a/crates/sumcheck/src/mle.rs +++ b/crates/sumcheck/src/mle.rs @@ -306,21 +306,25 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { } } - pub fn sumcheck_compute( + pub fn sumcheck_compute<'params, SC, SCP>( &self, - zs: &[usize], - skips: usize, - eq_mle: Option<&Mle>, - folding_scalars: &[Vec>], - computation: &SC, - computation_packed: &SCP, - batching_scalars: &[EF], - missing_mul_factor: Option, + params: SumcheckComputeParams<'params, EF, SC, SCP>, ) -> Vec<(PF, EF)> where SC: SumcheckComputation, EF> + SumcheckComputation, SCP: SumcheckComputationPacked, { + let SumcheckComputeParams { + zs, + skips, + eq_mle, + folding_scalars, + computation, + computation_packed, + batching_scalars, + missing_mul_factor, + } = params; + let fold_size = 1 << (self.n_vars() - skips); let packed_fold_size = if self.is_packed() { fold_size / packing_width::() @@ -441,52 +445,87 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { let eq_mle = eq_mle.map(|eq_mle| eq_mle.as_extension().unwrap().as_slice()); sumcheck_compute_not_packed( multilinears, - zs, - skips, - eq_mle, - folding_scalars, - computation, - batching_scalars, - missing_mul_factor, - fold_size, + SumcheckComputeNotPackedParams { + zs, + skips, + eq_mle, + folding_scalars, + computation, + batching_scalars, + missing_mul_factor, + fold_size, + }, ) } Self::Extension(multilinears) => { let eq_mle = eq_mle.map(|eq_mle| eq_mle.as_extension().unwrap().as_slice()); sumcheck_compute_not_packed( multilinears, - zs, - skips, - eq_mle, - folding_scalars, - computation, - batching_scalars, - missing_mul_factor, - fold_size, + SumcheckComputeNotPackedParams { + zs, + skips, + eq_mle, + folding_scalars, + computation, + batching_scalars, + missing_mul_factor, + fold_size, + }, ) } } } } +#[derive(Debug)] +pub struct SumcheckComputeParams<'a, EF: ExtensionField>, SC, SCP> { + pub zs: &'a [usize], + pub skips: usize, + pub eq_mle: Option<&'a Mle>, + pub folding_scalars: &'a [Vec>], + pub computation: &'a SC, + pub computation_packed: &'a SCP, + pub batching_scalars: &'a [EF], + pub missing_mul_factor: Option, +} + +#[derive(Debug)] +pub struct SumcheckComputeNotPackedParams<'a, EF, SC> +where + EF: ExtensionField>, +{ + pub zs: &'a [usize], + pub skips: usize, + pub eq_mle: Option<&'a [EF]>, + pub folding_scalars: &'a [Vec>], + pub computation: &'a SC, + pub batching_scalars: &'a [EF], + pub missing_mul_factor: Option, + pub fold_size: usize, +} + pub fn sumcheck_compute_not_packed< EF: ExtensionField> + ExtensionField, IF: ExtensionField>, SC, >( multilinears: &[&[IF]], - zs: &[usize], - skips: usize, - eq_mle: Option<&[EF]>, - folding_scalars: &[Vec>], - computation: &SC, - batching_scalars: &[EF], - missing_mul_factor: Option, - fold_size: usize, + params: SumcheckComputeNotPackedParams<'_, EF, SC>, ) -> Vec<(PF, EF)> where SC: SumcheckComputation, { + let SumcheckComputeNotPackedParams { + zs, + skips, + eq_mle, + folding_scalars, + computation, + batching_scalars, + missing_mul_factor, + fold_size, + } = params; + let all_sums = unsafe { uninitialized_vec::(zs.len() * fold_size) }; // sums for zs[0], sums for zs[1], ... (0..fold_size).into_par_iter().for_each(|i| { let eq_mle_eval = eq_mle.as_ref().map(|eq_mle| eq_mle[i]); diff --git a/crates/sumcheck/src/prove.rs b/crates/sumcheck/src/prove.rs index f618d6da..2fa946ba 100644 --- a/crates/sumcheck/src/prove.rs +++ b/crates/sumcheck/src/prove.rs @@ -14,6 +14,7 @@ use crate::Mle; use crate::MleGroup; use crate::SumcheckComputation; use crate::SumcheckComputationPacked; +use crate::SumcheckComputeParams; #[allow(clippy::too_many_arguments)] pub fn prove<'a, EF, SC, SCP, M: Into>>( @@ -151,16 +152,20 @@ where }) .collect::>>>(); - p_evals.extend(multilinears.by_ref().sumcheck_compute( - &zs, - skips, - eq_factor.as_ref().map(|(_, eq_mle)| eq_mle), - &folding_scalars, - computation, - computations_packed, - batching_scalars, - missing_mul_factor, - )); + p_evals.extend( + multilinears + .by_ref() + .sumcheck_compute(SumcheckComputeParams { + zs: &zs, + skips, + eq_mle: eq_factor.as_ref().map(|(_, eq_mle)| eq_mle), + folding_scalars: &folding_scalars, + computation, + computation_packed: computations_packed, + batching_scalars, + missing_mul_factor, + }), + ); if !is_zerofier { let missing_sum_z = if let Some((eq_factor, _)) = eq_factor { diff --git a/crates/utils/src/display.rs b/crates/utils/src/display.rs index b5faf93f..df4451e9 100644 --- a/crates/utils/src/display.rs +++ b/crates/utils/src/display.rs @@ -5,7 +5,7 @@ pub fn pretty_integer(i: usize) -> String { let mut result = String::new(); for (index, ch) in chars.iter().enumerate() { - if index > 0 && (chars.len() - index) % 3 == 0 { + if index > 0 && (chars.len() - index).is_multiple_of(3) { result.push(','); } result.push(*ch); diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index ecfe6051..9eba2d28 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -30,25 +30,25 @@ pub const fn diff_to_next_power_of_two(n: usize) -> usize { } pub fn left_mut(slice: &mut [A]) -> &mut [A] { - assert!(slice.len() % 2 == 0); + assert!(slice.len().is_multiple_of(2)); let mid = slice.len() / 2; &mut slice[..mid] } pub fn right_mut(slice: &mut [A]) -> &mut [A] { - assert!(slice.len() % 2 == 0); + assert!(slice.len().is_multiple_of(2)); let mid = slice.len() / 2; &mut slice[mid..] } pub fn left_ref(slice: &[A]) -> &[A] { - assert!(slice.len() % 2 == 0); + assert!(slice.len().is_multiple_of(2)); let mid = slice.len() / 2; &slice[..mid] } pub fn right_ref(slice: &[A]) -> &[A] { - assert!(slice.len() % 2 == 0); + assert!(slice.len().is_multiple_of(2)); let mid = slice.len() / 2; &slice[mid..] } diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index 6365ca22..1d151c4f 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -21,7 +21,7 @@ pub fn fold_multilinear_in_small_field, D>( let dim = >::DIMENSION; let m_transmuted: &[F] = - unsafe { std::slice::from_raw_parts(std::mem::transmute(m.as_ptr()), m.len() * dim) }; + unsafe { std::slice::from_raw_parts(m.as_ptr().cast::(), m.len() * dim) }; let res_transmuted = { let new_size = m.len() * dim / scalars.len(); @@ -183,19 +183,17 @@ pub fn multilinear_eval_constants_at_right(limit: usize, point: &[F]) return F::ZERO; } - if point.len() == 0 { + if point.is_empty() { assert!(limit <= 1); if limit == 1 { F::ZERO } else { F::ONE } } else { let main_bit = limit >> (n_vars - 1); if main_bit == 1 { // limit is at the right half - return point[0] - * multilinear_eval_constants_at_right(limit - (1 << (n_vars - 1)), &point[1..]); + point[0] * multilinear_eval_constants_at_right(limit - (1 << (n_vars - 1)), &point[1..]) } else { // limit is at left half - return point[0] - + (F::ONE - point[0]) * multilinear_eval_constants_at_right(limit, &point[1..]); + point[0] + (F::ONE - point[0]) * multilinear_eval_constants_at_right(limit, &point[1..]) } } } @@ -279,9 +277,9 @@ mod tests { let n_point_vars = 7; let mut rng = StdRng::seed_from_u64(0); let mut pol = F::zero_vec(1 << n_point_vars); - for i in 0..(1 << n_vars) { - pol[i] = rng.random(); - } + pol.iter_mut() + .take(1 << n_vars) + .for_each(|coeff| *coeff = rng.random()); let point = (0..n_point_vars).map(|_| rng.random()).collect::>(); assert_eq!( evaluate_as_larger_multilinear_pol(&pol[..1 << n_vars], &point), @@ -297,9 +295,10 @@ mod tests { for limit in [0, 1, 2, 45, 74, 451, 741, 1022, 1023] { let eval = multilinear_eval_constants_at_right(limit, &point); let mut pol = F::zero_vec(1 << n_vars); - for i in limit..(1 << n_vars) { - pol[i] = F::ONE; - } + pol.iter_mut() + .take(1 << n_vars) + .skip(limit) + .for_each(|coeff| *coeff = F::ONE); assert_eq!(eval, pol.evaluate(&MultilinearPoint(point.clone()))); } } diff --git a/crates/utils/src/univariate.rs b/crates/utils/src/univariate.rs index ddace427..bb60085b 100644 --- a/crates/utils/src/univariate.rs +++ b/crates/utils/src/univariate.rs @@ -7,10 +7,10 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; type CacheKey = (TypeId, usize); +type CacheValue = Arc>>; +type SelectorsCache = Mutex>; -static SELECTORS_CACHE: OnceLock< - Mutex>>>>, -> = OnceLock::new(); +static SELECTORS_CACHE: OnceLock = OnceLock::new(); pub fn univariate_selectors(n: usize) -> Arc>> { let key = (TypeId::of::(), n); diff --git a/src/examples/prove_poseidon2.rs b/src/examples/prove_poseidon2.rs index bb380ba4..86eacd27 100644 --- a/src/examples/prove_poseidon2.rs +++ b/src/examples/prove_poseidon2.rs @@ -5,13 +5,20 @@ use p3_field::PrimeField64; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; use p3_symmetric::Permutation; use p3_util::{log2_ceil_usize, log2_strict_usize}; -use packed_pcs::*; +use packed_pcs::{ + ColDims, packed_pcs_commit, packed_pcs_global_statements_for_prover, + packed_pcs_global_statements_for_verifier, packed_pcs_parse_commitment, +}; use rand::{Rng, SeedableRng, rngs::StdRng}; +use std::collections::BTreeMap; use std::fmt; +use std::ops::Range; use std::time::{Duration, Instant}; use utils::{ - build_merkle_compress, build_merkle_hash, build_poseidon_16_air, build_poseidon_16_air_packed, - build_poseidon_24_air, build_poseidon_24_air_packed, build_prover_state, build_verifier_state, + FSProver, MY_DIGEST_ELEMS, MyChallenger, MyMerkleCompress, MyMerkleHash, MyWhirConfigBuilder, + PF, PFPacking, Poseidon16Air, Poseidon24Air, build_merkle_compress, build_merkle_hash, + build_poseidon_16_air, build_poseidon_16_air_packed, build_poseidon_24_air, + build_poseidon_24_air_packed, build_prover_state, build_verifier_state, generate_trace_poseidon_16, generate_trace_poseidon_24, get_poseidon16, get_poseidon24, init_tracing, padd_with_zero_to_next_power_of_two, }; @@ -20,6 +27,7 @@ use whir_p3::whir::config::{FoldingFactor, SecurityAssumption, WhirConfig, WhirC type F = KoalaBear; type EF = QuinticExtensionFieldKB; +type MyWhirConfig = WhirConfig, EF, MyMerkleHash, MyMerkleCompress, MY_DIGEST_ELEMS>; #[derive(Clone, Debug)] pub struct Poseidon2Benchmark { @@ -38,7 +46,7 @@ impl fmt::Display for Poseidon2Benchmark { 1 << self.log_n_poseidons_16, 1 << self.log_n_poseidons_24, self.prover_time.as_millis() as f64 / 1000.0, - (((1 << self.log_n_poseidons_16) + (1 << self.log_n_poseidons_24)) as f64 + (f64::from((1 << self.log_n_poseidons_16) + (1 << self.log_n_poseidons_24)) / self.prover_time.as_secs_f64()) .round() as usize )?; @@ -51,25 +59,43 @@ impl fmt::Display for Poseidon2Benchmark { } } -pub fn prove_poseidon2( - log_n_poseidons_16: usize, - log_n_poseidons_24: usize, - univariate_skips: usize, - folding_factor: FoldingFactor, - log_inv_rate: usize, - soundness_type: SecurityAssumption, - pow_bits: usize, - security_level: usize, - rs_domain_initial_reduction_factor: usize, - max_num_variables_to_send_coeffs: usize, - display_logs: bool, -) -> Poseidon2Benchmark { - if display_logs { - init_tracing(); - } +#[derive(Clone, Debug)] +pub struct Poseidon2Config { + pub log_n_poseidons_16: usize, + pub log_n_poseidons_24: usize, + pub univariate_skips: usize, + pub folding_factor: FoldingFactor, + pub log_inv_rate: usize, + pub soundness_type: SecurityAssumption, + pub pow_bits: usize, + pub security_level: usize, + pub rs_domain_initial_reduction_factor: usize, + pub max_num_variables_to_send_coeffs: usize, + pub display_logs: bool, +} + +struct PoseidonSetup { + n_columns_24: usize, + log_table_area_16: usize, + log_table_area_24: usize, + witness_columns_16: Vec>, + witness_columns_24: Vec>, + column_groups_16: Vec>, + column_groups_24: Vec>, + table_16: AirTable, Poseidon16Air>>, + table_24: AirTable, Poseidon24Air>>, +} + +struct ProverArtifacts { + prover_time: Duration, + whir_config_builder: MyWhirConfigBuilder, + whir_config: MyWhirConfig, + dims: [ColDims; 2], +} - let n_poseidons_16 = 1 << log_n_poseidons_16; - let n_poseidons_24 = 1 << log_n_poseidons_24; +fn prepare_poseidon(config: &Poseidon2Config) -> PoseidonSetup { + let n_poseidons_16 = 1 << config.log_n_poseidons_16; + let n_poseidons_24 = 1 << config.log_n_poseidons_24; let poseidon_air_16 = build_poseidon_16_air(); let poseidon_air_16_packed = build_poseidon_16_air_packed(); @@ -78,8 +104,8 @@ pub fn prove_poseidon2( let n_columns_16 = poseidon_air_16.width(); let n_columns_24 = poseidon_air_24.width(); - let log_table_area_16 = log_n_poseidons_16 + log2_ceil_usize(n_columns_16); - let log_table_area_24 = log_n_poseidons_24 + log2_ceil_usize(n_columns_24); + let log_table_area_16 = config.log_n_poseidons_16 + log2_ceil_usize(n_columns_16); + let log_table_area_24 = config.log_n_poseidons_24 + log2_ceil_usize(n_columns_24); let mut rng = StdRng::seed_from_u64(0); let inputs_16: Vec<[F; 16]> = (0..n_poseidons_16).map(|_| Default::default()).collect(); @@ -114,76 +140,103 @@ pub fn prove_poseidon2( .to_vec() }) .collect::>(); - let column_groups_16 = vec![0..n_columns_16]; - let column_groups_24 = vec![0..n_columns_24]; - let witness_16 = AirWitness::new(&witness_columns_16, &column_groups_16); - let witness_24 = AirWitness::new(&witness_columns_24, &column_groups_24); - - let table_16 = AirTable::::new(poseidon_air_16, poseidon_air_16_packed); - let table_24 = AirTable::::new(poseidon_air_24, poseidon_air_24_packed); - - let t = Instant::now(); + let column_groups_16 = vec![Range { + start: 0, + end: n_columns_16, + }]; + let column_groups_24 = vec![Range { + start: 0, + end: n_columns_24, + }]; + + let table_16: AirTable, Poseidon16Air>> = + AirTable::new(poseidon_air_16, poseidon_air_16_packed); + let table_24: AirTable, Poseidon24Air>> = + AirTable::new(poseidon_air_24, poseidon_air_24_packed); + + PoseidonSetup { + n_columns_24, + log_table_area_16, + log_table_area_24, + witness_columns_16, + witness_columns_24, + column_groups_16, + column_groups_24, + table_16, + table_24, + } +} - let mut prover_state = build_prover_state(); +fn run_prover_phase( + config: &Poseidon2Config, + setup: &PoseidonSetup, + witness_16: AirWitness<'_, F>, + witness_24: AirWitness<'_, F>, + prover_state: &mut FSProver, +) -> ProverArtifacts { + let start = Instant::now(); let whir_config_builder = WhirConfigBuilder { - folding_factor, - soundness_type, + folding_factor: config.folding_factor, + soundness_type: config.soundness_type, merkle_hash: build_merkle_hash(), merkle_compress: build_merkle_compress(), - pow_bits, - max_num_variables_to_send_coeffs, - rs_domain_initial_reduction_factor, - security_level, - starting_log_inv_rate: log_inv_rate, + pow_bits: config.pow_bits, + max_num_variables_to_send_coeffs: config.max_num_variables_to_send_coeffs, + rs_domain_initial_reduction_factor: config.rs_domain_initial_reduction_factor, + security_level: config.security_level, + starting_log_inv_rate: config.log_inv_rate, }; - // let pcs = RingSwitching::::new(pcs); let dft = EvalsDft::new( - 1 << (log2_ceil_usize(n_columns_24) - + log_n_poseidons_16.max(log_n_poseidons_24) - + log_inv_rate + 1 << (log2_ceil_usize(setup.n_columns_24) + + config.log_n_poseidons_16.max(config.log_n_poseidons_24) + + config.log_inv_rate - whir_config_builder.folding_factor.at_round(0)), ); let commited_trace_polynomial_16 = - padd_with_zero_to_next_power_of_two(&witness_columns_16.concat()); + padd_with_zero_to_next_power_of_two(&setup.witness_columns_16.concat()); let commited_trace_polynomial_24 = - padd_with_zero_to_next_power_of_two(&witness_columns_24.concat()); + padd_with_zero_to_next_power_of_two(&setup.witness_columns_24.concat()); let dims = [ - ColDims::dense(log_table_area_16), - ColDims::dense(log_table_area_24), + ColDims::dense(setup.log_table_area_16), + ColDims::dense(setup.log_table_area_24), ]; - let log_smallest_decomposition_chunk = 0; // UNUSED because verything is power of 2 - - let commited_data = [ + let log_smallest_decomposition_chunk = 0; + let commited_slices = [ commited_trace_polynomial_16.as_slice(), commited_trace_polynomial_24.as_slice(), ]; + let commitment_witness = packed_pcs_commit( &whir_config_builder, - &commited_data, + &commited_slices, &dims, &dft, - &mut prover_state, - 0, + prover_state, + log_smallest_decomposition_chunk, ); let evaluations_remaining_to_prove_16 = - table_16.prove_base(&mut prover_state, univariate_skips, witness_16); + setup + .table_16 + .prove_base(prover_state, config.univariate_skips, witness_16); let evaluations_remaining_to_prove_24 = - table_24.prove_base(&mut prover_state, univariate_skips, witness_24); + setup + .table_24 + .prove_base(prover_state, config.univariate_skips, witness_24); let global_statements_to_prove = packed_pcs_global_statements_for_prover( - &commited_data, + &commited_slices, &dims, log_smallest_decomposition_chunk, &[ evaluations_remaining_to_prove_16, evaluations_remaining_to_prove_24, ], - &mut prover_state, + prover_state, ); let whir_config = WhirConfig::new( whir_config_builder.clone(), @@ -191,54 +244,70 @@ pub fn prove_poseidon2( ); whir_config.prove( &dft, - &mut prover_state, + prover_state, global_statements_to_prove, commitment_witness.inner_witness, &commitment_witness.packed_polynomial, ); - let prover_time = t.elapsed(); - let time = Instant::now(); + ProverArtifacts { + prover_time: start.elapsed(), + whir_config_builder, + whir_config, + dims, + } +} - let mut verifier_state = build_verifier_state(&prover_state); +fn run_verifier_phase( + config: &Poseidon2Config, + setup: &PoseidonSetup, + artifacts: &ProverArtifacts, + prover_state: &FSProver, +) -> Duration { + let start = Instant::now(); + let mut verifier_state = build_verifier_state(prover_state); + let log_smallest_decomposition_chunk = 0; let packed_parsed_commitment = packed_pcs_parse_commitment( - &whir_config_builder, + &artifacts.whir_config_builder, &mut verifier_state, - &dims, + &artifacts.dims, log_smallest_decomposition_chunk, ) .unwrap(); - let evaluations_remaining_to_verify_16 = table_16 + let evaluations_remaining_to_verify_16 = setup + .table_16 .verify( &mut verifier_state, - univariate_skips, - log_n_poseidons_16, - &column_groups_16, + config.univariate_skips, + config.log_n_poseidons_16, + &setup.column_groups_16, ) .unwrap(); - let evaluations_remaining_to_verify_24 = table_24 + let evaluations_remaining_to_verify_24 = setup + .table_24 .verify( &mut verifier_state, - univariate_skips, - log_n_poseidons_24, - &column_groups_24, + config.univariate_skips, + config.log_n_poseidons_24, + &setup.column_groups_24, ) .unwrap(); let global_statements_to_verify = packed_pcs_global_statements_for_verifier( - &dims, + &artifacts.dims, log_smallest_decomposition_chunk, &[ evaluations_remaining_to_verify_16, evaluations_remaining_to_verify_24, ], &mut verifier_state, - &Default::default(), + &BTreeMap::default(), ) .unwrap(); - whir_config + artifacts + .whir_config .verify( &mut verifier_state, &packed_parsed_commitment, @@ -246,14 +315,28 @@ pub fn prove_poseidon2( ) .unwrap(); - let verifier_time = time.elapsed(); + start.elapsed() +} + +pub fn prove_poseidon2(config: &Poseidon2Config) -> Poseidon2Benchmark { + if config.display_logs { + init_tracing(); + } + + let setup = prepare_poseidon(config); + let witness_16 = AirWitness::new(&setup.witness_columns_16, &setup.column_groups_16); + let witness_24 = AirWitness::new(&setup.witness_columns_24, &setup.column_groups_24); + + let mut prover_state = build_prover_state(); + let artifacts = run_prover_phase(config, &setup, witness_16, witness_24, &mut prover_state); + let verifier_time = run_verifier_phase(config, &setup, &artifacts, &prover_state); let proof_size = prover_state.proof_data().len() as f64 * (F::ORDER_U64 as f64).log2() / 8.0; Poseidon2Benchmark { - log_n_poseidons_16, - log_n_poseidons_24, - prover_time, + log_n_poseidons_16: config.log_n_poseidons_16, + log_n_poseidons_24: config.log_n_poseidons_24, + prover_time: artifacts.prover_time, verifier_time, proof_size, } @@ -267,19 +350,20 @@ mod tests { #[test] fn test_prove_poseidon2() { - let benchmark = prove_poseidon2( - 13, - 12, - 4, - FoldingFactor::new(5, 3), - 2, - SecurityAssumption::CapacityBound, - 13, - 128, - 1, - 5, - false, - ); + let config = Poseidon2Config { + log_n_poseidons_16: 13, + log_n_poseidons_24: 12, + univariate_skips: 4, + folding_factor: FoldingFactor::new(5, 3), + log_inv_rate: 2, + soundness_type: SecurityAssumption::CapacityBound, + pow_bits: 13, + security_level: 128, + rs_domain_initial_reduction_factor: 1, + max_num_variables_to_send_coeffs: 5, + display_logs: false, + }; + let benchmark = prove_poseidon2(&config); println!("\n{benchmark}"); } } diff --git a/src/main.rs b/src/main.rs index b8c4929c..00fd38ff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,22 +2,23 @@ mod examples; -use crate::examples::prove_poseidon2::prove_poseidon2; +use crate::examples::prove_poseidon2::{Poseidon2Config, prove_poseidon2}; use whir_p3::whir::config::{FoldingFactor, SecurityAssumption}; fn main() { - let benchmark = prove_poseidon2( - 17, - 17, - 4, - FoldingFactor::new(7, 4), - 1, - SecurityAssumption::CapacityBound, - 16, - 128, - 5, - 3, - true, - ); + let config = Poseidon2Config { + log_n_poseidons_16: 17, + log_n_poseidons_24: 17, + univariate_skips: 4, + folding_factor: FoldingFactor::new(7, 4), + log_inv_rate: 1, + soundness_type: SecurityAssumption::CapacityBound, + pow_bits: 16, + security_level: 128, + rs_domain_initial_reduction_factor: 5, + max_num_variables_to_send_coeffs: 3, + display_logs: true, + }; + let benchmark = prove_poseidon2(&config); println!("\n{benchmark}"); }