29
29
*
30
30
**************************************************************************************************/
31
31
32
- // Inspired by: sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
32
+ // Inspired by: sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
33
33
34
34
#pragma once
35
35
36
36
#include " cutlass/cutlass.h"
37
+ #include " cutlass/detail/blockwise_scale_layout.hpp"
37
38
#include " cutlass/gemm/dispatch_policy.hpp"
38
39
39
40
#include " cute/algorithm/functional.hpp"
40
41
#include " cute/atom/mma_atom.hpp"
41
42
#include " cute/algorithm/gemm.hpp"
42
- #include " cute/tensor_predicate.hpp"
43
43
#include " cute/numeric/arithmetic_tuple.hpp"
44
44
45
- #include " cutlass/detail/blockwise_scale_layout.hpp"
46
45
47
46
// ///////////////////////////////////////////////////////////////////////////////////////////////
48
47
@@ -325,14 +324,17 @@ struct CollectiveMma<
325
324
Tensor tSFBcSFB_compact = filter_zeros (tSFBcSFB, tSFBsSFB (_,_,_,_0{}).stride ());
326
325
327
326
// Since scale granularity K is multiple of BLK_K we do not have to consider if that is OOB
327
+ // Only a few threads participate in copying and clearing scale factors in shared memory. Because
328
+ // the scale factor is broadcast across certain dimensions, multiple threads end up accessing
329
+ // the same location in shared memory.
328
330
bool load_sfa = thread_idx < cute::min (32 , ScaleMsPerTile);
331
+ bool load_sfb = thread_idx < cute::min (32 , ScaleNsPerTile);
329
332
auto residue_sf = cute::shape_div (residue_mnk,
330
333
ResidueMNK{ScaleGranularityM, ScaleGranularityN, ScaleGranularityK});
331
334
CUTLASS_PRAGMA_UNROLL
332
335
for (int i = 0 ; i < size (tSFApSFA); ++i) {
333
336
tSFApSFA (i) = load_sfa && elem_less (get<0 , 1 >(tSFAcSFA_compact (i)), get<0 >(residue_sf));
334
337
}
335
- bool load_sfb = thread_idx < cute::min (32 , ScaleNsPerTile);
336
338
CUTLASS_PRAGMA_UNROLL
337
339
for (int i = 0 ; i < size (tSFBpSFB); ++i) {
338
340
tSFBpSFB (i) = load_sfb && elem_less (get<0 , 1 >(tSFBcSFB_compact (i)), get<1 >(residue_sf));
@@ -345,8 +347,13 @@ struct CollectiveMma<
345
347
// Clear the smem tiles to account for predicated off loads
346
348
clear (tAsA);
347
349
clear (tBsB);
348
- clear (tSFAsSFA);
349
- clear (tSFBsSFB);
350
+ // Only a few threads participate in copying and clearing scale factors in shared memory.
351
+ if (load_sfa) {
352
+ clear (tSFAsSFA);
353
+ }
354
+ if (load_sfb) {
355
+ clear (tSFBsSFB);
356
+ }
350
357
// Start async loads, no k-residue handling needed
351
358
CUTLASS_PRAGMA_UNROLL
352
359
for (int k_pipe = 0 ; k_pipe < DispatchPolicy::Stages-1 ; ++k_pipe) {
0 commit comments