Skip to content

Commit b63f214

Browse files
andylolu2jeejeelee
andauthored
[LoRA] LoRA cuda graph specialization (vllm-project#25914)
Signed-off-by: Andy Lo <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent f32bf75 commit b63f214

File tree

9 files changed

+122
-34
lines changed

9 files changed

+122
-34
lines changed

tests/lora/test_chatglm3_tp.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import vllm
5+
import vllm.config
56
from vllm.lora.request import LoRARequest
67

78
from ..utils import create_new_process_for_each_test, multi_gpu_test
@@ -53,9 +54,10 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
5354
def test_chatglm3_lora(chatglm3_lora_files):
5455
llm = vllm.LLM(
5556
MODEL_PATH,
56-
max_model_len=1024,
57+
max_model_len=512,
5758
enable_lora=True,
58-
max_loras=4,
59+
max_loras=2,
60+
max_num_seqs=16,
5961
max_lora_rank=64,
6062
trust_remote_code=True,
6163
)
@@ -72,13 +74,17 @@ def test_chatglm3_lora(chatglm3_lora_files):
7274
def test_chatglm3_lora_tp4(chatglm3_lora_files):
7375
llm = vllm.LLM(
7476
MODEL_PATH,
75-
max_model_len=1024,
77+
max_model_len=512,
7678
enable_lora=True,
77-
max_loras=4,
79+
max_loras=2,
7880
max_lora_rank=64,
81+
max_num_seqs=16,
7982
tensor_parallel_size=4,
8083
trust_remote_code=True,
8184
fully_sharded_loras=False,
85+
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
86+
cudagraph_specialize_lora=False,
87+
),
8288
)
8389

8490
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
@@ -96,14 +102,17 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
96102
# more GPU memory causing vLLM to OOM
97103
llm = vllm.LLM(
98104
MODEL_PATH,
99-
max_model_len=1024,
105+
max_model_len=512,
100106
enable_lora=True,
101-
max_loras=4,
107+
max_loras=2,
102108
max_lora_rank=64,
103109
tensor_parallel_size=4,
104110
trust_remote_code=True,
105111
fully_sharded_loras=True,
106112
gpu_memory_utilization=0.85,
113+
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
114+
cudagraph_specialize_lora=False,
115+
),
107116
)
108117
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
109118
for i in range(len(EXPECTED_LORA_OUTPUT)):

tests/lora/test_llama_tp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import subprocess
44
import sys
55

6+
import pytest
7+
68
import vllm
9+
import vllm.config
710
from vllm import LLM
811
from vllm.lora.request import LoRARequest
912
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@@ -100,14 +103,18 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
100103

101104

102105
@create_new_process_for_each_test()
103-
def test_llama_lora(sql_lora_files):
106+
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
107+
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
104108
llm = vllm.LLM(
105109
MODEL_PATH,
106110
tokenizer=sql_lora_files,
107111
enable_lora=True,
108112
# also test odd max_num_seqs
109113
max_num_seqs=13,
110114
max_loras=4,
115+
compilation_config=vllm.config.CompilationConfig(
116+
cudagraph_specialize_lora=cudagraph_specialize_lora,
117+
),
111118
)
112119
generate_and_test(llm, sql_lora_files)
113120

vllm/config/compilation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,14 @@ class CompilationConfig:
366366
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
367367
FULL_AND_PIECEWISE instead.
368368
"""
369+
cudagraph_specialize_lora: bool = True
370+
"""Whether to create separate cuda graphs for cases with and without active
371+
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
372+
for all cases, incurring the overhead of running LoRA ops even when no
373+
adapters are active. Setting this to True will remove this overhead at the
374+
cost of increased startup time and slightly higher memory usage.
375+
When `enable_lora` is False, this option has no effect.
376+
"""
369377

370378
use_inductor_graph_partition: bool = False
371379
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.

vllm/forward_context.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,19 @@ class BatchDescriptor(NamedTuple):
4040
False can also be used for an uniform decode batch to dispatch to the
4141
cudagraph supporting non-uniform batches.
4242
"""
43+
has_lora: bool = False
44+
"""
45+
Whether this batch has active LoRA adapters.
46+
"""
4347

4448
@property
4549
def non_uniform(self) -> "BatchDescriptor":
4650
"""
4751
Return a non-uniform version of current batch descriptor.
4852
"""
49-
return BatchDescriptor(self.num_tokens, uniform_decode=False)
53+
return BatchDescriptor(
54+
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
55+
)
5056

5157

5258
def _compute_sp_num_tokens(

vllm/lora/ops/triton_ops/lora_shrink_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def _lora_shrink(
169169
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
170170
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
171171

172+
output_tensor.zero_()
173+
172174
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
173175
_get_lora_a_ptr(lora_a_weights, inputs.device)
174176
)

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,18 @@ def add_lora_linear(
205205

206206
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
207207

208-
if buffer is None:
209-
r = lora_b_stacked[0].size(-1)
210-
# We set the buffer to be float32 by default, refer to:
211-
# https://github.com/triton-lang/triton/issues/1387
212-
buffer = torch.zeros( # type: ignore
213-
(len(output_slices), x.size(0), r),
214-
dtype=torch.float32,
215-
device=x.device,
216-
)
208+
assert buffer is None, (
209+
"To minimize overhead, the buffer should be created by "
210+
".add_lora_linear() instead of being passed in."
211+
)
212+
r = lora_b_stacked[0].size(-1)
213+
# We set the buffer to be float32 by default, refer to:
214+
# https://github.com/triton-lang/triton/issues/1387
215+
# Note: buffer is zeroed inside the shrink op
216+
buffer = torch.empty(
217+
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
218+
)
219+
217220
self.add_shrink(
218221
buffer, # type: ignore
219222
x,
@@ -260,10 +263,15 @@ def add_lora_logits(
260263
y = y.view(-1, y.shape[-1])
261264
x = x.view(-1, x.shape[-1])
262265
r = lora_b_stacked.size(-1)
263-
if buffer is None:
264-
# We set the buffer to be float32 by default, refer to:
265-
# https://github.com/triton-lang/triton/issues/1387
266-
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
266+
267+
assert buffer is None, (
268+
"To minimize overhead, the buffer should be created by "
269+
".add_lora_linear() instead of being passed in."
270+
)
271+
# We set the buffer to be float32 by default, refer to:
272+
# https://github.com/triton-lang/triton/issues/1387
273+
# Note: buffer is zeroed inside the shrink op
274+
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
267275

268276
lora_shrink(
269277
x,

vllm/v1/cudagraph_dispatcher.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from itertools import product
34

45
from vllm.config import CUDAGraphMode, VllmConfig
56
from vllm.forward_context import BatchDescriptor
@@ -67,14 +68,27 @@ def initialize_cudagraph_keys(
6768
):
6869
# This should be called only after attention backend is initialized.
6970

71+
# LoRA activation cases to specialize the cuda graphs on
72+
if self.vllm_config.lora_config:
73+
if self.compilation_config.cudagraph_specialize_lora:
74+
lora_cases = [True, False]
75+
else:
76+
lora_cases = [True]
77+
else:
78+
lora_cases = [False]
79+
7080
# Note: we create all valid keys for cudagraph here but do not
7181
# guarantee all keys would be used. For example, if we allow lazy
7282
# capturing in future PR, some keys may never be triggered.
7383
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
74-
for bs in self.compilation_config.cudagraph_capture_sizes:
84+
for bs, has_lora in product(
85+
self.compilation_config.cudagraph_capture_sizes, lora_cases
86+
):
7587
self.add_cudagraph_key(
7688
cudagraph_mode.mixed_mode(),
77-
BatchDescriptor(num_tokens=bs, uniform_decode=False),
89+
BatchDescriptor(
90+
num_tokens=bs, uniform_decode=False, has_lora=has_lora
91+
),
7892
)
7993

8094
# if decode cudagraph mode is FULL, and we don't already have mixed
@@ -92,10 +106,12 @@ def initialize_cudagraph_keys(
92106
for x in self.compilation_config.cudagraph_capture_sizes
93107
if x <= max_num_tokens and x >= uniform_decode_query_len
94108
]
95-
for bs in cudagraph_capture_sizes_for_decode:
109+
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
96110
self.add_cudagraph_key(
97111
CUDAGraphMode.FULL,
98-
BatchDescriptor(num_tokens=bs, uniform_decode=True),
112+
BatchDescriptor(
113+
num_tokens=bs, uniform_decode=True, has_lora=has_lora
114+
),
99115
)
100116
self.keys_initialized = True
101117

vllm/v1/worker/gpu_model_runner.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Iterator
99
from contextlib import contextmanager
1010
from copy import deepcopy
11+
from itertools import product
1112
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
1213

1314
import numpy as np
@@ -2469,7 +2470,9 @@ def execute_model(
24692470
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
24702471
)
24712472
batch_descriptor = BatchDescriptor(
2472-
num_tokens=num_input_tokens, uniform_decode=uniform_decode
2473+
num_tokens=num_input_tokens,
2474+
uniform_decode=uniform_decode,
2475+
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
24732476
)
24742477
cudagraph_runtime_mode, batch_descriptor = (
24752478
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
@@ -3193,6 +3196,7 @@ def _dummy_run(
31933196
is_profile: bool = False,
31943197
create_mixed_batch: bool = False,
31953198
remove_lora: bool = True,
3199+
activate_lora: bool = False,
31963200
) -> tuple[torch.Tensor, torch.Tensor]:
31973201
"""
31983202
Run a dummy forward pass to warm up/profile run or capture the
@@ -3215,6 +3219,7 @@ def _dummy_run(
32153219
create_mixed_batch: If True, create a mixed batch with both decode
32163220
(1 token) and prefill (multiple tokens) requests.
32173221
remove_lora: If False, dummy LoRAs are not destroyed after the run
3222+
activate_lora: If False, dummy_run is performed without LoRAs.
32183223
"""
32193224
assert (
32203225
cudagraph_runtime_mode is None
@@ -3364,7 +3369,7 @@ def _dummy_run(
33643369
attn_metadata[layer_name] = attn_metadata_i
33653370

33663371
with self.maybe_dummy_run_with_lora(
3367-
self.lora_config, num_scheduled_tokens, remove_lora
3372+
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
33683373
):
33693374
# Make sure padding doesn't exceed max_num_tokens
33703375
assert num_tokens_after_padding <= self.max_num_tokens
@@ -3411,6 +3416,7 @@ def _dummy_run(
34113416
BatchDescriptor(
34123417
num_tokens=num_tokens_after_padding,
34133418
uniform_decode=uniform_decode,
3419+
has_lora=activate_lora and self.lora_config is not None,
34143420
)
34153421
)
34163422
if not is_profile
@@ -3769,10 +3775,21 @@ def freeze_gc():
37693775
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
37703776
cudagraph_mode = self.compilation_config.cudagraph_mode
37713777
assert cudagraph_mode is not None
3778+
3779+
if self.lora_config:
3780+
if self.compilation_config.cudagraph_specialize_lora:
3781+
lora_cases = [True, False]
3782+
else:
3783+
lora_cases = [True]
3784+
else:
3785+
lora_cases = [False]
3786+
37723787
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
37733788
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
37743789

3775-
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
3790+
compilation_cases = list(
3791+
product(reversed(self.cudagraph_batch_sizes), lora_cases)
3792+
)
37763793
self._capture_cudagraphs(
37773794
compilation_cases,
37783795
cudagraph_runtime_mode=cudagraph_runtime_mode,
@@ -3793,7 +3810,9 @@ def freeze_gc():
37933810
for x in self.cudagraph_batch_sizes
37943811
if max_num_tokens >= x >= self.uniform_decode_query_len
37953812
]
3796-
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
3813+
compilation_cases_decode = list(
3814+
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
3815+
)
37973816
self._capture_cudagraphs(
37983817
compilation_cases=compilation_cases_decode,
37993818
cudagraph_runtime_mode=CUDAGraphMode.FULL,
@@ -3823,7 +3842,7 @@ def freeze_gc():
38233842

38243843
def _capture_cudagraphs(
38253844
self,
3826-
compilation_cases: list[int],
3845+
compilation_cases: list[tuple[int, bool]],
38273846
cudagraph_runtime_mode: CUDAGraphMode,
38283847
uniform_decode: bool,
38293848
):
@@ -3844,7 +3863,7 @@ def _capture_cudagraphs(
38443863
)
38453864

38463865
# We skip EPLB here since we don't want to record dummy metrics
3847-
for num_tokens in compilation_cases:
3866+
for num_tokens, activate_lora in compilation_cases:
38483867
# We currently only capture ubatched graphs when its a FULL
38493868
# cudagraph, a uniform decode batch, and the number of tokens
38503869
# is above the threshold. Otherwise we just capture a non-ubatched
@@ -3875,6 +3894,7 @@ def _capture_cudagraphs(
38753894
allow_microbatching=allow_microbatching,
38763895
skip_eplb=True,
38773896
remove_lora=False,
3897+
activate_lora=activate_lora,
38783898
)
38793899
self._dummy_run(
38803900
num_tokens,
@@ -3883,6 +3903,7 @@ def _capture_cudagraphs(
38833903
allow_microbatching=allow_microbatching,
38843904
skip_eplb=True,
38853905
remove_lora=False,
3906+
activate_lora=activate_lora,
38863907
)
38873908
self.maybe_remove_all_loras(self.lora_config)
38883909

vllm/v1/worker/lora_model_runner_mixin.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def maybe_setup_dummy_loras(
120120

121121
@contextmanager
122122
def maybe_select_dummy_loras(
123-
self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray
123+
self,
124+
lora_config: LoRAConfig | None,
125+
num_scheduled_tokens: np.ndarray,
126+
activate_lora: bool = True,
124127
):
125128
if lora_config is None:
126129
yield
@@ -133,7 +136,12 @@ def maybe_select_dummy_loras(
133136

134137
# Make prompt lora mapping
135138
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
136-
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1
139+
if activate_lora:
140+
prompt_lora_mapping = (
141+
np.arange(num_reqs, dtype=np.int32) % num_loras
142+
) + 1
143+
else:
144+
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
137145

138146
# Make token lora mapping
139147
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
@@ -159,11 +167,14 @@ def maybe_dummy_run_with_lora(
159167
self,
160168
lora_config: LoRAConfig | None,
161169
num_scheduled_tokens: np.ndarray,
170+
activate_lora: bool = True,
162171
remove_lora: bool = True,
163172
):
164173
with (
165174
self.maybe_setup_dummy_loras(lora_config, remove_lora),
166-
self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens),
175+
self.maybe_select_dummy_loras(
176+
lora_config, num_scheduled_tokens, activate_lora
177+
),
167178
):
168179
yield
169180

0 commit comments

Comments
 (0)