Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature(MInference): add triton-based decoding in case flash_attn is not available #35

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ _Now, you can process **1M context 10x faster in a single A100** using Long-cont
### Requirements

- Torch
- FlashAttention-2
- FlashAttention-2 (Optional)
- Triton == 2.1.0

To get started with MInference, simply install it using pip:
Expand Down
10 changes: 8 additions & 2 deletions minference/modules/inf_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
from typing import Optional, Tuple

import torch
from flash_attn import flash_attn_func
from transformers.modeling_outputs import CausalLMOutput

try:
from flash_attn import flash_attn_func as dense_decoding_func
except ImportError:
from ..ops.flash_attn_triton import (
_flash_attn_triton_decoding as dense_decoding_func,
)

from ..ops.streaming_kernel import TritonMultiStageDotProductionAttention


Expand Down Expand Up @@ -1084,7 +1090,7 @@ def forward(
h_v = h_v.transpose(1, 2)

# (batch_size, seqlen, nheads, headdim)
o = flash_attn_func(h_q, h_k, h_v, causal=True)
o = dense_decoding_func(h_q, h_k, h_v)

o = o.reshape(batch_size, len_q, dim_head * num_heads)
o = attention_out(o)
Expand Down
11 changes: 8 additions & 3 deletions minference/modules/minference_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@

if _is_package_available("vllm"):
try:
from vllm.attention.backends.flash_attn import *
from vllm.attention.ops.paged_attn import PagedAttention
except:
warnings.warn("Only support 'vllm==0.4.1'. Please update your vllm version.")
warnings.warn("Only support 'vllm>=0.4.0'. Please update your vllm version.")

from ..ops.block_sparse_flash_attention import block_sparse_attention
from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention
from ..ops.streaming_kernel import streaming_forward, streaming_forward2
from .snap_kv import *

try:
from flash_attn import flash_attn_func
except ImportError:
from ..ops.flash_attn_triton import _flash_attn_triton_decoding as flash_attn_func

last_q = 64
arange = torch.arange(last_q, device="cuda")
LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :]
Expand Down Expand Up @@ -1113,7 +1118,7 @@ def forward_vllm_043(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
attn_metadata,
kv_scale: float,
layer_idx: int,
) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion minference/ops/block_sparse_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
import triton
import triton.language as tl
from flash_attn import flash_attn_varlen_func

# from flash_attn import flash_attn_varlen_func
# import pycuda.autoprimaryctx
# from pycuda.compiler import SourceModule

Expand Down
Loading