|
1 | 1 | use std::borrow::Borrow; |
2 | 2 |
|
3 | | -use p3_field::{BasedVectorSpace, ExtensionField, Field, PackedValue, dot_product}; |
| 3 | +use p3_field::{ExtensionField, Field, dot_product}; |
4 | 4 | use rayon::prelude::*; |
5 | 5 | use tracing::instrument; |
6 | 6 | use whir_p3::poly::evals::EvaluationsList; |
7 | 7 |
|
8 | 8 | use crate::{EFPacking, PF}; |
9 | 9 |
|
10 | | -pub fn fold_multilinear_in_small_field<F: Field, EF: ExtensionField<F>, D>( |
11 | | - m: &[D], |
12 | | - scalars: &[F], |
13 | | -) -> Vec<EF> { |
14 | | - // TODO ... |
15 | | - assert!(scalars.len().is_power_of_two() && scalars.len() <= m.len()); |
16 | | - let new_size = m.len() / scalars.len(); |
17 | | - |
18 | | - let dim = <EF as BasedVectorSpace<F>>::DIMENSION; |
19 | | - |
20 | | - let m_transmuted: &[F] = |
21 | | - unsafe { std::slice::from_raw_parts(m.as_ptr().cast::<F>(), m.len() * dim) }; |
22 | | - let res_transmuted = { |
23 | | - let new_size = m.len() * dim / scalars.len(); |
24 | | - |
25 | | - if new_size < F::Packing::WIDTH { |
26 | | - (0..new_size) |
27 | | - .into_par_iter() |
28 | | - .map(|i| { |
29 | | - scalars |
30 | | - .iter() |
31 | | - .enumerate() |
32 | | - .map(|(j, s)| *s * m_transmuted[i + j * new_size]) |
33 | | - .sum() |
34 | | - }) |
35 | | - .collect() |
36 | | - } else { |
37 | | - let inners = (0..scalars.len()) |
38 | | - .map(|i| &m_transmuted[i * new_size..(i + 1) * new_size]) |
39 | | - .collect::<Vec<_>>(); |
40 | | - let inners_packed = inners |
41 | | - .iter() |
42 | | - .map(|&inner| F::Packing::pack_slice(inner)) |
43 | | - .collect::<Vec<_>>(); |
44 | | - |
45 | | - let packed_res = (0..new_size / F::Packing::WIDTH) |
46 | | - .into_par_iter() |
47 | | - .map(|i| { |
48 | | - scalars |
49 | | - .iter() |
50 | | - .enumerate() |
51 | | - .map(|(j, s)| inners_packed[j][i] * *s) |
52 | | - .sum::<F::Packing>() |
53 | | - }) |
54 | | - .collect::<Vec<_>>(); |
55 | | - |
56 | | - let mut unpacked: Vec<F> = unsafe { std::mem::transmute(packed_res) }; |
57 | | - unsafe { |
58 | | - unpacked.set_len(new_size); |
59 | | - } |
60 | | - |
61 | | - unpacked |
62 | | - } |
63 | | - }; |
64 | | - let res: Vec<EF> = unsafe { |
65 | | - let mut res: Vec<EF> = std::mem::transmute(res_transmuted); |
66 | | - res.set_len(new_size); |
67 | | - res |
68 | | - }; |
69 | | - |
70 | | - res |
71 | | -} |
72 | | - |
73 | 10 | pub fn fold_multilinear_in_large_field<F: Field, EF: ExtensionField<F>>( |
74 | 11 | m: &[F], |
75 | 12 | scalars: &[EF], |
@@ -145,40 +82,6 @@ pub fn batch_fold_multilinear_in_large_field_packed<EF: ExtensionField<PF<EF>>>( |
145 | 82 | .collect() |
146 | 83 | } |
147 | 84 |
|
148 | | -pub fn batch_fold_multilinear_in_small_field<F: Field, EF: ExtensionField<F>>( |
149 | | - polys: &[&[EF]], |
150 | | - scalars: &[F], |
151 | | -) -> Vec<Vec<EF>> { |
152 | | - polys |
153 | | - .par_iter() |
154 | | - .map(|poly| fold_multilinear_in_small_field(poly, scalars)) |
155 | | - .collect() |
156 | | -} |
157 | | - |
158 | | -pub fn batch_fold_multilinear_in_small_field_packed<EF: ExtensionField<PF<EF>>>( |
159 | | - polys: &[&[EFPacking<EF>]], |
160 | | - scalars: &[PF<EF>], |
161 | | -) -> Vec<Vec<EF>> { |
162 | | - polys |
163 | | - .par_iter() |
164 | | - .map(|poly| fold_multilinear_in_small_field(poly, scalars)) |
165 | | - .collect() |
166 | | -} |
167 | | - |
168 | | -// pub fn packed_multilinear<F: Field>(pols: &[Vec<F>]) -> Vec<F> { |
169 | | -// let n_vars = pols[0].num_variables(); |
170 | | -// assert!(pols.iter().all(|p| p.num_variables() == n_vars)); |
171 | | -// let packed_len = (pols.len() << n_vars).next_power_of_two(); |
172 | | -// let mut dst = F::zero_vec(packed_len); |
173 | | -// let mut offset = 0; |
174 | | -// // TODO parallelize |
175 | | -// for pol in pols { |
176 | | -// dst[offset..offset + pol.num_evals()].copy_from_slice(pol); |
177 | | -// offset += pol.num_evals(); |
178 | | -// } |
179 | | -// dst |
180 | | -// } |
181 | | - |
182 | 85 | #[instrument(name = "add_multilinears", skip_all)] |
183 | 86 | pub fn add_multilinears<F: Field>(pol1: &[F], pol2: &[F]) -> Vec<F> { |
184 | 87 | assert_eq!(pol1.len(), pol2.len()); |
|
0 commit comments