Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,12 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912

return HologresIndexConfig

# DB.Pinecone, DB.Chroma, DB.Redis
if self == DB.Chroma:
from .chroma.config import ChromaIndexConfig

return ChromaIndexConfig

# DB.Pinecone, DB.Redis
return EmptyDBCaseConfig


Expand Down
99 changes: 37 additions & 62 deletions vectordb_bench/backend/clients/chroma/chroma.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from contextlib import contextmanager
from typing import Any

import chromadb

from ..api import DBCaseConfig, VectorDB
from ..api import VectorDB
from .config import ChromaIndexConfig

log = logging.getLogger(__name__)

Expand All @@ -21,35 +21,34 @@ def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
db_case_config: ChromaIndexConfig,
collection_name: str = "VectorDBBenchCollection",
drop_old: bool = False,
**kwargs,
):
self.db_config = db_config
self.case_config = db_case_config
self.collection_name = "example2"
self.collection_name = collection_name

client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
client = chromadb.HttpClient(**db_config)
assert client.heartbeat() is not None

if drop_old:
try:
client.reset() # Reset the database
client.reset()
except Exception:
drop_old = False
log.info(f"Chroma client drop_old collection: {self.collection_name}")

@contextmanager
def init(self) -> None:
"""create and destory connections to database.

Examples:
>>> with self.init():
>>> self.insert_embeddings()
"""
# create connection
self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
self.client = None
self.collection = None

self.collection = self.client.get_or_create_collection("example2")
@contextmanager
def init(self):
self.client = chromadb.HttpClient(**self.db_config)
self.collection = self.client.get_or_create_collection(
name=self.collection_name, configuration=self.case_config.index_param()
)
yield
self.client = None
self.collection = None
Expand All @@ -58,62 +57,38 @@ def ready_to_search(self) -> bool:
pass

def optimize(self, data_size: int | None = None):
pass
assert self.collection is not None, "Please call self.init() before"
try:
self.collection.modify(configuration=self.case_config.search_param())
except Exception as e:
log.warning(f"Optimize error: {e}")
raise

def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
**kwargs,
) -> tuple[int, Exception]:
"""Insert embeddings into the database.

Args:
embeddings(list[list[float]]): list of embeddings
metadata(list[int]): list of metadata
kwargs: other arguments

Returns:
tuple[int, Exception]: number of embeddings inserted and exception if any
"""
ids = [str(i) for i in metadata]
metadata = [{"id": int(i)} for i in metadata]
assert self.collection is not None, "Please call self.init() before"
ids = [f"{idx}" for idx in metadata]
metadata = [{"index": mid} for mid in metadata]
try:
if len(embeddings) > 0:
self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadata)
except Exception as e:
log.warning(f"Failed to insert data: error: {e!s}")
log.warning(f"Failed to insert data: {e}")
return 0, e
return len(embeddings), None

return len(metadata), None

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict:
"""Search embeddings from the database.
Args:
embedding(list[float]): embedding to search
k(int): number of results to return
kwargs: other arguments

Returns:
Dict {ids: list[list[int]],
embedding: list[list[float]]
distance: list[list[float]]}
"""
self, query: list[float], k: int = 100, filters: dict | None = None, timeout: int | None = None
) -> list[int]:
assert self.client is not None, "Please call self.init() before"
if filters:
# assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
id_value = filters.get("id")
results = self.collection.query(
query_embeddings=query,
n_results=k,
where={"id": {"$gt": id_value}},
query_embeddings=[query], n_results=k, where={"id": {"$gt": filters.get("id")}}
)
# return list of id's in results
return [int(i) for i in results.get("ids")[0]]
results = self.collection.query(query_embeddings=query, n_results=k)
return [int(i) for i in results.get("ids")[0]]
else:
results = self.collection.query(query_embeddings=[query], n_results=k)
return [int(idx) for idx in results["ids"][0]]
54 changes: 54 additions & 0 deletions vectordb_bench/backend/clients/chroma/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Annotated, Unpack

import click
from pydantic import SecretStr

from vectordb_bench.backend.clients import DB
from vectordb_bench.cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)

DBTYPE = DB.Chroma


class ChromaTypeDict(CommonTypedDict):
host: Annotated[
str,
click.option("--host", type=str, help="Chroma host", default="localhost"),
]
port: Annotated[int, click.option("--port", type=int, help="Chroma port", default=8000)]
m: Annotated[
int,
click.option("--m", type=int, help="HNSW Maximum Neighbors", default=16),
]
ef_construct: Annotated[
int,
click.option("--ef-construct", type=int, help="HNSW efConstruct", default=256),
]
ef_search: Annotated[
int,
click.option("--ef-search", type=int, help="HNSW efSearch", default=256),
]


@cli.command()
@click_parameter_decorators_from_typed_dict(ChromaTypeDict)
def Chroma(**parameters: Unpack[ChromaTypeDict]):
from .config import ChromaConfig, ChromaIndexConfig

run(
db=DBTYPE,
db_config=ChromaConfig(
host=SecretStr(parameters["host"]),
port=parameters["port"],
),
db_case_config=ChromaIndexConfig(
m=parameters["m"],
ef_construct=parameters["ef_construct"],
ef_search=parameters["ef_search"],
),
**parameters,
)
38 changes: 31 additions & 7 deletions vectordb_bench/backend/clients/chroma/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,40 @@
from pydantic import SecretStr

from ..api import DBConfig
from ..api import DBCaseConfig, DBConfig, MetricType


class ChromaConfig(DBConfig):
password: SecretStr
host: SecretStr
port: int
host: SecretStr = "localhost"
port: int = 8000

def to_dict(self) -> dict:
return {"host": self.host.get_secret_value(), "port": self.port}


class ChromaIndexConfig(ChromaConfig, DBCaseConfig):
metric_type: MetricType = "cosine"
m: int = 16
ef_construct: int = 100
ef_search: int | None = 100

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2"
if self.metric_type == MetricType.IP:
return "ip"
if self.metric_type == MetricType.COSINE:
return "cosine"
raise ValueError("Unsupported metric type: %s" % self.metric_type)

def index_param(self):
return {
"host": self.host.get_secret_value(),
"port": self.port,
"password": self.password.get_secret_value(),
"hnsw": {
"space": self.parse_metric(),
"max_neighbors": self.m,
"ef_construction": self.ef_construct,
"ef_search": self.search_param().get("ef_search", 100),
}
}

def search_param(self) -> dict:
return {"ef_search": self.ef_search}
2 changes: 2 additions & 0 deletions vectordb_bench/cli/vectordbbench.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..backend.clients.alloydb.cli import AlloyDBScaNN
from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
from ..backend.clients.chroma.cli import Chroma
from ..backend.clients.clickhouse.cli import Clickhouse
from ..backend.clients.hologres.cli import HologresHGraph
from ..backend.clients.lancedb.cli import LanceDB
Expand Down Expand Up @@ -50,6 +51,7 @@
cli.add_command(QdrantLocal)
cli.add_command(BatchCli)
cli.add_command(S3Vectors)
cli.add_command(Chroma)


if __name__ == "__main__":
Expand Down