diff --git a/crates/vm/src/bytecode.rs b/crates/vm/src/bytecode.rs deleted file mode 100644 index 6de5ab02..00000000 --- a/crates/vm/src/bytecode.rs +++ /dev/null @@ -1,286 +0,0 @@ -use std::{ - collections::BTreeMap, - fmt, - fmt::{Display, Formatter}, -}; - -use p3_field::PrimeCharacteristicRing; - -use crate::F; - -pub type Label = String; - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Bytecode { - pub instructions: Vec, - pub hints: BTreeMap>, // pc -> hints - pub starting_frame_memory: usize, - pub ending_pc: usize, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum MemOrConstant { - Constant(F), - MemoryAfterFp { offset: usize }, // m[fp + offset] -} -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum MemOrFpOrConstant { - MemoryAfterFp { offset: usize }, // m[fp + offset] - Fp, - Constant(F), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum MemOrFp { - MemoryAfterFp { offset: usize }, // m[fp + offset] - Fp, -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Operation { - Add, - Mul, -} - -impl fmt::Display for Operation { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Add => write!(f, "+"), - Self::Mul => write!(f, "x"), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Instruction { - // 3 basic instructions - Computation { - operation: Operation, - arg_a: MemOrConstant, - arg_c: MemOrFp, - res: MemOrConstant, - }, - Deref { - shift_0: usize, - shift_1: usize, - res: MemOrFpOrConstant, - }, // res = m[m[fp + shift_0] + shift_1] - JumpIfNotZero { - condition: MemOrConstant, - dest: MemOrConstant, - updated_fp: MemOrFp, - }, - // 4 precompiles: - Poseidon2_16 { - arg_a: MemOrConstant, // vectorized pointer, of size 1 - arg_b: MemOrConstant, // vectorized pointer, of size 1 - res: MemOrFp, // vectorized pointer, of size 2 (The Fp would never be used in practice) - }, - Poseidon2_24 { - arg_a: MemOrConstant, // vectorized pointer, of size 2 (2 first inputs) - arg_b: MemOrConstant, // vectorized pointer, of size 1 (3rd = last input) - res: MemOrFp, // vectorized pointer, of size 1 (3rd = last output) (The Fp would never be used in practice) - }, - DotProductExtensionExtension { - arg0: MemOrConstant, // vectorized pointer - arg1: MemOrConstant, // vectorized pointer - res: MemOrFp, // vectorized pointer, of size 1 (never Fp in practice) - size: usize, - }, - MultilinearEval { - coeffs: MemOrConstant, // vectorized pointer, chunk size = 2^n_vars - point: MemOrConstant, // vectorized pointer, of size `n_vars` - res: MemOrFp, // vectorized pointer, of size 1 (never fp in practice) - n_vars: usize, - }, -} - -impl Operation { - #[must_use] - pub fn compute(&self, a: F, b: F) -> F { - match self { - Self::Add => a + b, - Self::Mul => a * b, - } - } - - #[must_use] - pub fn inverse_compute(&self, a: F, b: F) -> Option { - match self { - Self::Add => Some(a - b), - Self::Mul => { - if b == F::ZERO { - None - } else { - Some(a / b) - } - } - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Hint { - Inverse { - arg: MemOrConstant, // the value to invert (return 0 if arg is zero) - res_offset: usize, // m[fp + res_offset] will contain the result - }, - RequestMemory { - offset: usize, // m[fp + offset] where the hint will be stored - size: MemOrConstant, // the hint - vectorized: bool, - }, - DecomposeBits { - res_offset: usize, // m[fp + res_offset..fp + res_offset + 31] will contain the decomposed bits - to_decompose: MemOrConstant, - }, - Print { - line_info: String, - content: Vec, - }, -} - -impl MemOrConstant { - #[must_use] - pub const fn zero() -> Self { - Self::Constant(F::ZERO) - } - - #[must_use] - pub const fn one() -> Self { - Self::Constant(F::ONE) - } -} - -impl Display for Bytecode { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - for (pc, instruction) in self.instructions.iter().enumerate() { - if let Some(hints) = self.hints.get(&pc) { - for hint in hints { - writeln!(f, "hint: {hint}")?; - } - } - writeln!(f, "{pc:>4}: {instruction}")?; - } - Ok(()) - } -} - -impl Display for MemOrConstant { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::Constant(c) => write!(f, "{c}"), - Self::MemoryAfterFp { offset } => write!(f, "m[fp + {offset}]"), - } - } -} - -impl Display for MemOrFp { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::MemoryAfterFp { offset } => write!(f, "m[fp + {offset}]"), - Self::Fp => f.write_str("fp"), - } - } -} - -impl Display for MemOrFpOrConstant { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::MemoryAfterFp { offset } => write!(f, "m[fp + {offset}]"), - Self::Fp => f.write_str("fp"), - Self::Constant(c) => write!(f, "{c}"), - } - } -} - -impl Display for Instruction { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::Computation { - operation, - arg_a, - arg_c, - res, - } => { - write!(f, "{res} = {arg_a} {operation} {arg_c}") - } - Self::Deref { - shift_0, - shift_1, - res, - } => { - write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]") - } - Self::DotProductExtensionExtension { - arg0, - arg1, - res, - size, - } => { - write!(f, "dot_product({arg0}, {arg1}, {res}, {size})") - } - Self::MultilinearEval { - coeffs, - point, - res, - n_vars, - } => { - write!(f, "multilinear_eval({coeffs}, {point}, {res}, {n_vars})") - } - Self::JumpIfNotZero { - condition, - dest, - updated_fp, - } => { - write!( - f, - "if {condition} != 0 jump to {dest} with next(fp) = {updated_fp}" - ) - } - Self::Poseidon2_16 { arg_a, arg_b, res } => { - write!(f, "{res} = poseidon2_16({arg_a}, {arg_b})") - } - Self::Poseidon2_24 { arg_a, arg_b, res } => { - write!(f, "{res} = poseidon2_24({arg_a}, {arg_b})") - } - } - } -} - -impl Display for Hint { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::RequestMemory { - offset, - size, - vectorized, - } => { - write!( - f, - "m[fp + {offset}] = {}({size})", - if *vectorized { "malloc_vec" } else { "malloc" } - ) - } - Self::DecomposeBits { - res_offset, - to_decompose, - } => { - write!(f, "m[fp + {res_offset}] = decompose_bits({to_decompose})") - } - Self::Print { line_info, content } => { - write!(f, "print(")?; - for (i, c) in content.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{c}")?; - } - write!(f, ") for \"{line_info}\"") - } - Self::Inverse { arg, res_offset } => { - write!(f, "m[fp + {res_offset}] = inverse({arg})") - } - } - } -} diff --git a/crates/vm/src/bytecode/hint.rs b/crates/vm/src/bytecode/hint.rs new file mode 100644 index 00000000..b9f3e60a --- /dev/null +++ b/crates/vm/src/bytecode/hint.rs @@ -0,0 +1,88 @@ +use std::{ + fmt, + fmt::{Display, Formatter}, +}; + +use super::MemOrConstant; + +/// Hints are special instructions for the prover to resolve non-determinism. +/// +/// They are not part of the verified computation trace. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Hint { + /// A hint for the prover to allocate a new memory segment for a function's stack frame. + /// + /// This is the core mechanism for memory management in a VM without an `ap` (allocation pointer) + /// register. The compiler pre-calculates the required memory size for each function. + RequestMemory { + /// The offset from `fp` where the pointer to the newly allocated segment will be stored. + offset: usize, + /// The requested size of the memory segment in scalar field elements. + size: MemOrConstant, + /// If true, the start of the allocated memory is aligned to an 8-element boundary + /// to facilitate vectorized memory access for extension field operations. + /// The value stored at `m[fp + offset]` will be the aligned address divided by 8. + vectorized: bool, + }, + /// A hint for the prover to compute the bit decomposition of a base field element. + /// + /// This is a non-deterministic operation used for operations like range checks + /// or other logic required by the XMSS signature scheme. + DecomposeBits { + /// The starting offset from `fp` where the resulting bits will be stored. + res_offset: usize, + /// The field element that needs to be decomposed into its bits. + to_decompose: MemOrConstant, + }, + /// A hint used for debugging to print values from memory during execution. + Print { + /// A string containing line information (e.g., file and line number) for context. + line_info: String, + /// A list of memory locations or constants whose values should be printed. + content: Vec, + }, + /// A hint for the prover to compute the modular inverse of a field element. + Inverse { + /// The value to be inverted. + arg: MemOrConstant, + /// The offset from `fp` where the result (`arg^-1`) will be stored. If `arg` is zero, zero is stored. + res_offset: usize, + }, +} + +impl Display for Hint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::RequestMemory { + offset, + size, + vectorized, + } => { + write!( + f, + "m[fp + {offset}] = {}({size})", + if *vectorized { "malloc_vec" } else { "malloc" } + ) + } + Self::DecomposeBits { + res_offset, + to_decompose, + } => { + write!(f, "m[fp + {res_offset}] = decompose_bits({to_decompose})") + } + Self::Print { line_info, content } => { + write!(f, "print(")?; + for (i, c) in content.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{c}")?; + } + write!(f, ") for \"{line_info}\"") + } + Self::Inverse { arg, res_offset } => { + write!(f, "m[fp + {res_offset}] = inverse({arg})") + } + } + } +} diff --git a/crates/vm/src/bytecode/instruction.rs b/crates/vm/src/bytecode/instruction.rs new file mode 100644 index 00000000..5af39915 --- /dev/null +++ b/crates/vm/src/bytecode/instruction.rs @@ -0,0 +1,104 @@ +use std::{ + fmt, + fmt::{Display, Formatter}, +}; + +use super::{MemOrConstant, MemOrFp, MemOrFpOrConstant, Operation}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Instruction { + // 3 basic instructions + Computation { + operation: Operation, + arg_a: MemOrConstant, + arg_c: MemOrFp, + res: MemOrConstant, + }, + Deref { + shift_0: usize, + shift_1: usize, + res: MemOrFpOrConstant, + }, // res = m[m[fp + shift_0] + shift_1] + JumpIfNotZero { + condition: MemOrConstant, + dest: MemOrConstant, + updated_fp: MemOrFp, + }, + // 4 precompiles: + Poseidon2_16 { + arg_a: MemOrConstant, // vectorized pointer, of size 1 + arg_b: MemOrConstant, // vectorized pointer, of size 1 + res: MemOrFp, // vectorized pointer, of size 2 (The Fp would never be used in practice) + }, + Poseidon2_24 { + arg_a: MemOrConstant, // vectorized pointer, of size 2 (2 first inputs) + arg_b: MemOrConstant, // vectorized pointer, of size 1 (3rd = last input) + res: MemOrFp, // vectorized pointer, of size 1 (3rd = last output) (The Fp would never be used in practice) + }, + DotProductExtensionExtension { + arg0: MemOrConstant, // vectorized pointer + arg1: MemOrConstant, // vectorized pointer + res: MemOrFp, // vectorized pointer, of size 1 (never Fp in practice) + size: usize, + }, + MultilinearEval { + coeffs: MemOrConstant, // vectorized pointer, chunk size = 2^n_vars + point: MemOrConstant, // vectorized pointer, of size `n_vars` + res: MemOrFp, // vectorized pointer, of size 1 (never fp in practice) + n_vars: usize, + }, +} + +impl Display for Instruction { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Computation { + operation, + arg_a, + arg_c, + res, + } => { + write!(f, "{res} = {arg_a} {operation} {arg_c}") + } + Self::Deref { + shift_0, + shift_1, + res, + } => { + write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]") + } + Self::DotProductExtensionExtension { + arg0, + arg1, + res, + size, + } => { + write!(f, "dot_product({arg0}, {arg1}, {res}, {size})") + } + Self::MultilinearEval { + coeffs, + point, + res, + n_vars, + } => { + write!(f, "multilinear_eval({coeffs}, {point}, {res}, {n_vars})") + } + Self::JumpIfNotZero { + condition, + dest, + updated_fp, + } => { + write!( + f, + "if {condition} != 0 jump to {dest} with next(fp) = {updated_fp}" + ) + } + Self::Poseidon2_16 { arg_a, arg_b, res } => { + write!(f, "{res} = poseidon2_16({arg_a}, {arg_b})") + } + Self::Poseidon2_24 { arg_a, arg_b, res } => { + write!(f, "{res} = poseidon2_24({arg_a}, {arg_b})") + } + } + } +} diff --git a/crates/vm/src/bytecode/mod.rs b/crates/vm/src/bytecode/mod.rs new file mode 100644 index 00000000..52f3585e --- /dev/null +++ b/crates/vm/src/bytecode/mod.rs @@ -0,0 +1,38 @@ +use std::{ + collections::BTreeMap, + fmt, + fmt::{Display, Formatter}, +}; + +pub mod operand; +pub use operand::*; +pub mod hint; +pub use hint::*; +pub mod operation; +pub use operation::*; +pub mod instruction; +pub use instruction::*; + +pub type Label = String; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Bytecode { + pub instructions: Vec, + pub hints: BTreeMap>, // pc -> hints + pub starting_frame_memory: usize, + pub ending_pc: usize, +} + +impl Display for Bytecode { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for (pc, instruction) in self.instructions.iter().enumerate() { + if let Some(hints) = self.hints.get(&pc) { + for hint in hints { + writeln!(f, "hint: {hint}")?; + } + } + writeln!(f, "{pc:>4}: {instruction}")?; + } + Ok(()) + } +} diff --git a/crates/vm/src/bytecode/operand.rs b/crates/vm/src/bytecode/operand.rs new file mode 100644 index 00000000..2f4dff03 --- /dev/null +++ b/crates/vm/src/bytecode/operand.rs @@ -0,0 +1,122 @@ +use std::{ + fmt, + fmt::{Display, Formatter}, +}; + +use p3_field::PrimeCharacteristicRing; + +use crate::F; + +/// Represents a value that can either be a constant or a value from memory. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum MemOrConstant { + /// A constant value (a field element). + Constant(F), + /// A memory location specified by a positive offset from the frame pointer (`fp`). + /// + /// Represents the scalar value at `m[fp + shift]`. + MemoryAfterFp { + /// The offset from `fp` where the memory location is located. + offset: usize, + }, +} + +impl MemOrConstant { + /// Converts the operand into its raw field representation for the trace. + /// + /// Returns a tuple `(operand, flag)` where: + /// - `operand`: The field element representing the constant value or memory offset. + /// - `flag`: A flag that is `1` for a constant and `0` for a memory access. + #[must_use] + pub fn to_field_and_flag(&self) -> (F, F) { + match self { + // If it's a constant, the flag is 1 and the value is the constant itself. + Self::Constant(c) => (*c, F::ONE), + // If it's a memory location, the flag is 0 and the value is the offset. + Self::MemoryAfterFp { offset } => (F::from_usize(*offset), F::ZERO), + } + } + + /// Returns a constant operand with value `0`. + #[must_use] + pub const fn zero() -> Self { + Self::Constant(F::ZERO) + } + + /// Returns a constant operand with value `1`. + #[must_use] + pub const fn one() -> Self { + Self::Constant(F::ONE) + } +} + +impl Display for MemOrConstant { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Constant(c) => write!(f, "{c}"), + Self::MemoryAfterFp { offset } => write!(f, "m[fp + {offset}]"), + } + } +} + +/// Represents a value that can be a memory location, the `fp` register itself, or a constant. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum MemOrFpOrConstant { + /// A memory location specified by a positive offset from `fp`. Represents `m[fp + shift]`. + MemoryAfterFp { + /// The offset from `fp` where the memory location is located. + offset: usize, + }, + /// The value of the frame pointer (`fp`) register itself. + Fp, + /// A constant value (a field element). + Constant(F), +} + +impl Display for MemOrFpOrConstant { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::MemoryAfterFp { offset } => write!(f, "m[fp + {offset}]"), + Self::Fp => f.write_str("fp"), + Self::Constant(c) => write!(f, "{c}"), + } + } +} + +/// Represents a value that is either a memory location or the `fp` register itself. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum MemOrFp { + /// A memory location specified by a positive offset from `fp`. Represents `m[fp + shift]`. + MemoryAfterFp { + /// The offset from `fp` where the memory location is located. + offset: usize, + }, + /// The value of the frame pointer (`fp`) register itself. + Fp, +} + +impl MemOrFp { + /// Converts the operand into its raw field representation for the trace. + /// + /// Returns a tuple `(operand, flag)` where: + /// - `operand`: The field element representing the memory offset (or 0 if `Fp`). + /// - `flag`: A flag that is `1` for the `Fp` register and `0` for a memory access. + #[must_use] + pub fn to_field_and_flag(&self) -> (F, F) { + match self { + // If it's the frame pointer, the flag is 1 and the operand value is 0. + Self::Fp => (F::ZERO, F::ONE), + // If it's a memory location, the flag is 0 and the value is the offset. + Self::MemoryAfterFp { offset } => (F::from_usize(*offset), F::ZERO), + } + } +} + +impl Display for MemOrFp { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::MemoryAfterFp { offset } => write!(f, "m[fp + {offset}]"), + Self::Fp => f.write_str("fp"), + } + } +} diff --git a/crates/vm/src/bytecode/operation.rs b/crates/vm/src/bytecode/operation.rs new file mode 100644 index 00000000..5612930f --- /dev/null +++ b/crates/vm/src/bytecode/operation.rs @@ -0,0 +1,115 @@ +use std::fmt; + +use p3_field::Field; + +use crate::F; + +/// The basic arithmetic operations supported by the VM's `Computation` instruction. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Operation { + /// Field addition in the base field. + Add, + /// Field multiplication in the base field. + Mul, +} + +impl Operation { + /// Computes the result of applying this arithmetic operation to two operands. + /// + /// # Parameters + /// + /// - `a`: The left-hand operand. + /// - `b`: The right-hand operand. + /// + /// # Returns + /// + /// The result of `a ⊕ b`, where `⊕` is the operation represented by `self`. + /// For example: + /// - If `self` is `Add`, returns `a + b`. + /// - If `self` is `Mul`, returns `a * b`. + #[must_use] + pub fn compute(&self, a: F, b: F) -> F { + match self { + Self::Add => a + b, + Self::Mul => a * b, + } + } + + /// Computes the inverse of the operation with respect to the right-hand operand. + /// + /// Solves for `a` given the result `c = a ⊕ b`, by computing `a = c ⊖ b`, where `⊖` + /// is the inverse of the operation represented by `self`. + /// + /// # Parameters + /// + /// - `a`: The result value (i.e., `a ⊕ b`). + /// - `b`: The right-hand operand of the original operation. + /// + /// # Returns + /// + /// - `Some(a)` if the inverse exists. + /// - `None` if the inverse does not exist (e.g., `b == 0` for `Mul`). + #[must_use] + pub fn inverse_compute(&self, a: F, b: F) -> Option { + match self { + Self::Add => Some(a - b), + Self::Mul => (!b.is_zero()).then(|| a / b), + } + } +} + +impl fmt::Display for Operation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Add => write!(f, "+"), + Self::Mul => write!(f, "x"), + } + } +} + +#[cfg(test)] +mod tests { + use p3_field::PrimeCharacteristicRing; + + use super::*; + + #[test] + fn test_compute_add() { + let op = Operation::Add; + let val1 = F::from_u32(100); + let val2 = F::from_u32(50); + assert_eq!(op.compute(val1, val2), F::from_u32(150)); + } + + #[test] + fn test_compute_mul() { + let op = Operation::Mul; + let val1 = F::from_u32(10); + let val2 = F::from_u32(5); + assert_eq!(op.compute(val1, val2), F::from_u32(50)); + } + + #[test] + fn test_inverse_compute_add() { + let op = Operation::Add; + let val1 = F::from_u32(150); + let val2 = F::from_u32(50); + assert_eq!(op.inverse_compute(val1, val2), Some(F::from_u32(100))); + } + + #[test] + fn test_inverse_compute_mul_success() { + let op = Operation::Mul; + let val1 = F::from_u32(50); + let val2 = F::from_u32(5); + assert_eq!(op.inverse_compute(val1, val2), Some(F::from_u32(10))); + } + + #[test] + fn test_inverse_compute_mul_by_zero() { + let op = Operation::Mul; + let val1 = F::from_u32(50); + let val2 = F::ZERO; + assert_eq!(op.inverse_compute(val1, val2), None); + } +}