Skip to content

Commit 5803ceb

Browse files
committed
Expand dtype changes to SemanticSessionManager
1 parent 8f2e1f5 commit 5803ceb

File tree

4 files changed

+129
-11
lines changed

4 files changed

+129
-11
lines changed

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
from redisvl.index import SearchIndex
2020
from redisvl.query import FilterQuery, RangeQuery
2121
from redisvl.query.filter import Tag
22-
from redisvl.utils.utils import validate_vector_dims
22+
from redisvl.utils.utils import deprecated_argument, validate_vector_dims
2323
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
2424

2525

2626
class SemanticSessionManager(BaseSessionManager):
2727

28+
@deprecated_argument("dtype", "vectorizer")
2829
def __init__(
2930
self,
3031
name: str,
@@ -70,16 +71,30 @@ def __init__(
7071
super().__init__(name, session_tag)
7172

7273
prefix = prefix or name
74+
dtype = kwargs.get("dtype")
7375

74-
self._vectorizer = vectorizer or HFTextVectorizer(
75-
model="sentence-transformers/msmarco-distilbert-cos-v5"
76-
)
76+
# Validate a provided vectorizer or set the default
77+
if vectorizer:
78+
if not isinstance(vectorizer, BaseVectorizer):
79+
raise TypeError("Must provide a valid redisvl.vectorizer class.")
80+
if dtype and vectorizer.dtype != dtype:
81+
raise ValueError(
82+
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
83+
)
84+
else:
85+
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
86+
87+
vectorizer = HFTextVectorizer(
88+
model="sentence-transformers/msmarco-distilbert-cos-v5",
89+
**vectorizer_kwargs,
90+
)
91+
92+
self._vectorizer = vectorizer
7793

7894
self.set_distance_threshold(distance_threshold)
7995

80-
dtype = kwargs.get("dtype", "float32")
8196
schema = SemanticSessionIndexSchema.from_params(
82-
name, prefix, self._vectorizer.dims, dtype
97+
name, prefix, self._vectorizer.dims, vectorizer.dtype
8398
)
8499

85100
self._index = SearchIndex(schema=schema)
@@ -215,7 +230,7 @@ def get_relevant(
215230
num_results=top_k,
216231
return_score=True,
217232
filter_expression=session_filter,
218-
dtype=self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
233+
dtype=self._vectorizer.dtype,
219234
)
220235
messages = self._index.query(query)
221236

@@ -341,7 +356,7 @@ def add_messages(
341356
if TOOL_FIELD_NAME in message:
342357
chat_message.tool_call_id = message[TOOL_FIELD_NAME]
343358

344-
chat_messages.append(chat_message.to_dict(dtype=self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.datatype)) # type: ignore[union-attr]
359+
chat_messages.append(chat_message.to_dict(dtype=self._vectorizer.dtype))
345360

346361
self._index.load(data=chat_messages, id_field=ID_FIELD_NAME)
347362

tests/integration/test_llmcache.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from collections import namedtuple
44
from time import sleep, time
5+
import warnings
56

67
import pytest
78
from pydantic.v1 import ValidationError
@@ -71,6 +72,13 @@ def cache_with_redis_client(vectorizer, client):
7172
cache_instance._index.delete(True) # Clean up index
7273

7374

75+
@pytest.fixture(autouse=True)
76+
def disable_deprecation_warnings():
77+
with warnings.catch_warnings():
78+
warnings.simplefilter("ignore")
79+
yield
80+
81+
7482
def test_bad_ttl(cache):
7583
with pytest.raises(ValueError):
7684
cache.set_ttl(2.5)
@@ -892,6 +900,7 @@ def test_vectorizer_dtype_mismatch():
892900
name="test_dtype_mismatch",
893901
dtype="float32",
894902
vectorizer=HFTextVectorizer(dtype="float16"),
903+
overwrite=True,
895904
)
896905

897906

@@ -900,15 +909,18 @@ def test_invalid_vectorizer():
900909
SemanticCache(
901910
name="test_invalid_vectorizer",
902911
vectorizer="invalid_vectorizer", # type: ignore
912+
overwrite=True,
903913
)
904914

905915

906916
def test_passes_through_dtype_to_default_vectorizer():
907917
# The default is float32, so we should see float64 if we pass it in.
908-
cache = SemanticCache(name="test_pass_through_dtype)", dtype="float64")
918+
cache = SemanticCache(
919+
name="test_pass_through_dtype", dtype="float64", overwrite=True
920+
)
909921
assert cache._vectorizer.dtype == "float64"
910922

911923

912924
def test_deprecated_dtype_argument():
913925
with pytest.warns(DeprecationWarning):
914-
SemanticCache(name="test_deprecated_dtype", dtype="float32")
926+
SemanticCache(name="test_deprecated_dtype", dtype="float32", overwrite=True)

tests/integration/test_semantic_router.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import os
22
import pathlib
3+
import warnings
34

45
import pytest
56
from redis.exceptions import ConnectionError
67

78
from redisvl.exceptions import RedisModuleVersionError
9+
from redisvl.extensions.llmcache.semantic import SemanticCache
810
from redisvl.extensions.router import SemanticRouter
911
from redisvl.extensions.router.schema import Route, RoutingConfig
1012
from redisvl.redis.connection import compare_versions
13+
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
1114

1215

1316
def get_base_path():
@@ -45,6 +48,13 @@ def semantic_router(client, routes):
4548
router.delete()
4649

4750

51+
@pytest.fixture(autouse=True)
52+
def disable_deprecation_warnings():
53+
with warnings.catch_warnings():
54+
warnings.simplefilter("ignore")
55+
yield
56+
57+
4858
def test_initialize_router(semantic_router):
4959
assert semantic_router.name == "test-router"
5060
assert len(semantic_router.routes) == 2
@@ -303,3 +313,44 @@ def test_bad_dtype_connecting_to_exiting_router(redis_url, routes):
303313
dtype="float16",
304314
redis_url=redis_url,
305315
)
316+
317+
318+
def test_vectorizer_dtype_mismatch(routes):
319+
with pytest.raises(ValueError):
320+
SemanticRouter(
321+
name="test_dtype_mismatch",
322+
routes=routes,
323+
dtype="float32",
324+
vectorizer=HFTextVectorizer(dtype="float16"),
325+
overwrite=True,
326+
)
327+
328+
329+
def test_invalid_vectorizer(routes):
330+
with pytest.raises(TypeError):
331+
SemanticRouter(
332+
name="test_invalid_vectorizer",
333+
vectorizer="invalid_vectorizer", # type: ignore
334+
overwrite=True,
335+
)
336+
337+
338+
def test_passes_through_dtype_to_default_vectorizer(routes):
339+
# The default is float32, so we should see float64 if we pass it in.
340+
router = SemanticRouter(
341+
name="test_pass_through_dtype",
342+
routes=routes,
343+
dtype="float64",
344+
overwrite=True,
345+
)
346+
assert router.vectorizer.dtype == "float64"
347+
348+
349+
def test_deprecated_dtype_argument(routes):
350+
with pytest.warns(DeprecationWarning):
351+
SemanticRouter(
352+
name="test_deprecated_dtype",
353+
routes=routes,
354+
dtype="float32",
355+
overwrite=True,
356+
)

tests/integration/test_session_manager.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
import warnings
22

33
import pytest
44
from redis.exceptions import ConnectionError
@@ -9,6 +9,7 @@
99
SemanticSessionManager,
1010
StandardSessionManager,
1111
)
12+
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
1213

1314

1415
@pytest.fixture
@@ -26,6 +27,13 @@ def semantic_session(app_name, client):
2627
session.delete()
2728

2829

30+
@pytest.fixture(autouse=True)
31+
def disable_deprecation_warnings():
32+
with warnings.catch_warnings():
33+
warnings.simplefilter("ignore")
34+
yield
35+
36+
2937
# test standard session manager
3038
def test_specify_redis_client(client):
3139
session = StandardSessionManager(name="test_app", redis_client=client)
@@ -579,3 +587,35 @@ def test_bad_dtype_connecting_to_exiting_session(redis_url):
579587
bad_type = SemanticSessionManager(
580588
name="float64 session", dtype="float16", redis_url=redis_url
581589
)
590+
591+
592+
def test_vectorizer_dtype_mismatch():
593+
with pytest.raises(ValueError):
594+
SemanticSessionManager(
595+
name="test_dtype_mismatch",
596+
dtype="float32",
597+
vectorizer=HFTextVectorizer(dtype="float16"),
598+
overwrite=True,
599+
)
600+
601+
602+
def test_invalid_vectorizer():
603+
with pytest.raises(TypeError):
604+
SemanticSessionManager(
605+
name="test_invalid_vectorizer",
606+
vectorizer="invalid_vectorizer", # type: ignore
607+
overwrite=True,
608+
)
609+
610+
611+
def test_passes_through_dtype_to_default_vectorizer():
612+
# The default is float32, so we should see float64 if we pass it in.
613+
cache = SemanticSessionManager(
614+
name="test_pass_through_dtype", dtype="float64", overwrite=True
615+
)
616+
assert cache._vectorizer.dtype == "float64"
617+
618+
619+
def test_deprecated_dtype_argument():
620+
with pytest.warns(DeprecationWarning):
621+
SemanticSessionManager(name="float64 session", dtype="float64", overwrite=True)

0 commit comments

Comments
 (0)