Skip to content

Commit

Permalink
community[minor]: Add support for metadata indexing policy in Cassand…
Browse files Browse the repository at this point in the history
…ra vector store (langchain-ai#22548)

This PR adds a constructor `metadata_indexing` parameter to the
Cassandra vector store to allow optional fine-tuning of which fields of
the metadata are to be indexed.

This is a feature supported by the underlying CassIO library. Indexing
mode of "all", "none" or deny- and allow-list based choices are
available.

The rationale is, in some cases it's advisable to programmatically
exclude some portions of the metadata from the index if one knows in
advance they won't ever be used at search-time. this keeps the index
more lightweight and performant and avoids limitations on the length of
_indexed_ strings.

I added a integration test of the feature. I also added the possibility
of running the integration test with Cassandra on an arbitrary IP
address (e.g. Dockerized), via
`CASSANDRA_CONTACT_POINTS=10.1.1.5,10.1.1.6 poetry run pytest [...]` or
similar.

While I was at it, I added a line to the `.gitignore` since the mypy
_test_ cache was not ignored yet.

My X (Twitter) handle: @rsprrs.
  • Loading branch information
hemidactylus authored Jun 5, 2024
1 parent c3d4126 commit 328d0c9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ env.bak/

# mypy
.mypy_cache/
.mypy_cache_test/
.dmypy.json
dmypy.json

Expand Down
24 changes: 22 additions & 2 deletions libs/community/langchain_community/vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
*,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
setup_mode: SetupMode = SetupMode.SYNC,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
) -> None:
"""Apache Cassandra(R) for vector-store workloads.
Expand All @@ -83,13 +84,24 @@ def __init__(
embedding: Embedding function to use.
session: Cassandra driver session. If not provided, it is resolved from
cassio.
keyspace: Cassandra key space. If not provided, it is resolved from cassio.
keyspace: Cassandra keyspace. If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
setup_mode: mode used to create the Cassandra table (SYNC,
ASYNC or OFF).
metadata_indexing: Optional specification of a metadata indexing policy,
i.e. to fine-tune which of the metadata fields are indexed.
It can be a string ("all" or "none"), or a 2-tuple. The following
means that all fields except 'f1', 'f2' ... are NOT indexed:
metadata_indexing=("allowlist", ["f1", "f2", ...])
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
metadata_indexing("denylist", ["g1", "g2", ...])
The default is to index every metadata field.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
"""
try:
from cassio.table import MetadataVectorCassandraTable
Expand Down Expand Up @@ -125,7 +137,7 @@ def __init__(
keyspace=keyspace,
table=table_name,
vector_dimension=embedding_dimension,
metadata_indexing="all",
metadata_indexing=metadata_indexing,
primary_key_type="TEXT",
skip_provisioning=setup_mode == SetupMode.OFF,
**kwargs,
Expand Down Expand Up @@ -885,6 +897,7 @@ def from_texts(
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
Expand Down Expand Up @@ -915,6 +928,7 @@ def from_texts(
table_name=table_name,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
)
store.add_texts(
texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size
Expand All @@ -935,6 +949,7 @@ async def afrom_texts(
concurrency: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
Expand Down Expand Up @@ -966,6 +981,7 @@ async def afrom_texts(
ttl_seconds=ttl_seconds,
setup_mode=SetupMode.ASYNC,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
)
await store.aadd_texts(
texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency
Expand All @@ -985,6 +1001,7 @@ def from_documents(
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.
Expand Down Expand Up @@ -1020,6 +1037,7 @@ def from_documents(
batch_size=batch_size,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)

Expand All @@ -1036,6 +1054,7 @@ async def afrom_documents(
concurrency: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.
Expand Down Expand Up @@ -1071,6 +1090,7 @@ async def afrom_documents(
concurrency=concurrency,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Test Cassandra functionality."""
import asyncio
import os
import time
from typing import List, Optional, Type
from typing import Iterable, List, Optional, Tuple, Type, Union

import pytest
from langchain_core.documents import Document

from langchain_community.vectorstores import Cassandra
Expand All @@ -19,13 +21,22 @@ def _vectorstore_from_texts(
metadatas: Optional[List[dict]] = None,
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
drop: bool = True,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
table_name: str = "vector_test_table",
) -> Cassandra:
from cassandra.cluster import Cluster

keyspace = "vector_test_keyspace"
table_name = "vector_test_table"
# get db connection
cluster = Cluster()
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = [
cp.strip()
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
if cp.strip()
]
else:
contact_points = None
cluster = Cluster(contact_points)
session = cluster.connect()
# ensure keyspace exists
session.execute(
Expand All @@ -45,6 +56,7 @@ def _vectorstore_from_texts(
session=session,
keyspace=keyspace,
table_name=table_name,
metadata_indexing=metadata_indexing,
)


Expand All @@ -53,13 +65,22 @@ async def _vectorstore_from_texts_async(
metadatas: Optional[List[dict]] = None,
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
drop: bool = True,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
table_name: str = "vector_test_table",
) -> Cassandra:
from cassandra.cluster import Cluster

keyspace = "vector_test_keyspace"
table_name = "vector_test_table"
# get db connection
cluster = Cluster()
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = [
cp.strip()
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
if cp.strip()
]
else:
contact_points = None
cluster = Cluster(contact_points)
session = cluster.connect()
# ensure keyspace exists
session.execute(
Expand Down Expand Up @@ -268,3 +289,29 @@ async def test_cassandra_adelete() -> None:
await asyncio.sleep(0.3)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0


def test_cassandra_metadata_indexing() -> None:
"""Test comparing metadata indexing policies."""
texts = ["foo"]
metadatas = [{"field1": "a", "field2": "b"}]
vstore_all = _vectorstore_from_texts(texts, metadatas=metadatas)
vstore_f1 = _vectorstore_from_texts(
texts,
metadatas=metadatas,
metadata_indexing=("allowlist", ["field1"]),
table_name="vector_test_table_indexing",
)

output_all = vstore_all.similarity_search("bar", k=2)
output_f1 = vstore_f1.similarity_search("bar", filter={"field1": "a"}, k=2)
output_f1_no = vstore_f1.similarity_search("bar", filter={"field1": "Z"}, k=2)
assert len(output_all) == 1
assert output_all[0].metadata == metadatas[0]
assert len(output_f1) == 1
assert output_f1[0].metadata == metadatas[0]
assert len(output_f1_no) == 0

with pytest.raises(ValueError):
# "Non-indexed metadata fields cannot be used in queries."
vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)

0 comments on commit 328d0c9

Please sign in to comment.