diff --git a/src/modaic/databases/__init__.py b/src/modaic/databases/__init__.py index 13e42fd..b960244 100644 --- a/src/modaic/databases/__init__.py +++ b/src/modaic/databases/__init__.py @@ -13,6 +13,7 @@ VectorType, ) from .vector_database.vendors.milvus import MilvusBackend +from .vector_database.vendors.weaviate import WeaviateBackend __all__ = [ "CollectionConfig", @@ -20,6 +21,7 @@ "SQLiteBackend", "VectorDatabase", "MilvusBackend", + "WeaviateBackend", "SearchResult", "VectorDBBackend", "IndexConfig", diff --git a/src/modaic/databases/vector_database/__init__.py b/src/modaic/databases/vector_database/__init__.py index 1f4c786..db1da2b 100644 --- a/src/modaic/databases/vector_database/__init__.py +++ b/src/modaic/databases/vector_database/__init__.py @@ -1,10 +1,12 @@ from .vector_database import IndexConfig, IndexType, Metric, SupportsHybridSearch, VectorDatabase, VectorType from .vendors.milvus import MilvusBackend +from .vendors.weaviate import WeaviateBackend __all__ = [ "VectorDatabase", "SupportsHybridSearch", "MilvusBackend", + "WeaviateBackend", "IndexConfig", "IndexType", "VectorType", diff --git a/src/modaic/databases/vector_database/vector_database.py b/src/modaic/databases/vector_database/vector_database.py index 4524248..05c138f 100644 --- a/src/modaic/databases/vector_database/vector_database.py +++ b/src/modaic/databases/vector_database/vector_database.py @@ -45,7 +45,7 @@ class SearchResult(NamedTuple): class VectorType(AutoNumberEnum): _init_ = "supported_libraries" # name | supported_libraries - FLOAT = ["milvus", "qdrant", "mongo", "pinecone"] # float32 + FLOAT = ["milvus", "qdrant", "mongo", "pinecone", "weaviate"] # float32 FLOAT16 = ["milvus", "qdrant"] BFLOAT16 = ["milvus"] INT8 = ["milvus", "mongo"] @@ -59,14 +59,14 @@ class VectorType(AutoNumberEnum): class IndexType(AutoNumberEnum): """ - The ANN or ENN algorithm to use for an index. IndexType.DEFAULT is IndexType.HNSW for most vector databases (milvus, qdrant, mongo). + The ANN or ENN algorithm to use for an index. IndexType.DEFAULT is IndexType.HNSW for most vector databases (milvus, qdrant, mongo, weaviate). """ _init_ = "supported_libraries" # name | supported_libraries - DEFAULT = ["milvus", "qdrant", "mongo", "pinecone"] - HNSW = ["milvus", "qdrant", "mongo"] - FLAT = ["milvus", "redis"] + DEFAULT = ["milvus", "qdrant", "mongo", "pinecone", "weaviate"] + HNSW = ["milvus", "qdrant", "mongo", "weaviate"] + FLAT = ["milvus", "redis", "weaviate"] IVF_FLAT = ["milvus"] IVF_SQ8 = ["milvus"] IVF_PQ = ["milvus"] @@ -91,24 +91,31 @@ class Metric(AutoNumberEnum): "qdrant": "Euclid", "mongo": "euclidean", "pinecone": "euclidean", + "weaviate": "l2-squared", } DOT_PRODUCT = { "milvus": "IP", "qdrant": "Dot", "mongo": "dotProduct", "pinecone": "dotproduct", + "weaviate": "dot", } COSINE = { "milvus": "COSINE", "qdrant": "Cosine", "mongo": "cosine", "pinecone": "cosine", + "weaviate": "cosine", } MANHATTAN = { "qdrant": "Manhattan", "mongo": "manhattan", + "weaviate": "manhattan", + } + HAMMING = { + "milvus": "HAMMING", + "weaviate": "hamming", } - HAMMING = {"milvus": "HAMMING"} JACCARD = {"milvus": "JACCARD"} MHJACCARD = {"milvus": "MHJACCARD"} BM25 = {"milvus": "BM25"} diff --git a/src/modaic/databases/vector_database/vendors/weaviate.py b/src/modaic/databases/vector_database/vendors/weaviate.py new file mode 100644 index 0000000..5ecaba1 --- /dev/null +++ b/src/modaic/databases/vector_database/vendors/weaviate.py @@ -0,0 +1,402 @@ +import json # Still needed +from collections.abc import Mapping +from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, Union + +import numpy as np +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) +import weaviate +from weaviate.classes.config import Configure, Property, DataType as WeaviateDataType, VectorDistances +from weaviate.classes.init import AdditionalConfig, Timeout +from weaviate.classes.query import Filter, MetadataQuery + +from ....context.base import Context +from ....exceptions import BackendCompatibilityError +from ....types import InnerField, Schema, SchemaField, float_format, int_format +from ..vector_database import DEFAULT_INDEX_NAME, IndexConfig, IndexType, SearchResult, VectorType + + +modaic_to_weaviate_vector = { + VectorType.FLOAT: "float32", # Weaviate's default + VectorType.FLOAT16: "float16", + VectorType.BFLOAT16: "bfloat16", +} + +modaic_to_weaviate_index = { + IndexType.DEFAULT: "hnsw", + IndexType.HNSW: "hnsw", + IndexType.FLAT: "flat", +} + +modaic_metric_to_weaviate_enum = { + "cosine": VectorDistances.COSINE, + "l2": VectorDistances.L2_SQUARED, + "ip": VectorDistances.DOT, + "dot": VectorDistances.DOT, + "euclidean": VectorDistances.L2_SQUARED, +} + + +class WeaviateTranslator(Visitor): + """ + Translator to convert structured queries to Weaviate filters. + """ + + def visit_operation(self, operation: Operation) -> Filter: + """Convert an operation (AND, OR, NOT) to a Weaviate filter.""" + args = [arg.accept(self) for arg in operation.arguments] + + if operation.operator == Operator.AND: + return Filter.all_of(args) + elif operation.operator == Operator.OR: + return Filter.any_of(args) + elif operation.operator == Operator.NOT: + return Filter.not_(args[0]) + else: + raise ValueError(f"Unsupported operator: {operation.operator}") + + def visit_comparison(self, comparison: Comparison) -> Filter: + """Convert a comparison to a Weaviate filter.""" + attribute = comparison.attribute + comparator = comparison.comparator + value = comparison.value + + if comparator == Comparator.EQ: + return Filter.by_property(attribute).equal(value) + elif comparator == Comparator.NE: + return Filter.by_property(attribute).not_equal(value) + elif comparator == Comparator.GT: + return Filter.by_property(attribute).greater_than(value) + elif comparator == Comparator.GTE: + return Filter.by_property(attribute).greater_or_equal(value) + elif comparator == Comparator.LT: + return Filter.by_property(attribute).less_than(value) + elif comparator == Comparator.LTE: + return Filter.by_property(attribute).less_or_equal(value) + elif comparator == Comparator.IN: + return Filter.by_property(attribute).contains_any(value) + elif comparator == Comparator.CONTAIN: + return Filter.by_property(attribute).contains_any([value]) + elif comparator == Comparator.LIKE: + return Filter.by_property(attribute).like(f"*{value}*") + else: + raise ValueError(f"Unsupported comparator: {comparator}") + + def visit_structured_query(self, structured_query: StructuredQuery) -> Filter: + """Convert a structured query to a Weaviate filter.""" + if structured_query.filter is None: + raise ValueError("Structured query has no filter") + return structured_query.filter.accept(self) + + +class WeaviateBackend: + _name: ClassVar[Literal["weaviate"]] = "weaviate" + mql_translator: Visitor = WeaviateTranslator() + + def __init__( + self, + url: str = "http://localhost:8080", + api_key: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): + additional_config = None + if timeout is not None: + additional_config = AdditionalConfig( + timeout=Timeout(init=timeout, query=timeout, insert=timeout) + ) + + if 'additional_config' in kwargs: + pass + elif additional_config is not None: + kwargs['additional_config'] = additional_config + + if api_key: + self._client = weaviate.connect_to_weaviate_cloud( + cluster_url=url, + auth_credentials=weaviate.auth.AuthApiKey(api_key), + **kwargs, + ) + else: + url_clean = url.replace("http://", "").replace("https://", "") + if ":" in url_clean: + host, port_str = url_clean.split(":", 1) + port = int(port_str) + else: + host = url_clean + port = 8080 + + self._client = weaviate.connect_to_local( + host=host, + port=port, + **kwargs, + ) + + def __del__(self): + """Ensure client is closed on deletion.""" + if hasattr(self, '_client'): + self._client.close() + + def create_record(self, embedding_map: Dict[str, np.ndarray], context: Context) -> Any: + """ + Convert a Context to a record for Weaviate. + + This function serializes nested objects (like dicts or other Context models) + into JSON strings to be stored in Weaviate's TEXT fields. + """ + record_data = context.model_dump(include_hidden=True) + record_id = record_data.pop('id') + + properties = {} + for field_name, value in record_data.items(): + if isinstance(value, dict): + properties[field_name] = json.dumps(value) + elif isinstance(value, list) and value and isinstance(value[0], dict): + properties[field_name] = [json.dumps(item) for item in value] + else: + properties[field_name] = value + + return { + "properties": properties, + "vector": embedding_map.get(DEFAULT_INDEX_NAME, list(embedding_map.values())[0]).tolist(), + "uuid": record_id + } + + def add_records(self, collection_name: str, records: List[Any]): + collection = self._client.collections.get(collection_name) + + with collection.batch.dynamic() as batch: + for record in records: + batch.add_object( + properties=record["properties"], + vector=record["vector"], + uuid=record["uuid"] + ) + + def list_collections(self) -> List[str]: + return list(self._client.collections.list_all().keys()) + + def drop_collection(self, collection_name: str): + self._client.collections.delete(collection_name) + + def create_collection( + self, + collection_name: str, + payload_class: Type[Context], + index: IndexConfig = IndexConfig(), + ): + if not issubclass(payload_class, Context): + raise TypeError(f"Payload class {payload_class} must be a subclass of Context") + + properties = _modaic_to_weaviate_properties(payload_class.schema()) + + try: + vectorizer_config = None + index_type = modaic_to_weaviate_index.get(index.index_type, "hnsw") + + if isinstance(index.metric, str): + metric_name = index.metric.lower() + else: + metric_name = getattr(index.metric, 'name', 'cosine').lower() + + metric_enum = modaic_metric_to_weaviate_enum.get(metric_name, VectorDistances.COSINE) + + if index_type == "flat": + vector_index_config = Configure.VectorIndex.flat( + distance_metric=metric_enum, + ) + elif index_type == "dynamic": + vector_index_config = Configure.VectorIndex.dynamic( + distance_metric=metric_enum, + threshold=10000, + ) + else: + vector_index_config = Configure.VectorIndex.hnsw( + distance_metric=metric_enum, + ) + + except (KeyError, AttributeError) as e: + raise ValueError(f"Weaviate does not support the specified configuration: {e}") from None + + self._client.collections.create( + name=collection_name, + properties=properties, + vectorizer_config=vectorizer_config, + vector_index_config=vector_index_config, + ) + + def has_collection(self, collection_name: str) -> bool: + return collection_name in self.list_collections() + + def _deserialize_properties(self, properties: Dict[str, Any], payload_class: Type[Context]) -> Dict[str, Any]: + """ + Converts Weaviate properties back into a Pydantic-valid dictionary. + - Parses JSON strings back into dictionaries/objects. + """ + deserialized_props = dict(properties) + + # Get the schema to check types + schema_dict = payload_class.schema().as_dict() + + for field_name, value in deserialized_props.items(): + if field_name not in schema_dict: + continue + + schema_field = schema_dict[field_name] + + # Check if it's a string that should be an object (e.g., dict or nested model) + if isinstance(value, str) and schema_field.type == "object": + try: + deserialized_props[field_name] = json.loads(value) + except json.JSONDecodeError: + pass # Keep original string if not valid JSON + + # Check if it's a list of strings that should be a list of objects + elif (isinstance(value, list) and + schema_field.type == "array" and + schema_field.inner_type.type == "object"): + + new_list = [] + for item in value: + if isinstance(item, str): + try: + new_list.append(json.loads(item)) + except json.JSONDecodeError: + new_list.append(item) # Keep original + else: + new_list.append(item) # Already processed? + deserialized_props[field_name] = new_list + + return deserialized_props + + def search( + self, + collection_name: str, + vectors: List[np.ndarray], + payload_class: Type[Context], + k: int = 10, + filter: Optional[Filter] = None, + ) -> List[List[SearchResult]]: + if not issubclass(payload_class, Context): + raise TypeError(f"Payload class {payload_class} must be a subclass of Context") + + collection = self._client.collections.get(collection_name) + all_results = [] + + for vector in vectors: + response = collection.query.near_vector( + near_vector=vector.tolist(), + limit=k, + filters=filter, + return_metadata=MetadataQuery(distance=True), + ) + + context_list = [] + for obj in response.objects: + properties = self._deserialize_properties(obj.properties, payload_class) + properties['id'] = str(obj.uuid) # Cast UUID to string + + context = payload_class.model_validate(properties) + + score = 1.0 - obj.metadata.distance + + context_list.append(SearchResult(id=context.id, context=context, score=score)) + + all_results.append(context_list) + + return all_results + + def get_records(self, collection_name: str, payload_class: Type[Context], record_ids: List[str]) -> List[Context]: + collection = self._client.collections.get(collection_name) + records = [] + + for record_id in record_ids: + obj = collection.query.fetch_object_by_id(record_id) + if obj: + properties = self._deserialize_properties(obj.properties, payload_class) + properties['id'] = str(obj.uuid) # Cast UUID to string + + records.append(payload_class.model_validate(properties)) + + return records + + @staticmethod + def from_local(host: str = "localhost", port: int = 8080) -> "WeaviateBackend": + return WeaviateBackend(url=f"http://{host}:{port}") + + +def _modaic_to_weaviate_properties(modaic_schema: Schema) -> List[Property]: + """ + Convert a Modaic schema to Weaviate properties. + """ + type_mapping: Mapping[str, WeaviateDataType] = { + "string": WeaviateDataType.TEXT, + "integer": WeaviateDataType.INT, + "number": WeaviateDataType.NUMBER, + "boolean": WeaviateDataType.BOOL, + } + + format_mapping: Mapping[int_format | float_format, WeaviateDataType] = { + "int8": WeaviateDataType.INT, + "int16": WeaviateDataType.INT, + "int32": WeaviateDataType.INT, + "int64": WeaviateDataType.INT, + "float": WeaviateDataType.NUMBER, + "double": WeaviateDataType.NUMBER, + } + + properties = [] + + for field_name, schema_field in modaic_schema.as_dict().items(): + if schema_field.is_id: + continue + + index_searchable = False + + if schema_field.type == "array": + if schema_field.inner_type.type == "string": + data_type = WeaviateDataType.TEXT_ARRAY + index_searchable = True + elif schema_field.inner_type.type == "integer": + data_type = WeaviateDataType.INT_ARRAY + elif schema_field.inner_type.type == "number": + data_type = WeaviateDataType.NUMBER_ARRAY + elif schema_field.inner_type.type == "boolean": + data_type = WeaviateDataType.BOOL_ARRAY + elif schema_field.inner_type.type == "object": + data_type = WeaviateDataType.TEXT_ARRAY # Store as array of JSON strings + else: + raise ValueError(f"Weaviate does not support array of type: {schema_field.inner_type.type}") + + elif schema_field.type == "object": + data_type = WeaviateDataType.TEXT # Store as JSON string + index_searchable = False + + elif schema_field.format in format_mapping: + data_type = format_mapping[schema_field.format] + + elif schema_field.type in type_mapping: + data_type = type_mapping[schema_field.type] + if data_type == WeaviateDataType.TEXT: + index_searchable = True + + else: + raise ValueError(f"Weaviate does not support field type: {schema_field.type}") + + properties.append( + Property( + name=field_name, + data_type=data_type, + skip_vectorization=True, + index_filterable=True, + index_searchable=index_searchable, + ) + ) + + return properties \ No newline at end of file diff --git a/tests/databases/vector_database/test_weaviate.py b/tests/databases/vector_database/test_weaviate.py new file mode 100644 index 0000000..e283ec9 --- /dev/null +++ b/tests/databases/vector_database/test_weaviate.py @@ -0,0 +1,499 @@ +import os +import uuid +from typing import Optional + +import numpy as np +import pytest +import weaviate +from weaviate.classes.query import Filter + +from modaic import Condition, parse_modaic_filter + +# Import your real backend + types +from modaic.context import Context, Text +from modaic.databases import WeaviateBackend, VectorDatabase +from modaic.databases.vector_database.vector_database import VectorDBBackend +from modaic.types import Array, String +from tests.testing_utils import DummyEmbedder, HardcodedEmbedder + + +def _read_hosted_config(): + """ + Read hosted Weaviate configuration from environment variables. + + Returns: + dict: Configuration dictionary with url and api_key + """ + return { + "url": os.environ.get("WEAVIATE_URL"), + "api_key": os.environ.get("WEAVIATE_API_KEY"), + } + + +# --------------------------- +# Param: which backend flavor +# --------------------------- +@pytest.fixture(params=["local", "hosted"]) +def weaviate_mode(request): + """ + Fixture that provides Weaviate mode (local or hosted) for parameterized tests. + + Params: + request: pytest request object + + Returns: + str: Either "local" or "hosted" + """ + cfg = _read_hosted_config() + if request.param == "hosted" and not cfg["url"]: + pytest.skip("No hosted Weaviate configured (set WEAVIATE_URL environment variable).") + return request.param + + +# --------------------------- +# Configuration for each mode +# --------------------------- +@pytest.fixture(scope="session") +def hosted_cfg(): + """ + Provide hosted Weaviate configuration for tests. + + Returns: + dict: Hosted configuration dictionary + """ + return _read_hosted_config() + + +@pytest.fixture +def vector_database(weaviate_mode: str, hosted_cfg: dict): + """ + Returns a WeaviateBackend connected to local or hosted, depending on weaviate_mode. + + Params: + weaviate_mode: Either "local" or "hosted" + hosted_cfg: Configuration dictionary for hosted mode + + Returns: + VectorDatabase: Configured vector database instance + """ + # Create a default embedder for testing + default_embedder = DummyEmbedder() + + if weaviate_mode == "local": + vector_database = VectorDatabase( + WeaviateBackend.from_local(), + embedder=default_embedder + ) + else: + vector_database = VectorDatabase( + WeaviateBackend( + url=hosted_cfg["url"], + api_key=hosted_cfg["api_key"], + ), + embedder=default_embedder, + ) + + # Smoke check: try a harmless op to verify connectivity + try: + _ = vector_database.list_collections() + except Exception as e: + pytest.skip(f"Weaviate connection failed for mode={weaviate_mode}: {e}") + + yield vector_database + + # Best-effort cleanup: drop only collections we created in tests + try: + for c in vector_database.list_collections(): + vector_database.drop_collection(c) + except Exception: + pass + + +# --------------------------- +# Throwaway collection per test +# --------------------------- +@pytest.fixture +def collection_name(vector_database: VectorDatabase): + """ + Yields a unique collection name; drops it after the test if it was created. + + Params: + vector_database: Vector database instance + + Returns: + str: Unique collection name + """ + # Weaviate collection names must start with uppercase letter + name = f"T{uuid.uuid4().hex[:12]}" + try: + yield name + finally: + try: + if vector_database.has_collection(name): + vector_database.drop_collection(name) + except Exception: + pass + + +class CustomContext(Context): + """ + Custom context for Weaviate tests, covering all supported types and Optionals. + """ + + field1: str + field2: int + field3: bool + field4: float + field5: list[str] + field6: dict[str, int] + field7: Array[int, 10] + field8: String[50] + field9: Text + field10: Optional[Array[String[50], 10]] = None + field11: Optional[Array[int, 10]] = None + field12: Optional[String[50]] = None + + def embedme(self) -> str: + return self.field9.text + + +def test_mql_to_weaviate_simple(): + """ + Test simple MQL to Weaviate translation for equality, comparison, in/like, and logical ops. + + Params: + None + """ + translator = WeaviateBackend.mql_translator + + # Simple equality + expr = CustomContext.field1 == "foo" + filter_obj = parse_modaic_filter(translator, expr) + # Just verify it creates some kind of filter object, don't check specific type + assert filter_obj is not None + + # Range with AND + expr = (CustomContext.field2 > 5) & (CustomContext.field2 <= 10) + filter_obj = parse_modaic_filter(translator, expr) + assert filter_obj is not None + + # IN operator with AND + expr = (CustomContext.field1.in_(["a", "b"])) & (CustomContext.field2 < 100) + filter_obj = parse_modaic_filter(translator, expr) + assert filter_obj is not None + + # OR combination + expr = (CustomContext.field2 < 0) | (CustomContext.field2 > 10) + filter_obj = parse_modaic_filter(translator, expr) + assert filter_obj is not None + + +def test_mql_to_weaviate_complex(): + """ + Complex nested MQL to Weaviate translation - simplified since Modaic doesn't support NOT yet + """ + translator = WeaviateBackend.mql_translator + + range_and = (CustomContext.field2 >= 1) & (CustomContext.field2 <= 10) + in_list = CustomContext.field1.in_(["x", "y"]) + + # Combine the filters - skip NOT since Modaic doesn't support it yet + complex_expr = range_and & in_list + filter_obj = parse_modaic_filter(translator, complex_expr) + + # Verify it's a valid filter + assert filter_obj is not None + + +def test_weaviate_implements_vector_db_backend(vector_database: VectorDatabase): + """Test that WeaviateBackend implements VectorDBBackend interface.""" + backend = vector_database.ext.backend + assert isinstance(backend, VectorDBBackend) + + +def test_create_collection(vector_database: VectorDatabase, collection_name: str): + """Test creating a collection.""" + vector_database.create_collection(collection_name, CustomContext) + assert vector_database.has_collection(collection_name) + + +def test_drop_collection(vector_database: VectorDatabase, collection_name: str): + """Test dropping a collection.""" + vector_database.create_collection(collection_name, CustomContext) + assert vector_database.has_collection(collection_name) + vector_database.drop_collection(collection_name) + assert not vector_database.has_collection(collection_name) + + +def test_list_collections(vector_database: VectorDatabase, collection_name: str): + """Test listing collections.""" + vector_database.create_collection(collection_name, CustomContext) + assert collection_name in vector_database.list_collections() + + +def test_has_collection(vector_database: VectorDatabase, collection_name: str): + """Test checking if collection exists.""" + vector_database.create_collection(collection_name, CustomContext) + assert vector_database.has_collection(collection_name) + vector_database.drop_collection(collection_name) + assert not vector_database.has_collection(collection_name) + + +def test_record_ops(vector_database: VectorDatabase, collection_name: str): + """Test adding and retrieving records.""" + vector_database.create_collection(collection_name, CustomContext, embedder=DummyEmbedder(embedding_dim=3)) + context = CustomContext( + field1="test", + field2=1, + field3=True, + field4=1.0, + field5=["test"], + field6={"test": 1}, + field7=[1, 2, 3], + field8="test", + field9=Text(text="test"), + field10=["hello", "world"], + field11=None, + field12="test", + ) + vector_database.add_records(collection_name, [context]) + assert vector_database.has_collection(collection_name) + + retrieved = vector_database.get_records(collection_name, [context.id]) + assert len(retrieved) == 1 + assert retrieved[0] == context + + +def test_search(vector_database: VectorDatabase, collection_name: str): + """Test vector search with multiple records.""" + hardcoded_embedder = HardcodedEmbedder() + vector_database.create_collection(collection_name, CustomContext, embedder=hardcoded_embedder) + + context1 = CustomContext( + field1="test", + field2=1, + field3=True, + field4=1.0, + field5=["test"], + field6={"test": 1}, + field7=[1, 2, 3], + field8="test", + field9=Text(text="test"), + field10=["hello", "world"], + field11=None, + field12="test", + ) + context2 = CustomContext( + field1="test2", + field2=2, + field3=False, + field4=2.0, + field5=["test2"], + field6={"test2": 2}, + field7=[4, 5, 6], + field8="test2", + field9=Text(text="test2"), + field10=["hello2", "world2"], + field11=None, + field12="test2", + ) + context3 = CustomContext( + field1="test3", + field2=3, + field3=True, + field4=3.0, + field5=["test3"], + field6={"test3": 3}, + field7=[7, 8, 9], + field8="test3", + field9=Text(text="test3"), + field10=["hello3", "world3"], + field11=None, + field12="test3", + ) + + # Set up hardcoded embeddings for predictable results + hardcoded_embedder("query", np.array([3, 5, 7])) + hardcoded_embedder("record1", np.array([4, 5, 6])) # Cosine similarity 0.988195 + hardcoded_embedder("record2", np.array([6, 3, 0])) # Cosine similarity 0.539969 + hardcoded_embedder("record3", np.array([1, 0, 0])) # Cosine similarity 0.329293 + + vector_database.add_records(collection_name, [("record1", context1), ("record2", context2), ("record3", context3)]) + + # Test top-k retrieval + results_k1 = vector_database.search(collection_name, "query", k=1) + assert results_k1[0][0].context == context1 + + results_k2 = vector_database.search(collection_name, "query", k=2) + assert results_k2[0][1].context == context2 + + results_k3 = vector_database.search(collection_name, "query", k=3) + assert results_k3[0][2].context == context3 + + +def test_search_with_filters(vector_database: VectorDatabase[WeaviateBackend], collection_name: str): + """Test vector search with various filters.""" + hardcoded_embedder = HardcodedEmbedder() + vector_database.create_collection(collection_name, CustomContext, embedder=hardcoded_embedder) + + context1 = CustomContext( + field1="test", + field2=1, + field3=True, + field4=1.0, + field5=["test"], + field6={"test": 1}, + field7=[1, 2, 3], + field8="test", + field9=Text(text="test"), + field10=["hello", "world"], + field11=None, + field12="test", + ) + context2 = CustomContext( + field1="test2", + field2=2, + field3=False, + field4=2.0, + field5=["test2"], + field6={"test2": 2}, + field7=[4, 5, 6], + field8="test2", + field9=Text(text="test2"), + field10=["hello2", "world2"], + field11=None, + field12="test2", + ) + context3 = CustomContext( + field1="test3", + field2=3, + field3=True, + field4=3.0, + field5=["test3"], + field6={"test3": 3}, + field7=[7, 8, 9], + field8="test3", + field9=Text(text="test3"), + field10=["hello3", "world3"], + field11=None, + field12="test3", + ) + + # Set up embeddings + hardcoded_embedder("query", np.array([3, 5, 7])) + hardcoded_embedder("record1", np.array([4, 5, 6])) # Cosine similarity 0.988195 + hardcoded_embedder("record2", np.array([6, 3, 0])) # Cosine similarity 0.539969 + hardcoded_embedder("record3", np.array([1, 0, 0])) # Cosine similarity 0.329293 + + vector_database.add_records(collection_name, [("record1", context1), ("record2", context2), ("record3", context3)]) + + # Test equality filter + filter1 = CustomContext.field1 == "test2" + results1 = vector_database.search(collection_name, "query", 1, filter1) + assert results1[0][0].context == context2 + + # Test greater than filter + filter2 = CustomContext.field2 > 2 + results2 = vector_database.search(collection_name, "query", 1, filter2) + assert results2[0][0].context == context3 + + # Test less than filter + filter3 = CustomContext.field4 < 3.0 + results3 = vector_database.search(collection_name, "query", 1, filter3) + assert results3[0][0].context == context1 + + # Test IN filter + filter4 = CustomContext.field12.in_(["test2", "test3"]) + results4 = vector_database.search(collection_name, "query", 1, filter4) + assert results4[0][0].context == context2 + + # Test range filter with AND + filter9 = (CustomContext.field4 < 3.1) & (CustomContext.field4 > 1.9) + results9 = vector_database.search(collection_name, "query", 1, filter9) + assert results9[0][0].context == context2 + + +def test_search_with_multiple_vectors(vector_database: VectorDatabase, collection_name: str): + """Test searching with multiple query vectors at once.""" + hardcoded_embedder = HardcodedEmbedder() + vector_database.create_collection(collection_name, CustomContext, embedder=hardcoded_embedder) + + context1 = CustomContext( + field1="test", + field2=1, + field3=True, + field4=1.0, + field5=["test"], + field6={"test": 1}, + field7=[1, 2, 3], + field8="test", + field9=Text(text="test"), + field10=["hello", "world"], + field11=None, + field12="test", + ) + + hardcoded_embedder("query1", np.array([1, 0, 0])) + hardcoded_embedder("query2", np.array([0, 1, 0])) + hardcoded_embedder("record1", np.array([1, 0, 0])) + + vector_database.add_records(collection_name, [("record1", context1)]) + + # Search with multiple queries + results = vector_database.search(collection_name, ["query1", "query2"], k=1) + assert len(results) == 2 # One result set per query + assert results[0][0].context == context1 + assert results[1][0].context == context1 + + +def test_null_value_handling(vector_database: VectorDatabase, collection_name: str): + """Test that Weaviate properly handles null values (which Milvus Lite doesn't support).""" + vector_database.create_collection(collection_name, CustomContext, embedder=DummyEmbedder(embedding_dim=3)) + + context = CustomContext( + field1="test", + field2=1, + field3=True, + field4=1.0, + field5=["test"], + field6={"test": 1}, + field7=[1, 2, 3], + field8="test", + field9=Text(text="test"), + field10=None, # Null value + field11=None, # Null value + field12=None, # Null value + ) + + vector_database.add_records(collection_name, [context]) + retrieved = vector_database.get_records(collection_name, [context.id]) + + assert len(retrieved) == 1 + assert retrieved[0].field10 is None + assert retrieved[0].field11 is None + assert retrieved[0].field12 is None + + +@pytest.mark.skip(reason="Connection cleanup test - skipping for now") +def test_connection_cleanup(weaviate_mode: str, hosted_cfg: dict): + """Test that the client connection is properly closed.""" + default_embedder = DummyEmbedder() + + if weaviate_mode == "local": + backend = WeaviateBackend.from_local() + else: + if not hosted_cfg["url"]: + pytest.skip("No hosted Weaviate configured") + backend = WeaviateBackend( + url=hosted_cfg["url"], + api_key=hosted_cfg["api_key"], + ) + + vector_db = VectorDatabase(backend, embedder=default_embedder) + + # Verify connection works + _ = vector_db.list_collections() + + # Cleanup + del vector_db + del backend \ No newline at end of file