diff --git a/Cargo.lock b/Cargo.lock index 39e518f1..85802607 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1837,6 +1837,29 @@ dependencies = [ "syn 2.0.111", ] +[[package]] +name = "env_filter" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bf3c259d255ca70051b30e2e95b5446cdb8949ac4cd22c0d7fd634d89f568e2" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "jiff", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -3276,6 +3299,30 @@ version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ee5b5339afb4c41626dde77b7a611bd4f2c202b897852b4bcf5d03eddc61010" +[[package]] +name = "jiff" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89a5b5e10d5a9ad6e5d1f4bd58225f655d6fe9767575a5e8ac5a6fe64e04495" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff7a39c8862fc1369215ccf0a8f12dd4598c7f6484704359f0351bd617034dbf" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "jobserver" version = "0.1.34" @@ -4522,6 +4569,15 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f59e70c4aef1e55797c2e8fd94a4f2a973fc972cfde0e0b05f683667b0cd39dd" +[[package]] +name = "portable-atomic-util" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -7542,6 +7598,7 @@ dependencies = [ "bytemuck", "elf", "enum-map", + "env_logger", "eyre", "hashbrown 0.14.5", "hex", @@ -7549,6 +7606,7 @@ dependencies = [ "log", "nohash-hasher", "num", + "num_enum", "p3-field", "p3-koala-bear", "p3-maybe-rayon", @@ -7564,6 +7622,7 @@ dependencies = [ "thiserror 1.0.69", "tiny-keccak", "tracing", + "tracing-subscriber 0.3.22", "typenum", "vec_map", "zkm-curves", @@ -7673,6 +7732,7 @@ dependencies = [ name = "zkm-derive" version = "1.2.4" dependencies = [ + "proc-macro2", "quote", "syn 1.0.109", ] @@ -7702,6 +7762,20 @@ dependencies = [ "zkm-primitives 1.2.4", ] +[[package]] +name = "zkm-picus" +version = "1.2.4" +dependencies = [ + "clap", + "p3-air", + "p3-field", + "p3-koala-bear", + "p3-matrix", + "zkm-core-executor", + "zkm-core-machine", + "zkm-stark", +] + [[package]] name = "zkm-primitives" version = "1.2.3" diff --git a/Cargo.toml b/Cargo.toml index 03e9e45d..7c5a8e32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "crates/cuda", "crates/curves", "crates/derive", + "crates/picus", "crates/primitives", "crates/prover", "crates/recursion/circuit", @@ -115,6 +116,7 @@ zkm-build = { path = "crates/build" } zkm-sdk = { path = "crates/sdk" } zkm-cuda = { path = "crates/cuda" } zkm-verifier = { path = "crates/verifier" } +zkm-picus = {path = "crates/picus"} zkm-lib = { path = "crates/zkvm/lib", default-features = false } zkm-zkvm = { path = "crates/zkvm/entrypoint", default-features = false } diff --git a/crates/core/executor/Cargo.toml b/crates/core/executor/Cargo.toml index 7eeef049..5352051a 100644 --- a/crates/core/executor/Cargo.toml +++ b/crates/core/executor/Cargo.toml @@ -46,6 +46,9 @@ vec_map = { version = "0.8.2", features = ["serde"] } enum-map = { version = "2.7.3", features = ["serde"] } sha2 = { workspace = true } anyhow = { workspace = true } +tracing-subscriber = "0.3.19" +env_logger = "0.11.6" +num_enum = "0.7.5" [dev-dependencies] test-artifacts = { path = "../../test-artifacts" } diff --git a/crates/core/executor/src/opcode.rs b/crates/core/executor/src/opcode.rs index 5407508a..e5d1d059 100644 --- a/crates/core/executor/src/opcode.rs +++ b/crates/core/executor/src/opcode.rs @@ -1,6 +1,7 @@ //! Opcodes for ZKM. use enum_map::Enum; +use num_enum::TryFromPrimitive; use p3_field::Field; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -8,7 +9,18 @@ use std::fmt::Display; /// An opcode (short for "operation code") specifies the operation to be performed by the processor. #[allow(non_camel_case_types)] #[derive( - Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Enum, + TryFromPrimitive, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + Serialize, + Deserialize, + PartialOrd, + Ord, + Enum, )] #[repr(u8)] pub enum Opcode { diff --git a/crates/core/machine/src/alu/add_sub/mod.rs b/crates/core/machine/src/alu/add_sub/mod.rs index c4d13423..6fa46bd0 100644 --- a/crates/core/machine/src/alu/add_sub/mod.rs +++ b/crates/core/machine/src/alu/add_sub/mod.rs @@ -13,9 +13,9 @@ use zkm_core_executor::{ events::{AluEvent, ByteLookupEvent, ByteRecord}, ExecutionRecord, Opcode, Program, }; -use zkm_derive::AlignedBorrow; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; use zkm_stark::{ - air::{MachineAir, ZKMAirBuilder}, + air::{MachineAir, PicusInfo, ZKMAirBuilder}, Word, }; @@ -38,7 +38,7 @@ pub const NUM_ADD_SUB_COLS: usize = size_of::>(); pub struct AddSubChip; /// The column layout for the chip. -#[derive(AlignedBorrow, Default, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Clone, Copy)] #[repr(C)] pub struct AddSubCols { /// The current/next pc, used for instruction lookup table. @@ -56,9 +56,11 @@ pub struct AddSubCols { pub operand_2: Word, /// Flag indicating whether the opcode is `ADD`. + #[picus(selector)] pub is_add: T, /// Flag indicating whether the opcode is `SUB`. + #[picus(selector)] pub is_sub: T, } @@ -78,6 +80,9 @@ impl MachineAir for AddSubChip { next_power_of_two(input.add_sub_events.len(), input.fixed_log2_rows::(self)); Some(nb_rows) } + fn picus_info(&self) -> PicusInfo { + AddSubCols::::picus_info() + } fn generate_trace( &self, diff --git a/crates/core/machine/src/alu/bitwise/mod.rs b/crates/core/machine/src/alu/bitwise/mod.rs index 28c6dd20..35d0fd5c 100644 --- a/crates/core/machine/src/alu/bitwise/mod.rs +++ b/crates/core/machine/src/alu/bitwise/mod.rs @@ -13,9 +13,9 @@ use zkm_core_executor::{ events::{AluEvent, ByteLookupEvent, ByteRecord}, ByteOpcode, ExecutionRecord, Opcode, Program, }; -use zkm_derive::AlignedBorrow; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; use zkm_stark::{ - air::{MachineAir, ZKMAirBuilder}, + air::{MachineAir, PicusInfo, ZKMAirBuilder}, Word, }; @@ -29,7 +29,7 @@ pub const NUM_BITWISE_COLS: usize = size_of::>(); pub struct BitwiseChip; /// The column layout for the chip. -#[derive(AlignedBorrow, Default, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Clone, Copy)] #[repr(C)] pub struct BitwiseCols { /// The current/next pc, used for instruction lookup table. @@ -46,15 +46,19 @@ pub struct BitwiseCols { pub c: Word, /// If the opcode is NOR. + #[picus(selector)] pub is_nor: T, /// If the opcode is XOR. + #[picus(selector)] pub is_xor: T, // If the opcode is OR. + #[picus(selector)] pub is_or: T, /// If the opcode is AND. + #[picus(selector)] pub is_and: T, } diff --git a/crates/core/machine/src/alu/clo_clz/mod.rs b/crates/core/machine/src/alu/clo_clz/mod.rs index 10a0e3f2..7fe39899 100644 --- a/crates/core/machine/src/alu/clo_clz/mod.rs +++ b/crates/core/machine/src/alu/clo_clz/mod.rs @@ -22,8 +22,8 @@ use zkm_core_executor::{ events::{ByteLookupEvent, ByteRecord}, ByteOpcode, ExecutionRecord, Opcode, Program, }; -use zkm_derive::AlignedBorrow; -use zkm_stark::{air::MachineAir, Word}; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; +use zkm_stark::{air::MachineAir, PicusInfo, Word}; use crate::{air::ZKMCoreAirBuilder, utils::pad_rows_fixed, CoreChipError}; @@ -39,7 +39,7 @@ const BYTE_SIZE: usize = 8; pub struct CloClzChip; /// The column layout for the chip. -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Debug, Clone, Copy)] #[repr(C)] pub struct CloClzCols { /// The current/next pc, used for instruction lookup table. @@ -63,9 +63,11 @@ pub struct CloClzCols { pub sr1: Word, /// Flag to indicate whether the opcode is CLZ. + #[picus(selector)] pub is_clz: T, /// Flag to indicate whether the opcode is CLO. + #[picus(selector)] pub is_clo: T, /// Selector to know whether this row is enabled. @@ -83,6 +85,10 @@ impl MachineAir for CloClzChip { "CloClz".to_string() } + fn picus_info(&self) -> PicusInfo { + CloClzCols::::picus_info() + } + fn generate_trace( &self, input: &ExecutionRecord, diff --git a/crates/core/machine/src/alu/divrem/mod.rs b/crates/core/machine/src/alu/divrem/mod.rs index 43476fb4..9e0d0da6 100644 --- a/crates/core/machine/src/alu/divrem/mod.rs +++ b/crates/core/machine/src/alu/divrem/mod.rs @@ -76,9 +76,12 @@ use zkm_core_executor::{ }; use crate::{memory::MemoryReadWriteCols, CoreChipError}; -use zkm_derive::AlignedBorrow; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; use zkm_primitives::consts::WORD_SIZE; -use zkm_stark::{air::MachineAir, Word}; +use zkm_stark::{ + air::{MachineAir, PicusInfo}, + Word, +}; use crate::{ air::{WordAirBuilder, ZKMCoreAirBuilder}, @@ -101,7 +104,7 @@ const LONG_WORD_SIZE: usize = 2 * WORD_SIZE; pub struct DivRemChip; /// The column layout for the chip. -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Debug, Clone, Copy)] #[repr(C)] pub struct DivRemCols { /// The current/next pc, used for instruction lookup table. @@ -139,15 +142,19 @@ pub struct DivRemCols { pub is_c_0: IsZeroWordOperation, /// Flag to indicate whether the opcode is DIV. + #[picus(selector)] pub is_div: T, /// Flag to indicate whether the opcode is DIVU. + #[picus(selector)] pub is_divu: T, /// Flag to indicate whether the opcode is MOD. + #[picus(selector)] pub is_mod: T, /// Flag to indicate whether the opcode is MODU. + #[picus(selector)] pub is_modu: T, /// Flag to indicate whether the division operation overflows. diff --git a/crates/core/machine/src/alu/lt/mod.rs b/crates/core/machine/src/alu/lt/mod.rs index 0bc07f21..718a7372 100644 --- a/crates/core/machine/src/alu/lt/mod.rs +++ b/crates/core/machine/src/alu/lt/mod.rs @@ -13,10 +13,10 @@ use zkm_core_executor::{ events::{AluEvent, ByteLookupEvent, ByteRecord}, ByteOpcode, ExecutionRecord, Opcode, Program, }; -use zkm_derive::AlignedBorrow; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; use zkm_stark::{ air::{BaseAirBuilder, MachineAir, ZKMAirBuilder}, - Word, + PicusInfo, Word, }; use crate::{ @@ -32,7 +32,7 @@ pub const NUM_LT_COLS: usize = size_of::>(); pub struct LtChip; /// The column layout for the chip. -#[derive(AlignedBorrow, Default, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Clone, Copy)] #[repr(C)] pub struct LtCols { /// The current/next pc, used for instruction lookup table. @@ -40,9 +40,11 @@ pub struct LtCols { pub next_pc: T, /// If the opcode is SLT. + #[picus(selector)] pub is_slt: T, /// If the opcode is SLTU. + #[picus(selector)] pub is_sltu: T, /// The output operand. @@ -104,6 +106,10 @@ impl MachineAir for LtChip { "Lt".to_string() } + fn picus_info(&self) -> PicusInfo { + LtCols::::picus_info() + } + fn generate_trace( &self, input: &ExecutionRecord, diff --git a/crates/core/machine/src/alu/mul/mod.rs b/crates/core/machine/src/alu/mul/mod.rs index 1f882647..b7e51414 100644 --- a/crates/core/machine/src/alu/mul/mod.rs +++ b/crates/core/machine/src/alu/mul/mod.rs @@ -44,9 +44,9 @@ use zkm_core_executor::{ events::{ByteLookupEvent, ByteRecord, CompAluEvent, MemoryAccessPosition, MemoryRecordEnum}, ByteOpcode, ExecutionRecord, Opcode, Program, }; -use zkm_derive::AlignedBorrow; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; use zkm_primitives::consts::WORD_SIZE; -use zkm_stark::{air::MachineAir, Word}; +use zkm_stark::{air::MachineAir, PicusInfo, Word}; use crate::{ air::{WordAirBuilder, ZKMCoreAirBuilder}, @@ -74,10 +74,11 @@ const BYTE_MASK: u8 = 0xff; pub struct MulChip; /// The column layout for the chip. -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Debug, Clone, Copy)] #[repr(C)] pub struct MulCols { /// The current/next pc, used for instruction lookup table. + #[picus(input)] pub pc: T, pub next_pc: T, @@ -112,15 +113,17 @@ pub struct MulCols { pub c_sign_extend: T, /// Flag indicating whether the opcode is `MUL`. + #[picus(selector)] pub is_mul: T, /// Flag indicating whether the opcode is `MULT`. + #[picus(selector)] pub is_mult: T, /// Flag indicating whether the opcode is `MULTU`. + #[picus(selector)] pub is_multu: T, - /// Selector to know whether this row is enabled. pub is_real: T, /// Access to hi register @@ -146,6 +149,10 @@ impl MachineAir for MulChip { "Mul".to_string() } + fn picus_info(&self) -> PicusInfo { + MulCols::::picus_info() + } + fn generate_trace( &self, input: &ExecutionRecord, @@ -395,11 +402,11 @@ where let product = { for i in 0..PRODUCT_SIZE { if i == 0 { - builder.assert_eq(local.product[i], m[i].clone() - local.carry[i] * base); + builder.assert_eq(m[i].clone(), local.carry[i] * base + local.product[i]); } else { builder.assert_eq( - local.product[i], - m[i].clone() + local.carry[i - 1] - local.carry[i] * base, + local.product[i] + local.carry[i] * base - local.carry[i - 1], + m[i].clone(), ); } } diff --git a/crates/core/machine/src/alu/sll/mod.rs b/crates/core/machine/src/alu/sll/mod.rs index 46b5b2a5..006d6aa2 100644 --- a/crates/core/machine/src/alu/sll/mod.rs +++ b/crates/core/machine/src/alu/sll/mod.rs @@ -45,9 +45,9 @@ use zkm_core_executor::{ events::{AluEvent, ByteLookupEvent, ByteRecord}, ExecutionRecord, Opcode, Program, }; -use zkm_derive::AlignedBorrow; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; use zkm_primitives::consts::WORD_SIZE; -use zkm_stark::{air::MachineAir, Word}; +use zkm_stark::{air::MachineAir, PicusInfo, Word}; use crate::{air::ZKMCoreAirBuilder, utils::pad_rows_fixed, CoreChipError}; @@ -62,7 +62,7 @@ pub const BYTE_SIZE: usize = 8; pub struct ShiftLeft; /// The column layout for the chip. -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Debug, Clone, Copy)] #[repr(C)] pub struct ShiftLeftCols { /// The current/next pc, used for instruction lookup table. @@ -96,6 +96,7 @@ pub struct ShiftLeftCols { /// A boolean array whose `i`th element indicates whether `num_bytes_to_shift = i`. pub shift_by_n_bytes: [T; WORD_SIZE], + #[picus(selector)] pub is_real: T, } @@ -110,6 +111,10 @@ impl MachineAir for ShiftLeft { "ShiftLeft".to_string() } + fn picus_info(&self) -> PicusInfo { + ShiftLeftCols::::picus_info() + } + fn generate_trace( &self, input: &ExecutionRecord, diff --git a/crates/core/machine/src/alu/sr/mod.rs b/crates/core/machine/src/alu/sr/mod.rs index f7c32e03..5f67010d 100644 --- a/crates/core/machine/src/alu/sr/mod.rs +++ b/crates/core/machine/src/alu/sr/mod.rs @@ -57,9 +57,9 @@ use zkm_core_executor::{ events::{AluEvent, ByteLookupEvent, ByteRecord}, ByteOpcode, ExecutionRecord, Opcode, Program, }; -use zkm_derive::AlignedBorrow; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; use zkm_primitives::consts::WORD_SIZE; -use zkm_stark::{air::MachineAir, Word}; +use zkm_stark::{air::MachineAir, PicusInfo, Word}; use crate::{ air::ZKMCoreAirBuilder, @@ -83,7 +83,7 @@ const BYTE_SIZE: usize = 8; pub struct ShiftRightChip; /// The column layout for the chip. -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Debug, Clone, Copy)] #[repr(C)] pub struct ShiftRightCols { /// The current/next pc, used for instruction lookup table. @@ -124,12 +124,15 @@ pub struct ShiftRightCols { pub c_least_sig_byte: [T; BYTE_SIZE], /// If the opcode is SRL. + #[picus(selector)] pub is_srl: T, /// If the opcode is ROR. + #[picus(selector)] pub is_ror: T, /// If the opcode is SRA. + #[picus(selector)] pub is_sra: T, /// Selector to know whether this row is enabled. @@ -147,6 +150,10 @@ impl MachineAir for ShiftRightChip { "ShiftRight".to_string() } + fn picus_info(&self) -> PicusInfo { + ShiftRightCols::::picus_info() + } + fn generate_trace( &self, input: &ExecutionRecord, diff --git a/crates/core/machine/src/control_flow/branch/columns.rs b/crates/core/machine/src/control_flow/branch/columns.rs index 2207f6a9..00d6d93a 100644 --- a/crates/core/machine/src/control_flow/branch/columns.rs +++ b/crates/core/machine/src/control_flow/branch/columns.rs @@ -1,13 +1,13 @@ use std::mem::size_of; -use zkm_derive::AlignedBorrow; -use zkm_stark::Word; +use zkm_derive::{AlignedBorrow, PicusAnnotations}; +use zkm_stark::{PicusInfo, Word}; use crate::operations::KoalaBearWordRangeChecker; pub const NUM_BRANCH_COLS: usize = size_of::>(); /// The column layout for branching. -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[derive(AlignedBorrow, PicusAnnotations, Default, Debug, Clone, Copy)] #[repr(C)] pub struct BranchColumns { /// The current program counter. @@ -36,11 +36,17 @@ pub struct BranchColumns { pub op_c_value: Word, /// Branch Instructions Selectors. + #[picus(selector)] pub is_beq: T, + #[picus(selector)] pub is_bne: T, + #[picus(selector)] pub is_bltz: T, + #[picus(selector)] pub is_blez: T, + #[picus(selector)] pub is_bgtz: T, + #[picus(selector)] pub is_bgez: T, /// The branching column is equal to: diff --git a/crates/core/machine/src/control_flow/branch/trace.rs b/crates/core/machine/src/control_flow/branch/trace.rs index 7985af23..0823f882 100644 --- a/crates/core/machine/src/control_flow/branch/trace.rs +++ b/crates/core/machine/src/control_flow/branch/trace.rs @@ -9,7 +9,7 @@ use zkm_core_executor::{ events::{BranchEvent, ByteLookupEvent, ByteRecord}, ExecutionRecord, Opcode, Program, }; -use zkm_stark::{air::MachineAir, Word}; +use zkm_stark::{air::MachineAir, PicusInfo, Word}; use crate::{ utils::{next_power_of_two, zeroed_f_vec}, @@ -29,6 +29,10 @@ impl MachineAir for BranchChip { "Branch".to_string() } + fn picus_info(&self) -> PicusInfo { + BranchColumns::::picus_info() + } + fn generate_trace( &self, input: &ExecutionRecord, diff --git a/crates/core/machine/src/mips/mod.rs b/crates/core/machine/src/mips/mod.rs index fa6a5343..f086c780 100644 --- a/crates/core/machine/src/mips/mod.rs +++ b/crates/core/machine/src/mips/mod.rs @@ -19,7 +19,7 @@ use zkm_core_executor::{ }; use zkm_curves::weierstrass::{bls12_381::Bls12381BaseField, bn254::Bn254BaseField}; use zkm_stark::{ - air::{LookupScope, MachineAir, ZKM_PROOF_NUM_PV_ELTS}, + air::{LookupScope, MachineAir, PicusInfo, ZKM_PROOF_NUM_PV_ELTS}, Chip, LookupKind, StarkGenericConfig, StarkMachine, }; diff --git a/crates/derive/Cargo.toml b/crates/derive/Cargo.toml index 3baacb85..cbec4277 100644 --- a/crates/derive/Cargo.toml +++ b/crates/derive/Cargo.toml @@ -14,4 +14,5 @@ proc-macro = true [dependencies] quote = "1.0" -syn = { version = "1.0", features = ["full"] } +proc-macro2 = "1" +syn = { version = "1.0", features = ["full"] } \ No newline at end of file diff --git a/crates/derive/src/lib.rs b/crates/derive/src/lib.rs index afb7197d..c70b5bd1 100644 --- a/crates/derive/src/lib.rs +++ b/crates/derive/src/lib.rs @@ -23,7 +23,7 @@ // THE SOFTWARE. extern crate proc_macro; - +mod picus_annotations; use proc_macro::TokenStream; use quote::quote; use syn::{ @@ -206,6 +206,14 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { } }); + // Calls the underlying chip's `picus_info()` method + let picus_info_arms = variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + quote! { + #name::#variant_name(x) => <#field_ty as zkm_stark::air::MachineAir>::picus_info(x) + } + }); + let machine_air = quote! { impl #impl_generics zkm_stark::air::MachineAir for #name #ty_generics #where_clause { type Record = #execution_record_path; @@ -272,6 +280,12 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { #(#local_only_arms,)* } } + + fn picus_info(&self) -> PicusInfo { + match self { + #(#picus_info_arms,)* + } + } } }; @@ -415,3 +429,8 @@ fn find_eval_trait_bound(attrs: &[syn::Attribute]) -> Option { None } + +#[proc_macro_derive(PicusAnnotations, attributes(picus))] +pub fn picus_annotations_derive(input: TokenStream) -> TokenStream { + picus_annotations::picus_annotations_derive(input) +} diff --git a/crates/derive/src/picus_annotations.rs b/crates/derive/src/picus_annotations.rs new file mode 100644 index 00000000..c5d19669 --- /dev/null +++ b/crates/derive/src/picus_annotations.rs @@ -0,0 +1,271 @@ +use std::collections::HashSet; + +use proc_macro::TokenStream; +use quote::quote; +use syn::Generics; +use syn::{parse_macro_input, DeriveInput}; + +use syn::{ + parse::{Parse, ParseStream}, + parse_quote, + punctuated::Punctuated, + GenericArgument, Path, PathArguments, Result, Token, Type, TypeArray, TypeReference, TypeSlice, +}; + +#[derive(Default, Debug, Clone)] +struct PicusArgs { + input: bool, + output: bool, + selector: bool, +} + +enum Arg { + Input, + Output, + Selector, +} + +impl Parse for Arg { + // parses the arguments for the picus attribute + fn parse(input: ParseStream<'_>) -> Result { + let key: Path = input.parse()?; + + let is = |s: &str| key.is_ident(s); + + if is("input") { + return Ok(Arg::Input); + } + if is("output") { + return Ok(Arg::Output); + } + if is("selector") { + return Ok(Arg::Selector); + } + + Err(syn::Error::new_spanned(key, "unknown key in #[picus(...)]")) + } +} + +fn parse_picus_attr(attr: &syn::Attribute) -> syn::Result> { + // check that the attribute is a picus attribute + if !attr.path.is_ident("picus") { + return Ok(None); + } + // parse the attributes + let items = attr.parse_args_with(Punctuated::::parse_terminated)?; + let mut out = PicusArgs::default(); + for it in items { + match it { + Arg::Input => out.input = true, + Arg::Output => out.output = true, + Arg::Selector => out.selector = true, + } + } + Ok(Some(out)) +} + +// ---------- type substitution: replace *type* params with `u8` ---------- +fn type_params_set(gens: &Generics) -> HashSet { + gens.type_params().map(|tp| tp.ident.clone()).collect() +} + +// column values are determined by computing the offset of the ColStruct when instantiated +// with the u8 parameter. This utility substitutes a type parameter with `u8` so we can calculate offsets. +fn ty_sub_u8(mut ty: Type, type_params: &HashSet) -> Type { + match ty { + Type::Path(ref mut tp) => { + if tp.qself.is_none() && tp.path.segments.len() == 1 { + let seg = &tp.path.segments[0]; + if type_params.contains(&seg.ident) { + return parse_quote!(u8); + } + } + for seg in tp.path.segments.iter_mut() { + if let PathArguments::AngleBracketed(ref mut ab) = seg.arguments { + for arg in ab.args.iter_mut() { + if let GenericArgument::Type(ref mut inner) = arg { + *inner = ty_sub_u8(inner.clone(), type_params); + } + } + } + } + ty + } + Type::Reference(TypeReference { ref mut elem, .. }) => { + **elem = ty_sub_u8((**elem).clone(), type_params); + ty + } + Type::Array(TypeArray { ref mut elem, .. }) + | Type::Slice(TypeSlice { ref mut elem, .. }) => { + **elem = ty_sub_u8((**elem).clone(), type_params); + ty + } + Type::Tuple(ref mut tup) => { + for el in tup.elems.iter_mut() { + *el = ty_sub_u8(el.clone(), type_params); + } + ty + } + _ => ty, + } +} + +// Build Self actual type args; keep lifetimes/consts as-is. +fn concrete_type_args(gens: &Generics) -> proc_macro2::TokenStream { + let args = gens.params.iter().map(|p| match p { + syn::GenericParam::Type(_) => quote!(u8), + syn::GenericParam::Lifetime(lt) => { + let lt = <.lifetime; + quote!(#lt) + } + syn::GenericParam::Const(c) => { + let id = &c.ident; + quote!(#id) + } + }); + quote!(<#(#args),*>) +} + +// impl generics = lifetimes + consts only (type params fixed to u8) +fn impl_generics_without_type_params(gens: &Generics) -> proc_macro2::TokenStream { + let lifetimes = gens.lifetimes().map(|d| d.lifetime.clone()); + let consts = gens.const_params().map(|c| { + let id = &c.ident; + let ty = &c.ty; + quote!(const #id: #ty) + }); + let mut parts: Vec = Vec::new(); + for lt in lifetimes { + parts.push(quote!(#lt)); + } + for c in consts { + parts.push(c); + } + if parts.is_empty() { + quote!() + } else { + quote!(< #(#parts),* >) + } +} + +pub fn picus_annotations_derive(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let ident = input.ident.clone(); + let gens = input.generics.clone(); + + let data = match &input.data { + syn::Data::Struct(s) => s, + _ => { + return syn::Error::new_spanned(&input, "PicusInfoGen only supports structs") + .to_compile_error() + .into() + } + }; + let fields = match &data.fields { + syn::Fields::Named(f) => &f.named, + _ => { + return syn::Error::new_spanned(&input, "PicusInfoGen requires named fields") + .to_compile_error() + .into() + } + }; + + let type_params = type_params_set(&gens); + let impl_gens = impl_generics_without_type_params(&gens); + let self_args = concrete_type_args(&gens); + let self_conc = quote!(#ident #self_args); + + // Per-field code + let mut steps = Vec::new(); + for field in fields.iter() { + let f_ident = field.ident.as_ref().unwrap(); + let f_name = f_ident.to_string(); + // Collect flags + let mut flags = PicusArgs::default(); + for attr in &field.attrs { + if attr.path.is_ident("picus") { + match parse_picus_attr(attr) { + Ok(Some(a)) => { + flags.input |= a.input; + flags.output |= a.output; + flags.selector |= a.selector; + } + Ok(None) => {} + Err(e) => return e.to_compile_error().into(), + } + } + } + + // Field type with all *type* params → u8 + let conc_ty: Type = ty_sub_u8(field.ty.clone(), &type_params); + + // Add name to id map + let push_name = { + quote! { + if width > 0 { + info.name_to_colrange.insert((#f_name).to_string(), (cur, cur+width)); + for x in cur..(cur+width) { + info.col_to_name.insert(x, format!("{}_{}", #f_name, x)); + } + } + } + }; + let push_in = if flags.input { + quote! { if width > 0 { info.input_ranges.push((cur, cur + width, #f_name.to_string())); } } + } else { + quote!() + }; + + let push_out = if flags.output { + quote! { if width > 0 { info.output_ranges.push((cur, cur + width, #f_name.to_string())); } } + } else { + quote!() + }; + + let push_sel = if flags.selector { + quote! { + if width > 0 { + for i in 0..width { + info.selector_indices.push((cur + i, #f_name.to_string())); + } + } + + } + } else { + quote!() + }; + // If the field name is "is_real" then add that mark it in PicusInfo + let push_is_real = if f_name == "is_real" { + quote! { + if width > 0 { + info.is_real_index = Some(cur); + } + } + } else { + quote!() + }; + + steps.push(quote! {{ + let width: usize = ::core::mem::size_of::<#conc_ty>(); + #push_name + #push_in + #push_out + #push_sel + #push_is_real + cur += width; + }}); + } + + let expanded = quote! { + // Implement on the concrete instantiation where *type* params are `u8` + impl #impl_gens #self_conc { + pub fn picus_info() -> PicusInfo { + let mut info = PicusInfo::default(); + let mut cur: usize = 0; // 1 column == 1 byte + #(#steps)* + info + } + } + }; + expanded.into() +} diff --git a/crates/picus/Cargo.toml b/crates/picus/Cargo.toml new file mode 100644 index 00000000..2cbd865e --- /dev/null +++ b/crates/picus/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "zkm-picus" +version = { workspace = true } +edition = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +keywords = { workspace = true } +categories = { workspace = true } + +[dependencies] +zkm-core-machine = { workspace = true } +zkm-core-executor = {workspace = true} +zkm-stark = {workspace = true} +p3-air = { workspace = true } +p3-matrix = { workspace = true } +p3-field = {workspace = true} +p3-koala-bear = {workspace = true} + +clap = { version = "4.5.7", features = ["derive", "env"] } diff --git a/crates/picus/README.md b/crates/picus/README.md new file mode 100644 index 00000000..389a4992 --- /dev/null +++ b/crates/picus/README.md @@ -0,0 +1,17 @@ +# Usage Instructions +This document describes how to use the Picus translator +## Build +From the `picus/` directory just run: + +``` +cargo build +``` + +## Run +To extract the AddSub chip run the following command from the top level directory: + +``` +./target/debug/zkm-picus --chip AddSub +``` + +This will produce a file called `AddSub.picus` inside of the directory `picus_output`. The directory can be overriden by setting the environment variable `PICUS_OUT_DIR` \ No newline at end of file diff --git a/crates/picus/src/lib.rs b/crates/picus/src/lib.rs new file mode 100644 index 00000000..df976d8f --- /dev/null +++ b/crates/picus/src/lib.rs @@ -0,0 +1,5 @@ +pub mod opcode_spec; +pub mod pcl; +pub mod picus_builder; + +use pcl::*; diff --git a/crates/picus/src/main.rs b/crates/picus/src/main.rs new file mode 100644 index 00000000..5ef6ba46 --- /dev/null +++ b/crates/picus/src/main.rs @@ -0,0 +1,171 @@ +use std::{collections::BTreeMap, path::PathBuf}; + +use clap::{Parser, ValueHint}; +use p3_air::{Air, BaseAir}; +use zkm_core_machine::MipsAir; +use zkm_picus::{ + pcl::{ + initialize_fresh_var_ctr, set_field_modulus, set_picus_names, Felt, PicusModule, + PicusProgram, + }, + picus_builder::PicusBuilder, +}; +use zkm_stark::{Chip, MachineAir}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long, help = "Chip name to compile")] + pub chip: Option, + + /// Directory to write the extracted Picus program(s). + /// + /// Can be overridden with PICUS_OUT_DIR. + #[arg( + long = "picus-out-dir", + value_name = "DIR", + value_hint = ValueHint::DirPath, + env = "PICUS_OUT_DIR", + default_value = "picus_out" + )] + + /// Directory to write the extracted Picus program(s). + /// + /// Can be overridden with PICUS_OUT_DIR. + pub picus_out_dir: PathBuf, +} + +/// Analyze a single chip and process all its deferred sub-chip tasks. +/// This replaces direct recursion in `MessageBuilder::send()`. +fn analyze_chip<'chips, A>( + chip: &'chips Chip, + chips: &'chips [Chip], + picus_builder: Option<&mut PicusBuilder<'chips, A>>, +) -> (PicusModule, BTreeMap) +where + A: MachineAir + BaseAir + Air>, +{ + println!("Analyzing chip: {}", chip.name()); + + let builder = if let Some(builder) = picus_builder { + builder + } else { + &mut PicusBuilder::new(chip, PicusModule::new(chip.name()), chips, None, None) + }; + chip.air.eval(builder); + + // Process deferred tasks recursively + while let Some(task) = builder.concrete_pending_tasks.pop() { + let target_chip = builder.get_chip(&task.chip_name); + println!("Target chip: {:?}", &task.chip_name); + let target_picus_info = target_chip.picus_info(); + + let mut sub_builder = PicusBuilder::new( + target_chip, + PicusModule::new(task.chip_name.clone()), + builder.chips, + Some(task.main_vars.clone()), + Some(task.multiplicity.clone()), + ); + + let (mut sub_module, aux_modules) = + analyze_chip(target_chip, builder.chips, Some(&mut sub_builder)); + // Merge submodules + builder.aux_modules.extend(aux_modules.into_iter()); + + sub_module.apply_multiplier(task.multiplicity); + // partially evaluate + + let selector_col = target_picus_info.name_to_colrange.get(&task.selector).unwrap().0; + let mut env = BTreeMap::new(); + // Set `is_real = 1` if it is set in `picus_info` + if let Some(id) = target_picus_info.is_real_index { + env.insert(id, 1); + } + env.insert(selector_col, 1); + for (other_selector_col, _) in &target_picus_info.selector_indices { + if selector_col == *other_selector_col { + continue; + } + env.insert(*other_selector_col, 0); + } + let updated_picus_module = sub_module.partial_eval(&env); + println!("Updated module: {updated_picus_module}"); + builder.picus_module.constraints.extend_from_slice(&updated_picus_module.constraints); + builder.picus_module.calls.extend_from_slice(&updated_picus_module.calls); + builder.picus_module.postconditions.extend_from_slice(&sub_module.postconditions); + } + + (builder.picus_module.clone(), builder.aux_modules.clone()) +} + +fn main() { + let args = Args::parse(); + + if args.chip.is_none() { + panic!("Chip name must be provided!"); + } + + let chip_name = args.chip.unwrap(); + let chips = MipsAir::::chips(); + + // Get the chip + let chip = chips + .iter() + .find(|c| c.name() == chip_name) + .unwrap_or_else(|| panic!("No chip found named {}", chip_name.clone())); + // get the picus info for the chip + let picus_info = chip.picus_info(); + // set the var -> readable name mapping + set_picus_names(picus_info.col_to_name.clone()); + // set base col number for creating fresh values + initialize_fresh_var_ctr(chip.width() + 1); + + // Set the field modulus for the Picus program: + let koala_prime = 0x7f000001; + let _ = set_field_modulus(koala_prime); + + // Initialize the Picus program + let mut picus_program = PicusProgram::new(koala_prime); + + // Build the Picus program which will have a single module with the chip constraints + println!("Generating Picus program for {} chip.....", chip.name()); + let (picus_module, mut aux_modules) = analyze_chip(chip, &chips, None); + picus_program.add_modules(&mut aux_modules); + // At this point, we've built a module directly from the constraints. However, this isn't super amenable to verification + // because the selectors introduce a lot of nonlinearity. So what we do instead is generate distinct Picus modules + // each of which correspond to a selector being enabled. The selectors are mutually exclusive. + let mut selector_modules = BTreeMap::new(); + + if picus_info.selector_indices.is_empty() { + panic!("PicusBuilder needs at least one selector to be enabled!") + } + println!("Applying selectors program....."); + println!("PicusInfo: {:?}", picus_info.clone()); + for (selector_col, _) in &picus_info.selector_indices { + let mut env = BTreeMap::new(); + // Set `is_real = 1` if it is set in `picus_info` + if let Some(id) = picus_info.is_real_index { + env.insert(id, 1); + } + env.insert(*selector_col, 1); + for (other_selector_col, _) in &picus_info.selector_indices { + if selector_col == other_selector_col { + continue; + } + env.insert(*other_selector_col, 0); + } + // We generate a new Picus module by partially evaluating our original Picus module with respect + // to the environment map. + let updated_picus_module = picus_module.partial_eval(&env); + selector_modules.insert(updated_picus_module.name.clone(), updated_picus_module); + } + + picus_program.add_modules(&mut selector_modules); + let res = + picus_program.write_to_path(args.picus_out_dir.join(format!("{}.picus", chip.name()))); + if res.is_err() { + panic!("Failed to write picus file: {res:?}"); + } + println!("Successfully extracted Picus program"); +} diff --git a/crates/picus/src/opcode_spec.rs b/crates/picus/src/opcode_spec.rs new file mode 100644 index 00000000..abf98828 --- /dev/null +++ b/crates/picus/src/opcode_spec.rs @@ -0,0 +1,76 @@ +use zkm_core_executor::Opcode; + +/// Picus specification for the Instruction opcode. +#[derive(Clone, Debug, Default)] +#[allow(dead_code)] +pub struct OpcodeSpec { + /// Selector + pub selector: &'static str, + /// Chip + pub chip: &'static str, + /// Maps the argument to column name in corresponding chip. + pub arg_to_colname: &'static [(IndexSlice, &'static str)], +} + +/// A selection of indices inside `values`. +#[derive(Clone, Copy, Debug)] +#[allow(dead_code)] +pub enum IndexSlice { + /// A continuous half-open range [start, end). If end is `usize::MAX` then + /// it represents [start, ``values.len()``) + Range { start: usize, end: usize }, + /// A single position + Single(usize), +} + +/// The top level function which declares and retrieves the spec for a given opcode. +pub fn spec_for(kind: Opcode) -> OpcodeSpec { + use IndexSlice::*; + match kind { + Opcode::ADD => OpcodeSpec { + selector: "is_add", + chip: "AddSub", + arg_to_colname: &[ + (Single(2), "pc"), + (Single(3), "next_pc"), + (Range { start: 6, end: 10 }, "add_operation"), + (Range { start: 10, end: 14 }, "operand_1"), + (Range { start: 14, end: 18 }, "operand_2"), + ], + }, + Opcode::SUB => OpcodeSpec { + selector: "is_sub", + chip: "AddSub", + arg_to_colname: &[ + (Single(2), "pc"), + (Single(3), "next_pc"), + (Range { start: 6, end: 10 }, "add_operation"), + (Range { start: 10, end: 14 }, "operand_1"), + (Range { start: 14, end: 18 }, "operand_2"), + ], + }, + Opcode::SRL => OpcodeSpec { + selector: "is_srl", + chip: "ShiftRight", + arg_to_colname: &[ + (Single(2), "pc"), + (Single(3), "next_pc"), + (Range { start: 6, end: 10 }, "a"), + (Range { start: 10, end: 14 }, "b"), + (Range { start: 14, end: 18 }, "c"), + ], + }, + Opcode::SLT => OpcodeSpec { + selector: "is_slt", + chip: "Lt", + arg_to_colname: &[ + (Single(2), "pc"), + (Single(3), "next_pc"), + (Range { start: 6, end: 10 }, "a"), + (Range { start: 10, end: 14 }, "b"), + (Range { start: 14, end: 18 }, "c"), + ], + }, + _ => panic!("Unimplemented opcode {kind:#?}"), + } +} diff --git a/crates/picus/src/pcl/expr.rs b/crates/picus/src/pcl/expr.rs new file mode 100644 index 00000000..d54b8532 --- /dev/null +++ b/crates/picus/src/pcl/expr.rs @@ -0,0 +1,655 @@ +use std::{ + collections::HashMap, + fmt::{self, Display, Formatter}, + iter::{Product, Sum}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, OnceLock, RwLock, + }, +}; + +/// Mapping from column ids to variable names. This mapping should be derived in the `PicusInfo` struct +static PICUS_NAMES_GLOBAL: OnceLock>> = OnceLock::new(); + +/// Maintains col indices for fresh variables during the course of extraction +static FRESH_VAR_CTR: OnceLock = OnceLock::new(); +pub fn set_picus_names(map: HashMap) { + let _ = PICUS_NAMES_GLOBAL.set(RwLock::new(map)); +} + +// Get or initialize the fresh var counter +fn ctr() -> &'static AtomicUsize { + FRESH_VAR_CTR.get_or_init(|| AtomicUsize::new(0)) +} + +// set the fresh counter val to something +pub fn initialize_fresh_var_ctr(val: usize) { + let _ = FRESH_VAR_CTR.set(AtomicUsize::new(val)); +} + +pub fn fresh_picus_var_id() -> usize { + let cur_var = ctr().load(Ordering::Relaxed); + ctr().store(cur_var + 1, Ordering::Relaxed); + cur_var +} + +pub fn fresh_picus_var() -> PicusAtom { + PicusAtom::new_var(fresh_picus_var_id()) +} + +// update the counter +pub fn fresh_picus_expr() -> PicusExpr { + PicusExpr::Var(fresh_picus_var_id()) +} + +use p3_field::{FieldAlgebra, PrimeField32}; + +/// Global, thread-safe holder for the PCL prime field modulus. +/// +/// This is initialized exactly once via [`set_field_modulus`]. Arithmetic +/// that combines only constants will be reduced modulo this value when set. +static FIELD_MODULUS: OnceLock> = OnceLock::new(); +pub type Felt = p3_koala_bear::KoalaBear; + +/// Sets the field modulus for PCL +pub fn set_field_modulus(p: u64) -> Result<(), u64> { + // set only once; returns Err(p) if already set + FIELD_MODULUS.set(Arc::new(p)).map_err(|arc| Arc::try_unwrap(arc).unwrap_or_else(|a| *a)) +} + +/// Get PCL field modulus +pub fn current_modulus() -> Option { + FIELD_MODULUS.get().map(|a| **a) +} + +/// Given an integer reduce it into the field +pub fn reduce_mod(c: i64) -> u64 { + if let Some(p) = current_modulus() { + (c % (p as i64)) as u64 + } else { + c as u64 + } +} + +/// Arithmetic expressions over the Picus constraint language (PCL). +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum PicusExpr { + /// Constant field element. We use a `u64` to be safe because the prime is 31 bits and we don't want to deal with + /// underflows or overflows + Const(u64), + /// Variable identified by `(name, index, tag)`, printed as `name_index_tag`. NOTE: Tag might + /// be droppable + Var(usize), + /// Add. + Add(Box, Box), + /// Sub. + Sub(Box, Box), + /// Mul + Mul(Box, Box), + /// Div (probably can delete) + Div(Box, Box), + /// Unary negation. + Neg(Box), + /// Exponentiation + Pow(u64, Box), +} + +impl Default for PicusExpr { + fn default() -> Self { + PicusExpr::Const(0) + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub enum PicusAtom { + Const(u64), + Var(usize), +} + +impl PicusAtom { + pub fn new_var(id: usize) -> Self { + Self::Var(id) + } +} + +impl Display for PicusAtom { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Const(c) => write!(f, "{c}"), + Self::Var(id) => { + if let Some(lock) = PICUS_NAMES_GLOBAL.get() { + if let Some(name) = lock.read().unwrap().get(id) { + return f.write_str(name); + } + } + write!(f, "v{id}") + } + } + } +} + +impl From for PicusExpr { + fn from(value: PicusAtom) -> Self { + match value { + PicusAtom::Const(c) => PicusExpr::Const(c), + PicusAtom::Var(id) => PicusExpr::Var(id), + } + } +} + +impl From for PicusExpr { + fn from(value: Felt) -> Self { + PicusExpr::Const(value.as_canonical_u32().into()) + } +} + +impl Add for PicusAtom { + type Output = PicusExpr; + + fn add(self, rhs: Felt) -> Self::Output { + PrimeField32::as_canonical_u32(&rhs).into() + } +} + +impl Add for PicusAtom { + type Output = PicusExpr; + + fn add(self, rhs: PicusAtom) -> Self::Output { + PicusExpr::Add(Box::new(self.into()), Box::new(rhs.into())) + } +} + +impl Add for PicusAtom { + type Output = PicusExpr; + + fn add(self, rhs: PicusExpr) -> Self::Output { + let left_expr: PicusExpr = self.into(); + left_expr + rhs + } +} + +impl Sub for PicusAtom { + type Output = PicusExpr; + + fn sub(self, rhs: Felt) -> Self::Output { + let self_expr: PicusExpr = self.into(); + self_expr - rhs + } +} + +impl Sub for PicusAtom { + type Output = PicusExpr; + + fn sub(self, rhs: PicusAtom) -> Self::Output { + let self_expr: PicusExpr = self.into(); + let rhs_expr: PicusExpr = rhs.into(); + self_expr - rhs_expr + } +} + +impl Sub for PicusAtom { + type Output = PicusExpr; + + fn sub(self, rhs: PicusExpr) -> Self::Output { + let self_expr: PicusExpr = self.into(); + self_expr - rhs + } +} + +impl Mul for PicusAtom { + type Output = PicusExpr; + + fn mul(self, rhs: PicusAtom) -> Self::Output { + let self_expr: PicusExpr = self.into(); + let rhs_expr: PicusExpr = rhs.into(); + self_expr * rhs_expr + } +} + +impl Mul for PicusAtom { + type Output = PicusExpr; + + fn mul(self, rhs: Felt) -> Self::Output { + let self_expr: PicusExpr = self.into(); + self_expr * rhs + } +} + +impl Mul for PicusAtom { + type Output = PicusExpr; + + fn mul(self, rhs: PicusExpr) -> Self::Output { + let self_expr: PicusExpr = self.into(); + self_expr * rhs + } +} + +impl Sum for PicusExpr { + fn sum>(iter: I) -> Self { + let mut output: PicusExpr = 0.into(); + for item in iter { + output += item; + } + output + } +} + +impl Product for PicusExpr { + fn product>(iter: I) -> Self { + let mut output: PicusExpr = 1.into(); + for item in iter { + output *= item; + } + output + } +} + +impl PicusExpr { + /// Approximate tree size (number of nodes). + /// + /// Useful as a heuristic for introducing temporary variables (e.g., to keep + /// expressions small for solvers). `Pow` is counted as 1 by design. + #[must_use] + pub fn size(&self) -> usize { + match self { + Self::Const(_) | Self::Var(_) | Self::Pow(_, _) => 1, + Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) | Self::Div(a, b) => { + 1 + a.size() + b.size() + } + Self::Neg(a) => 1 + a.size(), + } + } + /// Helper to construct a `Var` with a column index. + pub fn var(idx: usize) -> Self { + PicusExpr::Var(idx) + } + #[must_use] + /// Convenience for exponentiating by a non-negative `u32` power. + pub fn pow(self, k: u32) -> Self { + PicusExpr::Pow(k.into(), Box::new(self)) + } + /// Returns `true` iff this is exactly the constant zero. + #[inline] + #[must_use] + pub fn is_const_zero(&self) -> bool { + matches!(self, PicusExpr::Const(c) if *c == 0) + } +} + +macro_rules! impl_from_ints { + ($($t:ty),* $(,)?) => {$( + impl From<$t> for PicusExpr { + fn from(v: $t) -> Self { + PicusExpr::Const(v as u64) + } + } + )*} +} + +impl_from_ints!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize); + +/// Pointwise addition with light constant folding. +/// +/// - If both sides are constant, the sum is reduced modulo the current field (if set). +/// - Adding zero returns the other side. +/// - Otherwise, constructs `Add(lhs, rhs)`. +impl Add for PicusExpr { + type Output = PicusExpr; + fn add(self, rhs: PicusExpr) -> Self::Output { + let lhs = self.clone(); + match (lhs.clone(), rhs.clone()) { + (PicusExpr::Const(c_1), PicusExpr::Const(c_2)) => { + (reduce_mod((c_1 + c_2) as i64)).into() + } + (PicusExpr::Const(c), _) => { + if c == 0 { + rhs + } else { + PicusExpr::Add(Box::new(lhs), Box::new(rhs)) + } + } + (_, PicusExpr::Const(c)) => { + if c == 0 { + lhs + } else { + PicusExpr::Add(Box::new(lhs), Box::new(rhs)) + } + } + _ => PicusExpr::Add(Box::new(lhs), Box::new(rhs)), + } + } +} + +impl Add for PicusExpr { + type Output = PicusExpr; + + fn add(self, rhs: Felt) -> Self::Output { + let rhs_expr: Self = rhs.into(); + self + rhs_expr + } +} + +impl Add for PicusExpr { + type Output = PicusExpr; + + fn add(self, rhs: PicusAtom) -> Self::Output { + let rhs_expr: Self = rhs.into(); + self + rhs_expr + } +} + +impl AddAssign for PicusExpr { + fn add_assign(&mut self, rhs: PicusExpr) { + *self = self.clone() + rhs; + } +} + +/// Pointwise subtraction with light constant folding. +/// +/// - If both sides are constant, the difference is reduced modulo the current field (if set). +/// - Subtracting zero returns the left-hand side. +/// - Otherwise, constructs `Sub(lhs, rhs)`. +impl Sub for PicusExpr { + type Output = PicusExpr; + fn sub(self, rhs: PicusExpr) -> Self::Output { + let lhs = self.clone(); + match (lhs.clone(), rhs.clone()) { + (PicusExpr::Const(c_1), PicusExpr::Const(c_2)) => { + reduce_mod((c_1 as i64) - (c_2 as i64)).into() + } + (_, PicusExpr::Const(c)) => { + if c == 0 { + lhs + } else { + PicusExpr::Sub(Box::new(self), Box::new(rhs)) + } + } + _ => PicusExpr::Sub(Box::new(self), Box::new(rhs)), + } + } +} + +impl Sub for PicusExpr { + type Output = PicusExpr; + + fn sub(self, rhs: Felt) -> Self::Output { + let rhs_expr: Self = rhs.into(); + self - rhs_expr + } +} + +impl Sub for PicusExpr { + type Output = PicusExpr; + + fn sub(self, rhs: PicusAtom) -> Self::Output { + let rhs_expr: Self = rhs.into(); + self - rhs_expr + } +} + +impl SubAssign for PicusExpr { + fn sub_assign(&mut self, rhs: PicusExpr) { + *self = self.clone() - rhs; + } +} + +/// Unary negation with constant folding. +/// +/// - If the input is a constant, returns the additive inverse reduced modulo the current field (if +/// set). Otherwise constructs `Neg`. +impl Neg for PicusExpr { + type Output = PicusExpr; + fn neg(self) -> Self::Output { + let lhs = self.clone(); + match lhs.clone() { + PicusExpr::Const(c) => reduce_mod((current_modulus().unwrap() - c) as i64).into(), + _ => PicusExpr::Neg(Box::new(lhs)), + } + } +} + +/// Pointwise multiplication with light constant folding and scalar routing. +/// +/// - If either side is a constant, routes to the `(PicusExpr * Integer)` impl to share logic. +/// - Otherwise constructs `Mul(lhs, rhs)`. +impl Mul for PicusExpr { + type Output = PicusExpr; + fn mul(self, rhs: PicusExpr) -> Self::Output { + let lhs = self.clone(); + match (lhs.clone(), rhs.clone()) { + (PicusExpr::Const(c), _) => rhs * c, + (_, PicusExpr::Const(c)) => lhs * c, + _ => PicusExpr::Mul(Box::new(lhs), Box::new(rhs)), + } + } +} + +impl Mul for PicusExpr { + type Output = PicusExpr; + + fn mul(self, rhs: Felt) -> Self::Output { + let rhs_expr: PicusExpr = rhs.into(); + self * rhs_expr + } +} + +impl Mul for PicusExpr { + type Output = PicusExpr; + + fn mul(self, rhs: PicusAtom) -> Self::Output { + let rhs_expr: PicusExpr = rhs.into(); + self * rhs_expr + } +} + +impl MulAssign for PicusExpr { + fn mul_assign(&mut self, rhs: PicusExpr) { + *self = self.clone() * rhs; + } +} + +/// Scalar multiplication with constant folding. +/// +/// - Multiplying by `0` yields `0`. +/// - Multiplying by `1` yields the original expression. +/// - If the left is also a constant, multiply and reduce modulo the current field (if set). +/// - Otherwise constructs `Mul(lhs, Const(rhs))`. +impl Mul for PicusExpr { + type Output = PicusExpr; + fn mul(self, rhs: u64) -> Self::Output { + if rhs == 0 { + return PicusExpr::Const(0); + } + if rhs == 1 { + return self.clone(); + } + let lhs = self.clone(); + match lhs { + PicusExpr::Const(c_1) => reduce_mod((c_1 * rhs) as i64).into(), + _ => PicusExpr::Mul(Box::new(lhs), Box::new(rhs.into())), + } + } +} + +impl FieldAlgebra for PicusExpr { + type F = Felt; + + const ZERO: Self = PicusExpr::Const(0); + + const ONE: Self = PicusExpr::Const(1); + + const TWO: Self = PicusExpr::Const(2); + + const NEG_ONE: Self = PicusExpr::Const(u64::MAX); + + fn from_f(f: Self::F) -> Self { + f.into() + } + + fn from_bool(b: bool) -> Self { + if b { + PicusExpr::Const(1) + } else { + PicusExpr::Const(0) + } + } + + fn from_canonical_u8(n: u8) -> Self { + n.into() + } + + fn from_canonical_u16(n: u16) -> Self { + n.into() + } + + fn from_canonical_u32(n: u32) -> Self { + n.into() + } + + fn from_canonical_u64(n: u64) -> Self { + n.into() + } + + fn from_canonical_usize(n: usize) -> Self { + n.into() + } + + fn from_wrapped_u32(n: u32) -> Self { + n.into() + } + + fn from_wrapped_u64(n: u64) -> Self { + n.into() + } +} + +/// Boolean/relational constraints over `PicusExpr`. +#[derive(Debug, Clone)] +pub enum PicusConstraint { + /// x < y + Lt(Box, Box), + /// x <= y + Leq(Box, Box), + /// x > y + Gt(Box, Box), + /// x >= y + Geq(Box, Box), + /// p => q + Implies(Box, Box), + /// -p + Not(Box), + /// p <=> q + Iff(Box, Box), + /// p && q + And(Box, Box), + /// p || q + Or(Box, Box), + /// Canonical equality-to-zero form: `Eq(e)` represents `e = 0`. + Eq(Box), +} + +impl PicusConstraint { + /// Build an equality constraint `left = right` by moving to zero: + /// returns `Eq(left - right)`. + #[must_use] + pub fn new_equality(left: PicusExpr, right: PicusExpr) -> PicusConstraint { + PicusConstraint::Eq(Box::new(left - right)) + } + + #[must_use] + /// Builds a bit constraint + pub fn new_bit(left: PicusExpr) -> PicusConstraint { + PicusConstraint::Eq(Box::new(left.clone() * (left.clone() - PicusExpr::Const(1u64)))) + } + + /// Build a comparison constraint `left < right` + #[must_use] + pub fn new_lt(left: PicusExpr, right: PicusExpr) -> PicusConstraint { + PicusConstraint::Lt(Box::new(left), Box::new(right)) + } + + /// Build a comparison constraint `left <= right` + #[must_use] + pub fn new_leq(left: PicusExpr, right: PicusExpr) -> PicusConstraint { + PicusConstraint::Leq(Box::new(left), Box::new(right)) + } + + /// Build a comparison constraint `left > right` + #[must_use] + pub fn new_gt(left: PicusExpr, right: PicusExpr) -> PicusConstraint { + PicusConstraint::Gt(Box::new(left), Box::new(right)) + } + + /// Build a comparison constraint `left >= right` + #[must_use] + pub fn new_geq(left: PicusExpr, right: PicusExpr) -> PicusConstraint { + PicusConstraint::Geq(Box::new(left), Box::new(right)) + } + + /// Assumes ``l`` and ``u`` fit into the prime + /// Generates constraints l <= e <= u + #[must_use] + pub fn in_range(e: PicusExpr, l: usize, u: usize) -> Vec { + assert!(l < u); + vec![PicusConstraint::new_geq(e.clone(), l.into()), PicusConstraint::new_leq(e, u.into())] + } + + #[must_use] + pub fn apply_multiplier(&self, multiplier: PicusExpr) -> PicusConstraint { + use PicusConstraint::*; + if let PicusExpr::Const(1) = multiplier { + return self.clone(); + } + match self { + And(l, r) => { + let new_left = l.apply_multiplier(multiplier.clone()); + let new_right = r.apply_multiplier(multiplier); + PicusConstraint::And(Box::new(new_left), Box::new(new_right)) + } + Lt(l, r) => { + let new_left = multiplier.clone() * (*l.clone()); + let new_right = multiplier.clone() * (*r.clone()); + PicusConstraint::Lt(Box::new(new_left), Box::new(new_right)) + } + Leq(l, r) => { + let new_left = multiplier.clone() * (*l.clone()); + let new_right = multiplier.clone() * (*r.clone()); + PicusConstraint::Leq(Box::new(new_left), Box::new(new_right)) + } + Gt(l, r) => { + let new_left = multiplier.clone() * (*l.clone()); + let new_right = multiplier.clone() * (*r.clone()); + PicusConstraint::Gt(Box::new(new_left), Box::new(new_right)) + } + Geq(l, r) => { + let new_left = multiplier.clone() * (*l.clone()); + let new_right = multiplier.clone() * (*r.clone()); + PicusConstraint::Geq(Box::new(new_left), Box::new(new_right)) + } + Implies(l, r) => { + let new_left = l.apply_multiplier(multiplier.clone()); + let new_right = r.apply_multiplier(multiplier); + PicusConstraint::Implies(Box::new(new_left), Box::new(new_right)) + } + Not(c) => { + let new_c = c.apply_multiplier(multiplier.clone()); + PicusConstraint::Not(Box::new(new_c)) + } + Iff(l, r) => { + let new_left = l.apply_multiplier(multiplier.clone()); + let new_right = r.apply_multiplier(multiplier); + PicusConstraint::Iff(Box::new(new_left), Box::new(new_right)) + } + Or(l, r) => { + let new_left = l.apply_multiplier(multiplier.clone()); + let new_right = r.apply_multiplier(multiplier); + PicusConstraint::Or(Box::new(new_left), Box::new(new_right)) + } + Eq(e) => { + let new_e = multiplier.clone() * (*e.clone()); + PicusConstraint::Eq(Box::new(new_e)) + } + } + } +} diff --git a/crates/picus/src/pcl/mod.rs b/crates/picus/src/pcl/mod.rs new file mode 100644 index 00000000..75f82bed --- /dev/null +++ b/crates/picus/src/pcl/mod.rs @@ -0,0 +1,7 @@ +mod expr; +mod partial_evaluator; +mod program; + +pub use expr::*; +pub use partial_evaluator::*; +pub use program::*; diff --git a/crates/picus/src/pcl/partial_evaluator.rs b/crates/picus/src/pcl/partial_evaluator.rs new file mode 100644 index 00000000..f0545113 --- /dev/null +++ b/crates/picus/src/pcl/partial_evaluator.rs @@ -0,0 +1,262 @@ +use std::collections::BTreeMap; + +use crate::pcl::{current_modulus, reduce_mod, PicusCall, PicusConstraint, PicusExpr}; + +// === Helpers === + +fn mod_reduce_u64(x: u64) -> u64 { + // converting to i64 is fine because the prime is 31 bits the input values will not wrap around + reduce_mod(x as i64) +} + +// performs the inverse of `base` with respect to `current_modulus()` +// this is only sound if `modulus` is under `64` bits +fn mod_pow_u64(mut base: u64, mut exp: u64) -> u64 { + // Fast pow with optional modulus + if let Some(p) = current_modulus() { + base %= p; + let mut acc: u128 = 1; + let mut b: u128 = base as u128; + let m: u128 = p as u128; + while exp > 0 { + if exp & 1 == 1 { + acc = (acc * b) % m; + } + b = (b * b) % m; + exp >>= 1; + } + acc as u64 + } else { + // No modulus set: beware overflow + let mut acc: u128 = 1; + let mut b: u128 = base as u128; + while exp > 0 { + if exp & 1 == 1 { + acc = acc.saturating_mul(b); + } + b = b.saturating_mul(b); + exp >>= 1; + } + acc as u64 + } +} + +// Smart Pow that also folds constants and k=0/1. +fn pow_simplify(k: u64, base: PicusExpr) -> PicusExpr { + match k { + 0 => 1u64.into(), + 1 => base, + _ => match base { + PicusExpr::Const(c) => PicusExpr::Const(mod_pow_u64(c, k)), + other => PicusExpr::Pow(k, Box::new(other)), + }, + } +} + +// === Expression substitution/simplification === +// substitutes variables with constants in `e` from `env` and performs partial evaluation +fn subst_expr(e: &PicusExpr, env: &BTreeMap) -> PicusExpr { + use crate::PicusExpr::*; + match e { + Const(c) => Const(mod_reduce_u64(*c)), + Var(v) => { + if let Some(val) = env.get(v) { + Const(mod_reduce_u64(*val)) + } else { + Var(*v) + } + } + Add(a, b) => subst_expr(a, env) + subst_expr(b, env), + Sub(a, b) => subst_expr(a, env) - subst_expr(b, env), + Mul(a, b) => subst_expr(a, env) * subst_expr(b, env), + Div(a, b) => { + // Optional: try to simplify known constants + let aa = subst_expr(a, env); + let bb = subst_expr(b, env); + match (&aa, &bb) { + (_, Const(1)) => aa, // e / 1 => e + (Const(0), _) => 0u64.into(), // 0 / e => 0 (assuming e ≠ 0; safe algebraically) + _ => Div(Box::new(aa), Box::new(bb)), + } + } + Neg(a) => -subst_expr(a, env), + Pow(k, a) => pow_simplify(*k, subst_expr(a, env)), + } +} + +// === Call substitution/simplification === +/// This function replaces variables in `call` with constants in `env` +/// and then simplifies. +pub fn subst_call(call: &PicusCall, env: &BTreeMap) -> PicusCall { + let mut new_inputs = Vec::new(); + let mut new_outputs = Vec::new(); + for input in &call.inputs { + new_inputs.push(subst_expr(input, env)); + } + for output in &call.outputs { + new_outputs.push(subst_expr(output, env)); + } + PicusCall { inputs: new_inputs, outputs: new_outputs, mod_name: call.mod_name.clone() } +} + +// === Constraint substitution/simplification === +/// This function replaces variables in `c` with constants in `env` +/// and then simplifies. +pub fn subst_constraint( + c: &PicusConstraint, + env: &BTreeMap, +) -> Option { + use PicusConstraint::*; + let keep = |cc: PicusConstraint| Some(cc); + + match c { + Eq(e) => { + let ee = subst_expr(e, env); + // Drop tautologies Eq(0); keep contradictions as Eq(1) + match ee { + PicusExpr::Const(0) => None, + PicusExpr::Const(1) => keep(Eq(Box::new(1u64.into()))), // 1 = 0 (unsat marker) + _ => keep(Eq(Box::new(ee))), + } + } + + Lt(a, b) => { + let aa = subst_expr(a, env); + let bb = subst_expr(b, env); + match (&aa, &bb) { + (PicusExpr::Const(x), PicusExpr::Const(y)) => { + if x < y { + None + } else { + keep(Eq(Box::new(1u64.into()))) + } + } + _ => keep(Lt(Box::new(aa), Box::new(bb))), + } + } + + Leq(a, b) => { + let aa = subst_expr(a, env); + let bb = subst_expr(b, env); + match (&aa, &bb) { + (PicusExpr::Const(x), PicusExpr::Const(y)) => { + if x <= y { + None + } else { + keep(Eq(Box::new(1u64.into()))) + } + } + _ => keep(Leq(Box::new(aa), Box::new(bb))), + } + } + + Gt(a, b) => { + let aa = subst_expr(a, env); + let bb = subst_expr(b, env); + match (&aa, &bb) { + (PicusExpr::Const(x), PicusExpr::Const(y)) => { + if x > y { + None + } else { + keep(Eq(Box::new(1u64.into()))) + } + } + _ => keep(Gt(Box::new(aa), Box::new(bb))), + } + } + + Geq(a, b) => { + let aa = subst_expr(a, env); + let bb = subst_expr(b, env); + match (&aa, &bb) { + (PicusExpr::Const(x), PicusExpr::Const(y)) => { + if x >= y { + None + } else { + keep(Eq(Box::new(1u64.into()))) + } + } + _ => keep(Geq(Box::new(aa), Box::new(bb))), + } + } + + Not(p) => { + // Push inside and simplify: + match subst_constraint(p, env) { + None => Some(Eq(Box::new(1u64.into()))), // not(true) => false + Some(Eq(e)) if matches!(*e, PicusExpr::Const(1)) => None, // not(false) => true + Some(pp) => Some(Not(Box::new(pp))), + } + } + + And(p, q) => { + let pp = subst_constraint(p, env); + let qq = subst_constraint(q, env); + match (pp, qq) { + (None, None) => None, // true && true + (Some(Eq(e)), _) if matches!(*e, PicusExpr::Const(1)) => { + Some(Eq(Box::new(1u64.into()))) + } // false && _ => false + (_, Some(Eq(e))) if matches!(*e, PicusExpr::Const(1)) => { + Some(Eq(Box::new(1u64.into()))) + } + (None, Some(r)) => Some(r), // true && r => r + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Some(And(Box::new(l), Box::new(r))), + } + } + + Or(p, q) => { + let pp = subst_constraint(p, env); + let qq = subst_constraint(q, env); + match (pp, qq) { + (None, _) => None, // true || _ => true + (_, None) => None, + (Some(Eq(e)), r) if matches!(*e, PicusExpr::Const(1)) => r, // false || r => r + (l, Some(Eq(e))) if matches!(*e, PicusExpr::Const(1)) => l, + (Some(l), Some(r)) => Some(Or(Box::new(l), Box::new(r))), + } + } + + Implies(p, q) => { + // p => q ≡ ¬p ∨ q + let np_or_q = Or(Box::new(Not(p.clone())), q.clone()); + subst_constraint(&np_or_q, env) + } + + Iff(p, q) => { + // p <=> q ≡ (p => q) ∧ (q => p) + let p_imp_q = Implies(p.clone(), q.clone()); + let q_imp_p = Implies(q.clone(), p.clone()); + subst_constraint(&And(Box::new(p_imp_q), Box::new(q_imp_p)), env) + } + } +} + +/// Given a collection of constraints `constraints` and a mapping of +/// variables to constants, `partial_evaluate` produces a new set of constraints +/// after substituting those variables with constants and partial evaluating +pub fn partial_evaluate( + constraints: &[PicusConstraint], + env: &BTreeMap, +) -> Vec { + let mut out_constraints = Vec::with_capacity(constraints.len()); + for c in constraints { + if let Some(cc) = subst_constraint(c, env) { + // Optional micro-normalization: if we ever produce Eq(Const(0)) here, drop it + match &cc { + PicusConstraint::Eq(e) if matches!(&**e, PicusExpr::Const(0)) => {} + _ => out_constraints.push(cc), + } + } + } + out_constraints +} + +pub fn partial_evaluate_calls(calls: &[PicusCall], env: &BTreeMap) -> Vec { + let mut out_calls = Vec::with_capacity(calls.len()); + for call in calls { + out_calls.push(subst_call(call, env)) + } + out_calls +} diff --git a/crates/picus/src/pcl/program.rs b/crates/picus/src/pcl/program.rs new file mode 100644 index 00000000..e21b4bae --- /dev/null +++ b/crates/picus/src/pcl/program.rs @@ -0,0 +1,343 @@ +use crate::pcl::{ + expr::{PicusConstraint, PicusExpr}, + partial_evaluate, partial_evaluate_calls, +}; +use std::{ + collections::BTreeMap, + fmt::{self, Display, Formatter}, + fs::File, + io::{self, Write}, + path::Path, +}; + +/// A call to another Picus module (by name). +/// +/// Renders to the PCL s-expression: +/// +/// ```text +/// (call [] []) +/// ``` +/// +/// where both `outputs` and `inputs` are printed using `Display` for `PicusExpr`, +/// enclosed in `[...]` and space-separated. +#[derive(Debug, Clone, Default)] +pub struct PicusCall { + /// Callee module name. This will oftentimes be specialized (e.g., suffixed with constants) + /// by the compiler to facilitate partial evaluation of the callee. + pub mod_name: String, + /// Expressions that *receive* the callee results at the call site. + /// (Printed first in the call s-expression.) + pub outputs: Vec, + /// Expressions that are *passed* to the callee. + /// (Printed last in the call s-expression.) + pub inputs: Vec, +} + +impl PicusCall { + pub fn new(mod_name: String, outputs: &[PicusExpr], inputs: &[PicusExpr]) -> PicusCall { + PicusCall { mod_name, outputs: outputs.into(), inputs: inputs.into() } + } + + pub fn apply_multiplier(&self, multiplier: PicusExpr) -> PicusCall { + let new_inputs: Vec = + self.inputs.iter().map(|x| multiplier.clone() * (*x).clone()).collect(); + PicusCall { + mod_name: self.mod_name.clone(), + outputs: self.outputs.clone(), + inputs: new_inputs, + } + } +} + +/// A single Picus module and its contents. +/// +/// A module has a name, a list of input/output expressions (ports), +/// a set of constraints, optional postconditions, assumptions about +/// determinism, and a list of nested calls to other modules. +/// +/// The textual form emitted by [`PicusModule::dump`] is a sequence +/// of PCL s-expressions wrapped between `(begin-module )` and +/// `(end-module)`. +#[derive(Debug, Clone, Default)] +pub struct PicusModule { + /// Module identifier used in `(begin-module )`. + pub name: String, + /// Module inputs (printed as `(input )`). + pub inputs: Vec, + /// Module outputs (printed as `(output )`). + pub outputs: Vec, + /// Circuit constraints enforced within the module (printed as `(assert )`). + pub constraints: Vec, + /// Constraints to be treated as postconditions (printed as `(post-condition )`). + pub postconditions: Vec, + /// Expressions assumed to be deterministic (printed as `(assume-deterministic )`). + pub assume_deterministic: Vec, + /// Nested calls emitted inside the module body. + pub calls: Vec, +} + +impl PicusModule { + /// Construct an empty Picus module with the given `name`. + #[must_use] + pub fn new(name: String) -> Self { + PicusModule { + name, + inputs: Vec::new(), + outputs: Vec::new(), + constraints: Vec::new(), + postconditions: Vec::new(), + assume_deterministic: Vec::new(), + calls: Vec::new(), + } + } + + /// builds an empty picus module with `num_inputs` inputs and `num_outputs` outputs + pub fn build_empty(name: String, num_inputs: usize, num_outputs: usize) -> Self { + let mut inputs = Vec::with_capacity(num_inputs); + let mut outputs = Vec::with_capacity(num_inputs); + for i in 0..num_inputs { + inputs.push(PicusExpr::Var(i)); + } + for i in 0..num_outputs { + outputs.push(PicusExpr::Var(num_inputs + i)); + } + PicusModule { + name, + inputs, + outputs, + constraints: Vec::new(), + postconditions: Vec::new(), + assume_deterministic: Vec::new(), + calls: Vec::new(), + } + } + + // Applies the multiplier across the constraints + pub fn apply_multiplier(&mut self, multiplier: PicusExpr) { + let mut constraints = Vec::with_capacity(self.constraints.len()); + let mut post_conditions = Vec::with_capacity(self.postconditions.len()); + let mut calls = Vec::with_capacity(self.calls.len()); + for constraint in &self.constraints { + constraints.push(constraint.apply_multiplier(multiplier.clone())); + } + + for call in &self.calls { + calls.push(call.apply_multiplier(multiplier.clone())); + } + + for postcond in &self.postconditions { + post_conditions.push(postcond.apply_multiplier(multiplier.clone())); + } + self.constraints = constraints; + self.postconditions = post_conditions; + self.calls = calls; + } + + #[must_use] + /// Construct a new Picus module by partially evaluating the module's constraints + /// with the given values + pub fn partial_eval(&self, env: &BTreeMap) -> Self { + let mut name = self.name.clone(); + for (k, v) in env { + name += &format!("{k}_{v}"); + } + let constraints = partial_evaluate(&self.constraints, env); + let calls = partial_evaluate_calls(&self.calls, env); + let postconditions = partial_evaluate(&self.postconditions, env); + PicusModule { + name, + inputs: self.inputs.clone(), + outputs: self.outputs.clone(), + constraints, + postconditions, + assume_deterministic: self.assume_deterministic.clone(), + calls, + } + } +} + +impl Display for PicusModule { + /// Serialize this module into a sequence of PCL lines. + /// + /// Output shape: + /// + /// ```text + /// (begin-module ) + /// (input )* + /// (output )* + /// (assert )* + /// (post-condition )* + /// (assume-deterministic )* + /// (call [] [])* + /// (end-module) + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + writeln!(f, "(begin-module {})", self.name)?; + + for inp in &self.inputs { + writeln!(f, "(input {inp})")?; + } + for out in &self.outputs { + writeln!(f, "(output {out})")?; + } + for c in &self.constraints { + writeln!(f, "(assert {c})")?; + } + for c in &self.postconditions { + writeln!(f, "(post-condition {c})")?; + } + for e in &self.assume_deterministic { + writeln!(f, "(assume-deterministic {e})")?; + } + for call in &self.calls { + writeln!(f, "{call}")?; + } + + write!(f, "(end-module)") + } +} + +/// Print a Picus arithmetic expression in PCL s-expression syntax. +/// +/// Examples: +/// +/// - `Const(5)` → `5` +/// - `Var("x",1,0)` → `x_1_0` +/// - `Add(a,b)` → `(+ a b)` +/// - `Neg(e)` → `(- e)` +/// - `Pow(2, e)` → `(pow 2 e)` +impl Display for PicusExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + use PicusExpr::{Add, Const, Div, Mul, Neg, Pow, Sub, Var}; + match self { + Const(v) => write!(f, "{v}"), + Var(id) => write!(f, "x_{id}"), + Add(a, b) => write!(f, "(+ {a} {b})"), + Sub(a, b) => write!(f, "(- {a} {b})"), + Mul(a, b) => write!(f, "(* {a} {b})"), + Div(a, b) => write!(f, "(/ {a} {b})"), + Neg(a) => write!(f, "(- {a})"), + Pow(c, e) => write!(f, "(pow {c} {e})"), + } + } +} + +/// Print a Picus logical/relational constraint in PCL s-expression syntax. +/// +/// Notes: +/// - Equalities are represented canonically as `(= 0)`, i.e., `Eq(e)` means `e = 0`. +/// - Composite forms (`=>`, `<=>`, `&&`, `||`, `!`) print recursively using `Display` on nested +/// constraints/expressions. +impl Display for PicusConstraint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + use PicusConstraint::{And, Eq, Geq, Gt, Iff, Implies, Leq, Lt, Not, Or}; + match self { + Lt(e1, e2) => write!(f, "(< {e1} {e2})"), + Leq(e1, e2) => write!(f, "(<= {e1} {e2})"), + Gt(e1, e2) => write!(f, "(> {e1} {e2})"), + Geq(e1, e2) => write!(f, "(>= {e1} {e2})"), + Eq(e) => write!(f, "(= {e} 0)"), + Implies(c1, c2) => write!(f, "(=> {c1} {c2})"), + Iff(c1, c2) => write!(f, "(<=> {c1} {c2})"), + Not(c) => write!(f, "(! {c})"), + And(c1, c2) => write!(f, "(&& {c1} {c2})"), + Or(c1, c2) => write!(f, "(|| {c1} {c2})"), + } + } +} + +/// Print a `(call ...)` s-expression for a [`PicusCall`]. +/// +/// Uses the `Display` implementation of `PicusExpr` for both output and input +/// vectors via [`write_expr_slice`]. +impl Display for PicusCall { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("(call ")?; + write_expr_slice(f, &self.outputs)?; + write!(f, " {}", self.mod_name)?; + f.write_str(" ")?; + write_expr_slice(f, &self.inputs)?; + f.write_str(")") + } +} + +/// Write a slice of expressions as a bracketed, space-separated list. +/// +/// Example: `[e1 e2 e3]`. +/// +/// This helper relies on `Display` for `PicusExpr`. +fn write_expr_slice(f: &mut Formatter<'_>, exprs: &[PicusExpr]) -> fmt::Result { + f.write_str("[")?; + for (i, e) in exprs.iter().enumerate() { + if i > 0 { + f.write_str(" ")?; + } + write!(f, "{e}")?; + } + f.write_str("]") +} + +/// A complete Picus program: the prime field and an ordered set of modules. +/// +/// The `modules` map is a `BTreeMap` so that serialization is deterministic +/// across runs (keys are emitted in sorted order). +#[derive(Debug, Clone, Default)] +pub struct PicusProgram { + /// Prime modulus for the field in which all arithmetic takes place. + /// It is assumed the value is prime. + prime: u64, + /// All modules in this program, keyed by module name. + modules: BTreeMap, +} + +impl PicusProgram { + /// Create a new empty program over the given prime field. + #[must_use] + pub fn new(prime: u64) -> Self { + PicusProgram { prime, modules: BTreeMap::new() } + } + + /// Move all entries from `modules` into this program. + /// + /// This uses `BTreeMap::append`, transferring ownership of all modules + /// from the argument map and leaving it empty. + pub fn add_modules(&mut self, modules: &mut BTreeMap) { + self.modules.append(modules); + } + + /// Write the serialized program to any `Write` sink. + pub fn write_to(&self, mut w: W) -> io::Result<()> { + write!(w, "{self}") + } + + /// Write the serialized program to `path`, creating parent directories if needed. + pub fn write_to_path>(&self, path: P) -> io::Result<()> { + let path = path.as_ref(); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let mut f = File::create(path)?; + self.write_to(&mut f) + } +} + +/// Serialize the whole program into PCL text. +/// +/// Output begins with `(prime-number

)`, followed by each module’s +/// PCL block separated by a blank line. Module order is stable due to +/// `BTreeMap` key ordering. +impl Display for PicusProgram { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + writeln!(f, "(prime-number {})", self.prime)?; + // Separate modules with a single blank line, deterministic order via BTreeMap. + let mut first = true; + for m in self.modules.values() { + if !first { + writeln!(f)?; + } + first = false; + writeln!(f, "{m}")?; + } + Ok(()) + } +} diff --git a/crates/picus/src/picus_builder.rs b/crates/picus/src/picus_builder.rs new file mode 100644 index 00000000..457f70b5 --- /dev/null +++ b/crates/picus/src/picus_builder.rs @@ -0,0 +1,382 @@ +use std::collections::BTreeMap; + +use crate::{ + opcode_spec::{spec_for, IndexSlice}, + pcl::{ + fresh_picus_expr, fresh_picus_var, fresh_picus_var_id, Felt, PicusAtom, PicusCall, + PicusConstraint, PicusExpr, PicusModule, + }, +}; +use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder}; +use p3_matrix::dense::{DenseMatrix, RowMajorMatrix}; +use zkm_core_executor::{ByteOpcode, Opcode}; +use zkm_stark::{AirLookup, Chip, LookupKind, MachineAir, MessageBuilder, ZKM_PROOF_NUM_PV_ELTS}; + +/// Implementation `AirBuilder` which builds Picus programs +#[derive(Clone)] +pub struct PicusBuilder<'chips, A: MachineAir> { + pub preprocessed: RowMajorMatrix, + pub main: RowMajorMatrix, + pub public_values: Vec, + pub picus_module: PicusModule, + pub aux_modules: BTreeMap, + pub chips: &'chips [Chip], + pub extract_modularly: bool, + pub multiplier: PicusExpr, + pub concrete_pending_tasks: Vec, + pub symbolic_pending_tasks: Vec, +} + +#[derive(Clone)] +pub struct ConcretePendingTask { + pub chip_name: String, + pub main_vars: Vec, + pub multiplicity: PicusExpr, + pub selector: String, +} + +#[derive(Clone)] +pub struct SymbolicPendingTask { + pub selector: PicusExpr, + pub multiplicity: PicusExpr, +} + +impl<'chips, A: MachineAir> PicusBuilder<'chips, A> { + /// Constructor for the builder + pub fn new( + chip_to_analyze: &'chips Chip, + picus_module: PicusModule, + chips: &'chips [Chip], + main_vars: Option>, + multiplier: Option, + ) -> Self { + let width = chip_to_analyze.air.width(); + // Initialize the public values. + let public_values = (0..ZKM_PROOF_NUM_PV_ELTS).map(PicusAtom::new_var).collect(); + // Initialize the preprocessed and main traces. + let row: Vec = + (0..chip_to_analyze.preprocessed_width()).map(PicusAtom::new_var).collect(); + let preprocessed = DenseMatrix::new_row(row); + let main = if let Some(vars) = main_vars { + assert_eq!(vars.len(), width); + vars + } else { + (0..width).map(PicusAtom::new_var).collect() + }; + let multiplier = + if let Some(expr) = multiplier { expr.clone() } else { PicusExpr::Const(1) }; + let aux_modules = BTreeMap::new(); + Self { + preprocessed, + main: RowMajorMatrix::new(main, width), + public_values, + picus_module, + aux_modules, + chips, + multiplier, + extract_modularly: false, + concrete_pending_tasks: Vec::new(), + symbolic_pending_tasks: Vec::new(), + } + } + + /// Gets a chip by name or panics if no chip is found. Kept as a slice since the number of chips is small + /// < 200 + pub fn get_chip(&self, name: &str) -> &'chips Chip { + self.chips + .iter() + .find(|c| c.name() == name) + .unwrap_or_else(|| panic!("No chip found named {name}")) + } + + // Picus does not have native support for interactions so we need to convert the interaction + // to Picus constructs. Most byte interactions appear to be range constraints + fn handle_byte_interaction(&mut self, multiplicity: PicusExpr, values: &Vec) { + match values[0] { + PicusExpr::Const(v) => { + if v == (ByteOpcode::U8Range as u64) { + for val in &values[1..] { + if let PicusExpr::Const(v) = val { + assert!(*v < 256); + continue; + } else { + self.picus_module.constraints.push(PicusConstraint::new_lt( + val.clone() * multiplicity.clone(), + 256.into(), + )) + } + } + } else if v == (ByteOpcode::U16Range as u64) { + for val in &values[1..] { + if let PicusExpr::Const(v) = val { + assert!(*v < 65536); + continue; + } else { + self.picus_module.constraints.push(PicusConstraint::new_lt( + val.clone() * multiplicity.clone(), + 65536.into(), + )) + } + } + } else if v == (ByteOpcode::MSB as u64) { + let msb = values[1].clone(); + let bytes = [values[3].clone(), values[4].clone()]; + let picus128_const = PicusExpr::Const(128); + for byte in &bytes { + if let PicusExpr::Const(0) = byte { + continue; + } + let fresh_picus_var: PicusExpr = fresh_picus_expr(); + self.picus_module.constraints.push(PicusConstraint::new_lt( + fresh_picus_var.clone(), + picus128_const.clone(), + )); + self.picus_module.constraints.push(PicusConstraint::Eq(Box::new( + msb.clone() * (msb.clone() - PicusExpr::Const(1)), + ))); + let decomp = + byte.clone() - (msb.clone() * picus128_const.clone() + fresh_picus_var); + self.picus_module.constraints.push(PicusConstraint::Eq(Box::new(decomp))); + } + } else if v == (ByteOpcode::ShrCarry as u64) { + if !self.aux_modules.contains_key("ShrCarry") { + let carry_module = PicusModule::build_empty("ShrCarry".to_string(), 2, 2); + self.aux_modules.insert("ShrCarry".to_string(), carry_module); + } + let shrcarry = PicusCall::new( + "ShrCarry".to_string(), + &[values[1].clone(), values[2].clone()], + &[values[3].clone(), values[4].clone()], + ); + self.picus_module.calls.push(shrcarry); + } else if v == (ByteOpcode::LTU as u64) { + let lt_const = PicusConstraint::new_lt(values[2].clone(), values[3].clone()); + if let PicusExpr::Const(1) = values[1] { + self.picus_module.constraints.push(lt_const); + } else { + let bit_const = PicusConstraint::new_bit(values[1].clone()); + let eq_one = PicusConstraint::new_equality(values[1].clone(), 1.into()); + self.picus_module.constraints.extend_from_slice(&[ + PicusConstraint::Iff(Box::new(eq_one), Box::new(lt_const)), + bit_const, + ]); + } + } else if v == (ByteOpcode::AND as u64) { + println!("values: {values:#?}"); + if let PicusExpr::Const(127) = values[4] { + let var_hi = fresh_picus_expr(); + self.picus_module + .constraints + .push(PicusConstraint::new_lt(values[1].clone(), 128.into())); + self.picus_module + .constraints + .push(PicusConstraint::new_bit(var_hi.clone())); + self.picus_module.constraints.push(PicusConstraint::new_equality( + values[3].clone(), + var_hi * 128 + values[1].clone(), + )); + } + } else { + panic!("Unhandled byte interaction") + } + } + // TODO: It might be fine if the first argument isn't a constant. We need to multiply the values + // in the interaction with the multiplicities + _ => panic!("Byte interaction but first argument isn't a constant"), + } + } + + // The receive instruction interaction is used to determine which columns are inputs/outputs. + // In particular, the following values correspond to inputs and outputs: + // - values[2] -> pc (input) + // - values[3] -> next_pc (output) + // - values[6-9] -> a (output) + // - values[10-13] -> b (input) + // - values[14-17] -> c (input) + // - TODO (Add high and low) + fn handle_receive_instruction(&mut self, multiplicity: PicusExpr, values: &[PicusExpr]) { + // Creating a fresh var because picus outputs need to be variables. + // When performing partial evaluation, + let next_pc_out = fresh_picus_expr(); + let eq_mul = |multiplicity: &PicusExpr, val: &PicusExpr, var: &PicusExpr| { + PicusConstraint::new_equality(var.clone(), val.clone() * multiplicity.clone()) + }; + self.picus_module.outputs.push(next_pc_out.clone()); + self.picus_module.constraints.push(eq_mul(&multiplicity, &values[3], &next_pc_out)); + // If this is a sequential instruction then we can assume next-pc is deterministic as we will check its + // determinism in the CPU chip. Otherwise, we have to prove it is deterministic. The flag for specifying the + // if the instruction is sequential is stored at index 27. + if let PicusExpr::Const(1) = values[27].clone() { + self.picus_module.assume_deterministic.push(next_pc_out); + } + // We need to mark some of the register values as inputs and other values as outputs. + // In particular, the parameters `b` and `c` to `receive_instruction` are inputs and + // parameter `a` is an output. `b` and `c` are at indexes 10-13 and 14-17 in `values` whereas + // `a` is at indexes 6-9. As in the code above, we need to create variables for the outputs since + // Picus requires the inputs and outputs to be variables. + for value in values.iter().take(10).skip(6) { + let a_var = fresh_picus_expr(); + self.picus_module.outputs.push(a_var.clone()); + self.picus_module.constraints.push(eq_mul(&multiplicity, value, &a_var)); + } + for value in values.iter().take(14).skip(10) { + let b_var = fresh_picus_expr(); + self.picus_module.inputs.push(b_var.clone()); + self.picus_module.constraints.push(eq_mul(&multiplicity, value, &b_var)); + } + for value in values.iter().take(18).skip(14) { + let c_var = fresh_picus_expr(); + self.picus_module.inputs.push(c_var.clone()); + self.picus_module.constraints.push(eq_mul(&multiplicity, value, &c_var)); + } + } + + fn get_main_vars_for_call(&mut self, message_values: &[PicusExpr]) -> Option> { + println!("MESSAGE VALUES: {message_values:?}"); + let opcode_spec = match message_values[6].clone() { + PicusExpr::Const(v) => { + assert!(v < Opcode::UNIMPL as u64); + spec_for(Opcode::try_from(v as u8).unwrap()) + } + _ => return None, + }; + let target_chip = self.get_chip(opcode_spec.chip); + let mut target_main_vals: Vec = + (0..target_chip.air.width()).map(|_| fresh_picus_var()).collect(); + + let target_picus_info = target_chip.picus_info(); + println!("Target picus info: {target_picus_info:?}"); + for (slice, name) in opcode_spec.arg_to_colname { + println!("Name: {name}"); + let colrange = target_picus_info.name_to_colrange.get(*name).unwrap(); + match *slice { + IndexSlice::Range { start, end } => { + assert!(colrange.1 - colrange.0 >= end - start); + for i in start..end { + if let PicusExpr::Var(v) = message_values[i].clone() { + target_main_vals[colrange.0 + i - start] = PicusAtom::Var(v); + } else { + let id = fresh_picus_var_id(); + let fresh_var = PicusAtom::Var(id); + self.picus_module.constraints.push(PicusConstraint::new_equality( + PicusExpr::Var(id), + message_values[i].clone(), + )); + target_main_vals[colrange.0 + i - start] = fresh_var; + } + } + } + IndexSlice::Single(col) => { + assert_eq!(colrange.1 - colrange.0, 1); + if let PicusExpr::Var(v) = message_values[col].clone() { + target_main_vals[colrange.0] = PicusAtom::Var(v); + } else { + let fresh_var = fresh_picus_var_id(); + self.picus_module.constraints.push(PicusConstraint::new_equality( + PicusExpr::Var(fresh_var), + message_values[col].clone(), + )); + target_main_vals[colrange.0] = PicusAtom::Var(fresh_var); + } + } + } + } + println!("Target main vals: {target_main_vals:?}"); + Some(target_main_vals) + } +} + +impl<'chips, A: MachineAir> PairBuilder for PicusBuilder<'chips, A> { + fn preprocessed(&self) -> Self::M { + todo!() + } +} + +impl<'chips, A: MachineAir> AirBuilderWithPublicValues for PicusBuilder<'chips, A> { + type PublicVar = PicusAtom; + + fn public_values(&self) -> &[Self::PublicVar] { + todo!() + } +} + +impl<'chips, A: MachineAir> MessageBuilder> for PicusBuilder<'chips, A> { + fn send(&mut self, message: AirLookup, _scope: zkm_stark::LookupScope) { + match message.kind { + LookupKind::Byte => { + self.handle_byte_interaction(message.multiplicity, &message.values); + } + LookupKind::Memory => { + // TODO: fill in + } + LookupKind::Instruction => { + let opcode_spec = match message.values[6].clone() { + PicusExpr::Const(v) => { + assert!(v < Opcode::UNIMPL as u64); + spec_for(Opcode::try_from(v as u8).unwrap()) + } + _ => panic!("Expected opcode val to be a constant: Got: {}", message.values[6]), + }; + let target_chip = self.get_chip(opcode_spec.chip); + println!("OPCODE SPEC: {:?}", opcode_spec.chip); + let main_vars = self.get_main_vars_for_call(&message.values); + if let Some(vars) = main_vars { + self.concrete_pending_tasks.push(ConcretePendingTask { + chip_name: target_chip.name(), + main_vars: vars, + multiplicity: message.multiplicity, + selector: opcode_spec.selector.to_string(), + }); + } else { + self.symbolic_pending_tasks.push(SymbolicPendingTask { + selector: message.values[6].clone(), + multiplicity: message.multiplicity, + }) + } + } + _ => todo!("handle send: {}", message.kind), + } + } + + fn receive(&mut self, message: AirLookup, _scope: zkm_stark::LookupScope) { + // initialize another chip + // call eval with builder? + match message.kind { + LookupKind::Instruction => { + self.handle_receive_instruction(message.multiplicity, &message.values); + } + LookupKind::Memory => { + // TODO: fill in + } + _ => todo!("handle receive: {}", message.kind), + } + } +} + +impl<'chips, A: MachineAir> AirBuilder for PicusBuilder<'chips, A> { + type F = Felt; + type Var = PicusAtom; + type Expr = PicusExpr; + + type M = RowMajorMatrix; + + fn main(&self) -> Self::M { + self.main.clone() + } + + fn is_first_row(&self) -> Self::Expr { + todo!() + } + + fn is_last_row(&self) -> Self::Expr { + todo!() + } + + fn is_transition_window(&self, _size: usize) -> Self::Expr { + todo!() + } + + fn assert_zero>(&mut self, x: I) { + self.picus_module.constraints.push(PicusConstraint::Eq(Box::new(x.into()))) + } +} diff --git a/crates/recursion/core/src/machine.rs b/crates/recursion/core/src/machine.rs index 00e6124a..2898f9f6 100644 --- a/crates/recursion/core/src/machine.rs +++ b/crates/recursion/core/src/machine.rs @@ -3,7 +3,7 @@ use std::ops::{Add, AddAssign}; use hashbrown::HashMap; use p3_field::{extension::BinomiallyExtendable, PrimeField32}; use zkm_stark::{ - air::{LookupScope, MachineAir}, + air::{LookupScope, MachineAir, PicusInfo}, shape::OrderedShape, Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS, }; diff --git a/crates/stark/src/air/machine.rs b/crates/stark/src/air/machine.rs index 1b2942eb..bed1c4ec 100644 --- a/crates/stark/src/air/machine.rs +++ b/crates/stark/src/air/machine.rs @@ -4,7 +4,7 @@ use p3_air::BaseAir; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; -use crate::{septic_digest::SepticDigest, MachineRecord}; +use crate::{septic_digest::SepticDigest, MachineRecord, PicusInfo}; pub use zkm_derive::MachineAir; @@ -77,6 +77,15 @@ pub trait MachineAir: BaseAir + 'static + Send + Sync { fn local_only(&self) -> bool { false } + + /// Returns information about Picus annotations on AIR columns. + /// + /// This includes: + /// - Input ranges: columns marked with `#[picus(input)]` + /// - Selector indices: columns marked with `#[picus(selector)]` + fn picus_info(&self) -> PicusInfo { + PicusInfo::default() + } } /// A program that defines the control flow of a machine through a program counter. diff --git a/crates/stark/src/air/mod.rs b/crates/stark/src/air/mod.rs index ab37207c..1e101023 100644 --- a/crates/stark/src/air/mod.rs +++ b/crates/stark/src/air/mod.rs @@ -4,6 +4,7 @@ mod builder; mod extension; mod lookup; mod machine; +mod picus_info; mod polynomial; mod public_values; mod sub_builder; @@ -12,6 +13,7 @@ pub use builder::*; pub use extension::*; pub use lookup::*; pub use machine::*; +pub use picus_info::*; pub use polynomial::*; pub use public_values::*; pub use sub_builder::*; diff --git a/crates/stark/src/air/picus_info.rs b/crates/stark/src/air/picus_info.rs new file mode 100644 index 00000000..940f009e --- /dev/null +++ b/crates/stark/src/air/picus_info.rs @@ -0,0 +1,32 @@ +use std::collections::HashMap; +/// Information about Picus annotations on AIR columns. +#[derive(Debug, Clone, Default)] +pub struct PicusInfo { + /// Column to name mapping. column i will get map to the string "f_i" where f is the field + /// in the column struct that contains column i + pub col_to_name: HashMap, + /// Name to column ranges + pub name_to_colrange: HashMap, + /// Ranges of columns marked as inputs. + /// Each tuple contains (`start_index`, `end_index`, `field_name`) where: + /// - `start_index` is the first column index (inclusive) + /// - `end_index` is the last column index (exclusive) + /// - `field_name` is the name of the field + pub input_ranges: Vec<(usize, usize, String)>, + + /// Ranges of columns marked as outputs. + /// Each tuple contains (`start_index`, `end_index`, `field_name`) where: + /// - `start_index` is the first column index (inclusive) + /// - `end_index` is the last column index (exclusive) + /// - `field_name` is the name of the field + pub output_ranges: Vec<(usize, usize, String)>, + + /// Indices of columns marked as selectors. + /// Each tuple contains (`column_index`, `field_name`) where: + /// - `column_index` is the index of the selector column + /// - `field_name` is the name of the field + pub selector_indices: Vec<(usize, String)>, + + /// Indices of columns marked as `is_real` + pub is_real_index: Option, +} diff --git a/crates/stark/src/chip.rs b/crates/stark/src/chip.rs index cc3f844e..a1093d57 100644 --- a/crates/stark/src/chip.rs +++ b/crates/stark/src/chip.rs @@ -10,7 +10,7 @@ use crate::{ air::{LookupScope, MachineAir, MultiTableAirBuilder, ZKMAirBuilder}, local_permutation_trace_width, lookup::{Lookup, LookupBuilder, LookupKind}, - scoped_lookups, + scoped_lookups, PicusInfo, }; use super::{eval_permutation_constraints, generate_permutation_trace, PROOF_MAX_NUM_PVS}; @@ -247,6 +247,10 @@ where fn local_only(&self) -> bool { self.air.local_only() } + + fn picus_info(&self) -> PicusInfo { + self.air.picus_info() + } } // Implement AIR directly on Chip, evaluating both execution and permutation constraints.