diff --git a/crates/luminal_2_eggplant/.gitignore b/crates/luminal_2_eggplant/.gitignore new file mode 100644 index 00000000..3f686602 --- /dev/null +++ b/crates/luminal_2_eggplant/.gitignore @@ -0,0 +1,2 @@ +/target +*.dot \ No newline at end of file diff --git a/crates/luminal_2_eggplant/Cargo.toml b/crates/luminal_2_eggplant/Cargo.toml new file mode 100644 index 00000000..a9d846c8 --- /dev/null +++ b/crates/luminal_2_eggplant/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "luminal_2_eggplant" +version = "0.1.0" +edition = "2024" +rust-version = "1.89" + +[features] +default = [] +cuda = ["dep:cudarc", "dep:luminal_cuda"] +metal = ["dep:objc2", "dep:objc2-metal", "dep:objc2-foundation"] + +[dependencies] +luminal = { path = "../../" } +luminal_cuda = { path = "../luminal_cuda", optional = true } +cudarc = { version = "0.16.6", features = [ + "f16", + "cuda-12080", +], optional = true } +#metal-rs = { version = "0.28.0", package = "metal", optional=true } +objc2 = { version = "0.6.2", optional = true } +objc2-metal = { version = "0.3.1", optional = true } +itertools = "0.14.0" +urlencoding = "2.1.3" +webbrowser = "1.0.4" +regex = "1.11.1" +serde_json = "1.0.140" +colored = "3.0.0" +generational-box = "0.6.2" +egg = "0.9.5" +symbolic_expressions = "5.0.3" +rustc-hash = "2.1.1" +rand = "0.9.1" +egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", branch = "main" } +indexmap = "2.9.0" +serde = "1.0.219" +ratatui = "0.29.0" +crossterm = "0.29.0" +libc = "0.2.174" +unicode-width = "0.2" +objc2-foundation = { version = "0.3.1", optional = true } +eframe = "0.28" +egui = "0.28" +anyhow = "1.0.99" + +# eggplant related +eggplant = "0.2.6" +derive_more = { version = "2.0.1", features = [ + "deref_mut", + "deref", + "into_iterator", + "debug", +] } +strum = { version = "0.27.2", features = ["strum_macros"] } +strum_macros = "0.27.2" diff --git a/crates/luminal_2_eggplant/src/datatypes.rs b/crates/luminal_2_eggplant/src/datatypes.rs new file mode 100644 index 00000000..fe8697b4 --- /dev/null +++ b/crates/luminal_2_eggplant/src/datatypes.rs @@ -0,0 +1,105 @@ +use eggplant::egglog; +use eggplant::prelude::*; +use serde::Deserialize; +use serde::Serialize; + +// datatype((datatype Expression(MNum i64:args_name "num")(MVar String:args_name "name")(MAdd Expression Expression:args_name "l,r")(MSub Expression Expression:args_name "l,r")(MMul Expression Expression:args_name "l,r")(MDiv Expression Expression:args_name "l,r")(MMod Expression Expression:args_name "l,r")(MMin Expression Expression:args_name "l,r")(MMax Expression Expression:args_name "l,r")(MAnd Expression Expression:args_name "l,r")(MOr Expression Expression:args_name "l,r")(MGte Expression Expression:args_name "l,r")(MLt Expression Expression:args_name "l,r")(MFloorTo Expression Expression:args_name "l,r")(MReplace Expression Expression Expression:args_name "l,r,rpl")(MAccum String:args_name "name")))#[doc = "DSl Generated"] +#[eggplant::dsl] +pub enum Expr { + MNum { num: i64 }, + MVar { name: String }, + MAdd { l: Expr, r: Expr }, + MSub { l: Expr, r: Expr }, + MMul { l: Expr, r: Expr }, + MDiv { l: Expr, r: Expr }, + MMod { l: Expr, r: Expr }, + MMin { l: Expr, r: Expr }, + MMax { l: Expr, r: Expr }, + MAnd { l: Expr, r: Expr }, + MOr { l: Expr, r: Expr }, + MGte { l: Expr, r: Expr }, + MLt { l: Expr, r: Expr }, + MFloorTo { l: Expr, r: Expr }, + MReplace { l: Expr, r: Expr, rpl: Expr }, + MAccum { name: String }, +} + +#[eggplant::base_ty] +#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq, Default)] +pub enum UnOp { + Exp2, + Log2, + Sqrt, + Sin, + Recip, + Neg, + #[default] + Unknown, +} +#[eggplant::base_ty] +#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq, Default)] +pub enum BinOp { + Add, + Sub, + Mul, + #[default] + Unknown, +} + +#[eggplant::dsl(base=BinOp,base=UnOp)] +enum IR { + GMEM { + name: String, + }, + LoopIn { + ir: IR, + l: Expr, + r: Expr, + }, + LoopOut { + ir: IR, + l: Expr, + r: Expr, + }, + Unary { + op: UnOp, + ir: IR, + }, + Binary { + op: BinOp, + l: IR, + r: IR, + }, + SwapLoops { + ir: IR, + level: i64, + }, + TileLoop { + ir: IR, + level: i64, + }, + MergeLoops { + ir: IR, + level: i64, + }, + TCMatmul { + inp_a: IR, + inp_b: IR, + a_k_stride: Expr, + b_k_stride: Expr, + a_inner_stride: Expr, + b_inner_stride: Expr, + c_inner_stride: Expr, + num_k_loops: Expr, + }, + TiledMatmulInputA { + ir: IR, + num: i64, + expr: Expr, + }, + TiledMatmulInputB { + ir: IR, + num: i64, + expr: Expr, + }, +} diff --git a/crates/luminal_2_eggplant/src/main.rs b/crates/luminal_2_eggplant/src/main.rs new file mode 100644 index 00000000..1030177f --- /dev/null +++ b/crates/luminal_2_eggplant/src/main.rs @@ -0,0 +1,28 @@ +mod datatypes; +mod rules; +mod shortcut; +use datatypes::*; +use eggplant::{prelude::*, tx_rx_vt_pr}; + +tx_rx_vt_pr!(MyTx, MyPatRec); +fn main() { + // let expr: Expr = (MNum::new(4) * 3) + 2; + let expr: Expr = MNum::new(4); + let ruleset = MyTx::new_ruleset("expr"); + rules::add_rules::(ruleset); + expr.commit(); + MyTx::run_ruleset(ruleset, RunConfig::Sat); +} + +#[test] +fn test_const_fold() { + let expr: Expr = MNum::new(3) * MNum::new(4) + MNum::new(2); + expr.commit(); + let ruleset = MyTx::new_ruleset("expr"); + rules::add_rules::(ruleset); + MyTx::run_ruleset(ruleset, RunConfig::Sat); + + let ans: Expr = MNum::new(12); + ans.commit(); + assert!(MyTx::canonical_raw(&expr) == MyTx::canonical_raw(&ans)) +} diff --git a/crates/luminal_2_eggplant/src/rules/basic_expr_rules.rs b/crates/luminal_2_eggplant/src/rules/basic_expr_rules.rs new file mode 100644 index 00000000..7b8bc7d2 --- /dev/null +++ b/crates/luminal_2_eggplant/src/rules/basic_expr_rules.rs @@ -0,0 +1,100 @@ +use crate::datatypes::*; +use eggplant::prelude::*; +use eggplant::wrap::{G, RuleCtx, RuleSetId}; +macro_rules! fold { + ($ty:ident,$f:expr,$pat_name:ident,$ruleset:ident) => { + T::add_rule( + stringify!($pat_name), + $ruleset, + || { + let l = MNum::query(); + let r = MNum::query(); + let p = $ty::query(&l, &r); + #[eggplant::pat_vars_catch] + struct $pat_name { + l: MNum, + r: MNum, + p: $ty, + } + }, + |ctx, pat| { + let cal = $f(ctx.devalue(pat.l.num), ctx.devalue(pat.r.num)); + let op_value = ctx.insert_m_num(cal); + ctx.union(pat.p, op_value); + }, + ); + }; +} + +macro_rules! commu { + ($ty:ident,$f:expr,$pat_name:ident,$ruleset:ident) => { + T::add_rule( + stringify!($pat_name), + $ruleset, + || { + let l = Expr::query_leaf(); + let r = Expr::query_leaf(); + let p = $ty::query(&l, &r); + #[eggplant::pat_vars_catch] + struct $pat_name { + l: Expr, + r: Expr, + p: $ty, + } + }, + |ctx, pat| { + let op = $f(ctx, pat.l, pat.r); + ctx.union(pat.p, op); + }, + ); + }; +} + +macro_rules! assoc { + ($ty:ident,$f:expr,$pat_name:ident,$ruleset:ident) => { + T::add_rule( + stringify!($pat_name), + $ruleset, + || { + let ll = Expr::query_leaf(); + let lr = Expr::query_leaf(); + let l = $ty::query(&ll, &lr); + let r = Expr::query_leaf(); + let p = $ty::query(&l, &r); + #[eggplant::pat_vars_catch] + struct $pat_name { + ll: Expr, + lr: Expr, + r: Expr, + p: $ty, + } + }, + |ctx, pat| { + let r = $f(ctx, pat.lr, pat.r); + let p = $f(ctx, pat.ll, r); + ctx.union(pat.p, p); + }, + ); + }; +} +pub fn assoc(ruleset: RuleSetId) { + assoc!(MAdd, RuleCtx::insert_m_add, AddAssocPat, ruleset); + assoc!(MMul, RuleCtx::insert_m_mul, MulAssocPat, ruleset); +} + +pub fn commu(ruleset: RuleSetId) { + commu!(MAdd, RuleCtx::insert_m_add, AddCommuPat, ruleset); + commu!(MMul, RuleCtx::insert_m_mul, MulCommuPat, ruleset); +} + +pub fn const_fold(ruleset: RuleSetId) { + use std::cmp::*; + use std::ops::*; + + fold!(MAdd, Add::add, AddPat, ruleset); + fold!(MSub, Sub::sub, SubPat, ruleset); + fold!(MMul, Mul::mul, MulPat, ruleset); + fold!(MMax, max, MaxPat, ruleset); + fold!(MMin, min, MinPat, ruleset); + fold!(MAnd, BitAnd::bitand, BitAndPat, ruleset); +} diff --git a/crates/luminal_2_eggplant/src/rules/mod.rs b/crates/luminal_2_eggplant/src/rules/mod.rs new file mode 100644 index 00000000..541c8879 --- /dev/null +++ b/crates/luminal_2_eggplant/src/rules/mod.rs @@ -0,0 +1,7 @@ +use eggplant::wrap::{G, RuleSetId}; +mod basic_expr_rules; +pub fn add_rules(ruleset: RuleSetId) { + basic_expr_rules::commu::(ruleset); + basic_expr_rules::const_fold::(ruleset); + basic_expr_rules::assoc::(ruleset); +} diff --git a/crates/luminal_2_eggplant/src/shortcut.rs b/crates/luminal_2_eggplant/src/shortcut.rs new file mode 100644 index 00000000..4fcb7ed0 --- /dev/null +++ b/crates/luminal_2_eggplant/src/shortcut.rs @@ -0,0 +1,140 @@ +use crate::datatypes::*; +use eggplant::wrap::G; + +macro_rules! cartesian_ops { + (($t1:ty, $t2:ty),$out:tt,$op:tt,$method:ident) => { + impl std::ops::$op<$t2> for $t1 { + type Output = $out; + + fn $method(self, rhs: $t2) -> Self::Output { + $out::new(&self, &rhs) + } + } + }; + (($t1:ty, $t2:ty, $($rest:ty),+),$out:tt,$op:tt,$method:ident) => { + cartesian_ops!(($t1, $t2),$out,$op,$method); + cartesian_ops!(($t2, $t1),$out,$op,$method); + cartesian_ops!(($t1, $t1),$out,$op,$method); + $( cartesian_ops!(($t1, $rest),$out,$op,$method);)* + $( cartesian_ops!(($rest, $t1),$out,$op,$method);)* + cartesian_ops!(($t2, $($rest),*),$out,$op,$method); + }; +} +cartesian_ops!( + ( + MNum, + MVar, + MAdd, + MSub, + MMul, + MDiv, + MMod, + MMin, + MMax, + MAnd, + MOr, + MGte, + MLt, + MFloorTo, + MReplace, + MAccum + ), + MAdd, + Add, + add +); +cartesian_ops!( + ( + MNum, + MVar, + MAdd, + MSub, + MMul, + MDiv, + MMod, + MMin, + MMax, + MAnd, + MOr, + MGte, + MLt, + MFloorTo, + MReplace, + MAccum + ), + MMul, + Mul, + mul +); + +cartesian_ops!( + ( + MNum, + MVar, + MAdd, + MSub, + MMul, + MDiv, + MMod, + MMin, + MMax, + MAnd, + MOr, + MGte, + MLt, + MFloorTo, + MReplace, + MAccum + ), + MDiv, + Div, + div +); + +cartesian_ops!( + ( + MNum, + MVar, + MAdd, + MSub, + MMul, + MDiv, + MMod, + MMin, + MMax, + MAnd, + MOr, + MGte, + MLt, + MFloorTo, + MReplace, + MAccum + ), + MSub, + Sub, + sub +); + +cartesian_ops!( + ( + MNum, + MVar, + MAdd, + MSub, + MMul, + MDiv, + MMod, + MMin, + MMax, + MAnd, + MOr, + MGte, + MLt, + MFloorTo, + MReplace, + MAccum + ), + MMod, + Rem, + rem +);