Skip to content

Commit a85348b

Browse files
committed
faster matrix_down_folded
1 parent 40d22b9 commit a85348b

File tree

1 file changed

+53
-15
lines changed

1 file changed

+53
-15
lines changed

crates/air/src/utils.rs

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
use p3_field::Field;
22
use rayon::prelude::*;
33
use tracing::instrument;
4-
use whir_p3::poly::{
5-
evals::{eval_eq, scale_poly},
6-
multilinear::MultilinearPoint,
7-
};
4+
use whir_p3::poly::{evals::eval_eq, multilinear::MultilinearPoint};
85

96
/// Evaluates the LDE of the "UP" matrix polynomial.
107
///
@@ -258,24 +255,65 @@ pub fn matrix_up_folded<F: Field>(outer_challenges: &[F]) -> Vec<F> {
258255
folded
259256
}
260257

258+
/// Computes the folded evaluation vector for the "DOWN" matrix polynomial.
259+
///
260+
/// This function pre-computes the evaluations of the `matrix_down_lde` polynomial for a
261+
/// fixed set of `outer_challenges` over the entire boolean hypercube of the remaining variables.
262+
///
263+
/// ### Behavior
264+
/// The function constructs the evaluation vector by building the underlying polynomial term by term.
265+
/// It iterates through each possible bit position `k`, calculates a corresponding scalar coefficient,
266+
/// and combines it with the evaluations of an equality polynomial over the other variables.
267+
///
268+
/// ### Arguments
269+
/// * `outer_challenges`: An `n`-element slice representing the point at which the first
270+
/// `n` variables of the LDE have been fixed.
271+
///
272+
/// ### Returns
273+
/// A `Vec<F>` of size `2^n` containing the evaluations, which represents one column of the
274+
/// "DOWN" matrix evaluated at the challenge point.
261275
#[instrument(name = "matrix_down_folded", skip_all)]
262276
pub fn matrix_down_folded<F: Field>(outer_challenges: &[F]) -> Vec<F> {
277+
// Get the number of variables, `n`.
263278
let n = outer_challenges.len();
264-
let mut folded = vec![F::ZERO; 1 << n];
279+
// Calculate the size of the evaluation domain (2^n).
280+
let size = 1 << n;
281+
// Initialize the result vector with zeros.
282+
let mut folded = vec![F::ZERO; size];
283+
284+
// Precompute products of suffixes of the challenges for efficient lookups.
285+
// `suffix_prods[i]` will store the product of challenges from index `i` to the end.
286+
// e.g., for challenges [c0, c1, c2], suffix_prods = [c0*c1*c2, c1*c2, c2, 1]
287+
let mut suffix_prods = vec![F::ONE; n + 1];
288+
for i in (0..n).rev() {
289+
suffix_prods[i] = suffix_prods[i + 1] * outer_challenges[i];
290+
}
291+
292+
// This loop constructs the final folded polynomial term by term, iterating through
293+
// each possible carry position `k` (from right to left).
265294
for k in 0..n {
266-
let outer_challenges_prod = (F::ONE - outer_challenges[n - k - 1])
267-
* outer_challenges[n - k..].iter().copied().product::<F>();
268-
let mut eq_mle = eval_eq(&outer_challenges[0..n - k - 1]);
269-
eq_mle = scale_poly(&eq_mle, outer_challenges_prod);
270-
for (mut i, v) in eq_mle.iter_mut().enumerate() {
271-
i <<= k + 1;
272-
i += 1 << k;
273-
folded[i] += *v;
295+
// Calculate the scalar coefficient for this term of the polynomial sum,
296+
// using the precomputed suffix product for efficiency.
297+
let scalar = (F::ONE - outer_challenges[n - k - 1]) * suffix_prods[n - k];
298+
299+
// Get the evaluations of the equality polynomial for the high-order bits.
300+
let eq_mle = eval_eq(&outer_challenges[0..n - k - 1]);
301+
302+
// This loop adds the scaled evaluations into the final `folded` vector.
303+
for (i, &v) in eq_mle.iter().enumerate() {
304+
// This bit-shifting logic calculates the correct target index in the `folded` vector,
305+
// which corresponds to constructing the polynomial via tensor products.
306+
let final_idx = (i << (k + 1)) + (1 << k);
307+
// The value from the equality polynomial is scaled and added to the target position.
308+
folded[final_idx] += v * scalar;
274309
}
275310
}
276-
// bottom left corner:
277-
folded[(1 << n) - 1] += outer_challenges.iter().copied().product::<F>();
278311

312+
// Add the correction for the bottom-right corner of the matrix, using the
313+
// precomputed total product of all challenges.
314+
folded[size - 1] += suffix_prods[0];
315+
316+
// Return the completed evaluation vector.
279317
folded
280318
}
281319

0 commit comments

Comments
 (0)