Skip to content

Commit ab3b26e

Browse files
committed
Use load_sf* flags to limit threads that perform clear sf*.
1 parent 9f7badd commit ab3b26e

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,19 @@
2929
*
3030
**************************************************************************************************/
3131

32-
// Inspired by: sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
32+
// Inspired by: sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
3333

3434
#pragma once
3535

3636
#include "cutlass/cutlass.h"
37+
#include "cutlass/detail/blockwise_scale_layout.hpp"
3738
#include "cutlass/gemm/dispatch_policy.hpp"
3839

3940
#include "cute/algorithm/functional.hpp"
4041
#include "cute/atom/mma_atom.hpp"
4142
#include "cute/algorithm/gemm.hpp"
42-
#include "cute/tensor_predicate.hpp"
4343
#include "cute/numeric/arithmetic_tuple.hpp"
4444

45-
#include "cutlass/detail/blockwise_scale_layout.hpp"
4645

4746
/////////////////////////////////////////////////////////////////////////////////////////////////
4847

@@ -325,14 +324,17 @@ struct CollectiveMma<
325324
Tensor tSFBcSFB_compact = filter_zeros(tSFBcSFB, tSFBsSFB(_,_,_,_0{}).stride());
326325

327326
// Since scale granularity K is multiple of BLK_K we do not have to consider if that is OOB
327+
// Only a few threads participate in copying and clearing scale factors in shared memory. Because
328+
// the scale factor is broadcast across certain dimensions, multiple threads end up accessing
329+
// the same location in shared memory.
328330
bool load_sfa = thread_idx < cute::min(32, ScaleMsPerTile);
331+
bool load_sfb = thread_idx < cute::min(32, ScaleNsPerTile);
329332
auto residue_sf = cute::shape_div(residue_mnk,
330333
ResidueMNK{ScaleGranularityM, ScaleGranularityN, ScaleGranularityK});
331334
CUTLASS_PRAGMA_UNROLL
332335
for (int i = 0; i < size(tSFApSFA); ++i) {
333336
tSFApSFA(i) = load_sfa && elem_less(get<0, 1>(tSFAcSFA_compact(i)), get<0>(residue_sf));
334337
}
335-
bool load_sfb = thread_idx < cute::min(32, ScaleNsPerTile);
336338
CUTLASS_PRAGMA_UNROLL
337339
for (int i = 0; i < size(tSFBpSFB); ++i) {
338340
tSFBpSFB(i) = load_sfb && elem_less(get<0, 1>(tSFBcSFB_compact(i)), get<1>(residue_sf));
@@ -345,8 +347,13 @@ struct CollectiveMma<
345347
// Clear the smem tiles to account for predicated off loads
346348
clear(tAsA);
347349
clear(tBsB);
348-
clear(tSFAsSFA);
349-
clear(tSFBsSFB);
350+
// Only a few threads participate in copying and clearing scale factors in shared memory.
351+
if (load_sfa) {
352+
clear(tSFAsSFA);
353+
}
354+
if (load_sfb) {
355+
clear(tSFBsSFB);
356+
}
350357
// Start async loads, no k-residue handling needed
351358
CUTLASS_PRAGMA_UNROLL
352359
for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) {

0 commit comments

Comments
 (0)