Skip to content

Commit 976429c

Browse files
committed
improve sum_quotients_2_by_2
1 parent d11ad4e commit 976429c

File tree

2 files changed

+46
-34
lines changed

2 files changed

+46
-34
lines changed

crates/lookup/src/logup_star.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,10 @@ mod tests {
271271
fn test_logup_star() {
272272
init_tracing();
273273

274-
let log_table_len = 19;
274+
let log_table_len = 21;
275275
let table_length = 1 << log_table_len;
276276

277-
let log_indexes_len = log_table_len + 2;
277+
let log_indexes_len = log_table_len + 1;
278278
let indexes_len = 1 << log_indexes_len;
279279

280280
let mut rng = StdRng::seed_from_u64(0);

crates/lookup/src/quotient_gkr.rs

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -239,24 +239,14 @@ where
239239
// TODO for the top layer, the denomiators have a structured form: constant - index.
240240
// We can skip one EE multilication in the sumcheck computation.
241241

242-
let sum_x_packed: EFPacking<EF> = (0..n_non_zeros_numerator - quarter_len_packed)
242+
let sum_x_packed: EFPacking<EF> = (0..quarter_len_packed)
243243
.into_par_iter()
244244
.map(|i| {
245245
let eq_eval = eq_poly_packed[i];
246246
let u2 = up_layer_packed[mid_len_packed + i];
247247
let u3 = up_layer_packed[three_quarter_len_packed + i];
248248
eq_eval * u2 * u3
249249
})
250-
.chain(
251-
(n_non_zeros_numerator - quarter_len_packed..quarter_len_packed)
252-
.into_par_iter()
253-
.map(|i| {
254-
let eq_eval = eq_poly_packed[i];
255-
let u2 = up_layer_packed[mid_len_packed + i];
256-
let u3 = up_layer_packed[three_quarter_len_packed + i];
257-
eq_eval * u3 * u2
258-
}),
259-
)
260250
.sum();
261251

262252
let sum_x = EFPacking::<EF>::to_ext_iter([sum_x_packed]).sum::<EF>();
@@ -414,30 +404,52 @@ where
414404
Ok(Evaluation::new(next_point, next_claim))
415405
}
416406

417-
fn sum_quotients_2_by_2<EF: PrimeCharacteristicRing + Sync + Send + Copy>(
418-
layer: &[EF],
407+
fn sum_quotients_2_by_2<F: PrimeCharacteristicRing + Sync + Send + Copy>(
408+
layer: &[F],
419409
n_non_zeros_numerator: Option<usize>,
420-
) -> Vec<EF> {
410+
) -> Vec<F> {
421411
let n = layer.len();
422-
let n_non_zeros_numerator = n_non_zeros_numerator.unwrap_or(n / 2);
423-
assert!(n_non_zeros_numerator >= n / 4);
412+
let denominators = &layer[n / 2..];
413+
sum_quotients_2_by_2_num_and_den(&layer[..n / 2], |i| denominators[i], n_non_zeros_numerator)
414+
}
415+
416+
fn sum_quotients_2_by_2_num_and_den<F: PrimeCharacteristicRing + Sync + Send + Copy>(
417+
numerators: &[F],
418+
denominators: impl Fn(usize) -> F + Sync + Send,
419+
n_non_zeros_numerator: Option<usize>,
420+
) -> Vec<F> {
421+
let n = numerators.len();
422+
let mut res = unsafe { uninitialized_vec(n) };
423+
let n_non_zeros_numerator = n_non_zeros_numerator.unwrap_or(n);
424+
assert!(n_non_zeros_numerator >= n / 2);
424425
let n_over_2 = n / 2;
425-
let n_over_4 = n_over_2 / 2;
426-
let n_times_3_over_4 = n_over_2 + n_over_4;
427-
(0..n_non_zeros_numerator - n / 4)
428-
.into_par_iter()
429-
.map(|i| layer[i] * layer[n_times_3_over_4 + i] + layer[n_over_4 + i] * layer[n_over_2 + i])
430-
.chain(
431-
(n_non_zeros_numerator - n / 4..n / 4)
432-
.into_par_iter()
433-
.map(|i| layer[i] * layer[n_times_3_over_4 + i]),
434-
)
435-
.chain(
436-
(n / 4..n / 2)
437-
.into_par_iter()
438-
.map(|i| layer[n_over_4 + i] * layer[n_over_2 + i]),
439-
)
440-
.collect()
426+
427+
let (new_numerators, new_denominators) = res.split_at_mut(n / 2);
428+
new_numerators[..n_non_zeros_numerator - n / 2]
429+
.par_iter_mut()
430+
.zip(new_denominators[..n_non_zeros_numerator - n / 2].par_iter_mut())
431+
.enumerate()
432+
.for_each(|(i, (num, den))| {
433+
let prev_num_1 = numerators[i];
434+
let prev_num_2 = numerators[n_over_2 + i];
435+
let prev_den_1 = denominators(i);
436+
let prev_den_2 = denominators(n_over_2 + i);
437+
*num = prev_num_1 * prev_den_2 + prev_num_2 * prev_den_1;
438+
*den = prev_den_1 * prev_den_2;
439+
});
440+
new_numerators[n_non_zeros_numerator - n / 2..]
441+
.par_iter_mut()
442+
.zip(new_denominators[n_non_zeros_numerator - n / 2..].par_iter_mut())
443+
.enumerate()
444+
.for_each(|(i, (num, den))| {
445+
let idx = i + n_non_zeros_numerator - n / 2;
446+
let prev_num_1 = numerators[idx];
447+
let prev_den_1 = denominators(idx);
448+
let prev_den_2 = denominators(n_over_2 + idx);
449+
*num = prev_num_1 * prev_den_2;
450+
*den = prev_den_1 * prev_den_2;
451+
});
452+
res
441453
}
442454

443455
#[cfg(test)]

0 commit comments

Comments
 (0)