Skip to content

Commit

Permalink
[Feature] reduce fuse on read (#2870)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Mar 6, 2025
1 parent f98cc0b commit f106148
Show file tree
Hide file tree
Showing 56 changed files with 2,781 additions and 1,019 deletions.
35 changes: 19 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,13 @@ portable-atomic = { version = "1.11.0" }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e4fa42bebc3348b8912854298f3ec8e4d2d23529" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e4fa42bebc3348b8912854298f3ec8e4d2d23529" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "32367c1fb6898beea79e175f27173b26ec8e5a69" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "32367c1fb6898beea79e175f27173b26ec8e5a69" }
cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "32367c1fb6898beea79e175f27173b26ec8e5a69" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
# cubecl-std = { path = "../cubecl/crates/cubecl-std", default-features = false }
### For the release. ###
# cubecl = { version = "0.4.0", default-features = false }
# cubecl-common = { version = "0.4.0", default-features = false }
Expand Down
26 changes: 26 additions & 0 deletions backend-comparison/benches/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use burn_common::benchmark::{run_benchmark, Benchmark};

enum Instruction {
ArgMin(usize),
ArgMinFused(usize),
SumDim(usize),
SumDimFused(usize),
Sum,
}

Expand Down Expand Up @@ -41,6 +43,20 @@ impl<B: Backend> Benchmark for ReduceBenchmark<B> {
Instruction::SumDim(axis) => {
self.tensor.clone().sum_dim(axis);
}
Instruction::SumDimFused(axis) => {
let tensor = self.tensor.clone() + 5;
let tensor = tensor.log();
let tensor = tensor.tanh();
let tensor = tensor * 3;
tensor.sum_dim(axis);
}
Instruction::ArgMinFused(axis) => {
let tensor = self.tensor.clone() + 5;
let tensor = tensor.log();
let tensor = tensor.tanh();
let tensor = tensor * 3;
tensor.argmin(axis);
}
Instruction::Sum => {
self.tensor.clone().sum();
}
Expand All @@ -50,7 +66,9 @@ impl<B: Backend> Benchmark for ReduceBenchmark<B> {
fn name(&self) -> String {
match self.instruction {
Instruction::ArgMin(axis) => format!("reduce-argmin-{axis}"),
Instruction::ArgMinFused(axis) => format!("reduce-argmin-{axis}-fused"),
Instruction::SumDim(axis) => format!("reduce-sum-{axis}"),
Instruction::SumDimFused(axis) => format!("reduce-sum-{axis}-fused"),
Instruction::Sum => String::from("reduce-sum-full"),
}
}
Expand Down Expand Up @@ -78,11 +96,19 @@ fn bench<B: Backend>(
Instruction::ArgMin(axis),
device.clone(),
));
benchmarks.push(ReduceBenchmark::<B>::new(
Instruction::ArgMinFused(axis),
device.clone(),
));

benchmarks.push(ReduceBenchmark::<B>::new(
Instruction::SumDim(axis),
device.clone(),
));
benchmarks.push(ReduceBenchmark::<B>::new(
Instruction::SumDimFused(axis),
device.clone(),
));
}

benchmarks.push(ReduceBenchmark::<B>::new(Instruction::Sum, device.clone()));
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ macro_rules! bench_on_backend {
};
($fn_name:ident) => {
use std::env;
backend_comparison::init_log().unwrap();
// backend_comparison::init_log().unwrap();

let args: Vec<String> = env::args().collect();
let url = backend_comparison::get_sharing_url(&args);
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-core/src/nn/norm/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ mod tests {
[-0.3428, 0.7970, 1.1845],
],
]);
output.to_data().assert_approx_eq(&expected, 3);
output.to_data().assert_approx_eq(&expected, 2);
}

#[test]
Expand All @@ -261,15 +261,15 @@ mod tests {
.expect("gamma should not be None")
.val()
.to_data()
.assert_approx_eq(&TensorData::ones::<f32, _>([6]), 3);
.assert_approx_eq(&TensorData::ones::<f32, _>([6]), 2);

module
.beta
.as_ref()
.expect("beta should not be None")
.val()
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>([6]), 3);
.assert_approx_eq(&TensorData::zeros::<f32, _>([6]), 2);

let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([
Expand Down Expand Up @@ -313,7 +313,7 @@ mod tests {
[-1.0903, -0.0419, -1.3623],
],
]);
output.to_data().assert_approx_eq(&expected, 3);
output.to_data().assert_approx_eq(&expected, 2);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cubecl-fusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ burn-ir = { path = "../burn-ir", version = "0.17.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [
"cubecl",
] }
cubecl = { workspace = true, features = ["linalg"] }
cubecl = { workspace = true, features = ["linalg", "reduce"] }

half = { workspace = true }
serde = { workspace = true }
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-cubecl-fusion/src/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::marker::PhantomData;

use crate::reduce::optimization::{ReduceOptimization, ReduceOptimizationState};

use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState};
use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState};

Expand All @@ -18,6 +20,7 @@ pub enum CubeOptimization<R: Runtime> {
ElementWise(ElemwiseOptimization<R>),
/// Matrix multiplication optimization.
Matmul(MatmulOptimization<R>),
Reduce(ReduceOptimization<R>),
}

/// Fusion optimization state type for cubecl.
Expand All @@ -29,6 +32,7 @@ pub enum CubeOptimizationState {
ElementWise(ElemwiseOptimizationState),
/// Matrix multiplication optimization state.
Matmul(MatmulOptimizationState),
Reduce(ReduceOptimizationState),
}

pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
Expand Down
9 changes: 5 additions & 4 deletions crates/burn-cubecl-fusion/src/elemwise/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use burn_fusion::OptimizationBuilder;
use cubecl::Runtime;

use crate::{
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
shared::{builder::FuseBuilder, ir::ElemwisePrecision, settings::FuseSettings},
CubeOptimization,
};

use super::optimization::ElemwiseOptimization;

/// Fused element wise operations that are normally memory bound.
pub struct ElementWiseBuilder<R: Runtime> {
builder: FuseOnWriteBuilder,
builder: FuseBuilder,
device: R::Device,
}

Expand All @@ -21,13 +21,14 @@ impl<R: Runtime> ElementWiseBuilder<R> {
let max_bindings = props.hardware_properties().max_bindings;

Self {
builder: FuseOnWriteBuilder::new(
builder: FuseBuilder::new(
max_bindings,
bool_precision,
FuseSettings {
broadcast: true,
output_shape_updates: true,
inplace: true,
inplace: false,
vectorization: true,
},
),
device,
Expand Down
Loading

0 comments on commit f106148

Please sign in to comment.