diff --git a/test/test_forward.py b/test/test_forward.py index 8a93d82..bae5c65 100644 --- a/test/test_forward.py +++ b/test/test_forward.py @@ -51,6 +51,15 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask print(f'{av[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') return av, attn_weight +def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None, device=None): + """ Clones the query, key, and value tensors and moves them to the specified dtype. """ + if dtype is None: + dtype = query.dtype + query_ref = query.clone().detach().to(dtype=dtype, device=device).requires_grad_(query.requires_grad) + key_ref = key.clone().detach().to(dtype=dtype, device=device).requires_grad_(key.requires_grad) + value_ref = value.clone().detach().to(dtype=dtype, device=device).requires_grad_(value.requires_grad) + return query_ref, key_ref, value_ref + ''' Flash Attention is batch operator that evaluates sm(QK')V Q = batch_size x ... x seqlen_q x head_size @@ -72,8 +81,8 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask @pytest.mark.parametrize('D_HEAD', [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256]) # @pytest.mark.parametrize('seqlen_q', [16,32,64,128,256,512,1024]) # @pytest.mark.parametrize('seqlen_k', [16,32,64,128,256,512,1024]) -@pytest.mark.parametrize('seqlen_q', [4,8,16,17,32,64,128,143,256,512,1024,2048]) -@pytest.mark.parametrize('seqlen_k', [4,8,16,23,32,64,128,256,512,587,1024,2048]) +@pytest.mark.parametrize('seqlen_q', [4, 8, 64, 143, 256, 512, 1024, 2048]) +@pytest.mark.parametrize('seqlen_k', [4, 8, 64, 128, 256, 587, 1024, 2048]) # @pytest.mark.parametrize('seqlen_q', [32, 128]) # @pytest.mark.parametrize('seqlen_k', [32, 128]) @pytest.mark.parametrize('causal', [False, True]) @@ -81,7 +90,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask # @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('sm_scale', [0.0, 1.2]) -@pytest.mark.parametrize('storage_flip', [True, False]) +@pytest.mark.parametrize('storage_flip', [False, True]) # @pytest.mark.parametrize('return_encoded_softmax', [False]) def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip): if causal and seqlen_q != seqlen_k: @@ -160,11 +169,16 @@ def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr ref_out = torch.matmul(p, v) ''' return_encoded_softmax = dropout_p > 0.0 + higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 + REF_DEVICE=None + q_ref, k_ref, v_ref = query_key_value_clones(q, k, v, dtype=higher_precision_dtype, device=REF_DEVICE) + def TO(ref_tensor): + return ref_tensor.to(device=q.device, dtype=dtype) tri_out, encoded_softmax, _ = attention(q, k, v, causal, sm_scale, dropout_p, return_encoded_softmax) dropout_mask = encoded_softmax > 0 if encoded_softmax is not None else None # assert torch.allclose(dropout_mask, dropout_mask_naive) - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q_ref, k_ref, v_ref, dropout_p=dropout_p, is_causal=causal, scale=sm_scale, @@ -192,17 +206,18 @@ def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dr ATOL = 1e-1 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) else: ATOL = 1e-2 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) - print(f'Using ATOL={ATOL}') - is_allclose = torch.allclose(ref_out, tri_out, atol=ATOL, rtol=0) + RTOL = 0.0 + print(f'Using ATOL={ATOL} RTOL={RTOL}') + is_allclose = torch.allclose(TO(ref_out), tri_out, atol=ATOL, rtol=RTOL) if not is_allclose: import numpy as np - err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) + err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape) print(f'{err_idx=}') print(f'{tri_out[err_idx]=} {ref_out[err_idx]=} error: {tri_out[err_idx] - ref_out[err_idx]}') # if not is_allclose: if False: import numpy as np - err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) + err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape) print(f'{tri_out[0][0][0][:]=}') print(f'{ref_out[0][0][0][:]=}') print(f'{mref_out[0][0][0][:]=}') diff --git a/tritonsrc/fwd_kernel.py b/tritonsrc/fwd_kernel.py index ce706d7..68da243 100644 --- a/tritonsrc/fwd_kernel.py +++ b/tritonsrc/fwd_kernel.py @@ -41,107 +41,73 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): return rng_keep @triton.jit -def _attn_fwd_inner( +def attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - bias_ptr, start_m, - seqlen_k, - seqlen_k_faligned, - # FIXME: The usage of seqlen_k and seqlen_k_faligned was compromised, and - # the fix is not straightforward. - # dropout_seqlen_k was added as a quick fix. - dropout_seqlen_k, + seqlen_q, + seqlen_k_low, + seqlen_k_high, + k_padded, dropout_p, + dropout_seqlen_k, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, + CAUSAL: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + pre_load_v: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_BLOCK: tl.constexpr, + MARGINAL_BLOCK: tl.constexpr, # MARGINAL_BLOCK = CAUSAL or k_padded PADDED_HEAD: tl.constexpr, ): - # range of values handled by this stage - if STAGE == 1: # "Solid" blocks of Causal masks - lo, hi = 0, min(seqlen_k, start_m * BLOCK_M) - elif STAGE == 2: # "Semi-solid", or "Transition" block of Causal mask - # Must use BLOCK_M, because the starting position of semi-solid block - # is determined by start_m * BLOCK_M - lo, hi = start_m * BLOCK_M, min(seqlen_k, start_m * BLOCK_M + BLOCK_M) - # lo = tl.multiple_of(lo, BLOCK_M) + lo, hi = seqlen_k_low, seqlen_k_high + if MARGINAL_BLOCK: K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) if RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, lo)) - tl.static_assert(PADDED_BLOCK == False, 'STAGE=2 should not be used with PADDED_BLOCK=True') - # So here, we are computing the elements for that last irregular block. - # In the loop, we will mask the elements of BLOCK_N that do not exist. - elif PADDED_BLOCK: - lo, hi = seqlen_k, seqlen_k + BLOCK_N - # lo = tl.multiple_of(lo, BLOCK_N) - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, lo)) - if bias_ptr is not None: - if bias_ptr.type.element_ty.is_block(): - bias_ptr = tl.advance(bias_ptr, (0, lo)) - else: - bias_ptr += lo - # causal = False - else: - lo, hi = 0, seqlen_k # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): - ''' - if STAGE == 1 or STAGE == 3: - start_n = tl.multiple_of(start_n, BLOCK_N) - ''' - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - if (PADDED_BLOCK or STAGE == 2) or PADDED_HEAD: - k = tl.load(K_block_ptr, boundary_check=(1,0), padding_option="zero") + # -- compute qk ---- + # MARGINAL_BLOCK serves as a compile-time switch for first attn_fwd_inner calls to "solid" blocks + if MARGINAL_BLOCK and k_padded: + if PADDED_HEAD: + k = tl.load(K_block_ptr, boundary_check=(1,0), padding_option="zero") + else: + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") else: - k = tl.load(K_block_ptr, boundary_check=(0,), padding_option="zero") - if PRE_LOAD_V: - if (PADDED_BLOCK or STAGE == 2) or PADDED_HEAD: - v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + if PADDED_HEAD: + k = tl.load(K_block_ptr, boundary_check=(0,), padding_option="zero") else: - v = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") - # -- compute qk ---- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if STAGE == 2: - mask = OFFS_M[:, None] >= (start_n + OFFS_N[None, :]) - qk = tl.where(mask, qk, float("-inf")) - if PADDED_BLOCK: - boundary_m = tl.full([BLOCK_M], seqlen_k_faligned, dtype=tl.int32) - size_n = start_n + OFFS_N[None,:] - mask = size_n < boundary_m[:,None] - qk = tl.where(mask, qk, float("-inf")) - qk += tl.dot(q, k) - if bias_ptr is not None: - if PADDED_BLOCK: - if bias_ptr.type.element_ty.is_block(): - bias = tl.load(bias_ptr,boundary_check=(1,), padding_option="zero") + k = tl.load(K_block_ptr) + if pre_load_v: + if MARGINAL_BLOCK and k_padded: + if PADDED_HEAD: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") else: - size_n = start_n + OFFS_N - boundary_n = tl.full([BLOCK_N], seqlen_k_faligned, dtype=tl.float32) - bias_padding = tl.full([BLOCK_N], 0, dtype=tl.float32) - bias = tl.load(bias_ptr, mask=size_n < boundary_n, other=bias_padding) + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") else: - bias = tl.load(bias_ptr) - # While bias is added after multiplying qk with sm_scale, - # our optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += (bias * 1.44269504089) - + if PADDED_HEAD: + v = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") + else: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if MARGINAL_BLOCK: + if CAUSAL: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = tl.where(mask, qk, float("-inf")) + if k_padded: + boundary_m = tl.full([BLOCK_M], seqlen_k_high, dtype=tl.int32) + size_n = start_n + offs_n[None,:] + mask = size_n < boundary_m[:,None] + qk = tl.where(mask, qk, float("-inf")) + qk += tl.dot(q, k) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) @@ -149,7 +115,7 @@ def _attn_fwd_inner( l_ij = tl.sum(p, 1) # Note about the conflicts of Flash attention algorithm and PyTorch's CUDA implementation # PyTorch needs to return softmax(qk) (dropout mask encoded in sign bits) - # While Flash attention paper compute the dropout AFTER exp2(qk- m_ij) + # While Flash attention paper computer the dropout AFTER exp2(qk- m_ij) if ENABLE_DROPOUT: philox_offset = batch_philox_offset + start_m * BLOCK_M * dropout_seqlen_k + start_n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, dropout_seqlen_k) @@ -157,15 +123,23 @@ def _attn_fwd_inner( tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), boundary_check=(0,1)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty), boundary_check=(0,1)) + tl.store(encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + boundary_check=(0,1)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - if not PRE_LOAD_V: - if (PADDED_BLOCK or STAGE == 2) or PADDED_HEAD: - v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + if not pre_load_v: + if MARGINAL_BLOCK and k_padded: + if PADDED_HEAD: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") else: - v = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") + if PADDED_HEAD: + v = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") + else: + v = tl.load(V_block_ptr) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i @@ -173,102 +147,75 @@ def _attn_fwd_inner( acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - if bias_ptr.type.element_ty.is_block(): - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - else: - bias_ptr += BLOCK_N if RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i + @triton.jit def attn_fwd( - Q, K, V, B, sm_scale, M, Out, + Q, K, V, sm_scale, M, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, - stride_bz, stride_bh, stride_bm, stride_bn, - cu_seqlens_q, cu_seqlens_k, - max_seqlens_q, max_seqlens_k, - head_dim_q, head_dim_k, + seqlen_q, + seqlen_k, + head_dim, dropout_p, philox_seed, philox_offset_base, encoded_softmax, - VARLEN: tl.constexpr, STAGE: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, + pre_load_v: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, - BIAS_TYPE: tl.constexpr, - PADDED_HEAD: tl.constexpr, # Cannot be inferred by AOT Compiler + PADDED_HEAD: tl.constexpr, ): - # FIXME: MQA should be num_heads_q != num_heads_k. - is_mqa = head_dim_q != head_dim_k start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index num_h = tl.num_programs(1) num_z = tl.num_programs(2) - if VARLEN: - cu_seqlens_q_this = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_next = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_next - cu_seqlens_q_this - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_this = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_next = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_next - cu_seqlens_k_this + if start_m * BLOCK_M + BLOCK_M > seqlen_q: + q_padded = True else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = max_seqlens_q - seqlen_k = max_seqlens_k - if is_mqa: - off_h_k = off_h_q % head_dim_k - else: - off_h_k = off_h_q - need_padding = False + q_padded = False + k_padded = True if seqlen_k < BLOCK_N: - need_padding = True - extra_tokens_n = BLOCK_N - seqlen_k seqlen_k_faligned = 0 # floor aligned elif seqlen_k % BLOCK_N: - need_padding = True extra_tokens_n = seqlen_k % BLOCK_N seqlen_k_faligned = seqlen_k - extra_tokens_n else: + k_padded = False seqlen_k_faligned = seqlen_k - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_offset = off_h * stride_qh + off_z * stride_qz Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, - shape=(seqlen_q, head_dim_q), + shape=(seqlen_q, head_dim), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_offset = off_h * stride_kh + off_z * stride_kz K_block_ptr = tl.make_block_ptr( base=K + k_offset, - shape=(head_dim_k, seqlen_k), + shape=(head_dim, seqlen_k), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_offset = off_h * stride_vh + off_z * stride_vz V_block_ptr = tl.make_block_ptr( base=V + v_offset, - shape=(seqlen_k, head_dim_k), + shape=(seqlen_k, head_dim), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -277,21 +224,6 @@ def attn_fwd( # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - off_zh = off_z * num_h + off_h_q * 1 - if BIAS_TYPE != 0: - if BIAS_TYPE == 1: - bias_ptr = B + off_h_q * stride_bh + offs_n - elif BIAS_TYPE == 2: - bias_ptr = tl.make_block_ptr( - base=B + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 @@ -301,19 +233,25 @@ def attn_fwd( # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504089 # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs - q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") + if q_padded: + if PADDED_HEAD: + q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") + else: + if PADDED_HEAD: + q = tl.load(Q_block_ptr, boundary_check=(1,), padding_option="zero") + else: + q = tl.load(Q_block_ptr) q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - # stage 1: off-band - # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE - # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + # For causal = True, STAGE = 3 and attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and attn_fwd_inner gets 3 as its STAGE + off_zh = off_z * num_h + off_h * 1 if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k else: batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. In - # this case, we return an invalid pointer so indicate the mask is not valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. if RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = tl.make_block_ptr( base=encoded_softmax + off_zh * seqlen_q * seqlen_k, @@ -325,53 +263,52 @@ def attn_fwd( ) else: encoded_softmax_block_ptr = 0 - if STAGE & 1: - # equal to N_CTX if N_CTX is already a multiple of block_M - if seqlen_k >= BLOCK_N: - acc, l_i, m_i = _attn_fwd_inner( - acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - bias_ptr, - start_m, seqlen_k_faligned, seqlen_k_faligned, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 4 - STAGE, offs_m, offs_n, - PRE_LOAD_V, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - PADDED_BLOCK=False, - PADDED_HEAD=PADDED_HEAD, - ) - tl.debug_barrier() - if need_padding: - acc, l_i, m_i = _attn_fwd_inner( - acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - bias_ptr, - start_m, seqlen_k_faligned, seqlen_k, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 4 - STAGE, offs_m, offs_n, - PRE_LOAD_V, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - PADDED_BLOCK=True, - PADDED_HEAD=PADDED_HEAD, - ) - # stage 2: on-band - if STAGE & 2: + + if STAGE == 3: + CAUSAL = True + else: + CAUSAL = False + # Stage 1: off-band (for causal) or non-boundary (for irregular seqlen_k) blocks + if CAUSAL: + # Causal = True + seqlen_k_low = 0 + seqlen_k_high = min(seqlen_k_faligned, start_m * BLOCK_M) + else: + # Causal = False + seqlen_k_low = 0 + seqlen_k_high = seqlen_k_faligned + acc, l_i, m_i = attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_q, seqlen_k_low, seqlen_k_high, False, + dropout_p, seqlen_k, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + False, offs_m, offs_n, + pre_load_v, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + MARGINAL_BLOCK=False, + PADDED_HEAD=PADDED_HEAD, + ) + # Stage 2: on-band or boundary blocks + if CAUSAL or k_padded: + seqlen_k_low = seqlen_k_high + if CAUSAL: + seqlen_k_high = min(seqlen_k, start_m * BLOCK_M + BLOCK_M) + else: + seqlen_k_high = seqlen_k # barrier makes it easier for compielr to schedule the # two loops independently tl.debug_barrier() - acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i = attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - bias_ptr, - start_m, seqlen_k, seqlen_k_faligned, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + start_m, seqlen_q, seqlen_k_low, seqlen_k_high, k_padded, + dropout_p, seqlen_k, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 2, offs_m, offs_n, - PRE_LOAD_V, + CAUSAL, offs_m, offs_n, + pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, - PADDED_BLOCK=False, + MARGINAL_BLOCK=True, PADDED_HEAD=PADDED_HEAD, ) # epilogue @@ -379,10 +316,10 @@ def attn_fwd( acc = acc / l_i[:, None] if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) - m_ptrs = M + off_zh * max_seqlens_q + offs_m + m_ptrs = M + off_zh * seqlen_q + offs_m # Check for last block_M - overflow_size = (start_m * BLOCK_M + BLOCK_M) - seqlen_q - if overflow_size > 0: + if q_padded: + overflow_size = (start_m * BLOCK_M + BLOCK_M) - seqlen_q boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) # This is a > check because mask being 0 blocks the store. m_ptrs_mask = boundary > tl.arange(0, BLOCK_M) @@ -390,13 +327,22 @@ def attn_fwd( else: tl.store(m_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + o_offset = off_h * stride_oh + off_z * stride_oz O_block_ptr = tl.make_block_ptr( base=Out + o_offset, - shape=(seqlen_q, head_dim_q), + shape=(seqlen_q, head_dim), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) # Don't exceed shape, makes sure padding isn't put in output. + if q_padded: + if PADDED_HEAD: + tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) + else: + tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,)) + else: + if PADDED_HEAD: + tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(1,)) + else: + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) diff --git a/v2python/rules/flash/attn_fwd.py b/v2python/rules/flash/attn_fwd.py index 4df5f08..434d027 100644 --- a/v2python/rules/flash/attn_fwd.py +++ b/v2python/rules/flash/attn_fwd.py @@ -5,54 +5,49 @@ class attn_fwd(FlashKernel): ARGUMENTS = [ - 'Q', 'K', 'V', 'B', 'sm_scale', 'M', 'Out', + 'Q', 'K', 'V', 'sm_scale', 'M', 'Out', 'stride_qz', 'stride_qh', 'stride_qm', 'stride_qk', 'stride_kz', 'stride_kh', 'stride_kn', 'stride_kk', 'stride_vz', 'stride_vh', 'stride_vk', 'stride_vn', 'stride_oz', 'stride_oh', 'stride_om', 'stride_on', - 'stride_bz', 'stride_bh', 'stride_bm', 'stride_bn', - 'cu_seqlens_q', 'cu_seqlens_k', - 'seqlen_q', 'seqlen_k', # Note: they were renamed to max_seqlens_q/k respectively, we kept it untouched for backward compatibility with tuning database - 'head_dim_q', 'head_dim_k', + 'seqlen_q', 'seqlen_k', + 'head_dim', 'dropout_p', 'philox_seed', 'philox_offset_base', 'encoded_softmax', - 'VARLEN', # tl.constexpr starts here - 'STAGE', + 'STAGE', # tl.constexpr starts here 'BLOCK_M', 'BLOCK_DMODEL', 'BLOCK_N', - 'pre_load_v', # TODO: kernel uses PRE_LOAD_V. We use this to keep backward compatibility + 'pre_load_v', 'ENABLE_DROPOUT', 'RETURN_ENCODED_SOFTMAX', - 'BIAS_TYPE', 'PADDED_HEAD', ] TENSOR_STRIDE_INPUTS = { 'Q' : select_pattern(ARGUMENTS, 'stride_q'), 'K' : select_pattern(ARGUMENTS, 'stride_k'), 'V' : select_pattern(ARGUMENTS, 'stride_v'), - 'B' : select_pattern(ARGUMENTS, 'stride_b'), 'Out' : select_pattern(ARGUMENTS, 'stride_o'), } TYPE_CHOICES = { - frozenset(['Q', 'K', 'V', 'B', 'Out', 'encoded_softmax']) : ['*fp16:16', '*bf16:16'], + frozenset(['Q', 'K', 'V', 'Out', 'encoded_softmax']) : ['*fp16:16', '*bf16:16'], frozenset(['sm_scale']) : ['fp32'], frozenset(['M']) : ['*fp32:16'], - frozenset(['cu_seqlens_q', 'cu_seqlens_k']) : ['*u32:16'], - frozenset(['seqlen_q', 'seqlen_k', 'head_dim_q', 'head_dim_k']) : ['i32'], + # frozenset(select_pattern(ARGUMENTS, 'stride_', trim=1)) : ['u64'], + # frozenset(select_pattern(ARGUMENTS, 'stride_', trim=1)) : ['u64'], + frozenset(['seqlen_q', 'seqlen_k']) : ['i32'], + frozenset(['head_dim']) : ['u64'], frozenset(['dropout_p']) : ['fp32'], frozenset(['philox_seed']) : ['u64'], frozenset(['philox_offset_base']) : ['u32'], } FEAT_CHOICES = { - frozenset(['VARLEN']) : [False], # TODO: support varlen frozenset(['STAGE']) : [1, 3], frozenset(['BLOCK_DMODEL']) : [16, 32, 64, 128, 256], frozenset(['ENABLE_DROPOUT']) : [True, False], frozenset(['RETURN_ENCODED_SOFTMAX']) : [True, False], - frozenset(['BIAS_TYPE']) : [0], # TODO: support bias frozenset(['PADDED_HEAD']) : [True, False], } PERF_CHOICES = { @@ -62,14 +57,11 @@ class attn_fwd(FlashKernel): } TENSOR_RANKS = { '_default' : 4, - 'M' : 2, - 'cu_seqlens_q' : 1, - 'cu_seqlens_k' : 1, + 'M': 2, } EXPECTED_IDENTICAL_TENSOR_STRIDES = [ # Not needed stride_o* exist ] - # LAUNCHER_PARAMETERS is not used LAUNCHER_PARAMETERS = [ 'Q', 'K', 'V', 'sm_scale', 'M', 'Out', # Basic functions 'dropout_p', 'philox_seed', 'philox_offset', 'encoded_softmax', # dropout @@ -83,7 +75,6 @@ class attn_fwd(FlashKernel): 'seqlen_k' : BinningLessOrEqual, 'STAGE' : BinningExact, } - # List of functionals that are not fully tuned in the tuning database # First element of the tuple is name. Second is the value to use instead PARTIALLY_TUNED_FUNCTIONALS = [('RETURN_ENCODED_SOFTMAX', False), ('PADDED_HEAD', None)] diff --git a/v2src/flash/attn_fwd.cc b/v2src/flash/attn_fwd.cc index 5e5de98..c90bc2b 100644 --- a/v2src/flash/attn_fwd.cc +++ b/v2src/flash/attn_fwd.cc @@ -6,7 +6,6 @@ #include #include #include -#include #ifdef NDEBUG #define AOTRITON_VERBOSE 0 @@ -50,40 +49,30 @@ attn_fwd(T4 q, #endif return grid; }; - T1 empty_tensor_1(0, {0}, {0}, aotriton::kUInt32); - T4 empty_tensor_4(0, {0,0,0,0}, {0,0,0,0}, q.dtype()); int seqlen_q = q.size(2); int seqlen_k = k.size(2); - int head_dim_q = q.size(3); - int head_dim_k = k.size(3); + int head_size = q.size(3); + int head_dim_rounded = std::max(16, aotriton::bit_ceil(head_size)); // Requires C++ 20 - int head_dim_rounded = aotriton::bit_ceil(head_dim_q); - // Also requires C++ 20 AttnFwdParams params = { .Q = &q, .K = &k, .V = &v, - .B = &empty_tensor_4, .Out = &out, .encoded_softmax = &encoded_softmax, .sm_scale = sm_scale, .M = &softmax_lse, - .cu_seqlens_q = &empty_tensor_1, - .cu_seqlens_k = &empty_tensor_1, .seqlen_q = seqlen_q, .seqlen_k = seqlen_k, - .head_dim_q = head_dim_q, - .head_dim_k = head_dim_k, + .head_dim = static_cast(head_size), .dropout_p = dropout_p, .philox_seed = philox_seed, .philox_offset_base = static_cast(philox_offset), - .VARLEN = false, .STAGE = is_causal ? kUseCausalBits : kNoCausalBits, .BLOCK_DMODEL = head_dim_rounded, .ENABLE_DROPOUT = dropout_p > 0.0, .RETURN_ENCODED_SOFTMAX = bool(encoded_softmax), - .BIAS_TYPE = 0, - .PADDED_HEAD = head_dim_rounded != head_dim_q, + .PADDED_HEAD = head_dim_rounded != head_size, }; AttnFwdContext context; context.grid_calculator = grid_calculator;