Skip to content

Commit

Permalink
fix an assertion failure due to refactoring in PR54 (#69)
Browse files Browse the repository at this point in the history
Summary:
We move the static_assert to the top-level kernel. After moving, the static_assert will be caught by autotuner:
        try:
            return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
        except (OutOfResources, CompileTimeAssertionFailure, PTXASError):
            return [float("inf"), float("inf"), float("inf")]

Prior to the change, CompileTimeAssertionFailure somehow is not caught and got reported and failed the build.

Verified with: python run.py --op fp8_attention
python run.py --op flash_attention --only triton_tutorial_flash_v2 --num-inputs 1 --metrics tflops --num-inputs 1

Pull Request resolved: #69

Reviewed By: xuzhao9, adamomainz

Differential Revision: D66336174

Pulled By: manman-ren

fbshipit-source-id: 95d29821e6cba45af535b11020aa51424a408789
  • Loading branch information
manman-ren authored and facebook-github-bot committed Nov 22, 2024
1 parent fbdcfcf commit 8f8db26
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,11 @@ def _attn_fwd_inner_ws(
num_warps=w,
)
)
for BM in [128]
for BN in [128]
for BM in [64, 128]
for BN in [64, 128]
for sched in schedList
for enable_tma in [False]
for w in [8]
for w in [4, 8]
]
# no WS, with TMA and CompPipe
configsTma = [
Expand Down Expand Up @@ -393,11 +393,11 @@ def _attn_fwd_inner_ws(
num_warps=w,
)
)
for BM in [128]
for BN in [128]
for BM in [64, 128]
for BN in [64, 128]
for sched in schedList
for enable_tma in [True]
for w in [8]
for w in [4, 8]
]
# no TMA, with WS and CompPipe
configsWS = [
Expand Down Expand Up @@ -454,10 +454,10 @@ def _attn_fwd_inner_ws(
num_warps=w,
)
)
for BM in [128] # 64, 128]
for BN in [128] # 64, 128]
for s in [3] # , 4, 7]
for w in [8] # 4, 8]
for BM in [64, 128]
for BN in [64, 128]
for s in [3, 4, 7]
for w in [4, 8]
]
# TMA, WS, and CompPipe
configsTmaWS = [
Expand Down Expand Up @@ -548,7 +548,6 @@ def _attn_fwd_compute(
ENABLE_TMA: tl.constexpr,
LOOP_SCHEDULE: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
Expand Down Expand Up @@ -729,7 +728,6 @@ def _attn_fwd_compute_ws(
ENABLE_TMA: tl.constexpr,
LOOP_SCHEDULE: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
Expand Down Expand Up @@ -914,6 +912,7 @@ def _attn_fwd_ws(
LOOP_SCHEDULE: tl.constexpr,
ENABLE_WS: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
_attn_fwd_compute_ws(
Q,
K,
Expand Down Expand Up @@ -993,6 +992,7 @@ def _attn_fwd(
LOOP_SCHEDULE: tl.constexpr,
ENABLE_WS: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
_attn_fwd_compute(
Q,
K,
Expand Down Expand Up @@ -1072,6 +1072,7 @@ def _attn_fwd_opt( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
LOOP_SCHEDULE: tl.constexpr,
ENABLE_WS: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
_attn_fwd_compute(
Q,
K,
Expand Down Expand Up @@ -1151,6 +1152,7 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
LOOP_SCHEDULE: tl.constexpr,
ENABLE_WS: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
_attn_fwd_compute(
Q,
K,
Expand Down Expand Up @@ -1230,6 +1232,7 @@ def _attn_fwd_tma_ws( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
LOOP_SCHEDULE: tl.constexpr,
ENABLE_WS: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
_attn_fwd_compute_ws(
Q,
K,
Expand Down

0 comments on commit 8f8db26

Please sign in to comment.