Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 90 additions & 52 deletions agent_manager.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import time
import asyncio
import logging
import os
import asyncio
import time
import traceback
import cachetools.func
from datetime import datetime, timedelta
from typing import Any, Dict

import cachetools.func
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from qdrant_retriever import QDrantVectorStoreRetriever
from cohere_rerank import CohereRerank
from generative_memory import GenerativeAgentMemory
from langchain.retrievers import ContextualCompressionRetriever
from langchain.schema import Document
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_qdrant import Qdrant
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
from qdrant_client.http.models import PayloadSchemaType
from memory_summarizer import MemorySummarizer
from pydantic import BaseModel

from cohere_rerank import CohereRerank
from document_summarizer import FlexibleDocumentSummarizer
from langchain_openai import ChatOpenAI
from langchain.schema import Document
from datetime import datetime, timedelta
from typing import Any, Dict
from generative_memory import GenerativeAgentMemory
from memory_summarizer import MemorySummarizer
from preferences_resolver import PreferencesResolver
from preferences_updater import PreferencesUpdater
from qdrant_retriever import QDrantVectorStoreRetriever


class MemoryInput(BaseModel):
Expand All @@ -37,7 +37,12 @@ def __str__(self):
return str(self.summary) + self.user_id + self.query + self.conversation_id

def __eq__(self, other):
return self.user_id == other.user_id and self.query == other.query and self.conversation_id == other.conversation_id and self.summary == other.summary
return (
self.user_id == other.user_id
and self.query == other.query
and self.conversation_id == other.conversation_id
and self.summary == other.summary
)

def __hash__(self):
return hash(str(self))
Expand All @@ -64,32 +69,35 @@ def __init__(self, rate_limiter, rate_limiter_sync):
self.QDRANT_URL = os.getenv("QDRANT_URL")
self.rate_limiter = rate_limiter
self.rate_limiter_sync = rate_limiter_sync
self.client = QdrantClient(
url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY)
self.client = QdrantClient(url=self.QDRANT_URL, api_key=self.QDRANT_API_KEY)
self.verbose = True
self.preferences_resolver = PreferencesResolver()
self.preferences_updater = PreferencesUpdater(
self.preferences_resolver, self.verbose)
self.preferences_resolver, self.verbose
)

async def push_memory(self, memory_output: MemoryOutput):
"""Add new memory to the current index for a specific user."""
start = time.time()
memory = self.load(memory_output.api_key, memory_output.user_id)
try:
# update preferences on every exchange but only save summarized memory of a "finished" exchange, reflect on an important summarized memory and then decay memories
asyncio.create_task(memory.pause_to_reflect(
memory_output.dict(), self.preferences_resolver))
asyncio.create_task(
memory.pause_to_reflect(memory_output.dict(), self.preferences_resolver)
)
# asyncio.create_task(self.preferences_updater.update_preferences(ChatOpenAI(openai_api_key=memory_output.api_key,
# model="gpt-4.1-mini", temperature=0), memory_output.query, memory_output.llm_response, memory_output.user_id))
# decay memory by summarizing it continiously until max_summarizations then prune
asyncio.create_task(memory.decay())
except Exception as e:
logging.warning(
f"AgentManager: push_memory exception {e}\n{traceback.format_exc()}")
f"AgentManager: push_memory exception {e}\n{traceback.format_exc()}"
)
finally:
end = time.time()
logging.info(
f"AgentManager: push_memory operation took {end - start} seconds")
f"AgentManager: push_memory operation took {end - start} seconds"
)
return end - start

def create_new_memory_retriever(self, api_key: str, user_id: str):
Expand All @@ -105,32 +113,55 @@ def create_new_memory_retriever(self, api_key: str, user_id: str):
),
)
self.client.create_payload_index(
collection_name, "metadata.extra_index", field_schema=PayloadSchemaType.KEYWORD)
collection_name,
"metadata.extra_index",
field_schema=PayloadSchemaType.KEYWORD,
)
except:
print("AgentManager: loaded from cloud...")
finally:
logging.info(
f"AgentManager: Creating memory store with collection {collection_name}")
vectorstore = Qdrant(self.client, collection_name, OpenAIEmbeddings(
model="text-embedding-3-small", openai_api_key=api_key))
f"AgentManager: Creating memory store with collection {collection_name}"
)
vectorstore = Qdrant(
self.client,
collection_name,
OpenAIEmbeddings(
model="text-embedding-3-small", openai_api_key=api_key
),
)
compressor = CohereRerank()
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=QDrantVectorStoreRetriever(
rate_limiter=self.rate_limiter, rate_limiter_sync=self.rate_limiter_sync, collection_name=collection_name, client=self.client, vectorstore=vectorstore,
)
base_compressor=compressor,
base_retriever=QDrantVectorStoreRetriever(
rate_limiter=self.rate_limiter,
rate_limiter_sync=self.rate_limiter_sync,
collection_name=collection_name,
client=self.client,
vectorstore=vectorstore,
),
)
return compression_retriever

def create_memory(self, api_key: str, user_id: str):
return GenerativeAgentMemory(
rate_limiter=self.rate_limiter,
llm=ChatOpenAI(openai_api_key=api_key,
model="gpt-4.1-mini", max_tokens=1024),
memory_retriever=self.create_new_memory_retriever(
api_key, user_id),
memory_summarizer=MemorySummarizer(rate_limiter=self.rate_limiter, rate_limiter_sync=self.rate_limiter_sync, flexible_document_summarizer=FlexibleDocumentSummarizer(
ChatOpenAI(openai_api_key=api_key, model="gpt-4.1-mini", temperature=0), verbose=self.verbose), agent_manager=self),
verbose=self.verbose
llm=ChatOpenAI(
openai_api_key=api_key, model="gpt-4.1-mini", max_tokens=1024
),
memory_retriever=self.create_new_memory_retriever(api_key, user_id),
memory_summarizer=MemorySummarizer(
rate_limiter=self.rate_limiter,
rate_limiter_sync=self.rate_limiter_sync,
flexible_document_summarizer=FlexibleDocumentSummarizer(
ChatOpenAI(
openai_api_key=api_key, model="gpt-4.1-mini", temperature=0
),
verbose=self.verbose,
),
agent_manager=self,
),
verbose=self.verbose,
)

@cachetools.func.ttl_cache(maxsize=16384, ttl=36000)
Expand All @@ -139,8 +170,7 @@ def load(self, api_key: str, user_id: str) -> GenerativeAgentMemory:
start = time.time()
memory = self.create_memory(api_key, user_id)
end = time.time()
logging.info(
f"AgentManager: Load operation took {end - start} seconds")
logging.info(f"AgentManager: Load operation took {end - start} seconds")
return memory

def _document_from_scored_point(
Expand All @@ -165,7 +195,8 @@ def get_key_value_document(self, collection_name, key, value) -> Document:
]
)
record, _ = self.client.scroll(
collection_name=collection_name, scroll_filter=filter, limit=1)
collection_name=collection_name, scroll_filter=filter, limit=1
)
if record is not None and len(record) > 0:
return self._document_from_scored_point(
record[0], "page_content", "metadata"
Expand All @@ -175,7 +206,10 @@ def get_key_value_document(self, collection_name, key, value) -> Document:

def load_summary(self, memory_input: MemoryInput) -> Dict[str, str]:
doc = self.get_key_value_document(
f"{memory_input.user_id}_summaries", "metadata.extra_index", memory_input.conversation_id)
f"{memory_input.user_id}_summaries",
"metadata.extra_index",
memory_input.conversation_id,
)
ret = ""
if doc:
ret = self.format_summary_simple(doc)
Expand All @@ -187,8 +221,7 @@ def load_summary(self, memory_input: MemoryInput) -> Dict[str, str]:
async def load_memory(self, memory_input: MemoryInput):
memory = self.load(memory_input.api_key, memory_input.user_id)
return await memory.load_memory_variables(
queries=[memory_input.query],
conversation_id=memory_input.conversation_id
queries=[memory_input.query], conversation_id=memory_input.conversation_id
)

def _time_ago(self, timestamp: float) -> str:
Expand All @@ -209,8 +242,7 @@ def format_summary_simple(self, conversation_summary: Document) -> str:
created_ago = self._time_ago(created_at)

# Extracting the extra_index (conversation_id)
conversation_id = conversation_summary.metadata.get(
"extra_index", "N/A")
conversation_id = conversation_summary.metadata.get("extra_index", "N/A")
return f"(created: {created_ago}, conversation_id: {conversation_id}) {conversation_summary.page_content}"

async def pull_memory(self, memory_input: MemoryInput):
Expand All @@ -222,19 +254,22 @@ async def pull_memory(self, memory_input: MemoryInput):
if memory_input.summary:
if len(memory_input.conversation_id) <= 0:
logging.warning(
f"AgentManager: pull_memory asked for summary but no conversation_id provided!")
f"AgentManager: pull_memory asked for summary but no conversation_id provided!"
)
end = time.time()
return {}, end - start
response = self.load_summary(memory_input)
else:
response = await self.load_memory(memory_input)
except Exception as e:
logging.warning(
f"AgentManager: pull_memory exception {e}\n{traceback.format_exc()}")
f"AgentManager: pull_memory exception {e}\n{traceback.format_exc()}"
)
finally:
end = time.time()
logging.info(
f"AgentManager: pull_memory operation took {end - start} seconds")
f"AgentManager: pull_memory operation took {end - start} seconds"
)
return response, end - start

def clear_collection_with_extra_index(self, collection_name, extra_index) -> None:
Expand All @@ -247,22 +282,25 @@ def clear_collection_with_extra_index(self, collection_name, extra_index) -> Non
)
]
)
self.client.delete(collection_name=collection_name,
points_selector=filter)
self.client.delete(collection_name=collection_name, points_selector=filter)

def clear_conversation(self, clear_memory: ClearMemory):
"""Delete all memories for a specific conversation with a user."""
start = time.time()
try:
self.clear_collection_with_extra_index(
clear_memory.user_id, clear_memory.conversation_id)
clear_memory.user_id, clear_memory.conversation_id
)
self.clear_collection_with_extra_index(
f"{clear_memory.user_id}_summaries", clear_memory.conversation_id)
f"{clear_memory.user_id}_summaries", clear_memory.conversation_id
)
except Exception as e:
logging.warning(
f"AgentManager: clear_conversation exception {e}\n{traceback.format_exc()}")
f"AgentManager: clear_conversation exception {e}\n{traceback.format_exc()}"
)
finally:
end = time.time()
logging.info(
f"AgentManager: clear_conversation operation took {end - start} seconds")
f"AgentManager: clear_conversation operation took {end - start} seconds"
)
return "success", end - start
2 changes: 1 addition & 1 deletion cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel


class CacheClearInput(BaseModel):
cache_types: list
console_key: str

Loading