Skip to content

Commit cb1a980

Browse files
committed
simplify AIR: remove column groups
1 parent ad5baa8 commit cb1a980

File tree

13 files changed

+227
-447
lines changed

13 files changed

+227
-447
lines changed

crates/air/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ pub mod table;
1313
mod uni_skip_utils;
1414
mod utils;
1515
mod verify;
16-
pub mod witness;
1716

1817
#[cfg(test)]
1918
mod test;

crates/air/src/prove.rs

Lines changed: 69 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
use std::any::TypeId;
22

33
use p3_air::BaseAir;
4-
use p3_field::{ExtensionField, Field, cyclic_subgroup_known_order, dot_product};
5-
use p3_util::log2_ceil_usize;
4+
use p3_field::{ExtensionField, Field, cyclic_subgroup_known_order};
5+
use p3_util::{log2_ceil_usize, log2_strict_usize};
66
use sumcheck::{MleGroup, MleGroupOwned, MleGroupRef, ProductComputation};
77
use tracing::{info_span, instrument};
88
use utils::PF;
9-
use utils::{FSProver, add_multilinears, from_end, multilinears_linear_combination};
9+
use utils::{FSProver, add_multilinears, multilinears_linear_combination};
1010
use whir_p3::fiat_shamir::FSChallenger;
1111
use whir_p3::poly::evals::{eval_eq, fold_multilinear, scale_poly};
1212
use whir_p3::poly::multilinear::Evaluation;
1313
use whir_p3::poly::{evals::EvaluationsList, multilinear::MultilinearPoint};
1414

15-
use crate::witness::AirWitness;
1615
use crate::{NormalAir, PackedAir};
1716
use crate::{
1817
uni_skip_utils::{matrix_down_folded, matrix_up_folded},
@@ -38,43 +37,42 @@ fn prove_air<
3837
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
3938
univariate_skips: usize,
4039
table: &AirTable<EF, A, AP>,
41-
witness: AirWitness<'a, WF>,
40+
witness: &[&'a [WF]],
4241
) -> Vec<Evaluation<EF>> {
42+
let n_rows = witness[0].len();
43+
assert!(witness.iter().all(|col| col.len() == n_rows));
44+
let log_n_rows = log2_strict_usize(n_rows);
4345
assert!(
44-
univariate_skips < witness.log_n_rows(),
46+
univariate_skips < log_n_rows,
4547
"TODO handle the case UNIVARIATE_SKIPS >= log_length"
4648
);
4749

4850
let structured_air = <A as BaseAir<PF<EF>>>::structured(&table.air);
4951

50-
let log_length = witness.log_n_rows();
51-
5252
let constraints_batching_scalar = prover_state.sample();
5353

5454
let constraints_batching_scalars =
5555
cyclic_subgroup_known_order(constraints_batching_scalar, table.n_constraints)
5656
.collect::<Vec<_>>();
5757

58-
let n_sc_rounds = log_length + 1 - univariate_skips;
58+
let n_sc_rounds = log_n_rows + 1 - univariate_skips;
5959

6060
let zerocheck_challenges = prover_state.sample_vec(n_sc_rounds);
6161

6262
let columns_for_zero_check: MleGroup<'_, EF> = if TypeId::of::<WF>() == TypeId::of::<PF<EF>>() {
63-
let columns =
64-
unsafe { std::mem::transmute::<&Vec<&'a [WF]>, &Vec<&'a [PF<EF>]>>(&witness.cols) };
63+
let columns = unsafe { std::mem::transmute::<&[&[WF]], &[&[PF<EF>]]>(witness) };
6564
if structured_air {
6665
MleGroupOwned::Base(columns_up_and_down(columns)).into()
6766
} else {
68-
MleGroupRef::Base(columns.clone()).into()
67+
MleGroupRef::Base(columns.to_vec()).into()
6968
}
7069
} else {
7170
assert!(TypeId::of::<WF>() == TypeId::of::<EF>());
72-
let columns =
73-
unsafe { std::mem::transmute::<&Vec<&'a [WF]>, &Vec<&'a [EF]>>(&witness.cols) };
71+
let columns = unsafe { std::mem::transmute::<&[&'a [WF]], &[&'a [EF]]>(witness) };
7472
if structured_air {
7573
MleGroupOwned::Extension(columns_up_and_down(columns)).into()
7674
} else {
77-
MleGroupRef::Extension(columns.clone()).into()
75+
MleGroupRef::Extension(columns.to_vec()).into()
7876
}
7977
};
8078

@@ -101,158 +99,126 @@ fn prove_air<
10199
open_structured_columns(
102100
prover_state,
103101
univariate_skips,
104-
&witness,
102+
witness,
105103
&outer_sumcheck_challenge,
106104
)
107105
} else {
108106
open_unstructured_columns(
109107
prover_state,
110108
univariate_skips,
111-
&witness,
109+
witness,
112110
&outer_sumcheck_challenge,
113111
)
114112
}
115113
}
116114

117115
impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<EF, A, AP> {
118116
#[instrument(name = "air: prove in base", skip_all)]
119-
pub fn prove_base<'a>(
117+
pub fn prove_base(
120118
&self,
121119
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
122120
univariate_skips: usize,
123-
witness: AirWitness<'a, PF<EF>>,
121+
witness: &[&[PF<EF>]],
124122
) -> Vec<Evaluation<EF>> {
125123
prove_air::<PF<EF>, EF, A, AP>(prover_state, univariate_skips, self, witness)
126124
}
127125

128126
#[instrument(name = "air: prove in extension", skip_all)]
129-
pub fn prove_extension<'a>(
127+
pub fn prove_extension(
130128
&self,
131129
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
132130
univariate_skips: usize,
133-
witness: AirWitness<'a, EF>,
131+
witness: &[&[EF]],
134132
) -> Vec<Evaluation<EF>> {
135133
prove_air::<EF, EF, A, AP>(prover_state, univariate_skips, self, witness)
136134
}
137135
}
138136

139-
fn eval_unstructured_column_groups<EF: ExtensionField<PF<EF>> + ExtensionField<IF>, IF: Field>(
140-
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
141-
univariate_skips: usize,
142-
witnesses: &AirWitness<'_, IF>,
143-
outer_sumcheck_challenge: &[EF],
144-
columns_batching_scalars: &[EF],
145-
) -> Vec<Vec<EF>> {
146-
let mut all_sub_evals = vec![];
147-
for group in &witnesses.column_groups {
148-
let batched_column = multilinears_linear_combination(
149-
&witnesses.cols[group.clone()],
150-
&eval_eq(from_end(
151-
columns_batching_scalars,
152-
log2_ceil_usize(group.len()),
153-
))[..group.len()],
154-
);
155-
156-
// TODO opti
157-
let sub_evals = fold_multilinear(
158-
&batched_column,
159-
&MultilinearPoint(
160-
outer_sumcheck_challenge[1..witnesses.log_n_rows() - univariate_skips + 1].to_vec(),
161-
),
162-
);
163-
164-
prover_state.add_extension_scalars(&sub_evals);
165-
all_sub_evals.push(sub_evals);
166-
}
167-
all_sub_evals
168-
}
169-
170137
#[instrument(skip_all)]
171138
fn open_unstructured_columns<
172-
'a,
173139
WF: ExtensionField<PF<EF>>,
174140
EF: ExtensionField<PF<EF>> + ExtensionField<WF>,
175141
>(
176142
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
177143
univariate_skips: usize,
178-
witness: &AirWitness<'a, WF>,
144+
witness: &[&[WF]],
179145
outer_sumcheck_challenge: &[EF],
180146
) -> Vec<Evaluation<EF>> {
181-
let columns_batching_scalars =
182-
prover_state.sample_vec(log2_ceil_usize(witness.max_columns_per_group()));
147+
let log_n_rows = log2_strict_usize(witness[0].len());
183148

184-
let sub_evals = eval_unstructured_column_groups(
185-
prover_state,
186-
univariate_skips,
149+
let columns_batching_scalars = prover_state.sample_vec(log2_ceil_usize(witness.len()));
150+
151+
let batched_column = multilinears_linear_combination(
187152
witness,
188-
outer_sumcheck_challenge,
189-
&columns_batching_scalars,
153+
&eval_eq(&columns_batching_scalars)[..witness.len()],
154+
);
155+
156+
// TODO opti
157+
let sub_evals = fold_multilinear(
158+
&batched_column,
159+
&MultilinearPoint(outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec()),
190160
);
191161

162+
prover_state.add_extension_scalars(&sub_evals);
163+
192164
let epsilons = MultilinearPoint(prover_state.sample_vec(univariate_skips));
165+
let common_point = MultilinearPoint(
166+
[
167+
epsilons.0.clone(),
168+
outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec(),
169+
]
170+
.concat(),
171+
);
193172

194173
let mut evaluations_remaining_to_prove = vec![];
195-
for (group, sub_evals) in witness.column_groups.iter().zip(sub_evals) {
196-
assert_eq!(sub_evals.len(), 1 << epsilons.len());
197-
198-
evaluations_remaining_to_prove.push(Evaluation::new(
199-
[
200-
from_end(&columns_batching_scalars, log2_ceil_usize(group.len())).to_vec(),
201-
epsilons.0.clone(),
202-
outer_sumcheck_challenge[1..witness.log_n_rows() - univariate_skips + 1].to_vec(),
203-
]
204-
.concat(),
205-
sub_evals.evaluate(&epsilons),
206-
));
174+
assert_eq!(sub_evals.len(), 1 << epsilons.len());
175+
176+
for col in witness {
177+
// TODO compute oe time eq(.) then inner product with everything
178+
let value = col.evaluate(&common_point);
179+
prover_state.add_extension_scalars(&[value]);
180+
evaluations_remaining_to_prove.push(Evaluation {
181+
point: common_point.clone(),
182+
value,
183+
});
207184
}
185+
208186
evaluations_remaining_to_prove
209187
}
210188

211189
#[instrument(skip_all)]
212-
fn open_structured_columns<'a, EF: ExtensionField<PF<EF>> + ExtensionField<IF>, IF: Field>(
190+
fn open_structured_columns<EF: ExtensionField<PF<EF>> + ExtensionField<IF>, IF: Field>(
213191
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
214192
univariate_skips: usize,
215-
witness: &AirWitness<'a, IF>,
193+
witness: &[&[IF]],
216194
outer_sumcheck_challenge: &[EF],
217195
) -> Vec<Evaluation<EF>> {
218-
let log_n_groups = log2_ceil_usize(witness.column_groups.len());
219-
let batching_scalars =
220-
prover_state.sample_vec(log_n_groups + witness.log_max_columns_per_group());
196+
let n_columns = witness.len();
197+
let n_rows = witness[0].len();
198+
let log_n_rows = log2_strict_usize(n_rows);
199+
let batching_scalars = prover_state.sample_vec(log2_ceil_usize(n_columns));
221200
let alpha = prover_state.sample();
222201

223202
let poly_eq_batching_scalars = eval_eq(&batching_scalars);
224-
let mut column_scalars = vec![];
225-
let mut index = 0;
226-
for group in &witness.column_groups {
227-
column_scalars.extend(
228-
poly_eq_batching_scalars
229-
.iter()
230-
.skip(index)
231-
.take(group.len())
232-
.copied(),
233-
);
234-
index += witness.max_columns_per_group().next_power_of_two();
235-
}
236203

237-
let batched_column = multilinears_linear_combination(&witness.cols, &column_scalars);
204+
let batched_column =
205+
multilinears_linear_combination(witness, &poly_eq_batching_scalars[..n_columns]);
238206
let batched_column_mixed = add_multilinears(
239207
&column_up(&batched_column),
240208
&scale_poly(&column_down(&batched_column), alpha),
241209
);
242210
// TODO do not recompute this (we can deduce it from already computed values)
243211
let sub_evals = fold_multilinear(
244212
&batched_column_mixed,
245-
&MultilinearPoint(
246-
outer_sumcheck_challenge[1..witness.log_n_rows() - univariate_skips + 1].to_vec(),
247-
),
213+
&MultilinearPoint(outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec()),
248214
);
249215
prover_state.add_extension_scalars(&sub_evals);
250216

251217
let epsilons = prover_state.sample_vec(univariate_skips);
252218

253219
let point = [
254220
epsilons,
255-
outer_sumcheck_challenge[1..witness.log_n_rows() - univariate_skips + 1].to_vec(),
221+
outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec(),
256222
]
257223
.concat();
258224

@@ -267,8 +233,7 @@ fn open_structured_columns<'a, EF: ExtensionField<PF<EF>> + ExtensionField<IF>,
267233
batched_column,
268234
]);
269235

270-
let n_groups = witness.column_groups.len();
271-
let (inner_challenges, inner_evals, _) = sumcheck::prove::<EF, _, _, _>(
236+
let (inner_challenges, _, _) = sumcheck::prove::<EF, _, _, _>(
272237
1,
273238
inner_mle,
274239
&ProductComputation,
@@ -284,43 +249,14 @@ fn open_structured_columns<'a, EF: ExtensionField<PF<EF>> + ExtensionField<IF>,
284249
// TODO using inner_evals[1], we can avoid 1 of the evaluations below (the last one)
285250

286251
let mut evaluations_remaining_to_prove = vec![];
287-
for i in 0..n_groups {
288-
let group = &witness.column_groups[i];
289-
let point = MultilinearPoint(
290-
[
291-
from_end(
292-
&batching_scalars[log_n_groups..],
293-
log2_ceil_usize(group.len()),
294-
)
295-
.to_vec(),
296-
inner_challenges.0.clone(),
297-
]
298-
.concat(),
299-
);
300-
let value = {
301-
let mut padded_group = IF::zero_vec(group.len().next_power_of_two() * witness.n_rows());
302-
for (i, col) in witness.cols[group.clone()].iter().enumerate() {
303-
padded_group[i * witness.n_rows()..(i + 1) * witness.n_rows()].copy_from_slice(col);
304-
}
305-
padded_group.evaluate(&point)
306-
};
307-
prover_state.add_extension_scalars(&[value]);
308-
evaluations_remaining_to_prove.push(Evaluation { point, value });
252+
for col in witness {
253+
let value = col.evaluate(&inner_challenges);
254+
prover_state.add_extension_scalar(value);
255+
evaluations_remaining_to_prove.push(Evaluation {
256+
point: inner_challenges.clone(),
257+
value,
258+
});
309259
}
310260

311-
assert_eq!(
312-
inner_evals[1],
313-
dot_product(
314-
eval_eq(&batching_scalars[..log_n_groups]).into_iter(),
315-
(0..n_groups).map(|i| evaluations_remaining_to_prove[i].value
316-
* batching_scalars[log_n_groups
317-
..log_n_groups + witness.log_max_columns_per_group()
318-
- log2_ceil_usize(witness.column_groups[i].len())]
319-
.iter()
320-
.map(|&x| EF::ONE - x)
321-
.product::<EF>())
322-
)
323-
);
324-
325261
evaluations_remaining_to_prove
326262
}

crates/air/src/table.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use p3_uni_stark::get_symbolic_constraints;
88
use tracing::instrument;
99
use utils::{ConstraintChecker, PF};
1010

11-
use crate::{NormalAir, PackedAir, witness::AirWitness};
11+
use crate::{NormalAir, PackedAir};
1212

1313
#[derive(Debug)]
1414
pub struct AirTable<EF: Field, A, AP> {
@@ -41,12 +41,14 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
4141
#[instrument(name = "Check trace validity", skip_all)]
4242
pub fn check_trace_validity<IF: ExtensionField<PF<EF>>>(
4343
&self,
44-
witness: &AirWitness<'_, IF>,
44+
witness: &[&[IF]],
4545
) -> Result<(), String>
4646
where
4747
EF: ExtensionField<IF>,
4848
{
49-
if witness.n_columns() != self.n_columns() {
49+
let n_rows = witness[0].len();
50+
assert!(witness.iter().all(|col| col.len() == n_rows));
51+
if witness.len() != self.n_columns() {
5052
return Err("Invalid number of columns".to_string());
5153
}
5254
let handle_errors = |row: usize, constraint_checker: &mut ConstraintChecker<'_, IF, EF>| {
@@ -65,7 +67,7 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
6567
Ok(())
6668
};
6769
if <A as BaseAir<PF<EF>>>::structured(&self.air) {
68-
for row in 0..witness.n_rows() - 1 {
70+
for row in 0..n_rows - 1 {
6971
let up = (0..self.n_columns())
7072
.map(|j| witness[j][row])
7173
.collect::<Vec<_>>();
@@ -98,7 +100,7 @@ impl<EF: ExtensionField<PF<EF>>, A: NormalAir<EF>, AP: PackedAir<EF>> AirTable<E
98100
handle_errors(row, &mut constraints_checker)?;
99101
}
100102
} else {
101-
for row in 0..witness.n_rows() {
103+
for row in 0..n_rows {
102104
let up = (0..self.n_columns())
103105
.map(|j| witness[j][row])
104106
.collect::<Vec<_>>();

0 commit comments

Comments
 (0)