Skip to content

Commit 3521a01

Browse files
committed
Use load_sf* flags to limit threads that perform clear sf*.
1 parent 034e486 commit 3521a01

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,17 @@ struct CollectiveMma<
325325
Tensor tSFBcSFB_compact = filter_zeros(tSFBcSFB, tSFBsSFB(_,_,_,_0{}).stride());
326326

327327
// Since scale granularity K is multiple of BLK_K we do not have to consider if that is OOB
328+
// Only a few threads participate in copying and clearing scale factors in shared memory. Because
329+
// the scale factor is broadcast across certain dimensions, multiple threads end up accessing
330+
// the same location in shared memory.
328331
bool load_sfa = thread_idx < cute::min(32, ScaleMsPerTile);
332+
bool load_sfb = thread_idx < cute::min(32, ScaleNsPerTile);
329333
auto residue_sf = cute::shape_div(residue_mnk,
330334
ResidueMNK{ScaleGranularityM, ScaleGranularityN, ScaleGranularityK});
331335
CUTLASS_PRAGMA_UNROLL
332336
for (int i = 0; i < size(tSFApSFA); ++i) {
333337
tSFApSFA(i) = load_sfa && elem_less(get<0, 1>(tSFAcSFA_compact(i)), get<0>(residue_sf));
334338
}
335-
bool load_sfb = thread_idx < cute::min(32, ScaleNsPerTile);
336339
CUTLASS_PRAGMA_UNROLL
337340
for (int i = 0; i < size(tSFBpSFB); ++i) {
338341
tSFBpSFB(i) = load_sfb && elem_less(get<0, 1>(tSFBcSFB_compact(i)), get<1>(residue_sf));
@@ -345,8 +348,13 @@ struct CollectiveMma<
345348
// Clear the smem tiles to account for predicated off loads
346349
clear(tAsA);
347350
clear(tBsB);
348-
clear(tSFAsSFA);
349-
clear(tSFBsSFB);
351+
// Only a few threads participate in copying and clearing scale factors in shared memory.
352+
if (load_sfa) {
353+
clear(tSFAsSFA);
354+
}
355+
if (load_sfb) {
356+
clear(tSFBsSFB);
357+
}
350358
// Start async loads, no k-residue handling needed
351359
CUTLASS_PRAGMA_UNROLL
352360
for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) {

0 commit comments

Comments
 (0)