diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 79a6f964a..19e670fdf 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -477,7 +477,12 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912 return HologresIndexConfig - # DB.Pinecone, DB.Chroma, DB.Redis + if self == DB.Chroma: + from .chroma.config import ChromaIndexConfig + + return ChromaIndexConfig + + # DB.Pinecone, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py index 7f2cd2f1c..6942f4fc9 100644 --- a/vectordb_bench/backend/clients/chroma/chroma.py +++ b/vectordb_bench/backend/clients/chroma/chroma.py @@ -1,10 +1,10 @@ import logging from contextlib import contextmanager -from typing import Any import chromadb -from ..api import DBCaseConfig, VectorDB +from ..api import VectorDB +from .config import ChromaIndexConfig log = logging.getLogger(__name__) @@ -21,35 +21,34 @@ def __init__( self, dim: int, db_config: dict, - db_case_config: DBCaseConfig, + db_case_config: ChromaIndexConfig, + collection_name: str = "VectorDBBenchCollection", drop_old: bool = False, **kwargs, ): self.db_config = db_config self.case_config = db_case_config - self.collection_name = "example2" + self.collection_name = collection_name - client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"]) + client = chromadb.HttpClient(**db_config) assert client.heartbeat() is not None + if drop_old: try: - client.reset() # Reset the database + client.reset() except Exception: drop_old = False log.info(f"Chroma client drop_old collection: {self.collection_name}") - @contextmanager - def init(self) -> None: - """create and destory connections to database. - - Examples: - >>> with self.init(): - >>> self.insert_embeddings() - """ - # create connection - self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"]) + self.client = None + self.collection = None - self.collection = self.client.get_or_create_collection("example2") + @contextmanager + def init(self): + self.client = chromadb.HttpClient(**self.db_config) + self.collection = self.client.get_or_create_collection( + name=self.collection_name, configuration=self.case_config.index_param() + ) yield self.client = None self.collection = None @@ -58,62 +57,38 @@ def ready_to_search(self) -> bool: pass def optimize(self, data_size: int | None = None): - pass + assert self.collection is not None, "Please call self.init() before" + try: + self.collection.modify(configuration=self.case_config.search_param()) + except Exception as e: + log.warning(f"Optimize error: {e}") + raise def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], - **kwargs: Any, + **kwargs, ) -> tuple[int, Exception]: - """Insert embeddings into the database. - - Args: - embeddings(list[list[float]]): list of embeddings - metadata(list[int]): list of metadata - kwargs: other arguments - - Returns: - tuple[int, Exception]: number of embeddings inserted and exception if any - """ - ids = [str(i) for i in metadata] - metadata = [{"id": int(i)} for i in metadata] + assert self.collection is not None, "Please call self.init() before" + ids = [f"{idx}" for idx in metadata] + metadata = [{"index": mid} for mid in metadata] try: - if len(embeddings) > 0: - self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata) + self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadata) except Exception as e: - log.warning(f"Failed to insert data: error: {e!s}") + log.warning(f"Failed to insert data: {e}") return 0, e - return len(embeddings), None + + return len(metadata), None def search_embedding( - self, - query: list[float], - k: int = 100, - filters: dict | None = None, - timeout: int | None = None, - **kwargs: Any, - ) -> dict: - """Search embeddings from the database. - Args: - embedding(list[float]): embedding to search - k(int): number of results to return - kwargs: other arguments - - Returns: - Dict {ids: list[list[int]], - embedding: list[list[float]] - distance: list[list[float]]} - """ + self, query: list[float], k: int = 100, filters: dict | None = None, timeout: int | None = None + ) -> list[int]: + assert self.client is not None, "Please call self.init() before" if filters: - # assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000} - id_value = filters.get("id") results = self.collection.query( - query_embeddings=query, - n_results=k, - where={"id": {"$gt": id_value}}, + query_embeddings=[query], n_results=k, where={"id": {"$gt": filters.get("id")}} ) - # return list of id's in results - return [int(i) for i in results.get("ids")[0]] - results = self.collection.query(query_embeddings=query, n_results=k) - return [int(i) for i in results.get("ids")[0]] + else: + results = self.collection.query(query_embeddings=[query], n_results=k) + return [int(idx) for idx in results["ids"][0]] diff --git a/vectordb_bench/backend/clients/chroma/cli.py b/vectordb_bench/backend/clients/chroma/cli.py new file mode 100644 index 000000000..414381914 --- /dev/null +++ b/vectordb_bench/backend/clients/chroma/cli.py @@ -0,0 +1,54 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB +from vectordb_bench.cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + +DBTYPE = DB.Chroma + + +class ChromaTypeDict(CommonTypedDict): + host: Annotated[ + str, + click.option("--host", type=str, help="Chroma host", default="localhost"), + ] + port: Annotated[int, click.option("--port", type=int, help="Chroma port", default=8000)] + m: Annotated[ + int, + click.option("--m", type=int, help="HNSW Maximum Neighbors", default=16), + ] + ef_construct: Annotated[ + int, + click.option("--ef-construct", type=int, help="HNSW efConstruct", default=256), + ] + ef_search: Annotated[ + int, + click.option("--ef-search", type=int, help="HNSW efSearch", default=256), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(ChromaTypeDict) +def Chroma(**parameters: Unpack[ChromaTypeDict]): + from .config import ChromaConfig, ChromaIndexConfig + + run( + db=DBTYPE, + db_config=ChromaConfig( + host=SecretStr(parameters["host"]), + port=parameters["port"], + ), + db_case_config=ChromaIndexConfig( + m=parameters["m"], + ef_construct=parameters["ef_construct"], + ef_search=parameters["ef_search"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/chroma/config.py b/vectordb_bench/backend/clients/chroma/config.py index af34cf513..4c1967a12 100644 --- a/vectordb_bench/backend/clients/chroma/config.py +++ b/vectordb_bench/backend/clients/chroma/config.py @@ -1,16 +1,40 @@ from pydantic import SecretStr -from ..api import DBConfig +from ..api import DBCaseConfig, DBConfig, MetricType class ChromaConfig(DBConfig): - password: SecretStr - host: SecretStr - port: int + host: SecretStr = "localhost" + port: int = 8000 def to_dict(self) -> dict: + return {"host": self.host.get_secret_value(), "port": self.port} + + +class ChromaIndexConfig(ChromaConfig, DBCaseConfig): + metric_type: MetricType = "cosine" + m: int = 16 + ef_construct: int = 100 + ef_search: int | None = 100 + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "l2" + if self.metric_type == MetricType.IP: + return "ip" + if self.metric_type == MetricType.COSINE: + return "cosine" + raise ValueError("Unsupported metric type: %s" % self.metric_type) + + def index_param(self): return { - "host": self.host.get_secret_value(), - "port": self.port, - "password": self.password.get_secret_value(), + "hnsw": { + "space": self.parse_metric(), + "max_neighbors": self.m, + "ef_construction": self.ef_construct, + "ef_search": self.search_param().get("ef_search", 100), + } } + + def search_param(self) -> dict: + return {"ef_search": self.ef_search} diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 83dab74f6..debaaa281 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,5 +1,6 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch +from ..backend.clients.chroma.cli import Chroma from ..backend.clients.clickhouse.cli import Clickhouse from ..backend.clients.hologres.cli import HologresHGraph from ..backend.clients.lancedb.cli import LanceDB @@ -50,6 +51,7 @@ cli.add_command(QdrantLocal) cli.add_command(BatchCli) cli.add_command(S3Vectors) +cli.add_command(Chroma) if __name__ == "__main__":