diff --git a/crates/luminal_cuda/Cargo.toml b/crates/luminal_cuda/Cargo.toml index 248d889d..fd43b9a7 100644 --- a/crates/luminal_cuda/Cargo.toml +++ b/crates/luminal_cuda/Cargo.toml @@ -10,7 +10,6 @@ license = "MIT OR Apache-2.0" [dependencies] luminal = { path = "../.." } cudarc = {version="0.17.3", features=["cuda-12080"]} -rustc-hash = "2.1.1" as-any = "0.3.2" itertools = "0.12.1" fixedbitset = "0.5.7" @@ -19,3 +18,6 @@ tracing = "0.1.43" tracing-perfetto-sdk-schema = "0.13.0" prost = "0.14.1" half = "2.7.1" +pretty-duration = "0.1.1" +bytemuck = "1.24.0" +memmap2 = "0.9.9" diff --git a/crates/luminal_cuda/src/block/mod.rs b/crates/luminal_cuda/src/block/mod.rs index 44f92f4d..c86b3d55 100644 --- a/crates/luminal_cuda/src/block/mod.rs +++ b/crates/luminal_cuda/src/block/mod.rs @@ -2,8 +2,7 @@ mod ops; pub use ops::*; use cudarc::driver::CudaStream; -use luminal::{shape::Expression, utils::EgglogOp}; -use rustc_hash::FxHashMap; +use luminal::{prelude::FxHashMap, shape::Expression, utils::EgglogOp}; use std::fmt::Debug; use crate::runtime::CustomState; diff --git a/crates/luminal_cuda/src/block/ops.rs b/crates/luminal_cuda/src/block/ops.rs index 28405d44..5c8c8e5e 100644 --- a/crates/luminal_cuda/src/block/ops.rs +++ b/crates/luminal_cuda/src/block/ops.rs @@ -5,15 +5,13 @@ use cudarc::driver::{CudaStream, DevicePtr}; use itertools::Itertools; use luminal::{ graph::{extract_expr, extract_expr_list}, - prelude::ENodeId, + prelude::*, serialized_egraph::SerializedEGraph, - shape::Expression, utils::{ - flatten_mul_strides, CStructBuilder, EgglogOp, LLIROp, + flatten_mul_strides, EgglogOp, LLIROp, OpParam::{self, *}, }, }; -use rustc_hash::FxHashMap; use crate::block::BlockOp; @@ -94,11 +92,15 @@ impl EgglogOp for RowAdd { impl BlockOp for RowAdd { fn launch_range(&self) -> Vec { - self.range.clone() + if self.range.is_empty() { + vec![1.into()] + } else { + self.range.clone() + } } fn output_size(&self) -> Expression { - self.range.iter().copied().product::() * self.row_width + self.range.iter().copied().product::().max(1) * self.row_width } fn consumer_barriers_seperate(&self) -> Vec> { @@ -128,7 +130,7 @@ impl BlockOp for RowAdd { _: &CudaStream, expressions: &FxHashMap, ) -> Vec { - CStructBuilder::new() + CStruct::new() .int(expressions[&flatten_mul_strides(&self.range, &self.a_stride)]) .int(expressions[&flatten_mul_strides(&self.range, &self.b_stride)]) .int(expressions[&flatten_mul_strides(&self.range, &self.out_stride)]) @@ -270,7 +272,7 @@ impl BlockOp for RowSwishMul { _: &CudaStream, expressions: &FxHashMap, ) -> Vec { - CStructBuilder::new() + CStruct::new() .int(expressions[&flatten_mul_strides(&self.range, &self.a_stride)]) .int(expressions[&flatten_mul_strides(&self.range, &self.b_stride)]) .int(expressions[&flatten_mul_strides(&self.range, &self.a_stride)]) @@ -491,7 +493,7 @@ impl BlockOp for RowRMSNorm { _: &CudaStream, expressions: &FxHashMap, ) -> Vec { - CStructBuilder::new() + CStruct::new() .int(expressions[&flatten_mul_strides(&self.range, &self.a_stride)]) .int(expressions[&flatten_mul_strides(&self.range, &self.a_stride)]) .int(expressions[&self.row_width]) @@ -623,7 +625,7 @@ impl BlockOp for RowRope { _: &CudaStream, expressions: &FxHashMap, ) -> Vec { - CStructBuilder::new() + CStruct::new() .int(expressions[&flatten_mul_strides(&self.range, &self.a_stride)]) .int(expressions[&flatten_mul_strides(&self.range, &self.a_stride)]) .int(expressions[&self.row_width]) @@ -938,7 +940,7 @@ impl BlockOp for TileMatmul { m_pos_stride[self.range.len() - 2] = 1.into(); let mut n_pos_stride = vec![0.into(); self.range.len()]; n_pos_stride[self.range.len() - 1] = 1.into(); - CStructBuilder::new() + CStruct::new() .ints( &self .untiled_range @@ -1277,7 +1279,7 @@ impl BlockOp for GQAAttention { group_pos_stride[self.range.len() - 2] = 1.into(); let mut head_pos_stride = vec![0.into(); self.range.len()]; head_pos_stride[self.range.len() - 3] = 1.into(); - CStructBuilder::new() + CStruct::new() .int(expressions[&self.head_dim]) .int(expressions[&self.cur_seq]) .int(expressions[&self.kv_row_stride]) @@ -1316,3 +1318,107 @@ impl BlockOp for GQAAttention { ] } } + +#[derive(Debug)] +pub struct CStruct { + buf: Vec, + max_align: usize, +} + +impl Default for CStruct { + fn default() -> Self { + Self { + buf: Vec::new(), + max_align: 1, + } + } +} + +impl CStruct { + pub fn new() -> Self { + Self::default() + } + + fn align_to(&mut self, align: usize) { + self.max_align = self.max_align.max(align); + + let len = self.buf.len(); + let rem = len % align; + if rem != 0 { + let pad = align - rem; + self.buf.extend(std::iter::repeat_n(0u8, pad)); + } + } + + pub fn int(mut self, v: i32) -> Self { + self.align_to(4); + self.buf.extend_from_slice(&v.to_ne_bytes()); + self + } + + pub fn ints(mut self, vs: &[i32]) -> Self { + self.align_to(4); + for &v in vs { + self.buf.extend_from_slice(&v.to_ne_bytes()); + } + self + } + + pub fn float(mut self, v: f32) -> Self { + self.align_to(4); + self.buf.extend_from_slice(&v.to_ne_bytes()); + self + } + + pub fn floats(mut self, vs: &[f32]) -> Self { + self.align_to(4); + for &v in vs { + self.buf.extend_from_slice(&v.to_ne_bytes()); + } + self + } + + pub fn bool(mut self, v: bool) -> Self { + self.align_to(1); + self.buf.push(if v { 1 } else { 0 }); + self + } + + pub fn ptr_const_f32(mut self, p: *const f32) -> Self { + let ptr_size = std::mem::size_of::(); // usually 8 + let ptr_align = ptr_size; + self.align_to(ptr_align); + + let addr = p as usize; + let bytes = addr.to_ne_bytes(); + + self.buf.extend_from_slice(&bytes[..ptr_size]); + self + } + + pub fn ptr_mut_f32(self, p: *mut f32) -> Self { + self.ptr_const_f32(p as *const f32) + } + + /// Pad the struct size to a multiple of max_align. + pub fn finish_struct(mut self) -> Vec { + let align = self.max_align; + if align > 1 { + let len = self.buf.len(); + let rem = len % align; + if rem != 0 { + let pad = align - rem; + self.buf.extend(std::iter::repeat_n(0u8, pad)); + } + } + self.buf + } + + /// Insert a raw byte field (e.g., another struct). + /// `align` must be the alignment of the nested struct. + pub fn bytes(mut self, align: usize, data: &[u8]) -> Self { + self.align_to(align); + self.buf.extend_from_slice(data); + self + } +} diff --git a/crates/luminal_cuda/src/kernel/mod.rs b/crates/luminal_cuda/src/kernel/mod.rs index 96639afd..befe751f 100644 --- a/crates/luminal_cuda/src/kernel/mod.rs +++ b/crates/luminal_cuda/src/kernel/mod.rs @@ -2,9 +2,8 @@ use std::sync::Arc; -use cudarc::driver::{CudaContext, CudaFunction, CudaSlice, CudaStream}; -use luminal::{shape::Expression, utils::EgglogOp}; -use rustc_hash::FxHashMap; +use cudarc::driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream}; +use luminal::prelude::*; pub mod ops; pub use ops::Ops; @@ -16,6 +15,7 @@ pub trait KernelOp: EgglogOp { stream: &Arc, ) -> ( CudaFunction, + Arc, String, (Expression, Expression, Expression), (Expression, Expression, Expression), diff --git a/crates/luminal_cuda/src/kernel/ops.rs b/crates/luminal_cuda/src/kernel/ops.rs index ae4ae97c..3a56c78e 100644 --- a/crates/luminal_cuda/src/kernel/ops.rs +++ b/crates/luminal_cuda/src/kernel/ops.rs @@ -1,24 +1,20 @@ use std::sync::Arc; +use crate::{cuda_dtype, kernel::KernelOp}; use cudarc::{ - driver::{CudaContext, CudaFunction, CudaSlice, CudaStream}, + driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream}, nvrtc::{compile_ptx, CompileOptions}, }; use itertools::Itertools; use luminal::{ graph::{extract_dtype, extract_expr, extract_expr_list}, - op::DType, - prelude::ENodeId, + prelude::*, serialized_egraph::SerializedEGraph, - shape::Expression, utils::{ flatten_mul_strides, EgglogOp, LLIROp, OpParam::{self, *}, }, }; -use rustc_hash::{FxHashMap, FxHashSet}; - -use crate::{cuda_dtype, kernel::KernelOp}; pub type Ops = (KernelAdd, KernelMul, KernelIota, KernelGather); @@ -85,6 +81,7 @@ impl KernelOp for KernelAdd { stream: &Arc, ) -> ( CudaFunction, + Arc, String, (Expression, Expression, Expression), (Expression, Expression, Expression), @@ -125,6 +122,7 @@ extern \"C\" {{ .collect(); ( func, + module, kernel, ( self.out_shape.iter().copied().product::(), @@ -205,6 +203,7 @@ impl KernelOp for KernelMul { stream: &Arc, ) -> ( CudaFunction, + Arc, String, (Expression, Expression, Expression), (Expression, Expression, Expression), @@ -245,6 +244,7 @@ extern \"C\" {{ .collect(); ( func, + module, kernel, ( self.out_shape.iter().copied().product::(), @@ -328,6 +328,7 @@ impl KernelOp for KernelGather { stream: &Arc, ) -> ( CudaFunction, + Arc, String, (Expression, Expression, Expression), (Expression, Expression, Expression), @@ -370,6 +371,7 @@ extern \"C\" {{ .collect(); ( func, + module, kernel, (self.out_shape.iter().copied().product(), 1.into(), 1.into()), (1.into(), 1.into(), 1.into()), @@ -438,6 +440,7 @@ impl KernelOp for KernelIota { stream: &Arc, ) -> ( CudaFunction, + Arc, String, (Expression, Expression, Expression), (Expression, Expression, Expression), @@ -468,6 +471,7 @@ extern \"C\" {{ .collect(); ( func, + module, kernel, (self.range, 1.into(), 1.into()), (1.into(), 1.into(), 1.into()), diff --git a/crates/luminal_cuda/src/runtime.rs b/crates/luminal_cuda/src/runtime.rs index 2ed1bb1a..8e3079db 100644 --- a/crates/luminal_cuda/src/runtime.rs +++ b/crates/luminal_cuda/src/runtime.rs @@ -4,14 +4,18 @@ use crate::{ }; use cudarc::{ driver::{ - CudaContext, CudaFunction, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, - LaunchConfig, PushKernelArg, ValidAsZeroBits, + CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, + DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits, }, nvrtc::{compile_ptx_with_opts, CompileOptions}, }; use fixedbitset::FixedBitSet; use itertools::Itertools; -use luminal::{op::Output, prelude::ToId}; +use luminal::{ + graph::Graph, + op::Output, + prelude::{FxHashMap, FxHashSet, ToId}, +}; use luminal::{ prelude::{ petgraph::{ @@ -24,18 +28,20 @@ use luminal::{ }, utils::flatten_z_strides, }; +use memmap2::MmapOptions; use prost::Message as _; -use rustc_hash::{FxHashMap, FxHashSet}; use safetensors::SafeTensors; use std::{ collections::VecDeque, ffi::c_void, fmt::Debug, + fs::File, hash::{DefaultHasher, Hash, Hasher}, io::Read, iter::once, ptr::{null, null_mut}, sync::Arc, + time::Duration, }; use tracing::{span, Level}; use tracing_perfetto_sdk_schema::{ @@ -61,6 +67,7 @@ pub enum ExecutableKernel { n_barriers: Expression, work_queue: Vec, node_to_task_index: FxHashMap, + module: Arc, }, Kernel { kernel: CudaFunction, @@ -71,9 +78,33 @@ pub enum ExecutableKernel { inputs: Vec, output: NodeIndex, constants: FxHashMap>, + module: Arc, }, } +impl Drop for ExecutableKernel { + fn drop(&mut self) { + match self { + ExecutableKernel::Megakernel { + interpreter_constants, + .. + } => { + // Prevent Drop of CudaSlice (likely calls cuMemFree). + let m = std::mem::take(interpreter_constants); + for (_k, v) in m { + std::mem::forget(v); + } + } + ExecutableKernel::Kernel { constants, .. } => { + let m = std::mem::take(constants); + for (_k, v) in m { + std::mem::forget(v); + } + } + } + } +} + impl Debug for ExecutableKernel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -88,11 +119,12 @@ impl Debug for ExecutableKernel { } pub struct CudaRuntime { + pub hlir_buffers: FxHashMap>, pub buffers: FxHashMap>, pub llir_graph: luminal::graph::LLIRGraph, cuda_stream: Arc, cuda_context: Arc, - custom_state: FxHashMap, + pub custom_state: FxHashMap, exec_graph: StableGraph, node_to_exec: FxHashMap, timings: Vec<(Vec, u64)>, @@ -100,26 +132,19 @@ pub struct CudaRuntime { impl CudaRuntime { #[tracing::instrument(skip_all)] - pub fn load_safetensors(&mut self, file_path: &str) { - let file = std::fs::read(file_path).unwrap(); - let st = SafeTensors::deserialize(&file).unwrap(); - for node in self.llir_graph.node_indices().collect_vec() { - if let Some(Input { label, .. }) = self.llir_graph[node].to_op::() { + pub fn load_safetensors(&mut self, cx: &Graph, file_path: &str) { + let f = File::open(file_path).unwrap(); + let mmap = unsafe { MmapOptions::new().map(&f).unwrap() }; + let st = SafeTensors::deserialize(&mmap).unwrap(); + for node in cx.graph.node_indices() { + if let Some(Input { label, .. }) = cx.graph[node].as_any().downcast_ref::() { if let Ok(tensor) = st.tensor(label) { match tensor.dtype() { - safetensors::Dtype::BF16 => { - let data: Vec = tensor - .data() - .chunks_exact(2) - .map(|chunk| { - half::bf16::from_le_bytes([chunk[0], chunk[1]]).to_f32() - }) - .collect(); - self.buffers.insert( - node, - data.to_cuda_buffer(&self.cuda_context, &self.cuda_stream), - ); - self.register_buffer(node); + safetensors::Dtype::F32 => { + let bytes = tensor.data(); + let f32s: &[f32] = bytemuck::cast_slice(bytes); + let dev = f32s.to_cuda_buffer(&self.cuda_context, &self.cuda_stream); + self.hlir_buffers.insert(node, dev); } dtype => unimplemented!("{dtype} loading not supported yet"), } @@ -305,6 +330,16 @@ impl ToCudaBuffer for Vec { } } +impl ToCudaBuffer for &[f32] { + fn to_cuda_buffer(&self, _: &Arc, stream: &Arc) -> CudaSlice { + stream + .memcpy_stod(unsafe { + std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 4) + }) + .unwrap() + } +} + impl ToCudaBuffer for Vec { fn to_cuda_buffer(&self, _: &Arc, stream: &Arc) -> CudaSlice { stream @@ -324,9 +359,11 @@ impl Runtime for CudaRuntime { ); type Data = Box; type ExecReturn = (); + type ProfileMetric = Duration; fn initialize((ctx, stream, custom_state): Self::CompileArg) -> Self { Self { + hlir_buffers: FxHashMap::default(), buffers: FxHashMap::default(), cuda_stream: stream, cuda_context: ctx, @@ -339,9 +376,18 @@ impl Runtime for CudaRuntime { } #[tracing::instrument(skip_all)] - fn compile(&mut self, llir_graph: &LLIRGraph) { + fn load_llir(&mut self, llir_graph: &LLIRGraph) { + self.exec_graph.clear(); + // clear kv cache + for (_, s) in &mut self.custom_state { + if let CustomState::KVCache(layers) = s { + for (k, v) in layers { + self.cuda_stream.memset_zeros(k).unwrap(); + self.cuda_stream.memset_zeros(v).unwrap(); + } + } + } let block_ops = ::into_vec(); - let block_ops_in_graph = llir_graph .node_indices() .filter(|n| llir_graph[*n].to_dialect::().is_some()) @@ -395,7 +441,7 @@ impl Runtime for CudaRuntime { .chain(once(0.into())) .chain(once(1.into())) .collect::>(); - let (interpreter, expressions, interpreter_constants) = compile_interpreter( + let (interpreter, module, expressions, interpreter_constants) = compile_interpreter( &self.cuda_context, &self.cuda_stream, &block_ops, @@ -469,6 +515,7 @@ impl Runtime for CudaRuntime { } let exec_node = exec_graph.add_node(ExecutableKernel::Megakernel { interpreter, + module, interpreter_constants, n_barriers, work_queue: tasks, @@ -481,8 +528,9 @@ impl Runtime for CudaRuntime { // Add kernels for kernel in llir_graph.node_indices() { if let Some(kernel_op) = llir_graph[kernel].to_dialect::() { - let (kernel_function, code, grid, tb, shared_mem, constants) = + let (kernel_function, module, code, grid, tb, shared_mem, constants) = kernel_op.compile(&self.cuda_context, &self.cuda_stream); + self.cuda_stream.synchronize().unwrap(); let inputs = llir_graph .edges_directed(kernel, Direction::Incoming) .sorted_by_key(|e| e.id()) @@ -492,6 +540,7 @@ impl Runtime for CudaRuntime { kernel, exec_graph.add_node(ExecutableKernel::Kernel { kernel: kernel_function, + module, code, launch_grid: grid, launch_threadblock: tb, @@ -524,13 +573,50 @@ impl Runtime for CudaRuntime { self.node_to_exec = node_to_exec; } + #[tracing::instrument(skip_all)] + fn profile( + &mut self, + llir_graph: &LLIRGraph, + dyn_map: &FxHashMap, + ) -> (Self::ProfileMetric, String) { + self.buffers.clear(); + self.load_llir(llir_graph); + self.allocate_intermediate_buffers(dyn_map); + let start = std::time::Instant::now(); + self.execute(dyn_map); + self.timings.clear(); + ( + start.elapsed(), + pretty_duration::pretty_duration(&start.elapsed(), None), + ) + } + #[tracing::instrument(skip_all)] fn execute(&mut self, dyn_map: &FxHashMap) -> Self::ExecReturn { + for (hlir_node, llir_node) in self + .llir_graph + .node_indices() + .filter_map(|n| { + if let Some(Input { node, .. }) = self.llir_graph[n].to_op::() { + Some((*node, n)) + } else { + None + } + }) + .collect_vec() + { + self.buffers.insert( + llir_node, + self.hlir_buffers[&NodeIndex::new(hlir_node)].clone(), + ); + self.register_buffer(llir_node); + } let mut timings = vec![]; for exec_node in toposort(&self.exec_graph, None).unwrap() { match &mut self.exec_graph[exec_node] { ExecutableKernel::Kernel { kernel, + module: _, code: _, launch_grid, launch_threadblock, @@ -580,6 +666,10 @@ impl Runtime for CudaRuntime { work_queue, .. } => { + let sm_count = self + .cuda_context + .attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT) + .unwrap(); let span = span!(Level::INFO, "megakernel_setup"); let _entered = span.enter(); // Upload queue, barriers and program counter @@ -590,10 +680,15 @@ impl Runtime for CudaRuntime { let d_tasks = self.cuda_stream.memcpy_stod(work_queue).unwrap(); let d_head = self.cuda_stream.memcpy_stod(&[0i32]).unwrap(); let queue_lock = self.cuda_stream.memcpy_stod(&[0i32]).unwrap(); - // Set up timing buffer (start_time_u64,[[event_start_u64,event_type_i32 for sm_event in sm[:1000] for sm in sms[:114]]) - let timing_buffer = - self.cuda_stream.alloc_zeros::(114 * 1000).unwrap(); - let start_time = self.cuda_stream.alloc_zeros::(114).unwrap(); + // Set up timing buffer (start_time_u64,[[event_start_u64,event_type_i32 for sm_event in sm[:1000] for sm in sms[:sm_count]]) + let timing_buffer = self + .cuda_stream + .alloc_zeros::(sm_count as usize * 1000) + .unwrap(); + let start_time = self + .cuda_stream + .alloc_zeros::(sm_count as usize) + .unwrap(); // Set up dyn dims for (dyn_dim, val) in dyn_map { @@ -606,17 +701,21 @@ impl Runtime for CudaRuntime { } } + let shared_mem_max = self + .cuda_context + .attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR) + .unwrap(); + interpreter.set_attribute( cudarc::driver::sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - 100_000, + shared_mem_max / 2, // Half shared mem, half L2 ).unwrap(); // Launch kernel - let num_sm = 114; let cfg = LaunchConfig { - grid_dim: (num_sm, 1, 1), // One block per SM - block_dim: (1024, 1, 1), // 512 threads per block - shared_mem_bytes: 100_000, // 40 kb + grid_dim: (sm_count as u32, 1, 1), // One block per SM + block_dim: (1024, 1, 1), // 1024 threads (32 warps) per block + shared_mem_bytes: (shared_mem_max / 2) as u32, }; let mut lb = self.cuda_stream.launch_builder(interpreter); let n_tasks = work_queue.len() as i32; @@ -652,23 +751,10 @@ impl Runtime for CudaRuntime { } fn set_data(&mut self, id: impl ToId, data: Self::Data) { - let id = id.to_id(); - let id = self - .llir_graph - .node_indices() - .find(|n| { - if let Some(Input { node, .. }) = self.llir_graph[*n].to_op::() { - *node == id.index() - } else { - false - } - }) - .unwrap(); - self.buffers.insert( - id, + self.hlir_buffers.insert( + id.to_id(), data.to_cuda_buffer(&self.cuda_context, &self.cuda_stream), ); - self.register_buffer(id); } } @@ -1100,6 +1186,7 @@ pub fn compile_interpreter( expressions: &FxHashSet, ) -> ( CudaFunction, + Arc, FxHashMap, FxHashMap>, ) { @@ -1182,6 +1269,7 @@ pub fn compile_interpreter( }, ) .unwrap(); + cuda_stream.synchronize().unwrap(); let module = cuda_ctx.load_module(ptx).unwrap(); let func = module.load_function("worker_kernel").unwrap(); let constants = constants @@ -1195,7 +1283,8 @@ pub fn compile_interpreter( ) }) .collect(); - (func, expression_map, constants) + cuda_stream.synchronize().unwrap(); + (func, module, expression_map, constants) } #[allow(unused)] diff --git a/crates/luminal_cuda/src/tests.rs b/crates/luminal_cuda/src/tests.rs index 66edb629..dd663416 100644 --- a/crates/luminal_cuda/src/tests.rs +++ b/crates/luminal_cuda/src/tests.rs @@ -1,6 +1,5 @@ use cudarc::driver::CudaContext; use luminal::prelude::*; -use rustc_hash::FxHashMap; use crate::runtime::CudaRuntime; @@ -14,13 +13,11 @@ pub fn cuda_test() { ctx.bind_to_thread().unwrap(); let stream = ctx.default_stream(); cx.build_search_space::(); - let mut rt = cx.search( - CudaRuntime::initialize((ctx, stream, FxHashMap::default())), - 10_000, - ); - rt.set_data(input.id, Box::new(vec![0., 1., 2., 3., 4.])); + let mut rt = CudaRuntime::initialize((ctx, stream, FxHashMap::default())); + rt.set_data(input, Box::new(vec![0., 1., 2., 3., 4.])); + rt = cx.search(rt, 10); rt.allocate_intermediate_buffers(&cx.dyn_map); rt.execute(&cx.dyn_map); - let out = rt.get_f32(output.id); + let out = rt.get_f32(output); assert_eq!(out, vec![0., 2., 4., 6., 8.]); } diff --git a/examples/llama/Cargo.toml b/examples/llama/Cargo.toml index 8c54753c..6feecd8a 100644 --- a/examples/llama/Cargo.toml +++ b/examples/llama/Cargo.toml @@ -12,4 +12,6 @@ luminal_cuda = { path = "../../crates/luminal_cuda" } luminal_tracing = {path="../../crates/luminal_tracing"} itertools = "0.12.1" tokenizers = "0.15.2" -tracing = "0.1.43" \ No newline at end of file +tracing = "0.1.43" +memmap2 = "0.9.9" +bytemuck = "1.24.0" diff --git a/examples/llama/prompts/asimov.txt b/examples/llama/prompts/asimov.txt deleted file mode 100644 index 7b86c8ce..00000000 --- a/examples/llama/prompts/asimov.txt +++ /dev/null @@ -1,12 +0,0 @@ -<|begin_of_text|># Three Laws of Robotics - -**The Three Laws of Robotics** (often shortened to **The Three Laws** or **Asimov's Laws**) are a set of rules devised by science fiction author Isaac Asimov, which were to be followed by robots in several of his stories. The rules were introduced in his 1942 short story "Runaround" (included in the 1950 collection I, Robot), although similar restrictions had been implied in earlier stories. - -## The Laws - -The Three Laws, presented to be from the fictional "Handbook of Robotics, 56th Edition, 2058 A.D.", are: - - The First Law: A robot may not injure a human being or, through inaction, allow a human being to come to harm. - - The Second Law: A robot must obey the orders given it by human beings except where such orders would conflict with the First Law. - - The Third Law: A robot must protect its own existence as long as such protection does not conflict with the First or Second Law. - -## Explination \ No newline at end of file diff --git a/examples/llama/prompts/merge_sort.txt b/examples/llama/prompts/merge_sort.txt deleted file mode 100644 index 06ef09e1..00000000 --- a/examples/llama/prompts/merge_sort.txt +++ /dev/null @@ -1,5 +0,0 @@ -<|begin_of_text|><|start_header_id|>system<|end_header_id|> - -You are a helpful AI assistant<|eot_id|><|start_header_id|>user<|end_header_id|> - -Please write a python implementation of merge sort.<|eot_id|><|start_header_id|>assistant<|end_header_id|> diff --git a/examples/llama/prompts/shakespeare.txt b/examples/llama/prompts/shakespeare.txt deleted file mode 100644 index 84b0f821..00000000 --- a/examples/llama/prompts/shakespeare.txt +++ /dev/null @@ -1,207 +0,0 @@ -<|begin_of_text|> -## SCENE VII. The forest. -A table set out. Enter DUKE SENIOR, AMIENS, and Lords like outlaws - -### DUKE SENIOR -I think he be transform'd into a beast; -For I can no where find him like a man. - -### First Lord -My lord, he is but even now gone hence: -Here was he merry, hearing of a song. - -### DUKE SENIOR -If he, compact of jars, grow musical, -We shall have shortly discord in the spheres. -Go, seek him: tell him I would speak with him. -Enter JAQUES - -### First Lord -He saves my labour by his own approach. - -### DUKE SENIOR -Why, how now, monsieur! what a life is this, -That your poor friends must woo your company? -What, you look merrily! - -### JAQUES -A fool, a fool! I met a fool i' the forest, -A motley fool; a miserable world! -As I do live by food, I met a fool -Who laid him down and bask'd him in the sun, -And rail'd on Lady Fortune in good terms, -In good set terms and yet a motley fool. -'Good morrow, fool,' quoth I. 'No, sir,' quoth he, -'Call me not fool till heaven hath sent me fortune:' -And then he drew a dial from his poke, -And, looking on it with lack-lustre eye, -Says very wisely, 'It is ten o'clock: -Thus we may see,' quoth he, 'how the world wags: -'Tis but an hour ago since it was nine, -And after one hour more 'twill be eleven; -And so, from hour to hour, we ripe and ripe, -And then, from hour to hour, we rot and rot; -And thereby hangs a tale.' When I did hear -The motley fool thus moral on the time, -My lungs began to crow like chanticleer, -That fools should be so deep-contemplative, -And I did laugh sans intermission -An hour by his dial. O noble fool! -A worthy fool! Motley's the only wear. - -### DUKE SENIOR -What fool is this? - -### JAQUES -O worthy fool! One that hath been a courtier, -And says, if ladies be but young and fair, -They have the gift to know it: and in his brain, -Which is as dry as the remainder biscuit -After a voyage, he hath strange places cramm'd -With observation, the which he vents -In mangled forms. O that I were a fool! -I am ambitious for a motley coat. - -### DUKE SENIOR -Thou shalt have one. - -### JAQUES -It is my only suit; -Provided that you weed your better judgments -Of all opinion that grows rank in them -That I am wise. I must have liberty -Withal, as large a charter as the wind, -To blow on whom I please; for so fools have; -And they that are most galled with my folly, -They most must laugh. And why, sir, must they so? -The 'why' is plain as way to parish church: -He that a fool doth very wisely hit -Doth very foolishly, although he smart, -Not to seem senseless of the bob: if not, -The wise man's folly is anatomized -Even by the squandering glances of the fool. -Invest me in my motley; give me leave -To speak my mind, and I will through and through -Cleanse the foul body of the infected world, -If they will patiently receive my medicine. - -### DUKE SENIOR -Fie on thee! I can tell what thou wouldst do. - -### JAQUES -What, for a counter, would I do but good? - -### DUKE SENIOR -Most mischievous foul sin, in chiding sin: -For thou thyself hast been a libertine, -As sensual as the brutish sting itself; -And all the embossed sores and headed evils, -That thou with licence of free foot hast caught, -Wouldst thou disgorge into the general world. - -### JAQUES -Why, who cries out on pride, -That can therein tax any private party? -Doth it not flow as hugely as the sea, -Till that the weary very means do ebb? -What woman in the city do I name, -When that I say the city-woman bears -The cost of princes on unworthy shoulders? -Who can come in and say that I mean her, -When such a one as she such is her neighbour? -Or what is he of basest function -That says his bravery is not of my cost, -Thinking that I mean him, but therein suits -His folly to the mettle of my speech? -There then; how then? what then? Let me see wherein -My tongue hath wrong'd him: if it do him right, -Then he hath wrong'd himself; if he be free, -Why then my taxing like a wild-goose flies, -Unclaim'd of any man. But who comes here? -Enter ORLANDO, with his sword drawn - -### ORLANDO -Forbear, and eat no more. - -### JAQUES -Why, I have eat none yet. - -### ORLANDO -Nor shalt not, till necessity be served. - -### JAQUES -Of what kind should this cock come of? - -### DUKE SENIOR -Art thou thus bolden'd, man, by thy distress, -Or else a rude despiser of good manners, -That in civility thou seem'st so empty? - -### ORLANDO -You touch'd my vein at first: the thorny point -Of bare distress hath ta'en from me the show -Of smooth civility: yet am I inland bred -And know some nurture. But forbear, I say: -He dies that touches any of this fruit -Till I and my affairs are answered. - -### JAQUES -An you will not be answered with reason, I must die. - -### DUKE SENIOR -What would you have? Your gentleness shall force -More than your force move us to gentleness. - -### ORLANDO -I almost die for food; and let me have it. - -### DUKE SENIOR -Sit down and feed, and welcome to our table. - -### ORLANDO -Speak you so gently? Pardon me, I pray you: -I thought that all things had been savage here; -And therefore put I on the countenance -Of stern commandment. But whate'er you are -That in this desert inaccessible, -Under the shade of melancholy boughs, -Lose and neglect the creeping hours of time -If ever you have look'd on better days, -If ever been where bells have knoll'd to church, -If ever sat at any good man's feast, -If ever from your eyelids wiped a tear -And know what 'tis to pity and be pitied, -Let gentleness my strong enforcement be: -In the which hope I blush, and hide my sword. - -### DUKE SENIOR -True is it that we have seen better days, -And have with holy bell been knoll'd to church -And sat at good men's feasts and wiped our eyes -Of drops that sacred pity hath engender'd: -And therefore sit you down in gentleness -And take upon command what help we have -That to your wanting may be minister'd. - -### ORLANDO -Then but forbear your food a little while, -Whiles, like a doe, I go to find my fawn -And give it food. There is an old poor man, -Who after me hath many a weary step -Limp'd in pure love: till he be first sufficed, -Oppress'd with two weak evils, age and hunger, -I will not touch a bit. - -### DUKE SENIOR -Go find him out, -And we will nothing waste till you return. - -### ORLANDO -I thank ye; and be blest for your good comfort! -Exit - -### DUKE SENIOR -Thou seest we are not all alone unhappy: -This wide and universal theatre -Presents more woeful pageants than the scene -Wherein we play in. diff --git a/examples/llama/setup/setup.py b/examples/llama/setup/setup.py index 70b60007..7464d500 100644 --- a/examples/llama/setup/setup.py +++ b/examples/llama/setup/setup.py @@ -13,18 +13,16 @@ import json from pathlib import Path +import torch from huggingface_hub import hf_hub_download, list_repo_files from safetensors import safe_open from safetensors.torch import save_file def download_model_files(repo_id: str, output_dir: Path): - """Download model files from Hugging Face Hub.""" - print(f"Listing files in {repo_id}...") all_files = list_repo_files(repo_id) - # Filter for files we need: tokenizer.json and all .safetensors files files_to_download = [] for file in all_files: if ( @@ -47,52 +45,56 @@ def download_model_files(repo_id: str, output_dir: Path): print("All files downloaded successfully!") -def combine_safetensors(model_dir: Path): - """Combine sharded safetensors files into a single file.""" - - # Check if combined file already exists +def combine_and_convert_safetensors_to_fp32(model_dir: Path): + """ + Combine sharded safetensors into a single file, converting tensors to FP32 on the fly. + Outputs: model_combined_fp32.safetensors + """ output_path = model_dir / "model_combined.safetensors" if output_path.exists(): - print(f"Combined safetensors file already exists at {output_path}") - print("Skipping combination step.") + print(f"FP32 combined safetensors file already exists at {output_path}") + print("Skipping combine+convert step.") return - # Load the index index_path = model_dir / "model.safetensors.index.json" + if not index_path.exists(): + raise FileNotFoundError(f"Missing index file: {index_path}") + with open(index_path, "r") as f: index = json.load(f) - # Collect all tensors - all_tensors = {} weight_map = index.get("weight_map", {}) - - # Get unique shard files shard_files = sorted(set(weight_map.values())) - print(f"Loading {len(shard_files)} shard files...") + # Stream through shards; convert tensors to fp32 as we read. + all_tensors = {} + + print(f"Loading {len(shard_files)} shard files (converting to fp32)...") for shard_file in shard_files: shard_path = model_dir / shard_file print(f" Loading {shard_file}...") - with safe_open(shard_path, framework="pt", device="cpu") as f: - for key in f.keys(): - all_tensors[key] = f.get_tensor(key) + with safe_open(shard_path, framework="pt", device="cpu") as sf: + for key in sf.keys(): + t = sf.get_tensor(key) - # Save combined file - print(f"Saving combined model to {output_path}...") - save_file(all_tensors, output_path) + # Convert float dtypes to fp32; keep non-floats as-is (e.g., int tensors, masks). + if torch.is_floating_point(t) and t.dtype != torch.float32: + t = t.to(dtype=torch.float32) - print(f"Combined model saved successfully to {output_path}") + all_tensors[key] = t + + print(f"Saving combined FP32 model to {output_path}...") + save_file(all_tensors, output_path) + print(f"Combined FP32 model saved successfully to {output_path}") if __name__ == "__main__": script_dir = Path(__file__).parent repo_id = "NousResearch/Meta-Llama-3-8B-Instruct" - # Download files from Hugging Face Hub download_model_files(repo_id, script_dir) - # Combine safetensors files - print("\nCombining safetensors files...") - combine_safetensors(script_dir) + print("\nCombining + converting safetensors to FP32...") + combine_and_convert_safetensors_to_fp32(script_dir) print("\nDone!") diff --git a/examples/llama/src/benchmark.rs b/examples/llama/src/benchmark.rs index ca9101f6..f5435abf 100644 --- a/examples/llama/src/benchmark.rs +++ b/examples/llama/src/benchmark.rs @@ -7,7 +7,7 @@ pub struct Benchmarker { peak_tflops: f64, peak_gbps: f64, start_generation: Instant, - ttft: Option, + ttft: Duration, decode_durations: Vec, seq_lengths: Vec<(usize, usize)>, current_iter_start: Option, @@ -20,7 +20,7 @@ impl Benchmarker { peak_tflops, peak_gbps, start_generation: Instant::now(), - ttft: None, + ttft: Duration::default(), decode_durations: vec![], seq_lengths: vec![], current_iter_start: None, @@ -38,7 +38,7 @@ impl Benchmarker { if let Some(start) = self.current_iter_start.take() { let duration = start.elapsed(); if iteration == 0 { - self.ttft = Some(duration); + self.ttft = duration; } else { self.decode_durations.push(duration); } @@ -52,11 +52,6 @@ impl Benchmarker { .decode_durations .iter() .fold(Duration::ZERO, |acc, value| acc + *value); - let tpot = if self.decode_durations.is_empty() { - None - } else { - Some(decode_total / self.decode_durations.len() as u32) - }; let (total_flops, total_bytes) = self .seq_lengths @@ -68,37 +63,26 @@ impl Benchmarker { let achieved_tflops = total_flops as f64 / total_elapsed.as_secs_f64() / 1e12; let achieved_gbps = total_bytes as f64 / total_elapsed.as_secs_f64() / 1e9; - let mfu = if self.peak_tflops > 0.0 { - Some(achieved_tflops / self.peak_tflops) - } else { - None - }; - let mbu = if self.peak_gbps > 0.0 { - Some(achieved_gbps / self.peak_gbps) - } else { - None - }; println!("Benchmark results:"); - if let Some(ttft) = self.ttft { - println!(" TTFT: {:.2} ms", ttft.as_secs_f64() * 1e3); - } - if let Some(tpot) = tpot { - println!(" TPOT: {:.2} ms", tpot.as_secs_f64() * 1e3); + println!(" TTFT: {:.2} ms", self.ttft.as_secs_f64() * 1e3); + if !self.decode_durations.is_empty() { + println!( + " TPOT: {:.2} ms", + (decode_total / self.decode_durations.len() as u32).as_secs_f64() * 1e3 + ); } println!( " Achieved: {:.2} TFLOP/s, {:.2} GB/s", achieved_tflops, achieved_gbps ); - if let Some(mfu) = mfu { - println!(" MFU (est): {:.1}%", mfu * 100.0); - } else { - println!(" MFU (est): N/A (set LUMINAL_PEAK_TFLOPS)"); - } - if let Some(mbu) = mbu { - println!(" MBU (est): {:.1}%", mbu * 100.0); - } else { - println!(" MBU (est): N/A (set LUMINAL_PEAK_BW_GBPS)"); - } + println!( + " MFU (est): {:.2}%", + (achieved_tflops / self.peak_tflops) * 100.0 + ); + println!( + " MBU (est): {:.2}%", + (achieved_gbps / self.peak_gbps) * 100.0 + ); } } diff --git a/examples/llama/src/main.rs b/examples/llama/src/main.rs index c8cd14d7..726a6795 100644 --- a/examples/llama/src/main.rs +++ b/examples/llama/src/main.rs @@ -37,6 +37,8 @@ fn main() { let ctx = luminal_cuda::cudarc::driver::CudaContext::new(0).unwrap(); ctx.bind_to_thread().unwrap(); + ctx.set_flags(luminal_cuda::cudarc::driver::sys::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC) + .unwrap(); let stream = ctx.default_stream(); println!("Allocating KV Cache..."); @@ -62,11 +64,17 @@ fn main() { println!("Building E-Graph..."); cx.build_search_space::(); - println!("Compiling..."); - let mut runtime = cx.search(CudaRuntime::initialize((ctx, stream, custom_state)), 10_000); - + let mut runtime = CudaRuntime::initialize((ctx, stream.clone(), custom_state)); println!("Loading weights..."); - runtime.load_safetensors("setup/model_combined.safetensors"); + runtime.load_safetensors(&cx, "setup/model_combined.safetensors"); + + println!("Compiling..."); + cx.set_dyn_dim('s', 1); + cx.set_dyn_dim('p', 0); + // inputs for search + runtime.set_data(input, Box::new(vec![1_i32])); + runtime.set_data(token_ids, Box::new(vec![0_i32])); + runtime = cx.search(runtime, 5); print!("{input_sentence}"); std::io::stdout().flush().unwrap(); diff --git a/src/egglog_utils/base.egg b/src/egglog_utils/base.egg index 1985e22f..5ae1799b 100644 --- a/src/egglog_utils/base.egg +++ b/src/egglog_utils/base.egg @@ -46,7 +46,7 @@ ;(rewrite (MAdd a b) (MAdd b a) :ruleset expr) ; communativity leads to some explosions ;(rewrite (MMul a b) (MMul b a) :ruleset expr) -(rewrite (MAdd (MAdd a b) c) (MAdd a (MAdd b c)) :ruleset expr) +;(rewrite (MAdd (MAdd a b) c) (MAdd a (MAdd b c)) :ruleset expr) ; explodes weirdly, see no_explode test in symbolic.rs ;(rewrite (MMul (MMul a b) c) (MMul a (MMul b c)) :ruleset expr) (rewrite (MAdd (MNum a) (MNum b)) (MNum (+ a b)) :ruleset expr) @@ -70,7 +70,7 @@ (rewrite (MAnd (MNum a) (MNum b)) (MNum (& a b)) :ruleset expr) (rewrite (MFloat -1.0) (MNum -1) :ruleset expr) (rewrite (MNum -1) (MFloat -1.0) :ruleset expr) -(rewrite (MDiv (MMul ?x (MNum ?a)) (MNum ?b)) (MMul ?x (MNum (/ ?a ?b))) :when ((< ?b ?a) (= (% ?a ?b) 0)) :ruleset expr) ; why does this explode??? +;(rewrite (MDiv (MMul ?x (MNum ?a)) (MNum ?b)) (MMul ?x (MNum (/ ?a ?b))) :when ((< ?b ?a) (= (% ?a ?b) 0)) :ruleset expr) ; why does this explode??? (rewrite (MAdd a (MNum 0)) a :ruleset expr) (rule ((= ?e (MMul ?a (MNum 1)))) ((union ?e ?a)) :ruleset expr) diff --git a/src/graph.rs b/src/graph.rs index fee05e79..157958e5 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -6,6 +6,8 @@ use crate::{ }; use std::{ any::TypeId, + fmt::Debug, + io::Write, ops::{Deref, DerefMut}, sync::Arc, }; @@ -196,41 +198,79 @@ impl Graph { self.ops.as_ref().unwrap(), limit, ); - let print = std::env::var("SEARCH") - .map(|s| s == "1") - .unwrap_or_default(); - let limit_reached = llir_graphs.len() == limit; + let n_graphs = llir_graphs.len(); let start = std::time::Instant::now(); - if print { - println!( - "{}", + let mut best_graph = StableGraph::default(); + let mut best_metric: Option = None; + let total = llir_graphs.len(); + let bar_width = 24; + + let progress_bar = |i| { + let head = ((i as f32 / total as f32) * bar_width as f32) + .clamp(0.0, bar_width as f32) + .floor() as usize; + let bar = if head == 0 { + format!("[>{}]", " ".repeat(bar_width - 1)) + } else if head >= bar_width { + format!("[{}>]", "=".repeat(bar_width)) + } else { format!( - "---- Searching through {}{} graphs ----", - llir_graphs.len().to_string().bold(), - if limit_reached { "[limit]" } else { "" } + "[{}>{}]", + "=".repeat(head), + " ".repeat(bar_width - head - 1) ) - .cyan() - ); - } - runtime.compile(llir_graphs.last().unwrap()); - if print { - info!( - target: "luminal::search", - graphs = llir_graphs.len(), - limit, - limit_reached, - duration_ms = start.elapsed().as_millis() as u64, - "search completed" + }; + print!( + "\r\x1b[2K {:>6} {bar} {i}/{total}", + "Searching".cyan().bold(), ); + std::io::stdout().flush().unwrap(); + }; + + // Search loop + for (i, llir_graph) in llir_graphs.into_iter().enumerate() { + progress_bar(i + 1); + let (new_metric, display_metric) = runtime.profile(&llir_graph, &self.dyn_map); + let mut new_best = false; + if let Some(old_metric) = &best_metric { + if old_metric.gt(&new_metric) { + best_metric = Some(new_metric); + best_graph = llir_graph; + new_best = true; + } + } else { + best_metric = Some(new_metric); + best_graph = llir_graph; + new_best = true; + } + print!("\r\x1b[2K"); // clear line + std::io::stdout().flush().unwrap(); println!( - "{}", - format!( - "---- Search Took {} ----", - pretty_duration::pretty_duration(&start.elapsed(), None).bold() - ) - .cyan() + " {:>6} Graph {}: {}", + "Searched".green().bold(), + i + 1, + if new_best { + display_metric.bold().green().to_string() + } else { + display_metric + } ); } + + info!( + target: "luminal::search", + graphs = n_graphs, + limit, + limit_reached = n_graphs >= limit, + duration_ms = start.elapsed().as_millis() as u64, + "search completed" + ); + println!( + " {:>6} {n_graphs} graphs in {}", + "Searched".green().bold(), + pretty_duration::pretty_duration(&start.elapsed(), None) + ); + runtime.load_llir(&best_graph); runtime } } @@ -240,10 +280,16 @@ pub trait Runtime { type CompileArg; type Data; type ExecReturn; + type ProfileMetric: PartialOrd + Clone + Debug; fn initialize(arg: Self::CompileArg) -> Self; - fn compile(&mut self, llir_graph: &LLIRGraph); + fn load_llir(&mut self, llir_graph: &LLIRGraph); fn set_data(&mut self, id: impl ToId, data: Self::Data); fn execute(&mut self, dyn_map: &FxHashMap) -> Self::ExecReturn; + fn profile( + &mut self, + llir_graph: &LLIRGraph, + dyn_map: &FxHashMap, + ) -> (Self::ProfileMetric, String); } impl Deref for Graph { @@ -393,41 +439,36 @@ fn run_egglog( panic!("Failed to run:\n{s}\nError: {e}"); } } - if std::env::var("SEARCH") - .map(|s| s == "1") - .unwrap_or_default() + println!("{}", "---- Egglog Rule Matches ----".green()); + let mut rule_lines = Vec::new(); + for (rule, matches) in egraph + .get_overall_run_report() + .num_matches_per_rule + .iter() + .filter(|(k, _)| !k.contains("(")) { - println!("{}", "---- Egglog Rule Matches ----".green()); - let mut rule_lines = Vec::new(); - for (rule, matches) in egraph - .get_overall_run_report() - .num_matches_per_rule - .iter() - .filter(|(k, _)| !k.contains("(")) - { - info!( - target: "luminal::egglog", - rule = %rule, - matches = *matches, - "rule matches" - ); - rule_lines.push(format!("{rule}: {matches}")); - } - println!("{}", rule_lines.join("\n").green()); - println!( - "{}", - format!( - "---- Egglog Took {} ----", - pretty_duration::pretty_duration(&start.elapsed(), None).bold() - ) - .green() - ); info!( target: "luminal::egglog", - duration_ms = start.elapsed().as_millis() as u64, - "egglog run completed" + rule = %rule, + matches = *matches, + "rule matches" ); + rule_lines.push(format!("{rule}: {matches}")); } + println!("{}", rule_lines.join("\n").green()); + println!( + "{}", + format!( + "---- Egglog Took {} ----", + pretty_duration::pretty_duration(&start.elapsed(), None).bold() + ) + .green() + ); + info!( + target: "luminal::egglog", + duration_ms = start.elapsed().as_millis() as u64, + "egglog run completed" + ); let (sort, value) = egraph.eval_expr(&var!(root)).unwrap(); let s = egraph.serialize(egglog::SerializeConfig { diff --git a/src/op.rs b/src/op.rs index 37d4433e..9768c9fb 100644 --- a/src/op.rs +++ b/src/op.rs @@ -1458,6 +1458,7 @@ impl Runtime for NativeRuntime { type CompileArg = (); type Data = NativeData; type ExecReturn = (); + type ProfileMetric = usize; fn initialize(_: Self::CompileArg) -> Self { Self { @@ -1466,7 +1467,15 @@ impl Runtime for NativeRuntime { } } - fn compile(&mut self, llir_graph: &LLIRGraph) { + fn profile( + &mut self, + _: &LLIRGraph, + _: &FxHashMap, + ) -> (Self::ProfileMetric, String) { + (0, "0 ms".to_string()) + } + + fn load_llir(&mut self, llir_graph: &LLIRGraph) { // Extract nativeop graph let mut graph = StableGraph::new(); for node in llir_graph.node_weights() { diff --git a/src/shape/symbolic.rs b/src/shape/symbolic.rs index f372160d..778954e0 100644 --- a/src/shape/symbolic.rs +++ b/src/shape/symbolic.rs @@ -996,6 +996,7 @@ mod tests { assert_eq!(expr.simplify().len(), 7); } + #[ignore] // ignore until we can add back in associativity #[test] fn test_simple_div() { let w = Expression::from('w'); @@ -1011,6 +1012,7 @@ mod tests { assert!(!(a + 1).egglog_equal(a + 2)); } + #[ignore] // ignore until we can add back in associativity #[test] fn test_other() { let z = Expression::from('z'); @@ -1024,6 +1026,7 @@ mod tests { assert!(x.len() <= 27); // Should be 21 if we can re-enable mul-div-associative-rev } + #[ignore] // ignore until we can add back in associativity #[test] fn test_final() { let z = Expression::from('z'); @@ -1054,4 +1057,18 @@ mod tests { simplified.exec(&env).unwrap() ); } + + #[test] + fn test_no_explode() { + let r = Expression::new(vec![ + Term::Num(1), + Term::Num(8), + Term::Num(32), + Term::Div, + Term::Num(27), + Term::Add, + Term::Add, + ]); + r.simplify(); + } } diff --git a/src/utils.rs b/src/utils.rs index afd8a3a8..4cd20518 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -79,111 +79,6 @@ impl DialectOpTrait for DialectOp {} pub trait DialectOpTrait: AsAny + Debug {} -#[derive(Debug)] -pub struct CStructBuilder { - buf: Vec, - max_align: usize, -} - -impl Default for CStructBuilder { - fn default() -> Self { - Self { - buf: Vec::new(), - max_align: 1, - } - } -} - -#[allow(unused)] -impl CStructBuilder { - pub fn new() -> Self { - Self::default() - } - - fn align_to(&mut self, align: usize) { - self.max_align = self.max_align.max(align); - - let len = self.buf.len(); - let rem = len % align; - if rem != 0 { - let pad = align - rem; - self.buf.extend(std::iter::repeat_n(0u8, pad)); - } - } - - pub fn int(mut self, v: i32) -> Self { - self.align_to(4); - self.buf.extend_from_slice(&v.to_ne_bytes()); - self - } - - pub fn ints(mut self, vs: &[i32]) -> Self { - self.align_to(4); - for &v in vs { - self.buf.extend_from_slice(&v.to_ne_bytes()); - } - self - } - - pub fn float(mut self, v: f32) -> Self { - self.align_to(4); - self.buf.extend_from_slice(&v.to_ne_bytes()); - self - } - - pub fn floats(mut self, vs: &[f32]) -> Self { - self.align_to(4); - for &v in vs { - self.buf.extend_from_slice(&v.to_ne_bytes()); - } - self - } - - pub fn bool(mut self, v: bool) -> Self { - self.align_to(1); - self.buf.push(if v { 1 } else { 0 }); - self - } - - pub fn ptr_const_f32(mut self, p: *const f32) -> Self { - let ptr_size = std::mem::size_of::(); // usually 8 - let ptr_align = ptr_size; - self.align_to(ptr_align); - - let addr = p as usize; - let bytes = addr.to_ne_bytes(); - - self.buf.extend_from_slice(&bytes[..ptr_size]); - self - } - - pub fn ptr_mut_f32(self, p: *mut f32) -> Self { - self.ptr_const_f32(p as *const f32) - } - - /// Pad the struct size to a multiple of max_align. - pub fn finish_struct(mut self) -> Vec { - let align = self.max_align; - if align > 1 { - let len = self.buf.len(); - let rem = len % align; - if rem != 0 { - let pad = align - rem; - self.buf.extend(std::iter::repeat_n(0u8, pad)); - } - } - self.buf - } - - /// Insert a raw byte field (e.g., another struct). - /// `align` must be the alignment of the nested struct. - pub fn bytes(mut self, align: usize, data: &[u8]) -> Self { - self.align_to(align); - self.buf.extend_from_slice(data); - self - } -} - pub enum OpParam { EList, Expr,