Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
)

completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")

cache_entry = CacheEntry(
query=query_request.query,
response=summary.llm_response,
Expand All @@ -383,6 +382,8 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
started_at=started_at,
completed_at=completed_at,
referenced_documents=referenced_documents if referenced_documents else None,
tool_calls=summary.tool_calls if summary.tool_calls else None,
tool_results=summary.tool_results if summary.tool_results else None,
)

consume_tokens(
Expand Down
87 changes: 81 additions & 6 deletions src/cache/postgres_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from models.cache_entry import CacheEntry
from models.config import PostgreSQLDatabaseConfiguration
from models.responses import ConversationData, ReferencedDocument
from log import get_logger
from utils.connection_decorator import connection
from utils.types import ToolCallSummary, ToolResultSummary
from log import get_logger

logger = get_logger("cache.postgres_cache")

Expand All @@ -32,7 +33,9 @@ class PostgresCache(Cache):
response | text | |
provider | text | |
model | text | |
referenced_documents | jsonb | |
referenced_documents | jsonb | |
tool_calls | jsonb | |
tool_results | jsonb | |
Indexes:
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
"timestamps" btree (created_at)
Expand All @@ -55,6 +58,8 @@ class PostgresCache(Cache):
provider text,
model text,
referenced_documents jsonb,
tool_calls jsonb,
tool_results jsonb,
PRIMARY KEY(user_id, conversation_id, created_at)
);
"""
Expand All @@ -75,16 +80,18 @@ class PostgresCache(Cache):
"""

SELECT_CONVERSATION_HISTORY_STATEMENT = """
SELECT query, response, provider, model, started_at, completed_at, referenced_documents
SELECT query, response, provider, model, started_at, completed_at,
referenced_documents, tool_calls, tool_results
FROM cache
WHERE user_id=%s AND conversation_id=%s
ORDER BY created_at
"""

INSERT_CONVERSATION_HISTORY_STATEMENT = """
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
query, response, provider, model, referenced_documents)
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s)
query, response, provider, model, referenced_documents,
tool_calls, tool_results)
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""

QUERY_CACHE_SIZE = """
Expand Down Expand Up @@ -220,7 +227,7 @@ def initialize_cache(self, namespace: str) -> None:
self.connection.commit()

@connection
def get(
def get( # pylint: disable=R0914
self, user_id: str, conversation_id: str, skip_user_id_check: bool = False
) -> list[CacheEntry]:
"""Get the value associated with the given key.
Expand Down Expand Up @@ -260,6 +267,40 @@ def get(
conversation_id,
e,
)

# Parse tool_calls back into ToolCallSummary objects
tool_calls_data = conversation_entry[7]
tool_calls_obj = None
if tool_calls_data:
try:
tool_calls_obj = [
ToolCallSummary.model_validate(tc) for tc in tool_calls_data
]
except (ValueError, TypeError) as e:
logger.warning(
"Failed to deserialize tool_calls for "
"conversation %s: %s",
conversation_id,
e,
)

# Parse tool_results back into ToolResultSummary objects
tool_results_data = conversation_entry[8]
tool_results_obj = None
if tool_results_data:
try:
tool_results_obj = [
ToolResultSummary.model_validate(tr)
for tr in tool_results_data
]
except (ValueError, TypeError) as e:
logger.warning(
"Failed to deserialize tool_results for "
"conversation %s: %s",
conversation_id,
e,
)

cache_entry = CacheEntry(
query=conversation_entry[0],
response=conversation_entry[1],
Expand All @@ -268,6 +309,8 @@ def get(
started_at=conversation_entry[4],
completed_at=conversation_entry[5],
referenced_documents=docs_obj,
tool_calls=tool_calls_obj,
tool_results=tool_results_obj,
)
result.append(cache_entry)

Expand Down Expand Up @@ -311,6 +354,36 @@ def insert_or_append(
e,
)

tool_calls_json = None
if cache_entry.tool_calls:
try:
tool_calls_as_dicts = [
tc.model_dump(mode="json") for tc in cache_entry.tool_calls
]
tool_calls_json = json.dumps(tool_calls_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_calls for "
"conversation %s: %s",
conversation_id,
e,
)

tool_results_json = None
if cache_entry.tool_results:
try:
tool_results_as_dicts = [
tr.model_dump(mode="json") for tr in cache_entry.tool_results
]
tool_results_json = json.dumps(tool_results_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_results for "
"conversation %s: %s",
conversation_id,
e,
)

# the whole operation is run in one transaction
with self.connection.cursor() as cursor:
cursor.execute(
Expand All @@ -325,6 +398,8 @@ def insert_or_append(
cache_entry.provider,
cache_entry.model,
referenced_documents_json,
tool_calls_json,
tool_results_json,
),
)

Expand Down
86 changes: 81 additions & 5 deletions src/cache/sqlite_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from models.cache_entry import CacheEntry
from models.config import SQLiteDatabaseConfiguration
from models.responses import ConversationData, ReferencedDocument
from log import get_logger
from utils.connection_decorator import connection
from utils.types import ToolCallSummary, ToolResultSummary
from log import get_logger

logger = get_logger("cache.sqlite_cache")

Expand All @@ -34,6 +35,8 @@ class SQLiteCache(Cache):
provider | text | |
model | text | |
referenced_documents | text | |
tool_calls | text | |
tool_results | text | |
Indexes:
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
"cache_key_key" UNIQUE CONSTRAINT, btree (key)
Expand All @@ -54,6 +57,8 @@ class SQLiteCache(Cache):
provider text,
model text,
referenced_documents text,
tool_calls text,
tool_results text,
PRIMARY KEY(user_id, conversation_id, created_at)
);
"""
Expand All @@ -74,16 +79,18 @@ class SQLiteCache(Cache):
"""

SELECT_CONVERSATION_HISTORY_STATEMENT = """
SELECT query, response, provider, model, started_at, completed_at, referenced_documents
SELECT query, response, provider, model, started_at, completed_at,
referenced_documents, tool_calls, tool_results
FROM cache
WHERE user_id=? AND conversation_id=?
ORDER BY created_at
"""

INSERT_CONVERSATION_HISTORY_STATEMENT = """
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
query, response, provider, model, referenced_documents)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
query, response, provider, model, referenced_documents,
tool_calls, tool_results)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""

QUERY_CACHE_SIZE = """
Expand Down Expand Up @@ -187,7 +194,7 @@ def initialize_cache(self) -> None:
self.connection.commit()

@connection
def get(
def get( # pylint: disable=R0914
self, user_id: str, conversation_id: str, skip_user_id_check: bool = False
) -> list[CacheEntry]:
"""Get the value associated with the given key.
Expand Down Expand Up @@ -228,6 +235,41 @@ def get(
conversation_id,
e,
)

# Parse tool_calls back into ToolCallSummary objects
tool_calls_json_str = conversation_entry[7]
tool_calls_obj = None
if tool_calls_json_str:
try:
tool_calls_data = json.loads(tool_calls_json_str)
tool_calls_obj = [
ToolCallSummary.model_validate(tc) for tc in tool_calls_data
]
except (json.JSONDecodeError, ValueError) as e:
logger.warning(
"Failed to deserialize tool_calls for "
"conversation %s: %s",
conversation_id,
e,
)

# Parse tool_results back into ToolResultSummary objects
tool_results_json_str = conversation_entry[8]
tool_results_obj = None
if tool_results_json_str:
try:
tool_results_data = json.loads(tool_results_json_str)
tool_results_obj = [
ToolResultSummary.model_validate(tr) for tr in tool_results_data
]
except (json.JSONDecodeError, ValueError) as e:
logger.warning(
"Failed to deserialize tool_results for "
"conversation %s: %s",
conversation_id,
e,
)

cache_entry = CacheEntry(
query=conversation_entry[0],
response=conversation_entry[1],
Expand All @@ -236,6 +278,8 @@ def get(
started_at=conversation_entry[4],
completed_at=conversation_entry[5],
referenced_documents=docs_obj,
tool_calls=tool_calls_obj,
tool_results=tool_results_obj,
)
result.append(cache_entry)

Expand Down Expand Up @@ -281,6 +325,36 @@ def insert_or_append(
e,
)

tool_calls_json = None
if cache_entry.tool_calls:
try:
tool_calls_as_dicts = [
tc.model_dump(mode="json") for tc in cache_entry.tool_calls
]
tool_calls_json = json.dumps(tool_calls_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_calls for "
"conversation %s: %s",
conversation_id,
e,
)

tool_results_json = None
if cache_entry.tool_results:
try:
tool_results_as_dicts = [
tr.model_dump(mode="json") for tr in cache_entry.tool_results
]
tool_results_json = json.dumps(tool_results_as_dicts)
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize tool_results for "
"conversation %s: %s",
conversation_id,
e,
)

cursor.execute(
self.INSERT_CONVERSATION_HISTORY_STATEMENT,
(
Expand All @@ -294,6 +368,8 @@ def insert_or_append(
cache_entry.provider,
cache_entry.model,
referenced_documents_json,
tool_calls_json,
tool_results_json,
),
)

Expand Down
5 changes: 5 additions & 0 deletions src/models/cache_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional
from pydantic import BaseModel
from models.responses import ReferencedDocument
from utils.types import ToolCallSummary, ToolResultSummary


class CacheEntry(BaseModel):
Expand All @@ -14,6 +15,8 @@ class CacheEntry(BaseModel):
provider: Provider identification
model: Model identification
referenced_documents: List of documents referenced by the response
tool_calls: List of tool calls made during response generation
tool_results: List of tool results from tool calls
"""

query: str
Expand All @@ -23,3 +26,5 @@ class CacheEntry(BaseModel):
started_at: str
completed_at: str
referenced_documents: Optional[list[ReferencedDocument]] = None
tool_calls: Optional[list[ToolCallSummary]] = None
tool_results: Optional[list[ToolResultSummary]] = None
2 changes: 2 additions & 0 deletions src/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,8 @@ async def cleanup_after_streaming(
started_at=started_at,
completed_at=completed_at,
referenced_documents=referenced_documents if referenced_documents else None,
tool_calls=summary.tool_calls if summary.tool_calls else None,
tool_results=summary.tool_results if summary.tool_results else None,
)

store_conversation_into_cache(
Expand Down
Loading
Loading