diff --git a/README.md b/README.md index 7ec3fba..52bb906 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ _Now, you can process **1M context 10x faster in a single A100** using Long-cont ### Requirements - Torch -- FlashAttention-2 +- FlashAttention-2 (Optional) - Triton == 2.1.0 To get started with MInference, simply install it using pip: diff --git a/minference/modules/inf_llm.py b/minference/modules/inf_llm.py index f247dec..52c7d63 100644 --- a/minference/modules/inf_llm.py +++ b/minference/modules/inf_llm.py @@ -6,9 +6,15 @@ from typing import Optional, Tuple import torch -from flash_attn import flash_attn_func from transformers.modeling_outputs import CausalLMOutput +try: + from flash_attn import flash_attn_func as dense_decoding_func +except ImportError: + from ..ops.flash_attn_triton import ( + _flash_attn_triton_decoding as dense_decoding_func, + ) + from ..ops.streaming_kernel import TritonMultiStageDotProductionAttention @@ -1084,7 +1090,7 @@ def forward( h_v = h_v.transpose(1, 2) # (batch_size, seqlen, nheads, headdim) - o = flash_attn_func(h_q, h_k, h_v, causal=True) + o = dense_decoding_func(h_q, h_k, h_v) o = o.reshape(batch_size, len_q, dim_head * num_heads) o = attention_out(o) diff --git a/minference/modules/minference_forward.py b/minference/modules/minference_forward.py index 0ef502e..6db50cb 100644 --- a/minference/modules/minference_forward.py +++ b/minference/modules/minference_forward.py @@ -12,15 +12,20 @@ if _is_package_available("vllm"): try: - from vllm.attention.backends.flash_attn import * + from vllm.attention.ops.paged_attn import PagedAttention except: - warnings.warn("Only support 'vllm==0.4.1'. Please update your vllm version.") + warnings.warn("Only support 'vllm>=0.4.0'. Please update your vllm version.") from ..ops.block_sparse_flash_attention import block_sparse_attention from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention from ..ops.streaming_kernel import streaming_forward, streaming_forward2 from .snap_kv import * +try: + from flash_attn import flash_attn_func +except ImportError: + from ..ops.flash_attn_triton import _flash_attn_triton_decoding as flash_attn_func + last_q = 64 arange = torch.arange(last_q, device="cuda") LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :] @@ -1113,7 +1118,7 @@ def forward_vllm_043( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata, kv_scale: float, layer_idx: int, ) -> torch.Tensor: diff --git a/minference/ops/block_sparse_flash_attention.py b/minference/ops/block_sparse_flash_attention.py index 17a6a8f..b651189 100644 --- a/minference/ops/block_sparse_flash_attention.py +++ b/minference/ops/block_sparse_flash_attention.py @@ -5,8 +5,8 @@ import torch import triton import triton.language as tl -from flash_attn import flash_attn_varlen_func +# from flash_attn import flash_attn_varlen_func # import pycuda.autoprimaryctx # from pycuda.compiler import SourceModule diff --git a/minference/ops/flash_attn_triton.py b/minference/ops/flash_attn_triton.py new file mode 100644 index 0000000..e9ac3ba --- /dev/null +++ b/minference/ops/flash_attn_triton.py @@ -0,0 +1,339 @@ +import math + +import torch +import triton +import triton.language as tl + + +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Bias, + Out, + Lse, + TMP, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + q_ptrs = ( + Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + k_ptrs = ( + K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + ) + v_ptrs = ( + V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + ) + if BIAS_TYPE == "vector": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = ( + Bias + + off_b * stride_bb + + off_h * stride_bh + + (offs_m[:, None] * stride_bm + offs_n[None, :]) + ) + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 + ) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != "none": + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 + ).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = ( + Out + + off_b * stride_ob + + off_h * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + +def _flash_attn_triton_decoding(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, "FlashAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(d) + bias = None + + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError( + "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" + ) + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + k, + v, + bias, + o, + lse, + tmp, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o #, lse, softmax_scale + +def torch_decoding(q, k, v): + def repeat_kv(kv, num_groups): + return kv.repeat(num_groups, 1, 1) + # q: [bsz, num_heads, q_len, head_dim] + # k,v: [bsz, num_kv_heads, kv_len, head_dim] + bsz, num_heads, q_len, head_dim = q.size() + _, num_kv_heads, kv_len, _ = k.size() + num_groups = num_heads // num_kv_heads + assert bsz == 1 + q = q.view(num_heads, q_len, head_dim) + k = k.view(num_kv_heads, kv_len, head_dim) + v = v.view(num_kv_heads, kv_len, head_dim) + + k = repeat_kv(k, num_groups) + v = repeat_kv(v, num_groups) + + o = torch.bmm(q, k.transpose(1, 2)) + o = o / (head_dim ** 0.5) + o = torch.nn.functional.softmax(o, dim=-1, dtype=torch.float32).to(q.dtype) + o = torch.bmm(o, v) + return o.transpose(0, 1).view(1, num_heads, head_dim) + +if __name__ == "__main__": + from flash_attn import flash_attn_func + q = torch.randn(1, 32, 1, 128, device="cuda", dtype=torch.bfloat16) + k = torch.randn(1, 32, 20, 128, device="cuda", dtype=torch.bfloat16) + v = torch.randn(1, 32, 20, 128, device="cuda", dtype=torch.bfloat16) + + fa_o = flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), + 0.0, softmax_scale=None, causal=True + ) + t_o = torch_decoding(q, k, v)[:, None, ...] + + print('testing flash_attn') + triton.testing.assert_close(fa_o, t_o) + + print('testing flash_attn_triton_decoding') + fad_o, _, _ = _flash_attn_triton_decoding(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)) + triton.testing.assert_close(fad_o, t_o) + print(fad_o.shape) diff --git a/minference/version.py b/minference/version.py index d8ee7fc..b96d89f 100644 --- a/minference/version.py +++ b/minference/version.py @@ -8,7 +8,7 @@ _PATCH = "4" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. -_SUFFIX = ".post2" +_SUFFIX = ".post3" VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) diff --git a/setup.py b/setup.py index 52307d9..6854d3e 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,6 @@ "accelerate", "torch", "triton", - "flash_attn", ] QUANLITY_REQUIRES = [ "black==21.4b0",