Skip to content

Commit ab0e093

Browse files
Merge branch 'main' into feat/RAAE-127/voyage-ai-integration
2 parents 8b1b520 + f808b80 commit ab0e093

File tree

17 files changed

+4279
-149
lines changed

17 files changed

+4279
-149
lines changed

doctests/data/query_vector.json

Lines changed: 3952 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
version: '0.1.0'
2+
3+
index:
4+
name: idx:bicycle
5+
prefix: bicycle
6+
storage_type: json
7+
8+
fields:
9+
- name: description
10+
type: text
11+
- name: description_embeddings
12+
type: vector
13+
attrs:
14+
algorithm: flat
15+
dims: 384
16+
distance_metric: cosine
17+
datatype: float32

doctests/query_vector.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# EXAMPLE: query_vector
2+
# HIDE_START
3+
import json
4+
import warnings
5+
import redis
6+
import numpy as np
7+
from redisvl.index import SearchIndex
8+
from redisvl.query import RangeQuery, VectorQuery
9+
from redisvl.schema import IndexSchema
10+
from sentence_transformers import SentenceTransformer
11+
12+
13+
def embed_text(model, text):
14+
return np.array(model.encode(text)).astype(np.float32).tobytes()
15+
16+
r = redis.Redis(decode_responses=True)
17+
18+
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces.*")
19+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
20+
21+
# create index
22+
schema = IndexSchema.from_yaml('data/query_vector_idx.yaml')
23+
index = SearchIndex(schema, r)
24+
index.create(overwrite=True, drop=True)
25+
26+
# load data
27+
with open("data/query_vector.json") as f:
28+
bicycles = json.load(f)
29+
index.load(bicycles)
30+
# HIDE_END
31+
32+
# STEP_START vector1
33+
query = "Bike for small kids"
34+
query_vector = embed_text(model, query)
35+
print(query_vector[:10]) # >>> b'\x02=c=\x93\x0e\xe0=aC'
36+
37+
vquery = VectorQuery(
38+
vector=query_vector,
39+
vector_field_name="description_embeddings",
40+
num_results=3,
41+
return_score=True,
42+
return_fields=["description"]
43+
)
44+
res = index.query(vquery)
45+
print(res) # >>> [{'id': 'bicycle:6b702e8b...', 'vector_distance': '0.399111807346', 'description': 'Kids want...
46+
# REMOVE_START
47+
assert len(res) == 3
48+
# REMOVE_END
49+
# STEP_END
50+
51+
# STEP_START vector2
52+
vquery = RangeQuery(
53+
vector=query_vector,
54+
vector_field_name="description_embeddings",
55+
distance_threshold=0.5,
56+
return_score=True
57+
).return_fields("description").dialect(2)
58+
res = index.query(vquery)
59+
print(res) # >>> [{'id': 'bicycle:6bcb1bb4...', 'vector_distance': '0.399111807346', 'description': 'Kids want...
60+
# REMOVE_START
61+
assert len(res) == 2
62+
# REMOVE_END
63+
# STEP_END
64+
65+
# REMOVE_START
66+
# destroy index and data
67+
index.delete(drop=True)
68+
# REMOVE_END

redisvl/extensions/constants.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Constants used within the extension classes SemanticCache, BaseSessionManager,
3+
StandardSessionManager,SemanticSessionManager and SemanticRouter.
4+
These constants are also used within theses classes corresponding schema.
5+
"""
6+
7+
# BaseSessionManager
8+
ID_FIELD_NAME: str = "entry_id"
9+
ROLE_FIELD_NAME: str = "role"
10+
CONTENT_FIELD_NAME: str = "content"
11+
TOOL_FIELD_NAME: str = "tool_call_id"
12+
TIMESTAMP_FIELD_NAME: str = "timestamp"
13+
SESSION_FIELD_NAME: str = "session_tag"
14+
15+
# SemanticSessionManager
16+
SESSION_VECTOR_FIELD_NAME: str = "vector_field"
17+
18+
# SemanticCache
19+
REDIS_KEY_FIELD_NAME: str = "key"
20+
ENTRY_ID_FIELD_NAME: str = "entry_id"
21+
PROMPT_FIELD_NAME: str = "prompt"
22+
RESPONSE_FIELD_NAME: str = "response"
23+
CACHE_VECTOR_FIELD_NAME: str = "prompt_vector"
24+
INSERTED_AT_FIELD_NAME: str = "inserted_at"
25+
UPDATED_AT_FIELD_NAME: str = "updated_at"
26+
METADATA_FIELD_NAME: str = "metadata"
27+
28+
# SemanticRouter
29+
ROUTE_VECTOR_FIELD_NAME: str = "vector"

redisvl/extensions/llmcache/base.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Any, Dict, List, Optional
22

3-
from redisvl.redis.utils import hashify
4-
53

64
class BaseLLMCache:
75
def __init__(self, ttl: Optional[int] = None):
@@ -79,14 +77,3 @@ async def astore(
7977
"""Async store the specified key-value pair in the cache along with
8078
metadata."""
8179
raise NotImplementedError
82-
83-
def hash_input(self, prompt: str) -> str:
84-
"""Hashes the input prompt using SHA256.
85-
86-
Args:
87-
prompt (str): Input string to be hashed.
88-
89-
Returns:
90-
str: Hashed string.
91-
"""
92-
return hashify(prompt)

redisvl/extensions/llmcache/schema.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
from pydantic.v1 import BaseModel, Field, root_validator, validator
44

5+
from redisvl.extensions.constants import (
6+
CACHE_VECTOR_FIELD_NAME,
7+
INSERTED_AT_FIELD_NAME,
8+
PROMPT_FIELD_NAME,
9+
RESPONSE_FIELD_NAME,
10+
UPDATED_AT_FIELD_NAME,
11+
)
512
from redisvl.redis.utils import array_to_buffer, hashify
613
from redisvl.schema import IndexSchema
714
from redisvl.utils.utils import current_timestamp, deserialize, serialize
@@ -32,7 +39,7 @@ class CacheEntry(BaseModel):
3239
def generate_id(cls, values):
3340
# Ensure entry_id is set
3441
if not values.get("entry_id"):
35-
values["entry_id"] = hashify(values["prompt"])
42+
values["entry_id"] = hashify(values["prompt"], values.get("filters"))
3643
return values
3744

3845
@validator("metadata")
@@ -110,12 +117,12 @@ def from_params(cls, name: str, prefix: str, vector_dims: int):
110117
return cls(
111118
index={"name": name, "prefix": prefix}, # type: ignore
112119
fields=[ # type: ignore
113-
{"name": "prompt", "type": "text"},
114-
{"name": "response", "type": "text"},
115-
{"name": "inserted_at", "type": "numeric"},
116-
{"name": "updated_at", "type": "numeric"},
120+
{"name": PROMPT_FIELD_NAME, "type": "text"},
121+
{"name": RESPONSE_FIELD_NAME, "type": "text"},
122+
{"name": INSERTED_AT_FIELD_NAME, "type": "numeric"},
123+
{"name": UPDATED_AT_FIELD_NAME, "type": "numeric"},
117124
{
118-
"name": "prompt_vector",
125+
"name": CACHE_VECTOR_FIELD_NAME,
119126
"type": "vector",
120127
"attrs": {
121128
"dims": vector_dims,

redisvl/extensions/llmcache/semantic.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33

44
from redis import Redis
55

6+
from redisvl.extensions.constants import (
7+
CACHE_VECTOR_FIELD_NAME,
8+
ENTRY_ID_FIELD_NAME,
9+
INSERTED_AT_FIELD_NAME,
10+
METADATA_FIELD_NAME,
11+
PROMPT_FIELD_NAME,
12+
REDIS_KEY_FIELD_NAME,
13+
RESPONSE_FIELD_NAME,
14+
UPDATED_AT_FIELD_NAME,
15+
)
616
from redisvl.extensions.llmcache.base import BaseLLMCache
717
from redisvl.extensions.llmcache.schema import (
818
CacheEntry,
@@ -19,15 +29,6 @@
1929
class SemanticCache(BaseLLMCache):
2030
"""Semantic Cache for Large Language Models."""
2131

22-
redis_key_field_name: str = "key"
23-
entry_id_field_name: str = "entry_id"
24-
prompt_field_name: str = "prompt"
25-
response_field_name: str = "response"
26-
vector_field_name: str = "prompt_vector"
27-
inserted_at_field_name: str = "inserted_at"
28-
updated_at_field_name: str = "updated_at"
29-
metadata_field_name: str = "metadata"
30-
3132
_index: SearchIndex
3233
_aindex: Optional[AsyncSearchIndex] = None
3334

@@ -94,12 +95,12 @@ def __init__(
9495
# Process fields and other settings
9596
self.set_threshold(distance_threshold)
9697
self.return_fields = [
97-
self.entry_id_field_name,
98-
self.prompt_field_name,
99-
self.response_field_name,
100-
self.inserted_at_field_name,
101-
self.updated_at_field_name,
102-
self.metadata_field_name,
98+
ENTRY_ID_FIELD_NAME,
99+
PROMPT_FIELD_NAME,
100+
RESPONSE_FIELD_NAME,
101+
INSERTED_AT_FIELD_NAME,
102+
UPDATED_AT_FIELD_NAME,
103+
METADATA_FIELD_NAME,
103104
]
104105

105106
# Create semantic cache schema and index
@@ -133,7 +134,7 @@ def __init__(
133134

134135
validate_vector_dims(
135136
vectorizer.dims,
136-
self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore
137+
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
137138
)
138139
self._vectorizer = vectorizer
139140

@@ -145,9 +146,7 @@ def _modify_schema(
145146
"""Modify the base cache schema using the provided filterable fields"""
146147

147148
if filterable_fields is not None:
148-
protected_field_names = set(
149-
self.return_fields + [self.redis_key_field_name]
150-
)
149+
protected_field_names = set(self.return_fields + [REDIS_KEY_FIELD_NAME])
151150
for filter_field in filterable_fields:
152151
field_name = filter_field["name"]
153152
if field_name in protected_field_names:
@@ -300,7 +299,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
300299
def _check_vector_dims(self, vector: List[float]):
301300
"""Checks the size of the provided vector and raises an error if it
302301
doesn't match the search index vector dimensions."""
303-
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
302+
schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore
304303
validate_vector_dims(len(vector), schema_vector_dims)
305304

306305
def check(
@@ -363,7 +362,7 @@ def check(
363362

364363
query = RangeQuery(
365364
vector=vector,
366-
vector_field_name=self.vector_field_name,
365+
vector_field_name=CACHE_VECTOR_FIELD_NAME,
367366
return_fields=self.return_fields,
368367
distance_threshold=distance_threshold,
369368
num_results=num_results,
@@ -444,7 +443,7 @@ async def acheck(
444443

445444
query = RangeQuery(
446445
vector=vector,
447-
vector_field_name=self.vector_field_name,
446+
vector_field_name=CACHE_VECTOR_FIELD_NAME,
448447
return_fields=self.return_fields,
449448
distance_threshold=distance_threshold,
450449
num_results=num_results,
@@ -479,7 +478,7 @@ def _process_cache_results(
479478
cache_hit_dict = {
480479
k: v for k, v in cache_hit_dict.items() if k in return_fields
481480
}
482-
cache_hit_dict[self.redis_key_field_name] = redis_key
481+
cache_hit_dict[REDIS_KEY_FIELD_NAME] = redis_key
483482
cache_hits.append(cache_hit_dict)
484483
return redis_keys, cache_hits
485484

@@ -541,7 +540,7 @@ def store(
541540
keys = self._index.load(
542541
data=[cache_entry.to_dict()],
543542
ttl=ttl,
544-
id_field=self.entry_id_field_name,
543+
id_field=ENTRY_ID_FIELD_NAME,
545544
)
546545
return keys[0]
547546

@@ -605,7 +604,7 @@ async def astore(
605604
keys = await aindex.load(
606605
data=[cache_entry.to_dict()],
607606
ttl=ttl,
608-
id_field=self.entry_id_field_name,
607+
id_field=ENTRY_ID_FIELD_NAME,
609608
)
610609
return keys[0]
611610

@@ -629,21 +628,19 @@ def update(self, key: str, **kwargs) -> None:
629628
for k, v in kwargs.items():
630629

631630
# Make sure the item is in the index schema
632-
if k not in set(
633-
self._index.schema.field_names + [self.metadata_field_name]
634-
):
631+
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
635632
raise ValueError(f"{k} is not a valid field within the cache entry")
636633

637634
# Check for metadata and deserialize
638-
if k == self.metadata_field_name:
635+
if k == METADATA_FIELD_NAME:
639636
if isinstance(v, dict):
640637
kwargs[k] = serialize(v)
641638
else:
642639
raise TypeError(
643640
"If specified, cached metadata must be a dictionary."
644641
)
645642

646-
kwargs.update({self.updated_at_field_name: current_timestamp()})
643+
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})
647644

648645
self._index.client.hset(key, mapping=kwargs) # type: ignore
649646

@@ -674,21 +671,19 @@ async def aupdate(self, key: str, **kwargs) -> None:
674671
for k, v in kwargs.items():
675672

676673
# Make sure the item is in the index schema
677-
if k not in set(
678-
self._index.schema.field_names + [self.metadata_field_name]
679-
):
674+
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
680675
raise ValueError(f"{k} is not a valid field within the cache entry")
681676

682677
# Check for metadata and deserialize
683-
if k == self.metadata_field_name:
678+
if k == METADATA_FIELD_NAME:
684679
if isinstance(v, dict):
685680
kwargs[k] = serialize(v)
686681
else:
687682
raise TypeError(
688683
"If specified, cached metadata must be a dictionary."
689684
)
690685

691-
kwargs.update({self.updated_at_field_name: current_timestamp()})
686+
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})
692687

693688
await aindex.load(data=[kwargs], keys=[key])
694689

redisvl/extensions/router/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pydantic.v1 import BaseModel, Field, validator
55

6+
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
67
from redisvl.schema import IndexInfo, IndexSchema
78

89

@@ -104,7 +105,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema"
104105
{"name": "route_name", "type": "tag"},
105106
{"name": "reference", "type": "text"},
106107
{
107-
"name": "vector",
108+
"name": ROUTE_VECTOR_FIELD_NAME,
108109
"type": "vector",
109110
"attrs": {
110111
"algorithm": "flat",

redisvl/extensions/router/semantic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
99
from redis.exceptions import ResponseError
1010

11+
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
1112
from redisvl.extensions.router.schema import (
1213
DistanceAggregationMethod,
1314
Route,
@@ -226,7 +227,7 @@ def _classify_route(
226227
"""Classify to a single route using a vector."""
227228
vector_range_query = RangeQuery(
228229
vector=vector,
229-
vector_field_name="vector",
230+
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
230231
distance_threshold=distance_threshold,
231232
return_fields=["route_name"],
232233
)
@@ -278,7 +279,7 @@ def _classify_multi_route(
278279
"""Classify to multiple routes, up to max_k (int), using a vector."""
279280
vector_range_query = RangeQuery(
280281
vector=vector,
281-
vector_field_name="vector",
282+
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
282283
distance_threshold=distance_threshold,
283284
return_fields=["route_name"],
284285
)

0 commit comments

Comments
 (0)