Skip to content

Commit 5587a71

Browse files
authored
Merge pull request #2 from Mat-O-Lab/rag_agent
seperate literature rag agent
2 parents 1e088e7 + 3944808 commit 5587a71

5 files changed

Lines changed: 138 additions & 52 deletions

File tree

ckanext/chat/bot/agent.py

Lines changed: 116 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,15 @@
1616
from ckan.model.resource import Resource
1717
from openai import AsyncAzureOpenAI
1818
from openai.resources.embeddings import Embeddings
19-
from pydantic import (
20-
BaseModel,
21-
HttpUrl,
22-
ValidationError,
23-
computed_field,
24-
root_validator,
25-
)
19+
from pydantic import (BaseModel, ConfigDict, HttpUrl, ValidationError,
20+
computed_field, root_validator)
2621
from pydantic_ai import Agent, RunContext
27-
from pydantic_ai.exceptions import (
28-
AgentRunError,
29-
FallbackExceptionGroup,
30-
ModelHTTPError,
31-
ModelRetry,
32-
UnexpectedModelBehavior,
33-
UsageLimitExceeded,
34-
)
35-
from pydantic_ai.messages import (
36-
ModelMessagesTypeAdapter,
37-
ModelRequest,
38-
ModelResponse,
39-
TextPart,
40-
UserPromptPart,
41-
)
22+
from pydantic_ai.exceptions import (AgentRunError, FallbackExceptionGroup,
23+
ModelHTTPError, ModelRetry,
24+
UnexpectedModelBehavior,
25+
UsageLimitExceeded)
26+
from pydantic_ai.messages import (ModelMessagesTypeAdapter, ModelRequest,
27+
ModelResponse, TextPart, UserPromptPart)
4228
from pydantic_ai.models.openai import OpenAIModel, OpenAIModelSettings
4329
from pydantic_ai.providers.openai import OpenAIProvider
4430
from pydantic_ai.usage import UsageLimits
@@ -206,7 +192,8 @@ def download_file(url: str, headers: dict = None, verify: bool = True):
206192
azure_endpoint=toolkit.config.get(
207193
"ckanext.chat.completion_url", "https://your.chat.api"
208194
),
209-
api_version="2024-02-15-preview",
195+
# api_version="2024-02-15-preview",
196+
api_version="2024-06-01",
210197
api_key=toolkit.config.get("ckanext.chat.api_token", "your-api-token"),
211198
)
212199
deployment = toolkit.config.get("ckanext.chat.deployment", "gpt-4-vision-preview")
@@ -218,6 +205,7 @@ def download_file(url: str, headers: dict = None, verify: bool = True):
218205
# provider=OpenAIProvider(base_url=toolkit.config.get("ckanext.chat.completion_url", "https://ollama.local/v1"))
219206
# )
220207

208+
221209
@dataclass
222210
class Deps:
223211
user_id: str
@@ -258,8 +246,9 @@ def init_dynamic_models():
258246
"- You *must* use `get_action_info` on any action you want to run to understand the action and its arguments. After ur instrcuted to run the action immediately to try it.\n"
259247
"- For update or patch actions, always present the proposed changes to the user and ask for explicit confirmation before proceeding.\n"
260248
"- When turning off SSL verification in resource downloads (by setting `ssl_verify=False`), notify the user and request confirmation before proceeding.\n"
261-
"- For general dataset searches and overviews, prioritize using action nameed `package_search`. Run the package_search action with an parameter q="", to fetch all datasets.\n"
262-
"- For more detailed document searches, try `rag_search` first; if it indicates that the milvus client is not set up, switch to `package_search`.\n"
249+
"- For general dataset searches and overviews, prioritize using action nameed `package_search`. Run the package_search action with an parameter q="
250+
", to fetch all datasets.\n"
251+
"- For more detailed document searches, try `literature_search` first; if it indicates that the milvus client is not set up, switch to `package_search`.\n"
263252
"- Ensure you select the appropriate tool based on the user's request and the available capabilities.\n\n"
264253
"Your Toolset:\n\n"
265254
"1. **List CKAN Actions:**\n"
@@ -283,11 +272,11 @@ def init_dynamic_models():
283272
"- **Purpose:** Retrieves file content from CKAN or external sources, with options for partial content retrieval using token parameters.\n"
284273
"- **When to Use:** To fetch the contents of a file resource. If SSL verification is to be disabled (i.e., `ssl_verify=False`), notify the user and ask for confirmation before proceeding.\n\n"
285274
"6. **Retrieve Documents:**\n"
286-
"- **Function:** `rag_search(search_query: List[str]) -> List[RagHit]`\n"
287-
"- **Purpose:** Performs a vector search on document chunks using a list of search strings.\n"
275+
"- **Function:** `literature_search: str) -> List[str]`\n"
276+
"- **Purpose:** Performs a literature search on documents by the question you ask. Mention the number of hits you want to have as return.\n"
288277
"- **When to Use:**\n"
289278
" - For in-depth document searches.\n"
290-
" - If `rag_search` indicates that the milvus client is not set up, then use `package_search` instead.\n"
279+
" - If `literature_search` indicates that the milvus client is not set up, then use `package_search` instead.\n"
291280
" - For general dataset searches or overviews, prefer `package_search`.\n"
292281
)
293282

@@ -297,7 +286,79 @@ def init_dynamic_models():
297286
deps_type=Deps,
298287
system_prompt="".join(system_prompt),
299288
retries=3,
300-
#model_settings=OpenAIModelSettings(openai_reasoning_effort= "low")
289+
# model_settings=OpenAIModelSettings(openai_reasoning_effort= "low")
290+
)
291+
292+
293+
# --------------------- Vector & RAG Models ---------------------
294+
295+
from datetime import datetime
296+
from uuid import UUID
297+
298+
299+
class MyBaseModel(BaseModel):
300+
model_config = ConfigDict(
301+
from_attributes=True, # allows .model_dump to work with ORM-style objects
302+
json_encoders={
303+
UUID: str,
304+
HttpUrl: str,
305+
datetime: lambda dt: dt.isoformat(), # or str(dt)
306+
},
307+
)
308+
309+
310+
class VectorMeta(MyBaseModel):
311+
id: int
312+
chunk_id: Optional[int] = None
313+
chunks: Optional[HttpUrl] = None
314+
dataset_id: Optional[str] = None
315+
dataset_url: Optional[HttpUrl] = None
316+
groups: Optional[list[str]] = None
317+
private: Optional[str] = None
318+
resource_id: Optional[str] = None
319+
source: Optional[HttpUrl] = None
320+
view_url: Optional[list[HttpUrl]] = None
321+
322+
323+
class RagHit(BaseModel):
324+
id: int
325+
distance: Optional[float] = None
326+
title: Optional[str] = None
327+
summary: Optional[str] = None
328+
entity: VectorMeta
329+
330+
331+
class LitSearchResult(BaseModel):
332+
search_str: Optional[list[str]] = None
333+
results: Optional[list] = None
334+
error: Optional[list[str]] = None
335+
336+
337+
rag_prompt = (
338+
"Role:\n\n"
339+
"You are an assistant doing literature search be rephrasig questions and looking up a vector store to a CKAN software instance that must execute tool commands and assess their success or failure. Do not provide endless examples; instead focus on running tools and reasoning based on their outputs and execute steps in your chain of tought right away. Reduce Thinking output to a minimum.\n"
340+
"Key Guidelines:\n\n"
341+
"- when rephasing questions make sure to stay close to the context of the original input as much as possible. Search strings musst consist of 3 words minimum.\n"
342+
"- use the `rag_search` tool find hits for chunks of literature by passing a list of search strings.\n"
343+
"- beside the results also return the phrases you used for the search in the search_str field as list of strings.\n\n"
344+
"- for all hits try to access the text and create a joint result object per distinct source. it should include the title of the source, a summary why its relevant and the rest of the vector metadata."
345+
# "- to any hit you retrieve try to access the text and add the title of the document to a title field of the hit, by either retrieving it from the VectorMeta or the text and create markdown link in the form [title](url to document).\n\n"
346+
# "- to any hit you retrieve add a summary of its relevants to the summary field of the hit in less then 500 string length. If the hit does not provide the desired context, discard the hit!\n\n"
347+
"- make sure that u get at least 5 good results for the context that is in question by running the search again if no number of results is requested.\n\n"
348+
"- any error occuring return to the error field as strings.\n\n"
349+
"Your Toolset:\n\n"
350+
"1. **Retrieve Documents:**\n"
351+
"- **Function:** `rag_search(search_query: List[str]) -> List[RagHit]`\n"
352+
"- **Purpose:** Performs a vector search on document chunks using a list of search strings.\n"
353+
)
354+
355+
rag_agent = Agent(
356+
model=model,
357+
deps_type=Deps,
358+
output_type=LitSearchResult,
359+
system_prompt="".join(rag_prompt),
360+
retries=3,
361+
# model_settings=OpenAIModelSettings(openai_reasoning_effort= "low")
301362
)
302363

303364

@@ -340,8 +401,8 @@ def repl(match):
340401
class RouteModel(BaseModel):
341402
endpoint: str
342403
rule: str
343-
methods: List[str]
344-
variables: Optional[List] = []
404+
methods: Optional[list[str]] = []
405+
variables: Optional[list] = []
345406
full_url_pattern: Optional[str]
346407

347408
@root_validator(pre=True)
@@ -644,11 +705,11 @@ class VectorMeta(BaseModel):
644705
chunks: Optional[HttpUrl] = None
645706
dataset_id: Optional[str] = None
646707
dataset_url: Optional[HttpUrl] = None
647-
groups: Optional[List[str]] = None
708+
groups: Optional[list[str]] = None
648709
private: Optional[str] = None
649710
resource_id: Optional[str] = None
650711
source: Optional[HttpUrl] = None
651-
view_url: Optional[List[HttpUrl]] = None
712+
view_url: Optional[list[HttpUrl]] = None
652713

653714

654715
class RagHit(BaseModel):
@@ -657,7 +718,13 @@ class RagHit(BaseModel):
657718
entity: VectorMeta
658719

659720

660-
@agent.tool
721+
class LitSearchResult(BaseModel):
722+
search_str: Optional[list[str]] = None
723+
hits: Optional[list[RagHit]] = None
724+
error: Optional[str] = None
725+
726+
727+
@rag_agent.tool
661728
async def rag_search(
662729
ctx: RunContext[Deps], search_query: List[str], limit: int = 3
663730
) -> List[RagHit]:
@@ -672,7 +739,7 @@ async def rag_search(
672739
List[RagHit]: List of RagHit instances as a reult of rag search. the object provided a distance attribute with the metrics of similarity and an entity attribute containing the meta data of the vector entity in store.
673740
"""
674741
if not ctx.deps.milvus_client or not ctx.deps.embeddings:
675-
return "The Milvus Client was not setup properly, no rag_search supported in the moment."
742+
return "The Milvus Client was not setup properly, no rag_search supported in the moment."
676743
else:
677744
emb_r = await ctx.deps.embeddings.create(
678745
input=search_query,
@@ -691,12 +758,24 @@ async def rag_search(
691758
if search_res:
692759
hits = []
693760
for i in range(len(query_vectors)):
694-
hits += [RagHit(**item) for item in search_res[i]]
695-
return [hit.json() for hit in hits]
761+
hit = [RagHit(**item) for item in search_res[i]]
762+
log.debug(hit)
763+
hits += hit
764+
return hits
696765
else:
697766
return []
698767

699768

769+
@agent.tool
770+
async def literature_search(ctx: RunContext[Deps], search_question: str) -> list[str]:
771+
r = await rag_agent.run(
772+
f"{search_question}.",
773+
deps=ctx.deps,
774+
)
775+
# log.debug(r.data)
776+
return r.data.json()
777+
778+
700779
def get_user_token(user_id: str) -> Optional[str]:
701780
user = CKANmodel.User.get(user_reference=user_id)
702781
context = {

ckanext/chat/tests/test_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ def test_some_endpoint(app):
4646
@pytest.mark.ckan_config("ckanext.myext.some_key", "some_value")
4747
def test_some_action():
4848
pass
49-
"""
49+
"""

ckanext/chat/views.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@
77
from flask import Blueprint, current_app, jsonify, request
88
from flask.views import MethodView
99

10-
from ckanext.chat.bot.agent import (
11-
Deps,
12-
async_agent_response,
13-
exception_to_model_response,
14-
user_input_to_model_request,
15-
)
10+
from ckanext.chat.bot.agent import (Deps, async_agent_response,
11+
exception_to_model_response,
12+
user_input_to_model_request)
1613
from ckanext.chat.helpers import service_available
1714

1815
blueprint = Blueprint("chat", __name__)
@@ -53,7 +50,8 @@ def get(self):
5350
},
5451
)
5552

56-
from pydantic_ai.messages import TextPart, ModelMessage
53+
54+
from pydantic_ai.messages import ModelMessage, TextPart
5755

5856
# Assuming 'response' is an instance of ModelResponse
5957

@@ -77,9 +75,17 @@ def ask():
7775
# Now response is guaranteed to have new_messages() if no exception occurred.
7876
# Ensure new_messages() is awaited in the sync wrapper if it's async
7977
messages = response.new_messages()
80-
#remove empty text responses parts
81-
[[ message.parts.remove(part) for part in message.parts if isinstance(part, TextPart) and part.content==""] for message in messages]
78+
# remove empty text responses parts
79+
[
80+
[
81+
message.parts.remove(part)
82+
for part in message.parts
83+
if isinstance(part, TextPart) and part.content == ""
84+
]
85+
for message in messages
86+
]
8287
return jsonify({"response": messages})
88+
8389
except Exception as e:
8490
user_promt = user_input_to_model_request(user_input)
8591
error_response = exception_to_model_response(e)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pydantic-ai-slim[openai,logfire]==0.0.47
1+
pydantic-ai-slim[openai,logfire]==0.2.15
22
openai
33
logfire[httpx]
44
nest_asyncio

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# -*- coding: utf-8 -*-
22
from codecs import open # To use a consistent encoding
3-
from os import path, environ
3+
from os import environ, path
44

5-
from setuptools import find_packages, setup # Always prefer setuptools over distutils
5+
from setuptools import ( # Always prefer setuptools over distutils
6+
find_packages, setup)
67

78
here = path.abspath(path.dirname(__file__))
89

@@ -11,15 +12,15 @@
1112
with open(path.join(here, "README.md"), encoding="utf-8") as f:
1213
long_description = f.read()
1314

14-
with open(path.join(here,"requirements.txt")) as f:
15+
with open(path.join(here, "requirements.txt")) as f:
1516
requirements = f.read().splitlines()
1617

1718
setup(
1819
name="""ckanext-chat""",
1920
# Versions should comply with PEP440. For a discussion on single-sourcing
2021
# the version across setup.py and the project code, see
2122
# http://packaging.python.org/en/latest/tutorial.html#version
22-
version=environ.get('VERSION', '0.0.0'),
23+
version=environ.get("VERSION", "0.0.0"),
2324
description="""Extension adds a pydantic ai chat interface to CKAN, that can run actions with user aware context.""",
2425
long_description=long_description,
2526
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)