Skip to content

Commit

Permalink
Add lora support
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
Varun Sundar Rabindranath committed Dec 17, 2024
1 parent f9ecbb1 commit 569fb69
Show file tree
Hide file tree
Showing 19 changed files with 401 additions and 40 deletions.
17 changes: 17 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,20 @@ def get_model_patched(**kwargs):
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)


@pytest.fixture(params=[True])
def run_with_both_engines_lora(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")

if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield
9 changes: 9 additions & 0 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
Expand All @@ -62,6 +70,7 @@ def test_baichuan_lora(baichuan_lora_files):
assert output2[i] == expected_lora_output[i]


@pytest.mark.skip_v1
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
num_gpus_available, fully_sharded):
Expand Down
12 changes: 12 additions & 0 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import pytest

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -45,6 +47,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
Expand All @@ -64,6 +74,7 @@ def test_chatglm3_lora(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
Expand All @@ -85,6 +96,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
Expand Down
12 changes: 12 additions & 0 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

import pytest
import ray

import vllm
Expand Down Expand Up @@ -71,6 +72,14 @@ def generate_and_test(llm, sql_lora_files):
print("removing lora")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

Expand Down Expand Up @@ -111,6 +120,7 @@ def get_num_gpu_blocks_no_lora():
"less when using lora than when not using lora")


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4(sql_lora_files):
Expand All @@ -126,6 +136,7 @@ def test_llama_lora_tp4(sql_lora_files):
generate_and_test(llm, sql_lora_files)


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
Expand All @@ -142,6 +153,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
generate_and_test(llm, sql_lora_files)


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_lora_bias_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
Expand Down
10 changes: 10 additions & 0 deletions tests/lora/test_phi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -46,6 +48,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
Expand Down
9 changes: 9 additions & 0 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def format_prompt_tuples(prompt):
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
Expand Down Expand Up @@ -163,6 +171,7 @@ def expect_match(output, expected_output):
cleanup_dist_env_and_memory()


@pytest.mark.skip_v1
@pytest.mark.parametrize("model", MODELS)
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
model):
Expand Down
15 changes: 10 additions & 5 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:

# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
metadata_hash = (None if not request.lora_request else
request.lora_request.lora_int_id)
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
request.all_token_ids,
parent_hash=metadata_hash)

for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
Expand Down Expand Up @@ -377,12 +380,14 @@ def _cache_full_blocks(
prev_block: The previous block in the chain.
"""
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
metadata_hash = (None if request.lora_request is None else
request.lora_request.lora_int_id)
parent_hash = metadata_hash
if prev_block is not None:
# Previous block must have a block hash because it must be
# a full, cached block.
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
parent_hash = prev_block.block_hash.hash_value

for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i
Expand All @@ -397,9 +402,9 @@ def _cache_full_blocks(
f"{request.request_id}({request})")

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
block_hash = hash_block_tokens(parent_hash, block_tokens)

# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value
parent_hash = block_hash.hash_value
35 changes: 18 additions & 17 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,52 +159,53 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]:
return ret


def hash_block_tokens(parent_block_hash: Optional[int],
def hash_block_tokens(parent_hash: Optional[int],
curr_block_token_ids: Sequence[int]) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
TODO: Support arbitrary metadata so that we could support more
features such as LoRA adapter.
"""Computes a hash value corresponding to the contents of a block, in
the context of the contents of the preceding block(s) and maybe also
some metadata. The hash value is used for prefix caching. We use LRU
cache for this function to avoid recomputing hash values for the same
block contents.
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
parent_hash: The hash of the parent block if this is not the
first block. If it is the first block, parent hash could
be None or be the hash of some relevant metadata.
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
return BlockHashType(hash((parent_hash, *curr_block_token_ids)),
tuple(curr_block_token_ids))


def hash_request_tokens(block_size: int,
token_ids: Sequence[int]) -> List[BlockHashType]:
def hash_request_tokens(
block_size: int,
token_ids: Sequence[int],
parent_hash: Optional[int] = None) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
token_ids: A sequence of token ids in the request.
parent_hash: Seed hash value. For example, when using LoRA this is
the hash value of the LoRA ID.
Returns:
The list of computed hash values.
"""
ret = []
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids)
block_hash = hash_block_tokens(parent_hash, block_token_ids)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
parent_hash = block_hash.hash_value
return ret
24 changes: 22 additions & 2 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -32,8 +33,6 @@ def __init__(
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
# TODO: Support LoRA.
assert lora_config is None, "V1 does not support LoRA yet."

# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
Expand Down Expand Up @@ -173,6 +172,14 @@ def schedule(self) -> "SchedulerOutput":
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget

# Record the LoRAs in scheduled_running_reqs
requested_loras: Set[int] = set()
if self.lora_config:
requested_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras

# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting:
Expand All @@ -184,6 +191,17 @@ def schedule(self) -> "SchedulerOutput":
break

request = self.waiting[0]

# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request:
req_lora_id = request.lora_request.lora_int_id
if len(requested_loras) == self.lora_config.max_loras and (
req_lora_id not in requested_loras):
# cannot schedule
break
requested_loras.add(req_lora_id)

# Get already-cached tokens.
computed_blocks = self.kv_cache_manager.get_computed_blocks(
request)
Expand Down Expand Up @@ -520,6 +538,7 @@ class NewRequestData:
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
lora_request: Optional[LoRARequest]

@classmethod
def from_request(
Expand All @@ -537,6 +556,7 @@ def from_request(
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
lora_request=request.lora_request,
)


Expand Down
Loading

0 comments on commit 569fb69

Please sign in to comment.