Skip to content

Commit

Permalink
Bugfix: Check cache keys as prefix to prompt tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Apr 25, 2023
1 parent b75fa96 commit d484c56
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque

from . import llama_cpp
Expand All @@ -15,15 +15,34 @@ class LlamaCache:
"""Cache for a llama.cpp model."""

def __init__(self):
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()

def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
return [
key
for _, key in sorted(
((len(key), key) for key in self.cache_state.keys()), reverse=True
)
]

def _find_key(
self, key: Tuple[llama_cpp.llama_token, ...]
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
for k in self._sorted_keys():
if key[: len(k)] == k:
return k
return None

def __getitem__(
self, key: Sequence[llama_cpp.llama_token]
) -> Optional["LlamaState"]:
return self.cache_state.get(tuple(key), None)
_key = self._find_key(tuple(key))
if _key is None:
return None
return self.cache_state[_key]

def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
return tuple(key) in self.cache_state
return self._find_key(tuple(key)) is not None

def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
Expand Down Expand Up @@ -295,7 +314,7 @@ def generate(
if (
reset
and len(self.eval_tokens) > 0
and self.eval_tokens == tokens[: len(self.eval_tokens)]
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
):
if self.verbose:
print("generate cache hit", file=sys.stderr)
Expand Down Expand Up @@ -438,6 +457,8 @@ def _create_completion(

if self.cache and len(completion_tokens) == 0:
if prompt_tokens not in self.cache:
if self.verbose:
print("cache miss", file=sys.stderr)
self.cache[prompt_tokens] = self.save_state()

completion_tokens.append(token)
Expand Down

0 comments on commit d484c56

Please sign in to comment.