Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions crates/air/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
univariate_skips: usize,
witness: AirWitness<'a, PF<EF>>,
) -> Vec<Evaluation<EF>> {
prove_air::<PF<EF>, EF, A, AP>(prover_state, univariate_skips, &self, witness)
prove_air::<PF<EF>, EF, A, AP>(prover_state, univariate_skips, self, witness)
}

#[instrument(name = "air: prove in extension", skip_all)]
Expand All @@ -132,7 +132,7 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
univariate_skips: usize,
witness: AirWitness<'a, EF>,
) -> Vec<Evaluation<EF>> {
prove_air::<EF, EF, A, AP>(prover_state, univariate_skips, &self, witness)
prove_air::<EF, EF, A, AP>(prover_state, univariate_skips, self, witness)
}
}

Expand Down Expand Up @@ -224,9 +224,13 @@ fn open_structured_columns<'a, EF: ExtensionField<PF<EF>> + ExtensionField<IF>,
let mut column_scalars = vec![];
let mut index = 0;
for group in &witness.column_groups {
for i in index..index + group.len() {
column_scalars.push(poly_eq_batching_scalars[i]);
}
column_scalars.extend(
poly_eq_batching_scalars
.iter()
.skip(index)
.take(group.len())
.copied(),
);
index += witness.max_columns_per_group().next_power_of_two();
}

Expand Down
32 changes: 16 additions & 16 deletions crates/air/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,18 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
};
if TypeId::of::<IF>() == TypeId::of::<EF>() {
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>(
&mut constraints_checker,
));
self.air.eval(transmute::<
&mut ConstraintChecker<'_, IF, EF>,
&mut ConstraintChecker<'_, EF, EF>,
>(&mut constraints_checker));
}
} else {
assert_eq!(TypeId::of::<IF>(), TypeId::of::<PF<EF>>());
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, PF<EF>, EF>>(
&mut constraints_checker,
));
self.air.eval(transmute::<
&mut ConstraintChecker<'_, IF, EF>,
&mut ConstraintChecker<'_, PF<EF>, EF>,
>(&mut constraints_checker));
}
}
handle_errors(row, &mut constraints_checker)?;
Expand All @@ -110,18 +110,18 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
};
if TypeId::of::<IF>() == TypeId::of::<EF>() {
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>(
&mut constraints_checker,
));
self.air.eval(transmute::<
&mut ConstraintChecker<'_, IF, EF>,
&mut ConstraintChecker<'_, EF, EF>,
>(&mut constraints_checker));
}
} else {
assert_eq!(TypeId::of::<IF>(), TypeId::of::<PF<EF>>());
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, PF<EF>, EF>>(
&mut constraints_checker,
));
self.air.eval(transmute::<
&mut ConstraintChecker<'_, IF, EF>,
&mut ConstraintChecker<'_, PF<EF>, EF>,
>(&mut constraints_checker));
}
}
handle_errors(row, &mut constraints_checker)?;
Expand Down
78 changes: 60 additions & 18 deletions crates/air/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,41 @@ fn generate_structured_trace<const N_COLUMNS: usize, const N_PREPROCESSED_COLUMN
trace.push((0..n_rows).map(|_| rng.random()).collect::<Vec<F>>());
}
let mut witness_cols = vec![vec![F::ZERO]; N_COLUMNS - N_PREPROCESSED_COLUMNS];
for i in 1..n_rows {
for j in 0..N_COLUMNS - N_PREPROCESSED_COLUMNS {
let witness_cols_j_i_min_1 = witness_cols[j][i - 1];
witness_cols[j].push(
witness_cols_j_i_min_1
+ F::from_usize(j + N_PREPROCESSED_COLUMNS)
+ (0..N_PREPROCESSED_COLUMNS)
.map(|k| trace[k][i])
.product::<F>(),
);
let mut prev_values = vec![F::ZERO; N_COLUMNS - N_PREPROCESSED_COLUMNS];
let mut column_iters = trace[..N_PREPROCESSED_COLUMNS]
.iter()
.map(|col| col.iter())
.collect::<Vec<_>>();
if column_iters.is_empty() {
trace.extend(witness_cols);
return trace;
}
for iter in &mut column_iters {
iter.next(); // skip first row, already initialised
}
loop {
let mut row_product = F::ONE;
let mut progressed = true;
for iter in &mut column_iters {
match iter.next() {
Some(value) => row_product *= *value,
None => {
progressed = false;
break;
}
}
}
if !progressed {
break;
}
for (j, (witness_col, prev)) in witness_cols
.iter_mut()
.zip(prev_values.iter_mut())
.enumerate()
{
let next_val = *prev + F::from_usize(j + N_PREPROCESSED_COLUMNS) + row_product;
witness_col.push(next_val);
*prev = next_val;
}
}
trace.extend(witness_cols);
Expand All @@ -131,14 +156,31 @@ fn generate_unstructured_trace<const N_COLUMNS: usize, const N_PREPROCESSED_COLU
trace.push((0..n_rows).map(|_| rng.random()).collect::<Vec<F>>());
}
let mut witness_cols = vec![vec![]; N_COLUMNS - N_PREPROCESSED_COLUMNS];
for i in 0..n_rows {
for j in 0..N_COLUMNS - N_PREPROCESSED_COLUMNS {
witness_cols[j].push(
F::from_usize(j + N_PREPROCESSED_COLUMNS)
+ (0..N_PREPROCESSED_COLUMNS)
.map(|k| trace[k][i])
.product::<F>(),
);
let mut column_iters = trace[..N_PREPROCESSED_COLUMNS]
.iter()
.map(|col| col.iter())
.collect::<Vec<_>>();
if column_iters.is_empty() {
trace.extend(witness_cols);
return trace;
}
loop {
let mut row_product = F::ONE;
let mut progressed = true;
for iter in &mut column_iters {
match iter.next() {
Some(value) => row_product *= *value,
None => {
progressed = false;
break;
}
}
}
if !progressed {
break;
}
for (j, witness_col) in witness_cols.iter_mut().enumerate() {
witness_col.push(F::from_usize(j + N_PREPROCESSED_COLUMNS) + row_product);
}
}
trace.extend(witness_cols);
Expand Down
58 changes: 39 additions & 19 deletions crates/air/src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,18 @@ fn verify_air<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>>(
if structured_air {
verify_structured_columns(
verifier_state,
table.n_columns(),
univariate_skips,
&inner_sums,
&column_groups,
&Evaluation::new(
outer_statement.point[1..log_length - univariate_skips + 1].to_vec(),
outer_statement.value,
),
&outer_selector_evals,
log_length,
StructuredColumnsArgs {
n_columns: table.n_columns(),
univariate_skips,
all_inner_sums: &inner_sums,
column_groups,
outer_sumcheck_challenge: &Evaluation::new(
outer_statement.point[1..log_length - univariate_skips + 1].to_vec(),
outer_statement.value,
),
outer_selector_evals: &outer_selector_evals,
log_n_rows: log_length,
},
)
} else {
verify_many_unstructured_columns(
Expand Down Expand Up @@ -181,16 +183,30 @@ fn verify_many_unstructured_columns<EF: ExtensionField<PF<EF>>>(
Ok(evaluations_remaining_to_verify)
}

fn verify_structured_columns<EF: ExtensionField<PF<EF>>>(
verifier_state: &mut FSVerifier<EF, impl FSChallenger<EF>>,
#[derive(Debug)]
struct StructuredColumnsArgs<'a, EF> {
n_columns: usize,
univariate_skips: usize,
all_inner_sums: &[EF],
column_groups: &[Range<usize>],
outer_sumcheck_challenge: &Evaluation<EF>,
outer_selector_evals: &[EF],
all_inner_sums: &'a [EF],
column_groups: &'a [Range<usize>],
outer_sumcheck_challenge: &'a Evaluation<EF>,
outer_selector_evals: &'a [EF],
log_n_rows: usize,
}

fn verify_structured_columns<EF: ExtensionField<PF<EF>>>(
verifier_state: &mut FSVerifier<EF, impl FSChallenger<EF>>,
args: StructuredColumnsArgs<'_, EF>,
) -> Result<Vec<Evaluation<EF>>, ProofError> {
let StructuredColumnsArgs {
n_columns,
univariate_skips,
all_inner_sums,
column_groups,
outer_sumcheck_challenge,
outer_selector_evals,
log_n_rows,
} = args;
let log_n_groups = log2_ceil_usize(column_groups.len());
let max_columns_per_group = Iterator::max(column_groups.iter().map(|g| g.len())).unwrap();
let log_max_columns_per_group = log2_ceil_usize(max_columns_per_group);
Expand All @@ -201,9 +217,13 @@ fn verify_structured_columns<EF: ExtensionField<PF<EF>>>(
let mut column_scalars = vec![];
let mut index = 0;
for group in column_groups {
for i in index..index + group.len() {
column_scalars.push(poly_eq_batching_scalars[i]);
}
column_scalars.extend(
poly_eq_batching_scalars
.iter()
.skip(index)
.take(group.len())
.copied(),
);
index += max_columns_per_group.next_power_of_two();
}

Expand Down
70 changes: 34 additions & 36 deletions crates/lean_compiler/src/a_simplify_lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,10 @@ fn simplify_lines(
unimplemented!("Reverse for non-unrolled loops are not implemented yet");
}

let mut loop_const_malloc = ConstMalloc::default();
loop_const_malloc.counter = const_malloc.counter;
let mut loop_const_malloc = ConstMalloc {
counter: const_malloc.counter,
..ConstMalloc::default()
};
let valid_aux_vars_in_array_manager_before = array_manager.valid.clone();
array_manager.valid.clear();
let simplified_body = simplify_lines(
Expand Down Expand Up @@ -678,16 +680,15 @@ fn simplify_expr(
match expr {
Expression::Value(value) => value.simplify_if_const(),
Expression::ArrayAccess { array, index } => {
if let SimpleExpr::Var(array_var) = array {
if let Some(label) = const_malloc.map.get(array_var) {
if let Ok(mut offset) = ConstExpression::try_from(*index.clone()) {
offset = offset.try_naive_simplification();
return SimpleExpr::ConstMallocAccess {
malloc_label: *label,
offset,
};
}
}
if let SimpleExpr::Var(array_var) = array
&& let Some(label) = const_malloc.map.get(array_var)
&& let Ok(mut offset) = ConstExpression::try_from(*index.clone())
{
offset = offset.try_naive_simplification();
return SimpleExpr::ConstMallocAccess {
malloc_label: *label,
offset,
};
}

let aux_arr = array_manager.get_aux_var(array, index); // auxiliary var to store m[array + index]
Expand Down Expand Up @@ -1082,30 +1083,27 @@ fn handle_array_assignment(
) {
let simplified_index = simplify_expr(index, res, counters, array_manager, const_malloc);

if let SimpleExpr::Constant(offset) = simplified_index.clone() {
if let SimpleExpr::Var(array_var) = &array {
if let Some(label) = const_malloc.map.get(array_var) {
if let ArrayAccessType::ArrayIsAssigned(Expression::Binary {
left,
operation,
right,
}) = access_type
{
let arg0 = simplify_expr(&left, res, counters, array_manager, const_malloc);
let arg1 = simplify_expr(&right, res, counters, array_manager, const_malloc);
res.push(SimpleLine::Assignment {
var: VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *label,
offset,
},
operation,
arg0,
arg1,
});
return;
}
}
}
if let SimpleExpr::Constant(offset) = simplified_index.clone()
&& let SimpleExpr::Var(array_var) = &array
&& let Some(label) = const_malloc.map.get(array_var)
&& let ArrayAccessType::ArrayIsAssigned(Expression::Binary {
left,
operation,
right,
}) = &access_type
{
let arg0 = simplify_expr(left, res, counters, array_manager, const_malloc);
let arg1 = simplify_expr(right, res, counters, array_manager, const_malloc);
res.push(SimpleLine::Assignment {
var: VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *label,
offset,
},
operation: *operation,
arg0,
arg1,
});
return;
}

let value_simplified = match access_type {
Expand Down
Loading
Loading