diff --git a/Cargo.toml b/Cargo.toml index 19b91bc9..d01a0a23 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "luminal" version = "0.2.0" -edition = "2024" +edition.workspace = true rust-version = "1.85" description = "Deep learning at the speed of light." license = "MIT OR Apache-2.0" @@ -22,18 +22,23 @@ regex = "1.9.5" rustc-hash = "2.1.1" uuid = { version = "1.7.0", features = ["v4"] } as-any = "0.3.1" -egg = "0.9.5" symbolic_expressions = "5.0.3" serde = { version = "1.0.202", features = ["derive"] } thread_local = "1.1.8" generational-box = "0.5.6" serde_json = "1.0.140" +egg = "0.9.5" egglog = "1.0.0" egglog-ast = "1.0.0" -egraph-serialize = "0.3.0" +egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]} tracing = "0.1.43" paste = "1.0.15" pretty-duration = "0.1.1" +anyhow = "1.0" +graphviz-rust = { version = "0.9", default-features = false} + +[workspace.package] +edition = "2024" [dev-dependencies] candle-core = "0.9.1" @@ -43,13 +48,13 @@ ordered-float = "5.1.0" [workspace] members = [ "examples/llama", - #"examples/*", + "examples/visualization", "crates/luminal_nn", "crates/luminal_cuda", "crates/luminal_training", "docs/company", ] exclude = [ - "examples/yolo_v8", + "crates/luminal_cuda", ] diff --git a/crates/luminal_cuda/src/block/ops.rs b/crates/luminal_cuda/src/block/ops.rs index 453e6e5d..28405d44 100644 --- a/crates/luminal_cuda/src/block/ops.rs +++ b/crates/luminal_cuda/src/block/ops.rs @@ -4,8 +4,9 @@ use super::CustomState; use cudarc::driver::{CudaStream, DevicePtr}; use itertools::Itertools; use luminal::{ - graph::{extract_expr, extract_expr_list, SerializedEGraph}, + graph::{extract_expr, extract_expr_list}, prelude::ENodeId, + serialized_egraph::SerializedEGraph, shape::Expression, utils::{ flatten_mul_strides, CStructBuilder, EgglogOp, LLIROp, diff --git a/crates/luminal_cuda/src/kernel/ops.rs b/crates/luminal_cuda/src/kernel/ops.rs index c5914292..ae4ae97c 100644 --- a/crates/luminal_cuda/src/kernel/ops.rs +++ b/crates/luminal_cuda/src/kernel/ops.rs @@ -6,9 +6,10 @@ use cudarc::{ }; use itertools::Itertools; use luminal::{ - graph::{extract_dtype, extract_expr, extract_expr_list, SerializedEGraph}, + graph::{extract_dtype, extract_expr, extract_expr_list}, op::DType, prelude::ENodeId, + serialized_egraph::SerializedEGraph, shape::Expression, utils::{ flatten_mul_strides, EgglogOp, LLIROp, diff --git a/crates/luminal_nn/src/lib.rs b/crates/luminal_nn/src/lib.rs index 6bd79ee3..d6499f3b 100644 --- a/crates/luminal_nn/src/lib.rs +++ b/crates/luminal_nn/src/lib.rs @@ -1,3 +1,5 @@ +#![allow(unused_imports)] + mod activation; pub use activation::*; mod convolution; diff --git a/examples/llama/Cargo.toml b/examples/llama/Cargo.toml index 84233676..8bf1809d 100644 --- a/examples/llama/Cargo.toml +++ b/examples/llama/Cargo.toml @@ -12,7 +12,6 @@ luminal_cuda = { path = "../../crates/luminal_cuda" } itertools = "0.12.1" tokenizers = "0.15.2" tracing = "0.1.43" -rustc-hash = "2.1.1" tracing-subscriber = {version="0.3", features=["env-filter"]} tracing-perfetto-sdk-layer = "0.13.0" tracing-perfetto-sdk-schema = "0.13.0" diff --git a/examples/llama/src/main.rs b/examples/llama/src/main.rs index faebe16d..5e660763 100644 --- a/examples/llama/src/main.rs +++ b/examples/llama/src/main.rs @@ -4,13 +4,13 @@ use itertools::Itertools; use luminal::{ graph::{Graph, Runtime}, op::DType, + prelude::FxHashMap, }; use luminal_cuda::{ block::IntoBlockOp, runtime::{record_exec_timings_to_file, CudaRuntime, CustomState}, }; use model::*; -use rustc_hash::*; use std::{fs::File, io::Write, time::Duration}; use tokenizers::Tokenizer; use tracing::{span, Level}; diff --git a/examples/visualization/.gitignore b/examples/visualization/.gitignore new file mode 100644 index 00000000..ecaea95c --- /dev/null +++ b/examples/visualization/.gitignore @@ -0,0 +1,3 @@ +*.dot +*.html +*.svg \ No newline at end of file diff --git a/examples/visualization/Cargo.toml b/examples/visualization/Cargo.toml new file mode 100644 index 00000000..f59a7e74 --- /dev/null +++ b/examples/visualization/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "visualization" +version = "0.1.0" +edition = "2021" + +[features] + +[dependencies] +anyhow = "1.0" +egglog = "1.0" +egglog-ast = "1.0.0" +egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]} +itertools = "0.12.1" +luminal = { path = "../.." } +luminal_cuda = { path = "../../crates/luminal_cuda" } +luminal_nn = { path = "../../crates/luminal_nn" } +rustc-hash = "2.1" +tokenizers = "0.15.2" +tracing = "0.1.43" +tracing-appender = "0.2.4" +tracing-perfetto-sdk-layer = "0.13.0" +tracing-perfetto-sdk-schema = "0.13.0" +tracing-perfetto-sdk-sys = "0.13.0" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/visualization/README.md b/examples/visualization/README.md new file mode 100644 index 00000000..6143c88f --- /dev/null +++ b/examples/visualization/README.md @@ -0,0 +1,23 @@ +# Visualization in Luminal + +## Design Choices +Luminal produces intermediate files rather than complete visualizations + +The two primary file types are: +- `.html` files +- `.dot` files + +These files enable interactive viewing which is often necessary for making visualizations interpretable. + +## VSCode Extensions +We recommend the following extensions for VSCode users. +The integrated nature of these extensions makes viewing these files easy even on remote machines via ssh. + +- `Live Preview` by microsoft. +- `Graphviz Interactive Preview` by tintinweb + +## Example Provided +In the example program, as simple program is defined. +From this a HLIR graph is created and visualized. +A saturated EGraph is created and visualized. +Finally an LLIR graph is extracted and visualized. diff --git a/examples/visualization/src/main.rs b/examples/visualization/src/main.rs new file mode 100644 index 00000000..3bf4d1f7 --- /dev/null +++ b/examples/visualization/src/main.rs @@ -0,0 +1,86 @@ +use std::fs; + +use luminal::{ + self, + graph::{hlir_to_egglog, Graph, Runtime}, + prelude::*, + serialized_egraph::SerializedEGraph, + visualization::{ToDot, ToHtml}, +}; +use luminal_cuda::runtime::{CudaRuntime, CustomState}; + +use egglog::{prelude::RustSpan, var, EGraph}; +use egglog_ast::span::Span; +use rustc_hash::FxHashMap; + +fn main() { + // Create a new graph + let mut cx = Graph::new(); + + // Create input tensor using constant values + + let (m, n, k) = (4096, 14336, 9192); + + let a = cx.tensor((m, k)); + let b = cx.tensor((k, n)); + + let _c = a.matmul(b); + + let ctx = luminal_cuda::cudarc::driver::CudaContext::new(0).unwrap(); + ctx.bind_to_thread().unwrap(); + let _stream = ctx.default_stream(); + let _custom_state: FxHashMap = FxHashMap::default(); + + println!("Visualizing HLIR"); + fs::write("HLIR.dot", cx.graph.to_dot().unwrap()).unwrap(); + + println!("Building and Saturating EGraph"); + cx.build_search_space::(); + + let (program, root) = hlir_to_egglog(&cx); + + let mut ops = ::Ops::into_vec(); + ops.extend(::into_vec()); + + let mut egglog_obj: EGraph = egglog::EGraph::default(); + + // setup the rules and datatypes + egglog_obj + .parse_and_run_program(None, luminal::egglog_utils::BASE) + .unwrap(); + egglog_obj + .parse_and_run_program(None, &luminal::egglog_utils::op_defs_string(&ops)) + .unwrap(); + egglog_obj + .parse_and_run_program(None, &luminal::egglog_utils::op_rewrites_string(&ops)) + .unwrap(); + egglog_obj + .parse_and_run_program(None, luminal::egglog_utils::BASE_CLEANUP) + .unwrap(); + egglog_obj + .parse_and_run_program(None, &luminal::egglog_utils::op_cleanups_string(&ops)) + .unwrap(); + + // load the program + egglog_obj.parse_and_run_program(None, &program).unwrap(); + + // run the graph + egglog_obj + .parse_and_run_program(None, luminal::egglog_utils::RUN_SCHEDULE) + .unwrap(); + + // EGraph Optimization Complete + println!("Visualizing EGraph"); + // save the egraph visualizations + fs::write("egraph.html", egglog_obj.to_html().unwrap()).unwrap(); + fs::write("egraph.dot", egglog_obj.to_dot().unwrap()).unwrap(); + + let (sort, value) = egglog_obj.eval_expr(&var!(root)).unwrap(); + let s_egraph = SerializedEGraph::new(&egglog_obj, vec![(sort, value)]); + let llir_graphs = egglog_to_llir(&s_egraph, &ops, 100); + + let example_llir_graph = llir_graphs.last().unwrap(); + + println!("Visualizing LLIR Graph"); + fs::write("LLIR.dot", example_llir_graph.to_dot().unwrap()).unwrap(); +} diff --git a/src/egglog_utils/base.egg b/src/egglog_utils/base.egg new file mode 100644 index 00000000..5af04dbb --- /dev/null +++ b/src/egglog_utils/base.egg @@ -0,0 +1,173 @@ +; -------- SYMBOLIC ALGEBRA ------- +(ruleset expr) +(datatype* + (Expression + (MNum i64) + (MFloat f64) + (MIter) + (MVar String) + (MAdd Expression Expression) + (MSub Expression Expression) + (MMul Expression Expression) + (MCeilDiv Expression Expression) + (MDiv Expression Expression) + (MMod Expression Expression) + (MMin Expression Expression) + (MMax Expression Expression) + (MAnd Expression Expression) + (MOr Expression Expression) + (MGte Expression Expression) + (MLt Expression Expression) + (MFloorTo Expression Expression) + (MReplace Expression Expression Expression) + (MAccum String) + ) + + ; eqsort list for vectors of Expression + (EList + (ECons Expression EList) + (ENil) + (MReplaceList EList Expression Expression) + (ReplaceNthFromEnd EList Expression i64) + (RemoveNthFromEnd EList i64) + (RowMajor EList) + ) + + (DType + (F32) + (F16) + (Bf16) + (Int) + ) +) + +; ---- Algebraic rewrites ---- +(rewrite (MAdd a b) (MAdd b a) :ruleset expr) +;(rewrite (MMul a b) (MMul b a) :ruleset expr) + +(rewrite (MAdd (MAdd a b) c) (MAdd a (MAdd b c)) :ruleset expr) +;(rewrite (MMul (MMul a b) c) (MMul a (MMul b c)) :ruleset expr) + +(rewrite (MAdd (MNum a) (MNum b)) (MNum (+ a b)) :ruleset expr) +(rewrite (MSub (MNum a) (MNum b)) (MNum (- a b)) :ruleset expr) +; multiply const folding +(rule + ( + (= ?e (MMul (MNum ?a) (MNum ?b))) + (= ?prod (* ?a ?b)) + ) + ( + (union ?e (MNum ?prod)) + (delete (MMul (MNum ?a) (MNum ?b))) + ) +) ; why does this explode??? +(rewrite (MDiv (MNum a) (MNum b)) (MNum (/ a b)) :when ((!= 0 b) (= 0 (% a b))) :ruleset expr) +(rewrite (MCeilDiv (MNum a) (MNum b)) (MNum (/ a b)) :when ((!= 0 b) (= 0 (% a b))) :ruleset expr) +(rewrite (MMax (MNum a) (MNum b)) (MNum (max a b)) :ruleset expr) +(rewrite (MMin (MNum a) (MNum b)) (MNum (min a b)) :ruleset expr) +(rewrite (MAnd (MNum a) (MNum b)) (MNum (& a b)) :ruleset expr) +(rewrite (MFloat -1.0) (MNum -1) :ruleset expr) +(rewrite (MNum -1) (MFloat -1.0) :ruleset expr) +(rewrite (MDiv (MMul ?x (MNum ?a)) (MNum ?b)) (MMul ?x (MNum (/ ?a ?b))) :when ((< ?b ?a) (= (% ?a ?b) 0))) ; why does this explode??? + +(rewrite (MAdd a (MNum 0)) a :ruleset expr) +(rule ((= ?e (MMul ?a (MNum 1)))) ((union ?e ?a)) :ruleset expr) +;(rewrite (MMul ?a (MNum 0)) (MNum 0) :ruleset expr) ; Why does this cause an infinite loop? +(rewrite (MDiv a (MNum 1)) a :ruleset expr) +(rewrite (MMod (MMul ?x ?y) ?y) (MNum 0) :ruleset expr) +(rewrite (MMod (MMod ?x (MNum ?y)) (MNum ?z)) (MMod ?x (MNum ?y)) + :when ((>= ?z ?y) (= 0 (% ?y ?z))) :ruleset expr) +(rewrite (MMod (MMod ?x (MNum ?y)) (MNum ?z)) (MMod ?x (MNum ?z)) + :when ((>= ?y ?z) (= 0 (% ?z ?y))) :ruleset expr) + +; ---- Replacement over expressions ---- +(rewrite (MReplace ?x ?y ?z) ?z :when ((= ?x ?y)) :ruleset expr) +(rewrite (MReplace (MAdd ?a ?b) ?x ?y) (MAdd (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MSub ?a ?b) ?x ?y) (MSub (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MMul ?a ?b) ?x ?y) (MMul (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MDiv ?a ?b) ?x ?y) (MDiv (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MCeilDiv ?a ?b) ?x ?y) (MCeilDiv (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MMod ?a ?b) ?x ?y) (MMod (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MMin ?a ?b) ?x ?y) (MMin (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MMax ?a ?b) ?x ?y) (MMax (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MFloorTo ?a ?b) ?x ?y) (MFloorTo (MReplace ?a ?x ?y) (MReplace ?b ?x ?y)) :ruleset expr) +(rewrite (MReplace (MNum ?n) ?x ?y) (MNum ?n) :ruleset expr) +(rewrite (MReplace (MAccum ?acc) ?x ?y) (MAccum ?acc) :ruleset expr) +(rewrite (MReplace (MVar ?z) ?find ?replace) (MVar ?z) :when ((!= ?find (MVar ?z))) :ruleset expr) +(rewrite (MReplace (MIter) ?find ?replace) (MIter) :when ((!= ?find (MIter))) :ruleset expr) + +; EList helper functions +(function len (EList) i64 :merge new) +(rule ((= ?e (ENil))) ((set (len ?e) 0)) :ruleset expr) +(rule ((= ?e (ECons ?expr ?list)) (= ?prev_len (len ?list))) ((set (len ?e) (+ ?prev_len 1))) :ruleset expr) + +(function nth_from_end (EList i64) Expression :merge new) +(rule ((= ?e (ECons ?expr ?list)) (= ?list_len (len ?list))) ((set (nth_from_end ?e ?list_len) ?expr)) :ruleset expr) +(rule ((= ?e (ECons ?expr ?list)) (= ?other_nth (nth_from_end ?list ?n))) ((set (nth_from_end ?e ?n) ?other_nth)) :ruleset expr) + +(function n_elements (EList) Expression :merge new) +(rule ((= ?e (ENil))) ((set (n_elements ?e) (MNum 1))) :ruleset expr) +(rule + ( + (= ?e (ECons ?dim ?other)) + (= ?other_elems (n_elements ?other)) + ) + ((set (n_elements ?e) (MMul ?dim ?other_elems))) + :ruleset expr +) + +(rule + ( + (= ?other (ECons ?other_dim ?other_other)) + (= ?list (ECons ?d ?other)) + (= ?e (RowMajor ?list)) + (= ?n_elems (n_elements ?other)) + ) + ( + (union ?e (ECons (MMul (MIter) ?n_elems) (RowMajor ?other))) + ) + :ruleset expr +) +(rewrite (RowMajor (ECons ?dim (ENil))) (ECons (MIter) (ENil)) :ruleset expr) + +(rewrite (MReplaceList (ECons ?expr ?list) ?from ?to) (ECons (MReplace ?expr ?from ?to) (MReplaceList ?list ?from ?to)) :ruleset expr) +(rule + ( + (= ?e (ReplaceNthFromEnd (ECons ?expr ?list) ?to ?ind)) + (= ?ind (len ?list)) + ) + ( + (union ?e (ECons ?to ?list)) + ) + :ruleset expr +) +(rule + ( + (= ?e (ReplaceNthFromEnd (ECons ?expr ?list) ?to ?ind)) + (< ?ind (len ?list)) + ) + ( + (union ?e (ECons ?expr (ReplaceNthFromEnd ?list ?to ?ind))) + ) + :ruleset expr +) +(rule + ( + (= ?e (RemoveNthFromEnd (ECons ?expr ?list) ?ind)) + (= ?ind (len ?list)) + ) + ( + (union ?e ?list) + ) + :ruleset expr +) +(rule + ( + (= ?e (RemoveNthFromEnd (ECons ?expr ?list) ?ind)) + (< ?ind (len ?list)) + ) + ( + (union ?e (ECons ?expr (RemoveNthFromEnd ?list ?ind))) + ) + :ruleset expr +) \ No newline at end of file diff --git a/src/egglog_utils/base_cleanup.egg b/src/egglog_utils/base_cleanup.egg new file mode 100644 index 00000000..08a44800 --- /dev/null +++ b/src/egglog_utils/base_cleanup.egg @@ -0,0 +1,41 @@ +(ruleset base_cleanup) +(rule + ((= ?m (MReplace ?a ?b ?c))) + ((delete (MReplace ?a ?b ?c))) + :ruleset base_cleanup +) +(rule + ((= ?m (MReplaceList ?a ?b ?c))) + ((delete (MReplaceList ?a ?b ?c))) + :ruleset base_cleanup +) +(rule + ((= ?m (ReplaceNthFromEnd ?a ?b ?c))) + ((delete (ReplaceNthFromEnd ?a ?b ?c))) + :ruleset base_cleanup +) +(rule + ((= ?m (RemoveNthFromEnd ?a ?b))) + ((delete (RemoveNthFromEnd ?a ?b))) + :ruleset base_cleanup +) +(rule + ((= ?m (len ?x))) + ((delete (len ?x))) + :ruleset base_cleanup +) +(rule + ((= ?m (nth_from_end ?x ?y))) + ((delete (nth_from_end ?x ?y))) + :ruleset base_cleanup +) +(rule + ((= ?m (n_elements ?x))) + ((delete (n_elements ?x))) + :ruleset base_cleanup +) +(rule + ((= ?m (RowMajor ?x))) + ((delete (RowMajor ?x))) + :ruleset base_cleanup +) \ No newline at end of file diff --git a/src/egglog.egg b/src/egglog_utils/egglog_template.egg similarity index 100% rename from src/egglog.egg rename to src/egglog_utils/egglog_template.egg diff --git a/src/egglog_utils/mod.rs b/src/egglog_utils/mod.rs new file mode 100644 index 00000000..4a2dfaf7 --- /dev/null +++ b/src/egglog_utils/mod.rs @@ -0,0 +1,56 @@ +use crate::utils::EgglogOp; +use itertools::Itertools; +use std::{str, sync::Arc}; + +pub const BASE: &str = include_str!("base.egg"); +pub const RUN_SCHEDULE: &str = include_str!("run_schedule.egg"); +pub const BASE_CLEANUP: &str = include_str!("base_cleanup.egg"); +pub const EGGLOG_TEMPLATE: &str = include_str!("egglog_template.egg"); + +pub fn op_defs_string(ops: &[Arc>]) -> String { + format!( + " + (datatype IR + {} + ) + (function dtype (IR) DType :merge new) + ", + ops.iter() + .map(|o| { + let (name, body) = o.term(); + format!( + "({name} {})", + body.into_iter().map(|j| format!("{j:?}")).join(" ") + ) + }) + .collect::>() + .join("\n") + ) +} + +pub fn op_rewrites_string(ops: &[Arc>]) -> String { + ops.iter().flat_map(|o| o.rewrites()).join("\n") +} + +pub fn op_cleanups_string(ops: &[Arc>]) -> String { + format!( + " + (ruleset cleanup) + {} + ", + ops.iter() + .filter(|op| op.cleanup()) + .map(|o| { + let (name, body) = o.term(); + let body_terms = (0..body.len()).map(|i| (b'a' + i as u8) as char).join(" "); + format!( + "(rule + ((= ?m ({name} {body_terms}))) + ((delete ({name} {body_terms}))) + :ruleset cleanup + )" + ) + }) + .join("\n") + ) +} diff --git a/src/egglog_utils/run_schedule.egg b/src/egglog_utils/run_schedule.egg new file mode 100644 index 00000000..8bc09d4d --- /dev/null +++ b/src/egglog_utils/run_schedule.egg @@ -0,0 +1,9 @@ +(run-schedule + (repeat 10 + (saturate expr) + (run) + ) + (saturate expr) + (saturate base_cleanup) + (saturate cleanup) +) \ No newline at end of file diff --git a/src/graph.rs b/src/graph.rs index 186cd407..7c310209 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,5 +1,7 @@ use crate::{ + egglog_utils, prelude::*, + serialized_egraph::SerializedEGraph, utils::{EgglogOp, IntoEgglogOp, LLIROp}, }; use std::{ @@ -280,7 +282,7 @@ impl NewOp<'_> { } } -fn hlir_to_egglog(graph: &Graph) -> (String, String) { +pub fn hlir_to_egglog(graph: &Graph) -> (String, String) { use std::cmp::Reverse; use std::collections::{BinaryHeap, HashMap}; @@ -369,14 +371,6 @@ pub fn elist_to_egglog(shape: &[Expression]) -> String { } } -#[derive(Debug)] -pub struct SerializedEGraph { - pub enodes: FxHashMap)>, - pub eclasses: FxHashMap)>, - pub node_to_class: FxHashMap, - pub roots: Vec, -} - #[tracing::instrument(skip_all)] fn run_egglog( program: &str, @@ -385,7 +379,7 @@ fn run_egglog( cleanup: bool, ) -> Result { let mut egraph = egglog::EGraph::default(); - let mut code = include_str!("egglog.egg").replace("{program}", program); + let mut code = egglog_utils::EGGLOG_TEMPLATE.replace("{program}", program); code = code.replace( "{ops}", &ops.iter() diff --git a/src/lib.rs b/src/lib.rs index 63771e5d..45b88114 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,12 @@ +pub mod egglog_utils; pub mod graph; pub mod graph_tensor; pub mod hl_ops; pub mod op; +pub mod serialized_egraph; pub mod shape; pub mod utils; +pub mod visualization; #[cfg(test)] pub mod tests; @@ -20,6 +23,7 @@ pub mod prelude { pub use half::{bf16, f16}; pub use petgraph; pub use petgraph::stable_graph::NodeIndex; + pub use rustc_hash::{FxHashMap, FxHashSet}; pub use tinyvec; } diff --git a/src/op.rs b/src/op.rs index bb0f364f..66fcc469 100644 --- a/src/op.rs +++ b/src/op.rs @@ -5,6 +5,7 @@ use std::{ use crate::{ prelude::*, + serialized_egraph::SerializedEGraph, utils::{ EgglogOp, LLIROp, OpParam::{self, *}, diff --git a/src/serialized_egraph.rs b/src/serialized_egraph.rs new file mode 100644 index 00000000..7fe2a34a --- /dev/null +++ b/src/serialized_egraph.rs @@ -0,0 +1,98 @@ +use crate::prelude::FxHashMap; +use egglog::{ArcSort, EGraph, Value}; +use egraph_serialize::{ClassId, NodeId}; + +#[derive(Debug)] +/// This is snapshot of an EGraph with Rust native hash maps and sets for enabling more native traversal / algorithm writing. +/// The name comes from the serialize egraph crates, which returns a ETermDAG, which caused issues, so this is a homebrew semi-static egraph +pub struct SerializedEGraph { + pub enodes: FxHashMap)>, + pub eclasses: FxHashMap)>, + pub node_to_class: FxHashMap, + pub roots: Vec, +} + +impl SerializedEGraph { + /// This is an opinionated function which does more than strictly take the state of the egglog object. + /// It also filters out "[...]" nodes and then changes the structure from the e-termDAG that egraph-serialize + /// produces to a strict egraph, where the children of e-classes are e-nodes. + pub fn new(egraph: &EGraph, root_eclasses: Vec<(ArcSort, Value)>) -> Self { + let s = egraph.serialize(egglog::SerializeConfig { + root_eclasses, + max_functions: None, + include_temporary_functions: false, + max_calls_per_function: None, + }); + // Convert to SerializedEGraph + let mut classes = FxHashMap::default(); + for (node_id, node) in &s.egraph.nodes { + classes + .entry(node.eclass.clone()) + .or_insert(vec![]) + .push(node_id.clone()) + } + let mut s_egraph = SerializedEGraph { + roots: s.egraph.root_eclasses, + node_to_class: s + .egraph + .nodes + .iter() + .map(|(n, enode)| (n.clone(), enode.eclass.clone())) + .collect(), + enodes: s + .egraph + .nodes + .iter() + .map(|(n, enode)| { + ( + n.clone(), + ( + enode.op.clone(), + enode + .children + .iter() + .map(|n| s.egraph.nodes[n].eclass.clone()) + .collect(), + ), + ) + }) + .collect(), + eclasses: s + .egraph + .class_data + .iter() + .map(|(c, eclass)| (c.clone(), (eclass.typ.clone().unwrap(), classes[c].clone()))) + .collect(), + }; + // Strip out all [...] enodes + s_egraph.enodes.retain(|_, (label, _)| label != "[...]"); + loop { + let mut to_remove = vec![]; + for (id, (_, children)) in &s_egraph.enodes { + if children.iter().any(|c| { + !s_egraph.eclasses[c] + .1 + .iter() + .any(|n| s_egraph.enodes.contains_key(n)) + }) { + to_remove.push(id.clone()); + } + } + for n in &to_remove { + s_egraph.enodes.remove(n); + } + if to_remove.is_empty() { + break; + } + } + // Correct the eclass mapping + for (_, enodes) in s_egraph.eclasses.values_mut() { + enodes.retain(|n| s_egraph.enodes.contains_key(n)); + } + s_egraph.eclasses.retain(|_, (_, c)| !c.is_empty()); + s_egraph + .node_to_class + .retain(|n, _| s_egraph.enodes.contains_key(n)); + s_egraph + } +} diff --git a/src/utils.rs b/src/utils.rs index d2d00a22..afd8a3a8 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,9 +1,9 @@ -use crate::graph::SerializedEGraph; use crate::{ prelude::{ ENodeId, NodeIndex, petgraph::{Directed, prelude::StableGraph}, }, + serialized_egraph::SerializedEGraph, shape::Expression, }; use as_any::{AsAny, Downcast}; diff --git a/src/visualization/egraph_viz_template.html.jinja b/src/visualization/egraph_viz_template.html.jinja new file mode 100644 index 00000000..d5364401 --- /dev/null +++ b/src/visualization/egraph_viz_template.html.jinja @@ -0,0 +1,8 @@ +
+ + \ No newline at end of file diff --git a/src/visualization/mod.rs b/src/visualization/mod.rs new file mode 100644 index 00000000..a664129b --- /dev/null +++ b/src/visualization/mod.rs @@ -0,0 +1,47 @@ +use anyhow::Result; +use egglog::EGraph; +use std::string::String; + +use crate::prelude::petgraph::dot::{Config, Dot}; + +pub trait ToHtml { + fn to_html(&self) -> Result; +} + +pub trait ToDot { + fn to_dot(&self) -> Result; +} + +impl ToHtml for EGraph { + fn to_html(&self) -> Result { + let egraph_as_json = serde_json::to_string_pretty( + &self.serialize(egglog::SerializeConfig::default()).egraph, + )?; + Ok(include_str!("egraph_viz_template.html.jinja") + .replace("{{JSON_TEMPLATE}}", &egraph_as_json)) + } +} + +impl ToDot for EGraph { + fn to_dot(&self) -> Result { + Ok(self + .serialize(egglog::SerializeConfig::default()) + .egraph + .to_dot()) + } +} + +/// Implements `ToDot` for [`LLIRGraph`] and [`HLIRGraph`] +/// TODO: This simple extraction can be improved in the future with support for edge labels +impl ToDot for petgraph::stable_graph::StableGraph +where + N: std::fmt::Debug, + E: std::fmt::Debug, +{ + fn to_dot(&self) -> Result { + Ok(format!( + "{:?}", + Dot::with_config(self, &[Config::EdgeNoLabel]) + )) + } +}