33
44from redis import Redis
55
6+ from redisvl .extensions .constants import (
7+ CACHE_VECTOR_FIELD_NAME ,
8+ ENTRY_ID_FIELD_NAME ,
9+ INSERTED_AT_FIELD_NAME ,
10+ METADATA_FIELD_NAME ,
11+ PROMPT_FIELD_NAME ,
12+ REDIS_KEY_FIELD_NAME ,
13+ RESPONSE_FIELD_NAME ,
14+ UPDATED_AT_FIELD_NAME ,
15+ )
616from redisvl .extensions .llmcache .base import BaseLLMCache
717from redisvl .extensions .llmcache .schema import (
818 CacheEntry ,
1929class SemanticCache (BaseLLMCache ):
2030 """Semantic Cache for Large Language Models."""
2131
22- redis_key_field_name : str = "key"
23- entry_id_field_name : str = "entry_id"
24- prompt_field_name : str = "prompt"
25- response_field_name : str = "response"
26- vector_field_name : str = "prompt_vector"
27- inserted_at_field_name : str = "inserted_at"
28- updated_at_field_name : str = "updated_at"
29- metadata_field_name : str = "metadata"
30-
3132 _index : SearchIndex
3233 _aindex : Optional [AsyncSearchIndex ] = None
3334
@@ -94,12 +95,12 @@ def __init__(
9495 # Process fields and other settings
9596 self .set_threshold (distance_threshold )
9697 self .return_fields = [
97- self . entry_id_field_name ,
98- self . prompt_field_name ,
99- self . response_field_name ,
100- self . inserted_at_field_name ,
101- self . updated_at_field_name ,
102- self . metadata_field_name ,
98+ ENTRY_ID_FIELD_NAME ,
99+ PROMPT_FIELD_NAME ,
100+ RESPONSE_FIELD_NAME ,
101+ INSERTED_AT_FIELD_NAME ,
102+ UPDATED_AT_FIELD_NAME ,
103+ METADATA_FIELD_NAME ,
103104 ]
104105
105106 # Create semantic cache schema and index
@@ -133,7 +134,7 @@ def __init__(
133134
134135 validate_vector_dims (
135136 vectorizer .dims ,
136- self ._index .schema .fields [self . vector_field_name ].attrs .dims , # type: ignore
137+ self ._index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .dims , # type: ignore
137138 )
138139 self ._vectorizer = vectorizer
139140
@@ -145,9 +146,7 @@ def _modify_schema(
145146 """Modify the base cache schema using the provided filterable fields"""
146147
147148 if filterable_fields is not None :
148- protected_field_names = set (
149- self .return_fields + [self .redis_key_field_name ]
150- )
149+ protected_field_names = set (self .return_fields + [REDIS_KEY_FIELD_NAME ])
151150 for filter_field in filterable_fields :
152151 field_name = filter_field ["name" ]
153152 if field_name in protected_field_names :
@@ -300,7 +299,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
300299 def _check_vector_dims (self , vector : List [float ]):
301300 """Checks the size of the provided vector and raises an error if it
302301 doesn't match the search index vector dimensions."""
303- schema_vector_dims = self ._index .schema .fields [self . vector_field_name ].attrs .dims # type: ignore
302+ schema_vector_dims = self ._index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .dims # type: ignore
304303 validate_vector_dims (len (vector ), schema_vector_dims )
305304
306305 def check (
@@ -363,7 +362,7 @@ def check(
363362
364363 query = RangeQuery (
365364 vector = vector ,
366- vector_field_name = self . vector_field_name ,
365+ vector_field_name = CACHE_VECTOR_FIELD_NAME ,
367366 return_fields = self .return_fields ,
368367 distance_threshold = distance_threshold ,
369368 num_results = num_results ,
@@ -444,7 +443,7 @@ async def acheck(
444443
445444 query = RangeQuery (
446445 vector = vector ,
447- vector_field_name = self . vector_field_name ,
446+ vector_field_name = CACHE_VECTOR_FIELD_NAME ,
448447 return_fields = self .return_fields ,
449448 distance_threshold = distance_threshold ,
450449 num_results = num_results ,
@@ -479,7 +478,7 @@ def _process_cache_results(
479478 cache_hit_dict = {
480479 k : v for k , v in cache_hit_dict .items () if k in return_fields
481480 }
482- cache_hit_dict [self . redis_key_field_name ] = redis_key
481+ cache_hit_dict [REDIS_KEY_FIELD_NAME ] = redis_key
483482 cache_hits .append (cache_hit_dict )
484483 return redis_keys , cache_hits
485484
@@ -541,7 +540,7 @@ def store(
541540 keys = self ._index .load (
542541 data = [cache_entry .to_dict ()],
543542 ttl = ttl ,
544- id_field = self . entry_id_field_name ,
543+ id_field = ENTRY_ID_FIELD_NAME ,
545544 )
546545 return keys [0 ]
547546
@@ -605,7 +604,7 @@ async def astore(
605604 keys = await aindex .load (
606605 data = [cache_entry .to_dict ()],
607606 ttl = ttl ,
608- id_field = self . entry_id_field_name ,
607+ id_field = ENTRY_ID_FIELD_NAME ,
609608 )
610609 return keys [0 ]
611610
@@ -629,21 +628,19 @@ def update(self, key: str, **kwargs) -> None:
629628 for k , v in kwargs .items ():
630629
631630 # Make sure the item is in the index schema
632- if k not in set (
633- self ._index .schema .field_names + [self .metadata_field_name ]
634- ):
631+ if k not in set (self ._index .schema .field_names + [METADATA_FIELD_NAME ]):
635632 raise ValueError (f"{ k } is not a valid field within the cache entry" )
636633
637634 # Check for metadata and deserialize
638- if k == self . metadata_field_name :
635+ if k == METADATA_FIELD_NAME :
639636 if isinstance (v , dict ):
640637 kwargs [k ] = serialize (v )
641638 else :
642639 raise TypeError (
643640 "If specified, cached metadata must be a dictionary."
644641 )
645642
646- kwargs .update ({self . updated_at_field_name : current_timestamp ()})
643+ kwargs .update ({UPDATED_AT_FIELD_NAME : current_timestamp ()})
647644
648645 self ._index .client .hset (key , mapping = kwargs ) # type: ignore
649646
@@ -674,21 +671,19 @@ async def aupdate(self, key: str, **kwargs) -> None:
674671 for k , v in kwargs .items ():
675672
676673 # Make sure the item is in the index schema
677- if k not in set (
678- self ._index .schema .field_names + [self .metadata_field_name ]
679- ):
674+ if k not in set (self ._index .schema .field_names + [METADATA_FIELD_NAME ]):
680675 raise ValueError (f"{ k } is not a valid field within the cache entry" )
681676
682677 # Check for metadata and deserialize
683- if k == self . metadata_field_name :
678+ if k == METADATA_FIELD_NAME :
684679 if isinstance (v , dict ):
685680 kwargs [k ] = serialize (v )
686681 else :
687682 raise TypeError (
688683 "If specified, cached metadata must be a dictionary."
689684 )
690685
691- kwargs .update ({self . updated_at_field_name : current_timestamp ()})
686+ kwargs .update ({UPDATED_AT_FIELD_NAME : current_timestamp ()})
692687
693688 await aindex .load (data = [kwargs ], keys = [key ])
694689
0 commit comments