-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] LoRA Support #10957
Open
varun-sundar-rabindranath
wants to merge
1
commit into
vllm-project:main
Choose a base branch
from
neuralmagic:varun/v1-lora-support-attempt-2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[V1] LoRA Support #10957
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,14 +169,28 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: | |
return ret | ||
|
||
|
||
def generate_block_hash_extra_keys( | ||
request: Request, start_token_idx: int, end_token_idx: int, | ||
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: | ||
"""Generate extra keys for the block hash. The extra keys can come from | ||
the multi-modal inputs and request specific metadata (e.g., LoRA ID). | ||
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that | ||
indicate a mm input contained in the block and its starting offset in | ||
the block tokens. | ||
def need_extra_keys(request: Request) -> bool: | ||
"""Check whether the blocks allocated to this request need extra hash keys. | ||
|
||
Args: | ||
request (Request): The request. | ||
|
||
Returns: | ||
bool: Whether blocks allocated to this request need extra hash keys. | ||
""" | ||
|
||
# Multimodal requests need to include the MM hash. | ||
# LoRA requests need to include the LoRA ID. | ||
return bool(request.mm_positions) or (request.lora_request is not None) | ||
|
||
|
||
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, | ||
end_token_idx: int, | ||
start_mm_idx: int) -> Tuple[List[Any], int]: | ||
"""Generate extra keys related to MultiModal request for block hash | ||
computation. For multi-modal inputs, the extra keys are | ||
(mm_hash, start_offset) that indicate a mm input contained in the | ||
block and its starting offset in the block tokens. | ||
|
||
Args: | ||
request: The request object. | ||
|
@@ -187,10 +201,11 @@ def generate_block_hash_extra_keys( | |
Returns: | ||
A tuple of extra keys and the next multi-modal index. | ||
""" | ||
extra_keys: List[Any] = [] | ||
|
||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes | ||
if not mm_positions: | ||
return None, start_mm_idx | ||
return extra_keys, start_mm_idx | ||
|
||
if mm_positions and len(mm_positions) != len(mm_hashes): | ||
raise ValueError( | ||
|
@@ -203,14 +218,13 @@ def generate_block_hash_extra_keys( | |
# range. This usually happens in the late prefill phase and decoding phase. | ||
if mm_positions[-1]["offset"] + mm_positions[-1][ | ||
"length"] < start_token_idx: | ||
return None, start_mm_idx | ||
return extra_keys, start_mm_idx | ||
|
||
# Support start_mm_idx == -1 to indicate the last mm input. | ||
if start_mm_idx < 0: | ||
assert -start_mm_idx <= len(mm_positions) | ||
start_mm_idx = len(mm_positions) + start_mm_idx | ||
|
||
extra_keys = [] | ||
curr_mm_idx = start_mm_idx | ||
while mm_positions and curr_mm_idx < len(mm_positions): | ||
assert mm_hashes[curr_mm_idx] is not None | ||
|
@@ -236,7 +250,50 @@ def generate_block_hash_extra_keys( | |
else: | ||
# This block has not reached the current mm input. | ||
break | ||
return tuple(extra_keys), curr_mm_idx | ||
return extra_keys, curr_mm_idx | ||
|
||
|
||
def _gen_lora_extra_hash_keys(request: Request) -> List[int]: | ||
"""Generate extra keys related to LoRA for block hash computation. | ||
|
||
Args: | ||
request: The request object. | ||
|
||
Returns: | ||
Return LoRA id of the request if it is a LoRA request. Return empty | ||
list otherwise. | ||
""" | ||
if not request.lora_request: | ||
return [] | ||
return [request.lora_request.lora_int_id] | ||
|
||
|
||
def generate_block_hash_extra_keys( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor for using prefix caching with LoRA. |
||
request: Request, start_token_idx: int, end_token_idx: int, | ||
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: | ||
"""Generate extra keys for the block hash. The extra keys can come from | ||
the multi-modal inputs and request specific metadata (e.g., LoRA ID). | ||
|
||
Args: | ||
request: The request object. | ||
start_token_idx: The start token index of the block. | ||
end_token_idx: The end token index of the block. | ||
start_mm_idx: The start multi-modal index of the block. | ||
|
||
Returns: | ||
A tuple of extra keys and the next multi-modal index. | ||
""" | ||
mm_extra_keys: List[Any] | ||
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( | ||
request, start_token_idx, end_token_idx, start_mm_idx) | ||
lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request) | ||
|
||
extra_keys: List[Any] = lora_extra_keys + mm_extra_keys | ||
|
||
if not extra_keys: | ||
return None, new_start_mm_idx | ||
|
||
return tuple(extra_keys), new_start_mm_idx | ||
|
||
|
||
def hash_block_tokens( | ||
|
@@ -248,9 +305,6 @@ def hash_block_tokens( | |
prefix caching. We use LRU cache for this function to avoid recomputing | ||
hash values for the same block contents. | ||
|
||
TODO: Support arbitrary metadata so that we could support more | ||
features such as LoRA adapter. | ||
|
||
Args: | ||
parent_block_hash: The hash of the parent block. None | ||
if this is the first block. | ||
|
@@ -279,14 +333,9 @@ def hash_request_tokens(block_size: int, | |
The list of computed hash values. | ||
""" | ||
token_ids = request.all_token_ids | ||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes | ||
if mm_positions and len(mm_positions) != len(mm_hashes): | ||
raise ValueError( | ||
"The number of multi-modal positions and hashes must match.") | ||
|
||
# TODO: Extend this to support other features such as LoRA. | ||
need_extra_keys = bool(mm_positions) | ||
extra_keys = None | ||
req_need_extra_keys = need_extra_keys(request) | ||
req_extra_keys = None | ||
curr_mm_idx = 0 | ||
|
||
ret = [] | ||
|
@@ -298,13 +347,13 @@ def hash_request_tokens(block_size: int, | |
if len(block_token_ids) < block_size: | ||
break | ||
|
||
# Add extra keys if the block is a multi-modal block. | ||
if need_extra_keys: | ||
extra_keys, curr_mm_idx = generate_block_hash_extra_keys( | ||
if req_need_extra_keys: | ||
# MM and LoRA requests need extra keys for block-hash computation. | ||
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( | ||
request, start, end, curr_mm_idx) | ||
|
||
block_hash = hash_block_tokens(parent_block_hash_value, | ||
block_token_ids, extra_keys) | ||
block_token_ids, req_extra_keys) | ||
ret.append(block_hash) | ||
parent_block_hash_value = block_hash.hash_value | ||
return ret | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor : introduce
_gather_logits()
thatLogitsProcessorWithLoRA
also uses.