-
Notifications
You must be signed in to change notification settings - Fork 710
[Feat] Single Batch Overlap (SBO): Overlaping of Down GEMM with Combine Send #183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Zqy11 <[email protected]>
Co-authored-by: Zqy11 <[email protected]>
Co-authored-by: Zqy11 <[email protected]>
Co-authored-by: Zqy11 <[email protected]>
Co-authored-by: Zqy11 <[email protected]>
Co-authored-by: Zqy11 <[email protected]>
Co-authored-by: Zqy11 <[email protected]> Co-authored-by: AniZpZ <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I will have a more detailed check a bit later)
__threadfence(); | ||
|
||
if (threadIdx.x == 0) { | ||
atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw I am still a bit worried about this atomicAdd... (the code location issue in sgl-project/sglang#9660 (comment) is already solved and no problem)
EDIT: oh I see the __threadfence
that is added in the new code compared with old. then curious whether the following approach will work and whether it is faster or not: remove threadfence + but make atomicAdd a released
ordering, similar to my naive attempt here https://github.com/flashinfer-ai/flashinfer/pull/1569/files#diff-26b7ee95d08a959cf95f3a5c1719b5e00a2b0bc596227967de8e0caf74aefdcaR95 (warn again - my impl there has not been tested e2e b/c the e2e code is not ready)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will conduct further research and testing on these suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using atom.add.release.gpu.global.s32
instead of __threadfence + atomicAdd
. Bench_kineto results showed some performance benefits:
__threadfence + atomicAdd
:
Testing m-grouped masked GEMM:
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
> Perf (num_groups=1, expected_m_per_group=1024, n=4096, k=7168, 1D2D, enable_overlap=False): 347 us | 216 TFLOPS | 142 GB/s
> Perf (num_groups=1, expected_m_per_group=1024, n=7168, k=2048, 1D2D, enable_overlap=False): 159 us | 215 TFLOPS | 213 GB/s
> Perf (num_groups=2, expected_m_per_group= 512, n=4096, k=7168, 1D2D, enable_overlap=False): 348 us | 178 TFLOPS | 216 GB/s
> Perf (num_groups=2, expected_m_per_group= 512, n=7168, k=2048, 1D2D, enable_overlap=False): 159 us | 191 TFLOPS | 291 GB/s
> Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168, 1D2D, enable_overlap=False): 348 us | 174 TFLOPS | 384 GB/s
> Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048, 1D2D, enable_overlap=False): 217 us | 146 TFLOPS | 353 GB/s
> Perf (num_groups=16, expected_m_per_group= 64, n=4096, k=7168, 1D2D, enable_overlap=False): 405 us | 153 TFLOPS | 1201 GB/s
> Perf (num_groups=16, expected_m_per_group= 64, n=7168, k=2048, 1D2D, enable_overlap=False): 172 us | 164 TFLOPS | 1455 GB/s
> Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D, enable_overlap=False): 256 us | 118 TFLOPS | 1865 GB/s
> Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D, enable_overlap=False): 127 us | 115 TFLOPS | 1911 GB/s
> Perf (num_groups=1, expected_m_per_group=1024, n=4096, k=7168, 1D2D, enable_overlap=True): 351 us | 194 TFLOPS | 135 GB/s
> Perf (num_groups=1, expected_m_per_group=1024, n=7168, k=2048, 1D2D, enable_overlap=True): 124 us | 174 TFLOPS | 216 GB/s
> Perf (num_groups=2, expected_m_per_group= 512, n=4096, k=7168, 1D2D, enable_overlap=True): 351 us | 179 TFLOPS | 215 GB/s
> Perf (num_groups=2, expected_m_per_group= 512, n=7168, k=2048, 1D2D, enable_overlap=True): 165 us | 183 TFLOPS | 281 GB/s
> Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168, 1D2D, enable_overlap=True): 351 us | 165 TFLOPS | 378 GB/s
> Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048, 1D2D, enable_overlap=True): 165 us | 159 TFLOPS | 444 GB/s
> Perf (num_groups=16, expected_m_per_group= 64, n=4096, k=7168, 1D2D, enable_overlap=True): 362 us | 159 TFLOPS | 1342 GB/s
> Perf (num_groups=16, expected_m_per_group= 64, n=7168, k=2048, 1D2D, enable_overlap=True): 215 us | 148 TFLOPS | 1177 GB/s
> Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D, enable_overlap=True): 261 us | 109 TFLOPS | 1829 GB/s
> Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D, enable_overlap=True): 135 us | 108 TFLOPS | 1798 GB/s
atom.add.release.gpu.global.s32
:
Testing m-grouped masked GEMM:
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
> Perf (num_groups=1, expected_m_per_group=1024, n=4096, k=7168, 1D2D, enable_overlap=False): 347 us | 216 TFLOPS | 142 GB/s
> Perf (num_groups=1, expected_m_per_group=1024, n=7168, k=2048, 1D2D, enable_overlap=False): 159 us | 215 TFLOPS | 213 GB/s
> Perf (num_groups=2, expected_m_per_group= 512, n=4096, k=7168, 1D2D, enable_overlap=False): 348 us | 178 TFLOPS | 216 GB/s
> Perf (num_groups=2, expected_m_per_group= 512, n=7168, k=2048, 1D2D, enable_overlap=False): 159 us | 191 TFLOPS | 291 GB/s
> Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168, 1D2D, enable_overlap=False): 348 us | 174 TFLOPS | 384 GB/s
> Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048, 1D2D, enable_overlap=False): 217 us | 146 TFLOPS | 353 GB/s
> Perf (num_groups=16, expected_m_per_group= 64, n=4096, k=7168, 1D2D, enable_overlap=False): 405 us | 153 TFLOPS | 1202 GB/s
> Perf (num_groups=16, expected_m_per_group= 64, n=7168, k=2048, 1D2D, enable_overlap=False): 172 us | 164 TFLOPS | 1456 GB/s
> Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D, enable_overlap=False): 256 us | 118 TFLOPS | 1866 GB/s
> Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D, enable_overlap=False): 127 us | 115 TFLOPS | 1909 GB/s
> Perf (num_groups=1, expected_m_per_group=1024, n=4096, k=7168, 1D2D, enable_overlap=True): 350 us | 195 TFLOPS | 136 GB/s
> Perf (num_groups=1, expected_m_per_group=1024, n=7168, k=2048, 1D2D, enable_overlap=True): 123 us | 175 TFLOPS | 217 GB/s
> Perf (num_groups=2, expected_m_per_group= 512, n=4096, k=7168, 1D2D, enable_overlap=True): 349 us | 180 TFLOPS | 216 GB/s
> Perf (num_groups=2, expected_m_per_group= 512, n=7168, k=2048, 1D2D, enable_overlap=True): 163 us | 184 TFLOPS | 283 GB/s
> Perf (num_groups=4, expected_m_per_group= 256, n=4096, k=7168, 1D2D, enable_overlap=True): 350 us | 165 TFLOPS | 379 GB/s
> Perf (num_groups=4, expected_m_per_group= 256, n=7168, k=2048, 1D2D, enable_overlap=True): 164 us | 160 TFLOPS | 448 GB/s
> Perf (num_groups=16, expected_m_per_group= 64, n=4096, k=7168, 1D2D, enable_overlap=True): 359 us | 160 TFLOPS | 1351 GB/s
> Perf (num_groups=16, expected_m_per_group= 64, n=7168, k=2048, 1D2D, enable_overlap=True): 209 us | 152 TFLOPS | 1207 GB/s
> Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D, enable_overlap=True): 259 us | 110 TFLOPS | 1844 GB/s
> Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D, enable_overlap=True): 132 us | 111 TFLOPS | 1846 GB/s
However, after some research, I concluded that release semantics ensure that all memory writes initiated by the same thread executing the atomic instruction before the atomic instruction are visible to other threads that subsequently observe the results of the atomic operation through an acquire operation. In other words, the guarantee of release semantics is bound to the thread executing the atomic operation. However, the thread initiating the TMA operation and the thread executing the write signal are not necessarily the same, so I'm concerned that this may cause problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My naive understanding is the following (please correct me if I am wrong!):
the thread initiating the TMA operation and the thread executing the write signal are not necessarily the same
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
ensures all math threads sync at this point, thus we know all tma_store_wait calls have finished at this point, thus when we do a atom.add.release.gpu.global.s32
on thread 0, we already ensure tma_store_wait from thread 0,1,2,3,... all done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we may apply cuda::atomic_ref and fetch_add instead of atomicAdd @Sulfur6
|
||
if constexpr (kEnableOverlap) { | ||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { | ||
cute::tma_store_wait<0>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one more naive worry: tma_store_wait
seems to correspond to cp.async.bulk.wait_group.read
(src: https://github.com/NVIDIA/cutlass/blob/76c96b0be35cb263debe3e3d8418b80911a544ab/include/cute/arch/copy_sm90_tma.hpp#L1251), but it seems that we need cp.async.bulk.wait_group
(no ".read"), otherwise the semantics (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-wait-group) is that, "the tma store has done reading from source, but the Writes being made visible to the executing thread
may not have been done".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right. we may apply "asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory")" instead of tma_store_wait here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking forward to the fix
1. Motivation
The optimization effect of Two-Batch Overlap (TBO) is suboptimal for the Decode phase on low-compute-power cards (i.e., H20). This is due to two main factors: First, on the Hopper architecture, the WGMMA block_m is 64. Consequently, when TBO is enabled with a small Decode batch size, the MLP GEMM suffers from redundant computations. A positive throughput gain is only observed at larger batch sizes (e.g., 64, 128). Second, at these larger batch sizes, low-compute-power cards like the H20 fail to meet the SLA guarantees for TPOT/ITL.
Therefore, it is necessary to find a solution that can improve Decode throughput even with small batch sizes. Single Batch Overlap (SBO) presents itself as a viable solution.
We implement SBO for DeepSeek v3/R1 by modifying DeepEP and DeepGEMM, including the overlap of Shared Expert and Dispatch Recv, as well as the overlap of Down GEMM with Combine Send.
The overlap of Down GEMM with Combine Send is implemented by modifying SGlang, DeepEP and DeepGEMM, with the detailed implementation available in the PRs below:
We also conducted integration and evaluation in SGLang: sgl-project/sglang#9660.
Since the latest version of SGLang depends on the branch https://github.com/sgl-project/DeepGEMM/tree/sgl, you should not use this branch when starting SGLang. Instead, you should use the branch developed based on the sgl branch https://github.com/Sulfur6/DeepGEMM/tree/sbo.v2.sgl.
2. Overlap Design
SBO implements two overlap for the MoE layers of DeepSeek-V3/R1. One is to overlap the Shared Expert computation with the Dispatch Recv communication, and the other is to overlap the Down GEMM computation with the Combine Send communication.


The interaction between Down GEMM and Combine Send is structured as a producer-consumer model synchronized by signals. For each local expert, a signal unit is allocated for every block_m tokens. The Down GEMM computes the results for these block_m tokens and atomically increments the signaling unit after completing a portion of the work. The Combine Send polls this signaling unit. Once the value reaches a threshold, it sends the corresponding block_m tokens.
3. Modifications
m_grouped_fp8_gemm_nt_signal
Python interface to support overlapping Down GEMM with Combine Send.SM90FP8SignalGemm1D2DRuntime
andsm90_m_grouped_fp8_gemm_signal_1d2d
to support Signal Down GEMM in SM90.sm90_fp8_signal_gemm_1d2d_impl
kernel usesatomicAdd
to write signal after the corresponding block_m tokens are computed.4. Evaluation
We integrated the modified DeepEP and DeepGEMM into SGLang for performance evaluation.
4.1. Experiment Setup
4.2. Performance Evaluation
4.3. Accuracy Tests
4.4. Repro Script
Please refer to sgl-project/sglang#9660.