Skip to content

Commit 39f0b71

Browse files
committed
packing for Grand product sumcheck for Execution
1 parent e147c84 commit 39f0b71

File tree

3 files changed

+87
-71
lines changed

3 files changed

+87
-71
lines changed

crates/lean_prover/src/common.rs

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use p3_field::BasedVectorSpace;
1+
use p3_field::{Algebra, BasedVectorSpace};
22
use p3_field::{ExtensionField, PrimeCharacteristicRing};
33
use p3_util::log2_ceil_usize;
44
use packed_pcs::ColDims;
@@ -157,46 +157,57 @@ pub struct PrecompileFootprint {
157157
pub multilinear_eval_powers: [EF; 6],
158158
}
159159

160-
impl<N: ExtensionField<PF<EF>>> SumcheckComputation<N, EF> for PrecompileFootprint
161-
where
162-
EF: ExtensionField<N>,
163-
{
164-
fn degree(&self) -> usize {
165-
3
166-
}
167-
fn eval(&self, point: &[N], _: &[EF]) -> EF {
168-
// TODO not all columns are used
169-
170-
let nu_a = (EF::ONE - point[COL_INDEX_FLAG_A]) * point[COL_INDEX_MEM_VALUE_A]
160+
impl PrecompileFootprint {
161+
fn air_eval<
162+
PointF: PrimeCharacteristicRing + Copy,
163+
ResultF: Algebra<EF> + Algebra<PointF> + Copy,
164+
>(
165+
&self,
166+
point: &[PointF],
167+
mul_point_f_and_ef: impl Fn(PointF, EF) -> ResultF,
168+
) -> ResultF {
169+
let nu_a = (ResultF::ONE - point[COL_INDEX_FLAG_A]) * point[COL_INDEX_MEM_VALUE_A]
171170
+ point[COL_INDEX_FLAG_A] * point[COL_INDEX_OPERAND_A];
172-
let nu_b = (EF::ONE - point[COL_INDEX_FLAG_B]) * point[COL_INDEX_MEM_VALUE_B]
171+
let nu_b = (ResultF::ONE - point[COL_INDEX_FLAG_B]) * point[COL_INDEX_MEM_VALUE_B]
173172
+ point[COL_INDEX_FLAG_B] * point[COL_INDEX_OPERAND_B];
174-
let nu_c = (EF::ONE - point[COL_INDEX_FLAG_C]) * point[COL_INDEX_MEM_VALUE_C]
173+
let nu_c = (ResultF::ONE - point[COL_INDEX_FLAG_C]) * point[COL_INDEX_MEM_VALUE_C]
175174
+ point[COL_INDEX_FLAG_C] * point[COL_INDEX_FP];
176175

177-
self.global_challenge
178-
+ (self.p16_powers[1]
179-
+ self.p16_powers[2] * nu_a
180-
+ self.p16_powers[3] * nu_b
181-
+ self.p16_powers[4] * nu_c)
182-
* point[COL_INDEX_POSEIDON_16]
183-
+ (self.p24_powers[1]
184-
+ self.p24_powers[2] * nu_a
185-
+ self.p24_powers[3] * nu_b
186-
+ self.p24_powers[4] * nu_c)
176+
(nu_a * self.p16_powers[2]
177+
+ nu_b * self.p16_powers[3]
178+
+ nu_c * self.p16_powers[4]
179+
+ self.p16_powers[1])
180+
* point[COL_INDEX_POSEIDON_16]
181+
+ (nu_a * self.p24_powers[2]
182+
+ nu_b * self.p24_powers[3]
183+
+ nu_c * self.p24_powers[4]
184+
+ self.p24_powers[1])
187185
* point[COL_INDEX_POSEIDON_24]
188-
+ (self.dot_product_powers[1]
189-
+ self.dot_product_powers[2] * nu_a
190-
+ self.dot_product_powers[3] * nu_b
191-
+ self.dot_product_powers[4] * nu_c
192-
+ self.dot_product_powers[5] * point[COL_INDEX_AUX])
186+
+ (nu_a * self.dot_product_powers[2]
187+
+ nu_b * self.dot_product_powers[3]
188+
+ nu_c * self.dot_product_powers[4]
189+
+ mul_point_f_and_ef(point[COL_INDEX_AUX], self.dot_product_powers[5])
190+
+ self.dot_product_powers[1])
193191
* point[COL_INDEX_DOT_PRODUCT]
194-
+ (self.multilinear_eval_powers[1]
195-
+ self.multilinear_eval_powers[2] * nu_a
196-
+ self.multilinear_eval_powers[3] * nu_b
197-
+ self.multilinear_eval_powers[4] * nu_c
198-
+ self.multilinear_eval_powers[5] * point[COL_INDEX_AUX])
192+
+ (nu_a * self.multilinear_eval_powers[2]
193+
+ nu_b * self.multilinear_eval_powers[3]
194+
+ nu_c * self.multilinear_eval_powers[4]
195+
+ mul_point_f_and_ef(point[COL_INDEX_AUX], self.multilinear_eval_powers[5])
196+
+ self.multilinear_eval_powers[1])
199197
* point[COL_INDEX_MULTILINEAR_EVAL]
198+
+ self.global_challenge
199+
}
200+
}
201+
202+
impl<N: ExtensionField<F>> SumcheckComputation<N, EF> for PrecompileFootprint
203+
where
204+
EF: ExtensionField<N>,
205+
{
206+
fn degree(&self) -> usize {
207+
3
208+
}
209+
fn eval(&self, point: &[N], _: &[EF]) -> EF {
210+
self.air_eval(point, |p, c| c * p)
200211
}
201212
}
202213

@@ -205,12 +216,12 @@ impl SumcheckComputationPacked<EF> for PrecompileFootprint {
205216
3
206217
}
207218

208-
fn eval_packed_extension(&self, _point: &[EFPacking<EF>], _: &[EF]) -> EFPacking<EF> {
209-
todo!()
219+
fn eval_packed_extension(&self, point: &[EFPacking<EF>], _: &[EF]) -> EFPacking<EF> {
220+
self.air_eval(point, |p, c| p * c)
210221
}
211222

212-
fn eval_packed_base(&self, _point: &[utils::PFPacking<EF>], _: &[EF]) -> EFPacking<EF> {
213-
todo!()
223+
fn eval_packed_base(&self, point: &[PFPacking<EF>], _: &[EF]) -> EFPacking<EF> {
224+
self.air_eval::<PFPacking<EF>, EFPacking<EF>>(point, |p, c| EFPacking::<EF>::from(p) * c)
214225
}
215226
}
216227

crates/lean_prover/src/prove_execution.rs

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -408,23 +408,25 @@ pub fn prove_execution(
408408
grand_product_dot_product_sumcheck_point,
409409
grand_product_dot_product_sumcheck_inner_evals,
410410
_,
411-
) = sumcheck::prove(
412-
1, // TODO univariate skip?
413-
MleGroupRef::Extension(
414-
dot_product_columns[..5]
415-
.iter()
416-
.map(|c| c.as_slice())
417-
.collect::<Vec<_>>(),
418-
), // TODO packing
419-
&dot_product_footprint_computation,
420-
&dot_product_footprint_computation,
421-
&[],
422-
Some((grand_product_dot_product_statement.point.0.clone(), None)),
423-
false,
424-
&mut prover_state,
425-
grand_product_dot_product_statement.value,
426-
None,
427-
);
411+
) = info_span!("Grand product sumcheck for Dot Product").in_scope(|| {
412+
sumcheck::prove(
413+
1, // TODO univariate skip?
414+
MleGroupRef::Extension(
415+
dot_product_columns[..5]
416+
.iter()
417+
.map(|c| c.as_slice())
418+
.collect::<Vec<_>>(),
419+
), // TODO packing
420+
&dot_product_footprint_computation,
421+
&dot_product_footprint_computation,
422+
&[],
423+
Some((grand_product_dot_product_statement.point.0.clone(), None)),
424+
false,
425+
&mut prover_state,
426+
grand_product_dot_product_statement.value,
427+
None,
428+
)
429+
});
428430
assert_eq!(grand_product_dot_product_sumcheck_inner_evals.len(), 5);
429431
prover_state.add_extension_scalars(&grand_product_dot_product_sumcheck_inner_evals);
430432

@@ -460,21 +462,24 @@ pub fn prove_execution(
460462
};
461463

462464
let (grand_product_exec_sumcheck_point, grand_product_exec_sumcheck_inner_evals, _) =
463-
sumcheck::prove(
464-
1, // TODO univariate skip?
465-
MleGroupRef::Base(
466-
// TODO not all columns re required
467-
full_trace.iter().map(|c| c.as_slice()).collect::<Vec<_>>(),
468-
), // TODO packing
469-
&precompile_foot_print_computation,
470-
&precompile_foot_print_computation,
471-
&[],
472-
Some((grand_product_exec_statement.point.0.clone(), None)),
473-
false,
474-
&mut prover_state,
475-
grand_product_exec_statement.value,
476-
None,
477-
);
465+
info_span!("Grand product sumcheck for Execution").in_scope(|| {
466+
sumcheck::prove(
467+
1, // TODO univariate skip
468+
MleGroupRef::Base(
469+
// TODO not all columns re required
470+
full_trace.iter().map(|c| c.as_slice()).collect::<Vec<_>>(),
471+
)
472+
.pack(),
473+
&precompile_foot_print_computation,
474+
&precompile_foot_print_computation,
475+
&[],
476+
Some((grand_product_exec_statement.point.0.clone(), None)),
477+
false,
478+
&mut prover_state,
479+
grand_product_exec_statement.value,
480+
None,
481+
)
482+
});
478483

479484
prover_state.add_extension_scalars(&grand_product_exec_sumcheck_inner_evals);
480485
assert_eq!(

crates/rec_aggregation/src/recursion.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ fn run_recursion_benchmark() -> RecursionBenchStats {
183183
// #[rustfmt::skip] // debug
184184
// std::fs::write("public_input.txt", build_public_memory(&public_input).chunks_exact(8).enumerate().map(|(i, chunk)| { format!("{} - {}: {}\n", i, i * 8, chunk.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(", ")) }).collect::<String>(),).unwrap();
185185

186-
// utils::init_tracing();
186+
utils::init_tracing();
187187
let (bytecode, function_locations) = compile_program(&program_str);
188188
let time = Instant::now();
189189
let (proof_data, proof_size) = prove_execution(

0 commit comments

Comments
 (0)