diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 2e8cc66bc9..06aaa7f209 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from docarray import BaseDoc, DocList +from docarray import BaseDoc from docarray.documents import AudioDoc from docarray.typing import AudioUrl, ImageUrl from pydantic import Field, NonNegativeFloat, PositiveInt, conint, conlist, field_validator @@ -133,7 +133,7 @@ class Audio2TextDoc(AudioDoc): class SearchedDoc(BaseDoc): - retrieved_docs: DocList[TextDoc] + retrieved_docs: List[TextDoc] initial_query: str top_n: PositiveInt = 1 @@ -164,7 +164,7 @@ class LVMSearchedMultimodalDoc(SearchedMultimodalDoc): class RerankedDoc(BaseDoc): - reranked_docs: DocList[TextDoc] + reranked_docs: List[TextDoc] initial_query: str @@ -502,10 +502,10 @@ class RerankerParms(BaseDoc): class RAGASParams(BaseDoc): - questions: DocList[TextDoc] - answers: DocList[TextDoc] - docs: DocList[TextDoc] - ground_truths: DocList[TextDoc] + questions: List[TextDoc] + answers: List[TextDoc] + docs: List[TextDoc] + ground_truths: List[TextDoc] class RAGASScores(BaseDoc): @@ -583,7 +583,7 @@ class ImagePath(BaseDoc): class ImagesPath(BaseDoc): - images_path: DocList[ImagePath] + images_path: List[ImagePath] class VideoPath(BaseDoc): diff --git a/comps/cores/storages/__init__.py b/comps/cores/storages/__init__.py index 8ca0972089..e6c0bb1727 100644 --- a/comps/cores/storages/__init__.py +++ b/comps/cores/storages/__init__.py @@ -4,7 +4,7 @@ from .arangodb import ArangoDBStore # from .redisdb import RedisDBStore -# from .mongodb import MongoDBStore +from .mongodb import MongoDBStore def opea_store(name: str, *args, **kwargs): @@ -12,7 +12,7 @@ def opea_store(name: str, *args, **kwargs): return ArangoDBStore(name, *args, **kwargs) # elif name == "redis": # return RedisDBStore(*args, **kwargs) - # elif name == "mongodb": - # return MongoDBStore(*args, **kwargs) + elif name == "mongodb": + return MongoDBStore(name, *args, **kwargs) else: raise ValueError(f"Unknown Data Store: {name}") diff --git a/comps/cores/storages/mongodb.py b/comps/cores/storages/mongodb.py new file mode 100644 index 0000000000..652b8e8a3e --- /dev/null +++ b/comps/cores/storages/mongodb.py @@ -0,0 +1,337 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from typing import Any + +import bson.errors as BsonError +import motor.motor_asyncio as motor +from bson.objectid import ObjectId + +from ..common.storage import OpeaStore +from ..mega.logger import CustomLogger + +logger = CustomLogger("MongoDBStore") + + +class MongoDBStore(OpeaStore): + + def __init__(self, name: str, description: str = "", config: dict = {}): + super().__init__(name, description, config) + self.user = config.get("user", None) + + def _initialize_db(self) -> None: + """Initializes the MongoDB database connection and collection.""" + + MONGO_HOST = self.config.get("MONGO_HOST", "localhost") + MONGO_PORT = self.config.get("MONGO_PORT", 27017) + DB_NAME = self.config.get("DB_NAME", "OPEA") + COLLECTION_NAME = self.config.get("COLLECTION_NAME", "default") + conn_url = f"mongodb://{MONGO_HOST}:{MONGO_PORT}/" + + try: + client = motor.AsyncIOMotorClient(conn_url) + self.db = client[DB_NAME] + + except Exception as e: + logger.error(e) + raise Exception() + + self.collection = self.db[COLLECTION_NAME] + + def health_check(self) -> bool: + """Performs a health check on the MongoDB connection. + + Returns: + bool: True if the connection is healthy, False otherwise. + """ + try: + self.collection.count_documents() + logger.info("MongoDB Health check succeed!") + return True + except Exception as e: + logger.error(f"MongoDB Health check failed: {e}") + return False + + async def asave_document(self, doc: dict, **kwargs) -> bool | dict: + """Stores a new document into the MongoDB collection. + + Args: + doc (dict): The document data to save. + **kwargs: Additional arguments for saving the document. + + Returns: + bool | dict: The result of the save operation. + """ + try: + inserted_data = await self.collection.insert_one( + doc.model_dump(by_alias=True, mode="json", exclude={"doc_id"}) + ) + doc_id = str(inserted_data.inserted_id) + logger.info(f"Inserted document: {doc_id}") + return doc_id + + except Exception as e: + logger.error(f"Fail to save document: {e}") + raise Exception(e) + + async def asave_documents(self, docs: list[dict], **kwargs) -> bool | list: + """Save multiple documents to the store. + + Args: + docs (list[dict]): A list of document data to save. + **kwargs: Additional arguments for saving the documents. + + Returns: + bool | list: A list of results for the save operations. + """ + try: + inserted_data = await self.collection.insert_many( + doc.model_dump(by_alias=True, mode="json", exclude={"doc_id"}) for doc in docs + ) + doc_ids = str(inserted_data.inserted_ids) + logger.info(f"Inserted documents: {doc_ids}") + return doc_ids + + except Exception as e: + logger.error(f"Fail to save document: {e}") + raise Exception(e) + + async def aupdate_document(self, doc: dict, **kwargs) -> bool | dict: + """Update a single document in the store. + + Args: + doc (dict): The document data to update. + **kwargs: Additional arguments for updating the document. + + Returns: + bool | dict: The result of the update operation. + """ + try: + doc_id = doc.get("doc_id", None) + _id = ObjectId(doc_id) + first_query = doc.get("first_query", None) + data = doc.get("data", None) + if first_query: + data = {"data": data.model_dump(by_alias=True, mode="json"), "first_query": first_query} + else: + data = {"data": data.model_dump(by_alias=True, mode="json")} + + updated_result = await self.collection.update_one( + {"_id": _id, "data.user": self.user}, + {"$set": data}, + ) + + if updated_result.modified_count == 1: + logger.info(f"Updated document: {doc_id}") + return True + else: + raise Exception("Not able to update the data.") + + except BsonError.InvalidId as e: + logger.error(e) + raise KeyError(e) + + except Exception as e: + logger.error(e) + raise Exception(e) + + async def aupdate_documents(self, docs: list[dict], **kwargs) -> bool | dict: + """Update multiple documents in the store. + + Args: + docs (list[dict]): The list of documents to update. + **kwargs: Additional arguments for updating the documents. + + Returns: + bool | dict: The result of the update operation. + """ + for doc in docs: + result = await self.aupdate_document(doc) + assert result + return True + + async def aget_document_by_id(self, id: str, **kwargs) -> dict | None: + """Asynchronously retrieve a single document by its unique identifier. + + Args: + id (str): The unique identifier for the document. + **kwargs: Additional arguments for retrieving the document. + + Returns: + dict | None: The user's feedback data if found, None otherwise. + + Raises: + Exception: If there is an error while retrieving data. + """ + try: + _id = ObjectId(id) + response: dict | None = await self.collection.find_one({"_id": _id, "chat_data.user": self.user}) + if response: + del response["_id"] + logger.info(f"Retrieved document: {id}") + return response["data"] + return None + + except BsonError.InvalidId as e: + logger.info(e) + raise KeyError(e) + + except Exception as e: + logger.info(e) + raise Exception(e) + + async def aget_documents_by_ids(self, ids: list[str], **kwargs) -> list[dict]: + """Asynchronously retrieve a single document by its unique identifier. + + Args: + id (str): The unique identifier for the document. + **kwargs: Additional arguments for retrieving the document. + + Returns: + dict: The retrieved document data. + """ + try: + responses = [] + for id in ids: + _id = ObjectId(id) + response: dict | None = await self.collection.find_one({"_id": _id, "chat_data.user": self.user}) + if response: + del response["_id"] + responses.append(response["data"]) + logger.info(f"Retrieved documents: {response}") + return responses + + except BsonError.InvalidId as e: + logger.info(e) + raise KeyError(e) + + except Exception as e: + logger.info(e) + raise Exception(e) + + async def aget_documents_by_user(self, user: str = None, **kwargs) -> list[dict] | None: + """Asynchronously retrieve all documents for a specific user. + + Args: + user (str): The unique identifier for the user. + **kwargs: Additional arguments for retrieving the documents. + + Returns: + list[dict] | None: List of dict of feedback data of the user, None otherwise. + + Raises: + Exception: If there is an error while retrieving data. + """ + try: + responses = [] + if user is None: + user = self.user + cursor = await self.collection.find({"user": user}, {"data": 0}) + + async for document in cursor: + document["doc_id"] = str(document["_id"]) + del document["_id"] + responses.append(document) + logger.info(f"Retrieved documents: {responses}") + return responses + + except Exception as e: + logger.info(e) + raise Exception(e) + + async def adelete_document(self, id: str, **kwargs) -> bool: + """Asynchronously delete a single document from the store. + + Args: + id (str): The unique identifier for the document. + **kwargs: Additional arguments for deleting the document. + + Returns: + bool: True if doc is successfully deleted, False otherwise. + + Raises: + KeyError: If the provided id is invalid: + Exception: If any errors occurs during delete process. + """ + try: + _id = ObjectId(id) + result = await self.collection.delete_one({"_id": _id, "chat_data.user": self.user}) + + delete_count = result.deleted_count + logger.info(f"Deleted {delete_count} documents!") + + return True if delete_count == 1 else False + + except BsonError.InvalidId as e: + logger.error(e) + raise KeyError(e) + + except Exception as e: + logger.error(e) + raise Exception(e) + + async def adelete_documents(self, ids: list[str], **kwargs) -> bool: + """Asynchronously delete multiple documents from the store.". + + Args: + ids (list[str]): A list of unique identifiers for the documents. + **kwargs: Additional arguments for deleting the documents. + + Returns: + bool: True if doc is successfully deleted, False otherwise. + + Raises: + KeyError: If the provided id is invalid: + Exception: If any errors occurs during delete process. + """ + try: + result = await self.collection.delete_many({"_id": {"$in": ids}, "chat_data.user": self.user}) + + delete_count = result.deleted_count + logger.info(f"Deleted {delete_count} documents!") + + return True if delete_count == 1 else False + + except BsonError.InvalidId as e: + logger.error(e) + raise KeyError(e) + + except Exception as e: + logger.error(e) + raise Exception(e) + + async def asearch(self, key: str, value: Any = None, search_type: str = "exact", **kwargs) -> list[dict]: + """Asynchronously search for documents based on a key-value pair. + + Args: + key (str): The keyword of prompt to search for. + value (Any): The value to match against the key. + search_type (str): The type of search to perform. + **kwargs: Additional arguments for the search. + + Returns: + list[dict]: A list of matching documents. + """ + try: + # Create a text index if not already created + self.collection.create_index([("$**", "text")]) + # Perform text search + results = await self.collection.find({"$text": {"$search": key}}, {"score": {"$meta": "textScore"}}) + sorted_results = results.sort([("score", {"$meta": "textScore"})]) + + # Return a list of top 5 most relevant data + relevant_data = await sorted_results.to_list(length=5) + + # Serialize data and return + serialized_data = [ + {"id": str(doc["_id"]), "data": doc["data"], "user": doc["user"], "score": doc["score"]} + for doc in relevant_data + ] + + logger.info(f"Search results: {serialized_data}") + return serialized_data + + except Exception as e: + logger.info(e) + raise Exception(e) diff --git a/requirements.in b/requirements.in index cca4354342..345257b7b0 100644 --- a/requirements.in +++ b/requirements.in @@ -8,6 +8,7 @@ kubernetes langchain langchain-community mcp +motor opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk diff --git a/requirements.txt b/requirements.txt index 6eebd36413..fbb040e68c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,10 +18,6 @@ anyio==4.9.0 # mcp # sse-starlette # starlette -async-timeout==4.0.3 ; python_full_version < '3.11' - # via - # aiohttp - # langchain attrs==25.3.0 # via aiohttp cachetools==5.5.2 @@ -42,14 +38,14 @@ colorama==0.4.6 ; sys_platform == 'win32' # via click dataclasses-json==0.6.7 # via langchain-community +dnspython==2.7.0 + # via pymongo docarray==0.41.0 # via -r ./requirements.in docx2txt==0.9 # via -r ./requirements.in durationpy==0.10 # via kubernetes -exceptiongroup==1.3.0 ; python_full_version < '3.11' - # via anyio fastapi==0.115.13 # via -r ./requirements.in frozenlist==1.7.0 @@ -123,17 +119,15 @@ mcp==1.9.4 # via -r ./requirements.in mdurl==0.1.2 # via markdown-it-py +motor==3.7.1 + # via -r ./requirements.in multidict==6.5.0 # via # aiohttp # yarl mypy-extensions==1.1.0 # via typing-inspect -numpy==2.2.6 ; python_full_version < '3.11' - # via - # docarray - # langchain-community -numpy==2.3.0 ; python_full_version >= '3.11' +numpy==2.3.0 # via # docarray # langchain-community @@ -221,6 +215,8 @@ pygments==2.19.1 # via rich pyjwt==2.10.1 # via python-arango +pymongo==4.13.2 + # via motor pypdf==5.6.0 # via -r ./requirements.in python-arango==8.2.0 @@ -291,10 +287,8 @@ types-requests==2.32.4.20250611 typing-extensions==4.14.0 # via # anyio - # exceptiongroup # fastapi # langchain-core - # multidict # opentelemetry-api # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http @@ -302,12 +296,9 @@ typing-extensions==4.14.0 # opentelemetry-semantic-conventions # pydantic # pydantic-core - # pypdf - # rich # sqlalchemy # typing-inspect # typing-inspection - # uvicorn typing-inspect==0.9.0 # via # dataclasses-json diff --git a/tests/cores/storages/test_mongodb.py b/tests/cores/storages/test_mongodb.py new file mode 100644 index 0000000000..9eec040213 --- /dev/null +++ b/tests/cores/storages/test_mongodb.py @@ -0,0 +1,153 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from bson.objectid import ObjectId + +from comps.cores.storages import opea_store + + +class DummyDoc: + def model_dump(self, **kwargs): + return {"text": "mock data"} + + +class MockAsyncCursor: + def __init__(self, docs): + self.docs = docs + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.docs): + raise StopAsyncIteration + doc = self.docs[self.index] + self.index += 1 + return doc + + +class MockSortCursor: + def __init__(self, docs): + self.docs = docs + + def sort(self, *args, **kwargs): + return self + + async def to_list(self, length): + return self.docs + + +class TestMongoDBStore(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + self.config = { + "MONGO_HOST": "localhost", + "MONGO_PORT": 27017, + "DB_NAME": "test_db", + "COLLECTION_NAME": "test_collection", + "user": "test_user", + } + # patcher = patch("motor.motor_asyncio.AsyncIOMotorClient") + patcher = patch("comps.cores.storages.mongodb.motor.AsyncIOMotorClient") + self.addCleanup(patcher.stop) + self.MockClient = patcher.start() + + mock_client = MagicMock() + mock_db = MagicMock() + mock_collection = AsyncMock() + mock_collection.count_documents = MagicMock(return_value=1) + mock_client.__getitem__.return_value = mock_db + mock_db.__getitem__.return_value = mock_collection + self.MockClient.return_value = mock_client + + self.store = opea_store(name="mongodb", description="test", config=self.config) + self.store.collection = mock_collection + + def test_health_check_success(self): + self.store.collection.count_documents.return_value = 1 + result = self.store.health_check() + self.assertTrue(result) + + def test_health_check_failure(self): + self.store.collection.count_documents.side_effect = Exception("failed") + result = self.store.health_check() + self.assertFalse(result) + + async def test_asave_document(self): + mock_id = ObjectId("60dbf3a1fc13ae1a3b000000") + self.store.collection.insert_one.return_value.inserted_id = mock_id + result = await self.store.asave_document(DummyDoc()) + self.assertEqual(result, str(mock_id)) + + async def test_asave_documents(self): + self.store.collection.insert_many.return_value.inserted_ids = [ObjectId()] + docs = [DummyDoc()] + result = await self.store.asave_documents(docs) + self.assertTrue(isinstance(result, str)) + + async def test_aupdate_document(self): + self.store.collection.update_one.return_value.modified_count = 1 + doc = {"doc_id": str(ObjectId()), "data": DummyDoc()} + result = await self.store.aupdate_document(doc) + self.assertTrue(result) + + async def test_aupdate_documents(self): + self.store.collection.update_one.return_value.modified_count = 1 + docs = [{"doc_id": str(ObjectId()), "data": DummyDoc()}] + result = await self.store.aupdate_documents(docs) + self.assertTrue(result) + + async def test_aget_document_by_id(self): + self.store.collection.find_one.return_value = {"_id": ObjectId(), "data": {"text": "mock"}} + result = await self.store.aget_document_by_id(str(ObjectId())) + self.assertEqual(result, {"text": "mock"}) + + async def test_aget_documents_by_ids(self): + mock_id = ObjectId("60dbf3a1fc13ae1a3b000000") + self.store.collection.find_one.return_value = {"_id": mock_id, "data": {"text": "mock"}} + result = await self.store.aget_documents_by_ids([str(mock_id)]) + self.assertEqual(result, [{"text": "mock"}]) + + async def test_aget_documents_by_user(self): + mock_docs = [{"_id": ObjectId("60dbf3a1fc13ae1a3b000000"), "user": "test_user"}] + self.store.collection.find.return_value = MockAsyncCursor(mock_docs) + + result = await self.store.aget_documents_by_user("test_user") + + self.assertIsInstance(result, list) + self.assertEqual(result[0]["doc_id"], "60dbf3a1fc13ae1a3b000000") + + async def test_adelete_document(self): + self.store.collection.delete_one.return_value.deleted_count = 1 + result = await self.store.adelete_document(str(ObjectId())) + self.assertTrue(result) + + async def test_adelete_documents(self): + self.store.collection.delete_many.return_value.deleted_count = 1 + result = await self.store.adelete_documents([str(ObjectId())]) + self.assertTrue(result) + + async def test_asearch(self): + self.store.collection.create_index = MagicMock() + + mock_docs = [ + {"_id": ObjectId("60dbf3a1fc13ae1a3b000000"), "data": "mock data", "user": "test_user", "score": 0.9} + ] + mock_cursor = MockSortCursor(mock_docs) + + self.store.collection.find.return_value = mock_cursor + + result = await self.store.asearch("prompt", "value") + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["user"], "test_user") + self.assertIn("score", result[0]) + + +if __name__ == "__main__": + unittest.main()