11use std:: any:: TypeId ;
22
33use p3_air:: BaseAir ;
4- use p3_field:: { ExtensionField , Field , cyclic_subgroup_known_order, dot_product } ;
5- use p3_util:: log2_ceil_usize;
4+ use p3_field:: { ExtensionField , Field , cyclic_subgroup_known_order} ;
5+ use p3_util:: { log2_ceil_usize, log2_strict_usize } ;
66use sumcheck:: { MleGroup , MleGroupOwned , MleGroupRef , ProductComputation } ;
77use tracing:: { info_span, instrument} ;
88use utils:: PF ;
9- use utils:: { FSProver , add_multilinears, from_end , multilinears_linear_combination} ;
9+ use utils:: { FSProver , add_multilinears, multilinears_linear_combination} ;
1010use whir_p3:: fiat_shamir:: FSChallenger ;
1111use whir_p3:: poly:: evals:: { eval_eq, fold_multilinear, scale_poly} ;
1212use whir_p3:: poly:: multilinear:: Evaluation ;
1313use whir_p3:: poly:: { evals:: EvaluationsList , multilinear:: MultilinearPoint } ;
1414
15- use crate :: witness:: AirWitness ;
1615use crate :: { NormalAir , PackedAir } ;
1716use crate :: {
1817 uni_skip_utils:: { matrix_down_folded, matrix_up_folded} ,
@@ -38,43 +37,42 @@ fn prove_air<
3837 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
3938 univariate_skips : usize ,
4039 table : & AirTable < EF , A , AP > ,
41- witness : AirWitness < ' a , WF > ,
40+ witness : & [ & ' a [ WF ] ] ,
4241) -> Vec < Evaluation < EF > > {
42+ let n_rows = witness[ 0 ] . len ( ) ;
43+ assert ! ( witness. iter( ) . all( |col| col. len( ) == n_rows) ) ;
44+ let log_n_rows = log2_strict_usize ( n_rows) ;
4345 assert ! (
44- univariate_skips < witness . log_n_rows( ) ,
46+ univariate_skips < log_n_rows,
4547 "TODO handle the case UNIVARIATE_SKIPS >= log_length"
4648 ) ;
4749
4850 let structured_air = <A as BaseAir < PF < EF > > >:: structured ( & table. air ) ;
4951
50- let log_length = witness. log_n_rows ( ) ;
51-
5252 let constraints_batching_scalar = prover_state. sample ( ) ;
5353
5454 let constraints_batching_scalars =
5555 cyclic_subgroup_known_order ( constraints_batching_scalar, table. n_constraints )
5656 . collect :: < Vec < _ > > ( ) ;
5757
58- let n_sc_rounds = log_length + 1 - univariate_skips;
58+ let n_sc_rounds = log_n_rows + 1 - univariate_skips;
5959
6060 let zerocheck_challenges = prover_state. sample_vec ( n_sc_rounds) ;
6161
6262 let columns_for_zero_check: MleGroup < ' _ , EF > = if TypeId :: of :: < WF > ( ) == TypeId :: of :: < PF < EF > > ( ) {
63- let columns =
64- unsafe { std:: mem:: transmute :: < & Vec < & ' a [ WF ] > , & Vec < & ' a [ PF < EF > ] > > ( & witness. cols ) } ;
63+ let columns = unsafe { std:: mem:: transmute :: < & [ & [ WF ] ] , & [ & [ PF < EF > ] ] > ( witness) } ;
6564 if structured_air {
6665 MleGroupOwned :: Base ( columns_up_and_down ( columns) ) . into ( )
6766 } else {
68- MleGroupRef :: Base ( columns. clone ( ) ) . into ( )
67+ MleGroupRef :: Base ( columns. to_vec ( ) ) . into ( )
6968 }
7069 } else {
7170 assert ! ( TypeId :: of:: <WF >( ) == TypeId :: of:: <EF >( ) ) ;
72- let columns =
73- unsafe { std:: mem:: transmute :: < & Vec < & ' a [ WF ] > , & Vec < & ' a [ EF ] > > ( & witness. cols ) } ;
71+ let columns = unsafe { std:: mem:: transmute :: < & [ & ' a [ WF ] ] , & [ & ' a [ EF ] ] > ( witness) } ;
7472 if structured_air {
7573 MleGroupOwned :: Extension ( columns_up_and_down ( columns) ) . into ( )
7674 } else {
77- MleGroupRef :: Extension ( columns. clone ( ) ) . into ( )
75+ MleGroupRef :: Extension ( columns. to_vec ( ) ) . into ( )
7876 }
7977 } ;
8078
@@ -101,158 +99,126 @@ fn prove_air<
10199 open_structured_columns (
102100 prover_state,
103101 univariate_skips,
104- & witness,
102+ witness,
105103 & outer_sumcheck_challenge,
106104 )
107105 } else {
108106 open_unstructured_columns (
109107 prover_state,
110108 univariate_skips,
111- & witness,
109+ witness,
112110 & outer_sumcheck_challenge,
113111 )
114112 }
115113}
116114
117115impl < EF : ExtensionField < PF < EF > > , A : NormalAir < EF > , AP : PackedAir < EF > > AirTable < EF , A , AP > {
118116 #[ instrument( name = "air: prove in base" , skip_all) ]
119- pub fn prove_base < ' a > (
117+ pub fn prove_base (
120118 & self ,
121119 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
122120 univariate_skips : usize ,
123- witness : AirWitness < ' a , PF < EF > > ,
121+ witness : & [ & [ PF < EF > ] ] ,
124122 ) -> Vec < Evaluation < EF > > {
125123 prove_air :: < PF < EF > , EF , A , AP > ( prover_state, univariate_skips, self , witness)
126124 }
127125
128126 #[ instrument( name = "air: prove in extension" , skip_all) ]
129- pub fn prove_extension < ' a > (
127+ pub fn prove_extension (
130128 & self ,
131129 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
132130 univariate_skips : usize ,
133- witness : AirWitness < ' a , EF > ,
131+ witness : & [ & [ EF ] ] ,
134132 ) -> Vec < Evaluation < EF > > {
135133 prove_air :: < EF , EF , A , AP > ( prover_state, univariate_skips, self , witness)
136134 }
137135}
138136
139- fn eval_unstructured_column_groups < EF : ExtensionField < PF < EF > > + ExtensionField < IF > , IF : Field > (
140- prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
141- univariate_skips : usize ,
142- witnesses : & AirWitness < ' _ , IF > ,
143- outer_sumcheck_challenge : & [ EF ] ,
144- columns_batching_scalars : & [ EF ] ,
145- ) -> Vec < Vec < EF > > {
146- let mut all_sub_evals = vec ! [ ] ;
147- for group in & witnesses. column_groups {
148- let batched_column = multilinears_linear_combination (
149- & witnesses. cols [ group. clone ( ) ] ,
150- & eval_eq ( from_end (
151- columns_batching_scalars,
152- log2_ceil_usize ( group. len ( ) ) ,
153- ) ) [ ..group. len ( ) ] ,
154- ) ;
155-
156- // TODO opti
157- let sub_evals = fold_multilinear (
158- & batched_column,
159- & MultilinearPoint (
160- outer_sumcheck_challenge[ 1 ..witnesses. log_n_rows ( ) - univariate_skips + 1 ] . to_vec ( ) ,
161- ) ,
162- ) ;
163-
164- prover_state. add_extension_scalars ( & sub_evals) ;
165- all_sub_evals. push ( sub_evals) ;
166- }
167- all_sub_evals
168- }
169-
170137#[ instrument( skip_all) ]
171138fn open_unstructured_columns <
172- ' a ,
173139 WF : ExtensionField < PF < EF > > ,
174140 EF : ExtensionField < PF < EF > > + ExtensionField < WF > ,
175141> (
176142 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
177143 univariate_skips : usize ,
178- witness : & AirWitness < ' a , WF > ,
144+ witness : & [ & [ WF ] ] ,
179145 outer_sumcheck_challenge : & [ EF ] ,
180146) -> Vec < Evaluation < EF > > {
181- let columns_batching_scalars =
182- prover_state. sample_vec ( log2_ceil_usize ( witness. max_columns_per_group ( ) ) ) ;
147+ let log_n_rows = log2_strict_usize ( witness[ 0 ] . len ( ) ) ;
183148
184- let sub_evals = eval_unstructured_column_groups (
185- prover_state ,
186- univariate_skips ,
149+ let columns_batching_scalars = prover_state . sample_vec ( log2_ceil_usize ( witness . len ( ) ) ) ;
150+
151+ let batched_column = multilinears_linear_combination (
187152 witness,
188- outer_sumcheck_challenge,
189- & columns_batching_scalars,
153+ & eval_eq ( & columns_batching_scalars) [ ..witness. len ( ) ] ,
154+ ) ;
155+
156+ // TODO opti
157+ let sub_evals = fold_multilinear (
158+ & batched_column,
159+ & MultilinearPoint ( outer_sumcheck_challenge[ 1 ..log_n_rows - univariate_skips + 1 ] . to_vec ( ) ) ,
190160 ) ;
191161
162+ prover_state. add_extension_scalars ( & sub_evals) ;
163+
192164 let epsilons = MultilinearPoint ( prover_state. sample_vec ( univariate_skips) ) ;
165+ let common_point = MultilinearPoint (
166+ [
167+ epsilons. 0 . clone ( ) ,
168+ outer_sumcheck_challenge[ 1 ..log_n_rows - univariate_skips + 1 ] . to_vec ( ) ,
169+ ]
170+ . concat ( ) ,
171+ ) ;
193172
194173 let mut evaluations_remaining_to_prove = vec ! [ ] ;
195- for ( group, sub_evals) in witness. column_groups . iter ( ) . zip ( sub_evals) {
196- assert_eq ! ( sub_evals. len( ) , 1 << epsilons. len( ) ) ;
197-
198- evaluations_remaining_to_prove. push ( Evaluation :: new (
199- [
200- from_end ( & columns_batching_scalars, log2_ceil_usize ( group. len ( ) ) ) . to_vec ( ) ,
201- epsilons. 0 . clone ( ) ,
202- outer_sumcheck_challenge[ 1 ..witness. log_n_rows ( ) - univariate_skips + 1 ] . to_vec ( ) ,
203- ]
204- . concat ( ) ,
205- sub_evals. evaluate ( & epsilons) ,
206- ) ) ;
174+ assert_eq ! ( sub_evals. len( ) , 1 << epsilons. len( ) ) ;
175+
176+ for col in witness {
177+ // TODO compute oe time eq(.) then inner product with everything
178+ let value = col. evaluate ( & common_point) ;
179+ prover_state. add_extension_scalars ( & [ value] ) ;
180+ evaluations_remaining_to_prove. push ( Evaluation {
181+ point : common_point. clone ( ) ,
182+ value,
183+ } ) ;
207184 }
185+
208186 evaluations_remaining_to_prove
209187}
210188
211189#[ instrument( skip_all) ]
212- fn open_structured_columns < ' a , EF : ExtensionField < PF < EF > > + ExtensionField < IF > , IF : Field > (
190+ fn open_structured_columns < EF : ExtensionField < PF < EF > > + ExtensionField < IF > , IF : Field > (
213191 prover_state : & mut FSProver < EF , impl FSChallenger < EF > > ,
214192 univariate_skips : usize ,
215- witness : & AirWitness < ' a , IF > ,
193+ witness : & [ & [ IF ] ] ,
216194 outer_sumcheck_challenge : & [ EF ] ,
217195) -> Vec < Evaluation < EF > > {
218- let log_n_groups = log2_ceil_usize ( witness. column_groups . len ( ) ) ;
219- let batching_scalars =
220- prover_state. sample_vec ( log_n_groups + witness. log_max_columns_per_group ( ) ) ;
196+ let n_columns = witness. len ( ) ;
197+ let n_rows = witness[ 0 ] . len ( ) ;
198+ let log_n_rows = log2_strict_usize ( n_rows) ;
199+ let batching_scalars = prover_state. sample_vec ( log2_ceil_usize ( n_columns) ) ;
221200 let alpha = prover_state. sample ( ) ;
222201
223202 let poly_eq_batching_scalars = eval_eq ( & batching_scalars) ;
224- let mut column_scalars = vec ! [ ] ;
225- let mut index = 0 ;
226- for group in & witness. column_groups {
227- column_scalars. extend (
228- poly_eq_batching_scalars
229- . iter ( )
230- . skip ( index)
231- . take ( group. len ( ) )
232- . copied ( ) ,
233- ) ;
234- index += witness. max_columns_per_group ( ) . next_power_of_two ( ) ;
235- }
236203
237- let batched_column = multilinears_linear_combination ( & witness. cols , & column_scalars) ;
204+ let batched_column =
205+ multilinears_linear_combination ( witness, & poly_eq_batching_scalars[ ..n_columns] ) ;
238206 let batched_column_mixed = add_multilinears (
239207 & column_up ( & batched_column) ,
240208 & scale_poly ( & column_down ( & batched_column) , alpha) ,
241209 ) ;
242210 // TODO do not recompute this (we can deduce it from already computed values)
243211 let sub_evals = fold_multilinear (
244212 & batched_column_mixed,
245- & MultilinearPoint (
246- outer_sumcheck_challenge[ 1 ..witness. log_n_rows ( ) - univariate_skips + 1 ] . to_vec ( ) ,
247- ) ,
213+ & MultilinearPoint ( outer_sumcheck_challenge[ 1 ..log_n_rows - univariate_skips + 1 ] . to_vec ( ) ) ,
248214 ) ;
249215 prover_state. add_extension_scalars ( & sub_evals) ;
250216
251217 let epsilons = prover_state. sample_vec ( univariate_skips) ;
252218
253219 let point = [
254220 epsilons,
255- outer_sumcheck_challenge[ 1 ..witness . log_n_rows ( ) - univariate_skips + 1 ] . to_vec ( ) ,
221+ outer_sumcheck_challenge[ 1 ..log_n_rows - univariate_skips + 1 ] . to_vec ( ) ,
256222 ]
257223 . concat ( ) ;
258224
@@ -267,8 +233,7 @@ fn open_structured_columns<'a, EF: ExtensionField<PF<EF>> + ExtensionField<IF>,
267233 batched_column,
268234 ] ) ;
269235
270- let n_groups = witness. column_groups . len ( ) ;
271- let ( inner_challenges, inner_evals, _) = sumcheck:: prove :: < EF , _ , _ , _ > (
236+ let ( inner_challenges, _, _) = sumcheck:: prove :: < EF , _ , _ , _ > (
272237 1 ,
273238 inner_mle,
274239 & ProductComputation ,
@@ -284,43 +249,14 @@ fn open_structured_columns<'a, EF: ExtensionField<PF<EF>> + ExtensionField<IF>,
284249 // TODO using inner_evals[1], we can avoid 1 of the evaluations below (the last one)
285250
286251 let mut evaluations_remaining_to_prove = vec ! [ ] ;
287- for i in 0 ..n_groups {
288- let group = & witness. column_groups [ i] ;
289- let point = MultilinearPoint (
290- [
291- from_end (
292- & batching_scalars[ log_n_groups..] ,
293- log2_ceil_usize ( group. len ( ) ) ,
294- )
295- . to_vec ( ) ,
296- inner_challenges. 0 . clone ( ) ,
297- ]
298- . concat ( ) ,
299- ) ;
300- let value = {
301- let mut padded_group = IF :: zero_vec ( group. len ( ) . next_power_of_two ( ) * witness. n_rows ( ) ) ;
302- for ( i, col) in witness. cols [ group. clone ( ) ] . iter ( ) . enumerate ( ) {
303- padded_group[ i * witness. n_rows ( ) ..( i + 1 ) * witness. n_rows ( ) ] . copy_from_slice ( col) ;
304- }
305- padded_group. evaluate ( & point)
306- } ;
307- prover_state. add_extension_scalars ( & [ value] ) ;
308- evaluations_remaining_to_prove. push ( Evaluation { point, value } ) ;
252+ for col in witness {
253+ let value = col. evaluate ( & inner_challenges) ;
254+ prover_state. add_extension_scalar ( value) ;
255+ evaluations_remaining_to_prove. push ( Evaluation {
256+ point : inner_challenges. clone ( ) ,
257+ value,
258+ } ) ;
309259 }
310260
311- assert_eq ! (
312- inner_evals[ 1 ] ,
313- dot_product(
314- eval_eq( & batching_scalars[ ..log_n_groups] ) . into_iter( ) ,
315- ( 0 ..n_groups) . map( |i| evaluations_remaining_to_prove[ i] . value
316- * batching_scalars[ log_n_groups
317- ..log_n_groups + witness. log_max_columns_per_group( )
318- - log2_ceil_usize( witness. column_groups[ i] . len( ) ) ]
319- . iter( )
320- . map( |& x| EF :: ONE - x)
321- . product:: <EF >( ) )
322- )
323- ) ;
324-
325261 evaluations_remaining_to_prove
326262}
0 commit comments