Skip to content

Commit e2347db

Browse files
authored
[Bugfix] [Model] Missing MRoPE function definition from KeyeForConditionalGeneration (vllm-project#27895)
Signed-off-by: tjtanaa <[email protected]>
1 parent 879a065 commit e2347db

File tree

2 files changed

+254
-17
lines changed

2 files changed

+254
-17
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from dataclasses import asdict
4+
from typing import NamedTuple
5+
6+
import pytest
7+
from PIL.Image import Image
8+
from transformers import AutoProcessor
9+
10+
from vllm import LLM, EngineArgs, SamplingParams
11+
from vllm.multimodal.utils import encode_image_base64
12+
13+
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
14+
15+
QUESTION = "What is the content of each image?"
16+
17+
18+
class ModelRequestData(NamedTuple):
19+
engine_args: EngineArgs
20+
prompt: str
21+
image_data: list[Image]
22+
stop_token_ids: list[int] | None = None
23+
chat_template: str | None = None
24+
sampling_params: SamplingParams | None = None
25+
26+
27+
@pytest.mark.core_model
28+
@pytest.mark.parametrize("question", [QUESTION])
29+
def test_keye_vl(
30+
image_assets,
31+
question: str,
32+
):
33+
images = [asset.pil_image for asset in image_assets]
34+
35+
image_urls = [
36+
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
37+
]
38+
39+
engine_args = EngineArgs(
40+
model=MODEL_NAME,
41+
trust_remote_code=True,
42+
max_model_len=8192,
43+
max_num_seqs=5,
44+
limit_mm_per_prompt={"image": len(image_urls)},
45+
)
46+
47+
placeholders = [{"type": "image", "image": url} for url in image_urls]
48+
messages = [
49+
{
50+
"role": "user",
51+
"content": [
52+
*placeholders,
53+
{"type": "text", "text": question},
54+
],
55+
},
56+
]
57+
58+
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
59+
60+
prompt = processor.apply_chat_template(
61+
messages, tokenize=False, add_generation_prompt=True
62+
)
63+
64+
engine_args = asdict(engine_args) | {"seed": 42}
65+
llm = LLM(**engine_args)
66+
67+
sampling_params = SamplingParams(
68+
temperature=0.0, max_tokens=256, stop_token_ids=None
69+
)
70+
71+
outputs = llm.generate(
72+
{
73+
"prompt": prompt,
74+
"multi_modal_data": {"image": images},
75+
},
76+
sampling_params=sampling_params,
77+
)
78+
79+
print("-" * 50)
80+
for o in outputs:
81+
generated_text = o.outputs[0].text
82+
print(generated_text)
83+
assert len(generated_text) > 10, (
84+
f"Generated text is too short: {generated_text}"
85+
)
86+
print("-" * 50)

vllm/model_executor/models/keye.py

Lines changed: 168 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from transformers.utils import torch_int
1818

1919
from vllm.attention.backends.registry import _Backend
20-
from vllm.attention.layer import check_upstream_fa_availability
20+
from vllm.attention.layer import (
21+
maybe_get_vit_flash_attn_backend,
22+
)
2123
from vllm.config import VllmConfig
2224
from vllm.config.multimodal import BaseDummyOptions
2325
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -56,12 +58,14 @@
5658
PromptUpdate,
5759
)
5860
from vllm.multimodal.profiling import BaseDummyInputsBuilder
61+
from vllm.platforms import current_platform
5962
from vllm.sequence import IntermediateTensors
6063
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6164

6265
from .interfaces import (
6366
MultiModalEmbeddings,
6467
SupportsLoRA,
68+
SupportsMRoPE,
6569
SupportsMultiModal,
6670
SupportsPP,
6771
)
@@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
337341
cos = cos.chunk(2, dim=-1)[0].contiguous()
338342
sin = sin.chunk(2, dim=-1)[0].contiguous()
339343

340-
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
344+
if current_platform.is_cuda():
345+
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
346+
elif current_platform.is_rocm():
347+
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
341348

342349
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
343350
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
@@ -398,18 +405,28 @@ def __init__(
398405
attn_backend_override=attn_backend_override,
399406
)
400407

401-
self.use_upstream_fa = False
402-
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
403-
torch.get_default_dtype()
404-
):
405-
self.attn_backend = _Backend.FLASH_ATTN
406-
self.use_upstream_fa = True
408+
self.attn_backend, self.flash_attn_varlen_func = (
409+
maybe_get_vit_flash_attn_backend(
410+
self.attn_backend,
411+
use_upstream_fa=False,
412+
attn_backend_override=attn_backend_override,
413+
)
414+
)
407415

408-
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
416+
if self.attn_backend not in {
417+
_Backend.FLASH_ATTN,
418+
_Backend.XFORMERS,
419+
_Backend.ROCM_AITER_FA,
420+
}:
409421
raise RuntimeError(
410422
f"Keye-VL does not support {self.attn_backend} backend now."
411423
)
412424

425+
self.is_flash_attn_backend = self.attn_backend in {
426+
_Backend.FLASH_ATTN,
427+
_Backend.ROCM_AITER_FA,
428+
}
429+
413430
def forward(
414431
self,
415432
hidden_states: torch.Tensor,
@@ -457,15 +474,10 @@ def forward(
457474
self.head_dim,
458475
)
459476

460-
if self.attn_backend == _Backend.FLASH_ATTN:
461-
if self.use_upstream_fa:
462-
from flash_attn import flash_attn_varlen_func
463-
else:
464-
from vllm.vllm_flash_attn import flash_attn_varlen_func
465-
477+
if self.is_flash_attn_backend:
466478
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
467479

468-
output = flash_attn_varlen_func(
480+
output = self.flash_attn_varlen_func(
469481
q,
470482
k,
471483
v,
@@ -1542,7 +1554,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
15421554
dummy_inputs=KeyeDummyInputsBuilder,
15431555
)
15441556
class KeyeForConditionalGeneration(
1545-
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
1557+
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
15461558
):
15471559
def _build_projector(
15481560
self,
@@ -1611,3 +1623,142 @@ def _process_video_input(
16111623
return tuple(
16121624
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
16131625
)
1626+
1627+
def get_mrope_input_positions(
1628+
self,
1629+
input_tokens: list[int],
1630+
hf_config: PretrainedConfig,
1631+
image_grid_thw: list[list[int]] | torch.Tensor,
1632+
video_grid_thw: list[list[int]] | torch.Tensor,
1633+
context_len: int = 0,
1634+
seq_len: int | None = None,
1635+
second_per_grid_ts: list[float] | None = None,
1636+
audio_feature_lengths: torch.Tensor | None = None,
1637+
use_audio_in_video: bool = False,
1638+
) -> tuple[torch.Tensor, int]:
1639+
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
1640+
video_grid_thw = video_grid_thw[0]
1641+
"""Get mrope input positions and delta value (Keye series)."""
1642+
1643+
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
1644+
"""
1645+
Split grid_thw along the t dimension.
1646+
1647+
Args:
1648+
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
1649+
1650+
Returns:
1651+
List of [1, h, w] rows, repeated t times for each original row.
1652+
"""
1653+
1654+
if isinstance(grid_thw, list):
1655+
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
1656+
1657+
if grid_thw.numel() == 0:
1658+
return []
1659+
1660+
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
1661+
ones = torch.ones_like(hw[:, :1]) # [N,1]
1662+
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
1663+
return out.tolist()
1664+
1665+
video_grid_thw = split_thw(video_grid_thw)
1666+
1667+
image_token_id = hf_config.image_token_id
1668+
video_token_id = hf_config.video_token_id
1669+
spatial_merge_size = hf_config.vision_config.spatial_merge_size
1670+
1671+
image_nums = len(image_grid_thw)
1672+
frame_nums = len(video_grid_thw)
1673+
llm_pos_ids_list: list = []
1674+
1675+
st = 0
1676+
remain_images, remain_frames = image_nums, frame_nums
1677+
1678+
image_index, video_index = 0, 0
1679+
for _ in range(image_nums + frame_nums):
1680+
if remain_images > 0:
1681+
try:
1682+
ed_image = input_tokens.index(image_token_id, st)
1683+
except ValueError:
1684+
ed_image = len(input_tokens) + 1
1685+
else:
1686+
ed_image = len(input_tokens) + 1
1687+
if remain_frames > 0:
1688+
try:
1689+
ed_video = input_tokens.index(video_token_id, st)
1690+
except ValueError:
1691+
ed_video = len(input_tokens) + 1
1692+
else:
1693+
ed_video = len(input_tokens) + 1
1694+
1695+
if ed_image < ed_video:
1696+
t, h, w = (
1697+
image_grid_thw[image_index][0],
1698+
image_grid_thw[image_index][1],
1699+
image_grid_thw[image_index][2],
1700+
)
1701+
image_index += 1
1702+
remain_images -= 1
1703+
ed = ed_image
1704+
else:
1705+
t, h, w = (
1706+
video_grid_thw[video_index][0],
1707+
video_grid_thw[video_index][1],
1708+
video_grid_thw[video_index][2],
1709+
)
1710+
video_index += 1
1711+
remain_frames -= 1
1712+
ed = ed_video
1713+
1714+
llm_grid_t, llm_grid_h, llm_grid_w = (
1715+
t,
1716+
h // spatial_merge_size,
1717+
w // spatial_merge_size,
1718+
)
1719+
text_len = ed - st
1720+
1721+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1722+
llm_pos_ids_list.append(
1723+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1724+
)
1725+
1726+
t_index = (
1727+
(
1728+
torch.arange(llm_grid_t)
1729+
.view(-1, 1)
1730+
.expand(-1, llm_grid_h * llm_grid_w)
1731+
)
1732+
.long()
1733+
.flatten()
1734+
)
1735+
1736+
h_index = (
1737+
torch.arange(llm_grid_h)
1738+
.view(1, -1, 1)
1739+
.expand(llm_grid_t, -1, llm_grid_w)
1740+
.flatten()
1741+
)
1742+
w_index = (
1743+
torch.arange(llm_grid_w)
1744+
.view(1, 1, -1)
1745+
.expand(llm_grid_t, llm_grid_h, -1)
1746+
.flatten()
1747+
)
1748+
llm_pos_ids_list.append(
1749+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
1750+
)
1751+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1752+
1753+
if st < len(input_tokens):
1754+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1755+
text_len = len(input_tokens) - st
1756+
llm_pos_ids_list.append(
1757+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1758+
)
1759+
1760+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1761+
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1762+
llm_positions = llm_positions[:, context_len:seq_len]
1763+
1764+
return llm_positions, mrope_position_delta

0 commit comments

Comments
 (0)