Skip to content

Commit 245fbcd

Browse files
authoredJan 22, 2025··
Feat/fused matmul tune (#2726)
1 parent b33bd24 commit 245fbcd

File tree

14 files changed

+525
-60
lines changed

14 files changed

+525
-60
lines changed
 

‎Cargo.lock

+14-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
153153
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
154154

155155
### For the main burn branch. ###
156-
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" }
157-
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" }
156+
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
157+
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
158158
### For local development. ###
159159
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
160160
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

‎backend-comparison/benches/matmul_fused.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
use backend_comparison::persistence::save;
2-
use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor};
2+
use burn::tensor::{
3+
activation::{gelu, relu},
4+
backend::Backend,
5+
Distribution, Shape, Tensor,
6+
};
37
use burn_common::benchmark::{run_benchmark, Benchmark};
48
use derive_new::new;
59

@@ -14,7 +18,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
1418
type Args = (Tensor<B, D>, Tensor<B, D>, Tensor<B, 1>);
1519

1620
fn name(&self) -> String {
17-
"matmul_bias_relu".into()
21+
"matmul_relu_bias_gelu".into()
1822
}
1923

2024
fn shapes(&self) -> Vec<Vec<usize>> {
@@ -23,7 +27,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
2327

2428
fn execute(&self, (lhs, rhs, bias): Self::Args) {
2529
let bias = bias.unsqueeze();
26-
relu(lhs.matmul(rhs) + bias);
30+
gelu(relu(lhs.matmul(rhs)) + bias);
2731
}
2832

2933
fn prepare(&self) -> Self::Args {

‎crates/burn-fusion/src/stream/context.rs

+78
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,84 @@ pub(crate) struct OperationConverter {
5959
scalar_u8: Vec<u8>,
6060
}
6161

62+
/// Fork of a [context](Context) which owns its data.
63+
pub struct ContextOwned<H> {
64+
tensors: HashMap<TensorId, TensorDescription>,
65+
handles: HandleContainer<H>,
66+
scalar_f32: Vec<f32>,
67+
scalar_f16: Vec<f16>,
68+
scalar_bf16: Vec<bf16>,
69+
scalar_i64: Vec<i64>,
70+
scalar_i32: Vec<i32>,
71+
scalar_i16: Vec<i16>,
72+
scalar_i8: Vec<i8>,
73+
scalar_u64: Vec<u64>,
74+
scalar_u32: Vec<u32>,
75+
scalar_u16: Vec<u16>,
76+
scalar_u8: Vec<u8>,
77+
}
78+
79+
impl<H: Clone> ContextOwned<H> {
80+
/// Convert into [context](Context).
81+
pub fn as_context(&mut self) -> Context<'_, H> {
82+
Context {
83+
tensors: &mut self.tensors,
84+
handles: &mut self.handles,
85+
scalar_f32: &self.scalar_f32,
86+
scalar_f16: &self.scalar_f16,
87+
scalar_bf16: &self.scalar_bf16,
88+
scalar_i64: &self.scalar_i64,
89+
scalar_i32: &self.scalar_i32,
90+
scalar_i16: &self.scalar_i16,
91+
scalar_i8: &self.scalar_i8,
92+
scalar_u64: &self.scalar_u64,
93+
scalar_u32: &self.scalar_u32,
94+
scalar_u16: &self.scalar_u16,
95+
scalar_u8: &self.scalar_u8,
96+
}
97+
}
98+
99+
/// Fork the context again.
100+
pub fn fork(&self) -> ContextOwned<H> {
101+
ContextOwned {
102+
tensors: self.tensors.clone(),
103+
handles: self.handles.fork(),
104+
scalar_f32: self.scalar_f32.clone(),
105+
scalar_f16: self.scalar_f16.clone(),
106+
scalar_bf16: self.scalar_bf16.clone(),
107+
scalar_i64: self.scalar_i64.clone(),
108+
scalar_i32: self.scalar_i32.clone(),
109+
scalar_i16: self.scalar_i16.clone(),
110+
scalar_i8: self.scalar_i8.clone(),
111+
scalar_u64: self.scalar_u64.clone(),
112+
scalar_u32: self.scalar_u32.clone(),
113+
scalar_u16: self.scalar_u16.clone(),
114+
scalar_u8: self.scalar_u8.clone(),
115+
}
116+
}
117+
}
118+
119+
impl<H: Clone> Context<'_, H> {
120+
/// Fork the context into an [owned context](ContextOwned).
121+
pub fn fork(&self) -> ContextOwned<H> {
122+
ContextOwned {
123+
tensors: self.tensors.clone(),
124+
handles: self.handles.fork(),
125+
scalar_f32: self.scalar_f32.clone(),
126+
scalar_f16: self.scalar_f16.clone(),
127+
scalar_bf16: self.scalar_bf16.clone(),
128+
scalar_i64: self.scalar_i64.clone(),
129+
scalar_i32: self.scalar_i32.clone(),
130+
scalar_i16: self.scalar_i16.clone(),
131+
scalar_i8: self.scalar_i8.clone(),
132+
scalar_u64: self.scalar_u64.clone(),
133+
scalar_u32: self.scalar_u32.clone(),
134+
scalar_u16: self.scalar_u16.clone(),
135+
scalar_u8: self.scalar_u8.clone(),
136+
}
137+
}
138+
}
139+
62140
pub(crate) trait RelativeOps {
63141
/// Convert (usually an [`OperationDescription`]) to a relative form.
64142
///

‎crates/burn-jit/src/fusion/base.rs

+6-10
Original file line numberDiff line numberDiff line change
@@ -125,20 +125,16 @@ impl<R: JitRuntime, BT: BoolElement> FusionRuntime for FusionJitRuntime<R, BT> {
125125
fn optimizations(
126126
device: R::Device,
127127
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
128-
let mut optimizations: Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> =
129-
vec![Box::new(ElementWiseBuilder::<R>::new(
128+
vec![
129+
Box::new(ElementWiseBuilder::<R>::new(
130130
device.clone(),
131131
BT::as_elem_native_unchecked().into(),
132-
))];
133-
134-
if cfg!(feature = "fusion-experimental") {
135-
optimizations.push(Box::new(MatmulBuilder::<R>::new(
132+
)),
133+
Box::new(MatmulBuilder::<R>::new(
136134
device.clone(),
137135
BT::as_elem_native_unchecked().into(),
138-
)));
139-
}
140-
141-
optimizations
136+
)),
137+
]
142138
}
143139
}
144140

‎crates/burn-jit/src/fusion/matmul/builder.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for MatmulBuilder<R>
4747
let rhs = self.builder.input_unhandled(&op.rhs);
4848
let out = self.builder.output_unhandled(&op.out);
4949

50-
self.matmul = Some(FusedMatmul::new(lhs, rhs, out, op.clone()));
50+
self.matmul = Some(FusedMatmul::new(
51+
lhs,
52+
rhs,
53+
out,
54+
op.clone(),
55+
Default::default(),
56+
));
5157
} else {
5258
self.builder.close();
5359
}

‎crates/burn-jit/src/fusion/matmul/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pub(crate) mod args;
22
pub(crate) mod builder;
33
pub(crate) mod optimization;
44
pub(crate) mod spec;
5+
pub(crate) mod tune;

‎crates/burn-jit/src/fusion/matmul/optimization.rs

+147-29
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use burn_tensor::Shape;
1212
use cubecl::linalg::matmul::components;
1313
use cubecl::linalg::matmul::components::tile::accelerated::Accelerated;
1414
use cubecl::linalg::matmul::components::MatmulProblem;
15-
use cubecl::linalg::matmul::kernels::matmul::{MatmulSelector, StandardSelector};
15+
use cubecl::linalg::matmul::kernels::matmul::{
16+
MatmulSelector, PipelinedSelector, SpecializedSelector, StandardSelector,
17+
};
1618
use cubecl::linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError};
1719
use cubecl::linalg::tensor::{matrix_layout, MatrixLayout};
1820
use cubecl::{client::ComputeClient, prelude::*};
@@ -26,30 +28,65 @@ use crate::fusion::on_write::{
2628

2729
use super::args::FusedMatmulInputLaunch;
2830
use super::spec::FusedMatmulSpec;
31+
use super::tune::fused_matmul_autotune;
2932

30-
#[derive(new)]
3133
/// Fuse matmul operation followed by elemwise operations into a single kernel.
3234
pub struct MatmulOptimization<R: JitRuntime> {
3335
trace: FuseOnWriteTrace,
3436
trace_fallback: FuseOnWriteTrace,
35-
client: ComputeClient<R::Server, R::Channel>,
36-
device: R::Device,
37-
len: usize,
38-
matmul: FusedMatmul,
37+
pub(crate) client: ComputeClient<R::Server, R::Channel>,
38+
pub(crate) device: R::Device,
39+
pub(crate) len: usize,
40+
pub(crate) matmul_standard: FusedMatmul,
41+
pub(crate) matmul_pipelined: FusedMatmul,
42+
pub(crate) matmul_specialized: FusedMatmul,
3943
}
4044

4145
#[derive(Serialize, Deserialize, Debug)]
4246
/// State for the [matrix optimization](MatmulOptimizationState).
4347
pub struct MatmulOptimizationState {
4448
trace: FuseOnWriteTrace,
4549
trace_fallback: FuseOnWriteTrace,
46-
matmul: FusedMatmul,
50+
matmul_standard: FusedMatmul,
51+
matmul_pipelined: FusedMatmul,
52+
matmul_specialized: FusedMatmul,
4753
len: usize,
4854
}
4955

5056
impl<R: JitRuntime> MatmulOptimization<R> {
57+
pub fn new(
58+
trace: FuseOnWriteTrace,
59+
trace_fallback: FuseOnWriteTrace,
60+
client: ComputeClient<R::Server, R::Channel>,
61+
device: R::Device,
62+
len: usize,
63+
matmul: FusedMatmul,
64+
) -> Self {
65+
let mut matmul_standard = matmul.clone();
66+
let mut matmul_specialized = matmul.clone();
67+
let mut matmul_pipelined = matmul;
68+
69+
matmul_standard.selector = FusedMatmulSelector::Standard;
70+
matmul_specialized.selector = FusedMatmulSelector::Specialized;
71+
matmul_pipelined.selector = FusedMatmulSelector::Pipelined;
72+
73+
Self {
74+
trace,
75+
trace_fallback,
76+
client,
77+
device,
78+
len,
79+
matmul_standard,
80+
matmul_pipelined,
81+
matmul_specialized,
82+
}
83+
}
5184
/// Execute the optimization.
5285
pub fn execute<BT: BoolElement>(&mut self, context: &mut Context<'_, JitFusionHandle<R>>) {
86+
#[cfg(feature = "autotune")]
87+
fused_matmul_autotune::<R, BT>(self, context);
88+
89+
#[cfg(not(feature = "autotune"))]
5390
if self.execute_fused::<BT>(context).is_err() {
5491
self.execute_fallback::<BT>(context);
5592
}
@@ -68,7 +105,9 @@ impl<R: JitRuntime> MatmulOptimization<R> {
68105
len: state.len,
69106
client: R::client(device),
70107
device: device.clone(),
71-
matmul: state.matmul.clone(),
108+
matmul_standard: state.matmul_standard.clone(),
109+
matmul_specialized: state.matmul_specialized.clone(),
110+
matmul_pipelined: state.matmul_pipelined.clone(),
72111
}
73112
}
74113

@@ -77,21 +116,51 @@ impl<R: JitRuntime> MatmulOptimization<R> {
77116
MatmulOptimizationState {
78117
trace: self.trace.clone(),
79118
trace_fallback: self.trace_fallback.clone(),
80-
matmul: self.matmul.clone(),
119+
matmul_standard: self.matmul_standard.clone(),
120+
matmul_specialized: self.matmul_specialized.clone(),
121+
matmul_pipelined: self.matmul_pipelined.clone(),
81122
len: self.len,
82123
}
83124
}
84125

85-
fn execute_fused<BT: BoolElement>(
86-
&mut self,
126+
pub fn execute_standard_fused<BT: BoolElement>(
127+
&self,
87128
context: &mut Context<'_, JitFusionHandle<R>>,
88129
) -> Result<(), FusedMatmulError> {
89-
self.trace
90-
.run::<R, BT, FusedMatmul>(&self.client, &self.device, context, &self.matmul)
130+
self.trace.run::<R, BT, FusedMatmul>(
131+
&self.client,
132+
&self.device,
133+
context,
134+
&self.matmul_standard,
135+
)
91136
}
92137

93-
fn execute_fallback<BT: BoolElement>(&mut self, context: &mut Context<'_, JitFusionHandle<R>>) {
94-
match self.matmul.lhs.precision() {
138+
pub fn execute_specialized_fused<BT: BoolElement>(
139+
&self,
140+
context: &mut Context<'_, JitFusionHandle<R>>,
141+
) -> Result<(), FusedMatmulError> {
142+
self.trace.run::<R, BT, FusedMatmul>(
143+
&self.client,
144+
&self.device,
145+
context,
146+
&self.matmul_specialized,
147+
)
148+
}
149+
150+
pub fn execute_pipelined_fused<BT: BoolElement>(
151+
&self,
152+
context: &mut Context<'_, JitFusionHandle<R>>,
153+
) -> Result<(), FusedMatmulError> {
154+
self.trace.run::<R, BT, FusedMatmul>(
155+
&self.client,
156+
&self.device,
157+
context,
158+
&self.matmul_pipelined,
159+
)
160+
}
161+
162+
pub fn execute_fallback<BT: BoolElement>(&self, context: &mut Context<'_, JitFusionHandle<R>>) {
163+
match self.matmul_standard.lhs.precision() {
95164
ElemwisePrecision::F32 => self.run_fallback::<BT, f32>(context),
96165
ElemwisePrecision::F16 => self.run_fallback::<BT, f16>(context),
97166
ElemwisePrecision::BF16 => self.run_fallback::<BT, bf16>(context),
@@ -100,13 +169,25 @@ impl<R: JitRuntime> MatmulOptimization<R> {
100169
}
101170

102171
fn run_fallback<BT: BoolElement, EG: FloatElement>(
103-
&mut self,
172+
&self,
104173
context: &mut Context<'_, JitFusionHandle<R>>,
105174
) {
106175
let (out_tensor, out_desc) = {
107-
let lhs = context.tensors.get(&self.matmul.op.lhs.id).unwrap().clone();
108-
let rhs = context.tensors.get(&self.matmul.op.rhs.id).unwrap().clone();
109-
let out = context.tensors.get(&self.matmul.op.out.id).unwrap().clone();
176+
let lhs = context
177+
.tensors
178+
.get(&self.matmul_standard.op.lhs.id)
179+
.unwrap()
180+
.clone();
181+
let rhs = context
182+
.tensors
183+
.get(&self.matmul_standard.op.rhs.id)
184+
.unwrap()
185+
.clone();
186+
let out = context
187+
.tensors
188+
.get(&self.matmul_standard.op.out.id)
189+
.unwrap()
190+
.clone();
110191

111192
let lhs_handle = context.handles.get_handle(&lhs.id, &TensorStatus::ReadOnly);
112193
let rhs_handle = context.handles.get_handle(&rhs.id, &TensorStatus::ReadOnly);
@@ -136,12 +217,21 @@ impl<R: JitRuntime> MatmulOptimization<R> {
136217
}
137218
}
138219

220+
#[derive(Default, Clone, Serialize, Deserialize, Debug)]
221+
pub enum FusedMatmulSelector {
222+
#[default]
223+
Standard,
224+
Pipelined,
225+
Specialized,
226+
}
227+
139228
#[derive(new, Clone, Serialize, Deserialize, Debug)]
140229
pub struct FusedMatmul {
141230
lhs: Arg,
142231
rhs: Arg,
143232
out: Arg,
144-
op: BinaryOperationDescription,
233+
pub(crate) op: BinaryOperationDescription,
234+
pub(crate) selector: FusedMatmulSelector,
145235
}
146236

147237
#[derive(Debug)]
@@ -261,15 +351,43 @@ impl FusedMatmul {
261351
}
262352
};
263353

264-
match matmul_launch_kernel::<R, EG, StandardSelector<Accelerated>>(
265-
client,
266-
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
267-
outputs,
268-
problem,
269-
plane_size,
270-
) {
271-
Ok(_) => Ok(()),
272-
Err(err) => Err(FusedMatmulError::LaunchError(err)),
354+
match self.selector {
355+
FusedMatmulSelector::Standard => {
356+
match matmul_launch_kernel::<R, EG, StandardSelector<Accelerated>>(
357+
client,
358+
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
359+
outputs,
360+
problem,
361+
plane_size,
362+
) {
363+
Ok(_) => Ok(()),
364+
Err(err) => Err(FusedMatmulError::LaunchError(err)),
365+
}
366+
}
367+
FusedMatmulSelector::Pipelined => {
368+
match matmul_launch_kernel::<R, EG, PipelinedSelector<Accelerated>>(
369+
client,
370+
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
371+
outputs,
372+
problem,
373+
plane_size,
374+
) {
375+
Ok(_) => Ok(()),
376+
Err(err) => Err(FusedMatmulError::LaunchError(err)),
377+
}
378+
}
379+
FusedMatmulSelector::Specialized => {
380+
match matmul_launch_kernel::<R, EG, SpecializedSelector<Accelerated>>(
381+
client,
382+
FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out),
383+
outputs,
384+
problem,
385+
plane_size,
386+
) {
387+
Ok(_) => Ok(()),
388+
Err(err) => Err(FusedMatmulError::LaunchError(err)),
389+
}
390+
}
273391
}
274392
}
275393
}
+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
use crate::{
2+
fusion::{
3+
tune::{TuneContext, TuneInput},
4+
JitFusionHandle,
5+
},
6+
kernel::matmul::MatmulAutotuneKey,
7+
BoolElement, JitRuntime, JitTuneId,
8+
};
9+
use burn_fusion::stream::Context;
10+
use cubecl::{
11+
tune::{local_tuner, LocalTuner, TunableSet},
12+
AutotuneKey,
13+
};
14+
use serde::{Deserialize, Serialize};
15+
16+
use super::optimization::MatmulOptimization;
17+
18+
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
19+
pub struct FusedMatmulAutotuneKey {
20+
matmul_key: MatmulAutotuneKey,
21+
#[autotune(anchor)]
22+
num_ops_fused: usize,
23+
}
24+
25+
/// Executes autotune on matmul operations
26+
pub fn fused_matmul_autotune<R: JitRuntime, BT: BoolElement>(
27+
optimization: &MatmulOptimization<R>,
28+
context: &mut Context<JitFusionHandle<R>>,
29+
) {
30+
static TUNER: LocalTuner<FusedMatmulAutotuneKey, JitTuneId> = local_tuner!();
31+
32+
let tunables = TunableSet::new(create_key::<R>, input_gen::<R>)
33+
.with_tunable(tune_standard_fused::<R, BT>)
34+
.with_tunable(tune_specialized_fused::<R, BT>)
35+
.with_tunable(tune_pipelined_fused::<R, BT>)
36+
.with_tunable(tune_fallback::<R, BT>);
37+
38+
TUNER.execute(
39+
&JitTuneId::new::<R>(&optimization.device),
40+
&optimization.client,
41+
&tunables,
42+
TuneInput::new(context, optimization),
43+
);
44+
}
45+
46+
pub(crate) fn create_key<R: JitRuntime>(
47+
input: &TuneInput<R, MatmulOptimization<R>>,
48+
) -> FusedMatmulAutotuneKey {
49+
let opt = input.optimization();
50+
let context = match input.context() {
51+
TuneContext::Original(context) => context,
52+
TuneContext::Fork(_) => panic!("Not supported when generating key"),
53+
};
54+
55+
let lhs = context.tensors.get(&opt.matmul_standard.op.lhs.id).unwrap();
56+
let rhs = context.tensors.get(&opt.matmul_standard.op.rhs.id).unwrap();
57+
let out = context.tensors.get(&opt.matmul_standard.op.out.id).unwrap();
58+
59+
let key = MatmulAutotuneKey::from_shape(
60+
&lhs.shape.clone().into(),
61+
&rhs.shape.clone().into(),
62+
out.dtype,
63+
);
64+
FusedMatmulAutotuneKey::new(key, opt.len)
65+
}
66+
67+
fn input_gen<R: JitRuntime>(
68+
_key: &FusedMatmulAutotuneKey,
69+
input: &TuneInput<R, MatmulOptimization<R>>,
70+
) -> TuneInput<R, MatmulOptimization<R>> {
71+
input.clone()
72+
}
73+
74+
fn tune_standard_fused<R: JitRuntime, BT: BoolElement>(
75+
input: TuneInput<R, MatmulOptimization<R>>,
76+
) -> Result<(), String> {
77+
let optimization = input.optimization();
78+
let context = input.context();
79+
80+
match context {
81+
TuneContext::Original(context) => optimization.execute_standard_fused::<BT>(context),
82+
TuneContext::Fork(mut context_owned) => {
83+
optimization.execute_standard_fused::<BT>(&mut context_owned.as_context())
84+
}
85+
}
86+
.map_err(|e| format!("{e:?}"))
87+
}
88+
89+
fn tune_specialized_fused<R: JitRuntime, BT: BoolElement>(
90+
input: TuneInput<R, MatmulOptimization<R>>,
91+
) -> Result<(), String> {
92+
let optimization = input.optimization();
93+
let context = input.context();
94+
95+
match context {
96+
TuneContext::Original(context) => optimization.execute_specialized_fused::<BT>(context),
97+
TuneContext::Fork(mut context_owned) => {
98+
optimization.execute_specialized_fused::<BT>(&mut context_owned.as_context())
99+
}
100+
}
101+
.map_err(|e| format!("{e:?}"))
102+
}
103+
104+
fn tune_pipelined_fused<R: JitRuntime, BT: BoolElement>(
105+
input: TuneInput<R, MatmulOptimization<R>>,
106+
) -> Result<(), String> {
107+
let optimization = input.optimization();
108+
let context = input.context();
109+
110+
match context {
111+
TuneContext::Original(context) => optimization.execute_pipelined_fused::<BT>(context),
112+
TuneContext::Fork(mut context_owned) => {
113+
optimization.execute_pipelined_fused::<BT>(&mut context_owned.as_context())
114+
}
115+
}
116+
.map_err(|e| format!("{e:?}"))
117+
}
118+
119+
fn tune_fallback<R: JitRuntime, BT: BoolElement>(
120+
input: TuneInput<R, MatmulOptimization<R>>,
121+
) -> Result<(), String> {
122+
let optimization = input.optimization();
123+
let context = input.context();
124+
125+
match context {
126+
TuneContext::Original(context) => optimization.execute_fallback::<BT>(context),
127+
TuneContext::Fork(mut context_owned) => {
128+
optimization.execute_fallback::<BT>(&mut context_owned.as_context())
129+
}
130+
};
131+
132+
Ok(())
133+
}

‎crates/burn-jit/src/fusion/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ mod base;
33
pub(crate) mod elemwise;
44
pub(crate) mod matmul;
55
pub(crate) mod on_write;
6+
pub(crate) mod tune;
67

78
pub use base::*;

‎crates/burn-jit/src/fusion/tune.rs

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use super::JitFusionHandle;
2+
use crate::JitRuntime;
3+
use burn_fusion::stream::{Context, ContextOwned};
4+
5+
/// Fusion context used when tuning kernels.
6+
///
7+
/// Either the original context is returned or a fork of the original.
8+
/// The fork is only given when performing autotuning, and not when actually performing the
9+
/// operation.
10+
pub enum TuneContext<'a, R: JitRuntime> {
11+
Original(&'a mut Context<'a, JitFusionHandle<R>>),
12+
Fork(Box<ContextOwned<JitFusionHandle<R>>>),
13+
}
14+
15+
/// Fusion input wrapper containing the context and the optimization.
16+
///
17+
/// # Safety
18+
///
19+
/// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions
20+
/// are made based on its behavior.
21+
pub struct TuneInput<R: JitRuntime, O> {
22+
context: UnsafeTuneContext<R>,
23+
optimization: *const O,
24+
}
25+
26+
/// Unsafe wrapper around the context.
27+
///
28+
/// # Safety
29+
///
30+
/// The wrapper removes the context lifetime.
31+
///
32+
/// For it to be correct, the context must not be used after the invocation of the
33+
/// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are
34+
/// tuned using a cloned version of the input; therefore, a fork of the context will be used to find
35+
/// the best kernel to use, which can be async.
36+
enum UnsafeTuneContext<R: JitRuntime> {
37+
Original(*mut Context<'static, JitFusionHandle<R>>),
38+
Fork(Box<ContextOwned<JitFusionHandle<R>>>),
39+
}
40+
41+
unsafe impl<R: JitRuntime> Send for UnsafeTuneContext<R> {}
42+
unsafe impl<R: JitRuntime, O> Send for TuneInput<R, O> {}
43+
44+
impl<R: JitRuntime, O> TuneInput<R, O> {
45+
/// Create a new autotune input from the [context](Context) and an optimization.
46+
pub fn new(context: &mut Context<JitFusionHandle<R>>, optimization: &O) -> Self {
47+
let context = UnsafeTuneContext::new(context);
48+
// We can erase the lifetime for the same reason we do with the context.
49+
let optimization = core::ptr::from_ref(optimization);
50+
51+
Self {
52+
context,
53+
optimization,
54+
}
55+
}
56+
57+
/// Retrieve the [autotune context](TuneContext) for the current input.
58+
pub fn context(&self) -> TuneContext<'static, R> {
59+
self.context.get()
60+
}
61+
62+
/// Retrieve the optimization for the current input.
63+
pub fn optimization(&self) -> &O {
64+
unsafe { self.optimization.as_ref().unwrap() }
65+
}
66+
}
67+
68+
impl<R: JitRuntime> UnsafeTuneContext<R> {
69+
fn new(context: &mut Context<'_, JitFusionHandle<R>>) -> Self {
70+
let ptr = core::ptr::from_mut(context);
71+
72+
// It is necessary for the lifetime.
73+
#[allow(clippy::unnecessary_cast)]
74+
Self::Original(ptr as *mut Context<'static, _>)
75+
}
76+
77+
fn get(&self) -> TuneContext<'static, R> {
78+
match self {
79+
UnsafeTuneContext::Original(ptr) => {
80+
TuneContext::Original(unsafe { ptr.as_mut().unwrap() })
81+
}
82+
UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())),
83+
}
84+
}
85+
}
86+
87+
impl<R: JitRuntime, O> Clone for TuneInput<R, O> {
88+
fn clone(&self) -> Self {
89+
Self {
90+
context: self.context.clone(),
91+
optimization: self.optimization,
92+
}
93+
}
94+
}
95+
96+
impl<R: JitRuntime> Clone for UnsafeTuneContext<R> {
97+
fn clone(&self) -> Self {
98+
let context = match self {
99+
UnsafeTuneContext::Original(ptr) => {
100+
let context: &mut Context<'static, JitFusionHandle<R>> =
101+
unsafe { ptr.as_mut().unwrap() };
102+
context.fork()
103+
}
104+
UnsafeTuneContext::Fork(context) => context.fork(),
105+
};
106+
UnsafeTuneContext::Fork(Box::new(context))
107+
}
108+
}

‎crates/burn-jit/src/kernel/matmul/tune/key.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub struct MatmulAutotuneKey {
2222
}
2323

2424
impl MatmulAutotuneKey {
25-
fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self {
25+
pub(crate) fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self {
2626
let ndims = lhs_shape.num_dims();
2727
let m = lhs_shape.dims[ndims - 2];
2828
let k = lhs_shape.dims[ndims - 1];

‎crates/burn-tensor/src/repr/handle.rs

+18
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@ pub struct HandleContainer<H> {
2626
pub handles_orphan: Vec<TensorId>,
2727
}
2828

29+
impl<H: Clone> HandleContainer<H> {
30+
/// Fork the container, useful for autotune.
31+
pub fn fork(&self) -> Self {
32+
let mut handles = HashMap::with_capacity(self.handles.len());
33+
34+
for (id, handle) in self.handles.iter() {
35+
handles.insert(*id, handle.clone());
36+
}
37+
38+
Self {
39+
handles,
40+
counter: self.counter,
41+
handles_orphan: self.handles_orphan.clone(),
42+
}
43+
}
44+
}
45+
2946
impl<H> core::fmt::Debug for HandleContainer<H> {
3047
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3148
f.debug_struct("HandleContainer")
@@ -37,6 +54,7 @@ impl<H> core::fmt::Debug for HandleContainer<H> {
3754
}
3855

3956
/// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state
57+
#[derive(Clone)]
4058
pub enum Handle<H> {
4159
/// No [tensor handle](ReprBackend::Handle) has been created yet
4260
NotInit,

‎examples/text-classification/examples/ag-news-train.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#![recursion_limit = "256"]
2+
13
use burn::{
24
nn::transformer::TransformerEncoderConfig,
35
optim::{decay::WeightDecayConfig, AdamConfig},

0 commit comments

Comments
 (0)
Please sign in to comment.