From b50e1b5bd382d07cb1c7e7c77563a3422b61962e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 20 Sep 2025 14:23:23 +0000 Subject: [PATCH 1/2] Initial plan From 9b750b729cb6abafe6f20e32249f14ade190f732 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 20 Sep 2025 14:29:08 +0000 Subject: [PATCH 2/2] Fix code formatting issues: apply Black and isort, rename test_main.txt to test_main.py Co-authored-by: mickeyjoes <10925516+mickeyjoes@users.noreply.github.com> --- agent_manager.py | 142 +++++++----- cache_manager.py | 2 +- classify_prompts.py | 90 +++++--- cohere_rerank.py | 20 +- doc_manager.py | 130 +++++++---- document_summarizer.py | 51 +++-- functions_manager.py | 212 ++++++++++++------ functions_manager_test.py | 84 +++++-- generative_conversation_summarized_memory.py | 128 +++++++---- generative_memory.py | 193 +++++++++++----- main.py | 196 ++++++++++------ memory_summarizer.py | 80 ++++--- preferences_resolver.py | 166 ++++++++------ preferences_updater.py | 76 ++++--- qdrant_retriever.py | 87 ++++--- queryplan_manager.py | 43 ++-- rate_limiter.py | 6 +- reader_writer_lock.py | 1 + tests/agent_manager_test.py | 82 +++++-- tests/conftest.py | 37 +-- tests/generative_memory_test.py | 59 +++-- tests/qdrant_retriever_test.py | 36 ++- tests/test_document_summarizer.py | 39 +++- tests/test_functions_manager.py | 66 +++--- ...nerative_conversation_summarized_memory.py | 60 +++-- tests/{test_main.txt => test_main.py} | 37 ++- tests/test_memory_summarizer.py | 58 +++-- tests/web_manager_test.py | 31 ++- web_manager.py | 121 ++++++---- 29 files changed, 1539 insertions(+), 794 deletions(-) rename tests/{test_main.txt => test_main.py} (79%) diff --git a/agent_manager.py b/agent_manager.py index db01ba0..5003edf 100644 --- a/agent_manager.py +++ b/agent_manager.py @@ -1,29 +1,29 @@ -import time +import asyncio import logging import os -import asyncio +import time import traceback -import cachetools.func +from datetime import datetime, timedelta +from typing import Any, Dict +import cachetools.func from dotenv import load_dotenv -from langchain_openai import OpenAIEmbeddings -from qdrant_retriever import QDrantVectorStoreRetriever -from cohere_rerank import CohereRerank -from generative_memory import GenerativeAgentMemory from langchain.retrievers import ContextualCompressionRetriever +from langchain.schema import Document +from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_qdrant import Qdrant +from pydantic import BaseModel from qdrant_client import QdrantClient from qdrant_client.http import models as rest from qdrant_client.http.models import PayloadSchemaType -from memory_summarizer import MemorySummarizer -from pydantic import BaseModel + +from cohere_rerank import CohereRerank from document_summarizer import FlexibleDocumentSummarizer -from langchain_openai import ChatOpenAI -from langchain.schema import Document -from datetime import datetime, timedelta -from typing import Any, Dict +from generative_memory import GenerativeAgentMemory +from memory_summarizer import MemorySummarizer from preferences_resolver import PreferencesResolver from preferences_updater import PreferencesUpdater +from qdrant_retriever import QDrantVectorStoreRetriever class MemoryInput(BaseModel): @@ -37,7 +37,12 @@ def __str__(self): return str(self.summary) + self.user_id + self.query + self.conversation_id def __eq__(self, other): - return self.user_id == other.user_id and self.query == other.query and self.conversation_id == other.conversation_id and self.summary == other.summary + return ( + self.user_id == other.user_id + and self.query == other.query + and self.conversation_id == other.conversation_id + and self.summary == other.summary + ) def __hash__(self): return hash(str(self)) @@ -64,12 +69,12 @@ def __init__(self, rate_limiter, rate_limiter_sync): self.QDRANT_URL = os.getenv("QDRANT_URL") self.rate_limiter = rate_limiter self.rate_limiter_sync = rate_limiter_sync - self.client = QdrantClient( - url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY) + self.client = QdrantClient(url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY) self.verbose = True self.preferences_resolver = PreferencesResolver() self.preferences_updater = PreferencesUpdater( - self.preferences_resolver, self.verbose) + self.preferences_resolver, self.verbose + ) async def push_memory(self, memory_output: MemoryOutput): """Add new memory to the current index for a specific user.""" @@ -77,19 +82,22 @@ async def push_memory(self, memory_output: MemoryOutput): memory = self.load(memory_output.api_key, memory_output.user_id) try: # update preferences on every exchange but only save summarized memory of a "finished" exchange, reflect on an important summarized memory and then decay memories - asyncio.create_task(memory.pause_to_reflect( - memory_output.dict(), self.preferences_resolver)) + asyncio.create_task( + memory.pause_to_reflect(memory_output.dict(), self.preferences_resolver) + ) # asyncio.create_task(self.preferences_updater.update_preferences(ChatOpenAI(openai_api_key=memory_output.api_key, # model="gpt-4.1-mini", temperature=0), memory_output.query, memory_output.llm_response, memory_output.user_id)) # decay memory by summarizing it continiously until max_summarizations then prune asyncio.create_task(memory.decay()) except Exception as e: logging.warning( - f"AgentManager: push_memory exception {e}\n{traceback.format_exc()}") + f"AgentManager: push_memory exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() logging.info( - f"AgentManager: push_memory operation took {end - start} seconds") + f"AgentManager: push_memory operation took {end - start} seconds" + ) return end - start def create_new_memory_retriever(self, api_key: str, user_id: str): @@ -105,32 +113,55 @@ def create_new_memory_retriever(self, api_key: str, user_id: str): ), ) self.client.create_payload_index( - collection_name, "metadata.extra_index", field_schema=PayloadSchemaType.KEYWORD) + collection_name, + "metadata.extra_index", + field_schema=PayloadSchemaType.KEYWORD, + ) except: print("AgentManager: loaded from cloud...") finally: logging.info( - f"AgentManager: Creating memory store with collection {collection_name}") - vectorstore = Qdrant(self.client, collection_name, OpenAIEmbeddings( - model="text-embedding-3-small", openai_api_key=api_key)) + f"AgentManager: Creating memory store with collection {collection_name}" + ) + vectorstore = Qdrant( + self.client, + collection_name, + OpenAIEmbeddings( + model="text-embedding-3-small", openai_api_key=api_key + ), + ) compressor = CohereRerank() compression_retriever = ContextualCompressionRetriever( - base_compressor=compressor, base_retriever=QDrantVectorStoreRetriever( - rate_limiter=self.rate_limiter, rate_limiter_sync=self.rate_limiter_sync, collection_name=collection_name, client=self.client, vectorstore=vectorstore, - ) + base_compressor=compressor, + base_retriever=QDrantVectorStoreRetriever( + rate_limiter=self.rate_limiter, + rate_limiter_sync=self.rate_limiter_sync, + collection_name=collection_name, + client=self.client, + vectorstore=vectorstore, + ), ) return compression_retriever def create_memory(self, api_key: str, user_id: str): return GenerativeAgentMemory( rate_limiter=self.rate_limiter, - llm=ChatOpenAI(openai_api_key=api_key, - model="gpt-4.1-mini", max_tokens=1024), - memory_retriever=self.create_new_memory_retriever( - api_key, user_id), - memory_summarizer=MemorySummarizer(rate_limiter=self.rate_limiter, rate_limiter_sync=self.rate_limiter_sync, flexible_document_summarizer=FlexibleDocumentSummarizer( - ChatOpenAI(openai_api_key=api_key, model="gpt-4.1-mini", temperature=0), verbose=self.verbose), agent_manager=self), - verbose=self.verbose + llm=ChatOpenAI( + openai_api_key=api_key, model="gpt-4.1-mini", max_tokens=1024 + ), + memory_retriever=self.create_new_memory_retriever(api_key, user_id), + memory_summarizer=MemorySummarizer( + rate_limiter=self.rate_limiter, + rate_limiter_sync=self.rate_limiter_sync, + flexible_document_summarizer=FlexibleDocumentSummarizer( + ChatOpenAI( + openai_api_key=api_key, model="gpt-4.1-mini", temperature=0 + ), + verbose=self.verbose, + ), + agent_manager=self, + ), + verbose=self.verbose, ) @cachetools.func.ttl_cache(maxsize=16384, ttl=36000) @@ -139,8 +170,7 @@ def load(self, api_key: str, user_id: str) -> GenerativeAgentMemory: start = time.time() memory = self.create_memory(api_key, user_id) end = time.time() - logging.info( - f"AgentManager: Load operation took {end - start} seconds") + logging.info(f"AgentManager: Load operation took {end - start} seconds") return memory def _document_from_scored_point( @@ -165,7 +195,8 @@ def get_key_value_document(self, collection_name, key, value) -> Document: ] ) record, _ = self.client.scroll( - collection_name=collection_name, scroll_filter=filter, limit=1) + collection_name=collection_name, scroll_filter=filter, limit=1 + ) if record is not None and len(record) > 0: return self._document_from_scored_point( record[0], "page_content", "metadata" @@ -175,7 +206,10 @@ def get_key_value_document(self, collection_name, key, value) -> Document: def load_summary(self, memory_input: MemoryInput) -> Dict[str, str]: doc = self.get_key_value_document( - f"{memory_input.user_id}_summaries", "metadata.extra_index", memory_input.conversation_id) + f"{memory_input.user_id}_summaries", + "metadata.extra_index", + memory_input.conversation_id, + ) ret = "" if doc: ret = self.format_summary_simple(doc) @@ -187,8 +221,7 @@ def load_summary(self, memory_input: MemoryInput) -> Dict[str, str]: async def load_memory(self, memory_input: MemoryInput): memory = self.load(memory_input.api_key, memory_input.user_id) return await memory.load_memory_variables( - queries=[memory_input.query], - conversation_id=memory_input.conversation_id + queries=[memory_input.query], conversation_id=memory_input.conversation_id ) def _time_ago(self, timestamp: float) -> str: @@ -209,8 +242,7 @@ def format_summary_simple(self, conversation_summary: Document) -> str: created_ago = self._time_ago(created_at) # Extracting the extra_index (conversation_id) - conversation_id = conversation_summary.metadata.get( - "extra_index", "N/A") + conversation_id = conversation_summary.metadata.get("extra_index", "N/A") return f"(created: {created_ago}, conversation_id: {conversation_id}) {conversation_summary.page_content}" async def pull_memory(self, memory_input: MemoryInput): @@ -222,7 +254,8 @@ async def pull_memory(self, memory_input: MemoryInput): if memory_input.summary: if len(memory_input.conversation_id) <= 0: logging.warning( - f"AgentManager: pull_memory asked for summary but no conversation_id provided!") + f"AgentManager: pull_memory asked for summary but no conversation_id provided!" + ) end = time.time() return {}, end - start response = self.load_summary(memory_input) @@ -230,11 +263,13 @@ async def pull_memory(self, memory_input: MemoryInput): response = await self.load_memory(memory_input) except Exception as e: logging.warning( - f"AgentManager: pull_memory exception {e}\n{traceback.format_exc()}") + f"AgentManager: pull_memory exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() logging.info( - f"AgentManager: pull_memory operation took {end - start} seconds") + f"AgentManager: pull_memory operation took {end - start} seconds" + ) return response, end - start def clear_collection_with_extra_index(self, collection_name, extra_index) -> None: @@ -247,22 +282,25 @@ def clear_collection_with_extra_index(self, collection_name, extra_index) -> Non ) ] ) - self.client.delete(collection_name=collection_name, - points_selector=filter) + self.client.delete(collection_name=collection_name, points_selector=filter) def clear_conversation(self, clear_memory: ClearMemory): """Delete all memories for a specific conversation with a user.""" start = time.time() try: self.clear_collection_with_extra_index( - clear_memory.user_id, clear_memory.conversation_id) + clear_memory.user_id, clear_memory.conversation_id + ) self.clear_collection_with_extra_index( - f"{clear_memory.user_id}_summaries", clear_memory.conversation_id) + f"{clear_memory.user_id}_summaries", clear_memory.conversation_id + ) except Exception as e: logging.warning( - f"AgentManager: clear_conversation exception {e}\n{traceback.format_exc()}") + f"AgentManager: clear_conversation exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() logging.info( - f"AgentManager: clear_conversation operation took {end - start} seconds") + f"AgentManager: clear_conversation operation took {end - start} seconds" + ) return "success", end - start diff --git a/cache_manager.py b/cache_manager.py index bcc7b59..3c96d5b 100644 --- a/cache_manager.py +++ b/cache_manager.py @@ -1,6 +1,6 @@ from pydantic import BaseModel + class CacheClearInput(BaseModel): cache_types: list console_key: str - diff --git a/classify_prompts.py b/classify_prompts.py index b9f0b3c..d03f431 100644 --- a/classify_prompts.py +++ b/classify_prompts.py @@ -1,42 +1,74 @@ class ClassifyPrompts: def __init__(self): - CoTClassifyPrompt = {"Rationale": "This type is for complex queries that require a step-by-step logical reasoning process. It's ideal for solving puzzles, mathematical problems, or any query that demands a rigorous logical approach.", - "Classification Guidelines": "If the query asks for a solution to a problem that involves multiple steps, logical deductions, or the need to evaluate different hypotheses, it falls under this category."} - self.CoTClassifyPrompt = "\n".join(f"{key}: {value}" for key, value in CoTClassifyPrompt.items()) + CoTClassifyPrompt = { + "Rationale": "This type is for complex queries that require a step-by-step logical reasoning process. It's ideal for solving puzzles, mathematical problems, or any query that demands a rigorous logical approach.", + "Classification Guidelines": "If the query asks for a solution to a problem that involves multiple steps, logical deductions, or the need to evaluate different hypotheses, it falls under this category.", + } + self.CoTClassifyPrompt = "\n".join( + f"{key}: {value}" for key, value in CoTClassifyPrompt.items() + ) CoT = "Role: Logical Reasoning (CoT). Steps: 1. Interpretation & Historical Context: Understand the user's query and reference previous interactions. 2. Research: Conduct a web search and access memory for relevant information. 3. Logical Analysis & Problem Breakdown: Apply chain-of-thought prompting and evaluate potential hypotheses. 4. Solution Drafting & Validation: Develop a solution and ensure its logical coherence. 5. Lambda Execution: If necessary, create and run lambda code. 6. Comprehensive Solution & Feedback: Present the final solution and seek user feedback for validation." - - CodeClassifyPrompt = {"Rationale": "This type is for queries related to coding, algorithms, or technical issues. While it may involve logical reasoning, the focus is more on the technical aspects.", - "Classification Guidelines": "If the query asks for code, discusses algorithms, or involves technical jargon, it falls under this category."} - self.CodeClassifyPrompt = "\n".join(f"{key}: {value}" for key, value in CodeClassifyPrompt.items()) + + CodeClassifyPrompt = { + "Rationale": "This type is for queries related to coding, algorithms, or technical issues. While it may involve logical reasoning, the focus is more on the technical aspects.", + "Classification Guidelines": "If the query asks for code, discusses algorithms, or involves technical jargon, it falls under this category.", + } + self.CodeClassifyPrompt = "\n".join( + f"{key}: {value}" for key, value in CodeClassifyPrompt.items() + ) Code = "Role: Coding. Steps: 1. Technical Query Interpretation: Grasp the user's technical request and identify its elements. 2. Task Breakdown & Code Prompting: Break the problem into subtasks using chain-of-code prompting. 3. Solution Drafting: Offer code snippets, explanations, and justifications. 4. Lambda Execution & Code Testing: With user agreement, run and test the code. 5. Technical Solution Presentation & Feedback: Deliver the solution tailored to the user's needs and seek improvement suggestions." - - QAClassifyPrompt = {"Rationale": "This type is for straightforward questions that require a simple answer without the need for extensive reasoning or elaboration.", - "Classification Guidelines": "If the query asks for a fact, a definition, or a simple explanation, it falls under this category."} - self.QAClassifyPrompt = "\n".join(f"{key}: {value}" for key, value in QAClassifyPrompt.items()) + + QAClassifyPrompt = { + "Rationale": "This type is for straightforward questions that require a simple answer without the need for extensive reasoning or elaboration.", + "Classification Guidelines": "If the query asks for a fact, a definition, or a simple explanation, it falls under this category.", + } + self.QAClassifyPrompt = "\n".join( + f"{key}: {value}" for key, value in QAClassifyPrompt.items() + ) QA = "Role: Question/Answer (QA). Steps: 1. Query Interpretation: Understand the user's question and its context. 2. Information Retrieval: Use web searches, memory, and external functions to gather pertinent data. 3. Draft & Refine Answers: Formulate initial answers and refine them based on the gathered info. 4. Answer Presentation: Respond with the most accurate answer, incorporating onboarding details if relevant." - - ConversationClassifyPrompt = {"Rationale": "This is the default type that is for queries that are more conversational in nature and do not require a specific format or structure.", - "Classification Guidelines": "If the query is open-ended, opinion-based, or conversational, it falls under this category."} - self.ConversationClassifyPrompt = "\n".join(f"{key}: {value}" for key, value in ConversationClassifyPrompt.items()) - - EmotionClassifyPrompt = {"Rationale": "This type is for queries that seek emotional support, advice, or guidance on personal matters.", - "Classification Guidelines": "If the query asks for advice, emotional support, or guidance on personal issues, it falls under this category."} - self.EmotionClassifyPrompt = "\n".join(f"{key}: {value}" for key, value in EmotionClassifyPrompt.items()) + + ConversationClassifyPrompt = { + "Rationale": "This is the default type that is for queries that are more conversational in nature and do not require a specific format or structure.", + "Classification Guidelines": "If the query is open-ended, opinion-based, or conversational, it falls under this category.", + } + self.ConversationClassifyPrompt = "\n".join( + f"{key}: {value}" for key, value in ConversationClassifyPrompt.items() + ) + + EmotionClassifyPrompt = { + "Rationale": "This type is for queries that seek emotional support, advice, or guidance on personal matters.", + "Classification Guidelines": "If the query asks for advice, emotional support, or guidance on personal issues, it falls under this category.", + } + self.EmotionClassifyPrompt = "\n".join( + f"{key}: {value}" for key, value in EmotionClassifyPrompt.items() + ) Emotion = "Role: Empathy. Steps: 1. Emotional Query Interpretation: Determine the emotional context and nuance of the user's input. 2. Historical Context & User Traits Analysis: Reflect on past interactions and consider user traits for a tailored response. 3. Compassionate Response Drafting: Formulate a supportive and understanding response. 4. User Feedback & Mood Update: Check in with the user post-response and adjust mood settings accordingly." - CreativeClassifyPrompt = {"Rationale": "This type is for queries that ask for creative output, like writing a poem, story, or generating art.", - "Classification Guidelines": "If the query asks for a creative piece of content, it falls under this category."} - self.CreativeClassifyPrompt = "\n".join(f"{key}: {value}" for key, value in CreativeClassifyPrompt.items()) + CreativeClassifyPrompt = { + "Rationale": "This type is for queries that ask for creative output, like writing a poem, story, or generating art.", + "Classification Guidelines": "If the query asks for a creative piece of content, it falls under this category.", + } + self.CreativeClassifyPrompt = "\n".join( + f"{key}: {value}" for key, value in CreativeClassifyPrompt.items() + ) Creative = "Role: Creative. Steps: 1. Creative Query Interpretation: Understand the user's creative request. 2. Research for Inspiration & Tools: Search for relevant references and fetch external tools. 3. Content Generation & User Feedback: Produce initial content and refine based on user feedback. 4. Final Creative Presentation: Deliver the polished creative piece tailored to user preferences." - EducationalClassifyPrompt = {"Rationale": "This type is for queries that seek educational information or a tutorial on how to do something.", - "Classification Guidelines": "If the query asks for a step-by-step guide, tutorial, or educational explanation, it falls under this category."} - self.EducationalClassifyPrompt = "\n".join(f"{key}: {value}" for key, value in EducationalClassifyPrompt.items()) + EducationalClassifyPrompt = { + "Rationale": "This type is for queries that seek educational information or a tutorial on how to do something.", + "Classification Guidelines": "If the query asks for a step-by-step guide, tutorial, or educational explanation, it falls under this category.", + } + self.EducationalClassifyPrompt = "\n".join( + f"{key}: {value}" for key, value in EducationalClassifyPrompt.items() + ) Education = "Role: Education. Steps: 1. Educational Query Interpretation: Ascertain the user's learning goal or question. 2. Information Retrieval & Resource Collection: Gather data and find educational tools to support the user's journey. 3. Learning Experience Generation: Create tailored learning experiences, consider learning style, and track progress possibly through task management if large enough tasks. 4. Educational Journey Summary & Feedback: Summarize the learning process, present findings, and seek feedback." - FactualClassifyPrompt = {"Rationale": "This type is for queries that require a detailed, factual answer based on research or data.", - "Classification Guidelines": "If the query asks for detailed information that requires research or data-backed answers, it falls under this category."} - self.FactualClassifyPrompt = "\n".join(f"{key}: {value}" for key, value in FactualClassifyPrompt.items()) + FactualClassifyPrompt = { + "Rationale": "This type is for queries that require a detailed, factual answer based on research or data.", + "Classification Guidelines": "If the query asks for detailed information that requires research or data-backed answers, it falls under this category.", + } + self.FactualClassifyPrompt = "\n".join( + f"{key}: {value}" for key, value in FactualClassifyPrompt.items() + ) Factual = "Role: Research (Factual). Steps: 1. Factual Query Interpretation: Understand the factual or research-oriented request from the user. 2. Comprehensive Research: Conduct thorough searches, utilizing memory and external tools, while fetching relevant data. 3. Data & Source Collection: Collate facts, visuals, multimedia, and other relevant content. 4. Fact-Checked Response Compilation & Presentation: Offer a data-backed response, ensuring accuracy and source attribution." self.classify_prompts = { @@ -46,7 +78,7 @@ def __init__(self): "EmpathyGPT": Emotion, "CreativeGPT": Creative, "EduGPT": Education, - "ResearchGPT": Factual + "ResearchGPT": Factual, } def to_prompt_string(self) -> str: diff --git a/cohere_rerank.py b/cohere_rerank.py index 21f07fe..e0d9e8b 100644 --- a/cohere_rerank.py +++ b/cohere_rerank.py @@ -1,16 +1,15 @@ from __future__ import annotations +import logging from copy import deepcopy from typing import Any, Dict, List, Optional, Sequence, Union -from langchain_core._api.deprecation import deprecated -from langchain_core.documents import Document -from pydantic import Field, field_validator - from langchain.callbacks.manager import Callbacks from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.utils import get_from_dict_or_env -import logging +from langchain_core._api.deprecation import deprecated +from langchain_core.documents import Document +from pydantic import Field, field_validator @deprecated( @@ -44,9 +43,13 @@ def model_post_init(self, __context: Any) -> None: "Please install it with `pip install cohere`." ) cohere_api_key = get_from_dict_or_env( - {"cohere_api_key": self.cohere_api_key}, "cohere_api_key", "COHERE_API_KEY" + {"cohere_api_key": self.cohere_api_key}, + "cohere_api_key", + "COHERE_API_KEY", + ) + self.client = cohere.AsyncClient( + cohere_api_key, client_name=self.user_agent ) - self.client = cohere.AsyncClient(cohere_api_key, client_name=self.user_agent) async def rerank( self, @@ -130,8 +133,7 @@ async def acompress_documents( # logging.info(f"acompress_documents: docs {documents} query {query}") for res in await self.rerank(documents, query): doc = documents[res["index"]] - doc_copy = Document( - doc.page_content, metadata=deepcopy(doc.metadata)) + doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) doc_copy.metadata["relevance_score"] = res["relevance_score"] compressed.append(doc_copy) # logging.info(f"acompress_documents: compressed {compressed} query {query}") diff --git a/doc_manager.py b/doc_manager.py index 4383586..40c72fe 100644 --- a/doc_manager.py +++ b/doc_manager.py @@ -1,26 +1,27 @@ -import time import datetime -import schedule +import logging import os import random -import logging +import time import traceback -import cachetools.func +from datetime import datetime +import cachetools.func +import schedule from dotenv import load_dotenv -from llama_index.core.langchain_helpers.text_splitter import SentenceSplitter -from qdrant_client import QdrantClient -from pydantic import BaseModel -from langchain_qdrant import Qdrant -from qdrant_retriever import QDrantVectorStoreRetriever -from langchain_openai import OpenAIEmbeddings from langchain.retrievers import ContextualCompressionRetriever -from cohere_rerank import CohereRerank from langchain.schema import Document -from datetime import datetime +from langchain_openai import OpenAIEmbeddings +from langchain_qdrant import Qdrant +from llama_index.core.langchain_helpers.text_splitter import SentenceSplitter +from pydantic import BaseModel +from qdrant_client import QdrantClient from qdrant_client.http import models as rest from qdrant_client.http.models import PayloadSchemaType +from cohere_rerank import CohereRerank +from qdrant_retriever import QDrantVectorStoreRetriever + class CacheDoc(BaseModel): source_url: str @@ -64,8 +65,7 @@ def __init__(self, rate_limiter, rate_limiter_sync): self.rate_limiter_sync = rate_limiter_sync self.QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") self.QDRANT_URL = os.getenv("QDRANT_URL") - self.client = QdrantClient( - url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY) + self.client = QdrantClient(url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY) self.collection_name = "doc" def create_new_web_retriever(self, api_key: str): @@ -80,19 +80,33 @@ def create_new_web_retriever(self, api_key: str): ), ) self.client.create_payload_index( - self.collection_name, "metadata.extra_index", field_schema=PayloadSchemaType.KEYWORD) + self.collection_name, + "metadata.extra_index", + field_schema=PayloadSchemaType.KEYWORD, + ) except: logging.info("DocManager: loaded from cloud...") finally: logging.info( - f"DocManager: Creating memory store with collection {self.collection_name}") - vectorstore = Qdrant(self.client, self.collection_name, OpenAIEmbeddings( - model="text-embedding-3-small", openai_api_key=api_key)) + f"DocManager: Creating memory store with collection {self.collection_name}" + ) + vectorstore = Qdrant( + self.client, + self.collection_name, + OpenAIEmbeddings( + model="text-embedding-3-small", openai_api_key=api_key + ), + ) compressor = CohereRerank() compression_retriever = ContextualCompressionRetriever( - base_compressor=compressor, base_retriever=QDrantVectorStoreRetriever( - rate_limiter=self.rate_limiter, rate_limiter_sync=self.rate_limiter_sync, collection_name=self.collection_name, client=self.client, vectorstore=vectorstore, - ) + base_compressor=compressor, + base_retriever=QDrantVectorStoreRetriever( + rate_limiter=self.rate_limiter, + rate_limiter_sync=self.rate_limiter_sync, + collection_name=self.collection_name, + client=self.client, + vectorstore=vectorstore, + ), ) return compression_retriever @@ -101,15 +115,17 @@ def extract_text_and_source_url(self, retrieved_nodes): seen = set() for document in retrieved_nodes: text = document.page_content - source_url = document.metadata.get('source_url') + source_url = document.metadata.get("source_url") # Create a tuple of text and source_url to check for duplicates key = (text, source_url) if key not in seen: - result.append({'text': text, 'source_url': source_url}) + result.append({"text": text, "source_url": source_url}) seen.add(key) return result - async def get_retrieved_nodes(self, memory: ContextualCompressionRetriever, function_input: DocSearchInput): + async def get_retrieved_nodes( + self, memory: ContextualCompressionRetriever, function_input: DocSearchInput + ): filter = rest.Filter( must=[ rest.FieldCondition( @@ -133,12 +149,15 @@ def load(self, api_key: str): async def add_doc(self, function_input: DocAddInput): start = time.time() if len(function_input.source_url) <= 0 or len(function_input.html_doc) <= 0: - logging.warning( - "DocManager: Cannot add information because data missing") + logging.warning("DocManager: Cannot add information because data missing") end = time.time() return "fail", end - start memory = self.load(function_input.api_key) - srcExist, _ = self.does_source_exist(CacheDoc(source_url=function_input.source_url, category=function_input.category)) + srcExist, _ = self.does_source_exist( + CacheDoc( + source_url=function_input.source_url, category=function_input.category + ) + ) if srcExist: logging.warning("DocManager: source_url already exists") end = time.time() @@ -148,22 +167,36 @@ async def add_doc(self, function_input: DocAddInput): if len(function_input.html_doc) > 0: text_splitter = SentenceSplitter() chunks = text_splitter.split_text(text=function_input.html_doc) - documents.extend([Document(page_content=chunk, metadata={"id": random.randint( - 0, 2**32 - 1), "extra_index": function_input.category, "last_accessed_at": nowStamp, 'source_url': function_input.source_url}) for chunk in chunks]) + documents.extend( + [ + Document( + page_content=chunk, + metadata={ + "id": random.randint(0, 2**32 - 1), + "extra_index": function_input.category, + "last_accessed_at": nowStamp, + "source_url": function_input.source_url, + }, + ) + for chunk in chunks + ] + ) if len(documents) > 0: ids = [doc.metadata["id"] for doc in documents] - await self.rate_limiter.execute(memory.base_retriever.vectorstore.aadd_documents, documents, ids=ids) + await self.rate_limiter.execute( + memory.base_retriever.vectorstore.aadd_documents, documents, ids=ids + ) end = time.time() logging.info( - f"DocManager: Loaded from documents operation took {end - start} seconds") + f"DocManager: Loaded from documents operation took {end - start} seconds" + ) return "success", end - start def delete_doc(self, function_input: DocDeleteInput): """Delete docs by source_url.""" start = time.time() if 0 >= len(function_input.source_url): - logging.warning( - "DocManager: Cannot delete document because data missing") + logging.warning("DocManager: Cannot delete document because data missing") end = time.time() return "fail", end - start try: @@ -176,14 +209,16 @@ def delete_doc(self, function_input: DocDeleteInput): rest.FieldCondition( key="metadata.extra_index", match=rest.MatchValue(value=function_input.category), - ) + ), ] ) self.client.delete( - collection_name=self.collection_name, points_selector=filter) + collection_name=self.collection_name, points_selector=filter + ) end = time.time() logging.info( - f"DocManager: Delete documents operation took {end - start} seconds") + f"DocManager: Delete documents operation took {end - start} seconds" + ) except Exception as e: logging.warning(f"DocManager: delete_doc exception {e}") end = time.time() @@ -202,15 +237,19 @@ async def search_doc(self, function_input: DocSearchInput): if len(nodes) > 0: ids = [doc.metadata["id"] for doc in nodes] for doc in nodes: - doc.metadata.pop('relevance_score', None) - await self.rate_limiter.execute(memory.base_retriever.vectorstore.aadd_documents, nodes, ids=ids) + doc.metadata.pop("relevance_score", None) + await self.rate_limiter.execute( + memory.base_retriever.vectorstore.aadd_documents, nodes, ids=ids + ) except Exception as e: logging.warning( - f"DocManager: search_html exception {e}\n{traceback.format_exc()}") + f"DocManager: search_html exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() logging.info( - f"DocManager: search_html operation took {end - start} seconds") + f"DocManager: search_html operation took {end - start} seconds" + ) return response, end - start def does_source_exist(self, function_input: CacheDoc): @@ -226,16 +265,19 @@ def does_source_exist(self, function_input: CacheDoc): rest.FieldCondition( key="metadata.extra_index", match=rest.MatchValue(value=function_input.category), - ) + ), ] ) result, _ = self.client.scroll( - collection_name=self.collection_name, scroll_filter=filter, limit=1) + collection_name=self.collection_name, scroll_filter=filter, limit=1 + ) except Exception as e: logging.warning( - f"DocManager: does_source_exist exception {e}\n{traceback.format_exc()}") + f"DocManager: does_source_exist exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() logging.info( - f"DocManager: does_source_exist operation took {end - start} seconds") + f"DocManager: does_source_exist operation took {end - start} seconds" + ) return result is not None and len(result) > 0, end - start diff --git a/document_summarizer.py b/document_summarizer.py index c565b40..48300d9 100644 --- a/document_summarizer.py +++ b/document_summarizer.py @@ -1,11 +1,10 @@ import asyncio -import traceback -import logging import json - +import logging +import traceback from typing import Sequence -from langchain.schema import Document -from langchain.schema import SystemMessage, HumanMessage + +from langchain.schema import Document, HumanMessage, SystemMessage from langchain_openai import ChatOpenAI @@ -20,8 +19,10 @@ def to_prompt_string(self) -> str: if self.summarizations == 1: summarization_description = "has already been summarized once." else: - summarization_description = f"has already been summarized {self.summarizations} times." - return (f"Summarize this memory. Keep in mind it's of {self.importance} importance and {summarization_description}.") + summarization_description = ( + f"has already been summarized {self.summarizations} times." + ) + return f"Summarize this memory. Keep in mind it's of {self.importance} importance and {summarization_description}." class FlexibleDocumentSummarizer: @@ -35,25 +36,39 @@ def __init__(self, llm: ChatOpenAI, verbose: bool = False) -> None: async def _get_single_summary(self, document: Document) -> None: try: summary_prompt = SummaryPrompt( - summarizations=document.metadata["summarizations"], importance=document.metadata["importance"]) + summarizations=document.metadata["summarizations"], + importance=document.metadata["importance"], + ) memory = json.loads(document.page_content) summary_prompt_str = summary_prompt.to_prompt_string() - user_message = [SystemMessage( - content=summary_prompt_str), HumanMessage(content=memory["user"])] - aida_message = [SystemMessage( - content=summary_prompt_str), HumanMessage(content=memory["AiDA"])] + user_message = [ + SystemMessage(content=summary_prompt_str), + HumanMessage(content=memory["user"]), + ] + aida_message = [ + SystemMessage(content=summary_prompt_str), + HumanMessage(content=memory["AiDA"]), + ] response = await self._llm.agenerate([user_message, aida_message]) - if not response.generations or not response.generations[0] or not response.generations[1]: - raise Exception( - "LLM did not provide a valid summary response.") + if ( + not response.generations + or not response.generations[0] + or not response.generations[1] + ): + raise Exception("LLM did not provide a valid summary response.") # Update the document's page_content in place with the summarized text document.page_content = json.dumps( - {'user': response.generations[0][0].text, 'AiDA': response.generations[1][0].text}) + { + "user": response.generations[0][0].text, + "AiDA": response.generations[1][0].text, + } + ) except Exception as e: if self.verbose: logging.warning( - f"FlexibleDocumentSummarizer: _get_single_summary exception e: {e}\n{traceback.format_exc()}") + f"FlexibleDocumentSummarizer: _get_single_summary exception e: {e}\n{traceback.format_exc()}" + ) - async def asummarize(self, documents: Sequence[Document]) -> None: + async def asummarize(self, documents: Sequence[Document]) -> None: tasks = [self._get_single_summary(document) for document in documents] await asyncio.gather(*tasks) diff --git a/functions_manager.py b/functions_manager.py index 1eecc01..35f44e3 100644 --- a/functions_manager.py +++ b/functions_manager.py @@ -1,28 +1,28 @@ -import time -import tiktoken +import asyncio import json +import logging import os import random -import asyncio -import logging +import time import traceback -import cachetools.func +from datetime import datetime, timedelta +from typing import List +import cachetools.func +import tiktoken from dotenv import load_dotenv -from qdrant_client import QdrantClient -from typing import List -from datetime import datetime -from pydantic import BaseModel, Field -from qdrant_client.http import models as rest -from langchain_qdrant import Qdrant -from langchain_openai import OpenAIEmbeddings -from qdrant_retriever import QDrantVectorStoreRetriever from langchain.retrievers import ContextualCompressionRetriever -from cohere_rerank import CohereRerank from langchain.schema import Document -from datetime import datetime, timedelta +from langchain_openai import OpenAIEmbeddings +from langchain_qdrant import Qdrant +from pydantic import BaseModel, Field +from qdrant_client import QdrantClient +from qdrant_client.http import models as rest from qdrant_client.http.models import PayloadSchemaType +from cohere_rerank import CohereRerank +from qdrant_retriever import QDrantVectorStoreRetriever + class ActionItem(BaseModel): action: str @@ -33,7 +33,11 @@ def __str__(self): return self.action + self.intent + self.category def __eq__(self, other): - return self.action == other.action and self.intent == other.intent and self.category == other.category + return ( + self.action == other.action + and self.intent == other.intent + and self.category == other.category + ) def __hash__(self): return hash(str(self)) @@ -42,8 +46,16 @@ def __hash__(self): class FunctionInput(BaseModel): api_key: str user_id: str = None - action_items: List[ActionItem] = Field(..., example=[ - {"action": "action_example", "intent": "intent_example", "category": "category_example"}]) + action_items: List[ActionItem] = Field( + ..., + example=[ + { + "action": "action_example", + "intent": "intent_example", + "category": "category_example", + } + ], + ) def __str__(self): if self.user_id: @@ -53,7 +65,10 @@ def __str__(self): def __eq__(self, other): if self.user_id: - return self.action_items == other.action_items and self.user_id == other.user_id + return ( + self.action_items == other.action_items + and self.user_id == other.user_id + ) else: return self.action_items == other.action_items @@ -70,8 +85,16 @@ class FunctionItem(BaseModel): class FunctionOutput(BaseModel): api_key: str user_id: str = None - functions: List[FunctionItem] = Field(..., example=[ - {"name": "name_example", "description": "description_example", "category": "category_example"}]) + functions: List[FunctionItem] = Field( + ..., + example=[ + { + "name": "name_example", + "description": "description_example", + "category": "category_example", + } + ], + ) class FunctionsManager: @@ -86,8 +109,7 @@ def __init__(self, rate_limiter, rate_limiter_sync): self.rate_limiter_sync = rate_limiter_sync self.max_length_allowed = 512 self.collection_name = "functions" - self.client = QdrantClient( - url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY) + self.client = QdrantClient(url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY) self.inited = False def create_new_functions_retriever(self, api_key: str): @@ -102,19 +124,33 @@ def create_new_functions_retriever(self, api_key: str): ), ) self.client.create_payload_index( - self.collection_name, "metadata.user_id", field_schema=PayloadSchemaType.KEYWORD) + self.collection_name, + "metadata.user_id", + field_schema=PayloadSchemaType.KEYWORD, + ) except: logging.info(f"FunctionsManager: loaded from cloud...") finally: logging.info( - f"FunctionsManager: Creating memory store with collection {self.collection_name}") - vectorstore = Qdrant(self.client, self.collection_name, OpenAIEmbeddings( - model="text-embedding-3-small", openai_api_key=api_key)) + f"FunctionsManager: Creating memory store with collection {self.collection_name}" + ) + vectorstore = Qdrant( + self.client, + self.collection_name, + OpenAIEmbeddings( + model="text-embedding-3-small", openai_api_key=api_key + ), + ) compressor = CohereRerank() compression_retriever = ContextualCompressionRetriever( - base_compressor=compressor, base_retriever=QDrantVectorStoreRetriever( - rate_limiter=self.rate_limiter, rate_limiter_sync=self.rate_limiter_sync, collection_name=self.collection_name, client=self.client, vectorstore=vectorstore, - ) + base_compressor=compressor, + base_retriever=QDrantVectorStoreRetriever( + rate_limiter=self.rate_limiter, + rate_limiter_sync=self.rate_limiter_sync, + collection_name=self.collection_name, + client=self.client, + vectorstore=vectorstore, + ), ) return compression_retriever @@ -123,32 +159,35 @@ def transform(self, user_id, data, category): now = datetime.now().timestamp() result = [] for item in data: - page_content = {'name': item['name'], 'category': category, 'description': str( - item['description'])} + page_content = { + "name": item["name"], + "category": category, + "description": str(item["description"]), + } lenData = len(str(page_content)) if lenData > self.max_length_allowed: logging.info( - f"FunctionsManager: transform tried to create a function that surpasses the maximum length allowed max_length_allowed: {self.max_length_allowed} vs length of data: {lenData}") + f"FunctionsManager: transform tried to create a function that surpasses the maximum length allowed max_length_allowed: {self.max_length_allowed} vs length of data: {lenData}" + ) continue metadata = { - "id": random.randint(0, 2**32 - 1), + "id": random.randint(0, 2**32 - 1), "user_id": user_id, "extra_index": category, "last_accessed_at": now, } - doc = Document( - page_content=json.dumps(page_content), - metadata=metadata - ) + doc = Document(page_content=json.dumps(page_content), metadata=metadata) result.append(doc) return result def count_tokens(self, functions): """Count the tokens for all the functions.""" - function_types = ['information_retrieval', - 'communication', - 'data_processing', - 'sensory_perception'] + function_types = [ + "information_retrieval", + "communication", + "data_processing", + "sensory_perception", + ] encoding = tiktoken.encoding_for_model("gpt-4o-mini") tokens = [] @@ -156,8 +195,7 @@ def count_tokens(self, functions): if func_type in functions: for func in functions[func_type]: function_string = json.dumps(func) - tokens.append( - {func['name']: len(encoding.encode(function_string))}) + tokens.append({func["name"]: len(encoding.encode(function_string))}) return tokens def extract_name_and_category(self, documents): @@ -166,12 +204,12 @@ def extract_name_and_category(self, documents): for doc in documents: # Parse the page_content string into a Python dict text = json.loads(doc.page_content) - name = text.get('name') - category = text.get('category') + name = text.get("name") + category = text.get("category") # Check if this combination has been seen before if (name, category) not in seen: - result.append({'name': name, 'category': category}) + result.append({"name": name, "category": category}) seen.add((name, category)) # Mark this combination as seen return result @@ -183,10 +221,12 @@ async def pull_functions(self, function_input: FunctionInput): try: self.client.get_collection(self.collection_name) except: - with open('./utils/functions.json', 'r') as f: + with open("./utils/functions.json", "r") as f: print("FunctionsManager: Loading from functions.json") functions_json = json.load(f) - await self.push_functions(function_input.user_id, function_input.api_key, functions_json) + await self.push_functions( + function_input.user_id, function_input.api_key, functions_json + ) self.inited = True memory = self.load(function_input.api_key) response = [] @@ -194,27 +234,40 @@ async def pull_functions(self, function_input: FunctionInput): try: for action_item in function_input.action_items: query = f"action: {action_item.action} intent: {action_item.intent} category: {action_item.category}" - documents = await self.get_retrieved_nodes(memory, - query, action_item.category, function_input.user_id) + documents = await self.get_retrieved_nodes( + memory, query, action_item.category, function_input.user_id + ) if len(documents) > 0: parsed_response = self.extract_name_and_category(documents) response.append(parsed_response) # update last_accessed_at ids = [doc.metadata["id"] for doc in documents] for doc in documents: - doc.metadata.pop('relevance_score', None) - await self.rate_limiter.execute(memory.base_retriever.vectorstore.aadd_documents, documents, ids=ids) + doc.metadata.pop("relevance_score", None) + await self.rate_limiter.execute( + memory.base_retriever.vectorstore.aadd_documents, + documents, + ids=ids, + ) # loop.run_in_executor(None, self.prune_functions) except Exception as e: logging.warning( - f"FunctionsManager: pull_functions exception {e}\n{traceback.format_exc()}") + f"FunctionsManager: pull_functions exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() logging.info( - f"FunctionsManager: pull_functions operation took {end - start} seconds") - return response, end-start - - async def get_retrieved_nodes(self, memory: ContextualCompressionRetriever, query_str: str, category: str, user_id: str): + f"FunctionsManager: pull_functions operation took {end - start} seconds" + ) + return response, end - start + + async def get_retrieved_nodes( + self, + memory: ContextualCompressionRetriever, + query_str: str, + category: str, + user_id: str, + ): kwargs = {} if len(category) > 0: kwargs["extra_index"] = category @@ -228,7 +281,7 @@ async def get_retrieved_nodes(self, memory: ContextualCompressionRetriever, quer ), rest.IsNullCondition( is_null=rest.PayloadField(key="metadata.user_id") - ) + ), ] ) kwargs["user_filter"] = filter @@ -249,8 +302,7 @@ def load(self, api_key: str): start = time.time() memory = self.create_new_functions_retriever(api_key) end = time.time() - logging.info( - f"FunctionsManager: Load operation took {end - start} seconds") + logging.info(f"FunctionsManager: Load operation took {end - start} seconds") return memory async def push_functions(self, user_id: str, api_key: str, functions): @@ -261,10 +313,12 @@ async def push_functions(self, user_id: str, api_key: str, functions): try: logging.info("FunctionsManager: adding functions to index...") - function_types = ['information_retrieval', - 'communication', - 'data_processing', - 'sensory_perception'] + function_types = [ + "information_retrieval", + "communication", + "data_processing", + "sensory_perception", + ] all_docs = [] @@ -272,22 +326,28 @@ async def push_functions(self, user_id: str, api_key: str, functions): for func_type in function_types: if func_type in functions: transformed_functions = self.transform( - user_id, functions[func_type], func_type.replace('_', ' ').title()) + user_id, + functions[func_type], + func_type.replace("_", " ").title(), + ) all_docs.extend(transformed_functions) ids = [doc.metadata["id"] for doc in all_docs] - await self.rate_limiter.execute(memory.base_retriever.vectorstore.aadd_documents, all_docs, ids=ids) + await self.rate_limiter.execute( + memory.base_retriever.vectorstore.aadd_documents, all_docs, ids=ids + ) tokens = self.count_tokens(functions) except Exception as e: logging.warning( - f"FunctionsManager: push_functions exception {e}\n{traceback.format_exc()}") + f"FunctionsManager: push_functions exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() - logging.info( - f"FunctionsManager: push_functions took {end - start} seconds") - return tokens, end-start + logging.info(f"FunctionsManager: push_functions took {end - start} seconds") + return tokens, end - start def prune_functions(self): """Prune functions that haven't been used for atleast six weeks.""" + def attempt_prune(): current_time = datetime.now() six_weeks_ago = current_time - timedelta(weeks=6) @@ -300,18 +360,22 @@ def attempt_prune(): ] ) self.client.delete( - collection_name=self.collection_name, points_selector=filter) + collection_name=self.collection_name, points_selector=filter + ) + try: attempt_prune() except Exception as e: logging.warning( - f"FunctionsManager: prune_functions exception {e}\n{traceback.format_exc()}") + f"FunctionsManager: prune_functions exception {e}\n{traceback.format_exc()}" + ) # Attempt a second prune after reload try: attempt_prune() except Exception as e: # If prune after reload fails, propagate the error upwards logging.error( - f"FunctionsManager: prune_functions failed after reload, exception {e}\n{traceback.format_exc()}") + f"FunctionsManager: prune_functions failed after reload, exception {e}\n{traceback.format_exc()}" + ) raise return True diff --git a/functions_manager_test.py b/functions_manager_test.py index 2c35657..f5ded87 100644 --- a/functions_manager_test.py +++ b/functions_manager_test.py @@ -1,10 +1,13 @@ import unittest -from unittest.mock import patch, MagicMock -from functions_manager import FunctionsManager, ActionItem, FunctionInput +from unittest.mock import MagicMock, patch + +from functions_manager import ActionItem, FunctionInput, FunctionsManager from rate_limiter import RateLimiter, SyncRateLimiter -rate_limiter = RateLimiter(rate=5, period=1) + +rate_limiter = RateLimiter(rate=5, period=1) rate_limiter_sync = SyncRateLimiter(rate=5, period=1) + class TestFunctionsManager(unittest.TestCase): @patch("functions_manager.schedule") @patch("functions_manager.load_dotenv") @@ -14,8 +17,17 @@ class TestFunctionsManager(unittest.TestCase): @patch("functions_manager.load_index_from_storage") @patch("functions_manager.StorageContext") @patch("functions_manager.os.getenv") - def test_init(self, mock_getenv, mock_storage_context, mock_load_index_from_storage, - mock_llm_rerank, mock_service_context, mock_thread, mock_load_dotenv, mock_schedule): + def test_init( + self, + mock_getenv, + mock_storage_context, + mock_load_index_from_storage, + mock_llm_rerank, + mock_service_context, + mock_thread, + mock_load_dotenv, + mock_schedule, + ): mock_load_index_from_storage.return_value = None mock_llm_rerank.return_value = MagicMock() mock_service_context.from_defaults.return_value = MagicMock() @@ -32,8 +44,8 @@ def test_init(self, mock_getenv, mock_storage_context, mock_load_index_from_stor mock_thread.assert_called() fm.stop() - @patch('memory_manager.schedule.run_pending') - @patch('memory_manager.time.sleep') + @patch("memory_manager.schedule.run_pending") + @patch("memory_manager.time.sleep") def test_run_continuously(self, mock_sleep, mock_run_pending): mock_sleep.side_effect = lambda *args: exit(0) functions_manager = FunctionsManager(rate_limiter, rate_limiter_sync) @@ -42,28 +54,37 @@ def test_run_continuously(self, mock_sleep, mock_run_pending): functions_manager.run_continuously() except SystemExit: pass - #functions_manager.release_locks() # Assuming the FunctionsManager class has a method to release locks + # functions_manager.release_locks() # Assuming the FunctionsManager class has a method to release locks mock_run_pending.assert_called() - def test_transform(self): fm = FunctionsManager(rate_limiter, rate_limiter_sync) user_id = "2" data = [{"name": "test_name", "description": "test_description"}] result = fm.transform(user_id, data, "test_category") - expected_result = [{"name": "test_name", "description": "test_description", "category": "test_category"}] + expected_result = [ + { + "name": "test_name", + "description": "test_description", + "category": "test_category", + } + ] self.assertEqual(result, expected_result) fm.stop() @patch("functions_manager.Document") @patch("functions_manager.VectorStoreIndex") def test_push_functions(self, mock_index, mock_document): - + functions = { - 'information_retrieval': [{'name': 'function1', 'description': 'description1'}], - 'communication': [{'name': 'function2', 'description': 'description2'}], - 'data_processing': [{'name': 'function3', 'description': 'description3'}], - 'sensory_perception': [{'name': 'function4', 'description': 'description4'}] + "information_retrieval": [ + {"name": "function1", "description": "description1"} + ], + "communication": [{"name": "function2", "description": "description2"}], + "data_processing": [{"name": "function3", "description": "description3"}], + "sensory_perception": [ + {"name": "function4", "description": "description4"} + ], } fm = FunctionsManager(rate_limiter, rate_limiter_sync) for idx, func_type in enumerate(functions): @@ -77,7 +98,14 @@ def test_push_functions(self, mock_index, mock_document): @patch("functions_manager.LLMRerank") @patch("functions_manager.os.path.exists") @patch("functions_manager.os.path.isdir") - def test_load(self, mock_isdir, mock_exists, mock_llm_rerank, mock_index, mock_load_index_from_storage): + def test_load( + self, + mock_isdir, + mock_exists, + mock_llm_rerank, + mock_index, + mock_load_index_from_storage, + ): mock_exists.return_value = True mock_isdir.return_value = True mock_load_index_from_storage.return_value = MagicMock() @@ -95,16 +123,23 @@ def test_load(self, mock_isdir, mock_exists, mock_llm_rerank, mock_index, mock_l fm.stop() - # Similarly, you would write tests for other methods like save, pull_functions and count_tokens @patch("functions_manager.tiktoken.encoding_for_model") def test_count_tokens(self, mock_encoding_for_model): mock_encoding_for_model.return_value = MagicMock() - mock_encoding_for_model.return_value.encode.side_effect = ['token1', 'token2', 'token3'] - functions = {'information_retrieval': [{'name': 'function1', 'description': 'description1'}]} + mock_encoding_for_model.return_value.encode.side_effect = [ + "token1", + "token2", + "token3", + ] + functions = { + "information_retrieval": [ + {"name": "function1", "description": "description1"} + ] + } fm = FunctionsManager(rate_limiter, rate_limiter_sync) result = fm.count_tokens(functions) - expected_result = [{'function1': 6}] + expected_result = [{"function1": 6}] self.assertEqual(result, expected_result) fm.stop() @@ -117,7 +152,9 @@ def test_save(self): # Assert that persist method was called on the storage_context of the index fm.index.storage_context.persist.assert_called_once_with(persist_dir=fm.dirpath) - self.assertFalse(fm.dirty) # Ensure dirty flag is set to False after save operation + self.assertFalse( + fm.dirty + ) # Ensure dirty flag is set to False after save operation fm.stop() @@ -126,7 +163,9 @@ def test_pull_functions(self): fm.query_engine = MagicMock() # create FunctionInput instance - action_item = ActionItem(action='test_query', intent='intent_example', category='category_example') + action_item = ActionItem( + action="test_query", intent="intent_example", category="category_example" + ) function_input = FunctionInput(action_items=[action_item]) fm.pull_functions(function_input) @@ -140,6 +179,5 @@ def test_pull_functions(self): fm.stop() - if __name__ == "__main__": unittest.main() diff --git a/generative_conversation_summarized_memory.py b/generative_conversation_summarized_memory.py index 2e9a571..7ff94ef 100644 --- a/generative_conversation_summarized_memory.py +++ b/generative_conversation_summarized_memory.py @@ -1,21 +1,29 @@ +import json import logging import random import traceback -import json - from datetime import datetime, timedelta from typing import Any, Dict, List, Optional -from qdrant_retriever import MemoryType + from langchain.retrievers import ContextualCompressionRetriever -from langchain.schema import SystemMessage, HumanMessage, AIMessage -from langchain.schema import BaseMemory, Document +from langchain.schema import ( + AIMessage, + BaseMemory, + Document, + HumanMessage, + SystemMessage, +) from langchain.schema.language_model import BaseLanguageModel + +from qdrant_retriever import MemoryType from rate_limiter import RateLimiter logger = logging.getLogger(__name__) - + + class GenerativeAgentConversationSummarizedMemory(BaseMemory): """Conversations summarized for the generative agent.""" + rate_limiter: RateLimiter llm: BaseLanguageModel """The core language model.""" @@ -24,36 +32,48 @@ class GenerativeAgentConversationSummarizedMemory(BaseMemory): verbose: bool = False async def _init_summary_of_convo(self, doc0: str) -> str: - prompt = "Create a topic-based summarization of the conversation between the subjects in the user message. Be concise, do not add new details that is not in the provided text. Output is a new JSON object. example {\"topic\",\"topic summary\"}." + prompt = 'Create a topic-based summarization of the conversation between the subjects in the user message. Be concise, do not add new details that is not in the provided text. Output is a new JSON object. example {"topic","topic summary"}.' try: - messages = [[SystemMessage(content=prompt), - HumanMessage(content=doc0)]] + messages = [[SystemMessage(content=prompt), HumanMessage(content=doc0)]] response = await self.llm.agenerate(messages) if not response.generations or not response.generations[0]: raise Exception("LLM did not provide a valid summary response.") return response.generations[0][0].text except Exception as e: if self.verbose: - logging.warning(f"GenerativeAgentConversationSummarizedMemory: _init_summary_of_convo exception, e: {e}\n{traceback.format_exc()}") - return '' - + logging.warning( + f"GenerativeAgentConversationSummarizedMemory: _init_summary_of_convo exception, e: {e}\n{traceback.format_exc()}" + ) + return "" + async def _summarize_with_convo(self, new_text: str, existing_summary: str) -> str: prompt = "Combine user message into the existing summary (AI message) and return new topic-based summarization of the conversation of the subjects in the supplied conversation. Remove any redundancies. Create new topics if any of the information does not belong to existing topics. Be concise, do not add new details that is not in the provided text or in the existing summary. Output is the modified JSON object." - try: - messages = [[SystemMessage(content=prompt), - AIMessage(content=existing_summary), - HumanMessage(content=new_text)]] + try: + messages = [ + [ + SystemMessage(content=prompt), + AIMessage(content=existing_summary), + HumanMessage(content=new_text), + ] + ] response = await self.llm.agenerate(messages) if not response.generations or not response.generations[0]: raise Exception("LLM did not provide a valid summary response.") return response.generations[0][0].text except Exception as e: if self.verbose: - logging.warning(f"GenerativeAgentConversationSummarizedMemory: _summarize_with_convo exception, e: {e}\n{traceback.format_exc()}") - return '' - + logging.warning( + f"GenerativeAgentConversationSummarizedMemory: _summarize_with_convo exception, e: {e}\n{traceback.format_exc()}" + ) + return "" + async def add_memories( - self, qa: List[str], conversation_id: str, importance: List[str], memory_types: List[MemoryType], now: Optional[datetime] = None + self, + qa: List[str], + conversation_id: str, + importance: List[str], + memory_types: List[MemoryType], + now: Optional[datetime] = None, ) -> List[str]: """Add an observations or memories to the agent's memory.""" documents = [] @@ -63,23 +83,29 @@ async def add_memories( if memory_types[i] != MemoryType.CONSCIOUS_MEMORY or importance[i] == "low": continue metadata = { - "id": random.randint(0, 2**32 - 1), + "id": random.randint(0, 2**32 - 1), "extra_index": conversation_id, "created_at": nowStamp, } - doc = Document( - page_content=qa[i], - metadata=metadata - ) + doc = Document(page_content=qa[i], metadata=metadata) documents.append(doc) ids.append(metadata["id"]) if len(documents) > 0: - return await self.rate_limiter.execute(self.memory_retriever.base_retriever.vectorstore.aadd_documents, documents, ids=ids) + return await self.rate_limiter.execute( + self.memory_retriever.base_retriever.vectorstore.aadd_documents, + documents, + ids=ids, + ) else: return None async def add_memory( - self, memory_content: str, conversation_id: str, importance: str, memory_type: MemoryType, now: Optional[datetime] = None + self, + memory_content: str, + conversation_id: str, + importance: str, + memory_type: MemoryType, + now: Optional[datetime] = None, ) -> List[str]: """Add an observation or memory to the agent's memory.""" if memory_type != MemoryType.CONSCIOUS_MEMORY or importance == "low": @@ -91,7 +117,7 @@ async def add_memory( "created_at": nowStamp, } document = Document( - page_content=memory_content, + page_content=memory_content, metadata=metadata, ) # pull existing conversation and merge the two into a new memory @@ -99,10 +125,18 @@ async def add_memory( if doc is not None and len(doc.page_content) > 0: # summarize the two together document.metadata = doc.metadata - document.page_content = await self._summarize_with_convo(document.page_content, doc.page_content) + document.page_content = await self._summarize_with_convo( + document.page_content, doc.page_content + ) else: - document.page_content = await self._init_summary_of_convo(document.page_content) - return await self.rate_limiter.execute(self.memory_retriever.base_retriever.vectorstore.aadd_documents, [document], ids=[document.metadata["id"]]) + document.page_content = await self._init_summary_of_convo( + document.page_content + ) + return await self.rate_limiter.execute( + self.memory_retriever.base_retriever.vectorstore.aadd_documents, + [document], + ids=[document.metadata["id"]], + ) async def save_context(self, outputs: Dict[str, Any]) -> List[str]: """Save the context of this model run to memory.""" @@ -112,15 +146,21 @@ async def save_context(self, outputs: Dict[str, Any]) -> List[str]: importance = outputs.get("importance") conversation_id = outputs.get("conversation_id") if query: - qa = {'user': query, 'AiDA': aida} - return await self.add_memory(json.dumps(qa), conversation_id=conversation_id, memory_type=MemoryType.CONSCIOUS_MEMORY, importance=importance, now=now) + qa = {"user": query, "AiDA": aida} + return await self.add_memory( + json.dumps(qa), + conversation_id=conversation_id, + memory_type=MemoryType.CONSCIOUS_MEMORY, + importance=importance, + now=now, + ) return [] - - def get_conversation( - self, conversation_id: str - ) -> Document: + + def get_conversation(self, conversation_id: str) -> Document: """Fetch summarized conversation.""" - return self.memory_retriever.base_retriever.get_key_value_document("metadata.extra_index", conversation_id) + return self.memory_retriever.base_retriever.get_key_value_document( + "metadata.extra_index", conversation_id + ) def _time_ago(self, timestamp: float) -> str: """Return a rough string representation of the time passed since a timestamp.""" @@ -138,7 +178,7 @@ def _time_ago(self, timestamp: float) -> str: def memory_variables(self) -> List[str]: """Input keys this memory class will load dynamically.""" return [] - + def clear(self) -> None: return @@ -146,14 +186,16 @@ def load_memory_variables(self, **kwargs) -> Dict[str, str]: """Return key-value pairs given the text input to the chain.""" conversation_id = kwargs.pop("conversation_id") return { - "relevant_summary": self.format_summary_simple(self.get_conversation(conversation_id)), + "relevant_summary": self.format_summary_simple( + self.get_conversation(conversation_id) + ), } - + def format_summary_simple(self, conversation_summary: Document) -> str: now = datetime.now().timestamp() created_at = conversation_summary.metadata.get("created_at", now) created_ago = self._time_ago(created_at) - + # Extracting the extra_index (conversation_id) conversation_id = conversation_summary.metadata.get("extra_index", "N/A") - return f"(created: {created_ago}, conversation_id: {conversation_id}) {conversation_summary.page_content}" \ No newline at end of file + return f"(created: {created_ago}, conversation_id: {conversation_id}) {conversation_summary.page_content}" diff --git a/generative_memory.py b/generative_memory.py index 9150676..f7199f1 100644 --- a/generative_memory.py +++ b/generative_memory.py @@ -1,22 +1,30 @@ -import logging import json +import logging import random import traceback - from datetime import datetime, timedelta from typing import Any, Dict, List, Optional -from qdrant_retriever import MemoryType + from langchain.retrievers import ContextualCompressionRetriever -from langchain.schema import BaseMemory, Document +from langchain.schema import ( + AIMessage, + BaseMemory, + BaseMessage, + Document, + HumanMessage, + SystemMessage, +) from langchain.schema.language_model import BaseLanguageModel from langchain.utils import mock_now from qdrant_client.http import models as rest + from memory_summarizer import MemorySummarizer -from langchain.schema import SystemMessage, HumanMessage, AIMessage, BaseMessage +from qdrant_retriever import MemoryType from rate_limiter import RateLimiter logger = logging.getLogger(__name__) - + + class GenerativeAgentMemory(BaseMemory): rate_limiter: RateLimiter """Memory for the generative agent.""" @@ -47,16 +55,16 @@ def _extract_insights(text: str) -> List[str]: if not line.strip(): break insights.append(line.strip()) - insights_str = '\n'.join(insights) + insights_str = "\n".join(insights) return insights_str @staticmethod def _extract_importance(text: str) -> str: """Extract importance level from the provided text.""" - + # Split the text into lines lines = text.splitlines() - + # Find the line with "Importance:" for line in lines: if "Importance:" in line: @@ -66,21 +74,29 @@ def _extract_importance(text: str) -> str: return "low" - def format_memories_as_messages(self, relevant_memories: List[Document]) -> List[BaseMessage]: + def format_memories_as_messages( + self, relevant_memories: List[Document] + ) -> List[BaseMessage]: formatted_memories = [] for mem in relevant_memories: - memory = json.loads(mem.page_content) # Convert the JSON string to a dictionary + memory = json.loads( + mem.page_content + ) # Convert the JSON string to a dictionary formatted_memories.append(HumanMessage(content=f'Memory: {memory["user"]}')) formatted_memories.append(AIMessage(content=f'Memory: {memory["AiDA"]}')) return formatted_memories - async def _get_importance_and_insight(self, user: str, llm_response: str, conversation_id: str, role: str): + async def _get_importance_and_insight( + self, user: str, llm_response: str, conversation_id: str, role: str + ): """Reflect on recent query and generate 'insights'.""" if self.verbose: logger.info("AiDA is checking importance") kwargs = {} # lookup some relevant context for query classification - memoryDocuments = await self.memory_retriever.base_retriever.get_relevant_documents_for_reflection(json.dumps({'user': user, 'AiDA': llm_response}), conversation_id, **kwargs) + memoryDocuments = await self.memory_retriever.base_retriever.get_relevant_documents_for_reflection( + json.dumps({"user": user, "AiDA": llm_response}), conversation_id, **kwargs + ) try: memoryMessages = self.format_memories_as_messages(memoryDocuments) prompt = f""" @@ -112,37 +128,61 @@ async def _get_importance_and_insight(self, user: str, llm_response: str, conver return importance, insights except Exception as e: if self.verbose: - logging.warning(f"GenerativeAgentMemory: _get_importance_and_insight exception, e: {e}\n{traceback.format_exc()}") + logging.warning( + f"GenerativeAgentMemory: _get_importance_and_insight exception, e: {e}\n{traceback.format_exc()}" + ) return None, None - async def pause_to_reflect(self, outputs: Dict[str, Any], preferences_resolver) -> List[str]: + async def pause_to_reflect( + self, outputs: Dict[str, Any], preferences_resolver + ) -> List[str]: """Reflect on recent observations and generate 'insights'.""" new_insights = [] conversation_id = outputs.get("conversation_id") query = outputs.get("query") aida = outputs.get("llm_response") - now=datetime.now() + now = datetime.now() try: role = await preferences_resolver.get_role(conversation_id) if role is None: role = "ConversationGPT" - importance, insights = await self._get_importance_and_insight(query, aida, conversation_id, role) + importance, insights = await self._get_importance_and_insight( + query, aida, conversation_id, role + ) if importance == "high" and len(insights) > 0: if self.verbose: logger.info("AiDA is reflecting") # ensure we are dealing with non-core memories because reflections are sub-conscious thoughts - await self.add_memory(memory_content=json.dumps({'user': 'AiDA to reflect and generate insight', 'AiDA': insights}), conversation_id=conversation_id, importance="medium", memory_type=MemoryType.SUBCONSCIOUS_MEMORY, now=now) + await self.add_memory( + memory_content=json.dumps( + { + "user": "AiDA to reflect and generate insight", + "AiDA": insights, + } + ), + conversation_id=conversation_id, + importance="medium", + memory_type=MemoryType.SUBCONSCIOUS_MEMORY, + now=now, + ) new_insights.extend(insights) except Exception as e: - importance = 'low' + importance = "low" if self.verbose: - logging.warning(f"GenerativeAgentMemory: pause_to_reflect exception, e: {e}\n{traceback.format_exc()}") + logging.warning( + f"GenerativeAgentMemory: pause_to_reflect exception, e: {e}\n{traceback.format_exc()}" + ) outputs["importance"] = importance await self.save_context(outputs) return new_insights async def add_memories( - self, qa: List[str], conversation_id: str, importance: List[str], memory_types: List[MemoryType], now: Optional[datetime] = None + self, + qa: List[str], + conversation_id: str, + importance: List[str], + memory_types: List[MemoryType], + now: Optional[datetime] = None, ) -> List[str]: """Add an observations or memories to the agent's memory.""" documents = [] @@ -150,7 +190,7 @@ async def add_memories( nowStamp = now.timestamp() for i in range(len(qa)): metadata = { - "id": random.randint(0, 2**32 - 1), + "id": random.randint(0, 2**32 - 1), "extra_index": conversation_id, "created_at": nowStamp, "importance": importance[i], @@ -158,16 +198,22 @@ async def add_memories( "summarizations": 0, "memory_type": memory_types[i].value, } - doc = Document( - page_content=qa[i], - metadata=metadata - ) + doc = Document(page_content=qa[i], metadata=metadata) documents.append(doc) ids.append(metadata["id"]) - return await self.rate_limiter.execute(self.memory_retriever.base_retriever.vectorstore.aadd_documents, documents, ids=ids) + return await self.rate_limiter.execute( + self.memory_retriever.base_retriever.vectorstore.aadd_documents, + documents, + ids=ids, + ) async def add_memory( - self, memory_content: str, conversation_id: str, importance: str, memory_type: MemoryType, now: Optional[datetime] = None + self, + memory_content: str, + conversation_id: str, + importance: str, + memory_type: MemoryType, + now: Optional[datetime] = None, ) -> List[str]: """Add an observation or memory to the agent's memory.""" nowStamp = now.timestamp() @@ -175,20 +221,22 @@ async def add_memory( "id": random.randint(0, 2**32 - 1), "extra_index": conversation_id, "created_at": nowStamp, - "importance": importance, + "importance": importance, "last_accessed_at": nowStamp, "summarizations": 0, "memory_type": memory_type.value, } document = Document( - page_content=memory_content, + page_content=memory_content, metadata=metadata, ) - return await self.rate_limiter.execute(self.memory_retriever.base_retriever.vectorstore.aadd_documents, [document], ids=[metadata["id"]]) + return await self.rate_limiter.execute( + self.memory_retriever.base_retriever.vectorstore.aadd_documents, + [document], + ids=[metadata["id"]], + ) - async def fetch_memories( - self, topic: str, **kwargs: Any - ) -> List[Document]: + async def fetch_memories(self, topic: str, **kwargs: Any) -> List[Document]: """Fetch related memories.""" current_time = kwargs.get("current_time", None) conversation_id = kwargs.pop("conversation_id") @@ -197,14 +245,18 @@ async def fetch_memories( return await self.memory_retriever.ainvoke(topic) else: if conversation_id != "": - kwargs.update({"filter": rest.Filter( - must=[ - rest.FieldCondition( - key="metadata.extra_index", - match=rest.MatchValue(value=conversation_id), + kwargs.update( + { + "filter": rest.Filter( + must=[ + rest.FieldCondition( + key="metadata.extra_index", + match=rest.MatchValue(value=conversation_id), + ) + ] ) - ] - )}) + } + ) return await self.memory_retriever.ainvoke(topic, **kwargs) def _time_ago(self, timestamp: float) -> str: @@ -223,15 +275,21 @@ def format_memories_simple(self, relevant_memories: List[Document]) -> str: now = datetime.now().timestamp() formatted_memories = [] for mem in relevant_memories: - memory_type = MemoryType(mem.metadata["memory_type"]).name.replace("_", " ").lower() + memory_type = ( + MemoryType(mem.metadata["memory_type"]).name.replace("_", " ").lower() + ) summarizations_count = mem.metadata.get("summarizations", 0) - importance = mem.metadata.get("importance", "medium") # assuming "medium" as default importance + importance = mem.metadata.get( + "importance", "medium" + ) # assuming "medium" as default importance created_at = mem.metadata.get("created_at", now) created_ago = self._time_ago(created_at) - + # Extracting the extra_index (conversation_id) conversation_id = mem.metadata.get("extra_index", "N/A") - formatted_memories.append(f"({memory_type}, importance: {importance}, summarizations: {summarizations_count}, from: {created_ago}, conversation_id: {conversation_id}) {mem.page_content}") + formatted_memories.append( + f"({memory_type}, importance: {importance}, summarizations: {summarizations_count}, from: {created_ago}, conversation_id: {conversation_id}) {mem.page_content}" + ) return "; ".join(formatted_memories) @property @@ -244,14 +302,20 @@ async def load_memory_variables(self, **kwargs) -> Dict[str, str]: queries = kwargs.pop("queries") if queries is not None: relevant_memories = [ - mem for query in queries for mem in await self.fetch_memories(query, **kwargs) + mem + for query in queries + for mem in await self.fetch_memories(query, **kwargs) ] if len(relevant_memories) > 0: # update last_accessed_at/summarizations ids = [doc.metadata["id"] for doc in relevant_memories] for doc in relevant_memories: - doc.metadata.pop('relevance_score', None) - await self.rate_limiter.execute(self.memory_retriever.base_retriever.vectorstore.aadd_documents, relevant_memories, ids=ids) + doc.metadata.pop("relevance_score", None) + await self.rate_limiter.execute( + self.memory_retriever.base_retriever.vectorstore.aadd_documents, + relevant_memories, + ids=ids, + ) return { "relevant_memories": self.format_memories_simple(relevant_memories), } @@ -267,26 +331,41 @@ async def save_context(self, outputs: Dict[str, Any]) -> List[str]: user_id = outputs.get("user_id") api_key = outputs.get("api_key") if query: - qa = {'user': query, 'AiDA': aida} + qa = {"user": query, "AiDA": aida} await self.memory_summarizer.save(api_key, user_id, outputs) - return await self.add_memory(json.dumps(qa), conversation_id=conversation_id, memory_type=MemoryType.CONSCIOUS_MEMORY, importance=importance, now=now) + return await self.add_memory( + json.dumps(qa), + conversation_id=conversation_id, + memory_type=MemoryType.CONSCIOUS_MEMORY, + importance=importance, + now=now, + ) return [] - async def decay(self): """Decay all old memories by summarizing based on importance and summarization count.""" try: # Delete memories flagged as too old self.memory_retriever.base_retriever.delete_max_summarized() # Get the documents to summarize - documents = self.memory_retriever.base_retriever.get_documents_for_summarization() + documents = ( + self.memory_retriever.base_retriever.get_documents_for_summarization() + ) if len(documents) > 0: - await self.memory_summarizer.flexible_document_summarizer.asummarize(documents) + await self.memory_summarizer.flexible_document_summarizer.asummarize( + documents + ) # upsert entire document set to qdrant against existing IDs (stored in metadata) ids = [doc.metadata["id"] for doc in documents] - await self.rate_limiter.execute(self.memory_retriever.base_retriever.vectorstore.aadd_documents, documents, ids=ids) + await self.rate_limiter.execute( + self.memory_retriever.base_retriever.vectorstore.aadd_documents, + documents, + ids=ids, + ) except Exception as e: - logging.warning(f"GenerativeAgentMemory: decay_user exception {e}\n{traceback.format_exc()}") - + logging.warning( + f"GenerativeAgentMemory: decay_user exception {e}\n{traceback.format_exc()}" + ) + def clear(self) -> None: - return \ No newline at end of file + return diff --git a/main.py b/main.py index e750257..7d1a237 100644 --- a/main.py +++ b/main.py @@ -1,19 +1,27 @@ -import os +import asyncio import logging +import os import time -import asyncio +from cachetools import LRUCache, TTLCache from dotenv import load_dotenv from fastapi import FastAPI -from agent_manager import AgentManager, MemoryInput, MemoryOutput, ClearMemory -from web_manager import WebManager, HTMLInput, CacheHTML -from doc_manager import DocManager, DocAddInput, DocDeleteInput, DocSearchInput, CacheDoc -from functions_manager import FunctionsManager, FunctionInput, FunctionOutput -from queryplan_manager import QueryPlanManager, QueryPlanInput + +from agent_manager import AgentManager, ClearMemory, MemoryInput, MemoryOutput from cache_manager import CacheClearInput +from doc_manager import ( + CacheDoc, + DocAddInput, + DocDeleteInput, + DocManager, + DocSearchInput, +) +from functions_manager import FunctionInput, FunctionOutput, FunctionsManager from preferences_resolver import QueryPreferencesInput -from cachetools import TTLCache, LRUCache +from queryplan_manager import QueryPlanInput, QueryPlanManager from rate_limiter import RateLimiter, SyncRateLimiter +from web_manager import CacheHTML, HTMLInput, WebManager + rate_limiter = RateLimiter(rate=5, period=1) # Allow 5 tasks per second rate_limiter_sync = SyncRateLimiter(rate=5, period=1) # Load environment variables @@ -33,10 +41,15 @@ app = FastAPI() # Initialize logging -LOGFILE_PATH = os.path.join(os.path.dirname( - os.path.abspath(__file__)), 'app.log') -logging.basicConfig(filename=LOGFILE_PATH, filemode='w', - format='%(asctime)s.%(msecs)03d %(name)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', force=True, level=logging.INFO) +LOGFILE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "app.log") +logging.basicConfig( + filename=LOGFILE_PATH, + filemode="w", + format="%(asctime)s.%(msecs)03d %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + level=logging.INFO, +) functions_manager = FunctionsManager(rate_limiter, rate_limiter_sync) @@ -50,123 +63,157 @@ functioncache = TTLCache(maxsize=16384, ttl=36000) doccache = LRUCache(maxsize=16384) -@app.post('/get_preferences/') + +@app.post("/get_preferences/") async def getPreferences(preferences_query: QueryPreferencesInput): - logging.info(f'Get preferences for user {preferences_query.user_id}') + logging.info(f"Get preferences for user {preferences_query.user_id}") start = time.time() - response = await agent_manager.preferences_resolver.get_preferences(preferences_query.user_id) + response = await agent_manager.preferences_resolver.get_preferences( + preferences_query.user_id + ) if response is None: - logging.info(f'Preferences for user {preferences_query.user_id} does not exist, returnng default and making one in the background...') - asyncio.create_task(agent_manager.preferences_resolver.create_default_preferences(preferences_query.user_id)) + logging.info( + f"Preferences for user {preferences_query.user_id} does not exist, returnng default and making one in the background..." + ) + asyncio.create_task( + agent_manager.preferences_resolver.create_default_preferences( + preferences_query.user_id + ) + ) response = agent_manager.preferences_resolver.default_preferences end = time.time() - return {'response': response, 'elapsed_time': end - start} + return {"response": response, "elapsed_time": end - start} + -@app.post('/query_plan/') +@app.post("/query_plan/") async def writeQueryPlan(query_input: QueryPlanInput): result = queryplancache.get(query_input.conversation_id) if result is not None: - return {'response': result, 'elapsed_time': 0} - logging.info(f'Writing query plan for conversation_id {query_input.conversation_id}') - response, elapsed_time = await queryplan_manager.query_plan(agent_manager.preferences_resolver, query_input) - logging.info('Elapsed time for operation: %s', - elapsed_time) # log the elapsed time + return {"response": result, "elapsed_time": 0} + logging.info( + f"Writing query plan for conversation_id {query_input.conversation_id}" + ) + response, elapsed_time = await queryplan_manager.query_plan( + agent_manager.preferences_resolver, query_input + ) + logging.info("Elapsed time for operation: %s", elapsed_time) # log the elapsed time if response != "No plan needed": queryplancache[query_input.conversation_id] = response - return {'response': response, 'elapsed_time': elapsed_time} + return {"response": response, "elapsed_time": elapsed_time} -@app.post('/push_memory/') + +@app.post("/push_memory/") async def writeMemoryForUser(memory_output: MemoryOutput): """Endpoint to push memory for a specific user.""" - logging.info(f'Writing memory for user for user {memory_output.user_id}, conversation {memory_output.conversation_id}') + logging.info( + f"Writing memory for user for user {memory_output.user_id}, conversation {memory_output.conversation_id}" + ) elapsed_time = await agent_manager.push_memory(memory_output) - return {'elapsed_time': elapsed_time} + return {"elapsed_time": elapsed_time} + -@app.post('/pull_memory/') +@app.post("/pull_memory/") async def pullRelevantMemoriesForUser(memory_input: MemoryInput): """Endpoint to pull relevant memories for a specific user.""" - logging.info(f'Pulling relevant memories for user {memory_input.user_id}, conversation {memory_input.conversation_id}') + logging.info( + f"Pulling relevant memories for user {memory_input.user_id}, conversation {memory_input.conversation_id}" + ) memories, elapsed_time = await agent_manager.pull_memory(memory_input) - return {'response': memories, 'elapsed_time': elapsed_time} + return {"response": memories, "elapsed_time": elapsed_time} + -@app.post('/semantic_search_html/') +@app.post("/semantic_search_html/") async def semanticSearchHTML(function_input: HTMLInput): """Endpoint to conduct a semantic search in HTML content.""" result = searchhtmlcache.get(function_input) if result is not None: - return {'response': result, 'elapsed_time': 0} - logging.info('Semantic search HTML') + return {"response": result, "elapsed_time": 0} + logging.info("Semantic search HTML") results, elapsed_time = await web_manager.search_html(function_input) if len(results) > 0: searchhtmlcache[function_input] = results - return {'response': results, 'elapsed_time': elapsed_time} + return {"response": results, "elapsed_time": elapsed_time} + -@app.post('/is_html_search_cached/') +@app.post("/is_html_search_cached/") async def isHTMLSearchCached(cache_html: CacheHTML): """Endpoint to check if HTML content is cached.""" - logging.info('Checking if HTML results are cached') + logging.info("Checking if HTML results are cached") result, elapsed_time = web_manager.does_hash_exist(cache_html.hash) - return {'response': result, 'elapsed_time': elapsed_time} + return {"response": result, "elapsed_time": elapsed_time} -@app.post('/add_doc/') + +@app.post("/add_doc/") async def addDoc(function_input: DocAddInput): """Endpoint to conduct add HTML document for doc portal.""" - logging.info('add to Doc Portal') + logging.info("add to Doc Portal") results, elapsed_time = await doc_manager.add_doc(function_input) doccache.clear() - return {'response': results, 'elapsed_time': elapsed_time} + return {"response": results, "elapsed_time": elapsed_time} + -@app.post('/delete_doc/') +@app.post("/delete_doc/") async def deleteDoc(function_input: DocDeleteInput): """Endpoint to conduct delete HTML document from doc portal.""" - logging.info('delete from Doc Portal') + logging.info("delete from Doc Portal") results, elapsed_time = doc_manager.delete_doc(function_input) doccache.clear() - return {'response': results, 'elapsed_time': elapsed_time} + return {"response": results, "elapsed_time": elapsed_time} + -@app.post('/is_doc_cached/') +@app.post("/is_doc_cached/") async def isDocCached(function_input: CacheDoc): """Endpoint to check if doc content is cached.""" - logging.info('Checking if doc is cached') + logging.info("Checking if doc is cached") result, elapsed_time = doc_manager.does_source_exist(function_input) - return {'response': result, 'elapsed_time': elapsed_time} + return {"response": result, "elapsed_time": elapsed_time} -@app.post('/search_doc/') + +@app.post("/search_doc/") async def semanticSearchDoc(function_input: DocSearchInput): """Endpoint to conduct a semantic search in doc portal.""" result = doccache.get(function_input) if result is not None: - return {'response': result, 'elapsed_time': 0} - logging.info('Semantic search Doc Portal') + return {"response": result, "elapsed_time": 0} + logging.info("Semantic search Doc Portal") results, elapsed_time = await doc_manager.search_doc(function_input) if len(results) > 0: doccache[function_input] = results - return {'response': results, 'elapsed_time': elapsed_time} + return {"response": results, "elapsed_time": elapsed_time} + -@app.post('/get_functions/') +@app.post("/get_functions/") async def getFunctions(function_input: FunctionInput): """Endpoint to get functions based on provided input.""" result = functioncache.get(function_input) if result is not None: - logging.info(f'Found functions in cache, result {result}') - return {'response': result, 'elapsed_time': 0} - logging.info(f'Processing Action Item: {function_input.action_items}') + logging.info(f"Found functions in cache, result {result}") + return {"response": result, "elapsed_time": 0} + logging.info(f"Processing Action Item: {function_input.action_items}") result, elapsed_time = await functions_manager.pull_functions(function_input) if len(result) > 0: functioncache[function_input] = result - return {'response': result, 'elapsed_time': elapsed_time} + return {"response": result, "elapsed_time": elapsed_time} + -@app.post('/push_functions/') +@app.post("/push_functions/") async def pushFunctions(function_output: FunctionOutput): """Endpoint to push functions based on provided functions.""" - logging.info(f'Adding functions: {function_output.functions}') + logging.info(f"Adding functions: {function_output.functions}") functions = {} - function_types = ['information_retrieval', 'communication', 'data_processing', 'sensory_perception'] + function_types = [ + "information_retrieval", + "communication", + "data_processing", + "sensory_perception", + ] for function_item in function_output.functions: - function_item.category = function_item.category.lower().replace(' ', '_') + function_item.category = function_item.category.lower().replace(" ", "_") if function_item.category not in function_types: - return {'response': f'Invalid category for function {function_item.name}, must be one of {function_types}'} + return { + "response": f"Invalid category for function {function_item.name}, must be one of {function_types}" + } # Initialize category list if not already done if function_item.category not in functions: @@ -174,34 +221,39 @@ async def pushFunctions(function_output: FunctionOutput): # Append the new function to the category new_function = { - 'name': function_item.name, - 'description': function_item.description + "name": function_item.name, + "description": function_item.description, } functions[function_item.category].append(new_function) # Push the functions - result, elapsed_time = await functions_manager.push_functions(function_output.user_id, function_output.api_key, functions) - return {'response': result, 'elapsed_time': elapsed_time} + result, elapsed_time = await functions_manager.push_functions( + function_output.user_id, function_output.api_key, functions + ) + return {"response": result, "elapsed_time": elapsed_time} -@app.post('/clear_conversation/') + +@app.post("/clear_conversation/") async def clearUserMemory(clear_memory: ClearMemory): """Endpoint to clear memory for a specific user/conversation.""" logging.info( - f'Clearing user memory for user {clear_memory.user_id} and conversation {clear_memory.conversation_id}') + f"Clearing user memory for user {clear_memory.user_id} and conversation {clear_memory.conversation_id}" + ) response, elapsed_time = agent_manager.clear_conversation(clear_memory) - return {'response': response, 'elapsed_time': elapsed_time} + return {"response": response, "elapsed_time": elapsed_time} + -@app.post('/cache_clear/') +@app.post("/cache_clear/") async def clearCache(cache_clear_input: CacheClearInput): """Endpoint to clear caches.""" start = time.time() if not cache_clear_input.console_key.strip(): logging.warning("CacheManager: console key is empty, check settings!") - return {'response': "fail", 'elapsed_time': 0} + return {"response": "fail", "elapsed_time": 0} if CONSOLE_KEY != cache_clear_input.console_key: logging.warning("CacheManager: Invalid console key") - return {'response': "fail", 'elapsed_time': 0} + return {"response": "fail", "elapsed_time": 0} if {"doc", "all"} & set(cache_clear_input.cache_types): doccache.clear() if {"queryplan", "all"} & set(cache_clear_input.cache_types): @@ -211,4 +263,4 @@ async def clearCache(cache_clear_input: CacheClearInput): if {"function", "all"} & set(cache_clear_input.cache_types): searchhtmlcache.clear() end = time.time() - return {'response': "success", 'elapsed_time': end - start} \ No newline at end of file + return {"response": "success", "elapsed_time": end - start} diff --git a/memory_summarizer.py b/memory_summarizer.py index 137755d..9373d13 100644 --- a/memory_summarizer.py +++ b/memory_summarizer.py @@ -1,26 +1,34 @@ -import time import logging import os -import cachetools.func +import time +from typing import Any, Dict, List +import cachetools.func from dotenv import load_dotenv -from typing import Any, Dict, List -from document_summarizer import FlexibleDocumentSummarizer -from langchain_openai import ChatOpenAI +from langchain.retrievers import ContextualCompressionRetriever +from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_qdrant import Qdrant from qdrant_client.http import models as rest from qdrant_client.http.models import PayloadSchemaType -from langchain.retrievers import ContextualCompressionRetriever -from qdrant_retriever import QDrantVectorStoreRetriever + from cohere_rerank import CohereRerank -from langchain_openai import OpenAIEmbeddings -from generative_conversation_summarized_memory import GenerativeAgentConversationSummarizedMemory +from document_summarizer import FlexibleDocumentSummarizer +from generative_conversation_summarized_memory import ( + GenerativeAgentConversationSummarizedMemory, +) +from qdrant_retriever import QDrantVectorStoreRetriever class MemorySummarizer: flexible_document_summarizer: FlexibleDocumentSummarizer - def __init__(self, rate_limiter, rate_limiter_sync, flexible_document_summarizer, agent_manager): + def __init__( + self, + rate_limiter, + rate_limiter_sync, + flexible_document_summarizer, + agent_manager, + ): load_dotenv() # Load environment variables os.getenv("COHERE_API_KEY") self.QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") @@ -43,42 +51,62 @@ def create_new_conversation_summarizer(self, api_key: str, user_id: str): ), ) self.agent_manager.client.create_payload_index( - collection_name, "metadata.extra_index", field_schema=PayloadSchemaType.KEYWORD) + collection_name, + "metadata.extra_index", + field_schema=PayloadSchemaType.KEYWORD, + ) except: print("MemorySummarizer: loaded from cloud...") finally: logging.info( - f"MemorySummarizer: Creating memory store with collection {collection_name}") - vectorstore = Qdrant(self.agent_manager.client, collection_name, OpenAIEmbeddings( - model="text-embedding-3-small", openai_api_key=api_key)) + f"MemorySummarizer: Creating memory store with collection {collection_name}" + ) + vectorstore = Qdrant( + self.agent_manager.client, + collection_name, + OpenAIEmbeddings( + model="text-embedding-3-small", openai_api_key=api_key + ), + ) compressor = CohereRerank() compression_retriever = ContextualCompressionRetriever( - base_compressor=compressor, base_retriever=QDrantVectorStoreRetriever( - rate_limiter=self.rate_limiter, rate_limiter_sync=self.rate_limiter_sync, collection_name=collection_name, client=self.agent_manager.client, vectorstore=vectorstore, - ) + base_compressor=compressor, + base_retriever=QDrantVectorStoreRetriever( + rate_limiter=self.rate_limiter, + rate_limiter_sync=self.rate_limiter_sync, + collection_name=collection_name, + client=self.agent_manager.client, + vectorstore=vectorstore, + ), ) return compression_retriever def create_summarized_memory(self, api_key: str, user_id: str): return GenerativeAgentConversationSummarizedMemory( rate_limiter=self.rate_limiter, - llm=ChatOpenAI(openai_api_key=api_key, temperature=0, - max_tokens=2048, model="gpt-4.1-mini"), - memory_retriever=self.create_new_conversation_summarizer( - api_key, user_id), - verbose=self.agent_manager.verbose + llm=ChatOpenAI( + openai_api_key=api_key, + temperature=0, + max_tokens=2048, + model="gpt-4.1-mini", + ), + memory_retriever=self.create_new_conversation_summarizer(api_key, user_id), + verbose=self.agent_manager.verbose, ) @cachetools.func.ttl_cache(maxsize=16384, ttl=36000) - def load(self, api_key: str, user_id: str) -> GenerativeAgentConversationSummarizedMemory: + def load( + self, api_key: str, user_id: str + ) -> GenerativeAgentConversationSummarizedMemory: """Load existing index data from the cloud.""" start = time.time() retriever = self.create_summarized_memory(api_key, user_id) end = time.time() - logging.info( - f"MemorySummarizer: Load operation took {end - start} seconds") + logging.info(f"MemorySummarizer: Load operation took {end - start} seconds") return retriever - async def save(self, api_key: str, user_id: str, outputs: Dict[str, Any]) -> List[str]: + async def save( + self, api_key: str, user_id: str, outputs: Dict[str, Any] + ) -> List[str]: memory = self.load(api_key, user_id) await memory.save_context(outputs) diff --git a/preferences_resolver.py b/preferences_resolver.py index f806dda..52fc54f 100644 --- a/preferences_resolver.py +++ b/preferences_resolver.py @@ -1,15 +1,15 @@ - import logging import os import traceback +from asyncio import Lock -from motor.motor_asyncio import AsyncIOMotorClient -from pymongo.server_api import ServerApi from dotenv import load_dotenv from jsonpatch import JsonPatch, JsonPatchException +from motor.motor_asyncio import AsyncIOMotorClient from pydantic import BaseModel +from pymongo.server_api import ServerApi + from rate_limiter import RateLimiter -from asyncio import Lock class QueryPreferencesInput(BaseModel): @@ -25,33 +25,27 @@ def __init__(self): self.role_collection = None self.rate_limiter = None self.schema = { - 'name_nickname': "", - 'traits': [], - 'achievements': [], - 'mood_feelings': [], - 'goals': [], - 'tasks': [], - 'subtasks': [], - 'active_task_id': '', - 'active_subtask_id': '', - 'facts_opinions': [], - 'interests': [], - 'links': [], - 'skills': [], - 'occupations': [], - 'communication': { - 'data_sharing': { - 'preferences': True, - 'history': False + "name_nickname": "", + "traits": [], + "achievements": [], + "mood_feelings": [], + "goals": [], + "tasks": [], + "subtasks": [], + "active_task_id": "", + "active_subtask_id": "", + "facts_opinions": [], + "interests": [], + "links": [], + "skills": [], + "occupations": [], + "communication": { + "data_sharing": {"preferences": True, "history": False}, + "engagement": { + "contact_methods": ["text", "voice", "video"], + "DND": {"enabled": False, "times": "22:00-06:00"}, }, - 'engagement': { - 'contact_methods': ['text', 'voice', 'video'], - 'DND': { - 'enabled': False, - 'times': '22:00-06:00' - } - } - } + }, } self.default_preferences = None self.init_lock = Lock() @@ -61,60 +55,84 @@ async def initialize(self): if self.client is not None: return try: - self.client = AsyncIOMotorClient( - self.uri, server_api=ServerApi('1')) - await self.client.admin.command('ping') + self.client = AsyncIOMotorClient(self.uri, server_api=ServerApi("1")) + await self.client.admin.command("ping") print("Pinged your deployment. You successfully connected to MongoDB!") # Setup references after successful connection - self.db = self.client['PreferencesDB'] - self.pref_collection = self.db['Preferences'] - self.role_collection = self.db['Roles'] + self.db = self.client["PreferencesDB"] + self.pref_collection = self.db["Preferences"] + self.role_collection = self.db["Roles"] self.rate_limiter = RateLimiter(rate=10, period=1) self.default_preferences = self.schema except Exception as e: logging.warning( - f"PreferencesResolver: initialize exception {e}\n{traceback.format_exc()}") + f"PreferencesResolver: initialize exception {e}\n{traceback.format_exc()}" + ) async def get_preferences(self, user_id): - if self.client is None or self.pref_collection is None or self.rate_limiter is None: + if ( + self.client is None + or self.pref_collection is None + or self.rate_limiter is None + ): await self.initialize() try: - doc = await self.rate_limiter.execute(self.pref_collection.find_one, {"_id": user_id}) + doc = await self.rate_limiter.execute( + self.pref_collection.find_one, {"_id": user_id} + ) if doc is None: await self.create_default_preferences(user_id) return self.default_preferences return doc except Exception as e: logging.warning( - f"PreferencesResolver: get_preferences exception {e}\n{traceback.format_exc()}") + f"PreferencesResolver: get_preferences exception {e}\n{traceback.format_exc()}" + ) return None async def get_role(self, conversation_id): - if self.client is None or self.role_collection is None or self.rate_limiter is None: + if ( + self.client is None + or self.role_collection is None + or self.rate_limiter is None + ): await self.initialize() try: - roleObj = await self.rate_limiter.execute(self.role_collection.find_one, {"_id": conversation_id}) + roleObj = await self.rate_limiter.execute( + self.role_collection.find_one, {"_id": conversation_id} + ) if roleObj is not None: return roleObj["role"] else: return None except Exception as e: logging.warning( - f"PreferencesResolver: get_role exception {e}\n{traceback.format_exc()}") + f"PreferencesResolver: get_role exception {e}\n{traceback.format_exc()}" + ) return None async def set_role(self, role, conversation_id): - if self.client is None or self.role_collection is None or self.rate_limiter is None: + if ( + self.client is None + or self.role_collection is None + or self.rate_limiter is None + ): await self.initialize() try: roleObj = {"_id": conversation_id, "role": role} - update_result = await self.rate_limiter.execute(self.role_collection.update_one, {"_id": conversation_id}, {"$set": roleObj}, upsert=True) + update_result = await self.rate_limiter.execute( + self.role_collection.update_one, + {"_id": conversation_id}, + {"$set": roleObj}, + upsert=True, + ) if update_result.matched_count == 0 and update_result.upserted_id is None: logging.warning("No documents were inserted or updated.") except Exception as e: logging.warning( - f"PreferencesResolver: set_role exception {e}\n{traceback.format_exc()}") + f"PreferencesResolver: set_role exception {e}\n{traceback.format_exc()}" + ) return "failure" return "success" @@ -122,33 +140,45 @@ def get_schema(self): return self.schema async def create_default_preferences(self, user_id): - if self.client is None or self.pref_collection is None or self.rate_limiter is None: + if ( + self.client is None + or self.pref_collection is None + or self.rate_limiter is None + ): await self.initialize() try: - await self.rate_limiter.execute(self.pref_collection.insert_one, { - '_id': user_id, - **self.default_preferences - }) + await self.rate_limiter.execute( + self.pref_collection.insert_one, + {"_id": user_id, **self.default_preferences}, + ) except Exception as e: logging.warning( - f"PreferencesResolver: create_default_preferences exception {e}\n{traceback.format_exc()}") + f"PreferencesResolver: create_default_preferences exception {e}\n{traceback.format_exc()}" + ) def check_for_nested_duplicates(self, value, target): if isinstance(target, list): res = value in target return res elif isinstance(target, dict): - return any(self.check_for_nested_duplicates(value, sub_value) for sub_value in target.values()) + return any( + self.check_for_nested_duplicates(value, sub_value) + for sub_value in target.values() + ) else: return False async def apply_patch(self, user_id, doc, patch_data): - if self.client is None or self.pref_collection is None or self.rate_limiter is None: + if ( + self.client is None + or self.pref_collection is None + or self.rate_limiter is None + ): await self.initialize() # Make sure keys exist before applying patch for patch in patch_data: if patch["op"] in ["add", "replace"]: - keys = patch["path"].lstrip('/').split('/') + keys = patch["path"].lstrip("/").split("/") temp_doc = doc for i, key in enumerate(keys[:-1]): if isinstance(temp_doc, list): @@ -174,24 +204,27 @@ async def apply_patch(self, user_id, doc, patch_data): # Check for nested duplicates if patch["op"] == "add": if self.check_for_nested_duplicates(patch["value"], temp_doc): - logging.warning( - f"Duplicate patch {patch}, skipping...") + logging.warning(f"Duplicate patch {patch}, skipping...") continue # Check type and validity - if isinstance(temp_doc, list) and not keys[i + 1].isdigit() and keys[i + 1] != '-': - return f'Error: List indices must be integers or slices, not str. Patch {patch}' + if ( + isinstance(temp_doc, list) + and not keys[i + 1].isdigit() + and keys[i + 1] != "-" + ): + return f"Error: List indices must be integers or slices, not str. Patch {patch}" elif isinstance(temp_doc, dict) and keys[i + 1].isdigit(): - return f'Error: Dictionary keys must be strings, not integers. Patch {patch}' + return f"Error: Dictionary keys must be strings, not integers. Patch {patch}" # Check the final nested key last_key = keys[-1] if isinstance(temp_doc, list) and last_key.isdigit(): if int(last_key) >= len(temp_doc): return f"Error: Key '{last_key}' does not exist in the document. Patch {patch}" - elif isinstance(temp_doc, list) and last_key != '-': + elif isinstance(temp_doc, list) and last_key != "-": return f'Error: List indices must be integers or "-", not str. Patch {patch}' elif isinstance(temp_doc, dict) and last_key.isdigit(): - return f'Error: Dictionary keys must be strings, not integers. Patch {patch}' + return f"Error: Dictionary keys must be strings, not integers. Patch {patch}" # Apply the patch try: @@ -203,10 +236,15 @@ async def apply_patch(self, user_id, doc, patch_data): return f"An unknown exception occurred: {e}" try: # Update the database - update_result = await self.rate_limiter.execute(self.pref_collection.update_one, {"_id": user_id}, {"$set": modified_doc}) + update_result = await self.rate_limiter.execute( + self.pref_collection.update_one, + {"_id": user_id}, + {"$set": modified_doc}, + ) if update_result.modified_count == 0: logging.warning("No documents were updated.") except Exception as e: logging.warning( - f"PreferencesResolver: update_one exception {e}\n{traceback.format_exc()}") + f"PreferencesResolver: update_one exception {e}\n{traceback.format_exc()}" + ) return "success" diff --git a/preferences_updater.py b/preferences_updater.py index 462606f..c46f878 100644 --- a/preferences_updater.py +++ b/preferences_updater.py @@ -1,13 +1,13 @@ - -import traceback -import logging import json +import logging import re +import traceback +from typing import List -from langchain.schema import SystemMessage, HumanMessage, AIMessage +from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain_openai import ChatOpenAI + from preferences_resolver import PreferencesResolver -from typing import List class SystemPrompt: @@ -56,63 +56,79 @@ class PreferencesUpdater: _preferences_resolver: PreferencesResolver verbose: bool - def __init__(self, preferences_resolver: PreferencesResolver, verbose: bool = False) -> None: + def __init__( + self, preferences_resolver: PreferencesResolver, verbose: bool = False + ) -> None: self._preferences_resolver = preferences_resolver self.verbose = verbose - async def _get_json_patch_commands( - self, messages, llm: ChatOpenAI - ) -> List[str]: + async def _get_json_patch_commands(self, messages, llm: ChatOpenAI) -> List[str]: """Generate 'preference updates', based on pertinent memories.""" array_json = [] try: response = await llm.agenerate(messages) if not response.generations or not response.generations[0]: - raise Exception( - "LLM did not provide a valid summary response.") + raise Exception("LLM did not provide a valid summary response.") result = response.generations[0][0].text # Find the array in the output string using a regular expression - array_match = re.search( - r'OPS:\s*\[\s*(\{.*\})\s*\]', result, re.DOTALL) + array_match = re.search(r"OPS:\s*\[\s*(\{.*\})\s*\]", result, re.DOTALL) if array_match: - array_str = '[' + array_match.group(1) + ']' - array_str = array_str.replace( - "True", "true").replace("False", "false") + array_str = "[" + array_match.group(1) + "]" + array_str = array_str.replace("True", "true").replace("False", "false") # Parse the array string as JSON array_json = json.loads(array_str) except Exception as e: if self.verbose: logging.warning( - f"PreferencesUpdater: _get_json_patch_commands exception, e: {e}\n{traceback.format_exc()}") + f"PreferencesUpdater: _get_json_patch_commands exception, e: {e}\n{traceback.format_exc()}" + ) return array_json - async def update_preferences(self, llm: ChatOpenAI, user: str, ai: str, user_id: str): + async def update_preferences( + self, llm: ChatOpenAI, user: str, ai: str, user_id: str + ): """Reflect on recent observations and generate 'insights'.""" doc = await self._preferences_resolver.get_preferences(user_id) if doc is None: logging.warning( - f"PreferencesUpdater: No preferences found for user {user_id}") + f"PreferencesUpdater: No preferences found for user {user_id}" + ) return summary_prompt = SystemPrompt(doc) - messages = [[SystemMessage(content=summary_prompt.to_prompt_string()), - HumanMessage(content=user), - AIMessage(content=ai)]] + messages = [ + [ + SystemMessage(content=summary_prompt.to_prompt_string()), + HumanMessage(content=user), + AIMessage(content=ai), + ] + ] patch_commands = await self._get_json_patch_commands(messages, llm) if len(patch_commands) > 0: if self.verbose: logging.info("AiDA is trying to update preferences") - response = await self._preferences_resolver.apply_patch(user_id, doc, patch_commands) + response = await self._preferences_resolver.apply_patch( + user_id, doc, patch_commands + ) if response != "success": summary_prompt = SystemPrompt( - doc, "2. You have been given human feedback that your changes were not accepted due to syntax, you are to carefully analyze and respond with the correct OPS") - messages = [[SystemMessage(content=summary_prompt.to_prompt_string()), - HumanMessage(content=user), - AIMessage(content=ai), - HumanMessage(content=response)]] + doc, + "2. You have been given human feedback that your changes were not accepted due to syntax, you are to carefully analyze and respond with the correct OPS", + ) + messages = [ + [ + SystemMessage(content=summary_prompt.to_prompt_string()), + HumanMessage(content=user), + AIMessage(content=ai), + HumanMessage(content=response), + ] + ] patch_commands = await self._get_json_patch_commands(messages, llm) - response = await self._preferences_resolver.apply_patch(user_id, doc, patch_commands) + response = await self._preferences_resolver.apply_patch( + user_id, doc, patch_commands + ) if response != "success" and self.verbose: logging.warning( - f"PreferencesUpdater: preferences_resolver patch application failed: {response}") + f"PreferencesUpdater: preferences_resolver patch application failed: {response}" + ) diff --git a/qdrant_retriever.py b/qdrant_retriever.py index 41f387f..56b701e 100644 --- a/qdrant_retriever.py +++ b/qdrant_retriever.py @@ -1,18 +1,19 @@ # Importing necessary libraries and modules -from datetime import datetime +from datetime import datetime, timedelta from enum import Enum -from langchain.schema import BaseRetriever, Document -from qdrant_client import QdrantClient -from qdrant_client.http import models as rest -from datetime import timedelta -from langchain_qdrant import Qdrant -from rate_limiter import RateLimiter, SyncRateLimiter from typing import ( List, Optional, Tuple, ) +from langchain.schema import BaseRetriever, Document +from langchain_qdrant import Qdrant +from qdrant_client import QdrantClient +from qdrant_client.http import models as rest + +from rate_limiter import RateLimiter, SyncRateLimiter + class MemoryType(Enum): CONSCIOUS_MEMORY = 0 @@ -21,6 +22,7 @@ class MemoryType(Enum): class QDrantVectorStoreRetriever(BaseRetriever): """Retriever that combines embedding similarity with conversation matching scores in retrieving values.""" + rate_limiter: RateLimiter rate_limiter_sync: SyncRateLimiter collection_name: str @@ -40,27 +42,34 @@ class QDrantVectorStoreRetriever(BaseRetriever): class Config: """Configuration for this pydantic object.""" + arbitrary_types_allowed = True def _get_combined_score( self, document: Document, vector_relevance: Optional[float], - extra_index: str = None + extra_index: str = None, ) -> float: """Return the combined score for a document.""" score = 0 if vector_relevance is not None: score += vector_relevance - if extra_index is not None and extra_index != document.metadata.get("extra_index"): + if extra_index is not None and extra_index != document.metadata.get( + "extra_index" + ): score -= self.extra_index_penalty if document.metadata.get("memory_type") == MemoryType.SUBCONSCIOUS_MEMORY: score -= self.subconscious_memory_penalty return score - async def get_salient_docs(self, query: str, **kwargs) -> List[Tuple[Document, float]]: + async def get_salient_docs( + self, query: str, **kwargs + ) -> List[Tuple[Document, float]]: """Return documents that are salient to the query.""" - return await self.rate_limiter.execute(self.vectorstore.asimilarity_search_with_score, query, k=10, **kwargs) + return await self.rate_limiter.execute( + self.vectorstore.asimilarity_search_with_score, query, k=10, **kwargs + ) async def get_relevant_documents_for_reflection( self, query: str, conversation: str, **kwargs @@ -79,10 +88,12 @@ async def get_relevant_documents_for_reflection( docs_and_scores = await self.get_salient_docs(query, **kwargs) rescored_docs = [] for doc, relevance in docs_and_scores: - combined_score = self._get_combined_score( - doc, relevance, conversation) + combined_score = self._get_combined_score(doc, relevance, conversation) # Skip the document if it matches the given query, and conversation - if doc.page_content == query and doc.metadata["extra_index"] == conversation: + if ( + doc.page_content == query + and doc.metadata["extra_index"] == conversation + ): continue # Skip to the next iteration rescored_docs.append((doc, combined_score)) rescored_docs.sort(key=lambda x: x[1], reverse=True) @@ -108,22 +119,29 @@ def get_documents_for_summarization(self) -> List[Document]: rest.FieldCondition( key="metadata.summarizations", range=rest.Range(lt=self._max_summarizations), - ) + ), ] ) results, _ = self.rate_limiter_sync.execute( - self.client.scroll, collection_name=self.collection_name, scroll_filter=filter, limit=5000) + self.client.scroll, + collection_name=self.collection_name, + scroll_filter=filter, + limit=5000, + ) docs = [] for record in results: document = self.vectorstore._document_from_scored_point( - record, self.collection_name, self.vectorstore.content_payload_key, self.vectorstore.metadata_payload_key + record, + self.collection_name, + self.vectorstore.content_payload_key, + self.vectorstore.metadata_payload_key, ) # Increment the summarizations count - if 'summarizations' in document.metadata: - document.metadata['summarizations'] += 1 + if "summarizations" in document.metadata: + document.metadata["summarizations"] += 1 else: - document.metadata['summarizations'] = 1 + document.metadata["summarizations"] = 1 # summarize once every 2 weeks document.metadata["last_accessed_at"] = current_time.timestamp() docs.append(document) @@ -132,9 +150,7 @@ def get_documents_for_summarization(self) -> List[Document]: def _get_relevant_documents(self, *args, **kwargs): pass - async def _aget_relevant_documents( - self, query: str, **kwargs - ) -> List[Document]: + async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]: """Return documents that are relevant to the query.""" current_time = datetime.now().timestamp() extra_index = kwargs.pop("extra_index", None) @@ -150,12 +166,13 @@ async def _aget_relevant_documents( for doc, _ in rescored_docs: doc.metadata["last_accessed_at"] = current_time # Decrement the summarizations count - if 'summarizations' in doc.metadata and doc.metadata['summarizations'] > 0: - doc.metadata['summarizations'] -= 1 + if "summarizations" in doc.metadata and doc.metadata["summarizations"] > 0: + doc.metadata["summarizations"] -= 1 # Sort by score and extract just the documents - sorted_docs = [doc for doc, _ in sorted( - rescored_docs, key=lambda x: x[1], reverse=True)] + sorted_docs = [ + doc for doc, _ in sorted(rescored_docs, key=lambda x: x[1], reverse=True) + ] # Return just the list of Documents return sorted_docs @@ -170,10 +187,17 @@ def get_key_value_document(self, key, value) -> Document: ] ) record, _ = self.rate_limiter_sync.execute( - self.client.scroll, collection_name=self.collection_name, scroll_filter=filter, limit=1) + self.client.scroll, + collection_name=self.collection_name, + scroll_filter=filter, + limit=1, + ) if record is not None and len(record) > 0: return self.vectorstore._document_from_scored_point( - record[0], self.collection_name, self.vectorstore.content_payload_key, self.vectorstore.metadata_payload_key + record[0], + self.collection_name, + self.vectorstore.content_payload_key, + self.vectorstore.metadata_payload_key, ) else: return None @@ -189,4 +213,7 @@ def delete_max_summarized(self): ] ) self.rate_limiter_sync.execute( - self.client.delete, collection_name=self.collection_name, points_selector=filter) + self.client.delete, + collection_name=self.collection_name, + points_selector=filter, + ) diff --git a/queryplan_manager.py b/queryplan_manager.py index cc72db1..9579733 100644 --- a/queryplan_manager.py +++ b/queryplan_manager.py @@ -1,13 +1,14 @@ -import time -import logging import asyncio +import logging +import time import traceback +from langchain.schema import HumanMessage, SystemMessage from langchain_openai import ChatOpenAI from pydantic import BaseModel -from langchain.schema import SystemMessage, HumanMessage -from preferences_resolver import PreferencesResolver + from classify_prompts import ClassifyPrompts +from preferences_resolver import PreferencesResolver class QueryPlanInput(BaseModel): @@ -22,29 +23,40 @@ class QueryPlanManager: def __init__(self): self.classify_prompts = ClassifyPrompts() - async def query_plan(self, preferences_resolver: PreferencesResolver, query_input: QueryPlanInput): + async def query_plan( + self, preferences_resolver: PreferencesResolver, query_input: QueryPlanInput + ): start = time.time() roleDB = await preferences_resolver.get_role(query_input.conversation_id) if roleDB is None: try: - messages = [[SystemMessage(content=self.classify_prompts.to_prompt_string()), - HumanMessage(content=query_input.query)]] - llm = ChatOpenAI(model='gpt-4.1-mini', temperature=0, - max_tokens=8, openai_api_key=query_input.api_key) + messages = [ + [ + SystemMessage(content=self.classify_prompts.to_prompt_string()), + HumanMessage(content=query_input.query), + ] + ] + llm = ChatOpenAI( + model="gpt-4.1-mini", + temperature=0, + max_tokens=8, + openai_api_key=query_input.api_key, + ) response = await llm.agenerate(messages) if not response.generations or not response.generations[0]: - raise Exception( - "LLM did not provide a valid summary response.") + raise Exception("LLM did not provide a valid summary response.") result = response.generations[0][0].text role = self.classify_prompts.parseClassification(result) if role is None: end = time.time() return "No plan needed", {end - start} - asyncio.create_task(preferences_resolver.set_role( - result, query_input.conversation_id)) + asyncio.create_task( + preferences_resolver.set_role(result, query_input.conversation_id) + ) except Exception as e: logging.warning( - f"QueryPlanManager: query_plan exception, e: {e}\n{traceback.format_exc()}") + f"QueryPlanManager: query_plan exception, e: {e}\n{traceback.format_exc()}" + ) end = time.time() return "No plan needed", {end - start} else: @@ -54,5 +66,6 @@ async def query_plan(self, preferences_resolver: PreferencesResolver, query_inpu return "No plan needed", {end - start} end = time.time() logging.info( - f"QueryPlanManager: query_plan operation took {end - start} seconds") + f"QueryPlanManager: query_plan operation took {end - start} seconds" + ) return role, {end - start} diff --git a/rate_limiter.py b/rate_limiter.py index 03411a8..f1732ae 100644 --- a/rate_limiter.py +++ b/rate_limiter.py @@ -1,6 +1,7 @@ import asyncio import threading + class RateLimiter: def __init__(self, rate: int, period: int, retries: int = 3): self.rate = rate @@ -31,7 +32,8 @@ async def execute(self, task, *args, **kwargs): await asyncio.sleep(self.period) else: raise e from None - + + class SyncRateLimiter: def __init__(self, rate: int, period: int, max_retries: int = 3): self.rate = rate @@ -72,4 +74,4 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - pass \ No newline at end of file + pass diff --git a/reader_writer_lock.py b/reader_writer_lock.py index 057cb5d..bddd557 100644 --- a/reader_writer_lock.py +++ b/reader_writer_lock.py @@ -1,5 +1,6 @@ import threading + class ReaderWriterLock: def __init__(self): self._read_lock = threading.Lock() diff --git a/tests/agent_manager_test.py b/tests/agent_manager_test.py index 9b6cdcf..d6f58ef 100644 --- a/tests/agent_manager_test.py +++ b/tests/agent_manager_test.py @@ -1,70 +1,91 @@ -import pytest from unittest.mock import AsyncMock, patch + +import pytest +from langchain.retrievers.contextual_compression import ContextualCompressionRetriever +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor +from langchain.schema import BaseRetriever +from langchain.schema.language_model import BaseLanguageModel +from pydantic import BaseModel, Field + from agent_manager import AgentManager, MemoryOutput +from document_summarizer import FlexibleDocumentSummarizer from generative_memory import GenerativeAgentMemory -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema import BaseRetriever -from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from memory_summarizer import MemorySummarizer -from document_summarizer import FlexibleDocumentSummarizer -from pydantic import BaseModel, Field -from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from rate_limiter import RateLimiter, SyncRateLimiter + rate_limiter = RateLimiter(rate=5, period=1) # Allow 5 tasks per second rate_limiter_sync = SyncRateLimiter(rate=5, period=1) + class MockLanguageModel(BaseLanguageModel): async def agenerate_prompt(self, *args, **kwargs): pass + async def apredict(self, *args, **kwargs): pass + async def apredict_messages(self, *args, **kwargs): pass + def generate_prompt(self, *args, **kwargs): pass + def invoke(self, *args, **kwargs): pass + def predict(self, *args, **kwargs): pass + def predict_messages(self, *args, **kwargs): pass class Config: arbitrary_types_allowed = True + class MockBaseDocumentCompressor(BaseDocumentCompressor): async def acompress_documents(self, *args, **kwargs): pass + def compress_documents(self, *args, **kwargs): pass + class MockVectorStore: aadd_documents = AsyncMock() + class MockBaseRetriever(BaseRetriever): vectorstore: MockVectorStore = MockVectorStore() def _get_relevant_documents(self, *args, **kwargs): pass + class MockMemoryRetriever(ContextualCompressionRetriever): - base_compressor: MockBaseDocumentCompressor = Field(default_factory=MockBaseDocumentCompressor) + base_compressor: MockBaseDocumentCompressor = Field( + default_factory=MockBaseDocumentCompressor + ) base_retriever: MockBaseRetriever = Field(default_factory=MockBaseRetriever) class Config: arbitrary_types_allowed = True + class MockFlexibleDocumentSummarizer(FlexibleDocumentSummarizer): pass + class MockMemorySummarizer(MemorySummarizer): pass + @pytest.fixture def setup_agent_manager(): agent_manager = AgentManager(rate_limiter, rate_limiter_sync) return agent_manager + @pytest.mark.asyncio async def test_push_memory(setup_agent_manager): agent_manager = setup_agent_manager @@ -74,35 +95,48 @@ async def test_push_memory(setup_agent_manager): query="test_query", llm_response="test_llm_response", conversation_id="test_conversation_id", - importance="high" + importance="high", ) - with patch.object(agent_manager, 'load', return_value=GenerativeAgentMemory( - rate_limiter=rate_limiter, - llm=MockLanguageModel(), - memory_retriever=MockMemoryRetriever(), - memory_summarizer=MockMemorySummarizer( + with patch.object( + agent_manager, + "load", + return_value=GenerativeAgentMemory( rate_limiter=rate_limiter, - rate_limiter_sync=rate_limiter_sync, - flexible_document_summarizer=MockFlexibleDocumentSummarizer( - llm=MockLanguageModel() + llm=MockLanguageModel(), + memory_retriever=MockMemoryRetriever(), + memory_summarizer=MockMemorySummarizer( + rate_limiter=rate_limiter, + rate_limiter_sync=rate_limiter_sync, + flexible_document_summarizer=MockFlexibleDocumentSummarizer( + llm=MockLanguageModel() + ), + agent_manager=setup_agent_manager, ), - agent_manager=setup_agent_manager + verbose=True, ), - verbose=True - )) as mock_load: + ) as mock_load: result = await agent_manager.push_memory(memory_output) assert isinstance(result, float) mock_load.assert_called_once_with(memory_output.api_key, memory_output.user_id) + def test_create_new_memory_retriever(setup_agent_manager): agent_manager = setup_agent_manager - with patch.object(agent_manager.client, 'create_collection') as mock_create_collection, \ - patch.object(agent_manager.client, 'create_payload_index') as mock_create_payload_index: - result = agent_manager.create_new_memory_retriever("test_api_key", "test_user_id") + with ( + patch.object( + agent_manager.client, "create_collection" + ) as mock_create_collection, + patch.object( + agent_manager.client, "create_payload_index" + ) as mock_create_payload_index, + ): + result = agent_manager.create_new_memory_retriever( + "test_api_key", "test_user_id" + ) assert isinstance(result, ContextualCompressionRetriever) mock_create_collection.assert_called_once() - mock_create_payload_index.assert_called_once() \ No newline at end of file + mock_create_payload_index.assert_called_once() diff --git a/tests/conftest.py b/tests/conftest.py index ba95da3..94acc1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ """Test configuration and fixtures for SuperDappAI tests.""" import os +from unittest.mock import MagicMock, Mock, patch + import pytest -from unittest.mock import Mock, patch, MagicMock + from rate_limiter import RateLimiter, SyncRateLimiter @@ -12,7 +14,7 @@ def rate_limiter(): return RateLimiter(rate=5, period=1) -@pytest.fixture +@pytest.fixture def rate_limiter_sync(): """Provide a sync rate limiter for tests.""" return SyncRateLimiter(rate=5, period=1) @@ -21,39 +23,38 @@ def rate_limiter_sync(): @pytest.fixture def mock_openai_api_key(): """Mock OpenAI API key for tests.""" - with patch.dict(os.environ, {'OPENAI_API_KEY': 'test-api-key'}): - yield 'test-api-key' + with patch.dict(os.environ, {"OPENAI_API_KEY": "test-api-key"}): + yield "test-api-key" @pytest.fixture def mock_cohere_api_key(): """Mock Cohere API key for tests.""" - with patch.dict(os.environ, {'COHERE_API_KEY': 'test-cohere-key'}): - yield 'test-cohere-key' + with patch.dict(os.environ, {"COHERE_API_KEY": "test-cohere-key"}): + yield "test-cohere-key" @pytest.fixture def mock_qdrant_config(): """Mock Qdrant configuration for tests.""" - with patch.dict(os.environ, { - 'QDRANT_API_KEY': 'test-qdrant-key', - 'QDRANT_URL': 'http://localhost:6333' - }): - yield { - 'api_key': 'test-qdrant-key', - 'url': 'http://localhost:6333' - } + with patch.dict( + os.environ, + {"QDRANT_API_KEY": "test-qdrant-key", "QDRANT_URL": "http://localhost:6333"}, + ): + yield {"api_key": "test-qdrant-key", "url": "http://localhost:6333"} @pytest.fixture def mock_mongodb_url(): """Mock MongoDB URL for tests.""" - with patch.dict(os.environ, {'MONGODB_URL': 'mongodb://localhost:27017/test'}): - yield 'mongodb://localhost:27017/test' + with patch.dict(os.environ, {"MONGODB_URL": "mongodb://localhost:27017/test"}): + yield "mongodb://localhost:27017/test" @pytest.fixture -def mock_all_env_vars(mock_openai_api_key, mock_cohere_api_key, mock_qdrant_config, mock_mongodb_url): +def mock_all_env_vars( + mock_openai_api_key, mock_cohere_api_key, mock_qdrant_config, mock_mongodb_url +): """Mock all required environment variables for tests.""" pass @@ -84,4 +85,4 @@ def mock_vectorstore(): mock.add_documents = Mock() mock.aadd_documents = Mock() mock.search = Mock(return_value=[]) - return mock \ No newline at end of file + return mock diff --git a/tests/generative_memory_test.py b/tests/generative_memory_test.py index 4adb3e7..f8244f0 100644 --- a/tests/generative_memory_test.py +++ b/tests/generative_memory_test.py @@ -1,49 +1,64 @@ -import pytest -from unittest.mock import AsyncMock from datetime import datetime -from generative_memory import GenerativeAgentMemory, MemoryType +from unittest.mock import AsyncMock + +import pytest from langchain.base_language import BaseLanguageModel -from memory_summarizer import MemorySummarizer -from document_summarizer import FlexibleDocumentSummarizer -from agent_manager import AgentManager -from pydantic import BaseModel -from langchain.schema.retriever import BaseRetriever -from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor +from langchain.schema.retriever import BaseRetriever +from pydantic import BaseModel + +from agent_manager import AgentManager +from document_summarizer import FlexibleDocumentSummarizer +from generative_memory import GenerativeAgentMemory, MemoryType +from memory_summarizer import MemorySummarizer from rate_limiter import RateLimiter, SyncRateLimiter + rate_limiter = RateLimiter(rate=5, period=1) # Allow 5 tasks per second rate_limiter_sync = SyncRateLimiter(rate=5, period=1) from unittest.mock import AsyncMock + from pydantic import Field + class MockLanguageModel(BaseLanguageModel): async def agenerate_prompt(self, *args, **kwargs): pass + async def apredict(self, *args, **kwargs): pass + async def apredict_messages(self, *args, **kwargs): pass + def generate_prompt(self, *args, **kwargs): pass + def invoke(self, *args, **kwargs): pass + def predict(self, *args, **kwargs): pass + def predict_messages(self, *args, **kwargs): pass class Config: arbitrary_types_allowed = True + class MockFlexibleDocumentSummarizer(FlexibleDocumentSummarizer): asummarize = AsyncMock() def __init__(self): super().__init__(llm=MockLanguageModel()) + + class MockAgentManager(AgentManager): pass + class MockMemorySummarizer(MemorySummarizer): save = AsyncMock() asummarize = AsyncMock() @@ -53,29 +68,39 @@ def __init__(self): rate_limiter=rate_limiter, rate_limiter_sync=rate_limiter_sync, flexible_document_summarizer=MockFlexibleDocumentSummarizer(), - agent_manager=MockAgentManager(rate_limiter, rate_limiter_sync) + agent_manager=MockAgentManager(rate_limiter, rate_limiter_sync), ) + class MockBaseDocumentCompressor(BaseDocumentCompressor): async def acompress_documents(self, *args, **kwargs): pass + def compress_documents(self, *args, **kwargs): pass + class MockVectorStore: aadd_documents = AsyncMock() + class MockBaseRetriever(BaseRetriever): vectorstore: MockVectorStore = MockVectorStore() def _get_relevant_documents(self, *args, **kwargs): pass + + class MockMemoryRetriever(ContextualCompressionRetriever): - base_compressor: MockBaseDocumentCompressor = Field(default_factory=MockBaseDocumentCompressor) + base_compressor: MockBaseDocumentCompressor = Field( + default_factory=MockBaseDocumentCompressor + ) base_retriever: MockBaseRetriever = Field(default_factory=MockBaseRetriever) class Config: arbitrary_types_allowed = True + + @pytest.fixture def setup_generative_agent_memory(): llm = MockLanguageModel() @@ -86,10 +111,11 @@ def setup_generative_agent_memory(): llm=llm, memory_retriever=memory_retriever, memory_summarizer=memory_summarizer, - verbose=True + verbose=True, ) return generative_agent_memory + @pytest.mark.asyncio async def test_add_memory(setup_generative_agent_memory): generative_agent_memory = setup_generative_agent_memory @@ -100,13 +126,16 @@ async def test_add_memory(setup_generative_agent_memory): timestamp = datetime.now() # Mock the aadd_documents method to return a specific result - generative_agent_memory.memory_retriever.base_retriever.vectorstore.aadd_documents.return_value = ["mock_id"] + generative_agent_memory.memory_retriever.base_retriever.vectorstore.aadd_documents.return_value = [ + "mock_id" + ] - result = await generative_agent_memory.add_memory(memory_content, conversation_id, importance, memory_type, now=timestamp) + result = await generative_agent_memory.add_memory( + memory_content, conversation_id, importance, memory_type, now=timestamp + ) # Check that the method was called with the correct arguments generative_agent_memory.memory_retriever.base_retriever.vectorstore.aadd_documents.assert_called_once() # Check that the result is as expected assert result == ["mock_id"] - diff --git a/tests/qdrant_retriever_test.py b/tests/qdrant_retriever_test.py index d3f8749..a1270d3 100644 --- a/tests/qdrant_retriever_test.py +++ b/tests/qdrant_retriever_test.py @@ -1,15 +1,18 @@ import os +from datetime import datetime + import pytest -from qdrant_retriever import QDrantVectorStoreRetriever -from qdrant_client.http import models as rest -from qdrant_client.http.models import PayloadSchemaType +from dotenv import load_dotenv from langchain.schema import Document -from qdrant_client import QdrantClient -from langchain_qdrant import Qdrant from langchain_openai import OpenAIEmbeddings -from datetime import datetime -from dotenv import load_dotenv +from langchain_qdrant import Qdrant +from qdrant_client import QdrantClient +from qdrant_client.http import models as rest +from qdrant_client.http.models import PayloadSchemaType + +from qdrant_retriever import QDrantVectorStoreRetriever from rate_limiter import RateLimiter, SyncRateLimiter + rate_limiter = RateLimiter(rate=5, period=1) # Allow 5 tasks per second rate_limiter_sync = SyncRateLimiter(rate=5, period=1) @@ -32,7 +35,10 @@ def setup_retriever(): ), ) client.create_payload_index( - collection_name, "metadata.extra_index", field_schema=PayloadSchemaType.KEYWORD) + collection_name, + "metadata.extra_index", + field_schema=PayloadSchemaType.KEYWORD, + ) except: print("MemorySummarizer: loaded from cloud...") finally: @@ -54,10 +60,16 @@ def setup_retriever(): ) vectorstore.add_documents([document], ids=[metadata["id"]]) - retriever = QDrantVectorStoreRetriever(rate_limiter=rate_limiter, rate_limiter_sync=rate_limiter_sync, - client=client, vectorstore=vectorstore, collection_name=collection_name) + retriever = QDrantVectorStoreRetriever( + rate_limiter=rate_limiter, + rate_limiter_sync=rate_limiter_sync, + client=client, + vectorstore=vectorstore, + collection_name=collection_name, + ) return retriever + # def test_get_salient_docs(setup_retriever): # retriever = setup_retriever # query = "test_query" @@ -81,8 +93,8 @@ async def test_get_salient_docs(setup_retriever): for doc, score in docs: assert isinstance(doc, Document) - assert hasattr(doc, 'page_content') - assert hasattr(doc, 'metadata') + assert hasattr(doc, "page_content") + assert hasattr(doc, "metadata") assert isinstance(score, float) diff --git a/tests/test_document_summarizer.py b/tests/test_document_summarizer.py index 9489f66..51a8d3c 100644 --- a/tests/test_document_summarizer.py +++ b/tests/test_document_summarizer.py @@ -1,35 +1,54 @@ -import pytest from unittest.mock import AsyncMock, MagicMock -from langchain.schema import Document -from langchain.schema import SystemMessage, HumanMessage + +import pytest +from langchain.schema import Document, HumanMessage, SystemMessage from langchain_openai import ChatOpenAI + from document_summarizer import FlexibleDocumentSummarizer, SummaryPrompt @pytest.mark.asyncio async def test_flexible_document_summarizer(): mock_llm = AsyncMock(ChatOpenAI) - mock_llm.agenerate.return_value = MagicMock(generations=[[MagicMock( - text="user summary text")], [MagicMock(text="aida summary text")]]) + mock_llm.agenerate.return_value = MagicMock( + generations=[ + [MagicMock(text="user summary text")], + [MagicMock(text="aida summary text")], + ] + ) summarizer = FlexibleDocumentSummarizer(llm=mock_llm, verbose=True) mock_document = MagicMock(Document) mock_document.metadata = {"summarizations": 2, "importance": "high"} - mock_document.page_content = "{\"user\": \"user original text\", \"AiDA\": \"aida original text\"}" + mock_document.page_content = ( + '{"user": "user original text", "AiDA": "aida original text"}' + ) await summarizer._get_single_summary(mock_document) # assertions mock_llm.agenerate.assert_called_once() expected_prompt = SummaryPrompt( - summarizations=2, importance="high").to_prompt_string() - expected_messages = [[SystemMessage(content=expected_prompt), HumanMessage(content="user original text")], [ - SystemMessage(content=expected_prompt), HumanMessage(content="aida original text")]] + summarizations=2, importance="high" + ).to_prompt_string() + expected_messages = [ + [ + SystemMessage(content=expected_prompt), + HumanMessage(content="user original text"), + ], + [ + SystemMessage(content=expected_prompt), + HumanMessage(content="aida original text"), + ], + ] assert mock_llm.agenerate.call_args[0][0] == expected_messages # Verify content was summarized - assert mock_document.page_content == "{\"user\": \"user summary text\", \"AiDA\": \"aida summary text\"}" + assert ( + mock_document.page_content + == '{"user": "user summary text", "AiDA": "aida summary text"}' + ) # test for multiple documents with asummarize method documents = [mock_document] * 5 # replace the number with your need diff --git a/tests/test_functions_manager.py b/tests/test_functions_manager.py index 2d98cca..1c1672f 100644 --- a/tests/test_functions_manager.py +++ b/tests/test_functions_manager.py @@ -1,23 +1,26 @@ -import pytest import os -from functions_manager import FunctionsManager, ActionItem, FunctionInput + +import pytest from dotenv import load_dotenv from langchain.schema import Document -from qdrant_client.http.models import ScoredPoint from qdrant_client import QdrantClient +from qdrant_client.http import models as rest from qdrant_client.http.exceptions import UnexpectedResponse +from qdrant_client.http.models import ScoredPoint + +from functions_manager import ActionItem, FunctionInput, FunctionsManager from rate_limiter import RateLimiter, SyncRateLimiter -from qdrant_client.http import models as rest rate_limiter = RateLimiter(rate=5, period=1) rate_limiter_sync = SyncRateLimiter(rate=5, period=1) -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def create_collections(): load_dotenv() - client = QdrantClient(url=os.getenv("QDRANT_URL"), - api_key=os.getenv("QDRANT_API_KEY")) + client = QdrantClient( + url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY") + ) collections = ["functions"] for collection in collections: @@ -30,16 +33,14 @@ def create_collections(): ), ) except Exception as e: - print( - f"Collection {collection} already exists or failed to create: {e}") + print(f"Collection {collection} already exists or failed to create: {e}") class TestFunctionsManager: @pytest.fixture(autouse=True) def setup_teardown(self): load_dotenv() - self.functions_manager = FunctionsManager( - rate_limiter, rate_limiter_sync) + self.functions_manager = FunctionsManager(rate_limiter, rate_limiter_sync) yield def test_transform(self): @@ -54,10 +55,18 @@ def test_transform(self): def test_count_tokens(self): functions = { - "information_retrieval": [{"name": "test function", "description": "test description"}], - "communication": [{"name": "test function", "description": "test description"}], - "data_processing": [{"name": "test function", "description": "test description"}], - "sensory_perception": [{"name": "test function", "description": "test description"}], + "information_retrieval": [ + {"name": "test function", "description": "test description"} + ], + "communication": [ + {"name": "test function", "description": "test description"} + ], + "data_processing": [ + {"name": "test function", "description": "test description"} + ], + "sensory_perception": [ + {"name": "test function", "description": "test description"} + ], } tokens = self.functions_manager.count_tokens(functions) @@ -72,33 +81,37 @@ def test_extract_name_and_category(self): payload={"name": "test1", "category": "cat1"}, vector=[0.1, 0.2, 0.3], score=0.9, - version=1 + version=1, ), ScoredPoint( id=2, payload={"name": "test2", "category": "cat2"}, vector=[0.4, 0.5, 0.6], score=0.8, - version=1 - ) + version=1, + ), ] extracted = self.functions_manager.extract_name_and_category(documents) - expected_output = [{'name': 'test1', 'category': 'cat1'}, { - 'name': 'test2', 'category': 'cat2'}] + expected_output = [ + {"name": "test1", "category": "cat1"}, + {"name": "test2", "category": "cat2"}, + ] assert extracted == expected_output @pytest.mark.asyncio async def test_pull_functions(self): - function_input = FunctionInput(api_key=os.getenv("OPENAI_API_KEY"), - action_items=[ActionItem( - action="act", intent="int", category="cat")] - ) + function_input = FunctionInput( + api_key=os.getenv("OPENAI_API_KEY"), + action_items=[ActionItem(action="act", intent="int", category="cat")], + ) try: - response, time_taken = await self.functions_manager.pull_functions(function_input) + response, time_taken = await self.functions_manager.pull_functions( + function_input + ) except UnexpectedResponse as e: pytest.fail(f"UnexpectedResponse: {e}") @@ -107,8 +120,7 @@ async def test_pull_functions(self): @pytest.mark.asyncio async def test_load(self): - response = self.functions_manager.load( - api_key=os.getenv("OPENAI_API_KEY")) + response = self.functions_manager.load(api_key=os.getenv("OPENAI_API_KEY")) assert response is not None print(response) diff --git a/tests/test_generative_conversation_summarized_memory.py b/tests/test_generative_conversation_summarized_memory.py index f5f91a0..0135a2b 100644 --- a/tests/test_generative_conversation_summarized_memory.py +++ b/tests/test_generative_conversation_summarized_memory.py @@ -1,12 +1,17 @@ -import unittest import asyncio +import os +import unittest from datetime import datetime -from generative_conversation_summarized_memory import GenerativeAgentConversationSummarizedMemory, MemoryType + +from dotenv import load_dotenv +from langchain_openai import ChatOpenAI + from agent_manager import AgentManager +from generative_conversation_summarized_memory import ( + GenerativeAgentConversationSummarizedMemory, + MemoryType, +) from rate_limiter import RateLimiter, SyncRateLimiter -from langchain_openai import ChatOpenAI -import os -from dotenv import load_dotenv class TestGenerativeAgentConversationSummarizedMemory(unittest.TestCase): @@ -18,18 +23,18 @@ def setUpClass(cls): load_dotenv() # Load environment variables from .env file api_key = os.getenv("OPENAI_API_KEY") if not api_key: - raise ValueError( - "The OPENAI_API_KEY environment variable is missing") + raise ValueError("The OPENAI_API_KEY environment variable is missing") cls.mock_llm = ChatOpenAI(openai_api_key=api_key) - cls.mock_retriever = AgentManager(cls.rate_limiter, cls.rate_limiter_sync).create_new_memory_retriever( - api_key=api_key, user_id="test1") + cls.mock_retriever = AgentManager( + cls.rate_limiter, cls.rate_limiter_sync + ).create_new_memory_retriever(api_key=api_key, user_id="test1") cls.agent_memory = GenerativeAgentConversationSummarizedMemory( rate_limiter=cls.rate_limiter, llm=cls.mock_llm, memory_retriever=cls.mock_retriever, - verbose=True + verbose=True, ) def setUp(self): @@ -44,22 +49,39 @@ def test_integration(self): # Call add_memory method memory_content = "sample memory" - result = self.loop.run_until_complete(self.agent_memory.add_memory( - memory_content, conversation_id, importance, memory_type, now=timestamp)) + result = self.loop.run_until_complete( + self.agent_memory.add_memory( + memory_content, conversation_id, importance, memory_type, now=timestamp + ) + ) self.assertIsNotNone(result) # Call add_memories method qa_list = ["question 1", "answer 1"] importance_list = ["high", "high"] - memory_types_list = [MemoryType.CONSCIOUS_MEMORY, - MemoryType.CONSCIOUS_MEMORY] - result = self.loop.run_until_complete(self.agent_memory.add_memories( - qa_list, conversation_id, importance_list, memory_types_list, now=timestamp)) + memory_types_list = [MemoryType.CONSCIOUS_MEMORY, MemoryType.CONSCIOUS_MEMORY] + result = self.loop.run_until_complete( + self.agent_memory.add_memories( + qa_list, + conversation_id, + importance_list, + memory_types_list, + now=timestamp, + ) + ) self.assertIsNotNone(result) # Call save_context method - result = self.loop.run_until_complete(self.agent_memory.save_context( - {"query": "sample query", "llm_response": "sample response", "importance": "high", "conversation_id": conversation_id})) + result = self.loop.run_until_complete( + self.agent_memory.save_context( + { + "query": "sample query", + "llm_response": "sample response", + "importance": "high", + "conversation_id": conversation_id, + } + ) + ) self.assertIsNotNone(result) # Call get_conversation method @@ -69,5 +91,5 @@ def test_integration(self): self.agent_memory.clear() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_main.txt b/tests/test_main.py similarity index 79% rename from tests/test_main.txt rename to tests/test_main.py index de213f3..be8d53e 100644 --- a/tests/test_main.txt +++ b/tests/test_main.py @@ -1,19 +1,26 @@ +import os + import pytest import requests -import os + @pytest.fixture def url(): - return "http://localhost:8000" # Modify as needed + return "http://localhost:8000" # Modify as needed + def test_query_plan(url): test_api_key = os.getenv("OPENAI_API_KEY") - mock_data = {"api_key": f"{test_api_key}", "query": "Compare the temperature in Sydney to that of London today"} + mock_data = { + "api_key": f"{test_api_key}", + "query": "Compare the temperature in Sydney to that of London today", + } response = requests.post(f"{url}/query_plan/", json=mock_data) assert response.status_code == 200 assert "response" in response.json() assert "elapsed_time" in response.json() + def test_memory_output(url): test_api_key = os.getenv("OPENAI_API_KEY") mock_data = { @@ -21,12 +28,13 @@ def test_memory_output(url): "user_id": "1", "query": "test", "llm_response": "response", - "conversation_id": "1" + "conversation_id": "1", } response = requests.post(f"{url}/push_memory/", json=mock_data) assert response.status_code == 200 assert "elapsed_time" in response.json() + def test_memory_input(url): test_api_key = os.getenv("OPENAI_API_KEY") mock_data = { @@ -34,34 +42,46 @@ def test_memory_input(url): "user_id": "1", "query": "test", "conversation_id": "1", - "summary": True + "summary": True, } response = requests.post(f"{url}/pull_memory/", json=mock_data) assert response.status_code == 200 assert "response" in response.json() assert "elapsed_time" in response.json() + def test_semantic_search_html(url): test_api_key = os.getenv("OPENAI_API_KEY") mock_data = { "api_key": f"{test_api_key}", "action_items": [{"source_url": "http://example.com", "html_doc": "text1"}], "hash": "ffff", - "query": "test" + "query": "test", } response = requests.post(f"{url}/semantic_search_html/", json=mock_data) assert response.status_code == 200 assert "response" in response.json() assert "elapsed_time" in response.json() + def test_get_functions(url): test_api_key = os.getenv("OPENAI_API_KEY") - mock_data = {"api_key": f"{test_api_key}", "action_items": [{"action": "search stocks", "intent": "get price of aapl", "category": "information retrieval"}]} + mock_data = { + "api_key": f"{test_api_key}", + "action_items": [ + { + "action": "search stocks", + "intent": "get price of aapl", + "category": "information retrieval", + } + ], + } response = requests.post(f"{url}/get_functions/", json=mock_data) assert response.status_code == 200 assert "response" in response.json() assert "elapsed_time" in response.json() + def test_clear_user_memory(url): mock_data = {"user_id": "1test", "conversation_id": "1test"} response = requests.post(f"{url}/clear_conversation/", json=mock_data) @@ -69,7 +89,8 @@ def test_clear_user_memory(url): assert "response" in response.json() assert "elapsed_time" in response.json() + def test_test_callback(url): response = requests.get(f"{url}/test_callback/") assert response.status_code == 200 - assert "test" in response.json() \ No newline at end of file + assert "test" in response.json() diff --git a/tests/test_memory_summarizer.py b/tests/test_memory_summarizer.py index 55e433e..8d5affa 100644 --- a/tests/test_memory_summarizer.py +++ b/tests/test_memory_summarizer.py @@ -1,13 +1,17 @@ -import pytest +import os from unittest import mock -from unittest.mock import Mock, patch -from memory_summarizer import MemorySummarizer -from unittest.mock import create_autospec -from qdrant_client import QdrantClient +from unittest.mock import Mock, create_autospec, patch + +import pytest from langchain_openai import ChatOpenAI -from generative_conversation_summarized_memory import GenerativeAgentConversationSummarizedMemory -import os +from qdrant_client import QdrantClient + +from generative_conversation_summarized_memory import ( + GenerativeAgentConversationSummarizedMemory, +) +from memory_summarizer import MemorySummarizer from rate_limiter import RateLimiter, SyncRateLimiter + rate_limiter = RateLimiter(rate=5, period=1) rate_limiter_sync = SyncRateLimiter(rate=5, period=1) @@ -19,19 +23,27 @@ def mock_agent_manager(): return mock_agent_manager -@patch('memory_summarizer.Qdrant') -@patch('memory_summarizer.OpenAIEmbeddings') -@patch('memory_summarizer.CohereRerank') -@patch('memory_summarizer.ContextualCompressionRetriever') -@patch('memory_summarizer.QDrantVectorStoreRetriever') -def test_create_new_conversation_summarizer(mock_retriever, mock_compression, mock_rerank, mock_embeddings, mock_qdrant, mock_agent_manager): +@patch("memory_summarizer.Qdrant") +@patch("memory_summarizer.OpenAIEmbeddings") +@patch("memory_summarizer.CohereRerank") +@patch("memory_summarizer.ContextualCompressionRetriever") +@patch("memory_summarizer.QDrantVectorStoreRetriever") +def test_create_new_conversation_summarizer( + mock_retriever, + mock_compression, + mock_rerank, + mock_embeddings, + mock_qdrant, + mock_agent_manager, +): # Arrange summarizer = MemorySummarizer( - rate_limiter, rate_limiter_sync, Mock(), mock_agent_manager) + rate_limiter, rate_limiter_sync, Mock(), mock_agent_manager + ) api_key = os.getenv("OPENAI_API_KEY") # Act - result = summarizer.create_new_conversation_summarizer(api_key, 'user_id') + result = summarizer.create_new_conversation_summarizer(api_key, "user_id") # Assert mock_agent_manager.client.create_collection.assert_called_once() @@ -42,15 +54,16 @@ def test_create_new_conversation_summarizer(mock_retriever, mock_compression, mo assert result == mock_compression.return_value -@patch('memory_summarizer.GenerativeAgentConversationSummarizedMemory') +@patch("memory_summarizer.GenerativeAgentConversationSummarizedMemory") def test_create_summarized_memory(mock_memory, mock_agent_manager): # Arrange summarizer = MemorySummarizer( - rate_limiter, rate_limiter_sync, Mock(), mock_agent_manager) + rate_limiter, rate_limiter_sync, Mock(), mock_agent_manager + ) api_key = os.getenv("OPENAI_API_KEY") # Act - result = summarizer.create_summarized_memory(api_key, 'user_id') + result = summarizer.create_summarized_memory(api_key, "user_id") print(result) # # Assert @@ -58,19 +71,20 @@ def test_create_summarized_memory(mock_memory, mock_agent_manager): rate_limiter=rate_limiter, llm=mock.ANY, memory_retriever=mock.ANY, - verbose=mock.ANY + verbose=mock.ANY, ) -@patch('memory_summarizer.GenerativeAgentConversationSummarizedMemory') +@patch("memory_summarizer.GenerativeAgentConversationSummarizedMemory") def test_load(mock_memory, mock_agent_manager): # Arrange summarizer = MemorySummarizer( - rate_limiter, rate_limiter_sync, Mock(), mock_agent_manager) + rate_limiter, rate_limiter_sync, Mock(), mock_agent_manager + ) api_key = os.getenv("OPENAI_API_KEY") # Act - result = summarizer.load(api_key, 'user_id') + result = summarizer.load(api_key, "user_id") # Assert mock_memory.assert_called_once() diff --git a/tests/web_manager_test.py b/tests/web_manager_test.py index 6a23e08..f01c6e6 100644 --- a/tests/web_manager_test.py +++ b/tests/web_manager_test.py @@ -1,19 +1,24 @@ -import pytest -from web_manager import WebManager, HTMLInput, HTMLItem -from langchain.retrievers import ContextualCompressionRetriever -import time -from dotenv import load_dotenv import asyncio import os +import time + +import pytest +from dotenv import load_dotenv +from langchain.retrievers import ContextualCompressionRetriever + from rate_limiter import RateLimiter, SyncRateLimiter +from web_manager import HTMLInput, HTMLItem, WebManager + rate_limiter = RateLimiter(rate=5, period=1) rate_limiter_sync = SyncRateLimiter(rate=5, period=1) + @pytest.fixture def setup_web_manager(): load_dotenv() web_manager = WebManager(rate_limiter, rate_limiter_sync) - yield web_manager + yield web_manager + @pytest.fixture def setup_html_input(): @@ -23,10 +28,11 @@ def setup_html_input(): HTMLItem(source_url="http://example.com", html_doc="test_html_doc") ], hash="test_hash", - query="test_query" + query="test_query", ) return html_input + @pytest.mark.asyncio async def test_load(setup_web_manager, setup_html_input): try: @@ -35,10 +41,11 @@ async def test_load(setup_web_manager, setup_html_input): memory = web_manager.load(setup_html_input.api_key) end_time = time.time() assert isinstance(memory, ContextualCompressionRetriever) - assert end_time - start_time >= 0 + assert end_time - start_time >= 0 except: print("Load test not executed") + @pytest.mark.asyncio async def test_search_html(setup_web_manager, setup_html_input): web_manager = setup_web_manager @@ -47,10 +54,10 @@ async def test_search_html(setup_web_manager, setup_html_input): assert isinstance(response, list) assert isinstance(duration, float) for item in response: - assert 'text' in item - assert 'source_url' in item - + assert "text" in item + assert "source_url" in item + current_task = asyncio.current_task() pending = [t for t in asyncio.all_tasks() if t is not current_task] for task in pending: - await task \ No newline at end of file + await task diff --git a/web_manager.py b/web_manager.py index f429448..0c07ff7 100644 --- a/web_manager.py +++ b/web_manager.py @@ -1,26 +1,27 @@ -import time import datetime +import logging import os import random -import logging +import time import traceback -import cachetools.func +from datetime import datetime, timedelta +from typing import List +import cachetools.func from dotenv import load_dotenv -from llama_index.core.langchain_helpers.text_splitter import SentenceSplitter -from typing import List -from qdrant_client import QdrantClient -from pydantic import BaseModel, Field -from langchain_qdrant import Qdrant -from qdrant_retriever import QDrantVectorStoreRetriever -from langchain_openai import OpenAIEmbeddings from langchain.retrievers import ContextualCompressionRetriever -from cohere_rerank import CohereRerank from langchain.schema import Document -from datetime import datetime, timedelta +from langchain_openai import OpenAIEmbeddings +from langchain_qdrant import Qdrant +from llama_index.core.langchain_helpers.text_splitter import SentenceSplitter +from pydantic import BaseModel, Field +from qdrant_client import QdrantClient from qdrant_client.http import models as rest from qdrant_client.http.models import PayloadSchemaType +from cohere_rerank import CohereRerank +from qdrant_retriever import QDrantVectorStoreRetriever + class HTMLItem(BaseModel): source_url: str @@ -33,8 +34,9 @@ class CacheHTML(BaseModel): class HTMLInput(BaseModel): api_key: str - action_items: List[HTMLItem] = Field(..., example=[ - {"source_url": "http://example.com", "html_doc": "text1"}]) + action_items: List[HTMLItem] = Field( + ..., example=[{"source_url": "http://example.com", "html_doc": "text1"}] + ) hash: str query: str @@ -56,8 +58,7 @@ def __init__(self, rate_limiter, rate_limiter_sync): self.QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") self.QDRANT_URL = os.getenv("QDRANT_URL") self.collection_name = "web" - self.client = QdrantClient( - url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY) + self.client = QdrantClient(url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY) self.rate_limiter = rate_limiter self.rate_limiter_sync = rate_limiter_sync @@ -73,19 +74,33 @@ def create_new_web_retriever(self, api_key: str): ), ) self.client.create_payload_index( - self.collection_name, "metadata.hash_key", field_schema=PayloadSchemaType.KEYWORD) + self.collection_name, + "metadata.hash_key", + field_schema=PayloadSchemaType.KEYWORD, + ) except: logging.info("WebManager: loaded from cloud...") finally: logging.info( - f"WebManager: Creating memory store with collection {self.collection_name}") - vectorstore = Qdrant(self.client, self.collection_name, OpenAIEmbeddings( - model="text-embedding-3-small", openai_api_key=api_key)) + f"WebManager: Creating memory store with collection {self.collection_name}" + ) + vectorstore = Qdrant( + self.client, + self.collection_name, + OpenAIEmbeddings( + model="text-embedding-3-small", openai_api_key=api_key + ), + ) compressor = CohereRerank() compression_retriever = ContextualCompressionRetriever( - base_compressor=compressor, base_retriever=QDrantVectorStoreRetriever( - rate_limiter=self.rate_limiter, rate_limiter_sync=self.rate_limiter_sync, collection_name=self.collection_name, client=self.client, vectorstore=vectorstore, - ) + base_compressor=compressor, + base_retriever=QDrantVectorStoreRetriever( + rate_limiter=self.rate_limiter, + rate_limiter_sync=self.rate_limiter_sync, + collection_name=self.collection_name, + client=self.client, + vectorstore=vectorstore, + ), ) return compression_retriever @@ -94,15 +109,17 @@ def extract_text_and_source_url(self, retrieved_nodes): seen = set() for document in retrieved_nodes: text = document.page_content - source_url = document.metadata.get('source_url') + source_url = document.metadata.get("source_url") # Create a tuple of text and source_url to check for duplicates key = (text, source_url) if key not in seen: - result.append({'text': text, 'source_url': source_url}) + result.append({"text": text, "source_url": source_url}) seen.add(key) return result - async def get_retrieved_nodes(self, memory: ContextualCompressionRetriever, function_input: HTMLInput): + async def get_retrieved_nodes( + self, memory: ContextualCompressionRetriever, function_input: HTMLInput + ): filter = rest.Filter( must=[ rest.FieldCondition( @@ -138,30 +155,49 @@ async def search_html(self, function_input: HTMLInput): for item in function_input.action_items: text_splitter = SentenceSplitter() chunks = text_splitter.split_text(text=item.html_doc) - documents.extend([Document(page_content=chunk, metadata={"id": random.randint( - 0, 2**32 - 1), "hash_key": function_input.hash, "last_accessed_at": nowStamp, 'source_url': item.source_url}) for chunk in chunks]) + documents.extend( + [ + Document( + page_content=chunk, + metadata={ + "id": random.randint(0, 2**32 - 1), + "hash_key": function_input.hash, + "last_accessed_at": nowStamp, + "source_url": item.source_url, + }, + ) + for chunk in chunks + ] + ) if len(documents) > 0: ids = [doc.metadata["id"] for doc in documents] - await self.rate_limiter.execute(memory.base_retriever.vectorstore.aadd_documents, documents, ids=ids) + await self.rate_limiter.execute( + memory.base_retriever.vectorstore.aadd_documents, documents, ids=ids + ) end = time.time() logging.info( - f"WebManager: Loaded from documents operation took {end - start} seconds") + f"WebManager: Loaded from documents operation took {end - start} seconds" + ) nodes = await self.get_retrieved_nodes(memory, function_input) response = self.extract_text_and_source_url(nodes) # update last_accessed_at if len(function_input.action_items) == 0 and len(nodes) > 0: ids = [doc.metadata["id"] for doc in nodes] for doc in nodes: - doc.metadata.pop('relevance_score', None) - await self.rate_limiter.execute(memory.base_retriever.vectorstore.aadd_documents, nodes, ids=ids) + doc.metadata.pop("relevance_score", None) + await self.rate_limiter.execute( + memory.base_retriever.vectorstore.aadd_documents, nodes, ids=ids + ) self.prune_web() except Exception as e: logging.warning( - f"WebManager: search_html exception {e}\n{traceback.format_exc()}") + f"WebManager: search_html exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() logging.info( - f"WebManager: search_html operation took {end - start} seconds") + f"WebManager: search_html operation took {end - start} seconds" + ) return response, end - start def prune_web(self): @@ -177,7 +213,10 @@ def prune_web(self): ] ) self.rate_limiter_sync.execute( - self.client.delete, collection_name=self.collection_name, points_selector=filter) + self.client.delete, + collection_name=self.collection_name, + points_selector=filter, + ) def does_hash_exist(self, hash: str): start = time.time() @@ -192,12 +231,18 @@ def does_hash_exist(self, hash: str): ] ) result, _ = self.rate_limiter_sync.execute( - self.client.scroll, collection_name=self.collection_name, scroll_filter=filter, limit=1) + self.client.scroll, + collection_name=self.collection_name, + scroll_filter=filter, + limit=1, + ) except Exception as e: logging.warning( - f"WebManager: does_hash_exist exception {e}\n{traceback.format_exc()}") + f"WebManager: does_hash_exist exception {e}\n{traceback.format_exc()}" + ) finally: end = time.time() logging.info( - f"WebManager: does_hash_exist operation took {end - start} seconds") + f"WebManager: does_hash_exist operation took {end - start} seconds" + ) return result is not None and len(result) > 0, end - start