Skip to content

Commit

Permalink
Add new triton kernel debug_fill_dropout_rng (#19)
Browse files Browse the repository at this point in the history
* tritonsrc/dropout_rng: Add new triton kernel debug_fill_dropout_rng

This is for _fill_mem_eff_dropout_mask_
  • Loading branch information
xinyazhang committed Apr 30, 2024
1 parent 457b7bf commit 71bd17f
Show file tree
Hide file tree
Showing 13 changed files with 588 additions and 329 deletions.
7 changes: 7 additions & 0 deletions bindings/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ namespace pyaotriton {
py::arg("philox_offset"),
py::arg("is_causal"),
py::arg("stream") = nullptr);
m.def("debug_fill_dropout_rng",
&aotriton::v2::flash::debug_fill_dropout_rng,
"Flash Attention Debugging Function to get raw RNG numbers used in dropout",
py::arg("q"),
py::arg("philox_seed"),
py::arg("philox_offset"),
py::arg("stream") = nullptr);
}
} // namespace flash

Expand Down
6 changes: 6 additions & 0 deletions include/aotriton/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ attn_bwd(T4 q, // batch_size x num_heads x seqlen_q x head_size
bool is_causal,
aotriton::Stream stream);

hipError_t
debug_fill_dropout_rng(T4 r,
uint64_t philox_seed,
uint64_t philox_offset,
aotriton::Stream stream);

} // aotriton::v2::flash

#endif
13 changes: 12 additions & 1 deletion test/aotriton_flash.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright © 2023-2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

from pyaotriton.v2.flash import attn_fwd as fa_forward, attn_bwd as fa_backward
from pyaotriton.v2.flash import (
attn_fwd as fa_forward,
attn_bwd as fa_backward,
debug_fill_dropout_rng as fa_debug_fill_dropout_rng,
)
from pyaotriton import T1, T2, T4, DType, Stream

def cast_dtype(dtype):
Expand Down Expand Up @@ -68,3 +72,10 @@ def attn_bwd(q, k, v, b, sm_scale, o, dout, dq, dk, dv, db, L, delta,
is_causal,
Stream())
print(f'{err=}')

def debug_fill_dropout_rng(R, philox_seed, philox_offset):
err = fa_debug_fill_dropout_rng(mk_aotensor(R),
philox_seed,
philox_offset,
Stream())
print(f'{err=}')
8 changes: 5 additions & 3 deletions test/attn_torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
# SPDX-License-Identifier: MIT

import torch
from aotriton_flash import attn_fwd, attn_bwd
from aotriton_flash import attn_fwd, attn_bwd, debug_fill_dropout_rng

VERBOSE=False
DEFAULT_PHILOX_SEED = 0x1BF52
DEFAULT_PHILOX_OFFSET = 0x1D4B42

def is_power_of_two(n: int) -> bool:
return (n & (n - 1) == 0) and n != 0
Expand Down Expand Up @@ -53,8 +55,8 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax
if encoded_softmax is not None:
print(f'{encoded_softmax.shape=} {encoded_softmax.dtype=}')

philox_seed = 114514
philox_offset = 1919810
philox_seed = DEFAULT_PHILOX_SEED
philox_offset = DEFAULT_PHILOX_OFFSET

attn_fwd(q, k, v, b, sm_scale, M, o,
dropout_p, philox_seed, philox_offset, encoded_softmax, causal);
Expand Down
367 changes: 206 additions & 161 deletions test/test_forward.py

Large diffs are not rendered by default.

27 changes: 24 additions & 3 deletions tritonsrc/attn_torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
import torch
import triton
import triton.language as tl
from flash import attn_fwd as bare_attn_fwd
from flash import (
debug_fill_dropout_rng as bare_debug_fill_dropout_rng,
attn_fwd as bare_attn_fwd,
bwd_preprocess as bare_bwd_preprocess,
bwd_kernel_dk_dv as bare_bwd_kernel_dk_dv,
bwd_kernel_dq as bare_bwd_kernel_dq
)

VERBOSE=False
DEFAULT_PHILOX_SEED = 0x1BF52
DEFAULT_PHILOX_OFFSET = 0x1D4B42

def is_power_of_two(n: int) -> bool:
return (n & (n - 1) == 0) and n != 0
Expand Down Expand Up @@ -266,8 +269,8 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax
if encoded_softmax is not None:
print(f'{encoded_softmax.shape=} {encoded_softmax.dtype=}')

philox_seed = 114514
philox_offset = 1919810
philox_seed = DEFAULT_PHILOX_SEED
philox_offset = DEFAULT_PHILOX_OFFSET
if b is None:
b = torch.empty((0,0,0,0), device=q.device, dtype=q.dtype)
BIAS_TYPE = 0
Expand Down Expand Up @@ -683,3 +686,21 @@ def backward(ctx, do, _, fwd_tuning_result):
return dq, dk, dv, None if db.numel() == 0 else db, None, None, None, None, None, None, None

attention = _attention.apply

def debug_fill_dropout_rng(dropout_rng, philox_seed, philox_offset):
BLOCK_M = 64
BLOCK_N = 32
BATCH, N_HEADS, seqlen_q, seqlen_k = dropout_rng.size()
grid_rng = lambda META: (
triton.cdiv(seqlen_q, META['BLOCK_M']),
N_HEADS,
BATCH,
)
r = dropout_rng
bare_debug_fill_dropout_rng[grid_rng](r,
r.stride(0), r.stride(1), r.stride(2), r.stride(3),
seqlen_q, seqlen_k,
philox_seed,
philox_offset,
BLOCK_M, BLOCK_N,
num_stages=1)
37 changes: 37 additions & 0 deletions tritonsrc/dropout_rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python
# Copyright © 2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

import triton
import triton.language as tl
from fwd_kernel import dropout_rng

@triton.jit
def debug_fill_dropout_rng(R,
stride_rz, stride_rh, stride_rm, stride_rn,
seqlen_q, seqlen_k,
philox_seed,
philox_offset_base,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_h = tl.program_id(1) # head index
off_z = tl.program_id(2) # batch index
d_offset = off_h * stride_rh + off_z * stride_rz
num_h = tl.num_programs(1)
off_zh = off_z * num_h + off_h * 1
batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k
R_block_ptr = tl.make_block_ptr(
base=R + d_offset,
shape=(seqlen_q, seqlen_k),
strides=(stride_rm, stride_rn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0)
)
for start_n in range(0, seqlen_k, BLOCK_N):
philox_offset = batch_philox_offset + start_m * BLOCK_M * seqlen_k + start_n
rng = dropout_rng(philox_seed, philox_offset, BLOCK_M, BLOCK_N, seqlen_k)
tl.store(R_block_ptr, rng.to(R_block_ptr.type.element_ty), boundary_check=(0,1))
R_block_ptr = tl.advance(R_block_ptr, (0, BLOCK_N))
1 change: 1 addition & 0 deletions tritonsrc/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from fwd_kernel import attn_fwd
from bwd_preprocess import bwd_preprocess
from bwd_split_kernel import bwd_kernel_dk_dv, bwd_kernel_dq
from dropout_rng import debug_fill_dropout_rng
8 changes: 4 additions & 4 deletions tritonsrc/fwd_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ def max_fn(x, y):
return tl.math.max(x, y)

@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
def dropout_offsets(philox_seed, philox_offset, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]

@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
def dropout_rng(philox_seed, philox_offset, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, m, n, stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)

@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_output = dropout_rng(philox_seed, philox_offset, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep

Expand Down
Loading

0 comments on commit 71bd17f

Please sign in to comment.