@@ -3,7 +3,13 @@ use std::{
33 fmt:: { Display , Formatter } ,
44} ;
55
6+ use p3_field:: { BasedVectorSpace , PrimeCharacteristicRing , dot_product} ;
7+ use p3_symmetric:: Permutation ;
8+ use utils:: { Poseidon16 , Poseidon24 , ToUsize } ;
9+ use whir_p3:: poly:: { evals:: EvaluationsList , multilinear:: MultilinearPoint } ;
10+
611use super :: { MemOrConstant , MemOrFp , MemOrFpOrConstant , Operation } ;
12+ use crate :: { DIMENSION , EF , F , Memory , RunnerError } ;
713
814#[ derive( Debug , Clone , PartialEq , Eq , PartialOrd , Ord , Hash ) ]
915pub enum Instruction {
@@ -49,6 +55,181 @@ pub enum Instruction {
4955 } ,
5056}
5157
58+ impl Instruction {
59+ pub fn execute (
60+ & self ,
61+ memory : & mut Memory ,
62+ fp : & mut usize ,
63+ pc : & mut usize ,
64+ p16 : & Poseidon16 ,
65+ p24 : & Poseidon24 ,
66+ poseidon16_calls : & mut usize ,
67+ poseidon24_calls : & mut usize ,
68+ dot_ext_ext_calls : & mut usize ,
69+ dot_base_ext_calls : & mut usize ,
70+ ) -> Result < ( ) , RunnerError > {
71+ match self {
72+ Self :: Computation {
73+ operation,
74+ arg_a,
75+ arg_c,
76+ res,
77+ } => {
78+ if res. is_value_unknown ( memory, * fp) {
79+ let addr = res. memory_address ( * fp) ?;
80+ let a = arg_a. read_value ( memory, * fp) ?;
81+ let b = arg_c. read_value ( memory, * fp) ?;
82+ memory. set ( addr, operation. compute ( a, b) ) ?;
83+ } else if arg_a. is_value_unknown ( memory, * fp) {
84+ let addr = arg_a. memory_address ( * fp) ?;
85+ let r = res. read_value ( memory, * fp) ?;
86+ let b = arg_c. read_value ( memory, * fp) ?;
87+ let a = operation
88+ . inverse_compute ( r, b)
89+ . ok_or ( RunnerError :: DivByZero ) ?;
90+ memory. set ( addr, a) ?;
91+ } else if arg_c. is_value_unknown ( memory, * fp) {
92+ let addr = arg_c. memory_address ( * fp) ?;
93+ let r = res. read_value ( memory, * fp) ?;
94+ let a = arg_a. read_value ( memory, * fp) ?;
95+ let b = operation
96+ . inverse_compute ( r, a)
97+ . ok_or ( RunnerError :: DivByZero ) ?;
98+ memory. set ( addr, b) ?;
99+ } else {
100+ let a = arg_a. read_value ( memory, * fp) ?;
101+ let b = arg_c. read_value ( memory, * fp) ?;
102+ let r = res. read_value ( memory, * fp) ?;
103+ let c = operation. compute ( a, b) ;
104+ if r != c {
105+ return Err ( RunnerError :: NotEqual ( c, r) ) ;
106+ }
107+ }
108+ * pc += 1 ;
109+ }
110+
111+ Self :: Deref {
112+ shift_0,
113+ shift_1,
114+ res,
115+ } => {
116+ let ptr = memory. get ( * fp + * shift_0) ?. to_usize ( ) ;
117+ if res. is_value_unknown ( memory, * fp) {
118+ let addr_res = res. memory_address ( * fp) ?;
119+ let v = memory. get ( ptr + * shift_1) ?;
120+ memory. set ( addr_res, v) ?;
121+ } else {
122+ let v = res. read_value ( memory, * fp) ?;
123+ memory. set ( ptr + * shift_1, v) ?;
124+ }
125+ * pc += 1 ;
126+ }
127+
128+ Self :: JumpIfNotZero {
129+ condition,
130+ dest,
131+ updated_fp,
132+ } => {
133+ let c = condition. read_value ( memory, * fp) ?;
134+ assert ! ( [ F :: ZERO , F :: ONE ] . contains( & c) ) ;
135+ if c == F :: ZERO {
136+ * pc += 1 ;
137+ } else {
138+ * pc = dest. read_value ( memory, * fp) ?. to_usize ( ) ;
139+ * fp = updated_fp. read_value ( memory, * fp) ?. to_usize ( ) ;
140+ }
141+ }
142+
143+ Self :: Poseidon2_16 { arg_a, arg_b, res } => {
144+ * poseidon16_calls += 1 ;
145+
146+ let a_ptr = arg_a. read_value ( memory, * fp) ?. to_usize ( ) ;
147+ let b_ptr = arg_b. read_value ( memory, * fp) ?. to_usize ( ) ;
148+ let r_ptr = res. read_value ( memory, * fp) ?. to_usize ( ) ;
149+
150+ let a = memory. get_vector ( a_ptr) ?;
151+ let b = memory. get_vector ( b_ptr) ?;
152+
153+ let mut state = [ F :: ZERO ; DIMENSION * 2 ] ;
154+ state[ ..DIMENSION ] . copy_from_slice ( & a) ;
155+ state[ DIMENSION ..] . copy_from_slice ( & b) ;
156+ p16. permute_mut ( & mut state) ;
157+
158+ memory. set_vectorized_slice ( r_ptr, & state) ?;
159+ * pc += 1 ;
160+ }
161+
162+ Self :: Poseidon2_24 { arg_a, arg_b, res } => {
163+ * poseidon24_calls += 1 ;
164+
165+ let a_ptr = arg_a. read_value ( memory, * fp) ?. to_usize ( ) ;
166+ let b_ptr = arg_b. read_value ( memory, * fp) ?. to_usize ( ) ;
167+ let r_ptr = res. read_value ( memory, * fp) ?. to_usize ( ) ;
168+
169+ let a0 = memory. get_vector ( a_ptr) ?;
170+ let a1 = memory. get_vector ( a_ptr + 1 ) ?;
171+ let b = memory. get_vector ( b_ptr) ?;
172+
173+ let mut state = [ F :: ZERO ; DIMENSION * 3 ] ;
174+ state[ ..DIMENSION ] . copy_from_slice ( & a0) ;
175+ state[ DIMENSION ..2 * DIMENSION ] . copy_from_slice ( & a1) ;
176+ state[ 2 * DIMENSION ..] . copy_from_slice ( & b) ;
177+ p24. permute_mut ( & mut state) ;
178+
179+ memory. set_vectorized_slice ( r_ptr, & state[ 2 * DIMENSION ..] ) ?;
180+ * pc += 1 ;
181+ }
182+
183+ Self :: DotProductExtensionExtension {
184+ arg0,
185+ arg1,
186+ res,
187+ size,
188+ } => {
189+ * dot_ext_ext_calls += 1 ;
190+
191+ let p0 = arg0. read_value ( memory, * fp) ?. to_usize ( ) ;
192+ let p1 = arg1. read_value ( memory, * fp) ?. to_usize ( ) ;
193+ let pr = res. read_value ( memory, * fp) ?. to_usize ( ) ;
194+
195+ let s0 = memory. get_vectorized_slice_extension :: < EF > ( p0, * size) ?;
196+ let s1 = memory. get_vectorized_slice_extension :: < EF > ( p1, * size) ?;
197+
198+ let dp: [ F ; DIMENSION ] = dot_product :: < EF , _ , _ > ( s0. into_iter ( ) , s1. into_iter ( ) )
199+ . as_basis_coefficients_slice ( )
200+ . try_into ( )
201+ . unwrap ( ) ;
202+ memory. set_vector ( pr, dp) ?;
203+ * pc += 1 ;
204+ }
205+
206+ Self :: MultilinearEval {
207+ coeffs,
208+ point,
209+ res,
210+ n_vars,
211+ } => {
212+ * dot_base_ext_calls += 1 ;
213+
214+ let pcf = coeffs. read_value ( memory, * fp) ?. to_usize ( ) ;
215+ let ppt = point. read_value ( memory, * fp) ?. to_usize ( ) ;
216+ let pr = res. read_value ( memory, * fp) ?. to_usize ( ) ;
217+
218+ let start = pcf << * n_vars;
219+ let len = 1usize << * n_vars;
220+ let coeffs = memory. slice ( start, len) ?;
221+ let point = memory. get_vectorized_slice_extension :: < EF > ( ppt, * n_vars) ?;
222+
223+ let eval = coeffs. evaluate ( & MultilinearPoint ( point) ) ;
224+ let out: [ F ; DIMENSION ] = eval. as_basis_coefficients_slice ( ) . try_into ( ) . unwrap ( ) ;
225+ memory. set_vector ( pr, out) ?;
226+ * pc += 1 ;
227+ }
228+ }
229+ Ok ( ( ) )
230+ }
231+ }
232+
52233impl Display for Instruction {
53234 fn fmt ( & self , f : & mut Formatter < ' _ > ) -> fmt:: Result {
54235 match self {
0 commit comments