Skip to content

Commit 9437b68

Browse files
committed
Add Cassandra vector store implementation
1 parent ea46820 commit 9437b68

File tree

7 files changed

+212
-3
lines changed

7 files changed

+212
-3
lines changed

dictionary.txt

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ numpy
6363
pypi
6464
nbformat
6565
semversioner
66+
cassio
6667

6768
# Library Methods
6869
iterrows

graphrag/index/verbs/text/embed/text_embed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def text_embed(
7575
max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai
7676
organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai
7777
vector_store: # The optional configuration for the vector store
78-
type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb
78+
type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb, cassandra
7979
<...>
8080
```
8181
"""

graphrag/vector_stores/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
from .azure_ai_search import AzureAISearch
77
from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult
8+
from .cassandra import CassandraVectorStore
89
from .lancedb import LanceDBVectorStore
910
from .typing import VectorStoreFactory, VectorStoreType
1011

1112
__all__ = [
1213
"AzureAISearch",
1314
"BaseVectorStore",
15+
"CassandraVectorStore",
1416
"LanceDBVectorStore",
1517
"VectorStoreDocument",
1618
"VectorStoreFactory",

graphrag/vector_stores/cassandra.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""The Apache Cassandra vector store implementation package."""
5+
6+
from typing import Any
7+
8+
import cassio
9+
from cassandra.cluster import Session
10+
from cassio.table import MetadataVectorCassandraTable
11+
from typing_extensions import override
12+
13+
from graphrag.model.types import TextEmbedder
14+
15+
from .base import (
16+
DEFAULT_VECTOR_SIZE,
17+
BaseVectorStore,
18+
VectorStoreDocument,
19+
VectorStoreSearchResult,
20+
)
21+
22+
23+
class CassandraVectorStore(BaseVectorStore):
24+
"""The Apache Cassandra vector storage implementation."""
25+
26+
@override
27+
def connect(
28+
self,
29+
*,
30+
session: Session | None = None,
31+
keyspace: str | None = None,
32+
**kwargs: Any,
33+
) -> None:
34+
"""Connect to the Apache Cassandra database.
35+
36+
Parameters
37+
----------
38+
session :
39+
The Cassandra session. If not provided, it is resolved from cassio.
40+
keyspace :
41+
The Cassandra keyspace. If not provided, it is resolved from cassio.
42+
"""
43+
self.db_connection = cassio.config.check_resolve_session(session)
44+
self.keyspace = cassio.config.check_resolve_keyspace(keyspace)
45+
46+
@override
47+
def load_documents(
48+
self, documents: list[VectorStoreDocument], overwrite: bool = True
49+
) -> None:
50+
if overwrite:
51+
self.db_connection.execute(
52+
f"DROP TABLE IF EXISTS {self.keyspace}.{self.collection_name};"
53+
)
54+
55+
if not documents:
56+
return
57+
58+
if not self.document_collection or overwrite:
59+
dimension = DEFAULT_VECTOR_SIZE
60+
for doc in documents:
61+
if doc.vector:
62+
dimension = len(doc.vector)
63+
break
64+
self.document_collection = MetadataVectorCassandraTable(
65+
table=self.collection_name,
66+
vector_dimension=dimension,
67+
primary_key_type="TEXT",
68+
)
69+
70+
futures = [
71+
self.document_collection.put_async(
72+
row_id=doc.id,
73+
body_blob=doc.text,
74+
vector=doc.vector,
75+
metadata=doc.attributes,
76+
)
77+
for doc in documents
78+
if doc.vector
79+
]
80+
81+
for future in futures:
82+
future.result()
83+
84+
@override
85+
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
86+
msg = "Cassandra vector store doesn't support filtering by IDs."
87+
raise NotImplementedError(msg)
88+
89+
@override
90+
def similarity_search_by_vector(
91+
self, query_embedding: list[float], k: int = 10, **kwargs: Any
92+
) -> list[VectorStoreSearchResult]:
93+
response = self.document_collection.metric_ann_search(
94+
vector=query_embedding,
95+
n=k,
96+
metric="cos",
97+
**kwargs,
98+
)
99+
100+
return [
101+
VectorStoreSearchResult(
102+
document=VectorStoreDocument(
103+
id=doc["row_id"],
104+
text=doc["body_blob"],
105+
vector=doc["vector"],
106+
attributes=doc["metadata"],
107+
),
108+
score=doc["distance"],
109+
)
110+
for doc in response
111+
]
112+
113+
@override
114+
def similarity_search_by_text(
115+
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
116+
) -> list[VectorStoreSearchResult]:
117+
query_embedding = text_embedder(text)
118+
if query_embedding:
119+
return self.similarity_search_by_vector(
120+
query_embedding=query_embedding, k=k, **kwargs
121+
)
122+
return []

graphrag/vector_stores/typing.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from enum import Enum
77
from typing import ClassVar
88

9+
from . import BaseVectorStore, CassandraVectorStore
910
from .azure_ai_search import AzureAISearch
1011
from .lancedb import LanceDBVectorStore
1112

@@ -15,6 +16,7 @@ class VectorStoreType(str, Enum):
1516

1617
LanceDB = "lancedb"
1718
AzureAISearch = "azure_ai_search"
19+
Cassandra = "cassandra"
1820

1921

2022
class VectorStoreFactory:
@@ -30,13 +32,15 @@ def register(cls, vector_store_type: str, vector_store: type):
3032
@classmethod
3133
def get_vector_store(
3234
cls, vector_store_type: VectorStoreType | str, kwargs: dict
33-
) -> LanceDBVectorStore | AzureAISearch:
35+
) -> BaseVectorStore:
3436
"""Get the vector store type from a string."""
3537
match vector_store_type:
3638
case VectorStoreType.LanceDB:
3739
return LanceDBVectorStore(**kwargs)
3840
case VectorStoreType.AzureAISearch:
3941
return AzureAISearch(**kwargs)
42+
case VectorStoreType.Cassandra:
43+
return CassandraVectorStore(**kwargs)
4044
case _:
4145
if vector_store_type in cls.vector_store_types:
4246
return cls.vector_store_types[vector_store_type](**kwargs)

poetry.lock

+79-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ datashaper = "^0.0.49"
5050
azure-search-documents = "^11.4.0"
5151
lancedb = "^0.13.0"
5252

53+
5354
# Async IO
5455
aiolimiter = "^1.1.0"
5556
aiofiles = "^24.1.0"
@@ -87,6 +88,7 @@ azure-identity = "^1.17.1"
8788
json-repair = "^0.28.4"
8889

8990
future = "^1.0.0" # Needed until graspologic fixes their dependency
91+
cassio = "^0.1.9"
9092

9193
[tool.poetry.group.dev.dependencies]
9294
coverage = "^7.6.0"

0 commit comments

Comments
 (0)