-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] LoRA Support #10957
base: main
Are you sure you want to change the base?
[V1] LoRA Support #10957
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
vllm/v1/engine/detokenizer.py
Outdated
tokenizer_name=tokenizer_name, | ||
tokenizer_mode=tokenizer_mode, | ||
trust_remote_code=trust_remote_code, | ||
revision=revision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm/v1/worker/gpu_model_runner.py
Outdated
@@ -602,269 +633,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: | |||
if batch_size <= size: | |||
return size | |||
return None | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor : Moved CachedRequestState and InputBatch to input_batch.py. It looked like a good refactor to reduce file-size. In this PR it lets both gpu_model_runner.py
and lora_model_runner_mixin.py
import these datastructures from InputBatch.
vllm/v1/worker/input_batch.py
Outdated
max_num_logprobs=self.max_num_logprobs, | ||
) | ||
|
||
def make_lora_inputs(self, num_scheduled_tokens: np.array) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added for LoRA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this! Left a few early comments. Will look into more details later.
vllm/v1/core/scheduler.py
Outdated
if self.lora_config: | ||
requested_loras = \ | ||
set(req.lora_request.lora_int_id \ | ||
for req in scheduled_running_reqs \ | ||
if req.lora_request and \ | ||
req.lora_request.lora_int_id > 0) | ||
assert len(requested_loras) <= self.lora_config.max_loras |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we cache this state and incrementally update it whenever new request joins or finishes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I explored this a bit. Tracking the additions and deletions to the running queue in the current code is hard. The updates happen in more than one place (with new requests, finish requests and requests moving between running to preempted state and back). one way is to replace the append/remove/pop with
self.running.<operation>()
if lora_config:
update_active_loras()
A better way is to subclass List and after any Create, Update, Delete operation we can update the active LoRAs. This is a considerable change. I believe we can do this after some profiling to see how bad this code is.
For the moment, I think this localized update is nicer as it doesn't introduce a bunch of if self.lora_config
s .
Is there a better way I am missing ?
vllm/v1/worker/input_batch.py
Outdated
req_lora_mapping = self.request_lora_mapping[:self.num_reqs] | ||
prompt_lora_mapping = tuple(req_lora_mapping) | ||
token_lora_mapping = tuple( | ||
req_lora_mapping.repeat(num_scheduled_tokens)) | ||
|
||
active_lora_ids: set[int] = set(np.unique(req_lora_mapping)) | ||
active_lora_requests: set[LoRARequest] = \ | ||
set({lr for lr in self.lora_requests \ | ||
if lr.lora_int_id in active_lora_ids}) | ||
# Update lora requests | ||
self.lora_requests = active_lora_requests | ||
|
||
return prompt_lora_mapping, token_lora_mapping, self.lora_requests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work with tunica kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use the punica SGMV kernel always (as set in
lora_mapping = LoRAMapping(token_lora_mapping, |
The SGMV kernel codepath merges the sequences that have the same lora-id together in
Line 28 in 7406274
def compute_meta( |
I'll profile with both SGMV and BGMV kernels and choose the best. For now, SGMV looked like a good default/placeholder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding V0 LoRA,SGMV implements group gemm, which provides better performance for prefill stage . BGMV implements group gemv, which is better optimized for decoding stage . If only one can be chosen, SGMV is likely more suitable.
d21df49
to
797dab2
Compare
This pull request has merge conflicts that must be resolved before it can be |
d4d70cc
to
550da53
Compare
51ef92a
to
3200ed4
Compare
This pull request has merge conflicts that must be resolved before it can be |
3200ed4
to
48e9185
Compare
logits = lm_head.linear_method.apply(lm_head, | ||
hidden_states, | ||
bias=embedding_bias) | ||
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor : introduce _gather_logits()
that LogitsProcessorWithLoRA
also uses.
return [request.lora_request.lora_int_id] | ||
|
||
|
||
def generate_block_hash_extra_keys( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor for using prefix caching with LoRA.
del hidden_states, logits | ||
self.encoder_cache.clear() | ||
# For profile, have maximum num_reqs and that collectively have | ||
# maximum num_tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setup num_scheduled_tokens
for initializing LoRA for profile_run. @ywang96 will this change interfere with the multi modal setup above ? Can you point me to a test / command that I should confirm that it works ? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bump.
I'd like some review on this part please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varun-sundar-rabindranath Hey sorry for the delayed review, but this should be okay since you're just moving self.encoder_cache.clear()
later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v1/core LGTM
This pull request has merge conflicts that must be resolved before it can be |
d04d56d
to
4fc158c
Compare
tests/lora/test_minicpmv.py
Outdated
# test in a package | ||
pass | ||
|
||
|
||
@pytest.mark.xfail( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minicpmv does not support v1 yet, see:https://docs.vllm.ai/en/latest/models/supported_models.html#id3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks for the call out 👍 . I was hoping to catch these errors when the PR goes /ready
.
b57ca04
to
5fc59ef
Compare
6576b44
to
6e81bd8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varun-sundar-rabindranath Thanks for doing this! The code looks very clean to me. Left some minor comments and questions. Please take a look!
@@ -182,6 +181,14 @@ def schedule(self) -> "SchedulerOutput": | |||
self.encoder_cache_manager.allocate(request, i) | |||
encoder_budget = new_encoder_budget | |||
|
|||
# Record the LoRAs in scheduled_running_reqs | |||
requested_loras: Set[int] = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why don't we cache this state and updates it incrementally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a cached requested_loras
was more invasive and very cumbersome. Incremental updates to this state will require updating this state when the requests move from / to running queue. The updates to the running queue happen in many places in the file and tacking on an update to the requested_loras
in all the places was cumbersome and seemed bug-prone.
The idea was to have these localized set of changes for LoRA and to make optimizations if necessary.
vllm/v1/worker/gpu_input_batch.py
Outdated
# only update request_lora_mapping. Defer the updates | ||
# to lora_requests to prepare_lora_inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we do so? I think we can maintain an inverse index like Dict: Lora_id --> Set[Request id]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not want to introduce too many data structures and lora_requests
was used only in prepare_lora_inputs
.
I have updated the code include lora_id_to_lora_request
and lora_id_to_request_ids
dicts to track the removal properly. This is probably better for consistency / guarantees.
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
3ba33fd
to
7037c91
Compare
Changes:
Benchmarks:
Machine : 1xA100
V1
Throughput: 2.42 requests/s, 1225.95 total tokens/s, 628.29 output tokens/s
V0
Throughput: 5.95 requests/s, 3021.90 total tokens/s, 1548.71 output tokens/s
The performance gap between V0 and V1 is due to CUDA Graphs. Refer to benchmarks in reference PR #11613 .