Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Run clippy
run: rustup update; cargo clippy --all-targets -- -D warnings
run: sudo apt-get install protobuf-compiler; rustup update; cargo clippy --all-targets -- -D warnings

fmt:
name: Fmt
Expand Down
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ egglog = "1.0.0"
egglog-ast = "1.0.0"
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
tracing = "0.1.43"
tracing-appender = "0.2.4"
tracing-perfetto-sdk-layer = "0.13.0"
tracing-perfetto-sdk-schema = "0.13.0"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
paste = "1.0.15"
pretty-duration = "0.1.1"
anyhow = "1.0"
Expand Down
125 changes: 63 additions & 62 deletions crates/luminal_cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub struct CudaRuntime {
custom_state: FxHashMap<String, CustomState>,
exec_graph: StableGraph<ExecutableKernel, (), Directed>,
node_to_exec: FxHashMap<NodeIndex, NodeIndex>,
timings: Vec<(Vec<SMEvent>, u64)>,
}

impl CudaRuntime {
Expand Down Expand Up @@ -229,6 +230,65 @@ impl CudaRuntime {
}
}
}

pub fn record_cuda_perfetto_trace(&self, file_path: impl AsRef<std::path::Path>) {
let ops = <crate::block::Ops as IntoBlockOp>::into_vec();
let data = std::fs::read(&file_path).unwrap();
let mut trace = tracing_perfetto_sdk_schema::Trace::decode(data.as_slice()).unwrap();

let host_start_times: Vec<(u64, u32)> = trace
.packet
.iter()
.filter_map(|p| match &p.data {
Some(tracing_perfetto_sdk_schema::trace_packet::Data::TrackEvent(TrackEvent {
name_field: Some(tracing_perfetto_sdk_schema::track_event::NameField::Name(s)),
r#type: ty,
..
})) if s == "megakernel"
&& *ty
== Some(
tracing_perfetto_sdk_schema::track_event::Type::SliceBegin as i32,
) =>
{
Some((p.timestamp?, p.timestamp_clock_id?))
}
_ => None,
})
.sorted_by_key(|i| *i)
.collect_vec();
let mut extra_packets = Vec::new();
for (run, (device_timings, device_start_time)) in self.timings.iter().enumerate() {
let (host_time, host_clock_id) = host_start_times[run];
for (sm, sm_timings) in device_timings.chunks(1000).into_iter().enumerate() {
let mut builder = ManualTrackBuilder::new(sm as u32, host_time, host_clock_id);
for n_op in 0..sm_timings.len() - 1 {
let op = sm_timings[n_op].event as usize;
let op_label = if op == 0 {
"Issue".to_string()
} else if op == 1 {
"Wait".to_string()
} else {
ops[op - 2].term().0
};
if sm_timings[n_op + 1].start == 0 {
break;
}
builder.push_slice(
&op_label,
sm_timings[n_op].start - *device_start_time,
sm_timings[n_op + 1].start - *device_start_time,
host_time,
host_clock_id,
);
}
extra_packets.extend(builder.into_packets());
}
}
trace.packet.extend(extra_packets);
let mut buf = Vec::with_capacity(trace.encoded_len());
trace.encode(&mut buf).unwrap();
std::fs::write(file_path, buf).unwrap();
}
}

pub trait ToCudaBuffer {
Expand Down Expand Up @@ -263,7 +323,7 @@ impl Runtime for CudaRuntime {
FxHashMap<String, CustomState>,
);
type Data = Box<dyn ToCudaBuffer>;
type ExecReturn = Vec<(Vec<SMEvent>, u64)>;
type ExecReturn = ();

fn initialize((ctx, stream, custom_state): Self::CompileArg) -> Self {
Self {
Expand All @@ -274,6 +334,7 @@ impl Runtime for CudaRuntime {
custom_state: custom_state,
exec_graph: StableGraph::default(),
node_to_exec: FxHashMap::default(),
timings: vec![],
}
}

Expand Down Expand Up @@ -587,7 +648,7 @@ impl Runtime for CudaRuntime {
}
}
}
timings
self.timings.extend(timings);
}

fn set_data(&mut self, id: impl ToId, data: Self::Data) {
Expand Down Expand Up @@ -795,66 +856,6 @@ pub fn allocate_input_buffers(
buffers
}

pub fn record_exec_timings_to_file(
timings: &Vec<(Vec<SMEvent>, u64)>,
ops: &Vec<Arc<Box<dyn BlockOp>>>,
file_path: &str,
) {
let data = std::fs::read(file_path).unwrap();
let mut trace = tracing_perfetto_sdk_schema::Trace::decode(data.as_slice()).unwrap();

let host_start_times: Vec<(u64, u32)> = trace
.packet
.iter()
.filter_map(|p| match &p.data {
Some(tracing_perfetto_sdk_schema::trace_packet::Data::TrackEvent(TrackEvent {
name_field: Some(tracing_perfetto_sdk_schema::track_event::NameField::Name(s)),
r#type: ty,
..
})) if s == "megakernel"
&& *ty
== Some(tracing_perfetto_sdk_schema::track_event::Type::SliceBegin as i32) =>
{
Some((p.timestamp?, p.timestamp_clock_id?))
}
_ => None,
})
.sorted_by_key(|i| *i)
.collect_vec();
let mut extra_packets = Vec::new();
for (run, (device_timings, device_start_time)) in timings.iter().enumerate() {
let (host_time, host_clock_id) = host_start_times[run];
for (sm, sm_timings) in device_timings.chunks(1000).into_iter().enumerate() {
let mut builder = ManualTrackBuilder::new(sm as u32, host_time, host_clock_id);
for n_op in 0..sm_timings.len() - 1 {
let op = sm_timings[n_op].event as usize;
let op_label = if op == 0 {
"Issue".to_string()
} else if op == 1 {
"Wait".to_string()
} else {
ops[op - 2].term().0
};
if sm_timings[n_op + 1].start == 0 {
break;
}
builder.push_slice(
&op_label,
sm_timings[n_op].start - *device_start_time,
sm_timings[n_op + 1].start - *device_start_time,
host_time,
host_clock_id,
);
}
extra_packets.extend(builder.into_packets());
}
}
trace.packet.extend(extra_packets);
let mut buf = Vec::with_capacity(trace.encoded_len());
trace.encode(&mut buf).unwrap();
std::fs::write(file_path, buf).unwrap();
}

struct ManualTrackBuilder {
packets: Vec<schema::TracePacket>,
track_uuid: u64,
Expand Down
7 changes: 1 addition & 6 deletions examples/llama/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,4 @@ luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda = { path = "../../crates/luminal_cuda" }
itertools = "0.12.1"
tokenizers = "0.15.2"
tracing = "0.1.43"
tracing-subscriber = {version="0.3", features=["env-filter"]}
tracing-perfetto-sdk-layer = "0.13.0"
tracing-perfetto-sdk-schema = "0.13.0"
tracing-perfetto-sdk-sys = "0.13.0"
tracing-appender = "0.2.4"
tracing = "0.1.43"
63 changes: 11 additions & 52 deletions examples/llama/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,16 @@ use luminal::{
op::DType,
prelude::FxHashMap,
};
use luminal_cuda::{
block::IntoBlockOp,
runtime::{record_exec_timings_to_file, CudaRuntime, CustomState},
};
use luminal_cuda::runtime::{CudaRuntime, CustomState};
use model::*;
use std::{fs::File, io::Write, time::Duration};
use std::io::Write;
use tokenizers::Tokenizer;
use tracing::{span, Level};
use tracing_appender::non_blocking;
use tracing_perfetto_sdk_layer::NativeLayer;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

fn main() {
// Set up tracing
let file = File::create("trace.pftrace").unwrap();
let (writer, _guard) = non_blocking(file);
let layer = NativeLayer::from_config(trace_config(), writer)
.build()
.unwrap();
let filter = EnvFilter::builder()
.parse(format!("{}=trace,luminal=trace", env!("CARGO_PKG_NAME")))
.unwrap();
let layer_handle = layer.clone();
tracing_subscriber::registry()
.with(filter)
.with(layer)
let trace_session = luminal::trace::new()
.perfetto("trace.pftrace")
.env_filter(format!("{}=trace,luminal=trace", env!("CARGO_PKG_NAME")))
.init();

let max_seq_len = 4096;
Expand Down Expand Up @@ -88,7 +71,6 @@ fn main() {
print!("{input_sentence}");
std::io::stdout().flush().unwrap();

let mut timings = vec![];
let mut prev_seq = 0;
let mut benchmarker = Benchmarker::new(756., 2_000.); // H100 specs
for i in 0..gen_tokens {
Expand Down Expand Up @@ -121,7 +103,7 @@ fn main() {
}

benchmarker.start_iteration(seq_len, prev_seq);
timings.extend(runtime.execute(&cx.dyn_map));
runtime.execute(&cx.dyn_map);
let logits_data = runtime.get_f32(logits);

let sample_span = span!(Level::INFO, "sample");
Expand All @@ -134,18 +116,12 @@ fn main() {
}
println!();

trace_session.stop();
benchmarker.report();

layer_handle
.flush(Duration::from_secs(5), Duration::from_secs(5))
.unwrap();
layer_handle.stop().unwrap();
drop(_guard);
record_exec_timings_to_file(
&timings,
&<luminal_cuda::block::Ops as IntoBlockOp>::into_vec(),
"trace.pftrace",
);
// Dump cuda trace to timeline
if let Some(path) = trace_session.perfetto_path {
runtime.record_cuda_perfetto_trace(path);
}
}

#[tracing::instrument(skip_all)]
Expand All @@ -164,20 +140,3 @@ fn sample(logits: &[f32], vocab_size: usize) -> Vec<u32> {
})
.collect()
}

fn trace_config() -> tracing_perfetto_sdk_schema::TraceConfig {
tracing_perfetto_sdk_schema::TraceConfig {
buffers: vec![tracing_perfetto_sdk_schema::trace_config::BufferConfig {
size_kb: Some(4096),
..Default::default()
}],
data_sources: vec![tracing_perfetto_sdk_schema::trace_config::DataSource {
config: Some(tracing_perfetto_sdk_schema::DataSourceConfig {
name: Some("rust_tracing".into()),
..Default::default()
}),
..Default::default()
}],
..Default::default()
}
}
48 changes: 32 additions & 16 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use egraph_serialize::{ClassId, NodeId};
use itertools::Itertools;
use petgraph::{Direction, stable_graph::StableGraph, visit::EdgeRef};
use rustc_hash::{FxHashMap, FxHashSet};
use tracing::info;

pub type LLIRGraph = StableGraph<LLIROp, (), petgraph::Directed>;
pub type HLIRGraph = StableGraph<Box<dyn HLIROp>, Dependency>;
Expand Down Expand Up @@ -198,24 +199,29 @@ impl Graph {
let print = std::env::var("SEARCH")
.map(|s| s == "1")
.unwrap_or_default();
let limit_reached = llir_graphs.len() == limit;
let start = std::time::Instant::now();
if print {
println!(
"{}",
format!(
"---- Searching through {}{} graphs ----",
llir_graphs.len().to_string().bold(),
if llir_graphs.len() == limit {
"[limit]"
} else {
""
}
if limit_reached { "[limit]" } else { "" }
)
.cyan()
);
}
runtime.compile(llir_graphs.last().unwrap());
if print {
info!(
target: "luminal::search",
graphs = llir_graphs.len(),
limit,
limit_reached,
duration_ms = start.elapsed().as_millis() as u64,
"search completed"
);
println!(
"{}",
format!(
Expand Down Expand Up @@ -392,17 +398,22 @@ fn run_egglog(
.unwrap_or_default()
{
println!("{}", "---- Egglog Rule Matches ----".green());
println!(
"{}",
egraph
.get_overall_run_report()
.num_matches_per_rule
.iter()
.filter(|(k, _)| !k.contains("("))
.map(|(k, v)| format!("{k}: {v}"))
.join("\n")
.green()
);
let mut rule_lines = Vec::new();
for (rule, matches) in egraph
.get_overall_run_report()
.num_matches_per_rule
.iter()
.filter(|(k, _)| !k.contains("("))
{
info!(
target: "luminal::egglog",
rule = %rule,
matches = *matches,
"rule matches"
);
rule_lines.push(format!("{rule}: {matches}"));
}
println!("{}", rule_lines.join("\n").green());
println!(
"{}",
format!(
Expand All @@ -411,6 +422,11 @@ fn run_egglog(
)
.green()
);
info!(
target: "luminal::egglog",
duration_ms = start.elapsed().as_millis() as u64,
"egglog run completed"
);
}

let (sort, value) = egraph.eval_expr(&var!(root)).unwrap();
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod hl_ops;
pub mod op;
pub mod serialized_egraph;
pub mod shape;
pub mod trace;
pub mod utils;
pub mod visualization;

Expand Down
Loading