Skip to content

Commit 2de27ba

Browse files
Merge branch 'main' into feat/RAAE-206/vector-dtypes
2 parents f0efe5c + f808b80 commit 2de27ba

File tree

15 files changed

+250
-163
lines changed

15 files changed

+250
-163
lines changed

redisvl/extensions/constants.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Constants used within the extension classes SemanticCache, BaseSessionManager,
3+
StandardSessionManager,SemanticSessionManager and SemanticRouter.
4+
These constants are also used within theses classes corresponding schema.
5+
"""
6+
7+
# BaseSessionManager
8+
ID_FIELD_NAME: str = "entry_id"
9+
ROLE_FIELD_NAME: str = "role"
10+
CONTENT_FIELD_NAME: str = "content"
11+
TOOL_FIELD_NAME: str = "tool_call_id"
12+
TIMESTAMP_FIELD_NAME: str = "timestamp"
13+
SESSION_FIELD_NAME: str = "session_tag"
14+
15+
# SemanticSessionManager
16+
SESSION_VECTOR_FIELD_NAME: str = "vector_field"
17+
18+
# SemanticCache
19+
REDIS_KEY_FIELD_NAME: str = "key"
20+
ENTRY_ID_FIELD_NAME: str = "entry_id"
21+
PROMPT_FIELD_NAME: str = "prompt"
22+
RESPONSE_FIELD_NAME: str = "response"
23+
CACHE_VECTOR_FIELD_NAME: str = "prompt_vector"
24+
INSERTED_AT_FIELD_NAME: str = "inserted_at"
25+
UPDATED_AT_FIELD_NAME: str = "updated_at"
26+
METADATA_FIELD_NAME: str = "metadata"
27+
28+
# SemanticRouter
29+
ROUTE_VECTOR_FIELD_NAME: str = "vector"

redisvl/extensions/llmcache/base.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Any, Dict, List, Optional
22

3-
from redisvl.redis.utils import hashify
4-
53

64
class BaseLLMCache:
75
def __init__(self, ttl: Optional[int] = None):
@@ -79,14 +77,3 @@ async def astore(
7977
"""Async store the specified key-value pair in the cache along with
8078
metadata."""
8179
raise NotImplementedError
82-
83-
def hash_input(self, prompt: str) -> str:
84-
"""Hashes the input prompt using SHA256.
85-
86-
Args:
87-
prompt (str): Input string to be hashed.
88-
89-
Returns:
90-
str: Hashed string.
91-
"""
92-
return hashify(prompt)

redisvl/extensions/llmcache/schema.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
from pydantic.v1 import BaseModel, Field, root_validator, validator
44

5+
from redisvl.extensions.constants import (
6+
CACHE_VECTOR_FIELD_NAME,
7+
INSERTED_AT_FIELD_NAME,
8+
PROMPT_FIELD_NAME,
9+
RESPONSE_FIELD_NAME,
10+
UPDATED_AT_FIELD_NAME,
11+
)
512
from redisvl.redis.utils import array_to_buffer, hashify
613
from redisvl.schema import IndexSchema
714
from redisvl.utils.utils import current_timestamp, deserialize, serialize
@@ -32,7 +39,7 @@ class CacheEntry(BaseModel):
3239
def generate_id(cls, values):
3340
# Ensure entry_id is set
3441
if not values.get("entry_id"):
35-
values["entry_id"] = hashify(values["prompt"])
42+
values["entry_id"] = hashify(values["prompt"], values.get("filters"))
3643
return values
3744

3845
@validator("metadata")
@@ -110,12 +117,12 @@ def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str):
110117
return cls(
111118
index={"name": name, "prefix": prefix}, # type: ignore
112119
fields=[ # type: ignore
113-
{"name": "prompt", "type": "text"},
114-
{"name": "response", "type": "text"},
115-
{"name": "inserted_at", "type": "numeric"},
116-
{"name": "updated_at", "type": "numeric"},
120+
{"name": PROMPT_FIELD_NAME, "type": "text"},
121+
{"name": RESPONSE_FIELD_NAME, "type": "text"},
122+
{"name": INSERTED_AT_FIELD_NAME, "type": "numeric"},
123+
{"name": UPDATED_AT_FIELD_NAME, "type": "numeric"},
117124
{
118-
"name": "prompt_vector",
125+
"name": CACHE_VECTOR_FIELD_NAME,
119126
"type": "vector",
120127
"attrs": {
121128
"dims": vector_dims,

redisvl/extensions/llmcache/semantic.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33

44
from redis import Redis
55

6+
from redisvl.extensions.constants import (
7+
CACHE_VECTOR_FIELD_NAME,
8+
ENTRY_ID_FIELD_NAME,
9+
INSERTED_AT_FIELD_NAME,
10+
METADATA_FIELD_NAME,
11+
PROMPT_FIELD_NAME,
12+
REDIS_KEY_FIELD_NAME,
13+
RESPONSE_FIELD_NAME,
14+
UPDATED_AT_FIELD_NAME,
15+
)
616
from redisvl.extensions.llmcache.base import BaseLLMCache
717
from redisvl.extensions.llmcache.schema import (
818
CacheEntry,
@@ -15,20 +25,10 @@
1525
from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims
1626
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
1727

18-
VECTOR_FIELD_NAME = "prompt_vector"
19-
2028

2129
class SemanticCache(BaseLLMCache):
2230
"""Semantic Cache for Large Language Models."""
2331

24-
redis_key_field_name: str = "key"
25-
entry_id_field_name: str = "entry_id"
26-
prompt_field_name: str = "prompt"
27-
response_field_name: str = "response"
28-
inserted_at_field_name: str = "inserted_at"
29-
updated_at_field_name: str = "updated_at"
30-
metadata_field_name: str = "metadata"
31-
3232
_index: SearchIndex
3333
_aindex: Optional[AsyncSearchIndex] = None
3434

@@ -95,12 +95,12 @@ def __init__(
9595
# Process fields and other settings
9696
self.set_threshold(distance_threshold)
9797
self.return_fields = [
98-
self.entry_id_field_name,
99-
self.prompt_field_name,
100-
self.response_field_name,
101-
self.inserted_at_field_name,
102-
self.updated_at_field_name,
103-
self.metadata_field_name,
98+
ENTRY_ID_FIELD_NAME,
99+
PROMPT_FIELD_NAME,
100+
RESPONSE_FIELD_NAME,
101+
INSERTED_AT_FIELD_NAME,
102+
UPDATED_AT_FIELD_NAME,
103+
METADATA_FIELD_NAME,
104104
]
105105

106106
# Create semantic cache schema and index
@@ -137,10 +137,10 @@ def __init__(
137137

138138
validate_vector_dims(
139139
vectorizer.dims,
140-
self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore
140+
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
141141
)
142142
self._vectorizer = vectorizer
143-
self._dtype = self.index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr]
143+
self._dtype = self.index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr]
144144

145145
def _modify_schema(
146146
self,
@@ -150,9 +150,7 @@ def _modify_schema(
150150
"""Modify the base cache schema using the provided filterable fields"""
151151

152152
if filterable_fields is not None:
153-
protected_field_names = set(
154-
self.return_fields + [self.redis_key_field_name]
155-
)
153+
protected_field_names = set(self.return_fields + [REDIS_KEY_FIELD_NAME])
156154
for filter_field in filterable_fields:
157155
field_name = filter_field["name"]
158156
if field_name in protected_field_names:
@@ -305,7 +303,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
305303
def _check_vector_dims(self, vector: List[float]):
306304
"""Checks the size of the provided vector and raises an error if it
307305
doesn't match the search index vector dimensions."""
308-
schema_vector_dims = self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims # type: ignore
306+
schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore
309307
validate_vector_dims(len(vector), schema_vector_dims)
310308

311309
def check(
@@ -368,7 +366,7 @@ def check(
368366

369367
query = RangeQuery(
370368
vector=vector,
371-
vector_field_name=VECTOR_FIELD_NAME,
369+
vector_field_name=CACHE_VECTOR_FIELD_NAME,
372370
return_fields=self.return_fields,
373371
distance_threshold=distance_threshold,
374372
num_results=num_results,
@@ -450,7 +448,7 @@ async def acheck(
450448

451449
query = RangeQuery(
452450
vector=vector,
453-
vector_field_name=VECTOR_FIELD_NAME,
451+
vector_field_name=CACHE_VECTOR_FIELD_NAME,
454452
return_fields=self.return_fields,
455453
distance_threshold=distance_threshold,
456454
num_results=num_results,
@@ -485,7 +483,7 @@ def _process_cache_results(
485483
cache_hit_dict = {
486484
k: v for k, v in cache_hit_dict.items() if k in return_fields
487485
}
488-
cache_hit_dict[self.redis_key_field_name] = redis_key
486+
cache_hit_dict[REDIS_KEY_FIELD_NAME] = redis_key
489487
cache_hits.append(cache_hit_dict)
490488
return redis_keys, cache_hits
491489

@@ -547,7 +545,7 @@ def store(
547545
keys = self._index.load(
548546
data=[cache_entry.to_dict(self._dtype)],
549547
ttl=ttl,
550-
id_field=self.entry_id_field_name,
548+
id_field=ENTRY_ID_FIELD_NAME,
551549
)
552550
return keys[0]
553551

@@ -611,7 +609,7 @@ async def astore(
611609
keys = await aindex.load(
612610
data=[cache_entry.to_dict(self._dtype)],
613611
ttl=ttl,
614-
id_field=self.entry_id_field_name,
612+
id_field=ENTRY_ID_FIELD_NAME,
615613
)
616614
return keys[0]
617615

@@ -635,21 +633,19 @@ def update(self, key: str, **kwargs) -> None:
635633
for k, v in kwargs.items():
636634

637635
# Make sure the item is in the index schema
638-
if k not in set(
639-
self._index.schema.field_names + [self.metadata_field_name]
640-
):
636+
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
641637
raise ValueError(f"{k} is not a valid field within the cache entry")
642638

643639
# Check for metadata and deserialize
644-
if k == self.metadata_field_name:
640+
if k == METADATA_FIELD_NAME:
645641
if isinstance(v, dict):
646642
kwargs[k] = serialize(v)
647643
else:
648644
raise TypeError(
649645
"If specified, cached metadata must be a dictionary."
650646
)
651647

652-
kwargs.update({self.updated_at_field_name: current_timestamp()})
648+
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})
653649

654650
self._index.client.hset(key, mapping=kwargs) # type: ignore
655651

@@ -680,21 +676,19 @@ async def aupdate(self, key: str, **kwargs) -> None:
680676
for k, v in kwargs.items():
681677

682678
# Make sure the item is in the index schema
683-
if k not in set(
684-
self._index.schema.field_names + [self.metadata_field_name]
685-
):
679+
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
686680
raise ValueError(f"{k} is not a valid field within the cache entry")
687681

688682
# Check for metadata and deserialize
689-
if k == self.metadata_field_name:
683+
if k == METADATA_FIELD_NAME:
690684
if isinstance(v, dict):
691685
kwargs[k] = serialize(v)
692686
else:
693687
raise TypeError(
694688
"If specified, cached metadata must be a dictionary."
695689
)
696690

697-
kwargs.update({self.updated_at_field_name: current_timestamp()})
691+
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})
698692

699693
await aindex.load(data=[kwargs], keys=[key])
700694

redisvl/extensions/router/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pydantic.v1 import BaseModel, Field, validator
55

6+
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
67
from redisvl.schema import IndexSchema
78

89

@@ -104,7 +105,7 @@ def from_params(cls, name: str, vector_dims: int, dtype: str):
104105
{"name": "route_name", "type": "tag"},
105106
{"name": "reference", "type": "text"},
106107
{
107-
"name": "vector",
108+
"name": ROUTE_VECTOR_FIELD_NAME,
108109
"type": "vector",
109110
"attrs": {
110111
"algorithm": "flat",

redisvl/extensions/router/semantic.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
99
from redis.exceptions import ResponseError
1010

11+
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
1112
from redisvl.extensions.router.schema import (
1213
DistanceAggregationMethod,
1314
Route,
@@ -28,8 +29,6 @@
2829

2930
logger = get_logger(__name__)
3031

31-
VECTOR_FIELD_NAME = "vector"
32-
3332

3433
class SemanticRouter(BaseModel):
3534
"""Semantic Router for managing and querying route vectors."""
@@ -172,7 +171,7 @@ def _add_routes(self, routes: List[Route]):
172171
reference_vectors = self.vectorizer.embed_many(
173172
[reference for reference in route.references],
174173
as_buffer=True,
175-
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
174+
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
176175
)
177176
# set route references
178177
for i, reference in enumerate(route.references):
@@ -246,10 +245,10 @@ def _classify_route(
246245
"""Classify to a single route using a vector."""
247246
vector_range_query = RangeQuery(
248247
vector=vector,
249-
vector_field_name="vector",
248+
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
250249
distance_threshold=distance_threshold,
251250
return_fields=["route_name"],
252-
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
251+
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
253252
)
254253

255254
aggregate_request = self._build_aggregate_request(
@@ -299,10 +298,10 @@ def _classify_multi_route(
299298
"""Classify to multiple routes, up to max_k (int), using a vector."""
300299
vector_range_query = RangeQuery(
301300
vector=vector,
302-
vector_field_name="vector",
301+
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
303302
distance_threshold=distance_threshold,
304303
return_fields=["route_name"],
305-
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
304+
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
306305
)
307306
aggregate_request = self._build_aggregate_request(
308307
vector_range_query, aggregation_method, max_k

redisvl/extensions/session_manager/base_session.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from typing import Any, Dict, List, Optional, Union
22

3+
from redisvl.extensions.constants import (
4+
CONTENT_FIELD_NAME,
5+
ROLE_FIELD_NAME,
6+
TOOL_FIELD_NAME,
7+
)
38
from redisvl.extensions.session_manager.schema import ChatMessage
49
from redisvl.utils.utils import create_uuid
510

611

712
class BaseSessionManager:
8-
id_field_name: str = "entry_id"
9-
role_field_name: str = "role"
10-
content_field_name: str = "content"
11-
tool_field_name: str = "tool_call_id"
12-
timestamp_field_name: str = "timestamp"
13-
session_field_name: str = "session_tag"
1413

1514
def __init__(
1615
self,
@@ -107,11 +106,11 @@ def _format_context(
107106
context.append(chat_message.content)
108107
else:
109108
chat_message_dict = {
110-
self.role_field_name: chat_message.role,
111-
self.content_field_name: chat_message.content,
109+
ROLE_FIELD_NAME: chat_message.role,
110+
CONTENT_FIELD_NAME: chat_message.content,
112111
}
113112
if chat_message.tool_call_id is not None:
114-
chat_message_dict[self.tool_field_name] = chat_message.tool_call_id
113+
chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id
115114

116115
context.append(chat_message_dict) # type: ignore
117116

0 commit comments

Comments
 (0)