33
44from redis import Redis
55
6+ from redisvl .extensions .constants import (
7+ CACHE_VECTOR_FIELD_NAME ,
8+ ENTRY_ID_FIELD_NAME ,
9+ INSERTED_AT_FIELD_NAME ,
10+ METADATA_FIELD_NAME ,
11+ PROMPT_FIELD_NAME ,
12+ REDIS_KEY_FIELD_NAME ,
13+ RESPONSE_FIELD_NAME ,
14+ UPDATED_AT_FIELD_NAME ,
15+ )
616from redisvl .extensions .llmcache .base import BaseLLMCache
717from redisvl .extensions .llmcache .schema import (
818 CacheEntry ,
1525from redisvl .utils .utils import current_timestamp , serialize , validate_vector_dims
1626from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
1727
18- VECTOR_FIELD_NAME = "prompt_vector"
19-
2028
2129class SemanticCache (BaseLLMCache ):
2230 """Semantic Cache for Large Language Models."""
2331
24- redis_key_field_name : str = "key"
25- entry_id_field_name : str = "entry_id"
26- prompt_field_name : str = "prompt"
27- response_field_name : str = "response"
28- inserted_at_field_name : str = "inserted_at"
29- updated_at_field_name : str = "updated_at"
30- metadata_field_name : str = "metadata"
31-
3232 _index : SearchIndex
3333 _aindex : Optional [AsyncSearchIndex ] = None
3434
@@ -95,12 +95,12 @@ def __init__(
9595 # Process fields and other settings
9696 self .set_threshold (distance_threshold )
9797 self .return_fields = [
98- self . entry_id_field_name ,
99- self . prompt_field_name ,
100- self . response_field_name ,
101- self . inserted_at_field_name ,
102- self . updated_at_field_name ,
103- self . metadata_field_name ,
98+ ENTRY_ID_FIELD_NAME ,
99+ PROMPT_FIELD_NAME ,
100+ RESPONSE_FIELD_NAME ,
101+ INSERTED_AT_FIELD_NAME ,
102+ UPDATED_AT_FIELD_NAME ,
103+ METADATA_FIELD_NAME ,
104104 ]
105105
106106 # Create semantic cache schema and index
@@ -137,10 +137,10 @@ def __init__(
137137
138138 validate_vector_dims (
139139 vectorizer .dims ,
140- self ._index .schema .fields [VECTOR_FIELD_NAME ].attrs .dims , # type: ignore
140+ self ._index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .dims , # type: ignore
141141 )
142142 self ._vectorizer = vectorizer
143- self ._dtype = self .index .schema .fields [VECTOR_FIELD_NAME ].attrs .datatype # type: ignore[union-attr]
143+ self ._dtype = self .index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .datatype # type: ignore[union-attr]
144144
145145 def _modify_schema (
146146 self ,
@@ -150,9 +150,7 @@ def _modify_schema(
150150 """Modify the base cache schema using the provided filterable fields"""
151151
152152 if filterable_fields is not None :
153- protected_field_names = set (
154- self .return_fields + [self .redis_key_field_name ]
155- )
153+ protected_field_names = set (self .return_fields + [REDIS_KEY_FIELD_NAME ])
156154 for filter_field in filterable_fields :
157155 field_name = filter_field ["name" ]
158156 if field_name in protected_field_names :
@@ -305,7 +303,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
305303 def _check_vector_dims (self , vector : List [float ]):
306304 """Checks the size of the provided vector and raises an error if it
307305 doesn't match the search index vector dimensions."""
308- schema_vector_dims = self ._index .schema .fields [VECTOR_FIELD_NAME ].attrs .dims # type: ignore
306+ schema_vector_dims = self ._index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .dims # type: ignore
309307 validate_vector_dims (len (vector ), schema_vector_dims )
310308
311309 def check (
@@ -368,7 +366,7 @@ def check(
368366
369367 query = RangeQuery (
370368 vector = vector ,
371- vector_field_name = VECTOR_FIELD_NAME ,
369+ vector_field_name = CACHE_VECTOR_FIELD_NAME ,
372370 return_fields = self .return_fields ,
373371 distance_threshold = distance_threshold ,
374372 num_results = num_results ,
@@ -450,7 +448,7 @@ async def acheck(
450448
451449 query = RangeQuery (
452450 vector = vector ,
453- vector_field_name = VECTOR_FIELD_NAME ,
451+ vector_field_name = CACHE_VECTOR_FIELD_NAME ,
454452 return_fields = self .return_fields ,
455453 distance_threshold = distance_threshold ,
456454 num_results = num_results ,
@@ -485,7 +483,7 @@ def _process_cache_results(
485483 cache_hit_dict = {
486484 k : v for k , v in cache_hit_dict .items () if k in return_fields
487485 }
488- cache_hit_dict [self . redis_key_field_name ] = redis_key
486+ cache_hit_dict [REDIS_KEY_FIELD_NAME ] = redis_key
489487 cache_hits .append (cache_hit_dict )
490488 return redis_keys , cache_hits
491489
@@ -547,7 +545,7 @@ def store(
547545 keys = self ._index .load (
548546 data = [cache_entry .to_dict (self ._dtype )],
549547 ttl = ttl ,
550- id_field = self . entry_id_field_name ,
548+ id_field = ENTRY_ID_FIELD_NAME ,
551549 )
552550 return keys [0 ]
553551
@@ -611,7 +609,7 @@ async def astore(
611609 keys = await aindex .load (
612610 data = [cache_entry .to_dict (self ._dtype )],
613611 ttl = ttl ,
614- id_field = self . entry_id_field_name ,
612+ id_field = ENTRY_ID_FIELD_NAME ,
615613 )
616614 return keys [0 ]
617615
@@ -635,21 +633,19 @@ def update(self, key: str, **kwargs) -> None:
635633 for k , v in kwargs .items ():
636634
637635 # Make sure the item is in the index schema
638- if k not in set (
639- self ._index .schema .field_names + [self .metadata_field_name ]
640- ):
636+ if k not in set (self ._index .schema .field_names + [METADATA_FIELD_NAME ]):
641637 raise ValueError (f"{ k } is not a valid field within the cache entry" )
642638
643639 # Check for metadata and deserialize
644- if k == self . metadata_field_name :
640+ if k == METADATA_FIELD_NAME :
645641 if isinstance (v , dict ):
646642 kwargs [k ] = serialize (v )
647643 else :
648644 raise TypeError (
649645 "If specified, cached metadata must be a dictionary."
650646 )
651647
652- kwargs .update ({self . updated_at_field_name : current_timestamp ()})
648+ kwargs .update ({UPDATED_AT_FIELD_NAME : current_timestamp ()})
653649
654650 self ._index .client .hset (key , mapping = kwargs ) # type: ignore
655651
@@ -680,21 +676,19 @@ async def aupdate(self, key: str, **kwargs) -> None:
680676 for k , v in kwargs .items ():
681677
682678 # Make sure the item is in the index schema
683- if k not in set (
684- self ._index .schema .field_names + [self .metadata_field_name ]
685- ):
679+ if k not in set (self ._index .schema .field_names + [METADATA_FIELD_NAME ]):
686680 raise ValueError (f"{ k } is not a valid field within the cache entry" )
687681
688682 # Check for metadata and deserialize
689- if k == self . metadata_field_name :
683+ if k == METADATA_FIELD_NAME :
690684 if isinstance (v , dict ):
691685 kwargs [k ] = serialize (v )
692686 else :
693687 raise TypeError (
694688 "If specified, cached metadata must be a dictionary."
695689 )
696690
697- kwargs .update ({self . updated_at_field_name : current_timestamp ()})
691+ kwargs .update ({UPDATED_AT_FIELD_NAME : current_timestamp ()})
698692
699693 await aindex .load (data = [kwargs ], keys = [key ])
700694
0 commit comments