Skip to content

Commit 5fc59ef

Browse files
author
Varun Sundar Rabindranath
committed
Add lora support
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
1 parent d53575a commit 5fc59ef

17 files changed

+453
-66
lines changed

tests/lora/conftest.py

+17
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,20 @@ def get_model_patched(**kwargs):
298298
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
299299
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
300300
model_runner.model)
301+
302+
303+
@pytest.fixture(params=[False, True])
304+
def run_with_both_engines_lora(request):
305+
# Automatically runs tests twice, once with V1 and once without
306+
use_v1 = request.param
307+
# Tests decorated with `@skip_v1` are only run without v1
308+
skip_v1 = request.node.get_closest_marker("skip_v1")
309+
310+
if use_v1:
311+
if skip_v1:
312+
pytest.skip("Skipping test on vllm V1")
313+
with patch('vllm.envs.VLLM_USE_V1', True):
314+
yield
315+
else:
316+
with patch('vllm.envs.VLLM_USE_V1', False):
317+
yield

tests/lora/test_baichuan.py

+8
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
4040
return generated_texts
4141

4242

43+
@pytest.fixture(autouse=True)
44+
def v1(run_with_both_engines_lora):
45+
# Simple autouse wrapper to run both engines for each test
46+
# This can be promoted up to conftest.py to run for every
47+
# test in a package
48+
pass
49+
50+
4351
def test_baichuan_lora(baichuan_lora_files):
4452
llm = vllm.LLM(MODEL_PATH,
4553
max_model_len=1024,

tests/lora/test_chatglm3_tp.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import List
22

3+
import pytest
4+
35
import vllm
46
from tests.utils import fork_new_process_for_each_test
57
from vllm.lora.request import LoRARequest
@@ -45,6 +47,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
4547
return generated_texts
4648

4749

50+
@pytest.fixture(autouse=True)
51+
def v1(run_with_both_engines_lora):
52+
# Simple autouse wrapper to run both engines for each test
53+
# This can be promoted up to conftest.py to run for every
54+
# test in a package
55+
pass
56+
57+
4858
@fork_new_process_for_each_test
4959
def test_chatglm3_lora(chatglm3_lora_files):
5060
llm = vllm.LLM(MODEL_PATH,

tests/lora/test_gemma.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
3131
return generated_texts
3232

3333

34+
@pytest.fixture(autouse=True)
35+
def v1(run_with_both_engines_lora):
36+
# Simple autouse wrapper to run both engines for each test
37+
# This can be promoted up to conftest.py to run for every
38+
# test in a package
39+
pass
40+
41+
3442
@pytest.mark.xfail(current_platform.is_rocm(),
3543
reason="There can be output mismatch on ROCm")
3644
def test_gemma_lora(gemma_lora_files):

tests/lora/test_llama_tp.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List
22

3+
import pytest
34
import ray
45

56
import vllm
@@ -71,6 +72,14 @@ def generate_and_test(llm, sql_lora_files):
7172
print("removing lora")
7273

7374

75+
@pytest.fixture(autouse=True)
76+
def v1(run_with_both_engines_lora):
77+
# Simple autouse wrapper to run both engines for each test
78+
# This can be promoted up to conftest.py to run for every
79+
# test in a package
80+
pass
81+
82+
7483
@fork_new_process_for_each_test
7584
def test_llama_lora(sql_lora_files):
7685

tests/lora/test_lora_bias_e2e.py

+8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
2828
return generated_texts
2929

3030

31+
@pytest.fixture(autouse=True)
32+
def v1(run_with_both_engines_lora):
33+
# Simple autouse wrapper to run both engines for each test
34+
# This can be promoted up to conftest.py to run for every
35+
# test in a package
36+
pass
37+
38+
3139
@pytest.mark.parametrize("lora_bias", [True])
3240
@pytest.mark.parametrize("fully_sharded", [True, False])
3341
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):

tests/lora/test_phi.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import List
22

3+
import pytest
4+
35
import vllm
46
from vllm.lora.request import LoRARequest
57

@@ -46,6 +48,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
4648
return generated_texts
4749

4850

51+
@pytest.fixture(autouse=True)
52+
def v1(run_with_both_engines_lora):
53+
# Simple autouse wrapper to run both engines for each test
54+
# This can be promoted up to conftest.py to run for every
55+
# test in a package
56+
pass
57+
58+
4959
def test_phi2_lora(phi2_lora_files):
5060
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
5161
# Otherwise, the lora-test will fail due to CUDA OOM.

tests/lora/test_quant_model.py

+8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ def format_prompt_tuples(prompt):
6868
return generated_texts
6969

7070

71+
@pytest.fixture(autouse=True)
72+
def v1(run_with_both_engines_lora):
73+
# Simple autouse wrapper to run both engines for each test
74+
# This can be promoted up to conftest.py to run for every
75+
# test in a package
76+
pass
77+
78+
7179
@pytest.mark.parametrize("model", MODELS)
7280
@pytest.mark.parametrize("tp_size", [1])
7381
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,

vllm/lora/layers.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
get_tensor_model_parallel_world_size,
1515
split_tensor_along_last_dim,
1616
tensor_model_parallel_all_gather,
17-
tensor_model_parallel_all_reduce,
18-
tensor_model_parallel_gather)
17+
tensor_model_parallel_all_reduce)
1918
from vllm.distributed.utils import divide
2019
# yapf: disable
2120
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -1034,7 +1033,10 @@ def _get_logits(
10341033
logits = lm_head.linear_method.apply(lm_head, hidden_states)
10351034
if embedding_bias is not None:
10361035
logits += embedding_bias
1037-
logits = tensor_model_parallel_gather(logits)
1036+
1037+
# Gather logits for TP
1038+
logits = self.base_layer._gather_logits(logits)
1039+
10381040
if logits is None:
10391041
return None
10401042

vllm/model_executor/layers/logits_processor.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def __init__(self,
4343
# Soft cap the logits. Used in Gemma 2.
4444
self.soft_cap = soft_cap
4545
# Whether to use gather or all-gather to gather the logits.
46-
4746
self.use_gather = not current_platform.is_tpu(
4847
) and not envs.VLLM_USE_V1
4948

@@ -78,16 +77,8 @@ def forward(
7877

7978
return logits
8079

81-
def _get_logits(
82-
self,
83-
hidden_states: torch.Tensor,
84-
lm_head: VocabParallelEmbedding,
85-
embedding_bias: Optional[torch.Tensor],
86-
) -> Optional[torch.Tensor]:
87-
# Get the logits for the next tokens.
88-
logits = lm_head.linear_method.apply(lm_head,
89-
hidden_states,
90-
bias=embedding_bias)
80+
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
81+
"""gather/all-gather the logits tensor across model parallel group."""
9182
if self.use_gather:
9283
# None may be returned for rank > 0
9384
logits = tensor_model_parallel_gather(logits)
@@ -98,6 +89,22 @@ def _get_logits(
9889
# because XLA requires strict SPMD among all devices. Every device
9990
# should execute the same operations after gathering the logits.
10091
logits = tensor_model_parallel_all_gather(logits)
92+
return logits
93+
94+
def _get_logits(
95+
self,
96+
hidden_states: torch.Tensor,
97+
lm_head: VocabParallelEmbedding,
98+
embedding_bias: Optional[torch.Tensor],
99+
) -> Optional[torch.Tensor]:
100+
# Get the logits for the next tokens.
101+
logits = lm_head.linear_method.apply(lm_head,
102+
hidden_states,
103+
bias=embedding_bias)
104+
105+
# Gather logits for TP
106+
logits = self._gather_logits(logits)
107+
101108
# Remove paddings in vocab (if any).
102109
if logits is not None:
103110
logits = logits[..., :self.org_vocab_size]

vllm/v1/core/kv_cache_utils.py

+75-26
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,28 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]:
166166
return ret
167167

168168

169-
def generate_block_hash_extra_keys(
170-
request: Request, start_token_idx: int, end_token_idx: int,
171-
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
172-
"""Generate extra keys for the block hash. The extra keys can come from
173-
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
174-
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
175-
indicate a mm input contained in the block and its starting offset in
176-
the block tokens.
169+
def need_extra_keys(request: Request) -> bool:
170+
"""Check whether the blocks allocated to this request need extra hash keys.
171+
172+
Args:
173+
request (Request): The request.
174+
175+
Returns:
176+
bool: Whether blocks allocated to this request need extra hash keys.
177+
"""
178+
179+
# Multimodal requests need to include the MM hash.
180+
# LoRA requests need to include the LoRA ID.
181+
return bool(request.mm_positions) or (request.lora_request is not None)
182+
183+
184+
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
185+
end_token_idx: int,
186+
start_mm_idx: int) -> Tuple[List[Any], int]:
187+
"""Generate extra keys related to MultiModal request for block hash
188+
computation. For multi-modal inputs, the extra keys are
189+
(mm_hash, start_offset) that indicate a mm input contained in the
190+
block and its starting offset in the block tokens.
177191
178192
Args:
179193
request: The request object.
@@ -184,10 +198,11 @@ def generate_block_hash_extra_keys(
184198
Returns:
185199
A tuple of extra keys and the next multi-modal index.
186200
"""
201+
extra_keys: List[Any] = []
187202

188203
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
189204
if not mm_positions:
190-
return None, start_mm_idx
205+
return extra_keys, start_mm_idx
191206

192207
if mm_positions and len(mm_positions) != len(mm_hashes):
193208
raise ValueError(
@@ -200,14 +215,13 @@ def generate_block_hash_extra_keys(
200215
# range. This usually happens in the late prefill phase and decoding phase.
201216
if mm_positions[-1]["offset"] + mm_positions[-1][
202217
"length"] < start_token_idx:
203-
return None, start_mm_idx
218+
return extra_keys, start_mm_idx
204219

205220
# Support start_mm_idx == -1 to indicate the last mm input.
206221
if start_mm_idx < 0:
207222
assert -start_mm_idx <= len(mm_positions)
208223
start_mm_idx = len(mm_positions) + start_mm_idx
209224

210-
extra_keys = []
211225
curr_mm_idx = start_mm_idx
212226
while mm_positions and curr_mm_idx < len(mm_positions):
213227
assert mm_hashes[curr_mm_idx] is not None
@@ -233,7 +247,50 @@ def generate_block_hash_extra_keys(
233247
else:
234248
# This block has not reached the current mm input.
235249
break
236-
return tuple(extra_keys), curr_mm_idx
250+
return extra_keys, curr_mm_idx
251+
252+
253+
def _gen_lora_extra_hash_keys(request: Request) -> List[int]:
254+
"""Generate extra keys related to LoRA for block hash computation.
255+
256+
Args:
257+
request: The request object.
258+
259+
Returns:
260+
Return LoRA id of the request if it is a LoRA request. Return empty
261+
list otherwise.
262+
"""
263+
if not request.lora_request:
264+
return []
265+
return [request.lora_request.lora_int_id]
266+
267+
268+
def generate_block_hash_extra_keys(
269+
request: Request, start_token_idx: int, end_token_idx: int,
270+
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
271+
"""Generate extra keys for the block hash. The extra keys can come from
272+
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
273+
274+
Args:
275+
request: The request object.
276+
start_token_idx: The start token index of the block.
277+
end_token_idx: The end token index of the block.
278+
start_mm_idx: The start multi-modal index of the block.
279+
280+
Returns:
281+
A tuple of extra keys and the next multi-modal index.
282+
"""
283+
mm_extra_keys: List[Any]
284+
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
285+
request, start_token_idx, end_token_idx, start_mm_idx)
286+
lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request)
287+
288+
extra_keys: List[Any] = lora_extra_keys + mm_extra_keys
289+
290+
if not extra_keys:
291+
return None, new_start_mm_idx
292+
293+
return tuple(extra_keys), new_start_mm_idx
237294

238295

239296
def hash_block_tokens(
@@ -245,9 +302,6 @@ def hash_block_tokens(
245302
prefix caching. We use LRU cache for this function to avoid recomputing
246303
hash values for the same block contents.
247304
248-
TODO: Support arbitrary metadata so that we could support more
249-
features such as LoRA adapter.
250-
251305
Args:
252306
parent_block_hash: The hash of the parent block. None
253307
if this is the first block.
@@ -276,14 +330,9 @@ def hash_request_tokens(block_size: int,
276330
The list of computed hash values.
277331
"""
278332
token_ids = request.all_token_ids
279-
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
280-
if mm_positions and len(mm_positions) != len(mm_hashes):
281-
raise ValueError(
282-
"The number of multi-modal positions and hashes must match.")
283333

284-
# TODO: Extend this to support other features such as LoRA.
285-
need_extra_keys = bool(mm_positions)
286-
extra_keys = None
334+
req_need_extra_keys = need_extra_keys(request)
335+
req_extra_keys = None
287336
curr_mm_idx = 0
288337

289338
ret = []
@@ -295,13 +344,13 @@ def hash_request_tokens(block_size: int,
295344
if len(block_token_ids) < block_size:
296345
break
297346

298-
# Add extra keys if the block is a multi-modal block.
299-
if need_extra_keys:
300-
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
347+
if req_need_extra_keys:
348+
# MM and LoRA requests need extra keys for block-hash computation.
349+
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
301350
request, start, end, curr_mm_idx)
302351

303352
block_hash = hash_block_tokens(parent_block_hash_value,
304-
block_token_ids, extra_keys)
353+
block_token_ids, req_extra_keys)
305354
ret.append(block_hash)
306355
parent_block_hash_value = block_hash.hash_value
307356
return ret

0 commit comments

Comments
 (0)