diff --git a/crates/vm/src/bytecode/instruction.rs b/crates/vm/src/bytecode/instruction.rs index 5af39915..42fd7057 100644 --- a/crates/vm/src/bytecode/instruction.rs +++ b/crates/vm/src/bytecode/instruction.rs @@ -3,7 +3,13 @@ use std::{ fmt::{Display, Formatter}, }; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing, dot_product}; +use p3_symmetric::Permutation; +use utils::{Poseidon16, Poseidon24, ToUsize}; +use whir_p3::poly::{evals::EvaluationsList, multilinear::MultilinearPoint}; + use super::{MemOrConstant, MemOrFp, MemOrFpOrConstant, Operation}; +use crate::{DIMENSION, EF, F, Memory, RunnerError}; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Instruction { @@ -49,6 +55,181 @@ pub enum Instruction { }, } +impl Instruction { + pub fn execute( + &self, + memory: &mut Memory, + fp: &mut usize, + pc: &mut usize, + p16: &Poseidon16, + p24: &Poseidon24, + poseidon16_calls: &mut usize, + poseidon24_calls: &mut usize, + dot_ext_ext_calls: &mut usize, + dot_base_ext_calls: &mut usize, + ) -> Result<(), RunnerError> { + match self { + Self::Computation { + operation, + arg_a, + arg_c, + res, + } => { + if res.is_value_unknown(memory, *fp) { + let addr = res.memory_address(*fp)?; + let a = arg_a.read_value(memory, *fp)?; + let b = arg_c.read_value(memory, *fp)?; + memory.set(addr, operation.compute(a, b))?; + } else if arg_a.is_value_unknown(memory, *fp) { + let addr = arg_a.memory_address(*fp)?; + let r = res.read_value(memory, *fp)?; + let b = arg_c.read_value(memory, *fp)?; + let a = operation + .inverse_compute(r, b) + .ok_or(RunnerError::DivByZero)?; + memory.set(addr, a)?; + } else if arg_c.is_value_unknown(memory, *fp) { + let addr = arg_c.memory_address(*fp)?; + let r = res.read_value(memory, *fp)?; + let a = arg_a.read_value(memory, *fp)?; + let b = operation + .inverse_compute(r, a) + .ok_or(RunnerError::DivByZero)?; + memory.set(addr, b)?; + } else { + let a = arg_a.read_value(memory, *fp)?; + let b = arg_c.read_value(memory, *fp)?; + let r = res.read_value(memory, *fp)?; + let c = operation.compute(a, b); + if r != c { + return Err(RunnerError::NotEqual(c, r)); + } + } + *pc += 1; + } + + Self::Deref { + shift_0, + shift_1, + res, + } => { + let ptr = memory.get(*fp + *shift_0)?.to_usize(); + if res.is_value_unknown(memory, *fp) { + let addr_res = res.memory_address(*fp)?; + let v = memory.get(ptr + *shift_1)?; + memory.set(addr_res, v)?; + } else { + let v = res.read_value(memory, *fp)?; + memory.set(ptr + *shift_1, v)?; + } + *pc += 1; + } + + Self::JumpIfNotZero { + condition, + dest, + updated_fp, + } => { + let c = condition.read_value(memory, *fp)?; + assert!([F::ZERO, F::ONE].contains(&c)); + if c == F::ZERO { + *pc += 1; + } else { + *pc = dest.read_value(memory, *fp)?.to_usize(); + *fp = updated_fp.read_value(memory, *fp)?.to_usize(); + } + } + + Self::Poseidon2_16 { arg_a, arg_b, res } => { + *poseidon16_calls += 1; + + let a_ptr = arg_a.read_value(memory, *fp)?.to_usize(); + let b_ptr = arg_b.read_value(memory, *fp)?.to_usize(); + let r_ptr = res.read_value(memory, *fp)?.to_usize(); + + let a = memory.get_vector(a_ptr)?; + let b = memory.get_vector(b_ptr)?; + + let mut state = [F::ZERO; DIMENSION * 2]; + state[..DIMENSION].copy_from_slice(&a); + state[DIMENSION..].copy_from_slice(&b); + p16.permute_mut(&mut state); + + memory.set_vectorized_slice(r_ptr, &state)?; + *pc += 1; + } + + Self::Poseidon2_24 { arg_a, arg_b, res } => { + *poseidon24_calls += 1; + + let a_ptr = arg_a.read_value(memory, *fp)?.to_usize(); + let b_ptr = arg_b.read_value(memory, *fp)?.to_usize(); + let r_ptr = res.read_value(memory, *fp)?.to_usize(); + + let a0 = memory.get_vector(a_ptr)?; + let a1 = memory.get_vector(a_ptr + 1)?; + let b = memory.get_vector(b_ptr)?; + + let mut state = [F::ZERO; DIMENSION * 3]; + state[..DIMENSION].copy_from_slice(&a0); + state[DIMENSION..2 * DIMENSION].copy_from_slice(&a1); + state[2 * DIMENSION..].copy_from_slice(&b); + p24.permute_mut(&mut state); + + memory.set_vectorized_slice(r_ptr, &state[2 * DIMENSION..])?; + *pc += 1; + } + + Self::DotProductExtensionExtension { + arg0, + arg1, + res, + size, + } => { + *dot_ext_ext_calls += 1; + + let p0 = arg0.read_value(memory, *fp)?.to_usize(); + let p1 = arg1.read_value(memory, *fp)?.to_usize(); + let pr = res.read_value(memory, *fp)?.to_usize(); + + let s0 = memory.get_vectorized_slice_extension::(p0, *size)?; + let s1 = memory.get_vectorized_slice_extension::(p1, *size)?; + + let dp: [F; DIMENSION] = dot_product::(s0.into_iter(), s1.into_iter()) + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + memory.set_vector(pr, dp)?; + *pc += 1; + } + + Self::MultilinearEval { + coeffs, + point, + res, + n_vars, + } => { + *dot_base_ext_calls += 1; + + let pcf = coeffs.read_value(memory, *fp)?.to_usize(); + let ppt = point.read_value(memory, *fp)?.to_usize(); + let pr = res.read_value(memory, *fp)?.to_usize(); + + let start = pcf << *n_vars; + let len = 1usize << *n_vars; + let coeffs = memory.slice(start, len)?; + let point = memory.get_vectorized_slice_extension::(ppt, *n_vars)?; + + let eval = coeffs.evaluate(&MultilinearPoint(point)); + let out: [F; DIMENSION] = eval.as_basis_coefficients_slice().try_into().unwrap(); + memory.set_vector(pr, out)?; + *pc += 1; + } + } + Ok(()) + } +} + impl Display for Instruction { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { diff --git a/crates/vm/src/runner.rs b/crates/vm/src/runner.rs index d78f9160..b5019f9a 100644 --- a/crates/vm/src/runner.rs +++ b/crates/vm/src/runner.rs @@ -1,12 +1,10 @@ -use p3_field::{BasedVectorSpace, PrimeCharacteristicRing, dot_product}; +use p3_field::PrimeCharacteristicRing; use p3_symmetric::Permutation; -use utils::{ToUsize, build_poseidon16, build_poseidon24, pretty_integer}; -use whir_p3::poly::{evals::EvaluationsList, multilinear::MultilinearPoint}; +use utils::{build_poseidon16, build_poseidon24, pretty_integer}; use crate::{ - DIMENSION, EF, F, MAX_MEMORY_SIZE, Memory, POSEIDON_16_NULL_HASH_PTR, - POSEIDON_24_NULL_HASH_PTR, PUBLIC_INPUT_START, RunnerError, ZERO_VEC_PTR, - bytecode::{Bytecode, Instruction}, + DIMENSION, F, MAX_MEMORY_SIZE, Memory, POSEIDON_16_NULL_HASH_PTR, POSEIDON_24_NULL_HASH_PTR, + PUBLIC_INPUT_START, RunnerError, ZERO_VEC_PTR, bytecode::Bytecode, }; #[must_use] @@ -136,180 +134,17 @@ fn execute_bytecode_helper( )?; } - let instruction = &bytecode.instructions[pc]; - match instruction { - Instruction::Computation { - operation, - arg_a, - arg_c, - res, - } => { - if res.is_value_unknown(&memory, fp) { - let memory_address_res = res.memory_address(fp)?; - let a_value = arg_a.read_value(&memory, fp)?; - let b_value = arg_c.read_value(&memory, fp)?; - let res_value = operation.compute(a_value, b_value); - memory.set(memory_address_res, res_value)?; - } else if arg_a.is_value_unknown(&memory, fp) { - let memory_address_a = arg_a.memory_address(fp)?; - let res_value = res.read_value(&memory, fp)?; - let b_value = arg_c.read_value(&memory, fp)?; - let a_value = operation - .inverse_compute(res_value, b_value) - .ok_or(RunnerError::DivByZero)?; - memory.set(memory_address_a, a_value)?; - } else if arg_c.is_value_unknown(&memory, fp) { - let memory_address_b = arg_c.memory_address(fp)?; - let res_value = res.read_value(&memory, fp)?; - let a_value = arg_a.read_value(&memory, fp)?; - let b_value = operation - .inverse_compute(res_value, a_value) - .ok_or(RunnerError::DivByZero)?; - memory.set(memory_address_b, b_value)?; - } else { - let a_value = arg_a.read_value(&memory, fp)?; - let b_value = arg_c.read_value(&memory, fp)?; - let res_value = res.read_value(&memory, fp)?; - let computed_value = operation.compute(a_value, b_value); - if res_value != computed_value { - return Err(RunnerError::NotEqual(computed_value, res_value)); - } - } - - pc += 1; - } - Instruction::Deref { - shift_0, - shift_1, - res, - } => { - if res.is_value_unknown(&memory, fp) { - let memory_address_res = res.memory_address(fp)?; - let ptr = memory.get(fp + shift_0)?; - let value = memory.get(ptr.to_usize() + shift_1)?; - memory.set(memory_address_res, value)?; - } else { - let value = res.read_value(&memory, fp)?; - let ptr = memory.get(fp + shift_0)?; - memory.set(ptr.to_usize() + shift_1, value)?; - } - pc += 1; - } - Instruction::JumpIfNotZero { - condition, - dest, - updated_fp, - } => { - let condition_value = condition.read_value(&memory, fp)?; - assert!([F::ZERO, F::ONE].contains(&condition_value),); - if condition_value == F::ZERO { - pc += 1; - } else { - pc = dest.read_value(&memory, fp)?.to_usize(); - fp = updated_fp.read_value(&memory, fp)?.to_usize(); - } - } - Instruction::Poseidon2_16 { arg_a, arg_b, res } => { - poseidon16_calls += 1; - - let a_value = arg_a.read_value(&memory, fp)?; - let b_value = arg_b.read_value(&memory, fp)?; - let res_value = res.read_value(&memory, fp)?; - - let arg0 = memory.get_vector(a_value.to_usize())?; - let arg1 = memory.get_vector(b_value.to_usize())?; - - let mut input = [F::ZERO; DIMENSION * 2]; - input[..DIMENSION].copy_from_slice(&arg0); - input[DIMENSION..].copy_from_slice(&arg1); - - poseidon_16.permute_mut(&mut input); - - let res0: [F; DIMENSION] = input[..DIMENSION].try_into().unwrap(); - let res1: [F; DIMENSION] = input[DIMENSION..].try_into().unwrap(); - - memory.set_vector(res_value.to_usize(), res0)?; - memory.set_vector(1 + res_value.to_usize(), res1)?; - - pc += 1; - } - Instruction::Poseidon2_24 { arg_a, arg_b, res } => { - poseidon24_calls += 1; - - let a_value = arg_a.read_value(&memory, fp)?; - let b_value = arg_b.read_value(&memory, fp)?; - let res_value = res.read_value(&memory, fp)?; - - let arg0 = memory.get_vector(a_value.to_usize())?; - let arg1 = memory.get_vector(1 + a_value.to_usize())?; - let arg2 = memory.get_vector(b_value.to_usize())?; - - let mut input = [F::ZERO; DIMENSION * 3]; - input[..DIMENSION].copy_from_slice(&arg0); - input[DIMENSION..2 * DIMENSION].copy_from_slice(&arg1); - input[2 * DIMENSION..].copy_from_slice(&arg2); - - poseidon_24.permute_mut(&mut input); - - let res: [F; DIMENSION] = input[2 * DIMENSION..].try_into().unwrap(); - - memory.set_vector(res_value.to_usize(), res)?; - - pc += 1; - } - Instruction::DotProductExtensionExtension { - arg0, - arg1, - res, - size, - } => { - dot_product_ext_ext_calls += 1; - - let ptr_arg_0 = arg0.read_value(&memory, fp)?.to_usize(); - let ptr_arg_1 = arg1.read_value(&memory, fp)?.to_usize(); - let ptr_res = res.read_value(&memory, fp)?.to_usize(); - - let slice_0 = (ptr_arg_0..ptr_arg_0 + *size) - .map(|i| Ok(EF::from_basis_coefficients_slice(&memory.get_vector(i)?).unwrap())) - .collect::, _>>()?; - - let slice_1 = (ptr_arg_1..ptr_arg_1 + *size) - .map(|i| Ok(EF::from_basis_coefficients_slice(&memory.get_vector(i)?).unwrap())) - .collect::, _>>()?; - - let dot_product = dot_product::(slice_0.into_iter(), slice_1.into_iter()) - .as_basis_coefficients_slice() - .try_into() - .unwrap(); - memory.set_vector(ptr_res, dot_product)?; - - pc += 1; - } - Instruction::MultilinearEval { - coeffs, - point, - res, - n_vars, - } => { - dot_product_base_ext_calls += 1; - - let ptr_coeffs = coeffs.read_value(&memory, fp)?.to_usize(); - let ptr_point = point.read_value(&memory, fp)?.to_usize(); - let ptr_res = res.read_value(&memory, fp)?.to_usize(); - let slice_coeffs = (ptr_coeffs << *n_vars..(1 + ptr_coeffs) << *n_vars) - .map(|i| memory.get(i)) - .collect::, _>>()?; - let point = (ptr_point..ptr_point + *n_vars) - .map(|i| Ok(EF::from_basis_coefficients_slice(&memory.get_vector(i)?).unwrap())) - .collect::, _>>()?; - - let eval = slice_coeffs.evaluate(&MultilinearPoint(point.clone())); - let eval_base: [F; 8] = eval.as_basis_coefficients_slice().try_into().unwrap(); - memory.set_vector(ptr_res, eval_base)?; - - pc += 1; - } - } + bytecode.instructions[pc].execute( + &mut memory, + &mut fp, + &mut pc, + &poseidon_16, + &poseidon_24, + &mut poseidon16_calls, + &mut poseidon24_calls, + &mut dot_product_ext_ext_calls, + &mut dot_product_base_ext_calls, + )?; } debug_assert_eq!(pc, bytecode.ending_pc);