Skip to content

Commit bc98834

Browse files
committed
faster open_structured_columns in AIR
1 parent a0aa176 commit bc98834

File tree

7 files changed

+81
-57
lines changed

7 files changed

+81
-57
lines changed

Cargo.lock

Lines changed: 8 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/air/src/prove.rs

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,25 @@ fn open_structured_columns<EF: ExtensionField<PF<EF>> + ExtensionField<IF>, IF:
195195

196196
let batched_column =
197197
multilinears_linear_combination(witness, &poly_eq_batching_scalars[..n_columns]);
198-
let mut batched_column_mixed = column_up(&batched_column);
199-
add_multilinears_inplace(
200-
&mut batched_column_mixed,
201-
&scale_poly(&column_down(&batched_column), alpha),
202-
);
198+
199+
let batched_column_mixed = info_span!("mixing up / down").in_scope(|| {
200+
let mut batched_column_mixed = column_down(&batched_column);
201+
add_multilinears_inplace(
202+
&mut batched_column_mixed,
203+
&scale_poly(&column_up(&batched_column), alpha),
204+
);
205+
batched_column_mixed
206+
});
207+
203208
// TODO do not recompute this (we can deduce it from already computed values)
204-
let sub_evals = fold_multilinear_chunks(
205-
&batched_column_mixed,
206-
&MultilinearPoint(outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec()),
207-
);
209+
let sub_evals = info_span!("fold_multilinear_chunks").in_scope(|| {
210+
fold_multilinear_chunks(
211+
&batched_column_mixed,
212+
&MultilinearPoint(
213+
outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec(),
214+
),
215+
)
216+
});
208217
prover_state.add_extension_scalars(&sub_evals);
209218

210219
let epsilons = prover_state.sample_vec(univariate_skips);
@@ -216,29 +225,39 @@ fn open_structured_columns<EF: ExtensionField<PF<EF>> + ExtensionField<IF>, IF:
216225
.concat();
217226

218227
// TODO do not recompute this (we can deduce it from already computed values)
219-
let inner_sum = batched_column_mixed.evaluate(&MultilinearPoint(point.clone()));
220-
221-
let mut mat_up = matrix_up_folded(&point);
222-
add_multilinears_inplace(&mut mat_up, &scale_poly(&matrix_down_folded(&point), alpha));
223-
let inner_mle = MleGroupOwned::Extension(vec![mat_up, batched_column]);
224-
225-
let (inner_challenges, _, _) = sumcheck_prove::<EF, _, _, _>(
226-
1,
227-
inner_mle,
228-
&ProductComputation,
229-
&ProductComputation,
230-
&[],
231-
None,
232-
false,
233-
prover_state,
234-
inner_sum,
235-
None,
236-
);
228+
let inner_sum = info_span!("mixed column eval")
229+
.in_scope(|| batched_column_mixed.evaluate(&MultilinearPoint(point.clone())));
230+
231+
let mut mat_up = matrix_up_folded(&point, alpha);
232+
matrix_down_folded(&point, &mut mat_up);
233+
let inner_mle = info_span!("packing").in_scope(|| {
234+
MleGroupOwned::ExtensionPacked(vec![
235+
pack_extension(&mat_up),
236+
pack_extension(&batched_column),
237+
])
238+
});
237239

238-
let evaluations_remaining_to_prove = witness
239-
.iter()
240-
.map(|col| col.evaluate(&inner_challenges))
241-
.collect::<Vec<_>>();
240+
let (inner_challenges, _, _) = info_span!("structured columns sumcheck").in_scope(|| {
241+
sumcheck_prove::<EF, _, _, _>(
242+
1,
243+
inner_mle,
244+
&ProductComputation,
245+
&ProductComputation,
246+
&[],
247+
None,
248+
false,
249+
prover_state,
250+
inner_sum,
251+
None,
252+
)
253+
});
254+
255+
let evaluations_remaining_to_prove = info_span!("final evals").in_scope(|| {
256+
witness
257+
.iter()
258+
.map(|col| col.evaluate(&inner_challenges))
259+
.collect::<Vec<_>>()
260+
});
242261
prover_state.add_extension_scalars(&evaluations_remaining_to_prove);
243262

244263
(inner_challenges, evaluations_remaining_to_prove)

crates/air/src/uni_skip_utils.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
11
use multilinear_toolkit::prelude::*;
2-
use p3_field::Field;
2+
use tracing::instrument;
33

4-
pub fn matrix_up_folded<F: Field>(outer_challenges: &[F]) -> Vec<F> {
4+
#[instrument(skip_all)]
5+
pub fn matrix_up_folded<F: ExtensionField<PF<F>>>(outer_challenges: &[F], alpha: F) -> Vec<F> {
56
let n = outer_challenges.len();
6-
let mut folded = eval_eq(outer_challenges);
7+
let mut folded = eval_eq_scaled(outer_challenges, alpha);
78
let outer_challenges_prod: F = outer_challenges.iter().copied().product();
8-
folded[(1 << n) - 1] -= outer_challenges_prod;
9-
folded[(1 << n) - 2] += outer_challenges_prod;
9+
folded[(1 << n) - 1] -= outer_challenges_prod * alpha;
10+
folded[(1 << n) - 2] += outer_challenges_prod * alpha;
1011
folded
1112
}
1213

13-
pub fn matrix_down_folded<F: Field>(outer_challenges: &[F]) -> Vec<F> {
14+
#[instrument(skip_all)]
15+
pub fn matrix_down_folded<F: ExtensionField<PF<F>>>(outer_challenges: &[F], dest: &mut [F]) {
1416
let n = outer_challenges.len();
15-
let mut folded = vec![F::ZERO; 1 << n];
1617
for k in 0..n {
1718
let outer_challenges_prod = (F::ONE - outer_challenges[n - k - 1])
1819
* outer_challenges[n - k..].iter().copied().product::<F>();
19-
let mut eq_mle = eval_eq(&outer_challenges[0..n - k - 1]);
20-
eq_mle = scale_poly(&eq_mle, outer_challenges_prod);
20+
let mut eq_mle = eval_eq_scaled(&outer_challenges[0..n - k - 1], outer_challenges_prod);
2121
for (mut i, v) in eq_mle.iter_mut().enumerate() {
2222
i <<= k + 1;
2323
i += 1 << k;
24-
folded[i] += *v;
24+
dest[i] += *v;
2525
}
2626
}
2727
// bottom left corner:
28-
folded[(1 << n) - 1] += outer_challenges.iter().copied().product::<F>();
29-
30-
folded
28+
dest[(1 << n) - 1] += outer_challenges.iter().copied().product::<F>();
3129
}

crates/air/src/utils.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,14 @@ pub(crate) fn columns_up_and_down<F: Field>(columns: &[&[F]]) -> Vec<Vec<F>> {
145145
}
146146

147147
pub(crate) fn column_up<F: Field>(column: &[F]) -> Vec<F> {
148-
let mut up = column.to_vec();
148+
let mut up = parallel_clone_vec(column);
149149
up[column.len() - 1] = up[column.len() - 2];
150150
up
151151
}
152152

153153
pub(crate) fn column_down<F: Field>(column: &[F]) -> Vec<F> {
154-
let mut down = column[1..].to_vec();
155-
down.push(*down.last().unwrap());
154+
let mut down = unsafe { uninitialized_vec(column.len()) };
155+
parallel_clone(&column[1..], &mut down[..column.len() - 1]);
156+
down[column.len() - 1] = down[column.len() - 2];
156157
down
157158
}

crates/air/src/verify.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,11 @@ fn verify_structured_columns<EF: ExtensionField<PF<EF>>>(
177177
) != dot_product::<EF, _, _>(
178178
all_witness_up.iter().copied(),
179179
poly_eq_batching_scalars.iter().copied(),
180-
) + dot_product::<EF, _, _>(
181-
all_witness_down.iter().copied(),
182-
poly_eq_batching_scalars.iter().copied(),
183180
) * alpha
181+
+ dot_product::<EF, _, _>(
182+
all_witness_down.iter().copied(),
183+
poly_eq_batching_scalars.iter().copied(),
184+
)
184185
{
185186
return Err(ProofError::InvalidProof);
186187
}
@@ -202,7 +203,7 @@ fn verify_structured_columns<EF: ExtensionField<PF<EF>>>(
202203
let up = matrix_up_lde(&matrix_lde_point);
203204
let down = matrix_down_lde(&matrix_lde_point);
204205

205-
let final_value = inner_sumcheck_stement.value / (up + alpha * down);
206+
let final_value = inner_sumcheck_stement.value / (up * alpha + down);
206207

207208
let evaluations_remaining_to_verify = verifier_state.next_extension_scalars_vec(n_columns)?;
208209

crates/utils/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ tracing-forest.workspace = true
1717
p3-poseidon2.workspace = true
1818
p3-poseidon2-air.workspace = true
1919
tracing-subscriber.workspace = true
20+
tracing.workspace = true
2021
p3-util.workspace = true
2122
multilinear-toolkit.workspace = true
2223

crates/utils/src/multilinear.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use p3_field::{ExtensionField, Field, dot_product};
55
use p3_util::log2_strict_usize;
66

77
use multilinear_toolkit::prelude::*;
8+
use tracing::instrument;
89

10+
#[instrument(skip_all)]
911
pub fn multilinears_linear_combination<
1012
F: Field,
1113
EF: ExtensionField<F>,

0 commit comments

Comments
 (0)