From b0e4eea5131e488297346ec9d31136653d771662 Mon Sep 17 00:00:00 2001 From: Bartosz Pietrzak Date: Sun, 28 Sep 2025 23:59:49 +0200 Subject: [PATCH 1/2] Add Chroma cli to VectorDBBench --- .../backend/clients/chroma/chroma.py | 122 +++++++++--------- vectordb_bench/backend/clients/chroma/cli.py | 55 ++++++++ .../backend/clients/chroma/config.py | 42 +++++- vectordb_bench/cli/vectordbbench.py | 3 + 4 files changed, 154 insertions(+), 68 deletions(-) create mode 100644 vectordb_bench/backend/clients/chroma/cli.py diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py index 7f2cd2f1c..3cd6cacde 100644 --- a/vectordb_bench/backend/clients/chroma/chroma.py +++ b/vectordb_bench/backend/clients/chroma/chroma.py @@ -1,10 +1,9 @@ import logging from contextlib import contextmanager -from typing import Any import chromadb -from ..api import DBCaseConfig, VectorDB +from ..api import VectorDB log = logging.getLogger(__name__) @@ -16,104 +15,103 @@ class ChromaClient(VectorDB): To change to running in process, modify the HttpClient() in __init__() and init(). """ - def __init__( self, dim: int, db_config: dict, - db_case_config: DBCaseConfig, + db_case_config, + collection_name: str = "VectorDBBenchCollection", drop_old: bool = False, - **kwargs, + **kwargs ): self.db_config = db_config self.case_config = db_case_config - self.collection_name = "example2" + self.collection_name = collection_name - client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"]) + client = chromadb.HttpClient(**db_config) assert client.heartbeat() is not None + if drop_old: try: - client.reset() # Reset the database + client.reset() except Exception: drop_old = False - log.info(f"Chroma client drop_old collection: {self.collection_name}") + log.info("Chroma client drop_old collection: " + + f"{self.collection_name}") - @contextmanager - def init(self) -> None: - """create and destory connections to database. - - Examples: - >>> with self.init(): - >>> self.insert_embeddings() - """ - # create connection - self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"]) - - self.collection = self.client.get_or_create_collection("example2") - yield self.client = None self.collection = None + @contextmanager + def init(self): + try: + self.client = chromadb.HttpClient( + host=self.db_config.get("host", "localhost"), + port=self.db_config.get("port", 8000) + ) + + self.collection = self.client.get_or_create_collection( + name=self.collection_name, + configuration=self.case_config.index_param() + ) + yield + self.client = None + self.collection = None + except Exception as e: + log.error(f"Failed to initialize Chroma client: {e}") + raise e + def ready_to_search(self) -> bool: pass def optimize(self, data_size: int | None = None): - pass + assert self.collection is not None, "Please call self.init() before" + try: + self.collection.modify( + configuration=self.case_config.search_param() + ) + except Exception as e: + log.warning(f"Optimize error: {e}") + raise e def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], - **kwargs: Any, + **kwargs, ) -> tuple[int, Exception]: - """Insert embeddings into the database. - - Args: - embeddings(list[list[float]]): list of embeddings - metadata(list[int]): list of metadata - kwargs: other arguments - - Returns: - tuple[int, Exception]: number of embeddings inserted and exception if any - """ - ids = [str(i) for i in metadata] - metadata = [{"id": int(i)} for i in metadata] + assert self.collection is not None, "Please call self.init() before" + ids = [f"{idx}" for idx in metadata] + metadata = [{"index": mid} for mid in metadata] try: - if len(embeddings) > 0: - self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata) + self.collection.add( + ids=ids, + embeddings=embeddings, + metadatas=metadata + ) except Exception as e: - log.warning(f"Failed to insert data: error: {e!s}") + log.info(f"Failed to insert data: {e}") return 0, e - return len(embeddings), None + + return len(metadata), None def search_embedding( self, query: list[float], k: int = 100, filters: dict | None = None, - timeout: int | None = None, - **kwargs: Any, - ) -> dict: - """Search embeddings from the database. - Args: - embedding(list[float]): embedding to search - k(int): number of results to return - kwargs: other arguments - - Returns: - Dict {ids: list[list[int]], - embedding: list[list[float]] - distance: list[list[float]]} - """ + timeout: int | None = None + ) -> list[int]: + assert self.client is not None, "Please call self.init() before" if filters: - # assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000} - id_value = filters.get("id") results = self.collection.query( - query_embeddings=query, + query_embeddings=[query], n_results=k, - where={"id": {"$gt": id_value}}, + where={"id": {"$gt": filters.get("id")}} + ) + else: + results = self.collection.query( + query_embeddings=[query], + n_results=k ) - # return list of id's in results - return [int(i) for i in results.get("ids")[0]] - results = self.collection.query(query_embeddings=query, n_results=k) - return [int(i) for i in results.get("ids")[0]] + return [int(idx) for idx in results['ids'][0]] diff --git a/vectordb_bench/backend/clients/chroma/cli.py b/vectordb_bench/backend/clients/chroma/cli.py new file mode 100644 index 000000000..a1975adf1 --- /dev/null +++ b/vectordb_bench/backend/clients/chroma/cli.py @@ -0,0 +1,55 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB +from vectordb_bench.cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + + +DBTYPE = DB.Chroma + + +class ChromaTypeDict(CommonTypedDict): + host: Annotated[ + str, + click.option("--host", type=str, help="Chroma host") + ] + port: Annotated[ + int, + click.option("--port", type=int, help="Chroma port", default=8000) + ] + m: Annotated[ + int, + click.option("--m", type=int, help="HNSW Maximum Neighbors", default=16) + ] + ef_construct: Annotated[ + int, + click.option("--ef-construct", type=int, help="HNSW efConstruct", default=100) + ] + ef_search: Annotated[ + int, + click.option("--ef-search", type=int, help="HNSW efSearch", default=100) + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(ChromaTypeDict) +def Chroma(**parameters: Unpack[ChromaTypeDict]): + from .config import ChromaConfig, ChromaIndexConfig + run( + db=DBTYPE, + db_config=ChromaConfig(host=SecretStr(parameters["host"]), + port=parameters["port"]), + db_case_config=ChromaIndexConfig( + m=parameters["m"], + ef_construct=parameters["ef_construct"], + ef_search=parameters["ef_search"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/chroma/config.py b/vectordb_bench/backend/clients/chroma/config.py index af34cf513..35eb41a72 100644 --- a/vectordb_bench/backend/clients/chroma/config.py +++ b/vectordb_bench/backend/clients/chroma/config.py @@ -1,16 +1,46 @@ from pydantic import SecretStr -from ..api import DBConfig +from ..api import DBConfig, DBCaseConfig, MetricType class ChromaConfig(DBConfig): - password: SecretStr - host: SecretStr - port: int + host: SecretStr = "localhost" + port: int = 8000 def to_dict(self) -> dict: return { "host": self.host.get_secret_value(), - "port": self.port, - "password": self.password.get_secret_value(), + "port": self.port + } + + +class ChromaIndexConfig(ChromaConfig, DBCaseConfig): + metric_type: MetricType = "cosine" + m: int = 16 + ef_construct: int = 100 + ef_search: int | None = 100 + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "l2" + elif self.metric_type == MetricType.IP: + return "ip" + elif self.metric_type == MetricType.COSINE: + return "cosine" + else: + raise ValueError(f"Unsupported metric type: {self.metric_type}") + + def index_param(self): + return { + "hnsw": { + "space": self.parse_metric(), + "max_neighbors": self.m, + "ef_construction": self.ef_construct, + "ef_search": self.search_param().get("ef_search", 100), + } + } + + def search_param(self) -> dict: + return { + "ef_search": self.ef_search } diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 83dab74f6..5331e2962 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,5 +1,6 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch +from ..backend.clients.chroma.cli import Chroma from ..backend.clients.clickhouse.cli import Clickhouse from ..backend.clients.hologres.cli import HologresHGraph from ..backend.clients.lancedb.cli import LanceDB @@ -24,6 +25,7 @@ from .batch_cli import BatchCli from .cli import cli + cli.add_command(PgVectorHNSW) cli.add_command(PgVectoRSHNSW) cli.add_command(PgVectoRSIVFFlat) @@ -50,6 +52,7 @@ cli.add_command(QdrantLocal) cli.add_command(BatchCli) cli.add_command(S3Vectors) +cli.add_command(Chroma) if __name__ == "__main__": From 835215acbf039842c3d285928b05ecf332a42d0f Mon Sep 17 00:00:00 2001 From: bpietrzak <> Date: Tue, 30 Sep 2025 21:08:41 +0200 Subject: [PATCH 2/2] Fix lint issues --- vectordb_bench/backend/clients/__init__.py | 7 ++- .../backend/clients/chroma/chroma.py | 63 ++++++------------- vectordb_bench/backend/clients/chroma/cli.py | 21 +++---- .../backend/clients/chroma/config.py | 18 ++---- vectordb_bench/cli/vectordbbench.py | 1 - 5 files changed, 42 insertions(+), 68 deletions(-) diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 79a6f964a..19e670fdf 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -477,7 +477,12 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912 return HologresIndexConfig - # DB.Pinecone, DB.Chroma, DB.Redis + if self == DB.Chroma: + from .chroma.config import ChromaIndexConfig + + return ChromaIndexConfig + + # DB.Pinecone, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py index 3cd6cacde..6942f4fc9 100644 --- a/vectordb_bench/backend/clients/chroma/chroma.py +++ b/vectordb_bench/backend/clients/chroma/chroma.py @@ -4,6 +4,7 @@ import chromadb from ..api import VectorDB +from .config import ChromaIndexConfig log = logging.getLogger(__name__) @@ -15,14 +16,15 @@ class ChromaClient(VectorDB): To change to running in process, modify the HttpClient() in __init__() and init(). """ + def __init__( self, dim: int, db_config: dict, - db_case_config, + db_case_config: ChromaIndexConfig, collection_name: str = "VectorDBBenchCollection", drop_old: bool = False, - **kwargs + **kwargs, ): self.db_config = db_config self.case_config = db_case_config @@ -36,30 +38,20 @@ def __init__( client.reset() except Exception: drop_old = False - log.info("Chroma client drop_old collection: " - + f"{self.collection_name}") + log.info(f"Chroma client drop_old collection: {self.collection_name}") self.client = None self.collection = None @contextmanager def init(self): - try: - self.client = chromadb.HttpClient( - host=self.db_config.get("host", "localhost"), - port=self.db_config.get("port", 8000) - ) - - self.collection = self.client.get_or_create_collection( - name=self.collection_name, - configuration=self.case_config.index_param() - ) - yield - self.client = None - self.collection = None - except Exception as e: - log.error(f"Failed to initialize Chroma client: {e}") - raise e + self.client = chromadb.HttpClient(**self.db_config) + self.collection = self.client.get_or_create_collection( + name=self.collection_name, configuration=self.case_config.index_param() + ) + yield + self.client = None + self.collection = None def ready_to_search(self) -> bool: pass @@ -67,12 +59,10 @@ def ready_to_search(self) -> bool: def optimize(self, data_size: int | None = None): assert self.collection is not None, "Please call self.init() before" try: - self.collection.modify( - configuration=self.case_config.search_param() - ) + self.collection.modify(configuration=self.case_config.search_param()) except Exception as e: log.warning(f"Optimize error: {e}") - raise e + raise def insert_embeddings( self, @@ -84,34 +74,21 @@ def insert_embeddings( ids = [f"{idx}" for idx in metadata] metadata = [{"index": mid} for mid in metadata] try: - self.collection.add( - ids=ids, - embeddings=embeddings, - metadatas=metadata - ) + self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadata) except Exception as e: - log.info(f"Failed to insert data: {e}") + log.warning(f"Failed to insert data: {e}") return 0, e return len(metadata), None def search_embedding( - self, - query: list[float], - k: int = 100, - filters: dict | None = None, - timeout: int | None = None + self, query: list[float], k: int = 100, filters: dict | None = None, timeout: int | None = None ) -> list[int]: assert self.client is not None, "Please call self.init() before" if filters: results = self.collection.query( - query_embeddings=[query], - n_results=k, - where={"id": {"$gt": filters.get("id")}} + query_embeddings=[query], n_results=k, where={"id": {"$gt": filters.get("id")}} ) else: - results = self.collection.query( - query_embeddings=[query], - n_results=k - ) - return [int(idx) for idx in results['ids'][0]] + results = self.collection.query(query_embeddings=[query], n_results=k) + return [int(idx) for idx in results["ids"][0]] diff --git a/vectordb_bench/backend/clients/chroma/cli.py b/vectordb_bench/backend/clients/chroma/cli.py index a1975adf1..414381914 100644 --- a/vectordb_bench/backend/clients/chroma/cli.py +++ b/vectordb_bench/backend/clients/chroma/cli.py @@ -11,30 +11,26 @@ run, ) - DBTYPE = DB.Chroma class ChromaTypeDict(CommonTypedDict): host: Annotated[ str, - click.option("--host", type=str, help="Chroma host") - ] - port: Annotated[ - int, - click.option("--port", type=int, help="Chroma port", default=8000) + click.option("--host", type=str, help="Chroma host", default="localhost"), ] + port: Annotated[int, click.option("--port", type=int, help="Chroma port", default=8000)] m: Annotated[ int, - click.option("--m", type=int, help="HNSW Maximum Neighbors", default=16) + click.option("--m", type=int, help="HNSW Maximum Neighbors", default=16), ] ef_construct: Annotated[ int, - click.option("--ef-construct", type=int, help="HNSW efConstruct", default=100) + click.option("--ef-construct", type=int, help="HNSW efConstruct", default=256), ] ef_search: Annotated[ int, - click.option("--ef-search", type=int, help="HNSW efSearch", default=100) + click.option("--ef-search", type=int, help="HNSW efSearch", default=256), ] @@ -42,10 +38,13 @@ class ChromaTypeDict(CommonTypedDict): @click_parameter_decorators_from_typed_dict(ChromaTypeDict) def Chroma(**parameters: Unpack[ChromaTypeDict]): from .config import ChromaConfig, ChromaIndexConfig + run( db=DBTYPE, - db_config=ChromaConfig(host=SecretStr(parameters["host"]), - port=parameters["port"]), + db_config=ChromaConfig( + host=SecretStr(parameters["host"]), + port=parameters["port"], + ), db_case_config=ChromaIndexConfig( m=parameters["m"], ef_construct=parameters["ef_construct"], diff --git a/vectordb_bench/backend/clients/chroma/config.py b/vectordb_bench/backend/clients/chroma/config.py index 35eb41a72..4c1967a12 100644 --- a/vectordb_bench/backend/clients/chroma/config.py +++ b/vectordb_bench/backend/clients/chroma/config.py @@ -1,6 +1,6 @@ from pydantic import SecretStr -from ..api import DBConfig, DBCaseConfig, MetricType +from ..api import DBCaseConfig, DBConfig, MetricType class ChromaConfig(DBConfig): @@ -8,10 +8,7 @@ class ChromaConfig(DBConfig): port: int = 8000 def to_dict(self) -> dict: - return { - "host": self.host.get_secret_value(), - "port": self.port - } + return {"host": self.host.get_secret_value(), "port": self.port} class ChromaIndexConfig(ChromaConfig, DBCaseConfig): @@ -23,12 +20,11 @@ class ChromaIndexConfig(ChromaConfig, DBCaseConfig): def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "l2" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "ip" - elif self.metric_type == MetricType.COSINE: + if self.metric_type == MetricType.COSINE: return "cosine" - else: - raise ValueError(f"Unsupported metric type: {self.metric_type}") + raise ValueError("Unsupported metric type: %s" % self.metric_type) def index_param(self): return { @@ -41,6 +37,4 @@ def index_param(self): } def search_param(self) -> dict: - return { - "ef_search": self.ef_search - } + return {"ef_search": self.ef_search} diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 5331e2962..debaaa281 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -25,7 +25,6 @@ from .batch_cli import BatchCli from .cli import cli - cli.add_command(PgVectorHNSW) cli.add_command(PgVectoRSHNSW) cli.add_command(PgVectoRSIVFFlat)