diff --git a/Cargo.lock b/Cargo.lock index 58191ae8bc..de9444c5ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1143,12 +1143,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - [[package]] name = "cfg_aliases" version = "0.2.1" @@ -1581,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1596,13 +1590,17 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ + "bytemuck", "derive-new 0.6.0", + "derive_more 1.0.0", "embassy-futures", "futures-lite", "getrandom", + "half", "log", + "num-traits", "portable-atomic", "rand", "serde", @@ -1613,11 +1611,12 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bitflags 2.8.0", "bytemuck", "cubecl-common", + "cubecl-ir", "cubecl-macros", "cubecl-runtime", "derive-new 0.6.0", @@ -1633,7 +1632,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1647,7 +1646,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1663,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1686,10 +1685,23 @@ dependencies = [ "libc", ] +[[package]] +name = "cubecl-ir" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +dependencies = [ + "cubecl-common", + "float-ord", + "half", + "num-traits", + "serde", + "type_hash", +] + [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-core", @@ -1701,7 +1713,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-common", "darling", @@ -1716,10 +1728,10 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-common", - "cubecl-core", + "cubecl-ir", "float-ord", "log", "num", @@ -1732,7 +1744,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1742,11 +1754,11 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "async-channel", "async-lock", - "cfg_aliases 0.2.1", + "cfg_aliases", "cubecl-common", "derive-new 0.6.0", "dirs", @@ -1764,7 +1776,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1779,13 +1791,13 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "ash", "async-channel", "bytemuck", "cfg-if", - "cfg_aliases 0.2.1", + "cfg_aliases", "cubecl-common", "cubecl-core", "cubecl-runtime", @@ -2835,9 +2847,9 @@ dependencies = [ [[package]] name = "glow" -version = "0.14.2" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d51fa363f025f5c111e03f13eda21162faeacb6911fe8caa0c0349f9cf0c4483" +checksum = "c5e5ea60d70410161c8bf5da3fdfeaa1c72ed2c15f8bbb9d19fe3a4fad085f08" dependencies = [ "js-sys", "slotmap", @@ -3907,6 +3919,21 @@ dependencies = [ "paste", ] +[[package]] +name = "metal" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f569fb946490b5743ad69813cb19629130ce9374034abe31614a36402d18f99e" +dependencies = [ + "bitflags 2.8.0", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", +] + [[package]] name = "mime" version = "0.3.17" @@ -4015,22 +4042,23 @@ dependencies = [ [[package]] name = "naga" -version = "23.1.0" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f" +checksum = "e380993072e52eef724eddfcde0ed013b0c023c3f0417336ed041aa9f076994e" dependencies = [ "arrayvec", "bit-set", "bitflags 2.8.0", - "cfg_aliases 0.1.1", + "cfg_aliases", "codespan-reporting", "hexf-parse", "indexmap", "log", "rustc-hash 1.1.0", "spirv 0.3.0+sdk-1.3.268.0", + "strum", "termcolor", - "thiserror 1.0.69", + "thiserror 2.0.11", "unicode-xid", ] @@ -4624,6 +4652,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + [[package]] name = "os_info" version = "3.9.2" @@ -5687,7 +5724,7 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" dependencies = [ - "cfg_aliases 0.2.1", + "cfg_aliases", "libc", "once_cell", "socket2", @@ -7583,6 +7620,37 @@ dependencies = [ "rustc-hash 1.1.0", ] +[[package]] +name = "type_hash" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c86f48f11992d3e379358c63cb25736c0b23944ff000d1583bbccad2b0b7c6" +dependencies = [ + "type_hash_core", + "type_hash_macros", +] + +[[package]] +name = "type_hash_core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87b1e93e2cd97790892dbe2d2813fbaa6eebaeb960265f59e363e79e51e4997a" +dependencies = [ + "fnv", +] + +[[package]] +name = "type_hash_macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "746fc164e076483ef087b3989f7aa80ffd9320fa558f3cb72cecfb9bb1dbc41e" +dependencies = [ + "either", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "typenum" version = "1.17.0" @@ -8020,12 +8088,13 @@ dependencies = [ [[package]] name = "wgpu" -version = "23.0.1" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80f70000db37c469ea9d67defdc13024ddf9a5f1b89cb2941b812ad7cde1735a" +checksum = "e41253fc7b660735e2a2d9a58c563f2a047d3cc3445293d8f4095538c9e8afbe" dependencies = [ "arrayvec", - "cfg_aliases 0.1.1", + "bitflags 2.8.0", + "cfg_aliases", "document-features", "js-sys", "log", @@ -8045,14 +8114,14 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "23.0.1" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a" +checksum = "82a39b8842dc9ffcbe34346e3ab6d496b32a47f6497e119d762c97fcaae3cb37" dependencies = [ "arrayvec", "bit-vec", "bitflags 2.8.0", - "cfg_aliases 0.1.1", + "cfg_aliases", "document-features", "indexmap", "log", @@ -8063,16 +8132,16 @@ dependencies = [ "raw-window-handle", "rustc-hash 1.1.0", "smallvec", - "thiserror 1.0.69", + "thiserror 2.0.11", "wgpu-hal", "wgpu-types", ] [[package]] name = "wgpu-hal" -version = "23.0.1" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89364b8a0b211adc7b16aeaf1bd5ad4a919c1154b44c9ce27838213ba05fd821" +checksum = "5a782e5056b060b0b4010881d1decddd059e44f2ecd01e2db2971b48ad3627e5" dependencies = [ "android_system_properties", "arrayvec", @@ -8081,7 +8150,7 @@ dependencies = [ "bitflags 2.8.0", "block", "bytemuck", - "cfg_aliases 0.1.1", + "cfg_aliases", "core-graphics-types", "glow", "glutin_wgl_sys", @@ -8093,11 +8162,12 @@ dependencies = [ "libc", "libloading", "log", - "metal 0.29.0", + "metal 0.31.0", "naga", "ndk-sys", "objc", "once_cell", + "ordered-float", "parking_lot 0.12.3", "profiling", "range-alloc", @@ -8105,7 +8175,7 @@ dependencies = [ "renderdoc-sys", "rustc-hash 1.1.0", "smallvec", - "thiserror 1.0.69", + "thiserror 2.0.11", "wasm-bindgen", "web-sys", "wgpu-types", @@ -8115,12 +8185,13 @@ dependencies = [ [[package]] name = "wgpu-types" -version = "23.0.0" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068" +checksum = "50ac044c0e76c03a0378e7786ac505d010a873665e2d51383dcff8dd227dc69c" dependencies = [ "bitflags 2.8.0", "js-sys", + "log", "web-sys", ] diff --git a/Cargo.toml b/Cargo.toml index 22ed0b2644..f731d063a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,7 +101,7 @@ ratatui = "0.29.0" # WGPU stuff text_placeholder = "0.5.1" -wgpu = "23.0.0" +wgpu = "24.0.0" # Benchmarks and Burnbench arboard = "3.4.1" @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/backend-comparison/benches/matmul_fused.rs b/backend-comparison/benches/matmul_fused.rs index 375be97b4e..fbec64c648 100644 --- a/backend-comparison/benches/matmul_fused.rs +++ b/backend-comparison/benches/matmul_fused.rs @@ -1,5 +1,9 @@ use backend_comparison::persistence::save; -use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor}; +use burn::tensor::{ + activation::{gelu, relu}, + backend::Backend, + Distribution, Shape, Tensor, +}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -14,7 +18,7 @@ impl Benchmark for MatmulBenchmark { type Args = (Tensor, Tensor, Tensor); fn name(&self) -> String { - "matmul_bias_relu".into() + "matmul_relu_bias_gelu".into() } fn shapes(&self) -> Vec> { @@ -23,7 +27,7 @@ impl Benchmark for MatmulBenchmark { fn execute(&self, (lhs, rhs, bias): Self::Args) { let bias = bias.unsqueeze(); - relu(lhs.matmul(rhs) + bias); + gelu(relu(lhs.matmul(rhs)) + bias); } fn prepare(&self) -> Self::Args { diff --git a/burn-book/src/advanced/no-std.md b/burn-book/src/advanced/no-std.md index 5f5621cc51..e55afc904d 100644 --- a/burn-book/src/advanced/no-std.md +++ b/burn-book/src/advanced/no-std.md @@ -68,7 +68,7 @@ We are using ndarray, so we just need to define the NdArray backend as usual use burn::{backend::NdArray, tensor::Tensor}; type Backend = NdArray; -type BackendDeice = ::Device; +type BackendDevice = ::Device; ``` Then inside the `main` function add @@ -76,7 +76,7 @@ Then inside the `main` function add use your_model::Model; // Get a default device for the backend -let device = BackendDeice::default(); +let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); diff --git a/crates/burn-autodiff/Cargo.toml b/crates/burn-autodiff/Cargo.toml index 5e221f887f..df7040f835 100644 --- a/crates/burn-autodiff/Cargo.toml +++ b/crates/burn-autodiff/Cargo.toml @@ -18,7 +18,7 @@ std = [] async = [] # Require std [dependencies] -burn-common = { path = "../burn-common", version = "0.17.0" } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index f554518430..ade8d64db7 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![recursion_limit = "135"] //! The core crate of Burn. diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index d5e1ee9e38..0f9c75fc94 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -59,6 +59,84 @@ pub(crate) struct OperationConverter { scalar_u8: Vec, } +/// Fork of a [context](Context) which owns its data. +pub struct ContextOwned { + tensors: HashMap, + handles: HandleContainer, + scalar_f32: Vec, + scalar_f16: Vec, + scalar_bf16: Vec, + scalar_i64: Vec, + scalar_i32: Vec, + scalar_i16: Vec, + scalar_i8: Vec, + scalar_u64: Vec, + scalar_u32: Vec, + scalar_u16: Vec, + scalar_u8: Vec, +} + +impl ContextOwned { + /// Convert into [context](Context). + pub fn as_context(&mut self) -> Context<'_, H> { + Context { + tensors: &mut self.tensors, + handles: &mut self.handles, + scalar_f32: &self.scalar_f32, + scalar_f16: &self.scalar_f16, + scalar_bf16: &self.scalar_bf16, + scalar_i64: &self.scalar_i64, + scalar_i32: &self.scalar_i32, + scalar_i16: &self.scalar_i16, + scalar_i8: &self.scalar_i8, + scalar_u64: &self.scalar_u64, + scalar_u32: &self.scalar_u32, + scalar_u16: &self.scalar_u16, + scalar_u8: &self.scalar_u8, + } + } + + /// Fork the context again. + pub fn fork(&self) -> ContextOwned { + ContextOwned { + tensors: self.tensors.clone(), + handles: self.handles.fork(), + scalar_f32: self.scalar_f32.clone(), + scalar_f16: self.scalar_f16.clone(), + scalar_bf16: self.scalar_bf16.clone(), + scalar_i64: self.scalar_i64.clone(), + scalar_i32: self.scalar_i32.clone(), + scalar_i16: self.scalar_i16.clone(), + scalar_i8: self.scalar_i8.clone(), + scalar_u64: self.scalar_u64.clone(), + scalar_u32: self.scalar_u32.clone(), + scalar_u16: self.scalar_u16.clone(), + scalar_u8: self.scalar_u8.clone(), + } + } +} + +impl Context<'_, H> { + /// Fork the context into an [owned context](ContextOwned). + pub fn fork(&self) -> ContextOwned { + ContextOwned { + tensors: self.tensors.clone(), + handles: self.handles.fork(), + scalar_f32: self.scalar_f32.clone(), + scalar_f16: self.scalar_f16.clone(), + scalar_bf16: self.scalar_bf16.clone(), + scalar_i64: self.scalar_i64.clone(), + scalar_i32: self.scalar_i32.clone(), + scalar_i16: self.scalar_i16.clone(), + scalar_i8: self.scalar_i8.clone(), + scalar_u64: self.scalar_u64.clone(), + scalar_u32: self.scalar_u32.clone(), + scalar_u16: self.scalar_u16.clone(), + scalar_u8: self.scalar_u8.clone(), + } + } +} + pub(crate) trait RelativeOps { /// Convert (usually an [`OperationDescription`]) to a relative form. /// diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index e8e4d82659..48587a1bf9 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -125,20 +125,16 @@ impl FusionRuntime for FusionJitRuntime { fn optimizations( device: R::Device, ) -> Vec>> { - let mut optimizations: Vec>> = - vec![Box::new(ElementWiseBuilder::::new( + vec![ + Box::new(ElementWiseBuilder::::new( device.clone(), BT::as_elem_native_unchecked().into(), - ))]; - - if cfg!(feature = "fusion-experimental") { - optimizations.push(Box::new(MatmulBuilder::::new( + )), + Box::new(MatmulBuilder::::new( device.clone(), BT::as_elem_native_unchecked().into(), - ))); - } - - optimizations + )), + ] } } diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs index 986332914f..f197237819 100644 --- a/crates/burn-jit/src/fusion/matmul/builder.rs +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -47,7 +47,13 @@ impl OptimizationBuilder> for MatmulBuilder let rhs = self.builder.input_unhandled(&op.rhs); let out = self.builder.output_unhandled(&op.out); - self.matmul = Some(FusedMatmul::new(lhs, rhs, out, op.clone())); + self.matmul = Some(FusedMatmul::new( + lhs, + rhs, + out, + op.clone(), + Default::default(), + )); } else { self.builder.close(); } diff --git a/crates/burn-jit/src/fusion/matmul/mod.rs b/crates/burn-jit/src/fusion/matmul/mod.rs index 1afeef9c88..cddec5983a 100644 --- a/crates/burn-jit/src/fusion/matmul/mod.rs +++ b/crates/burn-jit/src/fusion/matmul/mod.rs @@ -2,3 +2,4 @@ pub(crate) mod args; pub(crate) mod builder; pub(crate) mod optimization; pub(crate) mod spec; +pub(crate) mod tune; diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index d0cd8749ad..9a020df62c 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -12,7 +12,9 @@ use burn_tensor::Shape; use cubecl::linalg::matmul::components; use cubecl::linalg::matmul::components::tile::accelerated::Accelerated; use cubecl::linalg::matmul::components::MatmulProblem; -use cubecl::linalg::matmul::kernels::matmul::{MatmulSelector, StandardSelector}; +use cubecl::linalg::matmul::kernels::matmul::{ + MatmulSelector, PipelinedSelector, SpecializedSelector, StandardSelector, +}; use cubecl::linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}; use cubecl::linalg::tensor::{matrix_layout, MatrixLayout}; use cubecl::{client::ComputeClient, prelude::*}; @@ -26,16 +28,18 @@ use crate::fusion::on_write::{ use super::args::FusedMatmulInputLaunch; use super::spec::FusedMatmulSpec; +use super::tune::fused_matmul_autotune; -#[derive(new)] /// Fuse matmul operation followed by elemwise operations into a single kernel. pub struct MatmulOptimization { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, - client: ComputeClient, - device: R::Device, - len: usize, - matmul: FusedMatmul, + pub(crate) client: ComputeClient, + pub(crate) device: R::Device, + pub(crate) len: usize, + pub(crate) matmul_standard: FusedMatmul, + pub(crate) matmul_pipelined: FusedMatmul, + pub(crate) matmul_specialized: FusedMatmul, } #[derive(Serialize, Deserialize, Debug)] @@ -43,13 +47,46 @@ pub struct MatmulOptimization { pub struct MatmulOptimizationState { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, - matmul: FusedMatmul, + matmul_standard: FusedMatmul, + matmul_pipelined: FusedMatmul, + matmul_specialized: FusedMatmul, len: usize, } impl MatmulOptimization { + pub fn new( + trace: FuseOnWriteTrace, + trace_fallback: FuseOnWriteTrace, + client: ComputeClient, + device: R::Device, + len: usize, + matmul: FusedMatmul, + ) -> Self { + let mut matmul_standard = matmul.clone(); + let mut matmul_specialized = matmul.clone(); + let mut matmul_pipelined = matmul; + + matmul_standard.selector = FusedMatmulSelector::Standard; + matmul_specialized.selector = FusedMatmulSelector::Specialized; + matmul_pipelined.selector = FusedMatmulSelector::Pipelined; + + Self { + trace, + trace_fallback, + client, + device, + len, + matmul_standard, + matmul_pipelined, + matmul_specialized, + } + } /// Execute the optimization. pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + #[cfg(feature = "autotune")] + fused_matmul_autotune::(self, context); + + #[cfg(not(feature = "autotune"))] if self.execute_fused::(context).is_err() { self.execute_fallback::(context); } @@ -68,7 +105,9 @@ impl MatmulOptimization { len: state.len, client: R::client(device), device: device.clone(), - matmul: state.matmul.clone(), + matmul_standard: state.matmul_standard.clone(), + matmul_specialized: state.matmul_specialized.clone(), + matmul_pipelined: state.matmul_pipelined.clone(), } } @@ -77,21 +116,51 @@ impl MatmulOptimization { MatmulOptimizationState { trace: self.trace.clone(), trace_fallback: self.trace_fallback.clone(), - matmul: self.matmul.clone(), + matmul_standard: self.matmul_standard.clone(), + matmul_specialized: self.matmul_specialized.clone(), + matmul_pipelined: self.matmul_pipelined.clone(), len: self.len, } } - fn execute_fused( - &mut self, + pub fn execute_standard_fused( + &self, context: &mut Context<'_, JitFusionHandle>, ) -> Result<(), FusedMatmulError> { - self.trace - .run::(&self.client, &self.device, context, &self.matmul) + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_standard, + ) } - fn execute_fallback(&mut self, context: &mut Context<'_, JitFusionHandle>) { - match self.matmul.lhs.precision() { + pub fn execute_specialized_fused( + &self, + context: &mut Context<'_, JitFusionHandle>, + ) -> Result<(), FusedMatmulError> { + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_specialized, + ) + } + + pub fn execute_pipelined_fused( + &self, + context: &mut Context<'_, JitFusionHandle>, + ) -> Result<(), FusedMatmulError> { + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_pipelined, + ) + } + + pub fn execute_fallback(&self, context: &mut Context<'_, JitFusionHandle>) { + match self.matmul_standard.lhs.precision() { ElemwisePrecision::F32 => self.run_fallback::(context), ElemwisePrecision::F16 => self.run_fallback::(context), ElemwisePrecision::BF16 => self.run_fallback::(context), @@ -100,13 +169,25 @@ impl MatmulOptimization { } fn run_fallback( - &mut self, + &self, context: &mut Context<'_, JitFusionHandle>, ) { let (out_tensor, out_desc) = { - let lhs = context.tensors.get(&self.matmul.op.lhs.id).unwrap().clone(); - let rhs = context.tensors.get(&self.matmul.op.rhs.id).unwrap().clone(); - let out = context.tensors.get(&self.matmul.op.out.id).unwrap().clone(); + let lhs = context + .tensors + .get(&self.matmul_standard.op.lhs.id) + .unwrap() + .clone(); + let rhs = context + .tensors + .get(&self.matmul_standard.op.rhs.id) + .unwrap() + .clone(); + let out = context + .tensors + .get(&self.matmul_standard.op.out.id) + .unwrap() + .clone(); let lhs_handle = context.handles.get_handle(&lhs.id, &TensorStatus::ReadOnly); let rhs_handle = context.handles.get_handle(&rhs.id, &TensorStatus::ReadOnly); @@ -136,12 +217,21 @@ impl MatmulOptimization { } } +#[derive(Default, Clone, Serialize, Deserialize, Debug)] +pub enum FusedMatmulSelector { + #[default] + Standard, + Pipelined, + Specialized, +} + #[derive(new, Clone, Serialize, Deserialize, Debug)] pub struct FusedMatmul { lhs: Arg, rhs: Arg, out: Arg, - op: BinaryOperationDescription, + pub(crate) op: BinaryOperationDescription, + pub(crate) selector: FusedMatmulSelector, } #[derive(Debug)] @@ -261,15 +351,43 @@ impl FusedMatmul { } }; - match matmul_launch_kernel::>( - client, - FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), - outputs, - problem, - plane_size, - ) { - Ok(_) => Ok(()), - Err(err) => Err(FusedMatmulError::LaunchError(err)), + match self.selector { + FusedMatmulSelector::Standard => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } + FusedMatmulSelector::Pipelined => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } + FusedMatmulSelector::Specialized => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } } } } diff --git a/crates/burn-jit/src/fusion/matmul/tune.rs b/crates/burn-jit/src/fusion/matmul/tune.rs new file mode 100644 index 0000000000..0f6e42c486 --- /dev/null +++ b/crates/burn-jit/src/fusion/matmul/tune.rs @@ -0,0 +1,133 @@ +use crate::{ + fusion::{ + tune::{TuneContext, TuneInput}, + JitFusionHandle, + }, + kernel::matmul::MatmulAutotuneKey, + BoolElement, JitRuntime, JitTuneId, +}; +use burn_fusion::stream::Context; +use cubecl::{ + tune::{local_tuner, LocalTuner, TunableSet}, + AutotuneKey, +}; +use serde::{Deserialize, Serialize}; + +use super::optimization::MatmulOptimization; + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +pub struct FusedMatmulAutotuneKey { + matmul_key: MatmulAutotuneKey, + #[autotune(anchor)] + num_ops_fused: usize, +} + +/// Executes autotune on matmul operations +pub fn fused_matmul_autotune( + optimization: &MatmulOptimization, + context: &mut Context>, +) { + static TUNER: LocalTuner = local_tuner!(); + + let tunables = TunableSet::new(create_key::, input_gen::) + .with_tunable(tune_standard_fused::) + .with_tunable(tune_specialized_fused::) + .with_tunable(tune_pipelined_fused::) + .with_tunable(tune_fallback::); + + TUNER.execute( + &JitTuneId::new::(&optimization.device), + &optimization.client, + &tunables, + TuneInput::new(context, optimization), + ); +} + +pub(crate) fn create_key( + input: &TuneInput>, +) -> FusedMatmulAutotuneKey { + let opt = input.optimization(); + let context = match input.context() { + TuneContext::Original(context) => context, + TuneContext::Fork(_) => panic!("Not supported when generating key"), + }; + + let lhs = context.tensors.get(&opt.matmul_standard.op.lhs.id).unwrap(); + let rhs = context.tensors.get(&opt.matmul_standard.op.rhs.id).unwrap(); + let out = context.tensors.get(&opt.matmul_standard.op.out.id).unwrap(); + + let key = MatmulAutotuneKey::from_shape( + &lhs.shape.clone().into(), + &rhs.shape.clone().into(), + out.dtype, + ); + FusedMatmulAutotuneKey::new(key, opt.len) +} + +fn input_gen( + _key: &FusedMatmulAutotuneKey, + input: &TuneInput>, +) -> TuneInput> { + input.clone() +} + +fn tune_standard_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_standard_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_standard_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_specialized_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_specialized_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_specialized_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_pipelined_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_pipelined_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_pipelined_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_fallback( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_fallback::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_fallback::(&mut context_owned.as_context()) + } + }; + + Ok(()) +} diff --git a/crates/burn-jit/src/fusion/mod.rs b/crates/burn-jit/src/fusion/mod.rs index 4c44770b4e..96e1704964 100644 --- a/crates/burn-jit/src/fusion/mod.rs +++ b/crates/burn-jit/src/fusion/mod.rs @@ -3,5 +3,6 @@ mod base; pub(crate) mod elemwise; pub(crate) mod matmul; pub(crate) mod on_write; +pub(crate) mod tune; pub use base::*; diff --git a/crates/burn-jit/src/fusion/tune.rs b/crates/burn-jit/src/fusion/tune.rs new file mode 100644 index 0000000000..8c45f93bb0 --- /dev/null +++ b/crates/burn-jit/src/fusion/tune.rs @@ -0,0 +1,108 @@ +use super::JitFusionHandle; +use crate::JitRuntime; +use burn_fusion::stream::{Context, ContextOwned}; + +/// Fusion context used when tuning kernels. +/// +/// Either the original context is returned or a fork of the original. +/// The fork is only given when performing autotuning, and not when actually performing the +/// operation. +pub enum TuneContext<'a, R: JitRuntime> { + Original(&'a mut Context<'a, JitFusionHandle>), + Fork(Box>>), +} + +/// Fusion input wrapper containing the context and the optimization. +/// +/// # Safety +/// +/// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions +/// are made based on its behavior. +pub struct TuneInput { + context: UnsafeTuneContext, + optimization: *const O, +} + +/// Unsafe wrapper around the context. +/// +/// # Safety +/// +/// The wrapper removes the context lifetime. +/// +/// For it to be correct, the context must not be used after the invocation of the +/// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are +/// tuned using a cloned version of the input; therefore, a fork of the context will be used to find +/// the best kernel to use, which can be async. +enum UnsafeTuneContext { + Original(*mut Context<'static, JitFusionHandle>), + Fork(Box>>), +} + +unsafe impl Send for UnsafeTuneContext {} +unsafe impl Send for TuneInput {} + +impl TuneInput { + /// Create a new autotune input from the [context](Context) and an optimization. + pub fn new(context: &mut Context>, optimization: &O) -> Self { + let context = UnsafeTuneContext::new(context); + // We can erase the lifetime for the same reason we do with the context. + let optimization = core::ptr::from_ref(optimization); + + Self { + context, + optimization, + } + } + + /// Retrieve the [autotune context](TuneContext) for the current input. + pub fn context(&self) -> TuneContext<'static, R> { + self.context.get() + } + + /// Retrieve the optimization for the current input. + pub fn optimization(&self) -> &O { + unsafe { self.optimization.as_ref().unwrap() } + } +} + +impl UnsafeTuneContext { + fn new(context: &mut Context<'_, JitFusionHandle>) -> Self { + let ptr = core::ptr::from_mut(context); + + // It is necessary for the lifetime. + #[allow(clippy::unnecessary_cast)] + Self::Original(ptr as *mut Context<'static, _>) + } + + fn get(&self) -> TuneContext<'static, R> { + match self { + UnsafeTuneContext::Original(ptr) => { + TuneContext::Original(unsafe { ptr.as_mut().unwrap() }) + } + UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())), + } + } +} + +impl Clone for TuneInput { + fn clone(&self) -> Self { + Self { + context: self.context.clone(), + optimization: self.optimization, + } + } +} + +impl Clone for UnsafeTuneContext { + fn clone(&self) -> Self { + let context = match self { + UnsafeTuneContext::Original(ptr) => { + let context: &mut Context<'static, JitFusionHandle> = + unsafe { ptr.as_mut().unwrap() }; + context.fork() + } + UnsafeTuneContext::Fork(context) => context.fork(), + }; + UnsafeTuneContext::Fork(Box::new(context)) + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 7f9914989a..6b738ab988 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -99,7 +99,7 @@ fn im2col_kernel( #[cfg(not(test))] pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX); + let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX); let max_cube_count = u16::MAX as usize; let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); if max_simultaneous == 0 { diff --git a/crates/burn-jit/src/kernel/matmul/tune/key.rs b/crates/burn-jit/src/kernel/matmul/tune/key.rs index d25cce3023..44cb079399 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/key.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/key.rs @@ -22,7 +22,7 @@ pub struct MatmulAutotuneKey { } impl MatmulAutotuneKey { - fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { + pub(crate) fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { let ndims = lhs_shape.num_dims(); let m = lhs_shape.dims[ndims - 2]; let k = lhs_shape.dims[ndims - 1]; diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index 660ae2f6fd..fd23cd2e2d 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -15,7 +15,8 @@ pub use mask::*; pub(crate) use unary_float::*; pub(crate) use unary_numeric::*; -pub use cubecl::{Kernel, PLANE_DIM_APPROX}; +pub use burn_common::PLANE_DIM_APPROX; +pub use cubecl::Kernel; /// Convolution kernels pub mod conv; diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 54e50468fb..cfdf3319fe 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -1,5 +1,6 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use cubecl::{prelude::*, Compiler, ExecutionMode, KernelId}; +use burn_common::ExecutionMode; +use cubecl::{prelude::*, Compiler, KernelId}; use super::SourceTemplate; diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index 167cf88c1a..111649ab25 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -43,7 +43,7 @@ blas-openblas-system = [ # ** Please make sure all dependencies support no_std when std is disabled ** -burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, optional = true } burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"] } diff --git a/crates/burn-router/src/lib.rs b/crates/burn-router/src/lib.rs index 644f65ee67..773235f781 100644 --- a/crates/burn-router/src/lib.rs +++ b/crates/burn-router/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![recursion_limit = "138"] //! Burn multi-backend router. diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 318912b2f7..7428408292 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -32,7 +32,7 @@ std = [ [dependencies] burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } -cubecl = { workspace = true, optional = true, default-features = true } +cubecl = { workspace = true, optional = true, default-features = false } bytemuck = { workspace = true, features = ["extern_crate_alloc"] } colored = { workspace = true, optional = true } diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 85e18ec444..dce51f5ee2 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -26,6 +26,23 @@ pub struct HandleContainer { pub handles_orphan: Vec, } +impl HandleContainer { + /// Fork the container, useful for autotune. + pub fn fork(&self) -> Self { + let mut handles = HashMap::with_capacity(self.handles.len()); + + for (id, handle) in self.handles.iter() { + handles.insert(*id, handle.clone()); + } + + Self { + handles, + counter: self.counter, + handles_orphan: self.handles_orphan.clone(), + } + } +} + impl core::fmt::Debug for HandleContainer { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("HandleContainer") @@ -37,6 +54,7 @@ impl core::fmt::Debug for HandleContainer { } /// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state +#[derive(Clone)] pub enum Handle { /// No [tensor handle](ReprBackend::Handle) has been created yet NotInit, diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 3e9b2820a5..c11854fcaf 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -12,7 +12,7 @@ pub use burn_jit::{ pub use burn_jit::{tensor::JitTensor, JitBackend}; pub use burn_jit::{BoolElement, FloatElement, IntElement}; pub use cubecl::flex32; -pub use cubecl::ir::CubeDim; +pub use cubecl::CubeDim; pub use cubecl::wgpu::{ init_device, init_setup, init_setup_async, MemoryConfiguration, RuntimeOptions, WgpuDevice, diff --git a/examples/guide/src/bin/infer.rs b/examples/guide/src/bin/infer.rs index 8e8360febe..44c5b1dabc 100644 --- a/examples/guide/src/bin/infer.rs +++ b/examples/guide/src/bin/infer.rs @@ -1,3 +1,4 @@ +#![recursion_limit = "131"] use burn::{backend::WebGpu, data::dataset::Dataset}; use guide::inference; diff --git a/examples/image-classification-web/src/lib.rs b/examples/image-classification-web/src/lib.rs index 3881123eaf..3d528f2e9d 100644 --- a/examples/image-classification-web/src/lib.rs +++ b/examples/image-classification-web/src/lib.rs @@ -1,4 +1,5 @@ #![cfg_attr(not(test), no_std)] +#![recursion_limit = "135"] pub mod model; pub mod web; diff --git a/examples/raspberry-pi-pico/src/bin/main.rs b/examples/raspberry-pi-pico/src/bin/main.rs index 1b7f6acdf0..a502a8193e 100644 --- a/examples/raspberry-pi-pico/src/bin/main.rs +++ b/examples/raspberry-pi-pico/src/bin/main.rs @@ -10,7 +10,7 @@ use embassy_rp as _; use embedded_alloc::Heap; type Backend = NdArray; -type BackendDeice = ::Device; +type BackendDevice = ::Device; #[global_allocator] static HEAP: Heap = Heap::empty(); @@ -25,7 +25,7 @@ async fn main(_spawner: Spawner) { } // Get a default device for the backend - let device = BackendDeice::default(); + let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); @@ -47,7 +47,7 @@ async fn main(_spawner: Spawner) { } } -fn run_model<'a>(model: &Model, device: &BackendDeice, input: f32) -> Tensor { +fn run_model<'a>(model: &Model, device: &BackendDevice, input: f32) -> Tensor { // Define the tensor let input = Tensor::::from_floats([[input]], &device); diff --git a/examples/server/src/lib.rs b/examples/server/src/lib.rs index 20ce328192..014a5e2cf5 100644 --- a/examples/server/src/lib.rs +++ b/examples/server/src/lib.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "141"] + pub fn start() { let port = std::env::var("REMOTE_BACKEND_PORT") .map(|port| match port.parse::() { diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 9dbb9f5233..9a9cab44bd 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "256"] + use burn::{ nn::transformer::TransformerEncoderConfig, optim::{decay::WeightDecayConfig, AdamConfig}, diff --git a/examples/text-classification/examples/db-pedia-infer.rs b/examples/text-classification/examples/db-pedia-infer.rs index 490ed3b97e..027eb76122 100644 --- a/examples/text-classification/examples/db-pedia-infer.rs +++ b/examples/text-classification/examples/db-pedia-infer.rs @@ -1,6 +1,6 @@ use text_classification::DbPediaDataset; -use burn::tensor::backend::AutodiffBackend; +use burn::tensor::backend::Backend; #[cfg(not(feature = "f16"))] #[allow(dead_code)] @@ -8,7 +8,7 @@ type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; -pub fn launch(device: B::Device) { +pub fn launch(device: B::Device) { text_classification::inference::infer::( device, "/tmp/text-classification-db-pedia", @@ -34,24 +34,18 @@ pub fn launch(device: B::Device) { feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::{ - ndarray::{NdArray, NdArrayDevice}, - Autodiff, - }; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(NdArrayDevice::Cpu); + launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; @@ -61,35 +55,29 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>>(device); + launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::{ - tch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::tch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(LibTorchDevice::Cpu); + launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::{ - wgpu::{Wgpu, WgpuDevice}, - Autodiff, - }; + use burn::backend::wgpu::{Wgpu, WgpuDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(WgpuDevice::default()); + launch::>(WgpuDevice::default()); } }