Skip to content

Commit 29da7d1

Browse files
authored
air: better AirWitness implementation (#28)
1 parent deeaeec commit 29da7d1

File tree

3 files changed

+320
-72
lines changed

3 files changed

+320
-72
lines changed

crates/air/src/prove.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use whir_p3::{
1616
use super::table::AirTable;
1717
use crate::{
1818
MyAir,
19-
utils::{column_down, column_up, columns_up_and_down, matrix_down_folded, matrix_up_folded},
19+
utils::{column_down, column_up, matrix_down_folded, matrix_up_folded},
2020
witness::AirWitness,
2121
};
2222

@@ -154,15 +154,15 @@ pub fn prove_many_air_3<
154154
.chain(witnesses_2)
155155
.map(|w| {
156156
if structured_air {
157-
MleGroupOwned::Base(columns_up_and_down(w)).into()
157+
MleGroupOwned::Base(w.shifted_columns()).into()
158158
} else {
159159
MleGroupRef::Base(w.cols.clone()).into()
160160
}
161161
})
162162
.collect::<Vec<MleGroup<'_, EF>>>();
163163
columns_for_zero_check.extend(witnesses_3.iter().map(|w| {
164164
if structured_air {
165-
MleGroupOwned::Extension(columns_up_and_down(w)).into()
165+
MleGroupOwned::Extension(w.shifted_columns()).into()
166166
} else {
167167
MleGroupRef::Extension(w.cols.clone()).into()
168168
}

crates/air/src/utils.rs

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use p3_field::Field;
2-
use rayon::prelude::*;
32
use tracing::instrument;
43
use whir_p3::poly::{evals::eval_eq, multilinear::MultilinearPoint};
54

@@ -155,25 +154,6 @@ fn next_mle<F: Field>(point: &[F]) -> F {
155154
.sum()
156155
}
157156

158-
/// Generates "up" and "down" shifted columns for a set of AIR columns.
159-
///
160-
/// This is a utility function that applies `column_up` and `column_down` in parallel
161-
/// to a slice of columns, as required by the zerocheck protocol.
162-
///
163-
/// ### Arguments
164-
/// * `columns`: A slice of column slices (`&[&[F]]`).
165-
///
166-
/// ### Returns
167-
/// A `Vec` containing the results in the order `[up(c1), up(c2), ..., down(c1), down(c2), ...]`.
168-
pub fn columns_up_and_down<F: Field>(columns: &[&[F]]) -> Vec<Vec<F>> {
169-
// Process "up" columns in parallel using Rayon.
170-
let up_cols = columns.par_iter().map(|c| column_up(c));
171-
// Process "down" columns in parallel.
172-
let down_cols = columns.par_iter().map(|c| column_down(c));
173-
// Chain the two parallel iterators and collect the results into a single vector.
174-
up_cols.chain(down_cols).collect()
175-
}
176-
177157
/// Creates the "up" version of a column (`c_up`).
178158
///
179159
/// This corresponds to the `c_up` definition from the paper. It copies the column but
@@ -565,44 +545,6 @@ mod tests {
565545
assert_eq!(column_down(&col_len2), expected_len2);
566546
}
567547

568-
#[test]
569-
fn test_columns_up_and_down() {
570-
// Create two sample columns to process.
571-
let col1 = vec![F::from_u32(1), F::from_u32(2), F::from_u32(3)];
572-
let col2 = vec![F::from_u32(4), F::from_u32(5), F::from_u32(6)];
573-
// The function takes a slice of column slices as input.
574-
let columns = vec![col1.as_slice(), col2.as_slice()];
575-
576-
// The function first applies `column_up` to all input columns,
577-
// then applies `column_down` to all input columns, and finally
578-
// collects the results in that order.
579-
//
580-
// Input Columns:
581-
//
582-
// col1 | col2
583-
// -----|-----
584-
// 1 | 4
585-
// 2 | 5
586-
// 3 | 6
587-
//
588-
// Expected Output Structure:
589-
//
590-
// [ up(col1), up(col2), down(col1), down(col2) ]
591-
//
592-
// up(col1) = [1, 2, 2]
593-
// up(col2) = [4, 5, 5]
594-
// down(col1) = [2, 3, 3]
595-
// down(col2) = [5, 6, 6]
596-
let expected = vec![
597-
vec![F::from_u32(1), F::from_u32(2), F::from_u32(2)], // up(col1)
598-
vec![F::from_u32(4), F::from_u32(5), F::from_u32(5)], // up(col2)
599-
vec![F::from_u32(2), F::from_u32(3), F::from_u32(3)], // down(col1)
600-
vec![F::from_u32(5), F::from_u32(6), F::from_u32(6)], // down(col2)
601-
];
602-
// Assert that the function correctly processes and collects all results.
603-
assert_eq!(columns_up_and_down(&columns), expected);
604-
}
605-
606548
#[test]
607549
fn test_matrix_up_folded_vs_lde() {
608550
// Set n=3 variables, meaning we are testing the logic on an 8x8 matrix (since 2^3 = 8).

0 commit comments

Comments
 (0)