@@ -239,24 +239,14 @@ where
239239 // TODO for the top layer, the denomiators have a structured form: constant - index.
240240 // We can skip one EE multilication in the sumcheck computation.
241241
242- let sum_x_packed: EFPacking < EF > = ( 0 ..n_non_zeros_numerator - quarter_len_packed)
242+ let sum_x_packed: EFPacking < EF > = ( 0 ..quarter_len_packed)
243243 . into_par_iter ( )
244244 . map ( |i| {
245245 let eq_eval = eq_poly_packed[ i] ;
246246 let u2 = up_layer_packed[ mid_len_packed + i] ;
247247 let u3 = up_layer_packed[ three_quarter_len_packed + i] ;
248248 eq_eval * u2 * u3
249249 } )
250- . chain (
251- ( n_non_zeros_numerator - quarter_len_packed..quarter_len_packed)
252- . into_par_iter ( )
253- . map ( |i| {
254- let eq_eval = eq_poly_packed[ i] ;
255- let u2 = up_layer_packed[ mid_len_packed + i] ;
256- let u3 = up_layer_packed[ three_quarter_len_packed + i] ;
257- eq_eval * u3 * u2
258- } ) ,
259- )
260250 . sum ( ) ;
261251
262252 let sum_x = EFPacking :: < EF > :: to_ext_iter ( [ sum_x_packed] ) . sum :: < EF > ( ) ;
@@ -414,30 +404,52 @@ where
414404 Ok ( Evaluation :: new ( next_point, next_claim) )
415405}
416406
417- fn sum_quotients_2_by_2 < EF : PrimeCharacteristicRing + Sync + Send + Copy > (
418- layer : & [ EF ] ,
407+ fn sum_quotients_2_by_2 < F : PrimeCharacteristicRing + Sync + Send + Copy > (
408+ layer : & [ F ] ,
419409 n_non_zeros_numerator : Option < usize > ,
420- ) -> Vec < EF > {
410+ ) -> Vec < F > {
421411 let n = layer. len ( ) ;
422- let n_non_zeros_numerator = n_non_zeros_numerator. unwrap_or ( n / 2 ) ;
423- assert ! ( n_non_zeros_numerator >= n / 4 ) ;
412+ let denominators = & layer[ n / 2 ..] ;
413+ sum_quotients_2_by_2_num_and_den ( & layer[ ..n / 2 ] , |i| denominators[ i] , n_non_zeros_numerator)
414+ }
415+
416+ fn sum_quotients_2_by_2_num_and_den < F : PrimeCharacteristicRing + Sync + Send + Copy > (
417+ numerators : & [ F ] ,
418+ denominators : impl Fn ( usize ) -> F + Sync + Send ,
419+ n_non_zeros_numerator : Option < usize > ,
420+ ) -> Vec < F > {
421+ let n = numerators. len ( ) ;
422+ let mut res = unsafe { uninitialized_vec ( n) } ;
423+ let n_non_zeros_numerator = n_non_zeros_numerator. unwrap_or ( n) ;
424+ assert ! ( n_non_zeros_numerator >= n / 2 ) ;
424425 let n_over_2 = n / 2 ;
425- let n_over_4 = n_over_2 / 2 ;
426- let n_times_3_over_4 = n_over_2 + n_over_4;
427- ( 0 ..n_non_zeros_numerator - n / 4 )
428- . into_par_iter ( )
429- . map ( |i| layer[ i] * layer[ n_times_3_over_4 + i] + layer[ n_over_4 + i] * layer[ n_over_2 + i] )
430- . chain (
431- ( n_non_zeros_numerator - n / 4 ..n / 4 )
432- . into_par_iter ( )
433- . map ( |i| layer[ i] * layer[ n_times_3_over_4 + i] ) ,
434- )
435- . chain (
436- ( n / 4 ..n / 2 )
437- . into_par_iter ( )
438- . map ( |i| layer[ n_over_4 + i] * layer[ n_over_2 + i] ) ,
439- )
440- . collect ( )
426+
427+ let ( new_numerators, new_denominators) = res. split_at_mut ( n / 2 ) ;
428+ new_numerators[ ..n_non_zeros_numerator - n / 2 ]
429+ . par_iter_mut ( )
430+ . zip ( new_denominators[ ..n_non_zeros_numerator - n / 2 ] . par_iter_mut ( ) )
431+ . enumerate ( )
432+ . for_each ( |( i, ( num, den) ) | {
433+ let prev_num_1 = numerators[ i] ;
434+ let prev_num_2 = numerators[ n_over_2 + i] ;
435+ let prev_den_1 = denominators ( i) ;
436+ let prev_den_2 = denominators ( n_over_2 + i) ;
437+ * num = prev_num_1 * prev_den_2 + prev_num_2 * prev_den_1;
438+ * den = prev_den_1 * prev_den_2;
439+ } ) ;
440+ new_numerators[ n_non_zeros_numerator - n / 2 ..]
441+ . par_iter_mut ( )
442+ . zip ( new_denominators[ n_non_zeros_numerator - n / 2 ..] . par_iter_mut ( ) )
443+ . enumerate ( )
444+ . for_each ( |( i, ( num, den) ) | {
445+ let idx = i + n_non_zeros_numerator - n / 2 ;
446+ let prev_num_1 = numerators[ idx] ;
447+ let prev_den_1 = denominators ( idx) ;
448+ let prev_den_2 = denominators ( n_over_2 + idx) ;
449+ * num = prev_num_1 * prev_den_2;
450+ * den = prev_den_1 * prev_den_2;
451+ } ) ;
452+ res
441453}
442454
443455#[ cfg( test) ]
0 commit comments