Skip to content
14 changes: 9 additions & 5 deletions crates/air/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
univariate_skips: usize,
witness: AirWitness<'a, PF<EF>>,
) -> Vec<Evaluation<EF>> {
prove_air::<PF<EF>, EF, A, AP>(prover_state, univariate_skips, &self, witness)
prove_air::<PF<EF>, EF, A, AP>(prover_state, univariate_skips, self, witness)
}

#[instrument(name = "air: prove in extension", skip_all)]
Expand All @@ -131,7 +131,7 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
univariate_skips: usize,
witness: AirWitness<'a, EF>,
) -> Vec<Evaluation<EF>> {
prove_air::<EF, EF, A, AP>(prover_state, univariate_skips, &self, witness)
prove_air::<EF, EF, A, AP>(prover_state, univariate_skips, self, witness)
}
}

Expand Down Expand Up @@ -226,9 +226,13 @@ fn open_structured_columns<'a, EF: ExtensionField<PF<EF>> + ExtensionField<IF>,
let mut column_scalars = vec![];
let mut index = 0;
for group in &witness.column_groups {
for i in index..index + group.len() {
column_scalars.push(poly_eq_batching_scalars[i]);
}
column_scalars.extend(
poly_eq_batching_scalars
.iter()
.skip(index)
.take(group.len())
.cloned(),
);
index += witness.max_columns_per_group().next_power_of_two();
}

Expand Down
32 changes: 16 additions & 16 deletions crates/air/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,18 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
};
if TypeId::of::<IF>() == TypeId::of::<EF>() {
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>(
&mut constraints_checker,
));
self.air.eval(transmute::<
&mut ConstraintChecker<'_, IF, EF>,
&mut ConstraintChecker<'_, EF, EF>,
>(&mut constraints_checker));
}
} else {
assert_eq!(TypeId::of::<IF>(), TypeId::of::<PF<EF>>());
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, PF<EF>, EF>>(
&mut constraints_checker,
));
self.air.eval(transmute::<
&mut ConstraintChecker<'_, IF, EF>,
&mut ConstraintChecker<'_, PF<EF>, EF>,
>(&mut constraints_checker));
}
}
handle_errors(row, &mut constraints_checker)?;
Expand All @@ -110,18 +110,18 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
};
if TypeId::of::<IF>() == TypeId::of::<EF>() {
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>(
&mut constraints_checker,
));
self.air.eval(transmute::<
&mut ConstraintChecker<'_, IF, EF>,
&mut ConstraintChecker<'_, EF, EF>,
>(&mut constraints_checker));
}
} else {
assert_eq!(TypeId::of::<IF>(), TypeId::of::<PF<EF>>());
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, PF<EF>, EF>>(
&mut constraints_checker,
));
self.air.eval(transmute::<
&mut ConstraintChecker<'_, IF, EF>,
&mut ConstraintChecker<'_, PF<EF>, EF>,
>(&mut constraints_checker));
}
}
handle_errors(row, &mut constraints_checker)?;
Expand Down
10 changes: 5 additions & 5 deletions crates/air/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ fn generate_structured_trace<const N_COLUMNS: usize, const N_PREPROCESSED_COLUMN
}
let mut witness_cols = vec![vec![F::ZERO]; N_COLUMNS - N_PREPROCESSED_COLUMNS];
for i in 1..n_rows {
for j in 0..N_COLUMNS - N_PREPROCESSED_COLUMNS {
let witness_cols_j_i_min_1 = witness_cols[j][i - 1];
witness_cols[j].push(
for (j, col) in witness_cols.iter_mut().enumerate() {
let witness_cols_j_i_min_1 = col[i - 1];
col.push(
witness_cols_j_i_min_1
+ F::from_usize(j + N_PREPROCESSED_COLUMNS)
+ (0..N_PREPROCESSED_COLUMNS)
Expand All @@ -132,8 +132,8 @@ fn generate_unstructured_trace<const N_COLUMNS: usize, const N_PREPROCESSED_COLU
}
let mut witness_cols = vec![vec![]; N_COLUMNS - N_PREPROCESSED_COLUMNS];
for i in 0..n_rows {
for j in 0..N_COLUMNS - N_PREPROCESSED_COLUMNS {
witness_cols[j].push(
for (j, col) in witness_cols.iter_mut().enumerate() {
col.push(
F::from_usize(j + N_PREPROCESSED_COLUMNS)
+ (0..N_PREPROCESSED_COLUMNS)
.map(|k| trace[k][i])
Expand Down
7 changes: 3 additions & 4 deletions crates/air/src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn verify_air<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>>(
table.n_columns(),
univariate_skips,
&inner_sums,
&column_groups,
column_groups,
&Evaluation {
point: MultilinearPoint(
outer_statement.point[1..log_length - univariate_skips + 1].to_vec(),
Expand Down Expand Up @@ -185,6 +185,7 @@ fn verify_many_unstructured_columns<EF: ExtensionField<PF<EF>>>(
Ok(evaluations_remaining_to_verify)
}

#[allow(clippy::too_many_arguments)]
fn verify_structured_columns<EF: ExtensionField<PF<EF>>>(
verifier_state: &mut FSVerifier<EF, impl FSChallenger<EF>>,
n_columns: usize,
Expand All @@ -205,9 +206,7 @@ fn verify_structured_columns<EF: ExtensionField<PF<EF>>>(
let mut column_scalars = vec![];
let mut index = 0;
for group in column_groups {
for i in index..index + group.len() {
column_scalars.push(poly_eq_batching_scalars[i]);
}
column_scalars.extend_from_slice(&poly_eq_batching_scalars[index..index + group.len()]);
index += max_columns_per_group.next_power_of_two();
}

Expand Down
70 changes: 34 additions & 36 deletions crates/lean_compiler/src/a_simplify_lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,10 @@ fn simplify_lines(
unimplemented!("Reverse for non-unrolled loops are not implemented yet");
}

let mut loop_const_malloc = ConstMalloc::default();
loop_const_malloc.counter = const_malloc.counter;
let mut loop_const_malloc = ConstMalloc {
counter: const_malloc.counter,
..ConstMalloc::default()
};
let valid_aux_vars_in_array_manager_before = array_manager.valid.clone();
array_manager.valid.clear();
let simplified_body = simplify_lines(
Expand Down Expand Up @@ -678,16 +680,15 @@ fn simplify_expr(
match expr {
Expression::Value(value) => value.simplify_if_const(),
Expression::ArrayAccess { array, index } => {
if let SimpleExpr::Var(array_var) = array {
if let Some(label) = const_malloc.map.get(array_var) {
if let Ok(mut offset) = ConstExpression::try_from(*index.clone()) {
offset = offset.try_naive_simplification();
return SimpleExpr::ConstMallocAccess {
malloc_label: *label,
offset,
};
}
}
if let SimpleExpr::Var(array_var) = array
&& let Some(label) = const_malloc.map.get(array_var)
&& let Ok(mut offset) = ConstExpression::try_from(*index.clone())
{
offset = offset.try_naive_simplification();
return SimpleExpr::ConstMallocAccess {
malloc_label: *label,
offset,
};
}

let aux_arr = array_manager.get_aux_var(array, index); // auxiliary var to store m[array + index]
Expand Down Expand Up @@ -1082,30 +1083,27 @@ fn handle_array_assignment(
) {
let simplified_index = simplify_expr(index, res, counters, array_manager, const_malloc);

if let SimpleExpr::Constant(offset) = simplified_index.clone() {
if let SimpleExpr::Var(array_var) = &array {
if let Some(label) = const_malloc.map.get(array_var) {
if let ArrayAccessType::ArrayIsAssigned(Expression::Binary {
left,
operation,
right,
}) = access_type
{
let arg0 = simplify_expr(&left, res, counters, array_manager, const_malloc);
let arg1 = simplify_expr(&right, res, counters, array_manager, const_malloc);
res.push(SimpleLine::Assignment {
var: VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *label,
offset,
},
operation,
arg0,
arg1,
});
return;
}
}
}
if let SimpleExpr::Constant(offset) = simplified_index.clone()
&& let SimpleExpr::Var(array_var) = &array
&& let Some(label) = const_malloc.map.get(array_var)
&& let ArrayAccessType::ArrayIsAssigned(Expression::Binary {
left,
operation,
right,
}) = access_type
{
let arg0 = simplify_expr(&left, res, counters, array_manager, const_malloc);
let arg1 = simplify_expr(&right, res, counters, array_manager, const_malloc);
res.push(SimpleLine::Assignment {
var: VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *label,
offset,
},
operation,
arg0,
arg1,
});
return;
}

let value_simplified = match access_type {
Expand Down
16 changes: 8 additions & 8 deletions crates/lean_compiler/src/b_compile_intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl Compiler {
}

impl SimpleExpr {
fn into_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediaryMemOrFpOrConstant {
fn to_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediaryMemOrFpOrConstant {
match self {
Self::Var(var) => IntermediaryMemOrFpOrConstant::MemoryAfterFp {
offset: compiler.get_offset(&var.clone().into()),
Expand Down Expand Up @@ -368,7 +368,7 @@ fn compile_lines(
}

SimpleLine::RawAccess { res, index, shift } => {
validate_vars_declared(&[index.clone()], declared_vars)?;
validate_vars_declared(std::slice::from_ref(index), declared_vars)?;
if let SimpleExpr::Var(var) = res {
declared_vars.insert(var.clone());
}
Expand All @@ -379,7 +379,7 @@ fn compile_lines(
instructions.push(IntermediateInstruction::Deref {
shift_0,
shift_1: shift.clone(),
res: res.into_mem_after_fp_or_constant(compiler),
res: res.to_mem_after_fp_or_constant(compiler),
});
}

Expand Down Expand Up @@ -623,10 +623,10 @@ fn validate_vars_declared<VoC: Borrow<SimpleExpr>>(
declared: &BTreeSet<Var>,
) -> Result<(), String> {
for voc in vocs {
if let SimpleExpr::Var(v) = voc.borrow() {
if !declared.contains(v) {
return Err(format!("Variable {v} not declared"));
}
if let SimpleExpr::Var(v) = voc.borrow()
&& !declared.contains(v)
{
return Err(format!("Variable {v} not declared"));
}
}
Ok(())
Expand Down Expand Up @@ -665,7 +665,7 @@ fn setup_function_call(
instructions.push(IntermediateInstruction::Deref {
shift_0: new_fp_pos.into(),
shift_1: (2 + i).into(),
res: arg.into_mem_after_fp_or_constant(compiler),
res: arg.to_mem_after_fp_or_constant(compiler),
});
}

Expand Down
28 changes: 14 additions & 14 deletions crates/lean_compiler/src/c_compile_final.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,23 +152,23 @@ fn compile_block(
mut arg_c,
res,
} => {
if let Some(arg_a_cst) = try_as_constant(&arg_a, compiler) {
if let Some(arg_b_cst) = try_as_constant(&arg_c, compiler) {
// res = constant +/x constant
if let Some(arg_a_cst) = try_as_constant(&arg_a, compiler)
&& let Some(arg_b_cst) = try_as_constant(&arg_c, compiler)
{
// res = constant +/x constant

let op_res = operation.compute(arg_a_cst, arg_b_cst);
let op_res = operation.compute(arg_a_cst, arg_b_cst);

let res: MemOrFp = res.try_into_mem_or_fp(compiler).unwrap();
let res: MemOrFp = res.try_into_mem_or_fp(compiler).unwrap();

low_level_bytecode.push(Instruction::Computation {
operation: Operation::Add,
arg_a: MemOrConstant::zero(),
arg_c: res,
res: MemOrConstant::Constant(op_res),
});
pc += 1;
continue;
}
low_level_bytecode.push(Instruction::Computation {
operation: Operation::Add,
arg_a: MemOrConstant::zero(),
arg_c: res,
res: MemOrConstant::Constant(op_res),
});
pc += 1;
continue;
}

if arg_c.is_constant() {
Expand Down
1 change: 1 addition & 0 deletions crates/lean_prover/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use whir_p3::poly::{evals::fold_multilinear, multilinear::MultilinearPoint};
use crate::*;
use lean_vm::*;

#[allow(clippy::too_many_arguments)]
pub fn get_base_dims(
n_cycles: usize,
log_public_memory: usize,
Expand Down
13 changes: 6 additions & 7 deletions crates/lean_prover/src/prove_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,12 +833,11 @@ pub fn prove_execution(
let index_a: F = dot_product_columns[2][i].as_base().unwrap();
let index_b: F = dot_product_columns[3][i].as_base().unwrap();
let index_res: F = dot_product_columns[4][i].as_base().unwrap();
for j in 0..DIMENSION {
dot_product_indexes_spread[j][i] = index_a + F::from_usize(j);
dot_product_indexes_spread[j][i + dot_product_table_length] =
index_b + F::from_usize(j);
dot_product_indexes_spread[j][i + 2 * dot_product_table_length] =
index_res + F::from_usize(j);
for (j, slice) in dot_product_indexes_spread.iter_mut().enumerate() {
let offset = F::from_usize(j);
slice[i] = index_a + offset;
slice[i + dot_product_table_length] = index_b + offset;
slice[i + 2 * dot_product_table_length] = index_res + offset;
}
}
let dot_product_values_spread = dot_product_indexes_spread
Expand Down Expand Up @@ -1020,7 +1019,7 @@ pub fn prove_execution(
let packed_pcs_witness_extension = packed_pcs_commit(
&pcs.pcs_b(
log2_strict_usize(packed_pcs_witness_base.packed_polynomial.len()),
num_packed_vars_for_dims::<EF, EF>(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK),
num_packed_vars_for_dims::<EF>(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK),
),
&extension_pols,
&extension_dims,
Expand Down
10 changes: 7 additions & 3 deletions crates/lean_prover/src/verify_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ pub fn verify_execution(
let parsed_commitment_extension = packed_pcs_parse_commitment(
&pcs.pcs_b(
parsed_commitment_base.num_variables(),
num_packed_vars_for_dims::<EF, EF>(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK),
num_packed_vars_for_dims::<EF>(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK),
),
&mut verifier_state,
&extension_dims,
Expand Down Expand Up @@ -893,8 +893,12 @@ pub fn verify_execution(
);

let mut dot_product_indexes_inner_evals_incr = vec![EF::ZERO; 8];
for i in 0..DIMENSION {
dot_product_indexes_inner_evals_incr[i] = dot_product_logup_star_indexes_inner_value
for (i, value) in dot_product_indexes_inner_evals_incr
.iter_mut()
.enumerate()
.take(DIMENSION)
{
*value = dot_product_logup_star_indexes_inner_value
+ EF::from_usize(i)
* [F::ONE, F::ONE, F::ONE, F::ZERO].evaluate(&MultilinearPoint(
mem_lookup_eval_indexes_partial_point.0[3 + index_diff..5 + index_diff]
Expand Down
Loading
Loading