@@ -39,26 +39,36 @@ with: U0 = AB(0 0 --- )
3939pub fn prove_gkr_quotient < EF > (
4040 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
4141 final_layer : Vec < EFPacking < EF > > ,
42+ n_non_zeros_numerator : Option < usize > , // final_layer[n_non_zeros_numerator..n / 2] are zeros
4243) -> ( Evaluation < EF > , EF , EF )
4344where
4445 EF : ExtensionField < PF < EF > > ,
4546 PF < EF > : PrimeField64 ,
4647{
4748 let n = ( final_layer. len ( ) * packing_width :: < EF > ( ) ) . ilog2 ( ) as usize ;
49+ let n_non_zeros_numerator = n_non_zeros_numerator. unwrap_or ( final_layer. len ( ) / 2 ) ;
4850 let mut layers_packed = Vec :: new ( ) ;
4951 let mut layers_not_packed = Vec :: new ( ) ;
5052 let last_packed = n
5153 . checked_sub ( 6 + packing_log_width :: < EF > ( ) )
5254 . expect ( "TODO small GKR, no packing" ) ;
5355 layers_packed. push ( final_layer) ;
5456 for i in 0 ..last_packed {
55- layers_packed. push ( sum_quotients_2_by_2 ( & layers_packed[ i] ) ) ;
57+ layers_packed. push ( sum_quotients_2_by_2 (
58+ & layers_packed[ i] ,
59+ if i == 0 {
60+ Some ( n_non_zeros_numerator)
61+ } else {
62+ None
63+ } ,
64+ ) ) ;
5665 }
57- layers_not_packed. push ( sum_quotients_2_by_2 ( & unpack_extension (
58- & layers_packed[ last_packed] ,
59- ) ) ) ;
66+ layers_not_packed. push ( sum_quotients_2_by_2 (
67+ & unpack_extension ( & layers_packed[ last_packed] ) ,
68+ None ,
69+ ) ) ;
6070 for i in 0 ..n - last_packed - 2 {
61- layers_not_packed. push ( sum_quotients_2_by_2 ( & layers_not_packed[ i] ) ) ;
71+ layers_not_packed. push ( sum_quotients_2_by_2 ( & layers_not_packed[ i] , None ) ) ;
6272 }
6373
6474 assert_eq ! ( layers_not_packed[ n - last_packed - 2 ] . len( ) , 2 ) ;
7585 ( claim, up_layer_eval_left, up_layer_eval_right) =
7686 prove_gkr_quotient_step ( prover_state, layer, & claim) ;
7787 }
78- for layer in layers_packed. iter ( ) . rev ( ) {
79- ( claim, up_layer_eval_left, up_layer_eval_right) =
80- prove_gkr_quotient_step_packed ( prover_state, layer, & claim) ;
88+ for ( i, layer) in layers_packed. iter ( ) . enumerate ( ) . rev ( ) {
89+ ( claim, up_layer_eval_left, up_layer_eval_right) = prove_gkr_quotient_step_packed (
90+ prover_state,
91+ layer,
92+ & claim,
93+ if i == 0 {
94+ Some ( n_non_zeros_numerator)
95+ } else {
96+ None
97+ } ,
98+ ) ;
8199 }
82100
83101 ( claim, up_layer_eval_left, up_layer_eval_right)
@@ -203,6 +221,7 @@ fn prove_gkr_quotient_step_packed<EF>(
203221 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
204222 up_layer_packed : & [ EFPacking < EF > ] ,
205223 claim : & Evaluation < EF > ,
224+ n_non_zeros_numerator : Option < usize > ,
206225) -> ( Evaluation < EF > , EF , EF )
207226where
208227 EF : ExtensionField < PF < EF > > ,
@@ -212,36 +231,50 @@ where
212231 up_layer_packed. len( ) * packing_width:: <EF >( ) ,
213232 2 << claim. point. 0 . len( )
214233 ) ;
234+ let n_non_zeros_numerator = n_non_zeros_numerator. unwrap_or ( up_layer_packed. len ( ) / 2 ) ;
215235
216236 let len_packed = up_layer_packed. len ( ) ;
217237 let mid_len_packed = len_packed / 2 ;
218238 let quarter_len_packed = mid_len_packed / 2 ;
239+ let three_quarter_len_packed = quarter_len_packed * 3 ;
240+
241+ assert ! ( n_non_zeros_numerator >= quarter_len_packed) ;
219242
220243 let eq_poly = eval_eq ( & claim. point . 0 [ 1 ..] ) ;
221244
222245 let mut eq_poly_packed = pack_extension ( & eq_poly) ;
223246
224- let mut all_sums_x_packed = EFPacking :: < EF > :: zero_vec ( quarter_len_packed) ;
225- let mut all_sums_one_minus_x_packed = EFPacking :: < EF > :: zero_vec ( quarter_len_packed) ;
226-
227- all_sums_x_packed
228- . par_iter_mut ( )
229- . zip ( all_sums_one_minus_x_packed. par_iter_mut ( ) )
230- . enumerate ( )
231- . for_each ( |( i, ( x, one_minus_x) ) | {
247+ let ( sum_x_packed, sum_one_minus_x_packed) : ( EFPacking < EF > , EFPacking < EF > ) = ( 0
248+ ..n_non_zeros_numerator - quarter_len_packed)
249+ . into_par_iter ( )
250+ . map ( |i| {
232251 let eq_eval = eq_poly_packed[ i] ;
233252 let u0 = up_layer_packed[ i] ;
234253 let u1 = up_layer_packed[ quarter_len_packed + i] ;
235254 let u2 = up_layer_packed[ mid_len_packed + i] ;
236- let u3 = up_layer_packed[ mid_len_packed + quarter_len_packed + i] ;
237- * x = eq_eval * u2 * u3;
238- * one_minus_x = eq_eval * ( u0 * u3 + u1 * u2) ;
239- } ) ;
240-
241- let sum_x_packed = all_sums_x_packed. into_par_iter ( ) . sum :: < EFPacking < EF > > ( ) ;
242- let sum_one_minus_x_packed = all_sums_one_minus_x_packed
243- . into_par_iter ( )
244- . sum :: < EFPacking < EF > > ( ) ;
255+ let u3 = up_layer_packed[ three_quarter_len_packed + i] ;
256+ let x = eq_eval * u2 * u3;
257+ let one_minus_x = eq_eval * ( u0 * u3 + u1 * u2) ;
258+ ( x, one_minus_x)
259+ } )
260+ . chain (
261+ ( n_non_zeros_numerator - quarter_len_packed..quarter_len_packed)
262+ . into_par_iter ( )
263+ . map ( |i| {
264+ let eq_eval = eq_poly_packed[ i] ;
265+ let u0 = up_layer_packed[ i] ;
266+ let u2 = up_layer_packed[ mid_len_packed + i] ;
267+ let u3 = up_layer_packed[ three_quarter_len_packed + i] ;
268+ let eq_eval_times_u3 = eq_eval * u3;
269+ let x = eq_eval_times_u3 * u2;
270+ let one_minus_x = eq_eval_times_u3 * u0;
271+ ( x, one_minus_x)
272+ } ) ,
273+ )
274+ . reduce (
275+ || ( EFPacking :: < EF > :: ZERO , EFPacking :: < EF > :: ZERO ) ,
276+ |( acc_x, acc_one_minus_x) , ( x, one_minus_x) | ( acc_x + x, acc_one_minus_x + one_minus_x) ,
277+ ) ;
245278
246279 let sum_x = EFPacking :: < EF > :: to_ext_iter ( [ sum_x_packed] ) . sum :: < EF > ( ) ;
247280 let sum_one_minus_x = EFPacking :: < EF > :: to_ext_iter ( [ sum_one_minus_x_packed] ) . sum :: < EF > ( ) ;
@@ -429,17 +462,29 @@ impl<EF: ExtensionField<PF<EF>>> SumcheckComputationPacked<EF> for GKRQuotientCo
429462 }
430463}
431464
432- fn sum_quotients_2_by_2 < EF : PrimeCharacteristicRing + Sync + Send + Copy > ( layer : & [ EF ] ) -> Vec < EF > {
465+ fn sum_quotients_2_by_2 < EF : PrimeCharacteristicRing + Sync + Send + Copy > (
466+ layer : & [ EF ] ,
467+ n_non_zeros_numerator : Option < usize > ,
468+ ) -> Vec < EF > {
433469 let n = layer. len ( ) ;
434- ( 0 ..n / 2 )
470+ let n_non_zeros_numerator = n_non_zeros_numerator. unwrap_or ( n / 2 ) ;
471+ assert ! ( n_non_zeros_numerator >= n / 4 ) ;
472+ let n_over_2 = n / 2 ;
473+ let n_over_4 = n_over_2 / 2 ;
474+ let n_times_3_over_4 = n_over_2 + n_over_4;
475+ ( 0 ..n_non_zeros_numerator - n / 4 )
435476 . into_par_iter ( )
436- . map ( |i| {
437- if i < n / 4 {
438- layer[ i] * layer[ n * 3 / 4 + i] + layer[ n / 4 + i] * layer[ n / 2 + i]
439- } else {
440- layer[ n / 4 + i] * layer[ n / 2 + i]
441- }
442- } )
477+ . map ( |i| layer[ i] * layer[ n_times_3_over_4 + i] + layer[ n_over_4 + i] * layer[ n_over_2 + i] )
478+ . chain (
479+ ( n_non_zeros_numerator - n / 4 ..n / 4 )
480+ . into_par_iter ( )
481+ . map ( |i| layer[ i] * layer[ n_times_3_over_4 + i] ) ,
482+ )
483+ . chain (
484+ ( n / 4 ..n / 2 )
485+ . into_par_iter ( )
486+ . map ( |i| layer[ n_over_4 + i] * layer[ n_over_2 + i] ) ,
487+ )
443488 . collect ( )
444489}
445490
@@ -468,7 +513,7 @@ mod tests {
468513 let mut rng = StdRng :: seed_from_u64 ( 0 ) ;
469514
470515 let big = ( 0 ..n) . map ( |_| rng. random ( ) ) . collect :: < Vec < EF > > ( ) ;
471- let small = sum_quotients_2_by_2 ( & big) ;
516+ let small = sum_quotients_2_by_2 ( & big, None ) ;
472517
473518 // sanity check
474519 assert_eq ! (
@@ -483,7 +528,8 @@ mod tests {
483528
484529 let time = Instant :: now ( ) ;
485530 let claim = Evaluation { point, value : eval } ;
486- let _ = prove_gkr_quotient_step_packed ( & mut prover_state, & pack_extension ( & big) , & claim) ;
531+ let _ =
532+ prove_gkr_quotient_step_packed ( & mut prover_state, & pack_extension ( & big) , & claim, None ) ;
487533 dbg ! ( time. elapsed( ) ) ;
488534
489535 let mut verifier_state = build_verifier_state ( & prover_state) ;
@@ -504,7 +550,7 @@ mod tests {
504550
505551 let mut prover_state = build_prover_state ( ) ;
506552
507- let _ = prove_gkr_quotient ( & mut prover_state, pack_extension ( & layer) ) ;
553+ let _ = prove_gkr_quotient ( & mut prover_state, pack_extension ( & layer) , None ) ;
508554
509555 let mut verifier_state = build_verifier_state ( & prover_state) ;
510556
0 commit comments