Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse KV cache of prefixes #484

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
6 changes: 6 additions & 0 deletions mii/batching/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def is_done(self, is_done: bool) -> None:
def generated_tokens(self) -> List[torch.Tensor]:
return self._generated_tokens

@property
def all_tokens(self) -> List[torch.Tensor]:
return torch.cat([self.prompt_tokens] +
[t.unsqueeze(0) for t in self.generated_tokens],
dim=0)

@property
def finish_reason(self) -> GenerationFinishReason:
return self._finish_reason
Expand Down
39 changes: 31 additions & 8 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(self, inference_engine, tokenizer, model_config):
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.socket.setsockopt(zmq.RCVTIMEO, ZMQ_RECV_TIMEOUT)

self.enable_prefix_cache = self.inference_engine._config.enable_prefix_cache

@cached_property
def local_rank(self) -> int:
return get_accelerator().current_device()
Expand Down Expand Up @@ -122,6 +124,9 @@ def generate(self) -> None:
# 5. Schedule requests while we wait for the forward pass to finish
self._reset_scheduler_bookkeeping()

for r in running_requests:
self.inference_engine.update_prefix_cache(r.uid, r.all_tokens)

# 6. Accumulate generated tokens, check completion, and generate output
for r in running_requests.last_in_prompt:
r.accumulate_generated_token()
Expand Down Expand Up @@ -274,24 +279,38 @@ def _schedule_prompts(self, requests: List[Request]) -> None:

max_blocks = free_blocks - self.scheduled_req_blocks

if len(r.input_tokens) > 1:
input_tokens = r.input_tokens
if r.seq_length == 0:
cache_hit_length, block_ids = self.inference_engine.lookup_cache(r.input_tokens)
input_tokens = input_tokens[cache_hit_length:]
else:
cache_hit_length = 0
block_ids = []

if len(input_tokens) > 1:
# When the KV cache is out of capacity, we release KV cache blocks for a request.
# However, we can immediately schedule the request again if we split the request.
# So we make sure that we have capacity for the entire prompt (+tokens already generated).
req_tokens, _ = self.inference_engine.query(r.uid, len(r.input_tokens), max_blocks)
if req_tokens < len(r.input_tokens):
req_tokens, _ = self.inference_engine.query(r.uid, len(input_tokens), max_blocks)
if req_tokens < len(input_tokens):
break

req_tokens = min(len(r.input_tokens), max_batch_size)
req_tokens = min(len(input_tokens), max_batch_size)
req_tokens, req_blocks = self.inference_engine.query(r.uid, req_tokens, max_blocks)

if req_tokens <= 0:
continue

# Decompose the prompt to fit to the max ragged batch size
decomposed = req_tokens < len(r.input_tokens)
remaining_tokens = r.input_tokens[req_tokens:]
r.input_tokens = r.input_tokens[:req_tokens]
if cache_hit_length > 0:
self.inference_engine.setup_cached_sequence(r.uid,
cache_hit_length,
block_ids)
r.seq_length = r.seq_length + cache_hit_length

decomposed = req_tokens < len(input_tokens)
remaining_tokens = input_tokens[req_tokens:]
r.input_tokens = input_tokens[:req_tokens]
r.last_in_prompt = not decomposed

# Schedule the request
Expand Down Expand Up @@ -571,11 +590,15 @@ def __call__(self,

if self.is_rank_0:
# Rank 0 runs generate() until all responses are returned
while uids_running:
while uids_running \
or not self.request_queue.empty() \
or self.scheduled_requests.requests_to_flush.uids:
self.generate()
while not self.result_queues[self.tid].empty():
uid, response = self._get_response()
outputs.append(response)
# We can't directly call flush because the flush request is broadcasted
# to other ranks after taken from the queue
self._queue_flush_request(uid)
uids_complete_order.append(uid)
uids_running.remove(uid)
Expand Down
Loading