Skip to content

Commit

Permalink
add WITH_TMA
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Nov 27, 2024
1 parent 1a612c5 commit de3628b
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
WITH_COMPPIPE = os.getenv("ENABLE_COMPPIPE")
PEEL_LAST = os.getenv("PEEL_LAST_ITER")
WITH_TMA = os.getenv("WITH_TMA")

if HAS_TMA_DESC:
print(
Expand Down Expand Up @@ -333,6 +334,7 @@ def _attn_fwd_inner_ws(
schedList = ["default", "FA_firstDot", "FA_secondDot"] if WITH_COMPPIPE else ["default"]
# TODO: incorrect result with PEEL_LAST + FA_firstDot + WarpSpec + TMA
schedList = ["FA_secondDot"] if PEEL_LAST else schedList
tmaList = [True] if WITH_TMA else [False]
# no WS, no TMA, with CompPipe
configsOpt = [
(
Expand Down Expand Up @@ -493,7 +495,7 @@ def _attn_fwd_inner_ws(
for BM in [128]
for BN in [128]
for sched in schedList
for enable_tma in [True]
for enable_tma in tmaList
for enable_ws in [True]
for w in [4]
for buf in [2]
Expand All @@ -511,7 +513,6 @@ def _attn_fwd_inner_ws(
"ENABLE_TMA": enable_tma,
"LOOP_SCHEDULE": sched,
"GRID_MULTIPLE": mult,
"GRID_GROUP": gr,
},
num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0,
num_warps=w,
Expand All @@ -528,18 +529,16 @@ def _attn_fwd_inner_ws(
"ENABLE_TMA": enable_tma,
"LOOP_SCHEDULE": sched,
"GRID_MULTIPLE": mult,
"GRID_GROUP": gr,
},
num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0,
num_warps=w,
)
)
for BM in [128]
for BN in [128]
for mult in [1, 2, 4, 8, 16]
for gr in [0, 1]
for mult in [1] #, 2, 4, 8, 16]
for sched in schedList
for enable_tma in [True]
for enable_tma in tmaList
for enable_ws in [True]
for w in [4]
for buf in [2]
Expand Down Expand Up @@ -1384,7 +1383,6 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
LOOP_SCHEDULE: tl.constexpr,
ENABLE_WS: tl.constexpr,
GRID_MULTIPLE: tl.constexpr,
GRID_GROUP: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
# original grid
Expand All @@ -1401,12 +1399,11 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #

tile_idx = prog_id
for _ in range(0, tiles_per_sm):
if GRID_GROUP == 0:
pid = tile_idx // (Z * H)
off_hz = tile_idx % (Z * H) # tl.program_id(1)
else:
pid = tile_idx % n_tile_num
off_hz = tile_idx // n_tile_num
# This has much better cache locality than
# pid = tile_idx // (Z * H)
# off_hz = tile_idx % (Z * H) # tl.program_id(1)
pid = tile_idx % n_tile_num
off_hz = tile_idx // n_tile_num
_attn_fwd_compute_ws(
Q,
K,
Expand Down Expand Up @@ -1876,7 +1873,6 @@ def grid_tma(META):
)

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
print("NUM_SMS: ", NUM_SMS)

def grid_tma_persistent(META):
if META["ENABLE_TMA"] == False:
Expand Down

0 comments on commit de3628b

Please sign in to comment.