Skip to content

Commit

Permalink
Add new Vulkan and WebGpu types
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jan 23, 2025
1 parent a92e347 commit 5ec1949
Show file tree
Hide file tree
Showing 16 changed files with 75 additions and 44 deletions.
3 changes: 2 additions & 1 deletion crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ hip = ["burn-hip"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
wgpu = ["burn-wgpu"]
vulkan = ["wgpu", "burn-wgpu/spirv"]
vulkan = ["wgpu", "burn-wgpu/vulkan"]
webgpu = ["wgpu", "burn-wgpu/webgpu"]

# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]
Expand Down
6 changes: 6 additions & 0 deletions crates/burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ pub use burn_wgpu as wgpu;
#[cfg(feature = "wgpu")]
pub use burn_wgpu::Wgpu;

#[cfg(feature = "webgpu")]
pub use burn_wgpu::WebGpu;

#[cfg(feature = "vulkan")]
pub use burn_wgpu::Vulkan;

#[cfg(feature = "cuda")]
pub use burn_cuda as cuda;

Expand Down
9 changes: 8 additions & 1 deletion crates/burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
exclusive-memory-only = ["cubecl/exclusive-memory-only"]
fusion = ["burn-fusion", "burn-jit/fusion"]
spirv = ["cubecl/wgpu-spirv"]
std = ["burn-jit/std", "cubecl/std"]
template = ["burn-jit/template", "cubecl/template"]

# Backends
webgpu = ["cubecl-wgsl"]
vulkan = ["cubecl-spirv"]

# Compilers
cubecl-wgsl = []
cubecl-spirv = ["cubecl/wgpu-spirv"]

[dependencies]
cubecl = { workspace = true, features = ["wgpu"] }

Expand Down
57 changes: 35 additions & 22 deletions crates/burn-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ pub use burn_jit::{tensor::JitTensor, JitBackend};
pub use burn_jit::{BoolElement, FloatElement, IntElement};
pub use cubecl::flex32;
pub use cubecl::ir::CubeDim;
pub use cubecl::wgpu::*;

pub type Wgsl = cubecl::wgpu::WgslCompiler;
#[cfg(feature = "spirv")]
pub type SpirV = cubecl::wgpu::spirv::VkSpirvCompiler;
pub use cubecl::wgpu::{
init_device, init_setup, init_setup_async, MemoryConfiguration, RuntimeOptions, WgpuDevice,
WgpuResource, WgpuRuntime, WgpuSetup, WgpuStorage,
};
// Vulkan and WebGpu would have conflicting type names
pub mod graphics {
pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};
}

#[cfg(feature = "spirv")]
type Compiler = SpirV;
#[cfg(feature = "spirv")]
type Bool = u8;
#[cfg(not(feature = "spirv"))]
type Compiler = Wgsl;
#[cfg(not(feature = "spirv"))]
type Bool = u32;
#[cfg(feature = "cubecl-spirv")]
pub use cubecl::wgpu::spirv::SpirvCompiler;
#[cfg(feature = "cubecl-wgsl")]
pub use cubecl::wgpu::WgslCompiler;

#[cfg(feature = "fusion")]
/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
Expand All @@ -44,14 +44,14 @@ type Bool = u32;
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
/// burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
/// It's also possible to use an existing wgpu device, by using `init_device`.
///
/// # Notes
///
Expand All @@ -60,7 +60,7 @@ type Bool = u32;
///
/// You can disable the `fusion` feature flag to remove that functionality, which might be
/// necessary on `wasm` for now.
pub type Wgpu<F = f32, I = i32, B = Bool, C = Compiler> =
pub type Wgpu<F = f32, I = i32, B = u32, C = cubecl::wgpu::WgslCompiler> =
burn_fusion::Fusion<JitBackend<cubecl::wgpu::WgpuRuntime<C>, F, I, B>>;

#[cfg(not(feature = "fusion"))]
Expand All @@ -79,14 +79,14 @@ pub type Wgpu<F = f32, I = i32, B = Bool, C = Compiler> =
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
/// burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
/// It's also possible to use an existing wgpu device, by using `init_device`.
///
/// # Notes
///
Expand All @@ -95,20 +95,33 @@ pub type Wgpu<F = f32, I = i32, B = Bool, C = Compiler> =
///
/// You can enable the `fusion` feature flag to add that functionality, which might improve
/// performance.
pub type Wgpu<F = f32, I = i32, B = Bool, C = Compiler> =
pub type Wgpu<F = f32, I = i32, B = u32, C = cubecl::wgpu::WgslCompiler> =
JitBackend<cubecl::wgpu::WgpuRuntime<C>, F, I, B>;

#[cfg(feature = "vulkan")]
/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V.
pub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B, cubecl::wgpu::spirv::VkSpirvCompiler>;

#[cfg(feature = "webgpu")]
/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.
pub type WebGpu<F = f32, I = i32, B = u32> = Wgpu<F, I, B, WgslCompiler>;

#[cfg(test)]
mod tests {
use burn_jit::JitBackend;
#[cfg(feature = "spirv")]
#[cfg(feature = "vulkan")]
pub use half::f16;
pub type TestRuntime = cubecl::wgpu::WgpuRuntime<super::Compiler>;

#[cfg(feature = "cubecl-spirv")]
type Compiler = cubecl::wgpu::spirv::VkSpirvCompiler;
#[cfg(not(feature = "cubecl-spirv"))]
type Compiler = cubecl::wgpu::WgslCompiler;
pub type TestRuntime = cubecl::wgpu::WgpuRuntime<Compiler>;

// Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it
// breaks a lot of tests from precision issues
#[cfg(feature = "spirv")]
#[cfg(feature = "vulkan")]
burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
#[cfg(not(feature = "spirv"))]
#[cfg(not(feature = "vulkan"))]
burn_jit::testgen_all!([f32], [i32], [u32]);
}
1 change: 1 addition & 0 deletions crates/burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ server = ["burn-core/server"]
tch = ["burn-core/tch"]
wgpu = ["burn-core/wgpu"]
vulkan = ["burn-core/vulkan"]
webgpu = ["burn-core/webgpu"]

# Network utils
network = ["burn-core/network"]
Expand Down
1 change: 1 addition & 0 deletions crates/burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
//! - `vision`: Enables vision datasets (MnistDataset)
//! - Backends
//! - `wgpu`: Makes available the WGPU backend
//! - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler
//! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler
//! - `cuda`: Makes available the CUDA backend
//! - `hip`: Makes available the HIP backend
Expand Down
4 changes: 2 additions & 2 deletions examples/custom-renderer/examples/custom-renderer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
use burn::backend::{wgpu::WgpuDevice, Autodiff, WebGpu};

fn main() {
custom_renderer::run::<Autodiff<Wgpu>>(WgpuDevice::default());
custom_renderer::run::<Autodiff<WebGpu>>(WgpuDevice::default());
}
2 changes: 1 addition & 1 deletion examples/custom-training-loop/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ publish = false
version.workspace = true

[dependencies]
burn = {path = "../../crates/burn", features=["autodiff", "wgpu", "vision"]}
burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]}
guide = {path = "../guide"}

# Serialization
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::backend::{Autodiff, Wgpu};
use burn::backend::{Autodiff, WebGpu};

fn main() {
custom_training_loop::run::<Autodiff<Wgpu>>(Default::default());
custom_training_loop::run::<Autodiff<WebGpu>>(Default::default());
}
2 changes: 1 addition & 1 deletion examples/guide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ version.workspace = true
default = ["burn/default"]

[dependencies]
burn = {path = "../../crates/burn", features = ["wgpu", "train", "vision"]}
burn = {path = "../../crates/burn", features = ["webgpu", "train", "vision"]}

# Serialization
log = {workspace = true}
Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/bin/infer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use burn::{backend::Wgpu, data::dataset::Dataset};
use burn::{backend::WebGpu, data::dataset::Dataset};
use guide::inference;

fn main() {
type MyBackend = Wgpu<f32, i32>;
type MyBackend = WebGpu<f32, i32>;

let device = burn::backend::wgpu::WgpuDevice::default();

Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/bin/print.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use burn::backend::Wgpu;
use burn::backend::WebGpu;
use guide::model::ModelConfig;

fn main() {
type MyBackend = Wgpu<f32, i32>;
type MyBackend = WebGpu<f32, i32>;

let device = Default::default();
let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);
Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/bin/train.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::{
backend::{Autodiff, Wgpu},
backend::{Autodiff, WebGpu},
data::dataset::Dataset,
optim::AdamConfig,
};
Expand All @@ -10,7 +10,7 @@ use guide::{
};

fn main() {
type MyBackend = Wgpu<f32, i32>;
type MyBackend = WebGpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;

// Create a default Wgpu device
Expand Down
6 changes: 3 additions & 3 deletions examples/image-classification-web/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use burn::{
tensor::activation::softmax,
};

use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::wgpu::{graphics::AutoGraphicsApi, WebGpu, WgpuDevice};
use burn_candle::Candle;

use serde::Serialize;
Expand All @@ -37,8 +37,8 @@ pub enum ModelType {
/// The model is loaded to the NdArray backend
WithNdArrayBackend(Model<NdArray<f32>>),

/// The model is loaded to the Wgpu backend
WithWgpuBackend(Model<Wgpu<f32, i32>>),
/// The model is loaded to the WebGpu backend
WithWgpuBackend(Model<WebGpu<f32, i32>>),
}

/// The image is 224x224 pixels with 3 channels (RGB)
Expand Down
6 changes: 3 additions & 3 deletions examples/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ publish = false
version.workspace = true

[features]
default = ["wgpu"]
default = ["webgpu"]
cuda = ["burn/cuda"]
wgpu = ["burn/wgpu"]
vulkan = ["wgpu", "burn/vulkan"]
webgpu = ["burn/webgpu"]
vulkan = ["burn/vulkan"]
ndarray = ["burn/ndarray"]

[dependencies]
Expand Down
6 changes: 4 additions & 2 deletions examples/server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ pub fn start() {
burn::server::start::<burn::backend::NdArray>(Default::default(), port);
} else if #[cfg(feature = "cuda")]{
burn::server::start::<burn::backend::Cuda>(Default::default(), port);
} else if #[cfg(feature = "wgpu")] {
burn::server::start::<burn::backend::Wgpu>(Default::default(), port);
} else if #[cfg(feature = "webgpu")] {
burn::server::start::<burn::backend::WebGpu>(Default::default(), port);
} else if #[cfg(feature = "vulkan")] {
burn::server::start::<burn::backend::Vulkan>(Default::default(), port);
} else {
panic!("No backend selected, can't start server on port {port}");
}
Expand Down

0 comments on commit 5ec1949

Please sign in to comment.