@@ -12,7 +12,9 @@ use burn_tensor::Shape;
12
12
use cubecl:: linalg:: matmul:: components;
13
13
use cubecl:: linalg:: matmul:: components:: tile:: accelerated:: Accelerated ;
14
14
use cubecl:: linalg:: matmul:: components:: MatmulProblem ;
15
- use cubecl:: linalg:: matmul:: kernels:: matmul:: { MatmulSelector , StandardSelector } ;
15
+ use cubecl:: linalg:: matmul:: kernels:: matmul:: {
16
+ MatmulSelector , PipelinedSelector , SpecializedSelector , StandardSelector ,
17
+ } ;
16
18
use cubecl:: linalg:: matmul:: kernels:: { MatmulAvailabilityError , MatmulLaunchError } ;
17
19
use cubecl:: linalg:: tensor:: { matrix_layout, MatrixLayout } ;
18
20
use cubecl:: { client:: ComputeClient , prelude:: * } ;
@@ -26,30 +28,65 @@ use crate::fusion::on_write::{
26
28
27
29
use super :: args:: FusedMatmulInputLaunch ;
28
30
use super :: spec:: FusedMatmulSpec ;
31
+ use super :: tune:: fused_matmul_autotune;
29
32
30
- #[ derive( new) ]
31
33
/// Fuse matmul operation followed by elemwise operations into a single kernel.
32
34
pub struct MatmulOptimization < R : JitRuntime > {
33
35
trace : FuseOnWriteTrace ,
34
36
trace_fallback : FuseOnWriteTrace ,
35
- client : ComputeClient < R :: Server , R :: Channel > ,
36
- device : R :: Device ,
37
- len : usize ,
38
- matmul : FusedMatmul ,
37
+ pub ( crate ) client : ComputeClient < R :: Server , R :: Channel > ,
38
+ pub ( crate ) device : R :: Device ,
39
+ pub ( crate ) len : usize ,
40
+ pub ( crate ) matmul_standard : FusedMatmul ,
41
+ pub ( crate ) matmul_pipelined : FusedMatmul ,
42
+ pub ( crate ) matmul_specialized : FusedMatmul ,
39
43
}
40
44
41
45
#[ derive( Serialize , Deserialize , Debug ) ]
42
46
/// State for the [matrix optimization](MatmulOptimizationState).
43
47
pub struct MatmulOptimizationState {
44
48
trace : FuseOnWriteTrace ,
45
49
trace_fallback : FuseOnWriteTrace ,
46
- matmul : FusedMatmul ,
50
+ matmul_standard : FusedMatmul ,
51
+ matmul_pipelined : FusedMatmul ,
52
+ matmul_specialized : FusedMatmul ,
47
53
len : usize ,
48
54
}
49
55
50
56
impl < R : JitRuntime > MatmulOptimization < R > {
57
+ pub fn new (
58
+ trace : FuseOnWriteTrace ,
59
+ trace_fallback : FuseOnWriteTrace ,
60
+ client : ComputeClient < R :: Server , R :: Channel > ,
61
+ device : R :: Device ,
62
+ len : usize ,
63
+ matmul : FusedMatmul ,
64
+ ) -> Self {
65
+ let mut matmul_standard = matmul. clone ( ) ;
66
+ let mut matmul_specialized = matmul. clone ( ) ;
67
+ let mut matmul_pipelined = matmul;
68
+
69
+ matmul_standard. selector = FusedMatmulSelector :: Standard ;
70
+ matmul_specialized. selector = FusedMatmulSelector :: Specialized ;
71
+ matmul_pipelined. selector = FusedMatmulSelector :: Pipelined ;
72
+
73
+ Self {
74
+ trace,
75
+ trace_fallback,
76
+ client,
77
+ device,
78
+ len,
79
+ matmul_standard,
80
+ matmul_pipelined,
81
+ matmul_specialized,
82
+ }
83
+ }
51
84
/// Execute the optimization.
52
85
pub fn execute < BT : BoolElement > ( & mut self , context : & mut Context < ' _ , JitFusionHandle < R > > ) {
86
+ #[ cfg( feature = "autotune" ) ]
87
+ fused_matmul_autotune :: < R , BT > ( self , context) ;
88
+
89
+ #[ cfg( not( feature = "autotune" ) ) ]
53
90
if self . execute_fused :: < BT > ( context) . is_err ( ) {
54
91
self . execute_fallback :: < BT > ( context) ;
55
92
}
@@ -68,7 +105,9 @@ impl<R: JitRuntime> MatmulOptimization<R> {
68
105
len : state. len ,
69
106
client : R :: client ( device) ,
70
107
device : device. clone ( ) ,
71
- matmul : state. matmul . clone ( ) ,
108
+ matmul_standard : state. matmul_standard . clone ( ) ,
109
+ matmul_specialized : state. matmul_specialized . clone ( ) ,
110
+ matmul_pipelined : state. matmul_pipelined . clone ( ) ,
72
111
}
73
112
}
74
113
@@ -77,21 +116,51 @@ impl<R: JitRuntime> MatmulOptimization<R> {
77
116
MatmulOptimizationState {
78
117
trace : self . trace . clone ( ) ,
79
118
trace_fallback : self . trace_fallback . clone ( ) ,
80
- matmul : self . matmul . clone ( ) ,
119
+ matmul_standard : self . matmul_standard . clone ( ) ,
120
+ matmul_specialized : self . matmul_specialized . clone ( ) ,
121
+ matmul_pipelined : self . matmul_pipelined . clone ( ) ,
81
122
len : self . len ,
82
123
}
83
124
}
84
125
85
- fn execute_fused < BT : BoolElement > (
86
- & mut self ,
126
+ pub fn execute_standard_fused < BT : BoolElement > (
127
+ & self ,
87
128
context : & mut Context < ' _ , JitFusionHandle < R > > ,
88
129
) -> Result < ( ) , FusedMatmulError > {
89
- self . trace
90
- . run :: < R , BT , FusedMatmul > ( & self . client , & self . device , context, & self . matmul )
130
+ self . trace . run :: < R , BT , FusedMatmul > (
131
+ & self . client ,
132
+ & self . device ,
133
+ context,
134
+ & self . matmul_standard ,
135
+ )
91
136
}
92
137
93
- fn execute_fallback < BT : BoolElement > ( & mut self , context : & mut Context < ' _ , JitFusionHandle < R > > ) {
94
- match self . matmul . lhs . precision ( ) {
138
+ pub fn execute_specialized_fused < BT : BoolElement > (
139
+ & self ,
140
+ context : & mut Context < ' _ , JitFusionHandle < R > > ,
141
+ ) -> Result < ( ) , FusedMatmulError > {
142
+ self . trace . run :: < R , BT , FusedMatmul > (
143
+ & self . client ,
144
+ & self . device ,
145
+ context,
146
+ & self . matmul_specialized ,
147
+ )
148
+ }
149
+
150
+ pub fn execute_pipelined_fused < BT : BoolElement > (
151
+ & self ,
152
+ context : & mut Context < ' _ , JitFusionHandle < R > > ,
153
+ ) -> Result < ( ) , FusedMatmulError > {
154
+ self . trace . run :: < R , BT , FusedMatmul > (
155
+ & self . client ,
156
+ & self . device ,
157
+ context,
158
+ & self . matmul_pipelined ,
159
+ )
160
+ }
161
+
162
+ pub fn execute_fallback < BT : BoolElement > ( & self , context : & mut Context < ' _ , JitFusionHandle < R > > ) {
163
+ match self . matmul_standard . lhs . precision ( ) {
95
164
ElemwisePrecision :: F32 => self . run_fallback :: < BT , f32 > ( context) ,
96
165
ElemwisePrecision :: F16 => self . run_fallback :: < BT , f16 > ( context) ,
97
166
ElemwisePrecision :: BF16 => self . run_fallback :: < BT , bf16 > ( context) ,
@@ -100,13 +169,25 @@ impl<R: JitRuntime> MatmulOptimization<R> {
100
169
}
101
170
102
171
fn run_fallback < BT : BoolElement , EG : FloatElement > (
103
- & mut self ,
172
+ & self ,
104
173
context : & mut Context < ' _ , JitFusionHandle < R > > ,
105
174
) {
106
175
let ( out_tensor, out_desc) = {
107
- let lhs = context. tensors . get ( & self . matmul . op . lhs . id ) . unwrap ( ) . clone ( ) ;
108
- let rhs = context. tensors . get ( & self . matmul . op . rhs . id ) . unwrap ( ) . clone ( ) ;
109
- let out = context. tensors . get ( & self . matmul . op . out . id ) . unwrap ( ) . clone ( ) ;
176
+ let lhs = context
177
+ . tensors
178
+ . get ( & self . matmul_standard . op . lhs . id )
179
+ . unwrap ( )
180
+ . clone ( ) ;
181
+ let rhs = context
182
+ . tensors
183
+ . get ( & self . matmul_standard . op . rhs . id )
184
+ . unwrap ( )
185
+ . clone ( ) ;
186
+ let out = context
187
+ . tensors
188
+ . get ( & self . matmul_standard . op . out . id )
189
+ . unwrap ( )
190
+ . clone ( ) ;
110
191
111
192
let lhs_handle = context. handles . get_handle ( & lhs. id , & TensorStatus :: ReadOnly ) ;
112
193
let rhs_handle = context. handles . get_handle ( & rhs. id , & TensorStatus :: ReadOnly ) ;
@@ -136,12 +217,21 @@ impl<R: JitRuntime> MatmulOptimization<R> {
136
217
}
137
218
}
138
219
220
+ #[ derive( Default , Clone , Serialize , Deserialize , Debug ) ]
221
+ pub enum FusedMatmulSelector {
222
+ #[ default]
223
+ Standard ,
224
+ Pipelined ,
225
+ Specialized ,
226
+ }
227
+
139
228
#[ derive( new, Clone , Serialize , Deserialize , Debug ) ]
140
229
pub struct FusedMatmul {
141
230
lhs : Arg ,
142
231
rhs : Arg ,
143
232
out : Arg ,
144
- op : BinaryOperationDescription ,
233
+ pub ( crate ) op : BinaryOperationDescription ,
234
+ pub ( crate ) selector : FusedMatmulSelector ,
145
235
}
146
236
147
237
#[ derive( Debug ) ]
@@ -261,15 +351,43 @@ impl FusedMatmul {
261
351
}
262
352
} ;
263
353
264
- match matmul_launch_kernel :: < R , EG , StandardSelector < Accelerated > > (
265
- client,
266
- FusedMatmulInputLaunch :: new ( inputs, config, & self . lhs , & self . rhs , & self . out ) ,
267
- outputs,
268
- problem,
269
- plane_size,
270
- ) {
271
- Ok ( _) => Ok ( ( ) ) ,
272
- Err ( err) => Err ( FusedMatmulError :: LaunchError ( err) ) ,
354
+ match self . selector {
355
+ FusedMatmulSelector :: Standard => {
356
+ match matmul_launch_kernel :: < R , EG , StandardSelector < Accelerated > > (
357
+ client,
358
+ FusedMatmulInputLaunch :: new ( inputs, config, & self . lhs , & self . rhs , & self . out ) ,
359
+ outputs,
360
+ problem,
361
+ plane_size,
362
+ ) {
363
+ Ok ( _) => Ok ( ( ) ) ,
364
+ Err ( err) => Err ( FusedMatmulError :: LaunchError ( err) ) ,
365
+ }
366
+ }
367
+ FusedMatmulSelector :: Pipelined => {
368
+ match matmul_launch_kernel :: < R , EG , PipelinedSelector < Accelerated > > (
369
+ client,
370
+ FusedMatmulInputLaunch :: new ( inputs, config, & self . lhs , & self . rhs , & self . out ) ,
371
+ outputs,
372
+ problem,
373
+ plane_size,
374
+ ) {
375
+ Ok ( _) => Ok ( ( ) ) ,
376
+ Err ( err) => Err ( FusedMatmulError :: LaunchError ( err) ) ,
377
+ }
378
+ }
379
+ FusedMatmulSelector :: Specialized => {
380
+ match matmul_launch_kernel :: < R , EG , SpecializedSelector < Accelerated > > (
381
+ client,
382
+ FusedMatmulInputLaunch :: new ( inputs, config, & self . lhs , & self . rhs , & self . out ) ,
383
+ outputs,
384
+ problem,
385
+ plane_size,
386
+ ) {
387
+ Ok ( _) => Ok ( ( ) ) ,
388
+ Err ( err) => Err ( FusedMatmulError :: LaunchError ( err) ) ,
389
+ }
390
+ }
273
391
}
274
392
}
275
393
}
0 commit comments