11import pytest
2- import torch
32
43import vllm_gaudi .extension .environment as environment
54
@@ -44,17 +43,19 @@ def get_vllm_config():
4443 )
4544 return vllm_config
4645
46+
4747@pytest .fixture
4848def model_runner ():
4949 vllm_config = get_vllm_config ()
5050 model_config = vllm_config .model_config
5151 num_heads = model_config .get_num_kv_heads (vllm_config .parallel_config )
5252 head_size = model_config .get_head_size ()
53- environment .set_vllm_config (vllm_config )
53+ environment .set_vllm_config (vllm_config )
5454 vllm_config .compilation_config .static_forward_context = {"layer.0" : Attention (num_heads , head_size , 0.1 )}
5555 runner = HPUModelRunner (vllm_config , DEVICE )
5656 return runner
5757
58+
5859def make_new_request (req_id , prompt_token_ids , num_computed_tokens = 0 ):
5960 return NewRequestData (
6061 req_id = req_id ,
@@ -67,11 +68,13 @@ def make_new_request(req_id, prompt_token_ids, num_computed_tokens=0):
6768 lora_request = None ,
6869 )
6970
70- @pytest .mark .parametrize ("prompt1, prompt2, num_common_prefix, expected_tokens" , [
71- ([1 , 2 , 3 , 4 ], [1 , 2 , 3 , 4 ], 4 , 0 ), # full prefix cache hit
72- ([1 , 2 , 3 ], [1 , 2 , 3 , 6 , 7 ], 3 , 2 ) # partial prefix cache hit (3 cached, 2 new)
73- ])
7471
72+ @pytest .mark .parametrize (
73+ "prompt1, prompt2, num_common_prefix, expected_tokens" ,
74+ [
75+ ([1 , 2 , 3 , 4 ], [1 , 2 , 3 , 4 ], 4 , 0 ), # full prefix cache hit
76+ ([1 , 2 , 3 ], [1 , 2 , 3 , 6 , 7 ], 3 , 2 ) # partial prefix cache hit (3 cached, 2 new)
77+ ])
7578def test_prefix_cache_hits (model_runner , prompt1 , prompt2 , num_common_prefix , expected_tokens , dist_init ):
7679 req_id1 = "req1"
7780 req_id2 = "req2"
@@ -97,8 +100,7 @@ def test_prefix_cache_hits(model_runner, prompt1, prompt2, num_common_prefix, ex
97100 assert cached_state .prompt_token_ids == prompt1
98101 assert cached_state .num_computed_tokens == 0
99102 assert req_id1 in model_runner .requests
100- assert sched_out1 .num_scheduled_tokens [req_id1 ] == len (prompt1 )
101-
103+ assert sched_out1 .num_scheduled_tokens [req_id1 ] == len (prompt1 )
102104
103105 # Second request: full prefix cache hit or partial prefix cache hit
104106 new_req2 = make_new_request (req_id2 , prompt2 , num_computed_tokens = num_common_prefix )
@@ -119,13 +121,16 @@ def test_prefix_cache_hits(model_runner, prompt1, prompt2, num_common_prefix, ex
119121 cached_state = model_runner .requests [req_id2 ]
120122
121123 assert cached_state .prompt_token_ids == prompt2
122- assert cached_state .num_computed_tokens == num_common_prefix
124+ assert cached_state .num_computed_tokens == num_common_prefix
123125 assert req_id2 in model_runner .requests
124126 assert sched_out2 .num_scheduled_tokens [req_id2 ] == expected_tokens
125127
126- @pytest .mark .parametrize ("prompt, cache_first, cache_second" , [
127- ([10 , 11 , 12 ], 3 , 0 ), # first: all tokens cached, second: cache reset, all tokens need compute
128- ])
128+
129+ @pytest .mark .parametrize (
130+ "prompt, cache_first, cache_second" ,
131+ [
132+ ([10 , 11 , 12 ], 3 , 0 ), # first: all tokens cached, second: cache reset, all tokens need compute
133+ ])
129134def test_prefix_cache_reset (model_runner , prompt , cache_first , cache_second , dist_init ):
130135 req_id = "req_reset"
131136 new_req_1 = make_new_request (req_id , prompt , num_computed_tokens = cache_first )
0 commit comments