@@ -55,9 +55,15 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
55
55
if ( ! outputShape ) {
56
56
throw new Error ( "Can't use gemm on the given tensors" ) ;
57
57
}
58
+ const tileSize = 16 ;
59
+ const numTileN = Math . ceil ( N / tileSize ) ;
60
+ const numTileM = Math . ceil ( M / tileSize ) ;
61
+ // TODO: Find the condition when directly use the naive one.
62
+ const useShared = true ;
63
+
58
64
const outputSize = ShapeUtil . size ( outputShape ) ;
59
65
const programUniforms : ProgramUniform [ ] = [
60
- { type : DataType . uint32 , data : outputSize } ,
66
+ { type : DataType . uint32 , data : useShared ? numTileN : outputSize } ,
61
67
{ type : DataType . uint32 , data : M } ,
62
68
{ type : DataType . uint32 , data : N } ,
63
69
{ type : DataType . uint32 , data : K } ,
@@ -130,6 +136,159 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
130
136
}` ;
131
137
} ;
132
138
139
+ const getShaderSourceShared = ( shaderHelper : ShaderHelper ) => {
140
+ const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims ) ;
141
+ const b = inputVariable ( 'b' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims ) ;
142
+ let c : IndicesHelper | null = null ;
143
+ const variables = [ a , b ] ;
144
+ if ( inputs . length === 3 ) {
145
+ c = inputVariable ( 'c' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length ) ;
146
+ variables . push ( c ) ;
147
+ }
148
+ const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape . length ) ;
149
+ variables . push ( output ) ;
150
+ const uniforms : UniformsArrayType = [
151
+ { name : 'num_tile_n' , type : 'u32' } ,
152
+ { name : 'M' , type : 'u32' } ,
153
+ { name : 'N' , type : 'u32' } ,
154
+ { name : 'K' , type : 'u32' } ,
155
+ { name : 'alpha' , type : 'f32' } ,
156
+ { name : 'beta' , type : 'f32' } ,
157
+ ] ;
158
+
159
+ let calcResult = '' ;
160
+ let fillWorkgroupMemory = '' ;
161
+ if ( attributes . transA && attributes . transB ) {
162
+ fillWorkgroupMemory = `
163
+ var col = tile_row_start + local_id.x;
164
+ var row = k_start + local_id.y;
165
+ if (col < uniforms.M && row < uniforms.K) {
166
+ tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col];
167
+ } else {
168
+ tile_a[local_id.y][local_id.x] = ${ a . type . value } (0);
169
+ }
170
+
171
+ col = k_start + local_id.x;
172
+ row = tile_col_start + local_id.y;
173
+ if (col < uniforms.K && row < uniforms.N) {
174
+ tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col];
175
+ } else {
176
+ tile_b[local_id.y][local_id.x] = ${ b . type . value } (0);
177
+ }
178
+ ` ;
179
+ calcResult = `value += tile_a[k][local_id.y] * tile_b[local_id.x][k];` ;
180
+ } else if ( attributes . transA && ! attributes . transB ) {
181
+ fillWorkgroupMemory = `
182
+ var col = tile_row_start + local_id.x;
183
+ var row = k_start + local_id.y;
184
+ if (col < uniforms.M && row < uniforms.K) {
185
+ tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col];
186
+ } else {
187
+ tile_a[local_id.y][local_id.x] = ${ a . type . value } (0);
188
+ }
189
+
190
+ col = tile_col_start + local_id.x;
191
+ row = k_start + local_id.y;
192
+ if (col < uniforms.N && row < uniforms.K) {
193
+ tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col];
194
+ } else {
195
+ tile_b[local_id.y][local_id.x] = ${ b . type . value } (0);
196
+ }
197
+ ` ;
198
+ calcResult = `value += tile_a[k][local_id.y] * tile_b[k][local_id.x];` ;
199
+ } else if ( ! attributes . transA && attributes . transB ) {
200
+ fillWorkgroupMemory = `
201
+ var col = k_start + local_id.x;
202
+ var row = tile_row_start + local_id.y;
203
+ if (col < uniforms.K && row < uniforms.M) {
204
+ tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col];
205
+ } else {
206
+ tile_a[local_id.y][local_id.x] = ${ a . type . value } (0);
207
+ }
208
+
209
+ col = k_start + local_id.x;
210
+ row = tile_col_start + local_id.y;
211
+ if (col < uniforms.K && row < uniforms.N) {
212
+ tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col];
213
+ } else {
214
+ tile_b[local_id.y][local_id.x] = ${ b . type . value } (0);
215
+ }
216
+ ` ;
217
+ calcResult = `value += tile_a[local_id.y][k] * tile_b[local_id.x][k];` ;
218
+ } else if ( ! attributes . transA && ! attributes . transB ) {
219
+ fillWorkgroupMemory = `
220
+ var col = k_start + local_id.x;
221
+ var row = tile_row_start + local_id.y;
222
+ if (col < uniforms.K && row < uniforms.M) {
223
+ tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col];
224
+ } else {
225
+ tile_a[local_id.y][local_id.x] = ${ a . type . value } (0);
226
+ }
227
+
228
+ col = tile_col_start + local_id.x;
229
+ row = k_start + local_id.y;
230
+ if (col < uniforms.N && row < uniforms.K) {
231
+ tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col];
232
+ } else {
233
+ tile_b[local_id.y][local_id.x] = ${ b . type . value } (0);
234
+ }
235
+ ` ;
236
+ calcResult = `value += tile_a[local_id.y][k] * tile_b[k][local_id.x];` ;
237
+ }
238
+
239
+ const calculateAlpha = attributes . alpha === 1 ? '' : 'value *= uniforms.alpha;' ;
240
+
241
+ return `
242
+ ${ shaderHelper . registerUniforms ( uniforms ) . declareVariables ( ...variables ) }
243
+ var<workgroup> tile_a: array<array<${ a . type . storage } , ${ tileSize } >, ${ tileSize } >;
244
+ var<workgroup> tile_b: array<array<${ b . type . storage } , ${ tileSize } >, ${ tileSize } >;
245
+ ${ shaderHelper . mainStart ( [ tileSize , tileSize , 1 ] ) }
246
+ let tile_col_start = (workgroup_index % uniforms.num_tile_n) * ${ tileSize } ;
247
+ let tile_row_start = (workgroup_index / uniforms.num_tile_n) * ${ tileSize } ;
248
+ let num_tiles = (uniforms.K - 1) / ${ tileSize } + 1;
249
+ var k_start = 0u;
250
+ var value = ${ output . type . value } (0);
251
+ for (var t: u32 = 0u; t < num_tiles; t++) {
252
+ ${ fillWorkgroupMemory }
253
+ k_start = k_start + ${ tileSize } ;
254
+ workgroupBarrier();
255
+
256
+ for (var k: u32 = 0u; k < ${ tileSize } ; k++) {
257
+ ${ calcResult }
258
+ }
259
+ workgroupBarrier();
260
+ }
261
+
262
+ ${ calculateAlpha }
263
+ let m = tile_row_start + local_id.y;
264
+ let n = tile_col_start + local_id.x;
265
+ ${ ( ( ) => {
266
+ if ( c != null ) {
267
+ return `let cOffset = ${ c . broadcastedIndicesToOffset ( 'vec2(m, n)' , output ) } ; value += ${
268
+ output . type . value
269
+ } (uniforms.beta) * ${ c . getByOffset ( 'cOffset' ) } ;`;
270
+ }
271
+ return '' ;
272
+ } ) ( ) }
273
+ if (m < uniforms.M && n < uniforms.N) {
274
+ output[m * uniforms.N + n] = value;
275
+ }
276
+ }` ;
277
+ } ;
278
+
279
+ if ( useShared ) {
280
+ return {
281
+ name : 'GemmShared' ,
282
+ shaderCache : { hint : `${ attributes . cacheKey } ` , inputDependencies } ,
283
+ getRunData : ( ) => ( {
284
+ outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
285
+ dispatchGroup : { x : numTileN * numTileM } ,
286
+ programUniforms,
287
+ } ) ,
288
+ getShaderSource : getShaderSourceShared ,
289
+ } ;
290
+ }
291
+
133
292
return {
134
293
name : 'Gemm' ,
135
294
shaderCache : { hint : `${ attributes . cacheKey } ` , inputDependencies } ,
0 commit comments