Skip to content

Commit

Permalink
Merge branch 'main' into chore/feature-flags
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jan 23, 2025
2 parents 5ec1949 + e40c69b commit d839984
Show file tree
Hide file tree
Showing 29 changed files with 656 additions and 124 deletions.
157 changes: 114 additions & 43 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ ratatui = "0.29.0"

# WGPU stuff
text_placeholder = "0.5.1"
wgpu = "23.0.0"
wgpu = "24.0.0"

# Benchmarks and Burnbench
arboard = "3.4.1"
Expand Down Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
10 changes: 7 additions & 3 deletions backend-comparison/benches/matmul_fused.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use backend_comparison::persistence::save;
use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor};
use burn::tensor::{
activation::{gelu, relu},
backend::Backend,
Distribution, Shape, Tensor,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;

Expand All @@ -14,7 +18,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
type Args = (Tensor<B, D>, Tensor<B, D>, Tensor<B, 1>);

fn name(&self) -> String {
"matmul_bias_relu".into()
"matmul_relu_bias_gelu".into()
}

fn shapes(&self) -> Vec<Vec<usize>> {
Expand All @@ -23,7 +27,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {

fn execute(&self, (lhs, rhs, bias): Self::Args) {
let bias = bias.unsqueeze();
relu(lhs.matmul(rhs) + bias);
gelu(relu(lhs.matmul(rhs)) + bias);
}

fn prepare(&self) -> Self::Args {
Expand Down
4 changes: 2 additions & 2 deletions burn-book/src/advanced/no-std.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ We are using ndarray, so we just need to define the NdArray backend as usual
use burn::{backend::NdArray, tensor::Tensor};

type Backend = NdArray<f32>;
type BackendDeice = <Backend as burn::tensor::backend::Backend>::Device;
type BackendDevice = <Backend as burn::tensor::backend::Backend>::Device;
```

Then inside the `main` function add
```rs
use your_model::Model;

// Get a default device for the backend
let device = BackendDeice::default();
let device = BackendDevice::default();

// Create a new model and load the state
let model: Model<Backend> = Model::default();
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std = []
async = [] # Require std

[dependencies]
burn-common = { path = "../burn-common", version = "0.17.0" }
burn-common = { path = "../burn-common", version = "0.17.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true }

Expand Down
1 change: 1 addition & 0 deletions crates/burn-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![recursion_limit = "135"]

//! The core crate of Burn.
Expand Down
78 changes: 78 additions & 0 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,84 @@ pub(crate) struct OperationConverter {
scalar_u8: Vec<u8>,
}

/// Fork of a [context](Context) which owns its data.
pub struct ContextOwned<H> {
tensors: HashMap<TensorId, TensorDescription>,
handles: HandleContainer<H>,
scalar_f32: Vec<f32>,
scalar_f16: Vec<f16>,
scalar_bf16: Vec<bf16>,
scalar_i64: Vec<i64>,
scalar_i32: Vec<i32>,
scalar_i16: Vec<i16>,
scalar_i8: Vec<i8>,
scalar_u64: Vec<u64>,
scalar_u32: Vec<u32>,
scalar_u16: Vec<u16>,
scalar_u8: Vec<u8>,
}

impl<H: Clone> ContextOwned<H> {
/// Convert into [context](Context).
pub fn as_context(&mut self) -> Context<'_, H> {
Context {
tensors: &mut self.tensors,
handles: &mut self.handles,
scalar_f32: &self.scalar_f32,
scalar_f16: &self.scalar_f16,
scalar_bf16: &self.scalar_bf16,
scalar_i64: &self.scalar_i64,
scalar_i32: &self.scalar_i32,
scalar_i16: &self.scalar_i16,
scalar_i8: &self.scalar_i8,
scalar_u64: &self.scalar_u64,
scalar_u32: &self.scalar_u32,
scalar_u16: &self.scalar_u16,
scalar_u8: &self.scalar_u8,
}
}

/// Fork the context again.
pub fn fork(&self) -> ContextOwned<H> {
ContextOwned {
tensors: self.tensors.clone(),
handles: self.handles.fork(),
scalar_f32: self.scalar_f32.clone(),
scalar_f16: self.scalar_f16.clone(),
scalar_bf16: self.scalar_bf16.clone(),
scalar_i64: self.scalar_i64.clone(),
scalar_i32: self.scalar_i32.clone(),
scalar_i16: self.scalar_i16.clone(),
scalar_i8: self.scalar_i8.clone(),
scalar_u64: self.scalar_u64.clone(),
scalar_u32: self.scalar_u32.clone(),
scalar_u16: self.scalar_u16.clone(),
scalar_u8: self.scalar_u8.clone(),
}
}
}

impl<H: Clone> Context<'_, H> {
/// Fork the context into an [owned context](ContextOwned).
pub fn fork(&self) -> ContextOwned<H> {
ContextOwned {
tensors: self.tensors.clone(),
handles: self.handles.fork(),
scalar_f32: self.scalar_f32.clone(),
scalar_f16: self.scalar_f16.clone(),
scalar_bf16: self.scalar_bf16.clone(),
scalar_i64: self.scalar_i64.clone(),
scalar_i32: self.scalar_i32.clone(),
scalar_i16: self.scalar_i16.clone(),
scalar_i8: self.scalar_i8.clone(),
scalar_u64: self.scalar_u64.clone(),
scalar_u32: self.scalar_u32.clone(),
scalar_u16: self.scalar_u16.clone(),
scalar_u8: self.scalar_u8.clone(),
}
}
}

pub(crate) trait RelativeOps {
/// Convert (usually an [`OperationDescription`]) to a relative form.
///
Expand Down
16 changes: 6 additions & 10 deletions crates/burn-jit/src/fusion/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,16 @@ impl<R: JitRuntime, BT: BoolElement> FusionRuntime for FusionJitRuntime<R, BT> {
fn optimizations(
device: R::Device,
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
let mut optimizations: Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> =
vec![Box::new(ElementWiseBuilder::<R>::new(
vec![
Box::new(ElementWiseBuilder::<R>::new(
device.clone(),
BT::as_elem_native_unchecked().into(),
))];

if cfg!(feature = "fusion-experimental") {
optimizations.push(Box::new(MatmulBuilder::<R>::new(
)),
Box::new(MatmulBuilder::<R>::new(
device.clone(),
BT::as_elem_native_unchecked().into(),
)));
}

optimizations
)),
]
}
}

Expand Down
8 changes: 7 additions & 1 deletion crates/burn-jit/src/fusion/matmul/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for MatmulBuilder<R>
let rhs = self.builder.input_unhandled(&op.rhs);
let out = self.builder.output_unhandled(&op.out);

self.matmul = Some(FusedMatmul::new(lhs, rhs, out, op.clone()));
self.matmul = Some(FusedMatmul::new(
lhs,
rhs,
out,
op.clone(),
Default::default(),
));
} else {
self.builder.close();
}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/fusion/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub(crate) mod args;
pub(crate) mod builder;
pub(crate) mod optimization;
pub(crate) mod spec;
pub(crate) mod tune;
Loading

0 comments on commit d839984

Please sign in to comment.