Skip to content

Commit 3358c6a

Browse files
committed
wip
1 parent 104f96b commit 3358c6a

File tree

1 file changed

+123
-31
lines changed

1 file changed

+123
-31
lines changed

crates/lookup/src/quotient_gkr.rs

Lines changed: 123 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ fn prove_gkr_quotient_step_packed<EF>(
213213
prover_state: &mut FSProver<EF, impl FSChallenger<EF>>,
214214
up_layer_packed: &[EFPacking<EF>],
215215
claim: &Evaluation<EF>,
216-
n_non_zeros_numerator: Option<usize>,
216+
_n_non_zeros_numerator: Option<usize>, // TODO
217217
) -> (Evaluation<EF>, EF, EF)
218218
where
219219
EF: ExtensionField<PF<EF>>,
@@ -223,49 +223,112 @@ where
223223
up_layer_packed.len() * packing_width::<EF>(),
224224
2 << claim.point.0.len()
225225
);
226-
let n_non_zeros_numerator = n_non_zeros_numerator.unwrap_or(up_layer_packed.len() / 2);
227226

228227
let len_packed = up_layer_packed.len();
229228
let mid_len_packed = len_packed / 2;
230229
let quarter_len_packed = mid_len_packed / 2;
231-
let three_quarter_len_packed = quarter_len_packed * 3;
232-
233-
assert!(n_non_zeros_numerator >= quarter_len_packed);
234230

235231
let mut eq_poly_packed = eval_eq_packed(&claim.point.0[1..]);
236232
// TODO for the top layer, the denomiators have a structured form: constant - index.
237233
// We can skip one EE multilication in the sumcheck computation.
238234

239-
let sum_x_packed: EFPacking<EF> = (0..quarter_len_packed)
240-
.into_par_iter()
241-
.map(|i| {
242-
let eq_eval = eq_poly_packed[i];
243-
let u2 = up_layer_packed[mid_len_packed + i];
244-
let u3 = up_layer_packed[three_quarter_len_packed + i];
245-
eq_eval * u2 * u3
246-
})
247-
.sum();
235+
let up_layer_octics = split_at_many(
236+
&up_layer_packed,
237+
&(1..8)
238+
.map(|i| i * up_layer_packed.len() / 8)
239+
.collect::<Vec<_>>(),
240+
);
241+
242+
let (eq_mle_left, eq_mle_right) = eq_poly_packed.split_at(eq_poly_packed.len() / 2);
243+
244+
let (sum_x_packed, c0_term_single, c2_term_single, c0_term_double, c2_term_double) =
245+
up_layer_octics[0]
246+
.par_iter()
247+
.zip(up_layer_octics[1].par_iter())
248+
.zip(up_layer_octics[2].par_iter())
249+
.zip(up_layer_octics[3].par_iter())
250+
.zip(up_layer_octics[4].par_iter())
251+
.zip(up_layer_octics[5].par_iter())
252+
.zip(up_layer_octics[6].par_iter())
253+
.zip(up_layer_octics[7].par_iter())
254+
.zip(eq_mle_left.par_iter())
255+
.zip(eq_mle_right.par_iter())
256+
.map(
257+
|(
258+
(
259+
(
260+
(
261+
(((((u0_left, u0_right), u1_left), u1_right), u2_left), u2_right),
262+
u3_left,
263+
),
264+
u3_right,
265+
),
266+
&eq_val_left,
267+
),
268+
&eq_val_right,
269+
)| {
270+
let x_sum_left = eq_val_left * *u2_left * *u3_left;
271+
let x_sum_right = eq_val_right * *u2_right * *u3_right;
272+
273+
// first sumcheck polynomial
274+
let x_sum = x_sum_left + x_sum_right;
275+
276+
// anticipation for the next sumcheck polynomial
277+
let c0_term_single = x_sum_left;
278+
let mut c2_term_single = (*u3_right - *u3_left) * (*u2_right - *u2_left);
279+
c2_term_single *= eq_val_left;
280+
281+
let c0_term_double_a = *u0_left * *u3_left;
282+
let c2_term_double_a = (*u0_right - *u0_left) * (*u3_right - *u3_left);
283+
let c0_term_double_b = *u2_left * *u1_left;
284+
let c2_term_double_b = (*u1_right - *u1_left) * (*u2_right - *u2_left);
285+
let mut c0_term_double = c0_term_double_a + c0_term_double_b;
286+
let mut c2_term_double = c2_term_double_a + c2_term_double_b;
287+
c0_term_double *= eq_val_left;
288+
c2_term_double *= eq_val_left;
289+
290+
(
291+
x_sum,
292+
c0_term_single,
293+
c2_term_single,
294+
c0_term_double,
295+
c2_term_double,
296+
)
297+
},
298+
)
299+
.reduce(
300+
|| {
301+
(
302+
EFPacking::<EF>::ZERO,
303+
EFPacking::<EF>::ZERO,
304+
EFPacking::<EF>::ZERO,
305+
EFPacking::<EF>::ZERO,
306+
EFPacking::<EF>::ZERO,
307+
)
308+
},
309+
|(x, a0, a1, a2, a3), (y, b0, b1, b2, b3)| {
310+
(x + y, a0 + b0, a1 + b1, a2 + b2, a3 + b3)
311+
},
312+
);
248313

249314
let sum_x = EFPacking::<EF>::to_ext_iter([sum_x_packed]).sum::<EF>();
250315
let sum_one_minus_x = (claim.value - sum_x * claim.point[0]) / (EF::ONE - claim.point[0]);
251316

252-
let first_sumcheck_polynomial =
317+
let sumcheck_polynomial_1 =
253318
&DensePolynomial::new(vec![
254319
EF::ONE - claim.point[0],
255320
claim.point[0].double() - EF::ONE,
256321
]) * &DensePolynomial::new(vec![sum_one_minus_x, sum_x - sum_one_minus_x]);
257322

258323
// sanity check
259324
assert_eq!(
260-
first_sumcheck_polynomial.evaluate(EF::ZERO) + first_sumcheck_polynomial.evaluate(EF::ONE),
325+
sumcheck_polynomial_1.evaluate(EF::ZERO) + sumcheck_polynomial_1.evaluate(EF::ONE),
261326
claim.value
262327
);
263328

264-
prover_state.add_extension_scalars(&first_sumcheck_polynomial.coeffs);
265-
266-
let first_sumcheck_challenge = prover_state.sample();
267-
268-
let next_sum = first_sumcheck_polynomial.evaluate(first_sumcheck_challenge);
329+
prover_state.add_extension_scalars(&sumcheck_polynomial_1.coeffs);
330+
let sumcheck_challenge_1 = prover_state.sample();
331+
let sum_1 = sumcheck_polynomial_1.evaluate(sumcheck_challenge_1);
269332

270333
let (u0_folded_packed, u1_folded_packed, u2_folded_packed, u3_folded_packed) = (
271334
&up_layer_packed[..quarter_len_packed],
@@ -274,35 +337,64 @@ where
274337
&up_layer_packed[mid_len_packed + quarter_len_packed..],
275338
);
276339

277-
let u4_const = first_sumcheck_challenge;
278-
let u5_const = EF::ONE - first_sumcheck_challenge;
279-
let missing_mul_factor = (first_sumcheck_challenge * claim.point[0]
280-
+ (EF::ONE - first_sumcheck_challenge) * (EF::ONE - claim.point[0]))
340+
let u4_const = sumcheck_challenge_1;
341+
let u5_const = EF::ONE - sumcheck_challenge_1;
342+
let mut missing_mul_factor = (sumcheck_challenge_1 * claim.point[0]
343+
+ (EF::ONE - sumcheck_challenge_1) * (EF::ONE - claim.point[0]))
281344
/ (EF::ONE - claim.point[1]);
282345

283-
eq_poly_packed.resize(eq_poly_packed.len() / 2, Default::default());
346+
let c0 = c0_term_single * u4_const + c0_term_double * u5_const;
347+
let c2 = c2_term_single * u4_const + c2_term_double * u5_const;
348+
349+
let c0 = EFPacking::<EF>::to_ext_iter([c0]).into_iter().sum::<EF>();
350+
let c2 = EFPacking::<EF>::to_ext_iter([c2]).into_iter().sum::<EF>();
351+
352+
let first_eq_factor = claim.point[1];
353+
let c1 = ((sum_1 / missing_mul_factor) - c2 * first_eq_factor - c0) / first_eq_factor;
354+
355+
let mut sumcheck_polynomial_2 = DensePolynomial::new(vec![
356+
c0 * missing_mul_factor,
357+
c1 * missing_mul_factor,
358+
c2 * missing_mul_factor,
359+
]);
360+
361+
sumcheck_polynomial_2 *= &DensePolynomial::lagrange_interpolation(&[
362+
(PF::<EF>::ZERO, EF::ONE - first_eq_factor),
363+
(PF::<EF>::ONE, first_eq_factor),
364+
])
365+
.unwrap();
366+
367+
prover_state.add_extension_scalars(&sumcheck_polynomial_2.coeffs);
368+
let sumcheck_challenge_2 = prover_state.sample();
369+
let sum_2 = sumcheck_polynomial_2.evaluate(sumcheck_challenge_2);
370+
371+
eq_poly_packed.resize(eq_poly_packed.len() / 4, Default::default());
372+
missing_mul_factor *= ((EF::ONE - claim.point[1]) * (EF::ONE - sumcheck_challenge_2)
373+
+ claim.point[1] * sumcheck_challenge_2)
374+
/ (EF::ONE - claim.point.get(2).copied().unwrap_or_default());
284375

285-
let (mut sc_point, quarter_evals, _) = sumcheck_prove::<EF, _, _, _>(
376+
let (mut sc_point, quarter_evals, _) = sumcheck_fold_and_prove::<EF, _, _, _>(
286377
1,
287378
MleGroupRef::ExtensionPacked(vec![
288379
u0_folded_packed,
289380
u1_folded_packed,
290381
u2_folded_packed,
291382
u3_folded_packed,
292383
]),
384+
Some(vec![EF::ONE - sumcheck_challenge_2, sumcheck_challenge_2]),
293385
&GKRQuotientComputation { u4_const, u5_const },
294386
&GKRQuotientComputation { u4_const, u5_const },
295387
&[],
296388
Some((
297-
claim.point.0[1..].to_vec(),
389+
claim.point.0[2..].to_vec(),
298390
Some(MleOwned::ExtensionPacked(eq_poly_packed)),
299391
)),
300392
false,
301393
prover_state,
302-
next_sum,
394+
sum_2,
303395
Some(missing_mul_factor),
304396
);
305-
sc_point.insert(0, first_sumcheck_challenge);
397+
sc_point.splice(0..0, [sumcheck_challenge_1, sumcheck_challenge_2]);
306398

307399
prover_state.add_extension_scalars(&quarter_evals);
308400

@@ -468,7 +560,7 @@ mod tests {
468560

469561
#[test]
470562
fn test_gkr_quotient_step() {
471-
let log_n = 21;
563+
let log_n = 12;
472564
let n = 1 << log_n;
473565

474566
let mut rng = StdRng::seed_from_u64(0);

0 commit comments

Comments
 (0)