diff --git a/crates/compiler/src/c_compile_final.rs b/crates/compiler/src/c_compile_final.rs index 9161539e..f4d9712b 100644 --- a/crates/compiler/src/c_compile_final.rs +++ b/crates/compiler/src/c_compile_final.rs @@ -15,26 +15,6 @@ use crate::{ lang::{ConstExpression, ConstantValue}, }; -impl IntermediateInstruction { - const fn is_hint(&self) -> bool { - match self { - Self::RequestMemory { .. } - | Self::Print { .. } - | Self::DecomposeBits { .. } - | Self::Inverse { .. } => true, - Self::Computation { .. } - | Self::Panic - | Self::Deref { .. } - | Self::JumpIfNotZero { .. } - | Self::Jump { .. } - | Self::Poseidon2_16 { .. } - | Self::Poseidon2_24 { .. } - | Self::DotProduct { .. } - | Self::MultilinearEval { .. } => false, - } - } -} - struct Compiler { memory_size_per_function: BTreeMap, label_to_pc: BTreeMap, @@ -329,6 +309,7 @@ impl IntermediateValue { _ => Err(format!("Cannot convert {self:?} to MemOrFp")), } } + fn try_into_mem_or_constant(&self, compiler: &Compiler) -> Result { if let Some(cst) = try_as_constant(self, compiler) { return Ok(MemOrConstant::Constant(cst)); diff --git a/crates/compiler/src/intermediate_bytecode.rs b/crates/compiler/src/intermediate_bytecode.rs deleted file mode 100644 index 32da4694..00000000 --- a/crates/compiler/src/intermediate_bytecode.rs +++ /dev/null @@ -1,360 +0,0 @@ -use std::{collections::BTreeMap, fmt}; - -use p3_field::{PrimeCharacteristicRing, PrimeField64}; -use vm::{Label, Operation}; - -use crate::{F, lang::ConstExpression}; - -#[derive(Debug, Clone)] -pub(crate) struct IntermediateBytecode { - pub bytecode: BTreeMap>, - pub memory_size_per_function: BTreeMap, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum IntermediateValue { - Constant(ConstExpression), - Fp, - MemoryAfterFp { offset: ConstExpression }, // m[fp + offset] -} - -impl From for IntermediateValue { - fn from(value: ConstExpression) -> Self { - Self::Constant(value) - } -} -impl TryFrom for Operation { - type Error = String; - - fn try_from(value: HighLevelOperation) -> Result { - match value { - HighLevelOperation::Add => Ok(Self::Add), - HighLevelOperation::Mul => Ok(Self::Mul), - _ => Err(format!("Cannot convert {value:?} to +/x")), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum IntermediaryMemOrFpOrConstant { - MemoryAfterFp { offset: ConstExpression }, // m[fp + offset] - Fp, - Constant(ConstExpression), -} - -impl IntermediateValue { - pub(crate) const fn label(label: Label) -> Self { - Self::Constant(ConstExpression::label(label)) - } - - pub(crate) const fn is_constant(&self) -> bool { - matches!(self, Self::Constant(_)) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum HighLevelOperation { - Add, - Mul, - Sub, - Div, // in the end everything compiles to either Add or Mul - Exp, // Exponentiation, only for const expressions -} - -impl HighLevelOperation { - pub fn eval(&self, a: F, b: F) -> F { - match self { - Self::Add => a + b, - Self::Mul => a * b, - Self::Sub => a - b, - Self::Div => a / b, - Self::Exp => a.exp_u64(b.as_canonical_u64()), - } - } -} - -impl fmt::Display for HighLevelOperation { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Add => write!(f, "+"), - Self::Mul => write!(f, "*"), - Self::Sub => write!(f, "-"), - Self::Div => write!(f, "/"), - Self::Exp => write!(f, "**"), - } - } -} - -#[derive(Debug, Clone)] -pub(crate) enum IntermediateInstruction { - Computation { - operation: Operation, - arg_a: IntermediateValue, - arg_c: IntermediateValue, - res: IntermediateValue, - }, - Deref { - shift_0: ConstExpression, - shift_1: ConstExpression, - res: IntermediaryMemOrFpOrConstant, - }, // res = m[m[fp + shift_0]] - Panic, - Jump { - dest: IntermediateValue, - updated_fp: Option, - }, - JumpIfNotZero { - condition: IntermediateValue, - dest: IntermediateValue, - updated_fp: Option, - }, - Poseidon2_16 { - arg_a: IntermediateValue, // vectorized pointer, of size 1 - arg_b: IntermediateValue, // vectorized pointer, of size 1 - res: IntermediateValue, // vectorized pointer, of size 2 - }, - Poseidon2_24 { - arg_a: IntermediateValue, // vectorized pointer, of size 2 (2 first inputs) - arg_b: IntermediateValue, // vectorized pointer, of size 1 (3rd = last input) - res: IntermediateValue, // vectorized pointer, of size 1 (3rd = last output) - }, - DotProduct { - arg0: IntermediateValue, // vectorized pointer - arg1: IntermediateValue, // vectorized pointer - res: IntermediateValue, // vectorized pointer - size: ConstExpression, - }, - MultilinearEval { - coeffs: IntermediateValue, // vectorized pointer, chunk size = 2^n_vars - point: IntermediateValue, // vectorized pointer, of size `n_vars` - res: IntermediateValue, // vectorized pointer, of size 1 - n_vars: ConstExpression, - }, - // HINTS (does not appears in the final bytecode) - Inverse { - // If the value is zero, it will return zero. - arg: IntermediateValue, // the value to invert - res_offset: usize, // m[fp + res_offset] will contain the result - }, - RequestMemory { - offset: ConstExpression, // m[fp + offset] where the hint will be stored - size: IntermediateValue, // the hint - vectorized: bool, // if true, will be 8-alligned, and the returned pointer will be "divied" by 8 (i.e. everything is in chunks of 8 field elements) - }, - DecomposeBits { - res_offset: usize, // m[fp + res_offset..fp + res_offset + 31] will contain the decomposed bits - to_decompose: IntermediateValue, - }, - Print { - line_info: String, // information about the line where the print occurs - content: Vec, // values to print - }, -} - -impl IntermediateInstruction { - pub(crate) fn computation( - operation: HighLevelOperation, - arg_a: IntermediateValue, - arg_c: IntermediateValue, - res: IntermediateValue, - ) -> Self { - match operation { - HighLevelOperation::Add => Self::Computation { - operation: Operation::Add, - arg_a, - arg_c, - res, - }, - HighLevelOperation::Mul => Self::Computation { - operation: Operation::Mul, - arg_a, - arg_c, - res, - }, - HighLevelOperation::Sub => Self::Computation { - operation: Operation::Add, - arg_a: res, - arg_c, - res: arg_a, - }, - HighLevelOperation::Div => Self::Computation { - operation: Operation::Mul, - arg_a: res, - arg_c, - res: arg_a, - }, - HighLevelOperation::Exp => unreachable!(), - } - } - - pub(crate) const fn equality(left: IntermediateValue, right: IntermediateValue) -> Self { - Self::Computation { - operation: Operation::Add, - arg_a: left, - arg_c: IntermediateValue::Constant(ConstExpression::zero()), - res: right, - } - } -} - -impl ToString for IntermediateValue { - fn to_string(&self) -> String { - match self { - Self::Constant(value) => value.to_string(), - Self::Fp => "fp".to_string(), - Self::MemoryAfterFp { offset } => { - format!("m[fp + {offset}]") - } - } - } -} - -impl ToString for IntermediaryMemOrFpOrConstant { - fn to_string(&self) -> String { - match self { - Self::MemoryAfterFp { offset } => format!("m[fp + {offset}]"), - Self::Fp => "fp".to_string(), - Self::Constant(c) => c.to_string(), - } - } -} - -impl ToString for IntermediateInstruction { - fn to_string(&self) -> String { - match self { - Self::Deref { - shift_0, - shift_1, - res, - } => format!("{} = m[m[fp + {}] + {}]", res.to_string(), shift_0, shift_1), - Self::DotProduct { - arg0, - arg1, - res, - size, - } => format!( - "dot_product({}, {}, {}, {})", - arg0.to_string(), - arg1.to_string(), - res.to_string(), - size - ), - Self::MultilinearEval { - coeffs, - point, - res, - n_vars, - } => format!( - "multilinear_eval({}, {}, {}, {})", - coeffs.to_string(), - point.to_string(), - res.to_string(), - n_vars - ), - Self::DecomposeBits { - res_offset, - to_decompose, - } => { - format!( - "m[fp + {}..] = decompose_bits({})", - res_offset, - to_decompose.to_string() - ) - } - Self::Computation { - operation, - arg_a, - arg_c, - res, - } => { - format!( - "{} = {} {} {}", - res.to_string(), - arg_a.to_string(), - operation.to_string(), - arg_c.to_string() - ) - } - Self::Panic => "panic".to_string(), - Self::Jump { dest, updated_fp } => updated_fp.as_ref().map_or_else( - || format!("jump {}", dest.to_string()), - |fp| format!("jump {} with fp = {}", dest.to_string(), fp.to_string()), - ), - Self::JumpIfNotZero { - condition, - dest, - updated_fp, - } => { - if let Some(fp) = updated_fp { - format!( - "jump_if_not_zero {} to {} with fp = {}", - condition.to_string(), - dest.to_string(), - fp.to_string() - ) - } else { - format!( - "jump_if_not_zero {} to {}", - condition.to_string(), - dest.to_string() - ) - } - } - Self::Poseidon2_16 { arg_a, arg_b, res } => { - format!( - "{} = poseidon2_16({}, {})", - arg_a.to_string(), - arg_b.to_string(), - res.to_string(), - ) - } - Self::Poseidon2_24 { arg_a, arg_b, res } => { - format!( - "{} = poseidon2_24({}, {})", - res.to_string(), - arg_a.to_string(), - arg_b.to_string(), - ) - } - Self::Inverse { arg, res_offset } => { - format!("m[fp + {}] = inverse({})", res_offset, arg.to_string()) - } - Self::RequestMemory { - offset, - size, - vectorized, - } => format!( - "m[fp + {}] = {}({})", - offset, - if *vectorized { "malloc_vec" } else { "malloc" }, - size.to_string(), - ), - Self::Print { line_info, content } => format!( - "print {}: {}", - line_info, - content - .iter() - .map(std::string::ToString::to_string) - .collect::>() - .join(", ") - ), - } - } -} - -impl ToString for IntermediateBytecode { - fn to_string(&self) -> String { - let mut res = String::new(); - for (label, instructions) in &self.bytecode { - res.push_str(&format!("\n{label}:\n")); - for instruction in instructions { - res.push_str(&format!(" {}\n", instruction.to_string())); - } - } - res.push_str("\nMemory size per function:\n"); - for (function_name, size) in &self.memory_size_per_function { - res.push_str(&format!("{function_name}: {size}\n")); - } - res - } -} diff --git a/crates/compiler/src/intermediate_bytecode/instruction.rs b/crates/compiler/src/intermediate_bytecode/instruction.rs new file mode 100644 index 00000000..bf51c1e7 --- /dev/null +++ b/crates/compiler/src/intermediate_bytecode/instruction.rs @@ -0,0 +1,226 @@ +use std::fmt; + +use vm::Operation; + +use super::{HighLevelOperation, IntermediaryMemOrFpOrConstant, IntermediateValue}; +use crate::lang::ConstExpression; + +#[derive(Debug, Clone)] +pub enum IntermediateInstruction { + Computation { + operation: Operation, + arg_a: IntermediateValue, + arg_c: IntermediateValue, + res: IntermediateValue, + }, + Deref { + shift_0: ConstExpression, + shift_1: ConstExpression, + res: IntermediaryMemOrFpOrConstant, + }, // res = m[m[fp + shift_0]] + Panic, + Jump { + dest: IntermediateValue, + updated_fp: Option, + }, + JumpIfNotZero { + condition: IntermediateValue, + dest: IntermediateValue, + updated_fp: Option, + }, + Poseidon2_16 { + arg_a: IntermediateValue, // vectorized pointer, of size 1 + arg_b: IntermediateValue, // vectorized pointer, of size 1 + res: IntermediateValue, // vectorized pointer, of size 2 + }, + Poseidon2_24 { + arg_a: IntermediateValue, // vectorized pointer, of size 2 (2 first inputs) + arg_b: IntermediateValue, // vectorized pointer, of size 1 (3rd = last input) + res: IntermediateValue, // vectorized pointer, of size 1 (3rd = last output) + }, + DotProduct { + arg0: IntermediateValue, // vectorized pointer + arg1: IntermediateValue, // vectorized pointer + res: IntermediateValue, // vectorized pointer + size: ConstExpression, + }, + MultilinearEval { + coeffs: IntermediateValue, // vectorized pointer, chunk size = 2^n_vars + point: IntermediateValue, // vectorized pointer, of size `n_vars` + res: IntermediateValue, // vectorized pointer, of size 1 + n_vars: ConstExpression, + }, + // HINTS (does not appears in the final bytecode) + Inverse { + // If the value is zero, it will return zero. + arg: IntermediateValue, // the value to invert + res_offset: usize, // m[fp + res_offset] will contain the result + }, + RequestMemory { + offset: ConstExpression, // m[fp + offset] where the hint will be stored + size: IntermediateValue, // the hint + vectorized: bool, // if true, will be 8-alligned, and the returned pointer will be "divied" by 8 (i.e. everything is in chunks of 8 field elements) + }, + DecomposeBits { + res_offset: usize, // m[fp + res_offset..fp + res_offset + 31] will contain the decomposed bits + to_decompose: IntermediateValue, + }, + Print { + line_info: String, // information about the line where the print occurs + content: Vec, // values to print + }, +} + +impl IntermediateInstruction { + #[must_use] + pub fn computation( + operation: HighLevelOperation, + arg_a: IntermediateValue, + arg_c: IntermediateValue, + res: IntermediateValue, + ) -> Self { + match operation { + HighLevelOperation::Add => Self::Computation { + operation: Operation::Add, + arg_a, + arg_c, + res, + }, + HighLevelOperation::Mul => Self::Computation { + operation: Operation::Mul, + arg_a, + arg_c, + res, + }, + HighLevelOperation::Sub => Self::Computation { + operation: Operation::Add, + arg_a: res, + arg_c, + res: arg_a, + }, + HighLevelOperation::Div => Self::Computation { + operation: Operation::Mul, + arg_a: res, + arg_c, + res: arg_a, + }, + HighLevelOperation::Exp => unreachable!(), + } + } + + #[must_use] + pub const fn equality(left: IntermediateValue, right: IntermediateValue) -> Self { + Self::Computation { + operation: Operation::Add, + arg_a: left, + arg_c: IntermediateValue::Constant(ConstExpression::zero()), + res: right, + } + } + + #[must_use] + pub const fn is_hint(&self) -> bool { + match self { + Self::RequestMemory { .. } + | Self::Print { .. } + | Self::DecomposeBits { .. } + | Self::Inverse { .. } => true, + Self::Computation { .. } + | Self::Panic + | Self::Deref { .. } + | Self::JumpIfNotZero { .. } + | Self::Jump { .. } + | Self::Poseidon2_16 { .. } + | Self::Poseidon2_24 { .. } + | Self::DotProduct { .. } + | Self::MultilinearEval { .. } => false, + } + } +} + +impl fmt::Display for IntermediateInstruction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Deref { + shift_0, + shift_1, + res, + } => write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]"), + Self::DotProduct { + arg0, + arg1, + res, + size, + } => write!(f, "dot_product({arg0}, {arg1}, {res}, {size})"), + Self::MultilinearEval { + coeffs, + point, + res, + n_vars, + } => write!(f, "multilinear_eval({coeffs}, {point}, {res}, {n_vars})"), + Self::DecomposeBits { + res_offset, + to_decompose, + } => { + write!(f, "m[fp + {res_offset}..] = decompose_bits({to_decompose})") + } + Self::Computation { + operation, + arg_a, + arg_c, + res, + } => { + write!(f, "{res} = {arg_a} {operation} {arg_c}") + } + Self::Panic => write!(f, "panic"), + Self::Jump { dest, updated_fp } => { + if let Some(fp) = updated_fp { + write!(f, "jump {dest} with fp = {fp}") + } else { + write!(f, "jump {dest}") + } + } + Self::JumpIfNotZero { + condition, + dest, + updated_fp, + } => { + if let Some(fp) = updated_fp { + write!(f, "jump_if_not_zero {condition} to {dest} with fp = {fp}") + } else { + write!(f, "jump_if_not_zero {condition} to {dest}") + } + } + Self::Poseidon2_16 { arg_a, arg_b, res } => { + write!(f, "{arg_a} = poseidon2_16({arg_b}, {res})") + } + Self::Poseidon2_24 { arg_a, arg_b, res } => { + write!(f, "{res} = poseidon2_24({arg_a}, {arg_b})") + } + Self::Inverse { arg, res_offset } => { + write!(f, "m[fp + {res_offset}] = inverse({arg})") + } + Self::RequestMemory { + offset, + size, + vectorized, + } => write!( + f, + "m[fp + {}] = {}({})", + offset, + if *vectorized { "malloc_vec" } else { "malloc" }, + size, + ), + Self::Print { line_info, content } => write!( + f, + "print {}: {}", + line_info, + content + .iter() + .map(ToString::to_string) + .collect::>() + .join(", ") + ), + } + } +} diff --git a/crates/compiler/src/intermediate_bytecode/intermediate_value.rs b/crates/compiler/src/intermediate_bytecode/intermediate_value.rs new file mode 100644 index 00000000..293c1e70 --- /dev/null +++ b/crates/compiler/src/intermediate_bytecode/intermediate_value.rs @@ -0,0 +1,52 @@ +use std::fmt; + +use vm::F; + +use crate::{ + Compiler, + lang::{ConstExpression, Label}, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum IntermediateValue { + Constant(ConstExpression), + Fp, + MemoryAfterFp { offset: ConstExpression }, // m[fp + offset] +} + +impl IntermediateValue { + #[must_use] + pub const fn label(label: Label) -> Self { + Self::Constant(ConstExpression::label(label)) + } + + #[must_use] + pub const fn is_constant(&self) -> bool { + matches!(self, Self::Constant(_)) + } + + #[must_use] + pub fn try_as_constant(&self, compiler: &Compiler) -> Option { + if let Self::Constant(c) = self { + Some(c.eval(compiler)) + } else { + None + } + } +} + +impl From for IntermediateValue { + fn from(value: ConstExpression) -> Self { + Self::Constant(value) + } +} + +impl fmt::Display for IntermediateValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Constant(value) => write!(f, "{value}"), + Self::Fp => write!(f, "fp"), + Self::MemoryAfterFp { offset } => write!(f, "m[fp + {offset}]"), + } + } +} diff --git a/crates/compiler/src/intermediate_bytecode/mod.rs b/crates/compiler/src/intermediate_bytecode/mod.rs new file mode 100644 index 00000000..ba77282d --- /dev/null +++ b/crates/compiler/src/intermediate_bytecode/mod.rs @@ -0,0 +1,60 @@ +use std::{collections::BTreeMap, fmt}; + +pub mod intermediate_value; +pub use intermediate_value::*; +pub mod operation; +pub use operation::*; +pub mod instruction; +pub use instruction::*; +use vm::Label; + +use crate::lang::ConstExpression; + +#[derive(Debug, Clone)] +pub struct IntermediateBytecode { + pub bytecode: BTreeMap>, + pub memory_size_per_function: BTreeMap, +} + +impl fmt::Display for IntermediateBytecode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Iterate through each labeled block of instructions in the bytecode. + for (label, instructions) in &self.bytecode { + // Write the label for the current block, followed by a newline. + writeln!(f, "\n{label}:")?; + // Iterate through each instruction within the block. + for instruction in instructions { + // Write the instruction, indented with two spaces for readability. + writeln!(f, " {instruction}")?; + } + } + + // Write the header for the memory size section. + writeln!(f, "\nMemory size per function:")?; + // Iterate through the recorded memory sizes for each function. + for (function_name, size) in &self.memory_size_per_function { + // Write the function name and its corresponding memory size. + writeln!(f, "{function_name}: {size}")?; + } + + // Return Ok to indicate that formatting was successful. + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum IntermediaryMemOrFpOrConstant { + MemoryAfterFp { offset: ConstExpression }, // m[fp + offset] + Fp, + Constant(ConstExpression), +} + +impl fmt::Display for IntermediaryMemOrFpOrConstant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MemoryAfterFp { offset } => write!(f, "m[fp + {offset}]"), + Self::Fp => write!(f, "fp"), + Self::Constant(c) => write!(f, "{c}"), + } + } +} diff --git a/crates/compiler/src/intermediate_bytecode/operation.rs b/crates/compiler/src/intermediate_bytecode/operation.rs new file mode 100644 index 00000000..c708eb9c --- /dev/null +++ b/crates/compiler/src/intermediate_bytecode/operation.rs @@ -0,0 +1,38 @@ +use std::fmt; + +use p3_field::{PrimeCharacteristicRing, PrimeField64}; +use vm::F; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum HighLevelOperation { + Add, + Mul, + Sub, + Div, // in the end everything compiles to either Add or Mul + Exp, // Exponentiation, only for const expressions +} + +impl HighLevelOperation { + #[must_use] + pub fn eval(&self, a: F, b: F) -> F { + match self { + Self::Add => a + b, + Self::Mul => a * b, + Self::Sub => a - b, + Self::Div => a / b, + Self::Exp => a.exp_u64(b.as_canonical_u64()), + } + } +} + +impl fmt::Display for HighLevelOperation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Add => write!(f, "+"), + Self::Mul => write!(f, "*"), + Self::Sub => write!(f, "-"), + Self::Div => write!(f, "/"), + Self::Exp => write!(f, "**"), + } + } +} diff --git a/crates/vm/src/bytecode.rs b/crates/vm/src/bytecode.rs index 786645aa..0818b57c 100644 --- a/crates/vm/src/bytecode.rs +++ b/crates/vm/src/bytecode.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, fmt}; use p3_field::PrimeCharacteristicRing; @@ -38,6 +38,15 @@ pub enum Operation { Mul, } +impl fmt::Display for Operation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Add => write!(f, "+"), + Self::Mul => write!(f, "x"), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Instruction { // 3 basic instructions @@ -180,15 +189,6 @@ impl ToString for MemOrFpOrConstant { } } -impl ToString for Operation { - fn to_string(&self) -> String { - match self { - Self::Add => "+".to_string(), - Self::Mul => "x".to_string(), - } - } -} - impl ToString for Instruction { fn to_string(&self) -> String { match self { @@ -202,7 +202,7 @@ impl ToString for Instruction { "{} = {} {} {}", res.to_string(), arg_a.to_string(), - operation.to_string(), + operation, arg_c.to_string() ) }