Skip to content

Commit

Permalink
Feature(MInference): add triton-based decoding in case flash_attn is …
Browse files Browse the repository at this point in the history
…not available (#35)

* add triton-based decoding in case flash_attn is not available
* Feature(MInference): remove flash_attn dependency

Co-authored-by: Huiqiang Jiang <[email protected]>
  • Loading branch information
liyucheng09 and iofu728 authored Jul 15, 2024
1 parent a880a6e commit 50d17d9
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ _Now, you can process **1M context 10x faster in a single A100** using Long-cont
### Requirements

- Torch
- FlashAttention-2
- FlashAttention-2 (Optional)
- Triton == 2.1.0

To get started with MInference, simply install it using pip:
Expand Down
10 changes: 8 additions & 2 deletions minference/modules/inf_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
from typing import Optional, Tuple

import torch
from flash_attn import flash_attn_func
from transformers.modeling_outputs import CausalLMOutput

try:
from flash_attn import flash_attn_func as dense_decoding_func
except ImportError:
from ..ops.flash_attn_triton import (
_flash_attn_triton_decoding as dense_decoding_func,
)

from ..ops.streaming_kernel import TritonMultiStageDotProductionAttention


Expand Down Expand Up @@ -1084,7 +1090,7 @@ def forward(
h_v = h_v.transpose(1, 2)

# (batch_size, seqlen, nheads, headdim)
o = flash_attn_func(h_q, h_k, h_v, causal=True)
o = dense_decoding_func(h_q, h_k, h_v)

o = o.reshape(batch_size, len_q, dim_head * num_heads)
o = attention_out(o)
Expand Down
11 changes: 8 additions & 3 deletions minference/modules/minference_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@

if _is_package_available("vllm"):
try:
from vllm.attention.backends.flash_attn import *
from vllm.attention.ops.paged_attn import PagedAttention
except:
warnings.warn("Only support 'vllm==0.4.1'. Please update your vllm version.")
warnings.warn("Only support 'vllm>=0.4.0'. Please update your vllm version.")

from ..ops.block_sparse_flash_attention import block_sparse_attention
from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention
from ..ops.streaming_kernel import streaming_forward, streaming_forward2
from .snap_kv import *

try:
from flash_attn import flash_attn_func
except ImportError:
from ..ops.flash_attn_triton import _flash_attn_triton_decoding as flash_attn_func

last_q = 64
arange = torch.arange(last_q, device="cuda")
LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :]
Expand Down Expand Up @@ -1113,7 +1118,7 @@ def forward_vllm_043(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
attn_metadata,
kv_scale: float,
layer_idx: int,
) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion minference/ops/block_sparse_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
import triton
import triton.language as tl
from flash_attn import flash_attn_varlen_func

# from flash_attn import flash_attn_varlen_func
# import pycuda.autoprimaryctx
# from pycuda.compiler import SourceModule

Expand Down
Loading

0 comments on commit 50d17d9

Please sign in to comment.