@@ -325,14 +325,17 @@ struct CollectiveMma<
325
325
Tensor tSFBcSFB_compact = filter_zeros (tSFBcSFB, tSFBsSFB (_,_,_,_0{}).stride ());
326
326
327
327
// Since scale granularity K is multiple of BLK_K we do not have to consider if that is OOB
328
+ // Only a few threads participate in copying and clearing scale factors in shared memory. Because
329
+ // the scale factor is broadcast across certain dimensions, multiple threads end up accessing
330
+ // the same location in shared memory.
328
331
bool load_sfa = thread_idx < cute::min (32 , ScaleMsPerTile);
332
+ bool load_sfb = thread_idx < cute::min (32 , ScaleNsPerTile);
329
333
auto residue_sf = cute::shape_div (residue_mnk,
330
334
ResidueMNK{ScaleGranularityM, ScaleGranularityN, ScaleGranularityK});
331
335
CUTLASS_PRAGMA_UNROLL
332
336
for (int i = 0 ; i < size (tSFApSFA); ++i) {
333
337
tSFApSFA (i) = load_sfa && elem_less (get<0 , 1 >(tSFAcSFA_compact (i)), get<0 >(residue_sf));
334
338
}
335
- bool load_sfb = thread_idx < cute::min (32 , ScaleNsPerTile);
336
339
CUTLASS_PRAGMA_UNROLL
337
340
for (int i = 0 ; i < size (tSFBpSFB); ++i) {
338
341
tSFBpSFB (i) = load_sfb && elem_less (get<0 , 1 >(tSFBcSFB_compact (i)), get<1 >(residue_sf));
@@ -345,8 +348,13 @@ struct CollectiveMma<
345
348
// Clear the smem tiles to account for predicated off loads
346
349
clear (tAsA);
347
350
clear (tBsB);
348
- clear (tSFAsSFA);
349
- clear (tSFBsSFB);
351
+ // Only a few threads participate in copying and clearing scale factors in shared memory.
352
+ if (load_sfa) {
353
+ clear (tSFAsSFA);
354
+ }
355
+ if (load_sfb) {
356
+ clear (tSFBsSFB);
357
+ }
350
358
// Start async loads, no k-residue handling needed
351
359
CUTLASS_PRAGMA_UNROLL
352
360
for (int k_pipe = 0 ; k_pipe < DispatchPolicy::Stages-1 ; ++k_pipe) {
0 commit comments