Skip to content

Commit 7e8b93c

Browse files
committed
starting to make logup* "padding aware"
1 parent 4b0ae0a commit 7e8b93c

File tree

3 files changed

+102
-40
lines changed

3 files changed

+102
-40
lines changed

crates/lean_prover/src/prove_execution.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,7 @@ pub fn prove_execution(
827827
+ memory_poly_eq_point_alpha.square() * base_memory_lookup_statement_3.value,
828828
&base_memory_poly_eq_point,
829829
&base_memory_pushforward,
830+
Some(non_zero_memory_size),
830831
);
831832
let poseidon_logup_star_statements = prove_logup_star(
832833
&mut prover_state,
@@ -839,6 +840,7 @@ pub fn prove_execution(
839840
.sum(),
840841
&poseidon_poly_eq_point,
841842
&poseidon_pushforward,
843+
Some(non_zero_memory_size.div_ceil(VECTOR_LEN)),
842844
);
843845

844846
let bytecode_logup_star_statements = prove_logup_star(
@@ -848,6 +850,7 @@ pub fn prove_execution(
848850
bytecode_lookup_claim_1.value + alpha_bytecode_lookup * bytecode_lookup_claim_2.value,
849851
&bytecode_poly_eq_point,
850852
&bytecode_pushforward,
853+
Some(bytecode.instructions.len()),
851854
);
852855

853856
let poseidon_lookup_memory_point = MultilinearPoint(

crates/lookup/src/logup_star.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,18 @@ pub fn prove_logup_star<EF>(
3030
claimed_value: EF,
3131
poly_eq_point: &[EF],
3232
pushforward: &[EF], // already commited
33+
max_index: Option<usize>,
3334
) -> LogupStarStatements<EF>
3435
where
3536
EF: ExtensionField<PF<EF>>,
3637
PF<EF>: PrimeField64,
3738
{
3839
let table_length = table.unpacked_len();
3940
let indexes_length = indexes.len();
41+
let max_index = max_index
42+
.unwrap_or(table_length)
43+
.next_multiple_of(packing_width::<EF>());
44+
let max_index_packed = max_index / packing_width::<EF>();
4045

4146
let (poly_eq_point_packed, pushforward_packed, mut table_packed) = info_span!("packing")
4247
.in_scope(|| {
@@ -101,7 +106,8 @@ where
101106
layer
102107
});
103108

104-
let (claim_left, _, eval_c_minux_indexes) = prove_gkr_quotient(prover_state, gkr_layer_left);
109+
let (claim_left, _, eval_c_minux_indexes) =
110+
prove_gkr_quotient(prover_state, gkr_layer_left, None);
105111

106112
let gkr_layer_right = info_span!("building right").in_scope(|| {
107113
let mut layer =
@@ -113,12 +119,18 @@ where
113119
.map(|i| random_challenge - PF::<EF>::from_usize(i))
114120
.collect::<Vec<_>>(),
115121
);
116-
parallel_clone(&pushforward_packed, &mut layer[..half_len_packed]);
122+
parallel_clone(
123+
&pushforward_packed[..max_index_packed],
124+
&mut layer[..max_index_packed],
125+
);
126+
layer[max_index_packed..half_len_packed]
127+
.par_iter_mut()
128+
.for_each(|x| *x = EFPacking::<EF>::ZERO);
117129
parallel_clone(&challenge_minus_increment, &mut layer[half_len_packed..]);
118130
layer
119131
});
120132
let (claim_right, pushforward_final_eval, _) =
121-
prove_gkr_quotient(prover_state, gkr_layer_right);
133+
prove_gkr_quotient(prover_state, gkr_layer_right, Some(max_index_packed));
122134

123135
let final_point_left = claim_left.point[1..].to_vec();
124136
let indexes_final_eval = random_challenge - eval_c_minux_indexes;
@@ -303,6 +315,7 @@ mod tests {
303315
claim.value,
304316
&poly_eq_point,
305317
&pushforward,
318+
None,
306319
);
307320
println!("Proving logup_star took {} ms", time.elapsed().as_millis());
308321

crates/lookup/src/quotient_gkr.rs

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,36 @@ with: U0 = AB(0 0 --- )
3939
pub fn prove_gkr_quotient<EF>(
4040
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
4141
final_layer: Vec<EFPacking<EF>>,
42+
n_non_zeros_numerator: Option<usize>, // final_layer[n_non_zeros_numerator..n / 2] are zeros
4243
) -> (Evaluation<EF>, EF, EF)
4344
where
4445
EF: ExtensionField<PF<EF>>,
4546
PF<EF>: PrimeField64,
4647
{
4748
let n = (final_layer.len() * packing_width::<EF>()).ilog2() as usize;
49+
let n_non_zeros_numerator = n_non_zeros_numerator.unwrap_or(final_layer.len() / 2);
4850
let mut layers_packed = Vec::new();
4951
let mut layers_not_packed = Vec::new();
5052
let last_packed = n
5153
.checked_sub(6 + packing_log_width::<EF>())
5254
.expect("TODO small GKR, no packing");
5355
layers_packed.push(final_layer);
5456
for i in 0..last_packed {
55-
layers_packed.push(sum_quotients_2_by_2(&layers_packed[i]));
57+
layers_packed.push(sum_quotients_2_by_2(
58+
&layers_packed[i],
59+
if i == 0 {
60+
Some(n_non_zeros_numerator)
61+
} else {
62+
None
63+
},
64+
));
5665
}
57-
layers_not_packed.push(sum_quotients_2_by_2(&unpack_extension(
58-
&layers_packed[last_packed],
59-
)));
66+
layers_not_packed.push(sum_quotients_2_by_2(
67+
&unpack_extension(&layers_packed[last_packed]),
68+
None,
69+
));
6070
for i in 0..n - last_packed - 2 {
61-
layers_not_packed.push(sum_quotients_2_by_2(&layers_not_packed[i]));
71+
layers_not_packed.push(sum_quotients_2_by_2(&layers_not_packed[i], None));
6272
}
6373

6474
assert_eq!(layers_not_packed[n - last_packed - 2].len(), 2);
@@ -75,9 +85,17 @@ where
7585
(claim, up_layer_eval_left, up_layer_eval_right) =
7686
prove_gkr_quotient_step(prover_state, layer, &claim);
7787
}
78-
for layer in layers_packed.iter().rev() {
79-
(claim, up_layer_eval_left, up_layer_eval_right) =
80-
prove_gkr_quotient_step_packed(prover_state, layer, &claim);
88+
for (i, layer) in layers_packed.iter().enumerate().rev() {
89+
(claim, up_layer_eval_left, up_layer_eval_right) = prove_gkr_quotient_step_packed(
90+
prover_state,
91+
layer,
92+
&claim,
93+
if i == 0 {
94+
Some(n_non_zeros_numerator)
95+
} else {
96+
None
97+
},
98+
);
8199
}
82100

83101
(claim, up_layer_eval_left, up_layer_eval_right)
@@ -203,6 +221,7 @@ fn prove_gkr_quotient_step_packed<EF>(
203221
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
204222
up_layer_packed: &[EFPacking<EF>],
205223
claim: &Evaluation<EF>,
224+
n_non_zeros_numerator: Option<usize>,
206225
) -> (Evaluation<EF>, EF, EF)
207226
where
208227
EF: ExtensionField<PF<EF>>,
@@ -212,36 +231,50 @@ where
212231
up_layer_packed.len() * packing_width::<EF>(),
213232
2 << claim.point.0.len()
214233
);
234+
let n_non_zeros_numerator = n_non_zeros_numerator.unwrap_or(up_layer_packed.len() / 2);
215235

216236
let len_packed = up_layer_packed.len();
217237
let mid_len_packed = len_packed / 2;
218238
let quarter_len_packed = mid_len_packed / 2;
239+
let three_quarter_len_packed = quarter_len_packed * 3;
240+
241+
assert!(n_non_zeros_numerator >= quarter_len_packed);
219242

220243
let eq_poly = eval_eq(&claim.point.0[1..]);
221244

222245
let mut eq_poly_packed = pack_extension(&eq_poly);
223246

224-
let mut all_sums_x_packed = EFPacking::<EF>::zero_vec(quarter_len_packed);
225-
let mut all_sums_one_minus_x_packed = EFPacking::<EF>::zero_vec(quarter_len_packed);
226-
227-
all_sums_x_packed
228-
.par_iter_mut()
229-
.zip(all_sums_one_minus_x_packed.par_iter_mut())
230-
.enumerate()
231-
.for_each(|(i, (x, one_minus_x))| {
247+
let (sum_x_packed, sum_one_minus_x_packed): (EFPacking<EF>, EFPacking<EF>) = (0
248+
..n_non_zeros_numerator - quarter_len_packed)
249+
.into_par_iter()
250+
.map(|i| {
232251
let eq_eval = eq_poly_packed[i];
233252
let u0 = up_layer_packed[i];
234253
let u1 = up_layer_packed[quarter_len_packed + i];
235254
let u2 = up_layer_packed[mid_len_packed + i];
236-
let u3 = up_layer_packed[mid_len_packed + quarter_len_packed + i];
237-
*x = eq_eval * u2 * u3;
238-
*one_minus_x = eq_eval * (u0 * u3 + u1 * u2);
239-
});
240-
241-
let sum_x_packed = all_sums_x_packed.into_par_iter().sum::<EFPacking<EF>>();
242-
let sum_one_minus_x_packed = all_sums_one_minus_x_packed
243-
.into_par_iter()
244-
.sum::<EFPacking<EF>>();
255+
let u3 = up_layer_packed[three_quarter_len_packed + i];
256+
let x = eq_eval * u2 * u3;
257+
let one_minus_x = eq_eval * (u0 * u3 + u1 * u2);
258+
(x, one_minus_x)
259+
})
260+
.chain(
261+
(n_non_zeros_numerator - quarter_len_packed..quarter_len_packed)
262+
.into_par_iter()
263+
.map(|i| {
264+
let eq_eval = eq_poly_packed[i];
265+
let u0 = up_layer_packed[i];
266+
let u2 = up_layer_packed[mid_len_packed + i];
267+
let u3 = up_layer_packed[three_quarter_len_packed + i];
268+
let eq_eval_times_u3 = eq_eval * u3;
269+
let x = eq_eval_times_u3 * u2;
270+
let one_minus_x = eq_eval_times_u3 * u0;
271+
(x, one_minus_x)
272+
}),
273+
)
274+
.reduce(
275+
|| (EFPacking::<EF>::ZERO, EFPacking::<EF>::ZERO),
276+
|(acc_x, acc_one_minus_x), (x, one_minus_x)| (acc_x + x, acc_one_minus_x + one_minus_x),
277+
);
245278

246279
let sum_x = EFPacking::<EF>::to_ext_iter([sum_x_packed]).sum::<EF>();
247280
let sum_one_minus_x = EFPacking::<EF>::to_ext_iter([sum_one_minus_x_packed]).sum::<EF>();
@@ -429,17 +462,29 @@ impl<EF: ExtensionField<PF<EF>>> SumcheckComputationPacked<EF> for GKRQuotientCo
429462
}
430463
}
431464

432-
fn sum_quotients_2_by_2<EF: PrimeCharacteristicRing + Sync + Send + Copy>(layer: &[EF]) -> Vec<EF> {
465+
fn sum_quotients_2_by_2<EF: PrimeCharacteristicRing + Sync + Send + Copy>(
466+
layer: &[EF],
467+
n_non_zeros_numerator: Option<usize>,
468+
) -> Vec<EF> {
433469
let n = layer.len();
434-
(0..n / 2)
470+
let n_non_zeros_numerator = n_non_zeros_numerator.unwrap_or(n / 2);
471+
assert!(n_non_zeros_numerator >= n / 4);
472+
let n_over_2 = n / 2;
473+
let n_over_4 = n_over_2 / 2;
474+
let n_times_3_over_4 = n_over_2 + n_over_4;
475+
(0..n_non_zeros_numerator - n / 4)
435476
.into_par_iter()
436-
.map(|i| {
437-
if i < n / 4 {
438-
layer[i] * layer[n * 3 / 4 + i] + layer[n / 4 + i] * layer[n / 2 + i]
439-
} else {
440-
layer[n / 4 + i] * layer[n / 2 + i]
441-
}
442-
})
477+
.map(|i| layer[i] * layer[n_times_3_over_4 + i] + layer[n_over_4 + i] * layer[n_over_2 + i])
478+
.chain(
479+
(n_non_zeros_numerator - n / 4..n / 4)
480+
.into_par_iter()
481+
.map(|i| layer[i] * layer[n_times_3_over_4 + i]),
482+
)
483+
.chain(
484+
(n / 4..n / 2)
485+
.into_par_iter()
486+
.map(|i| layer[n_over_4 + i] * layer[n_over_2 + i]),
487+
)
443488
.collect()
444489
}
445490

@@ -468,7 +513,7 @@ mod tests {
468513
let mut rng = StdRng::seed_from_u64(0);
469514

470515
let big = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
471-
let small = sum_quotients_2_by_2(&big);
516+
let small = sum_quotients_2_by_2(&big, None);
472517

473518
// sanity check
474519
assert_eq!(
@@ -483,7 +528,8 @@ mod tests {
483528

484529
let time = Instant::now();
485530
let claim = Evaluation { point, value: eval };
486-
let _ = prove_gkr_quotient_step_packed(&mut prover_state, &pack_extension(&big), &claim);
531+
let _ =
532+
prove_gkr_quotient_step_packed(&mut prover_state, &pack_extension(&big), &claim, None);
487533
dbg!(time.elapsed());
488534

489535
let mut verifier_state = build_verifier_state(&prover_state);
@@ -504,7 +550,7 @@ mod tests {
504550

505551
let mut prover_state = build_prover_state();
506552

507-
let _ = prove_gkr_quotient(&mut prover_state, pack_extension(&layer));
553+
let _ = prove_gkr_quotient(&mut prover_state, pack_extension(&layer), None);
508554

509555
let mut verifier_state = build_verifier_state(&prover_state);
510556

0 commit comments

Comments
 (0)