diff --git a/Cargo.toml b/Cargo.toml index 741a1415..399a0166 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ serde = { version = "1.0.202", features = ["derive"] } thread_local = "1.1.8" generational-box = "0.5.6" serde_json = "1.0.140" -egg = "0.9.5" egglog = "1.0.0" egglog-ast = "1.0.0" egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]} @@ -36,6 +35,7 @@ paste = "1.0.15" pretty-duration = "0.1.1" anyhow = "1.0" graphviz-rust = { version = "0.9", default-features = false} +lru = "0.16.2" [workspace.package] edition = "2024" diff --git a/examples/visualization/src/main.rs b/examples/visualization/src/main.rs index 3bf4d1f7..694fce02 100644 --- a/examples/visualization/src/main.rs +++ b/examples/visualization/src/main.rs @@ -44,30 +44,9 @@ fn main() { let mut egglog_obj: EGraph = egglog::EGraph::default(); - // setup the rules and datatypes - egglog_obj - .parse_and_run_program(None, luminal::egglog_utils::BASE) - .unwrap(); - egglog_obj - .parse_and_run_program(None, &luminal::egglog_utils::op_defs_string(&ops)) - .unwrap(); - egglog_obj - .parse_and_run_program(None, &luminal::egglog_utils::op_rewrites_string(&ops)) - .unwrap(); - egglog_obj - .parse_and_run_program(None, luminal::egglog_utils::BASE_CLEANUP) - .unwrap(); - egglog_obj - .parse_and_run_program(None, &luminal::egglog_utils::op_cleanups_string(&ops)) - .unwrap(); - - // load the program - egglog_obj.parse_and_run_program(None, &program).unwrap(); - // run the graph - egglog_obj - .parse_and_run_program(None, luminal::egglog_utils::RUN_SCHEDULE) - .unwrap(); + let code = luminal::egglog_utils::full_egglog(&program, &ops, true); + egglog_obj.parse_and_run_program(None, &code).unwrap(); // EGraph Optimization Complete println!("Visualizing EGraph"); diff --git a/src/egglog_utils/base.egg b/src/egglog_utils/base.egg index 5af04dbb..1985e22f 100644 --- a/src/egglog_utils/base.egg +++ b/src/egglog_utils/base.egg @@ -1,5 +1,7 @@ -; -------- SYMBOLIC ALGEBRA ------- (ruleset expr) +(ruleset cleanup) + +; -------- SYMBOLIC ALGEBRA ------- (datatype* (Expression (MNum i64) @@ -20,7 +22,6 @@ (MLt Expression Expression) (MFloorTo Expression Expression) (MReplace Expression Expression Expression) - (MAccum String) ) ; eqsort list for vectors of Expression @@ -42,7 +43,7 @@ ) ; ---- Algebraic rewrites ---- -(rewrite (MAdd a b) (MAdd b a) :ruleset expr) +;(rewrite (MAdd a b) (MAdd b a) :ruleset expr) ; communativity leads to some explosions ;(rewrite (MMul a b) (MMul b a) :ruleset expr) (rewrite (MAdd (MAdd a b) c) (MAdd a (MAdd b c)) :ruleset expr) @@ -58,9 +59,10 @@ ) ( (union ?e (MNum ?prod)) - (delete (MMul (MNum ?a) (MNum ?b))) + (subsume (MMul (MNum ?a) (MNum ?b))) ) -) ; why does this explode??? + :ruleset expr +) (rewrite (MDiv (MNum a) (MNum b)) (MNum (/ a b)) :when ((!= 0 b) (= 0 (% a b))) :ruleset expr) (rewrite (MCeilDiv (MNum a) (MNum b)) (MNum (/ a b)) :when ((!= 0 b) (= 0 (% a b))) :ruleset expr) (rewrite (MMax (MNum a) (MNum b)) (MNum (max a b)) :ruleset expr) @@ -68,17 +70,44 @@ (rewrite (MAnd (MNum a) (MNum b)) (MNum (& a b)) :ruleset expr) (rewrite (MFloat -1.0) (MNum -1) :ruleset expr) (rewrite (MNum -1) (MFloat -1.0) :ruleset expr) -(rewrite (MDiv (MMul ?x (MNum ?a)) (MNum ?b)) (MMul ?x (MNum (/ ?a ?b))) :when ((< ?b ?a) (= (% ?a ?b) 0))) ; why does this explode??? +(rewrite (MDiv (MMul ?x (MNum ?a)) (MNum ?b)) (MMul ?x (MNum (/ ?a ?b))) :when ((< ?b ?a) (= (% ?a ?b) 0)) :ruleset expr) ; why does this explode??? (rewrite (MAdd a (MNum 0)) a :ruleset expr) (rule ((= ?e (MMul ?a (MNum 1)))) ((union ?e ?a)) :ruleset expr) -;(rewrite (MMul ?a (MNum 0)) (MNum 0) :ruleset expr) ; Why does this cause an infinite loop? +(rule ((= ?e (MMul ?a (MNum 0)))) ((union ?e (MNum 0)) (subsume (MMul ?a (MNum 0)))) :ruleset expr) (rewrite (MDiv a (MNum 1)) a :ruleset expr) (rewrite (MMod (MMul ?x ?y) ?y) (MNum 0) :ruleset expr) (rewrite (MMod (MMod ?x (MNum ?y)) (MNum ?z)) (MMod ?x (MNum ?y)) :when ((>= ?z ?y) (= 0 (% ?y ?z))) :ruleset expr) (rewrite (MMod (MMod ?x (MNum ?y)) (MNum ?z)) (MMod ?x (MNum ?z)) :when ((>= ?y ?z) (= 0 (% ?z ?y))) :ruleset expr) +(rewrite (MDiv (MDiv a b) c) (MDiv a (MMul b c)) :ruleset expr) +(rewrite (MAdd (MDiv a b) c) (MDiv (MAdd a (MMul c b)) b) :ruleset expr) +(rewrite (MAdd a (MSub b a)) b :ruleset expr) +(rewrite (MAdd (MSub b a) a) b :ruleset expr) +(rewrite (MSub a a) (MNum 0) :ruleset expr) +(rewrite + (MAdd (MSub a (MNum ?b)) (MNum ?c)) + (MSub a (MNum (- ?b ?c))) + :ruleset expr +) +(rewrite + (MAdd (MNum ?c) (MSub a (MNum ?b))) + (MSub a (MNum (- ?b ?c))) + :ruleset expr +) +(rewrite + (MSub (MAdd a (MNum ?b)) (MNum ?c)) + (MAdd a (MNum (- ?b ?c))) + :ruleset expr +) +(rewrite + (MSub (MSub a (MNum ?b)) (MNum ?c)) + (MSub a (MNum (+ ?b ?c))) + :ruleset expr +) +(rewrite (MAdd (MMul a b) (MMul a c)) (MMul a (MAdd b c)) :ruleset expr) +(rewrite (MAdd a a) (MMul (MNum 2) a) :ruleset expr) ; ---- Replacement over expressions ---- (rewrite (MReplace ?x ?y ?z) ?z :when ((= ?x ?y)) :ruleset expr) @@ -92,7 +121,6 @@ (rewrite (MReplace (MMax ?a ?b) ?x ?y) (MMax (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) (rewrite (MReplace (MFloorTo ?a ?b) ?x ?y) (MFloorTo (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) (rewrite (MReplace (MNum ?n) ?x ?y) (MNum ?n) :ruleset expr) -(rewrite (MReplace (MAccum ?acc) ?x ?y) (MAccum ?acc) :ruleset expr) (rewrite (MReplace (MVar ?z) ?find ?replace) (MVar ?z) :when ((!= ?find (MVar ?z))) :ruleset expr) (rewrite (MReplace (MIter) ?find ?replace) (MIter) :when ((!= ?find (MIter))) :ruleset expr) @@ -124,11 +152,11 @@ (= ?n_elems (n_elements ?other)) ) ( - (union ?e (ECons (MMul (MIter) ?n_elems) (RowMajor ?other))) + (union ?e (ECons ?n_elems (RowMajor ?other))) ) :ruleset expr ) -(rewrite (RowMajor (ECons ?dim (ENil))) (ECons (MIter) (ENil)) :ruleset expr) +(rewrite (RowMajor (ECons ?dim (ENil))) (ECons (MNum 1) (ENil)) :ruleset expr) (rewrite (MReplaceList (ECons ?expr ?list) ?from ?to) (ECons (MReplace ?expr ?from ?to) (MReplaceList ?list ?from ?to)) :ruleset expr) (rule @@ -170,4 +198,4 @@ (union ?e (ECons ?expr (RemoveNthFromEnd ?list ?ind))) ) :ruleset expr -) \ No newline at end of file +) diff --git a/src/egglog_utils/egglog_template.egg b/src/egglog_utils/egglog_template.egg deleted file mode 100644 index 0675e6b4..00000000 --- a/src/egglog_utils/egglog_template.egg +++ /dev/null @@ -1,239 +0,0 @@ -; -------- SYMBOLIC ALGEBRA ------- -(ruleset expr) -(datatype* - (Expression - (MNum i64) - (MFloat f64) - (MIter) - (MVar String) - (MAdd Expression Expression) - (MSub Expression Expression) - (MMul Expression Expression) - (MCeilDiv Expression Expression) - (MDiv Expression Expression) - (MMod Expression Expression) - (MMin Expression Expression) - (MMax Expression Expression) - (MAnd Expression Expression) - (MOr Expression Expression) - (MGte Expression Expression) - (MLt Expression Expression) - (MFloorTo Expression Expression) - (MReplace Expression Expression Expression) - (MAccum String) - ) - - ; eqsort list for vectors of Expression - (EList - (ECons Expression EList) - (ENil) - (MReplaceList EList Expression Expression) - (ReplaceNthFromEnd EList Expression i64) - (RemoveNthFromEnd EList i64) - (RowMajor EList) - ) - - (DType - (F32) - (F16) - (Bf16) - (Int) - ) -) - -; ---- Algebraic rewrites ---- -(rewrite (MAdd a b) (MAdd b a) :ruleset expr) -;(rewrite (MMul a b) (MMul b a) :ruleset expr) - -(rewrite (MAdd (MAdd a b) c) (MAdd a (MAdd b c)) :ruleset expr) -;(rewrite (MMul (MMul a b) c) (MMul a (MMul b c)) :ruleset expr) - -(rewrite (MAdd (MNum a) (MNum b)) (MNum (+ a b)) :ruleset expr) -(rewrite (MSub (MNum a) (MNum b)) (MNum (- a b)) :ruleset expr) -; multiply const folding -(rule - ( - (= ?e (MMul (MNum ?a) (MNum ?b))) - (= ?prod (* ?a ?b)) - ) - ( - (union ?e (MNum ?prod)) - (delete (MMul (MNum ?a) (MNum ?b))) - ) -) ; why does this explode??? -(rewrite (MDiv (MNum a) (MNum b)) (MNum (/ a b)) :when ((!= 0 b) (= 0 (% a b))) :ruleset expr) -(rewrite (MCeilDiv (MNum a) (MNum b)) (MNum (/ a b)) :when ((!= 0 b) (= 0 (% a b))) :ruleset expr) -(rewrite (MMax (MNum a) (MNum b)) (MNum (max a b)) :ruleset expr) -(rewrite (MMin (MNum a) (MNum b)) (MNum (min a b)) :ruleset expr) -(rewrite (MAnd (MNum a) (MNum b)) (MNum (& a b)) :ruleset expr) -(rewrite (MFloat -1.0) (MNum -1) :ruleset expr) -(rewrite (MNum -1) (MFloat -1.0) :ruleset expr) -(rewrite (MDiv (MMul ?x (MNum ?a)) (MNum ?b)) (MMul ?x (MNum (/ ?a ?b))) :when ((< ?b ?a) (= (% ?a ?b) 0))) ; why does this explode??? - -(rewrite (MAdd a (MNum 0)) a :ruleset expr) -(rule ((= ?e (MMul ?a (MNum 1)))) ((union ?e ?a)) :ruleset expr) -;(rewrite (MMul ?a (MNum 0)) (MNum 0) :ruleset expr) ; Why does this cause an infinite loop? -(rewrite (MDiv a (MNum 1)) a :ruleset expr) -(rewrite (MMod (MMul ?x ?y) ?y) (MNum 0) :ruleset expr) -(rewrite (MMod (MMod ?x (MNum ?y)) (MNum ?z)) (MMod ?x (MNum ?y)) - :when ((>= ?z ?y) (= 0 (% ?y ?z))) :ruleset expr) -(rewrite (MMod (MMod ?x (MNum ?y)) (MNum ?z)) (MMod ?x (MNum ?z)) - :when ((>= ?y ?z) (= 0 (% ?z ?y))) :ruleset expr) - -; ---- Replacement over expressions ---- -(rewrite (MReplace ?x ?y ?z) ?z :when ((= ?x ?y)) :ruleset expr) -(rewrite (MReplace (MAdd ?a ?b) ?x ?y) (MAdd (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MSub ?a ?b) ?x ?y) (MSub (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MMul ?a ?b) ?x ?y) (MMul (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MDiv ?a ?b) ?x ?y) (MDiv (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MCeilDiv ?a ?b) ?x ?y) (MCeilDiv (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MMod ?a ?b) ?x ?y) (MMod (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MMin ?a ?b) ?x ?y) (MMin (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MMax ?a ?b) ?x ?y) (MMax (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MFloorTo ?a ?b) ?x ?y) (MFloorTo (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) -(rewrite (MReplace (MNum ?n) ?x ?y) (MNum ?n) :ruleset expr) -(rewrite (MReplace (MAccum ?acc) ?x ?y) (MAccum ?acc) :ruleset expr) -(rewrite (MReplace (MVar ?z) ?find ?replace) (MVar ?z) :when ((!= ?find (MVar ?z))) :ruleset expr) -(rewrite (MReplace (MIter) ?find ?replace) (MIter) :when ((!= ?find (MIter))) :ruleset expr) - -; EList helper functions -(function len (EList) i64 :merge new) -(rule ((= ?e (ENil))) ((set (len ?e) 0)) :ruleset expr) -(rule ((= ?e (ECons ?expr ?list)) (= ?prev_len (len ?list))) ((set (len ?e) (+ ?prev_len 1))) :ruleset expr) - -(function nth_from_end (EList i64) Expression :merge new) -(rule ((= ?e (ECons ?expr ?list)) (= ?list_len (len ?list))) ((set (nth_from_end ?e ?list_len) ?expr)) :ruleset expr) -(rule ((= ?e (ECons ?expr ?list)) (= ?other_nth (nth_from_end ?list ?n))) ((set (nth_from_end ?e ?n) ?other_nth)) :ruleset expr) - -(function n_elements (EList) Expression :merge new) -(rule ((= ?e (ENil))) ((set (n_elements ?e) (MNum 1))) :ruleset expr) -(rule - ( - (= ?e (ECons ?dim ?other)) - (= ?other_elems (n_elements ?other)) - ) - ((set (n_elements ?e) (MMul ?dim ?other_elems))) - :ruleset expr -) - -(rule - ( - (= ?other (ECons ?other_dim ?other_other)) - (= ?list (ECons ?d ?other)) - (= ?e (RowMajor ?list)) - (= ?n_elems (n_elements ?other)) - ) - ( - (union ?e (ECons ?n_elems (RowMajor ?other))) - ) - :ruleset expr -) -(rewrite (RowMajor (ECons ?dim (ENil))) (ECons (MNum 1) (ENil)) :ruleset expr) - -(rewrite (MReplaceList (ECons ?expr ?list) ?from ?to) (ECons (MReplace ?expr ?from ?to) (MReplaceList ?list ?from ?to)) :ruleset expr) -(rule - ( - (= ?e (ReplaceNthFromEnd (ECons ?expr ?list) ?to ?ind)) - (= ?ind (len ?list)) - ) - ( - (union ?e (ECons ?to ?list)) - ) - :ruleset expr -) -(rule - ( - (= ?e (ReplaceNthFromEnd (ECons ?expr ?list) ?to ?ind)) - (< ?ind (len ?list)) - ) - ( - (union ?e (ECons ?expr (ReplaceNthFromEnd ?list ?to ?ind))) - ) - :ruleset expr -) -(rule - ( - (= ?e (RemoveNthFromEnd (ECons ?expr ?list) ?ind)) - (= ?ind (len ?list)) - ) - ( - (union ?e ?list) - ) - :ruleset expr -) -(rule - ( - (= ?e (RemoveNthFromEnd (ECons ?expr ?list) ?ind)) - (< ?ind (len ?list)) - ) - ( - (union ?e (ECons ?expr (RemoveNthFromEnd ?list ?ind))) - ) - :ruleset expr -) - -(datatype IR - (OutputJoin IR IR) - {ops} -) - -(function dtype (IR) DType :merge new) - -{rewrites} - -; CLEANUP -(ruleset base_cleanup) -(ruleset cleanup) -(rule - ((= ?m (MReplace ?a ?b ?c))) - ((delete (MReplace ?a ?b ?c))) - :ruleset base_cleanup -) -(rule - ((= ?m (MReplaceList ?a ?b ?c))) - ((delete (MReplaceList ?a ?b ?c))) - :ruleset base_cleanup -) -(rule - ((= ?m (ReplaceNthFromEnd ?a ?b ?c))) - ((delete (ReplaceNthFromEnd ?a ?b ?c))) - :ruleset base_cleanup -) -(rule - ((= ?m (RemoveNthFromEnd ?a ?b))) - ((delete (RemoveNthFromEnd ?a ?b))) - :ruleset base_cleanup -) -(rule - ((= ?m (len ?x))) - ((delete (len ?x))) - :ruleset base_cleanup -) -(rule - ((= ?m (nth_from_end ?x ?y))) - ((delete (nth_from_end ?x ?y))) - :ruleset base_cleanup -) -(rule - ((= ?m (n_elements ?x))) - ((delete (n_elements ?x))) - :ruleset base_cleanup -) -(rule - ((= ?m (RowMajor ?x))) - ((delete (RowMajor ?x))) - :ruleset base_cleanup -) -{cleanups} - -{program} - -(run-schedule - (repeat 10 - (saturate expr) - (run) - ) - (saturate expr) - (saturate base_cleanup) - (saturate cleanup) -) diff --git a/src/egglog_utils/mod.rs b/src/egglog_utils/mod.rs index 4a2dfaf7..57c11401 100644 --- a/src/egglog_utils/mod.rs +++ b/src/egglog_utils/mod.rs @@ -3,14 +3,14 @@ use itertools::Itertools; use std::{str, sync::Arc}; pub const BASE: &str = include_str!("base.egg"); -pub const RUN_SCHEDULE: &str = include_str!("run_schedule.egg"); pub const BASE_CLEANUP: &str = include_str!("base_cleanup.egg"); -pub const EGGLOG_TEMPLATE: &str = include_str!("egglog_template.egg"); +pub const RUN_SCHEDULE: &str = include_str!("run_schedule.egg"); -pub fn op_defs_string(ops: &[Arc>]) -> String { +fn op_defs_string(ops: &[Arc>]) -> String { format!( " (datatype IR + (OutputJoin IR IR) {} ) (function dtype (IR) DType :merge new) @@ -28,14 +28,13 @@ pub fn op_defs_string(ops: &[Arc>]) -> String { ) } -pub fn op_rewrites_string(ops: &[Arc>]) -> String { +fn op_rewrites_string(ops: &[Arc>]) -> String { ops.iter().flat_map(|o| o.rewrites()).join("\n") } -pub fn op_cleanups_string(ops: &[Arc>]) -> String { +fn op_cleanups_string(ops: &[Arc>]) -> String { format!( " - (ruleset cleanup) {} ", ops.iter() @@ -54,3 +53,21 @@ pub fn op_cleanups_string(ops: &[Arc>]) -> String { .join("\n") ) } + +pub fn full_egglog(program: &str, ops: &[Arc>], cleanup: bool) -> String { + let mut code = BASE.to_string(); + code.push_str(&op_defs_string(ops)); + code.push('\n'); + code.push_str(&op_rewrites_string(ops)); + code.push('\n'); + if cleanup { + code.push_str(&op_cleanups_string(ops)); + code.push('\n'); + } + code.push_str(BASE_CLEANUP); + code.push('\n'); + code.push_str(program); + code.push('\n'); + code.push_str(RUN_SCHEDULE); + code +} diff --git a/src/graph.rs b/src/graph.rs index ae190a48..d418d72c 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -379,41 +379,7 @@ fn run_egglog( cleanup: bool, ) -> Result { let mut egraph = egglog::EGraph::default(); - let mut code = egglog_utils::EGGLOG_TEMPLATE.replace("{program}", program); - code = code.replace( - "{ops}", - &ops.iter() - .map(|o| { - let (name, body) = o.term(); - format!( - "({name} {})", - body.into_iter().map(|j| format!("{j:?}")).join(" ") - ) - }) - .join("\n"), - ); - code = code.replace( - "{rewrites}", - &ops.iter().map(|o| o.rewrites().join("\n")).join("\n"), - ); - code = code.replace( - "{cleanups}", - &ops.iter() - .filter(|op| op.cleanup() && cleanup) - .map(|o| { - let (name, body) = o.term(); - let body_terms = (0..body.len()).map(|i| (b'a' + i as u8) as char).join(" "); - format!( - "(rule - ((= ?m ({name} {body_terms}))) - ((delete ({name} {body_terms}))) - :ruleset cleanup - )" - ) - }) - .join("\n"), - ); - + let code = egglog_utils::full_egglog(program, ops, cleanup); let commands = egraph.parser.get_program_from_string(None, &code)?; let start = std::time::Instant::now(); let msgs = egraph.run_program(commands)?; diff --git a/src/shape/symbolic.rs b/src/shape/symbolic.rs index ba8b4043..cc92d936 100644 --- a/src/shape/symbolic.rs +++ b/src/shape/symbolic.rs @@ -1,25 +1,28 @@ -use egg::*; use generational_box::{AnyStorage, GenerationalBox, Owner, SyncStorage}; +use lru::LruCache; use rustc_hash::FxHashMap; use serde::{Serialize, Serializer}; use std::{ fmt::Debug, hash::Hash, + num::NonZeroUsize, ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, }, - sync::OnceLock, + sync::{Mutex, OnceLock}, }; -use symbolic_expressions::Sexp; + +use crate::{egglog_utils, graph::extract_expr, serialized_egraph::SerializedEGraph}; +use egglog::{prelude::RustSpan, var}; +use egglog_ast::span::Span; type ExprBox = GenerationalBox, SyncStorage>; static EXPR_OWNER: OnceLock> = OnceLock::new(); +static SIMPLIFY_CACHE: OnceLock>> = OnceLock::new(); -pub fn expression_owner() -> &'static Owner { - EXPR_OWNER.get_or_init(SyncStorage::owner) -} +const MAX_CACHED_SIMPLIFICATIONS: usize = 100_000; #[derive(Copy, Clone)] pub struct Expression { @@ -39,14 +42,10 @@ impl Serialize for Expression { impl Expression { pub fn new(terms: Vec) -> Self { Self { - terms: expression_owner().insert(terms), + terms: EXPR_OWNER.get_or_init(SyncStorage::owner).insert(terms), } } - pub fn is_acc(&self) -> bool { - self.terms.read().iter().any(|i| matches!(i, Term::Acc(_))) - } - pub fn is_dynamic(&self) -> bool { self.terms.read().iter().any(|i| { if let Term::Var(v) = i { @@ -101,7 +100,6 @@ pub enum Term { Or, Gte, Lt, - Acc(char), } impl std::fmt::Debug for Term { @@ -121,7 +119,6 @@ impl std::fmt::Debug for Term { Term::Or => write!(f, "||"), Term::Gte => write!(f, ">="), Term::Lt => write!(f, "<"), - Term::Acc(s) => write!(f, "{s}"), } } } @@ -190,6 +187,7 @@ where for<'a> &'a T: Into, { fn eq(&self, other: &T) -> bool { + // Equals-approximation. For proper equality checking, use .egglog_equals (more expensive) *self.terms.read() == *other.into().terms.read() } } @@ -209,7 +207,6 @@ impl Debug for Expression { let new_symbol = match term { Term::Num(n) => n.to_string(), Term::Var(c) => c.to_string(), - Term::Acc(c) => format!("Acc({c})"), Term::Max => format!( "max({}, {})", symbols.pop().unwrap(), @@ -232,6 +229,12 @@ impl Debug for Expression { } } +impl std::fmt::Display for Expression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + impl Expression { pub fn to_egglog(&self) -> String { let mut symbols = vec![]; @@ -239,7 +242,6 @@ impl Expression { let new_symbol = match term { Term::Num(n) => format!("(MNum {n})"), Term::Var(c) => format!("(MVar \"{c}\")"), - Term::Acc(s) => format!("(MAccum \"{s}\")"), Term::Max => format!( "(MMax {} {})", symbols.pop().unwrap(), @@ -268,7 +270,6 @@ impl Expression { let new_symbol = match term { Term::Num(n) => n.to_string(), Term::Var(c) => format!("{}const_{c}", if *c == 'z' { "" } else { "*" }), - Term::Acc(_) => unreachable!(), Term::Max => format!( "max((int){}, (int){})", symbols.pop().unwrap(), @@ -305,35 +306,14 @@ impl Expression { } symbols.pop().unwrap_or_default() } -} - -impl std::fmt::Display for Expression { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl Expression { /// Simplify the expression to its minimal terms + #[tracing::instrument(skip_all)] pub fn simplify(self) -> Self { if self.terms.read().len() == 1 { return self; } - egg_simplify(self, false) - } - - /// Simplify the expression to its minimal terms, using a cache to retrieve / store the simplification - #[allow(clippy::mutable_key_type)] - pub fn simplify_cache(self, cache: &mut FxHashMap) -> Self { - if let Some(s) = cache.get(&self) { - *s - } else { - let simplified = self.simplify(); - cache.insert(self, simplified); - simplified - } + egglog_simplify(self) } - pub fn as_num(&self) -> Option { if let Term::Num(n) = self.terms.read()[0] { if self.terms.read().len() == 1 { @@ -342,15 +322,12 @@ impl Expression { } None } - pub fn len(&self) -> usize { self.terms.read().len() } - pub fn is_empty(&self) -> bool { self.len() == 0 } - /// Minimum pub fn min(self, rhs: impl Into) -> Self { let rhs = rhs.into(); @@ -365,7 +342,6 @@ impl Expression { terms.push(Term::Min); Expression::new(terms) } - /// Maximum pub fn max>(self, rhs: E) -> Self { let rhs = rhs.into(); @@ -383,7 +359,6 @@ impl Expression { terms.push(Term::Max); Expression::new(terms) } - /// Greater than or equals pub fn gte>(self, rhs: E) -> Self { let rhs = rhs.into(); @@ -401,7 +376,6 @@ impl Expression { terms.push(Term::Gte); Expression::new(terms) } - /// Ceil Division pub fn ceil_div>(self, rhs: E) -> Self { let rhs = rhs.into(); @@ -410,7 +384,6 @@ impl Expression { terms.push(Term::CeilDiv); Expression::new(terms) } - /// Less than pub fn lt>(self, rhs: E) -> Self { let rhs = rhs.into(); @@ -432,7 +405,6 @@ impl Expression { terms.push(Term::Lt); Expression::new(terms) } - /// Substitute an expression for a variable pub fn substitute(self, var: char, expr: impl Into) -> Self { let mut new_terms = vec![]; @@ -451,9 +423,6 @@ impl Expression { } Expression::new(new_terms) } -} - -impl Expression { /// Evaluate the expression with no variables. Returns Some(value) if no variables are required, otherwise returns None. pub fn to_usize(&self) -> Option { self.exec(&FxHashMap::default()) @@ -468,7 +437,6 @@ impl Expression { for term in self.terms.read().iter() { match term { Term::Num(n) => stack.push(*n as i64), - Term::Acc(_) => stack.push(1), Term::Var(_) => stack.push(value as i64), _ => { let a = stack.pop().unwrap(); @@ -492,7 +460,6 @@ impl Expression { for term in self.terms.read().iter() { match term { Term::Num(n) => stack.push(*n as i64), - Term::Acc(_) => stack.push(1), Term::Var(c) => { #[allow(clippy::needless_borrow)] @@ -511,38 +478,6 @@ impl Expression { } stack.pop().map(|i| i as usize) } - /// Evaluate the expression given variables. - pub fn exec_float(&self, variables: &FxHashMap) -> Option { - self.exec_stack_float(variables, &mut Vec::new()) - } - /// Evaluate the expression given variables. This function requires a stack to be given for use as storage - pub fn exec_stack_float( - &self, - variables: &FxHashMap, - stack: &mut Vec, - ) -> Option { - for term in self.terms.read().iter() { - match term { - Term::Num(n) => stack.push(*n as f64), - Term::Acc(_) => stack.push(1.0), - Term::Var(c) => - { - #[allow(clippy::needless_borrow)] - if let Some(n) = variables.get(&c) { - stack.push(*n as f64) - } else { - return None; - } - } - _ => { - let a = stack.pop().unwrap(); - let b = stack.pop().unwrap(); - stack.push(term.as_float_op().unwrap()(a, b)); - } - } - } - stack.pop() - } /// Retrieve all symbols in the expression. pub fn to_symbols(&self) -> Vec { self.terms @@ -554,15 +489,7 @@ impl Expression { }) .collect() } - - /// Check if the '-' variable exists in the expression. - pub fn is_unknown(&self) -> bool { - self.terms - .read() - .iter() - .any(|t| matches!(t, Term::Var('-'))) - } - + /// Resolve all known variables from dyn map into real values pub fn resolve_vars(&mut self, dyn_map: &FxHashMap) { for term in self.terms.write().iter_mut() { if let Term::Var(v) = *term @@ -572,6 +499,45 @@ impl Expression { } } } + /// Run proper equality check inside egglog + #[tracing::instrument(skip_all)] + pub fn egglog_equal(self, rhs: impl Into) -> bool { + let lhs_expr = self.to_egglog(); + let rhs_expr = rhs.into().to_egglog(); + let mut program = String::new(); + program.push_str(egglog_utils::BASE); + program.push('\n'); + program.push_str(egglog_utils::BASE_CLEANUP); + program.push('\n'); + program.push_str(&format!("(let expr_lhs {lhs_expr})\n")); + program.push_str(&format!("(let expr_rhs {rhs_expr})\n")); + program.push_str( + "(run-schedule + (saturate expr) + (saturate base_cleanup) + (saturate cleanup) + )", + ); + program.push('\n'); + program.push_str("(check (= expr_lhs expr_rhs))\n"); + + let mut egraph = egglog::EGraph::default(); + let commands = egraph + .parser + .get_program_from_string(None, &program) + .expect("failed to parse egglog program"); + let span = tracing::span!(tracing::Level::INFO, "to_egglog"); + let _entered = span.enter(); + match egraph.run_program(commands) { + Ok(_) => true, + Err(err) => { + if matches!(err, egglog::Error::CheckError(_, _)) { + return false; + } + panic!("failed to run egglog program: {err}"); + } + } + } } impl From for Expression { @@ -951,296 +917,47 @@ impl> BitOrAssign for Expression { } } -define_language! { - enum SimpleLanguage { - Num(i32), - "+" = Add([Id; 2]), - "*" = Mul([Id; 2]), - Symbol(Symbol), - } -} - -fn luminal_to_egg(expr: &Expression) -> RecExpr { - let mut stack = Vec::new(); - - for term in expr.terms.read().iter() { - match term { - Term::Num(_) | Term::Var(_) => { - stack.push(symbolic_expressions::Sexp::String(format!("{term:?}"))) - } - Term::Acc(_) => stack.push(symbolic_expressions::Sexp::String("1".to_string())), - _ => { - let left = stack.pop().unwrap(); - let right = stack.pop().unwrap(); - let subexpr = symbolic_expressions::Sexp::List(vec![ - symbolic_expressions::Sexp::String(format!("{term:?}")), - left, - right, - ]); - stack.push(subexpr); - } - } - } - fn parse_sexp_into( - sexp: &Sexp, - expr: &mut RecExpr, - ) -> Result> { - match sexp { - Sexp::Empty => Err(egg::RecExprParseError::EmptySexp), - Sexp::String(s) => { - let node = L::from_op(s, vec![]).map_err(egg::RecExprParseError::BadOp)?; - Ok(expr.add(node)) - } - Sexp::List(list) if list.is_empty() => Err(egg::RecExprParseError::EmptySexp), - Sexp::List(list) => match &list[0] { - Sexp::Empty => unreachable!("Cannot be in head position"), - list @ Sexp::List(..) => Err(egg::RecExprParseError::HeadList(list.to_owned())), - Sexp::String(op) => { - let arg_ids: Vec = list[1..] - .iter() - .map(|s| parse_sexp_into(s, expr)) - .collect::>()?; - let node = L::from_op(op, arg_ids).map_err(egg::RecExprParseError::BadOp)?; - Ok(expr.add(node)) - } - }, - } - } - - let sexp = stack.pop().unwrap(); - let mut expr = RecExpr::default(); - parse_sexp_into(&sexp, &mut expr).unwrap(); - expr -} - -fn egg_to_luminal(expr: RecExpr) -> Expression { - fn create_postfix(expr: &[Math]) -> Vec { - match expr.last().unwrap() { - Math::Num(i) => vec![Term::Num(*i)], - Math::Add([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Add], - ] - .concat(), - Math::Sub([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Sub], - ] - .concat(), - Math::Mul([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Mul], - ] - .concat(), - Math::Div([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Div], - ] - .concat(), - Math::Mod([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Mod], - ] - .concat(), - Math::Min([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Min], - ] - .concat(), - Math::Max([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Max], - ] - .concat(), - Math::And([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::And], - ] - .concat(), - Math::Or([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Or], - ] - .concat(), - Math::LessThan([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Lt], - ] - .concat(), - Math::GreaterThanEqual([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), - vec![Term::Gte], - ] - .concat(), - Math::Symbol(s) => vec![Term::Var(s.as_str().chars().next().unwrap())], - } - } - let mut terms = vec![]; - terms.extend(create_postfix(expr.as_ref())); - Expression::new(terms) -} - -type EGraph = egg::EGraph; -type Rewrite = egg::Rewrite; - -define_language! { - enum Math { - Num(i32), - "+" = Add([Id; 2]), - "-" = Sub([Id; 2]), - "*" = Mul([Id; 2]), - "/" = Div([Id; 2]), - "%" = Mod([Id; 2]), - "min" = Min([Id; 2]), - "max" = Max([Id; 2]), - "&&" = And([Id; 2]), - "||" = Or([Id; 2]), - "<" = LessThan([Id; 2]), - ">=" = GreaterThanEqual([Id; 2]), - Symbol(Symbol), - } -} - -#[derive(Default)] -pub struct ConstantFold; -impl Analysis for ConstantFold { - type Data = Option; - - fn make(egraph: &egg::EGraph, enode: &Math) -> Self::Data { - let x = |i: &Id| egraph[*i].data.as_ref().map(|d| *d); - Some(match enode { - Math::Num(c) => *c, - Math::Add([a, b]) => x(a)?.checked_add(x(b)?)?, - Math::Sub([a, b]) => x(a)?.checked_sub(x(b)?)?, - Math::Mul([a, b]) => x(a)?.checked_mul(x(b)?)?, - Math::Div([a, b]) if x(b) != Some(0) => { - let (a, b) = (x(a)?, x(b)?); - if a % b != 0 { - return None; - } else { - a.checked_div(b)? - } - } - Math::Mod([a, b]) => x(a)?.checked_rem(x(b)?)?, - Math::Min([a, b]) => x(a)?.min(x(b)?), - Math::Max([a, b]) => x(a)?.max(x(b)?), - Math::And([a, b]) => (x(a)? != 0 && x(b)? != 0) as i32, - Math::Or([a, b]) => (x(a)? != 0 || x(b)? != 0) as i32, - Math::LessThan([a, b]) => (x(a)? < x(b)?) as i32, - Math::GreaterThanEqual([a, b]) => (x(a)? >= x(b)?) as i32, - _ => return None, - }) - } - - fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { - merge_option(to, from, |a, b| { - assert_eq!(*a, b, "Merged non-equal constants"); - DidMerge(false, false) - }) - } - - fn modify(egraph: &mut EGraph, id: Id) { - let data = egraph[id].data; - if let Some(c) = data { - let added = egraph.add(Math::Num(c)); - egraph.union(id, added); - egraph[id].nodes.retain(|n| n.is_leaf()); - - #[cfg(debug_assertions)] - egraph[id].assert_unique_leaves(); - } - } -} - -fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { - let var = var.parse().unwrap(); - move |egraph, _, subst| egraph[subst[var]].data.map(|i| i != 0).unwrap_or(true) -} - -fn is_const_positive(vars: &[&str]) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { - let vars: Vec = vars.iter().map(|i| i.parse().unwrap()).collect::>(); - move |egraph, _, subst| { - vars.iter() - .all(|i| egraph[subst[*i]].data.map(|i| i >= 0).unwrap_or(false)) - } -} - -fn make_rules(lower_bound_zero: bool) -> Vec { - let mut v = vec![ - // Communative properties - rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), - rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), - rewrite!("commute-min"; "(min ?a ?b)" => "(min ?b ?a)"), - rewrite!("commute-max"; "(max ?a ?b)" => "(max ?b ?a)"), - rewrite!("commute-and"; "(&& ?a ?b)" => "(&& ?b ?a)"), - rewrite!("commute-or"; "(|| ?a ?b)" => "(|| ?b ?a)"), - // Associative properties - rewrite!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), - rewrite!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"), - rewrite!("assoc-div"; "(/ (/ ?a ?b) ?c)" => "(/ ?a (* ?b ?c))"), - rewrite!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"), - // Distributive - rewrite!("distribute-mul"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), - rewrite!("distribute-div"; "(/ (+ ?a ?b) ?c)" => "(+ (/ ?a ?c) (/ ?b ?c))"), - rewrite!("distribute-max"; "(* ?a (max ?b ?c))" => "(max (* ?a ?b) (* ?a ?c))" if is_const_positive(&["?a"])), - rewrite!("distribute-min"; "(* ?a (min ?b ?c))" => "(min (* ?a ?b) (* ?a ?c))"), - // Factoring - rewrite!("factor-mul" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), - rewrite!("group-terms"; "(+ ?a ?a)" => "(* 2 ?a)"), - // Other - rewrite!("div-move-inside"; "(+ (/ ?a ?b) ?c)" => "(/ (+ ?a (* ?c ?b)) ?b)"), - // Simple binary reductions - rewrite!("add-0"; "(+ ?a 0)" => "?a"), - rewrite!("mul-0"; "(* ?a 0)" => "0"), - rewrite!("mul-1"; "(* ?a 1)" => "?a"), - rewrite!("div-1"; "(/ ?a 1)" => "?a"), - rewrite!("div-self"; "(/ ?a ?a)" => "1"), - rewrite!("and-0"; "(&& ?a 0)" => "0"), - rewrite!("and-1"; "(&& ?a 1)" => "?a"), - rewrite!("or-0"; "(|| ?a 0)" => "?a"), - rewrite!("or-1"; "(|| ?a 1)" => "1"), - rewrite!("min-i32-max"; "(min ?a 2147483647)" => "?a"), - rewrite!("max-i32-max"; "(max ?a 2147483647)" => "2147483647"), - rewrite!("recip-mul-div"; "(* ?x (/ 1 ?x))" => "1" if is_not_zero("?x")), - rewrite!("add-zero"; "?a" => "(+ ?a 0)"), - rewrite!("mul-one"; "?a" => "(* ?a 1)"), - rewrite!("cancel-sub"; "(- ?a ?a)" => "0"), - rewrite!("cancel-div"; "(/ ?a ?a)" => "1" if is_not_zero("?a")), - rewrite!("dedup-max"; "(max ?a (max ?a ?b))" => "(max ?a ?b)"), - rewrite!("dedup-min"; "(min ?a (min ?a ?b))" => "(min ?a ?b)"), - ]; - if lower_bound_zero { - v.push(rewrite!("max-zero"; "(max ?a 0)" => "?a")); - } - v -} - -fn egg_simplify(e: Expression, lower_bound_zero: bool) -> Expression { - // Convert to egg expression - let expr = luminal_to_egg(&e); - // Simplify - let runner = Runner::default() - // .with_iter_limit(1_000) - // .with_time_limit(std::time::Duration::from_secs(30)) - // .with_node_limit(100_000_000) - .with_expr(&expr) - .run(&make_rules(lower_bound_zero)); - // runner.print_report(); - let extractor = Extractor::new(&runner.egraph, AstSize); - let (_, best) = extractor.find_best(runner.roots[0]); - // Convert back to luminal expression - egg_to_luminal(best) +#[tracing::instrument(skip_all)] +fn egglog_simplify(e: Expression) -> Expression { + let cache = SIMPLIFY_CACHE.get_or_init(|| { + Mutex::new(LruCache::::new( + NonZeroUsize::new(MAX_CACHED_SIMPLIFICATIONS).unwrap(), + )) + }); + + if let Some(out) = cache.lock().unwrap().get(&e).copied() { + return out; + } + let expr = e.to_egglog(); + let mut program = String::new(); + program.push_str(egglog_utils::BASE); + program.push('\n'); + program.push_str(egglog_utils::BASE_CLEANUP); + program.push('\n'); + program.push_str(&format!("(let expr_root {expr})\n")); + program.push_str( + "(run-schedule + (saturate expr) + (saturate base_cleanup) + (saturate cleanup) + )", + ); + let mut egraph = egglog::EGraph::default(); + let commands = egraph + .parser + .get_program_from_string(None, &program) + .unwrap(); + egraph.run_program(commands).unwrap(); + let (sort, value) = egraph.eval_expr(&var!("expr_root")).unwrap(); + let serialized = SerializedEGraph::new(&egraph, vec![(sort, value)]); + let simplified = extract_expr( + &serialized, + &serialized.eclasses[serialized.roots.first().unwrap()].1[0], + &mut FxHashMap::default(), + ) + .unwrap_or(e); + cache.lock().unwrap().push(e, simplified); + simplified } #[cfg(test)] @@ -1286,6 +1003,14 @@ mod tests { assert_eq!(s, (w + 7) / 4); } + #[test] + fn test_egglog_equality() { + let a = Expression::from('a'); + let b = Expression::from('b'); + assert!((a + (b - a)).egglog_equal(b)); + assert!(!(a + 1).egglog_equal(a + 2)); + } + #[test] fn test_other() { let z = Expression::from('z'); @@ -1296,7 +1021,7 @@ mod tests { * (-5 + (((9 + (4 * (-5 + ((((((153 + h) / 2) / 2) / 2) / 2) / 2)))) / 2) / 2)))) % 64; let x = o.simplify(); - assert_eq!(x.len(), 23); // Should be 21 if we can re-enable mul-div-associative-rev + assert!(x.len() <= 27); // Should be 21 if we can re-enable mul-div-associative-rev } #[test]