Skip to content

Commit 2f1f148

Browse files
k50112113dllehr-amd
authored andcommitted
add VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE support for VLLM_ROCM_USE_AITER_MHA
1 parent c28077e commit 2f1f148

File tree

5 files changed

+87
-27
lines changed

5 files changed

+87
-27
lines changed

vllm/attention/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,9 @@ def unified_attention_with_output(
527527
kv_cache = self.kv_cache[forward_context.virtual_engine]
528528

529529
from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl
530+
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl
530531
from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl
531-
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterMLAImpl)):
532+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterFlashAttentionImpl) or isinstance(self.impl, AiterMLAImpl)):
532533
# fusing RoPE with flushing kv_cache operation
533534
assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None and positions is not None, f"rotary_emb not found in {self.impl=} and positions cannot be None"
534535
self.impl.forward(self,

vllm/model_executor/models/gpt_oss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
4141
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM
42-
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and not envs.VLLM_ROCM_USE_AITER_MHA
42+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
4343
VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD
4444
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD:
4545
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
@@ -51,7 +51,8 @@
5151
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM = False
5252

5353
VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA
54-
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=} {VLLM_ROCM_USE_AITER_MHA=}")
54+
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")
55+
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_MHA=}")
5556
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD=}")
5657
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM=}")
5758

vllm/model_executor/models/llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,14 @@
6363

6464
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
6565
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT
66-
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and not envs.VLLM_ROCM_USE_AITER_MHA
66+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
6767
else:
6868
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False
6969
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
7070

7171
VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA
72-
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=} {VLLM_ROCM_USE_AITER_MHA=}")
72+
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")
73+
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_MHA=}")
7374

7475
class LlamaMLP(nn.Module):
7576

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,24 @@
1515
AttentionMetadataBuilder,
1616
CommonAttentionMetadata)
1717
from vllm.v1.kv_cache_interface import AttentionSpec
18+
from vllm import envs
19+
20+
logger = init_logger(__name__)
1821

1922
_PARTITION_SIZE_ROCM = 256
2023

21-
if current_platform.is_rocm():
24+
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
25+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
26+
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
27+
)
28+
VLLM_USE_AITER_TRITON_ROPE = envs.VLLM_USE_AITER_TRITON_ROPE
29+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
30+
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
31+
else:
32+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
33+
VLLM_USE_AITER_TRITON_ROPE = False
34+
35+
if current_platform.is_rocm():
2236
import aiter
2337

2438
from vllm.triton_utils import tl, triton
@@ -209,8 +223,6 @@ def flash_attn_varlen_func_fake(
209223
flash_attn_varlen_func_fake,
210224
dispatch_key=current_platform.dispatch_key)
211225

212-
logger = init_logger(__name__)
213-
214226

215227
@dataclass
216228
class AiterFlashAttentionMetadata:
@@ -430,6 +442,7 @@ def forward(
430442
attn_metadata: AiterFlashAttentionMetadata,
431443
output: Optional[torch.Tensor] = None,
432444
output_scale: Optional[torch.Tensor] = None,
445+
positions: torch.Tensor = None,
433446
output_block_scale: Optional[torch.Tensor] = None,
434447
) -> torch.Tensor:
435448
"""Forward pass with AiterFlashAttention.
@@ -469,24 +482,70 @@ def forward(
469482

470483
num_actual_tokens = attn_metadata.num_actual_tokens
471484
key_cache, value_cache = kv_cache.unbind(0)
472-
if self.kv_sharing_target_layer_name is None:
473-
# Reshape the input keys and values and store them in the cache.
474-
# Skip this if sharing KV cache with an earlier attention layer.
475-
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
476-
# not padded. However, we don't need to do key[:num_actual_tokens]
477-
# and value[:num_actual_tokens] because the reshape_and_cache_flash
478-
# op uses the slot_mapping's shape to determine the number of
479-
# actual tokens.
480-
torch.ops._C_cache_ops.reshape_and_cache_flash(
481-
key,
482-
value,
483-
key_cache,
484-
value_cache,
485-
attn_metadata.slot_mapping,
486-
self.kv_cache_dtype,
487-
layer._k_scale,
488-
layer._v_scale,
485+
if positions is not None and query.shape[0] <= 256:
486+
assert (
487+
self.kv_sharing_target_layer_name is None
488+
), "self.kv_sharing_target_layer_name cannot be None"
489+
assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}"
490+
cos_sin_cache = self.rotary_emb.cos_sin_cache
491+
is_neox = self.rotary_emb.is_neox_style
492+
cos, sin = cos_sin_cache.chunk(2, dim=-1)
493+
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
494+
if is_fp8_kv_cache:
495+
key_cache_og_dtype = key_cache.dtype
496+
value_cache_og_dtype = value_cache.dtype
497+
key_cache = key_cache.view(self.fp8_dtype)
498+
value_cache = value_cache.view(self.fp8_dtype)
499+
query, key, key_cache, value_cache, output = (
500+
fused_qk_rope_reshape_and_cache(
501+
query,
502+
key,
503+
value,
504+
key_cache,
505+
value_cache,
506+
attn_metadata.slot_mapping,
507+
positions,
508+
cos,
509+
sin,
510+
layer._k_scale,
511+
layer._v_scale,
512+
is_neox,
513+
flash_layout=True,
514+
apply_scale=is_fp8_kv_cache,
515+
offs=None,
516+
q_out=query,
517+
k_out=key,
518+
output_zeros=True,
519+
zeros_out=output,
520+
)
489521
)
522+
if is_fp8_kv_cache:
523+
key_cache = key_cache.view(key_cache_og_dtype)
524+
value_cache = value_cache.view(value_cache_og_dtype)
525+
else:
526+
if positions is not None:
527+
if VLLM_USE_AITER_TRITON_ROPE:
528+
query, key = self.rotary_emb.forward_cuda(positions, query, key)
529+
else:
530+
query, key = self.rotary_emb(positions, query, key)
531+
if self.kv_sharing_target_layer_name is None:
532+
# Reshape the input keys and values and store them in the cache.
533+
# Skip this if sharing KV cache with an earlier attention layer.
534+
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
535+
# not padded. However, we don't need to do key[:num_actual_tokens]
536+
# and value[:num_actual_tokens] because the reshape_and_cache_flash
537+
# op uses the slot_mapping's shape to determine the number of
538+
# actual tokens.
539+
torch.ops._C_cache_ops.reshape_and_cache_flash(
540+
key,
541+
value,
542+
key_cache,
543+
value_cache,
544+
attn_metadata.slot_mapping,
545+
self.kv_cache_dtype,
546+
layer._k_scale,
547+
layer._v_scale,
548+
)
490549

491550
if self.kv_cache_dtype.startswith("fp8"):
492551
if current_platform.is_fp8_fnuz():

vllm/v1/attention/backends/triton_attn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,6 @@ def forward(
295295
output: Optional[torch.Tensor] = None,
296296
output_scale: Optional[torch.Tensor] = None,
297297
positions: torch.Tensor = None,
298-
cos_sin_cache: torch.Tensor = None,
299-
is_neox: bool = False,
300298
output_block_scale: Optional[torch.Tensor] = None,
301299
) -> torch.Tensor:
302300
"""Forward pass with FlashAttention.

0 commit comments

Comments
 (0)