Skip to content

Commit 925eed9

Browse files
committed
Pre-commit fix
Signed-off-by: Izabela Irzynska <[email protected]>
1 parent 3a0d9ee commit 925eed9

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

tests/unit_tests/prefix_caching/test_prefix_caching.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
import torch
32

43
import vllm_gaudi.extension.environment as environment
54

@@ -44,17 +43,19 @@ def get_vllm_config():
4443
)
4544
return vllm_config
4645

46+
4747
@pytest.fixture
4848
def model_runner():
4949
vllm_config = get_vllm_config()
5050
model_config = vllm_config.model_config
5151
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
5252
head_size = model_config.get_head_size()
53-
environment.set_vllm_config(vllm_config)
53+
environment.set_vllm_config(vllm_config)
5454
vllm_config.compilation_config.static_forward_context = {"layer.0": Attention(num_heads, head_size, 0.1)}
5555
runner = HPUModelRunner(vllm_config, DEVICE)
5656
return runner
5757

58+
5859
def make_new_request(req_id, prompt_token_ids, num_computed_tokens=0):
5960
return NewRequestData(
6061
req_id=req_id,
@@ -67,11 +68,13 @@ def make_new_request(req_id, prompt_token_ids, num_computed_tokens=0):
6768
lora_request=None,
6869
)
6970

70-
@pytest.mark.parametrize("prompt1, prompt2, num_common_prefix, expected_tokens", [
71-
([1, 2, 3, 4], [1, 2, 3, 4], 4, 0), # full prefix cache hit
72-
([1, 2, 3], [1, 2, 3, 6, 7], 3, 2) # partial prefix cache hit (3 cached, 2 new)
73-
])
7471

72+
@pytest.mark.parametrize(
73+
"prompt1, prompt2, num_common_prefix, expected_tokens",
74+
[
75+
([1, 2, 3, 4], [1, 2, 3, 4], 4, 0), # full prefix cache hit
76+
([1, 2, 3], [1, 2, 3, 6, 7], 3, 2) # partial prefix cache hit (3 cached, 2 new)
77+
])
7578
def test_prefix_cache_hits(model_runner, prompt1, prompt2, num_common_prefix, expected_tokens, dist_init):
7679
req_id1 = "req1"
7780
req_id2 = "req2"
@@ -97,8 +100,7 @@ def test_prefix_cache_hits(model_runner, prompt1, prompt2, num_common_prefix, ex
97100
assert cached_state.prompt_token_ids == prompt1
98101
assert cached_state.num_computed_tokens == 0
99102
assert req_id1 in model_runner.requests
100-
assert sched_out1.num_scheduled_tokens[req_id1] == len(prompt1)
101-
103+
assert sched_out1.num_scheduled_tokens[req_id1] == len(prompt1)
102104

103105
# Second request: full prefix cache hit or partial prefix cache hit
104106
new_req2 = make_new_request(req_id2, prompt2, num_computed_tokens=num_common_prefix)
@@ -119,13 +121,16 @@ def test_prefix_cache_hits(model_runner, prompt1, prompt2, num_common_prefix, ex
119121
cached_state = model_runner.requests[req_id2]
120122

121123
assert cached_state.prompt_token_ids == prompt2
122-
assert cached_state.num_computed_tokens == num_common_prefix
124+
assert cached_state.num_computed_tokens == num_common_prefix
123125
assert req_id2 in model_runner.requests
124126
assert sched_out2.num_scheduled_tokens[req_id2] == expected_tokens
125127

126-
@pytest.mark.parametrize("prompt, cache_first, cache_second", [
127-
([10, 11, 12], 3, 0), # first: all tokens cached, second: cache reset, all tokens need compute
128-
])
128+
129+
@pytest.mark.parametrize(
130+
"prompt, cache_first, cache_second",
131+
[
132+
([10, 11, 12], 3, 0), # first: all tokens cached, second: cache reset, all tokens need compute
133+
])
129134
def test_prefix_cache_reset(model_runner, prompt, cache_first, cache_second, dist_init):
130135
req_id = "req_reset"
131136
new_req_1 = make_new_request(req_id, prompt, num_computed_tokens=cache_first)

0 commit comments

Comments
 (0)