diff --git a/crates/air/src/lib.rs b/crates/air/src/lib.rs index af14a77b..c052e322 100644 --- a/crates/air/src/lib.rs +++ b/crates/air/src/lib.rs @@ -11,7 +11,6 @@ use p3_uni_stark::SymbolicAirBuilder; mod prove; pub use prove::*; pub mod table; -mod uni_skip_utils; mod utils; mod verify; pub use verify::*; diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index c55b5412..74c05975 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -16,8 +16,7 @@ use whir_p3::{ use super::table::AirTable; use crate::{ MyAir, - uni_skip_utils::{matrix_down_folded, matrix_up_folded}, - utils::{column_down, column_up, columns_up_and_down}, + utils::{column_down, column_up, columns_up_and_down, matrix_down_folded, matrix_up_folded}, witness::AirWitness, }; diff --git a/crates/air/src/uni_skip_utils.rs b/crates/air/src/uni_skip_utils.rs deleted file mode 100644 index dfb9438d..00000000 --- a/crates/air/src/uni_skip_utils.rs +++ /dev/null @@ -1,34 +0,0 @@ -use p3_field::Field; -use tracing::instrument; -use whir_p3::poly::evals::{eval_eq, scale_poly}; - -#[instrument(name = "matrix_up_folded", skip_all)] -pub fn matrix_up_folded(outer_challenges: &[F]) -> Vec { - let n = outer_challenges.len(); - let mut folded = eval_eq(outer_challenges); - let outer_challenges_prod: F = outer_challenges.iter().copied().product(); - folded[(1 << n) - 1] -= outer_challenges_prod; - folded[(1 << n) - 2] += outer_challenges_prod; - folded -} - -#[instrument(name = "matrix_down_folded", skip_all)] -pub fn matrix_down_folded(outer_challenges: &[F]) -> Vec { - let n = outer_challenges.len(); - let mut folded = vec![F::ZERO; 1 << n]; - for k in 0..n { - let outer_challenges_prod = (F::ONE - outer_challenges[n - k - 1]) - * outer_challenges[n - k..].iter().copied().product::(); - let mut eq_mle = eval_eq(&outer_challenges[0..n - k - 1]); - eq_mle = scale_poly(&eq_mle, outer_challenges_prod); - for (mut i, v) in eq_mle.iter_mut().enumerate() { - i <<= k + 1; - i += 1 << k; - folded[i] += *v; - } - } - // bottom left corner: - folded[(1 << n) - 1] += outer_challenges.iter().copied().product::(); - - folded -} diff --git a/crates/air/src/utils.rs b/crates/air/src/utils.rs index 5ed9df37..128b2edf 100644 --- a/crates/air/src/utils.rs +++ b/crates/air/src/utils.rs @@ -1,61 +1,87 @@ use p3_field::Field; use rayon::prelude::*; -use whir_p3::poly::multilinear::MultilinearPoint; +use tracing::instrument; +use whir_p3::poly::{evals::eval_eq, multilinear::MultilinearPoint}; +/// Evaluates the LDE of the "UP" matrix polynomial. +/// +/// This function represents a matrix used to select the "current" row values (e.g., `c[r]`) +/// in AIR constraints. +/// +/// ### Behavior +/// The polynomial behaves like an identity matrix for all rows except the last one. To handle +/// the boundary condition of an execution trace, the last row is modified to be a copy of the +/// second-to-last row of the identity matrix. +/// +/// ### Example +/// For a trace with `N=4` rows, the matrix represented by this polynomial is: +/// ```text +/// [[1, 0, 0, 0], +/// [0, 1, 0, 0], +/// [0, 0, 1, 0], +/// [0, 0, 1, 0]] <-- The last row is a copy of the row above it from the identity matrix. +/// ``` +/// +/// ### Arguments +/// * `point`: A slice of `2n` field elements `[s1, s2]`, where: +/// - `s1` is an `n`-element vector representing the **row index**. +/// - `s2` is an `n`-element vector for the **column index**. +/// +/// ### Returns +/// A single field element `F` representing the polynomial's evaluation at the given `point`. pub fn matrix_up_lde(point: &[F]) -> F { - /* - Matrix UP: - - (1 0 0 0 ... 0 0 0) - (0 1 0 0 ... 0 0 0) - (0 0 1 0 ... 0 0 0) - (0 0 0 1 ... 0 0 0) - ... ... ... - (0 0 0 0 ... 1 0 0) - (0 0 0 0 ... 0 1 0) - (0 0 0 0 ... 0 1 0) - - Square matrix of size self.n_columns x sef.n_columns - As a multilinear polynomial in 2 * log_length variables: - - self.n_columns first variables -> encoding the row index - - self.n_columns last variables -> encoding the column index - */ - - assert_eq!(point.len() % 2, 0); + // Ensure the input point has an even number of variables, so it can be split in half. + assert!(point.len().is_multiple_of(2)); + // Determine `n`, the number of variables for a single index (row or column). let n = point.len() / 2; + // Split the 2n-element point into two n-element halves: `s1` (column) and `s2` (row). let (s1, s2) = point.split_at(n); + + // The polynomial is composed of two main parts: + // 1. The equality polynomial `eq(s1, s2)`, which evaluates to 1 if the column index + // equals the row index, and 0 otherwise. This term constructs the main diagonal + // of the identity matrix. + // 2. A correction term that modifies the matrix based on the input point. This term + // is specifically designed to be non-zero only under the conditions that adjust + // the last row of the matrix, ensuring the correct boundary behavior. MultilinearPoint(s1.to_vec()).eq_poly_outside(&MultilinearPoint(s2.to_vec())) + point[..point.len() - 1].iter().copied().product::() - * (F::ONE - point[point.len() - 1] * F::TWO) + * (F::ONE - point.last().unwrap().double()) } +/// Evaluates the LDE of the "DOWN" matrix polynomial. +/// +/// This function represents a matrix used to select the "next" row values (e.g., `c[r+1]`). +/// It maps row `r` to row `r+1`. For the boundary, it maps the last row to itself. +/// +/// ### Behavior +/// The polynomial represents a matrix with `1`s on the superdiagonal (`M[r, r+1] = 1`). +/// For the last row, the entry `M[N-1, N-1]` is also `1` to handle the boundary. +/// +/// ### Example +/// For a trace with `N=4` rows, the matrix is: +/// ```text +/// [[0, 1, 0, 0], +/// [0, 0, 1, 0], +/// [0, 0, 0, 1], +/// [0, 0, 0, 1]] // Last row maps to itself. +/// ``` +/// +/// ### Arguments +/// * `point`: A slice of `2n` field elements `[s1, s2]`, where: +/// - `s1` is an `n`-element vector representing the **row index**. +/// - `s2` is an `n`-element vector for the **column index**. +/// +/// ### Returns +/// A single field element `F` representing the polynomial's evaluation at `point`. pub fn matrix_down_lde(point: &[F]) -> F { - /* - Matrix DOWN: - - (0 1 0 0 ... 0 0 0) - (0 0 1 0 ... 0 0 0) - (0 0 0 1 ... 0 0 0) - (0 0 0 0 ... 0 0 0) - (0 0 0 0 ... 0 0 0) - ... ... ... - (0 0 0 0 ... 0 1 0) - (0 0 0 0 ... 0 0 1) - (0 0 0 0 ... 0 0 1) - - Square matrix of size self.n_columns x sef.n_columns - As a multilinear polynomial in 2 * log_length variables: - - self.n_columns first variables -> encoding the row index - - self.n_columns last variables -> encoding the column index - - TODO OPTIMIZATIOn: - the lde currently is in log(table_length)^2, but it could be log(table_length) using a recursive construction - (However it is not representable as a polynomial in this case, but as a fraction instead) - - */ - next_mle(point) + point.iter().copied().product::() - - // bottom right corner + // The polynomial is the sum of two components: + // 1. `next_mle(point)`: This polynomial is 1 if `s2` (column index) is the integer + // successor of `s1` (row index). This creates the superdiagonal of the matrix. + next_mle(point) + // 2. `product(point)`: This is 1 only if all bits in the point are 1. This handles + // the bottom-right corner of the matrix, ensuring the last row maps to itself. + + point.iter().copied().product::() } /// Returns a multilinear polynomial in 2n variables that evaluates to 1 @@ -84,9 +110,8 @@ pub fn matrix_down_lde(point: &[F]) -> F { /// Field element: 1 if y = x + 1, 0 otherwise. fn next_mle(point: &[F]) -> F { // Check that the point length is even: we split into x and y of equal length. - assert_eq!( - point.len() % 2, - 0, + assert!( + point.len().is_multiple_of(2), "Input point must have an even number of variables." ); let n = point.len() / 2; @@ -130,22 +155,538 @@ fn next_mle(point: &[F]) -> F { .sum() } +/// Generates "up" and "down" shifted columns for a set of AIR columns. +/// +/// This is a utility function that applies `column_up` and `column_down` in parallel +/// to a slice of columns, as required by the zerocheck protocol. +/// +/// ### Arguments +/// * `columns`: A slice of column slices (`&[&[F]]`). +/// +/// ### Returns +/// A `Vec` containing the results in the order `[up(c1), up(c2), ..., down(c1), down(c2), ...]`. pub fn columns_up_and_down(columns: &[&[F]]) -> Vec> { - columns - .par_iter() - .map(|c| column_up(c)) - .chain(columns.par_iter().map(|c| column_down(c))) - .collect() + // Process "up" columns in parallel using Rayon. + let up_cols = columns.par_iter().map(|c| column_up(c)); + // Process "down" columns in parallel. + let down_cols = columns.par_iter().map(|c| column_down(c)); + // Chain the two parallel iterators and collect the results into a single vector. + up_cols.chain(down_cols).collect() } +/// Creates the "up" version of a column (`c_up`). +/// +/// This corresponds to the `c_up` definition from the paper. It copies the column but +/// replaces the last element with the second-to-last element to handle the boundary case. +/// +/// ### Example +/// `[a, b, c, d]` becomes `[a, b, c, c]` +/// +/// ### Arguments +/// * `column`: A slice representing a single AIR column. +/// +/// ### Returns +/// A new `Vec` with the transformation applied. pub fn column_up(column: &[F]) -> Vec { + // Provide a helpful error in debug mode if the column is too short. + debug_assert!(column.len() >= 2, "column_up requires length >= 2"); + // Create a mutable copy of the input column. let mut up = column.to_vec(); + // Overwrite the last element with the value of the second-to-last element. up[column.len() - 1] = up[column.len() - 2]; up } +/// Creates the "down" version of a column (`c_down`). +/// +/// This corresponds to the `c_down` definition from the paper. It shifts the column's elements +/// up by one position and duplicates the new last element to handle the boundary case. +/// +/// ### Example +/// `[a, b, c, d]` becomes `[b, c, d, d]` +/// +/// ### Arguments +/// * `column`: A slice representing a single AIR column. +/// +/// ### Returns +/// A new `Vec` with the transformation applied. pub fn column_down(column: &[F]) -> Vec { - let mut down = column[1..].to_vec(); - down.push(*down.last().unwrap()); - down + // Provide a helpful error in debug mode if the column is empty. + debug_assert!( + !column.is_empty(), + "column_down requires a non-empty column" + ); + // Get the last element, which will be appended at the end. + let last_val = column[column.len() - 1]; + // - Create an iterator that skips the first element, + // - Chains the last element at the end to maintain original length. + column + .iter() + .skip(1) + .copied() + .chain(std::iter::once(last_val)) + .collect() +} + +/// Computes the folded evaluation vector for the "UP" matrix polynomial. +/// +/// This function pre-computes the evaluations of the `matrix_up_lde` polynomial for a +/// fixed set of `outer_challenges` over the entire boolean hypercube of the remaining variables. +/// +/// ### Arguments +/// * `outer_challenges`: An `n`-element slice representing the point at which the first +/// `n` variables of the LDE have been fixed. +/// +/// ### Returns +/// A `Vec` of size `2^n` containing the evaluations. +#[instrument(name = "matrix_up_folded", skip_all)] +pub fn matrix_up_folded(outer_challenges: &[F]) -> Vec { + // Get the number of variables, `n`. + let n = outer_challenges.len(); + // Calculate the size of the evaluation domain (2^n). + let size = 1 << n; + // Start with the evaluations of the equality polynomial `eq(X, challenges)`. + let mut folded = eval_eq(outer_challenges); + // Calculate the product of all challenges for the correction term. + let outer_challenges_prod: F = outer_challenges.iter().copied().product(); + // Apply corrections to the last two elements of the evaluation vector. + folded[size - 1] -= outer_challenges_prod; + folded[size - 2] += outer_challenges_prod; + folded +} + +/// Computes the folded evaluation vector for the "DOWN" matrix polynomial. +/// +/// This function pre-computes the evaluations of the `matrix_down_lde` polynomial for a +/// fixed set of `outer_challenges` over the entire boolean hypercube of the remaining variables. +/// +/// ### Behavior +/// The function constructs the evaluation vector by building the underlying polynomial term by term. +/// It iterates through each possible bit position `k`, calculates a corresponding scalar coefficient, +/// and combines it with the evaluations of an equality polynomial over the other variables. +/// +/// ### Arguments +/// * `outer_challenges`: An `n`-element slice representing the point at which the first +/// `n` variables of the LDE have been fixed. +/// +/// ### Returns +/// A `Vec` of size `2^n` containing the evaluations, which represents one column of the +/// "DOWN" matrix evaluated at the challenge point. +#[instrument(name = "matrix_down_folded", skip_all)] +pub fn matrix_down_folded(outer_challenges: &[F]) -> Vec { + // Get the number of variables, `n`. + let n = outer_challenges.len(); + // Calculate the size of the evaluation domain (2^n). + let size = 1 << n; + // Initialize the result vector with zeros. + let mut folded = vec![F::ZERO; size]; + + // Precompute products of suffixes of the challenges for efficient lookups. + // `suffix_prods[i]` will store the product of challenges from index `i` to the end. + // e.g., for challenges [c0, c1, c2], suffix_prods = [c0*c1*c2, c1*c2, c2, 1] + let mut suffix_prods = vec![F::ONE; n + 1]; + for i in (0..n).rev() { + suffix_prods[i] = suffix_prods[i + 1] * outer_challenges[i]; + } + + // This loop constructs the final folded polynomial term by term, iterating through + // each possible carry position `k` (from right to left). + for k in 0..n { + // Calculate the scalar coefficient for this term of the polynomial sum, + // using the precomputed suffix product for efficiency. + let scalar = (F::ONE - outer_challenges[n - k - 1]) * suffix_prods[n - k]; + + // Get the evaluations of the equality polynomial for the high-order bits. + let eq_mle = eval_eq(&outer_challenges[0..n - k - 1]); + + // This loop adds the scaled evaluations into the final `folded` vector. + for (i, &v) in eq_mle.iter().enumerate() { + // This bit-shifting logic calculates the correct target index in the `folded` vector, + // which corresponds to constructing the polynomial via tensor products. + let final_idx = (i << (k + 1)) + (1 << k); + // The value from the equality polynomial is scaled and added to the target position. + folded[final_idx] += v * scalar; + } + } + + // Add the correction for the bottom-right corner of the matrix, using the + // precomputed total product of all challenges. + folded[size - 1] += suffix_prods[0]; + + // Return the completed evaluation vector. + folded +} + +#[cfg(test)] +mod tests { + use p3_field::PrimeCharacteristicRing; + use p3_koala_bear::KoalaBear; + + use super::*; + + type F = KoalaBear; + + /// Helper function to convert an integer to its big-endian bit representation as field elements. + /// e.g., `int_to_bits(5, 3)` -> `[F::ONE, F::ZERO, F::ONE]` (for 101_2) + fn int_to_bits(n: u32, num_bits: usize) -> Vec { + (0..num_bits) + .map(|i| { + if (n >> (num_bits - 1 - i)) & 1 == 1 { + F::ONE + } else { + F::ZERO + } + }) + .collect() + } + + #[test] + #[should_panic] + fn matrix_up_lde_panics_on_odd_len() { + // 3 variables is invalid (must be even) + matrix_up_lde::(&[F::ZERO, F::ONE, F::ONE]); + } + + #[test] + fn test_next_mle_successor_cases() { + let n = 2; // Testing with 2-bit numbers, so point length is 4. + + // Case: 0 -> 1. x=(0,0), y=(0,1) + let x = int_to_bits(0, n); + let y = int_to_bits(1, n); + assert_eq!(next_mle(&[x, y].concat()), F::ONE, "Failed for 0 -> 1"); + + // Case: 1 -> 2. x=(0,1), y=(1,0) + let x = int_to_bits(1, n); + let y = int_to_bits(2, n); + assert_eq!(next_mle(&[x, y].concat()), F::ONE, "Failed for 1 -> 2"); + + // Case: 2 -> 3. x=(1,0), y=(1,1) + let x = int_to_bits(2, n); + let y = int_to_bits(3, n); + assert_eq!(next_mle(&[x, y].concat()), F::ONE, "Failed for 2 -> 3"); + } + + #[test] + fn test_next_mle_non_successor_cases() { + let n = 2; + + // Case: Not a successor. x=(0,0), y=(1,0) (0 -> 2) + let x = int_to_bits(0, n); + let y = int_to_bits(2, n); + assert_eq!(next_mle(&[x, y].concat()), F::ZERO, "Failed for 0 -> 2"); + + // Case: Identity. x=(1,0), y=(1,0) (2 -> 2) + let x = int_to_bits(2, n); + let y = int_to_bits(2, n); + assert_eq!(next_mle(&[x, y].concat()), F::ZERO, "Failed for 2 -> 2"); + + // Case: Wrap around (not handled by this function). x=(1,1), y=(0,0) (3 -> 0) + let x = int_to_bits(3, n); + let y = int_to_bits(0, n); + assert_eq!( + next_mle(&[x, y].concat()), + F::ZERO, + "Failed for 3 -> 0 (wrap-around)" + ); + } + + #[test] + #[should_panic] + fn test_matrix_up_lde_panics_on_odd_len() { + // The function expects an even-length slice (2*n variables), + matrix_up_lde::(&[F::ZERO, F::ONE, F::ONE]); + } + + #[test] + fn test_matrix_up_lde_on_hypercube() { + // Set n=2, meaning row and column indices are 2-bit numbers. + // This corresponds to a 4x4 matrix (2^n x 2^n). + let n = 2; + // The matrix M is the identity matrix, except the last row is a copy of the second-to-last + // row of the identity matrix. M[3] = (0,0,1,0). + let expected_matrix = [ + // c0 c1 c2 c3 + [F::ONE, F::ZERO, F::ZERO, F::ZERO], // row 0 + [F::ZERO, F::ONE, F::ZERO, F::ZERO], // row 1 + [F::ZERO, F::ZERO, F::ONE, F::ZERO], // row 2 + [F::ZERO, F::ZERO, F::ONE, F::ZERO], // row 3 + ]; + + // Iterate through every row index of the 4x4 matrix. + for r_idx in 0..4 { + // Iterate through every column index of the 4x4 matrix. + for c_idx in 0..4 { + // Convert the integer row index (e.g., 2) into its bit vector ([1, 0]). + let r_bits = int_to_bits(r_idx, n); + // Convert the integer column index (e.g., 3) into its bit vector ([1, 1]). + let c_bits = int_to_bits(c_idx, n); + // The function's input `point` is the concatenation of the row and column bits. + let point = [r_bits.as_slice(), c_bits.as_slice()].concat(); + + // Get the expected value (0 or 1) from our hardcoded matrix for the current (row, col). + let expected = expected_matrix[r_idx as usize][c_idx as usize]; + // Call the function with the generated point to get its actual evaluation. + let actual = matrix_up_lde(&point); + + // Assert that the actual value matches the expected value. + assert_eq!(actual, expected, "Mismatch at M_up[{r_idx}, {c_idx}]"); + } + } + } + + #[test] + #[should_panic] + fn test_matrix_down_lde_panics_on_odd_len() { + // The function expects an even-length slice (2*n variables). + matrix_down_lde::(&[F::ZERO, F::ONE, F::ONE]); + } + + #[test] + fn test_matrix_down_lde_on_hypercube() { + // Set n=2 for a 4x4 matrix. + let n = 2; + // The matrix M has 1s on the superdiagonal (col = row + 1) + // and the last two rows are identical (0,0,0,1). + let expected_matrix = [ + // c0 c1 c2 c3 + [F::ZERO, F::ONE, F::ZERO, F::ZERO], // row 0 + [F::ZERO, F::ZERO, F::ONE, F::ZERO], // row 1 + [F::ZERO, F::ZERO, F::ZERO, F::ONE], // row 2 + [F::ZERO, F::ZERO, F::ZERO, F::ONE], // row 3 + ]; + + // Iterate through every row index of the 4x4 matrix. + for r_idx in 0..4 { + // Iterate through every column index of the 4x4 matrix. + for c_idx in 0..4 { + // Convert the integer row index to its bit vector. + let r_bits = int_to_bits(r_idx, n); + // Convert the integer column index to its bit vector. + let c_bits = int_to_bits(c_idx, n); + // Concatenate bits to form the function's input point. + let point = [r_bits.as_slice(), c_bits.as_slice()].concat(); + + // Get the expected value from our hardcoded matrix. + let expected = expected_matrix[r_idx as usize][c_idx as usize]; + // Call the function to get the actual evaluated value. + let actual = matrix_down_lde(&point); + + // Assert that the actual value matches the expected value from the matrix. + assert_eq!(actual, expected, "Mismatch at M_down[{r_idx}, {c_idx}]"); + } + } + } + + #[test] + fn test_column_up() { + // Create a sample column vector with four distinct field elements. + let col = vec![ + F::from_u32(10), + F::from_u32(20), + F::from_u32(30), + F::from_u32(40), + ]; + + // The `column_up` function duplicates the second-to-last element into the last position. + // + // Transformation logic: + // + // [10] [10] + // [20] [20] + // [30] ---> [30] + // [40] [30] <-- This value is copied from the one above. + // + let expected = vec![ + F::from_u32(10), + F::from_u32(20), + F::from_u32(30), + F::from_u32(30), + ]; + // Assert that the function produces the correct "up" column. + assert_eq!(column_up(&col), expected); + + // Test the edge case with a column of length 2. + let col_len2 = vec![F::from_u32(5), F::from_u32(8)]; + // + // Transformation logic for length 2: + // + // [5] ---> [5] + // [8] [5] <-- Copied from above. + // + let expected_len2 = vec![F::from_u32(5), F::from_u32(5)]; + assert_eq!(column_up(&col_len2), expected_len2); + } + + #[test] + #[should_panic] + fn test_column_up_panics_on_len_1() { + column_up(&[F::ONE]); + } + + #[test] + fn test_column_down() { + // Create a sample column vector. + let col = vec![ + F::from_u32(10), + F::from_u32(20), + F::from_u32(30), + F::from_u32(40), + ]; + + // The `column_down` function shifts all elements up by one position + // and then duplicates the new last element to maintain the original length. + // + // Transformation logic: + // + // [10] [20] <-- Shifted up + // [20] [30] <-- Shifted up + // [30] ---> [40] <-- Shifted up + // [40] [40] <-- Duplicated from the new last element above + // + let expected = vec![ + F::from_u32(20), + F::from_u32(30), + F::from_u32(40), + F::from_u32(40), + ]; + // Assert that the function produces the correct "down" column. + assert_eq!(column_down(&col), expected); + + // Test the edge case with a column of length 2. + let col_len2 = vec![F::from_u32(5), F::from_u32(8)]; + // + // Transformation logic for length 2: + // + // [5] ---> [8] + // [8] [8] <-- Duplicated + // + let expected_len2 = vec![F::from_u32(8), F::from_u32(8)]; + assert_eq!(column_down(&col_len2), expected_len2); + } + + #[test] + fn test_columns_up_and_down() { + // Create two sample columns to process. + let col1 = vec![F::from_u32(1), F::from_u32(2), F::from_u32(3)]; + let col2 = vec![F::from_u32(4), F::from_u32(5), F::from_u32(6)]; + // The function takes a slice of column slices as input. + let columns = vec![col1.as_slice(), col2.as_slice()]; + + // The function first applies `column_up` to all input columns, + // then applies `column_down` to all input columns, and finally + // collects the results in that order. + // + // Input Columns: + // + // col1 | col2 + // -----|----- + // 1 | 4 + // 2 | 5 + // 3 | 6 + // + // Expected Output Structure: + // + // [ up(col1), up(col2), down(col1), down(col2) ] + // + // up(col1) = [1, 2, 2] + // up(col2) = [4, 5, 5] + // down(col1) = [2, 3, 3] + // down(col2) = [5, 6, 6] + let expected = vec![ + vec![F::from_u32(1), F::from_u32(2), F::from_u32(2)], // up(col1) + vec![F::from_u32(4), F::from_u32(5), F::from_u32(5)], // up(col2) + vec![F::from_u32(2), F::from_u32(3), F::from_u32(3)], // down(col1) + vec![F::from_u32(5), F::from_u32(6), F::from_u32(6)], // down(col2) + ]; + // Assert that the function correctly processes and collects all results. + assert_eq!(columns_up_and_down(&columns), expected); + } + + #[test] + fn test_matrix_up_folded_vs_lde() { + // Set n=3 variables, meaning we are testing the logic on an 8x8 matrix (since 2^3 = 8). + let n = 3; + // Calculate the number of rows/columns, which is 8. + let num_coords = 1 << n; + + // This test verifies that `matrix_up_folded` is consistent with `matrix_up_lde`. + // We iterate through each column of the matrix. + for c_idx in 0..num_coords { + // Convert the integer column index (e.g., 6) into its bit representation (e.g., [1, 1, 0]). + let c_bits = int_to_bits(c_idx as u32, n); + // Call `matrix_up_folded` to get the evaluations for the entire column `c_idx`. + // + // This vector represents the polynomial M(X, c_bits) evaluated for all boolean X. + let folded_col = matrix_up_folded(&c_bits); + // Sanity check: the resulting vector should have 8 elements, one for each row. + assert_eq!(folded_col.len(), num_coords as usize); + + // Now, for the fixed column, we check each row's value. + for r_idx in 0..num_coords { + // Convert the integer row index (e.g., 7) into its bit representation (e.g., [1, 1, 1]). + let r_bits = int_to_bits(r_idx as u32, n); + + // We construct the input point for the LDE by concatenating: + // - the COLUMN bits first, + // - then the ROW bits. + // This aligns the test's variable ordering with the one implicitly used by the `folded` functions. + let point = [c_bits.as_slice(), r_bits.as_slice()].concat(); + + // Evaluate the `lde` function at this specific point to get the "ground truth" value. + let lde_val = matrix_up_lde(&point); + // Get the corresponding value from the pre-calculated `folded_col` vector. + let folded_val = folded_col[r_idx as usize]; + + // Assert that the value from the folded evaluation matches the direct LDE evaluation. + // + // This confirms that both functions represent the same underlying polynomial. + assert_eq!( + lde_val, folded_val, + "Mismatch for M_up(col={c_idx}, row={r_idx})" + ); + } + } + } + + #[test] + fn test_matrix_down_folded_vs_lde() { + // Set n=3 variables for an 8x8 matrix. + let n = 3; + // Calculate the number of rows/columns. + let num_coords = 1 << n; + + // We iterate through each column to verify the consistency between the `folded` and `lde` functions. + for c_idx in 0..num_coords { + // Convert the column index to its bit representation. + let c_bits = int_to_bits(c_idx as u32, n); + // Get the entire column's evaluations from the `folded` function. + let folded_col = matrix_down_folded(&c_bits); + // Check if the output vector has the correct number of row entries. + assert_eq!(folded_col.len(), num_coords as usize); + + // For the given column, check each row's value. + for r_idx in 0..num_coords { + // Convert the row index to its bit representation. + let r_bits = int_to_bits(r_idx as u32, n); + + // We construct the input point for the LDE by concatenating: + // - the COLUMN bits first, + // - then the ROW bits. + // This aligns the test's variable ordering with the one implicitly used by the `folded` functions. + let point = [c_bits.as_slice(), r_bits.as_slice()].concat(); + + // Calculate the expected value by calling the `lde` function directly. + let lde_val = matrix_down_lde(&point); + // Get the actual value from the `folded` function's result vector. + let folded_val = folded_col[r_idx as usize]; + + // Assert that the two values are identical. + assert_eq!( + lde_val, folded_val, + "Mismatch for M_down(col={c_idx}, row={r_idx})" + ); + } + } + } }