|
15 | 15 | AttentionMetadataBuilder, |
16 | 16 | CommonAttentionMetadata) |
17 | 17 | from vllm.v1.kv_cache_interface import AttentionSpec |
| 18 | +from vllm import envs |
| 19 | + |
| 20 | +logger = init_logger(__name__) |
18 | 21 |
|
19 | 22 | _PARTITION_SIZE_ROCM = 256 |
20 | 23 |
|
21 | | -if current_platform.is_rocm(): |
| 24 | +if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: |
| 25 | + VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( |
| 26 | + envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE |
| 27 | + ) |
| 28 | + VLLM_USE_AITER_TRITON_ROPE = envs.VLLM_USE_AITER_TRITON_ROPE |
| 29 | + if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: |
| 30 | + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache |
| 31 | +else: |
| 32 | + VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False |
| 33 | + VLLM_USE_AITER_TRITON_ROPE = False |
| 34 | + |
| 35 | +if current_platform.is_rocm(): |
22 | 36 | import aiter |
23 | 37 |
|
24 | 38 | from vllm.triton_utils import tl, triton |
@@ -209,8 +223,6 @@ def flash_attn_varlen_func_fake( |
209 | 223 | flash_attn_varlen_func_fake, |
210 | 224 | dispatch_key=current_platform.dispatch_key) |
211 | 225 |
|
212 | | -logger = init_logger(__name__) |
213 | | - |
214 | 226 |
|
215 | 227 | @dataclass |
216 | 228 | class AiterFlashAttentionMetadata: |
@@ -430,6 +442,7 @@ def forward( |
430 | 442 | attn_metadata: AiterFlashAttentionMetadata, |
431 | 443 | output: Optional[torch.Tensor] = None, |
432 | 444 | output_scale: Optional[torch.Tensor] = None, |
| 445 | + positions: torch.Tensor = None, |
433 | 446 | output_block_scale: Optional[torch.Tensor] = None, |
434 | 447 | ) -> torch.Tensor: |
435 | 448 | """Forward pass with AiterFlashAttention. |
@@ -469,24 +482,70 @@ def forward( |
469 | 482 |
|
470 | 483 | num_actual_tokens = attn_metadata.num_actual_tokens |
471 | 484 | key_cache, value_cache = kv_cache.unbind(0) |
472 | | - if self.kv_sharing_target_layer_name is None: |
473 | | - # Reshape the input keys and values and store them in the cache. |
474 | | - # Skip this if sharing KV cache with an earlier attention layer. |
475 | | - # NOTE(woosuk): Here, key and value are padded while slot_mapping is |
476 | | - # not padded. However, we don't need to do key[:num_actual_tokens] |
477 | | - # and value[:num_actual_tokens] because the reshape_and_cache_flash |
478 | | - # op uses the slot_mapping's shape to determine the number of |
479 | | - # actual tokens. |
480 | | - torch.ops._C_cache_ops.reshape_and_cache_flash( |
481 | | - key, |
482 | | - value, |
483 | | - key_cache, |
484 | | - value_cache, |
485 | | - attn_metadata.slot_mapping, |
486 | | - self.kv_cache_dtype, |
487 | | - layer._k_scale, |
488 | | - layer._v_scale, |
| 485 | + if positions is not None and query.shape[0] <= 256: |
| 486 | + assert ( |
| 487 | + self.kv_sharing_target_layer_name is None |
| 488 | + ), "self.kv_sharing_target_layer_name cannot be None" |
| 489 | + assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}" |
| 490 | + cos_sin_cache = self.rotary_emb.cos_sin_cache |
| 491 | + is_neox = self.rotary_emb.is_neox_style |
| 492 | + cos, sin = cos_sin_cache.chunk(2, dim=-1) |
| 493 | + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") |
| 494 | + if is_fp8_kv_cache: |
| 495 | + key_cache_og_dtype = key_cache.dtype |
| 496 | + value_cache_og_dtype = value_cache.dtype |
| 497 | + key_cache = key_cache.view(self.fp8_dtype) |
| 498 | + value_cache = value_cache.view(self.fp8_dtype) |
| 499 | + query, key, key_cache, value_cache, output = ( |
| 500 | + fused_qk_rope_reshape_and_cache( |
| 501 | + query, |
| 502 | + key, |
| 503 | + value, |
| 504 | + key_cache, |
| 505 | + value_cache, |
| 506 | + attn_metadata.slot_mapping, |
| 507 | + positions, |
| 508 | + cos, |
| 509 | + sin, |
| 510 | + layer._k_scale, |
| 511 | + layer._v_scale, |
| 512 | + is_neox, |
| 513 | + flash_layout=True, |
| 514 | + apply_scale=is_fp8_kv_cache, |
| 515 | + offs=None, |
| 516 | + q_out=query, |
| 517 | + k_out=key, |
| 518 | + output_zeros=True, |
| 519 | + zeros_out=output, |
| 520 | + ) |
489 | 521 | ) |
| 522 | + if is_fp8_kv_cache: |
| 523 | + key_cache = key_cache.view(key_cache_og_dtype) |
| 524 | + value_cache = value_cache.view(value_cache_og_dtype) |
| 525 | + else: |
| 526 | + if positions is not None: |
| 527 | + if VLLM_USE_AITER_TRITON_ROPE: |
| 528 | + query, key = self.rotary_emb.forward_cuda(positions, query, key) |
| 529 | + else: |
| 530 | + query, key = self.rotary_emb(positions, query, key) |
| 531 | + if self.kv_sharing_target_layer_name is None: |
| 532 | + # Reshape the input keys and values and store them in the cache. |
| 533 | + # Skip this if sharing KV cache with an earlier attention layer. |
| 534 | + # NOTE(woosuk): Here, key and value are padded while slot_mapping is |
| 535 | + # not padded. However, we don't need to do key[:num_actual_tokens] |
| 536 | + # and value[:num_actual_tokens] because the reshape_and_cache_flash |
| 537 | + # op uses the slot_mapping's shape to determine the number of |
| 538 | + # actual tokens. |
| 539 | + torch.ops._C_cache_ops.reshape_and_cache_flash( |
| 540 | + key, |
| 541 | + value, |
| 542 | + key_cache, |
| 543 | + value_cache, |
| 544 | + attn_metadata.slot_mapping, |
| 545 | + self.kv_cache_dtype, |
| 546 | + layer._k_scale, |
| 547 | + layer._v_scale, |
| 548 | + ) |
490 | 549 |
|
491 | 550 | if self.kv_cache_dtype.startswith("fp8"): |
492 | 551 | if current_platform.is_fp8_fnuz(): |
|
0 commit comments