Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 50 additions & 71 deletions crates/air/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,23 @@ use crate::{MyAir, witness::AirWitness};
pub struct AirTable<EF: Field, A> {
pub air: A,
pub n_constraints: usize,

_phantom: std::marker::PhantomData<EF>,
_phantom: PhantomData<EF>,
}

impl<EF: ExtensionField<PF<EF>>, A: MyAir<EF>> AirTable<EF, A> {
pub fn new(air: A) -> Self {
let symbolic_constraints = get_symbolic_constraints(&air, 0, 0);
let n_constraints = symbolic_constraints.len();
let constraint_degree = Iterator::max(
symbolic_constraints
.iter()
.map(p3_uni_stark::SymbolicExpression::degree_multiple),
)
.unwrap();
let constraint_degree = symbolic_constraints
.iter()
.map(p3_uni_stark::SymbolicExpression::degree_multiple)
.max()
.unwrap();
assert_eq!(constraint_degree, air.degree());
Self {
air,
n_constraints,
_phantom: std::marker::PhantomData,
_phantom: PhantomData,
}
}

Expand All @@ -45,15 +43,43 @@ impl<EF: ExtensionField<PF<EF>>, A: MyAir<EF>> AirTable<EF, A> {
A: MyAir<EF>,
EF: ExtensionField<IF>,
{
if witness.n_columns() != self.n_columns() {
let width = self.air.width();
let rows = witness.n_rows();

if witness.n_columns() != width {
return Err("Invalid number of columns".to_string());
}
let handle_errors = |row: usize, constraint_checker: &mut ConstraintChecker<'_, IF, EF>| {
if !constraint_checker.errors.is_empty() {

// Erased-type dispatch to the concrete ConstraintChecker expected by `air.eval`.
let eval_erased = |checker: &mut ConstraintChecker<'_, IF, EF>| unsafe {
if TypeId::of::<IF>() == TypeId::of::<EF>() {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>(checker));
} else {
assert_eq!(TypeId::of::<IF>(), TypeId::of::<PF<EF>>());
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, PF<EF>, EF>>(
checker,
));
}
};

// Common per-row runner.
let run_row = |row: usize, slice: Vec<IF>| -> Result<(), String> {
let mut checker = ConstraintChecker {
main: RowMajorMatrixView::new(&slice, width),
constraint_index: 0,
errors: Vec::new(),
field: PhantomData,
};

eval_erased(&mut checker);

if !checker.errors.is_empty() {
return Err(format!(
"Trace is not valid at row {}: contraints not respected: {}",
row,
constraint_checker
checker
.errors
.iter()
.map(std::string::ToString::to_string)
Expand All @@ -63,69 +89,22 @@ impl<EF: ExtensionField<PF<EF>>, A: MyAir<EF>> AirTable<EF, A> {
}
Ok(())
};

if self.air.structured() {
for row in 0..witness.n_rows() - 1 {
let up = (0..self.n_columns())
.map(|j| witness[j][row])
.collect::<Vec<_>>();
let down = (0..self.n_columns())
.map(|j| witness[j][row + 1])
.collect::<Vec<_>>();
let up_and_down = [up, down].concat();
let mut constraints_checker = ConstraintChecker::<IF, EF> {
main: RowMajorMatrixView::new(&up_and_down, self.air.width()),
constraint_index: 0,
errors: Vec::new(),
field: PhantomData,
};
if TypeId::of::<IF>() == TypeId::of::<EF>() {
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>(
&mut constraints_checker,
));
}
} else {
assert_eq!(TypeId::of::<IF>(), TypeId::of::<PF<EF>>());
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, PF<EF>, EF>>(
&mut constraints_checker,
));
}
}
handle_errors(row, &mut constraints_checker)?;
// same semantics as original: panics if rows == 0 due to `rows - 1`
for row in 0..rows - 1 {
let mut up_and_down = Vec::with_capacity(width * 2);
up_and_down.extend((0..width).map(|j| witness[j][row]));
up_and_down.extend((0..width).map(|j| witness[j][row + 1]));
run_row(row, up_and_down)?;
}
} else {
for row in 0..witness.n_rows() {
let up = (0..self.n_columns())
.map(|j| witness[j][row])
.collect::<Vec<_>>();
let mut constraints_checker = ConstraintChecker {
main: RowMajorMatrixView::new(&up, self.air.width()),
constraint_index: 0,
errors: Vec::new(),
field: PhantomData,
};
if TypeId::of::<IF>() == TypeId::of::<EF>() {
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, EF, EF>>(
&mut constraints_checker,
));
}
} else {
assert_eq!(TypeId::of::<IF>(), TypeId::of::<PF<EF>>());
unsafe {
self.air
.eval(transmute::<_, &mut ConstraintChecker<'_, PF<EF>, EF>>(
&mut constraints_checker,
));
}
}
handle_errors(row, &mut constraints_checker)?;
for row in 0..rows {
let up = (0..width).map(|j| witness[j][row]).collect();
run_row(row, up)?;
}
}

Ok(())
}
}