Skip to content

Commit 6cb3e08

Browse files
authored
Revert: broken agent completion by #9631 (#9760)
### What problem does this PR solve? Revert broken agent completion by #9631. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
1 parent 986b9cb commit 6cb3e08

File tree

2 files changed

+71
-73
lines changed

2 files changed

+71
-73
lines changed

api/apps/sdk/session.py

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
import json
1717
import re
1818
import time
19+
1920
import tiktoken
2021
from flask import Response, jsonify, request
22+
2123
from agent.canvas import Canvas
2224
from api import settings
2325
from api.db import LLMType, StatusEnum
@@ -27,7 +29,8 @@
2729
from api.db.services.canvas_service import completion as agent_completion
2830
from api.db.services.conversation_service import ConversationService, iframe_completion
2931
from api.db.services.conversation_service import completion as rag_completion
30-
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
32+
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
33+
from api.db.services.document_service import DocumentService
3134
from api.db.services.knowledgebase_service import KnowledgebaseService
3235
from api.db.services.llm_service import LLMBundle
3336
from api.db.services.search_service import SearchService
@@ -37,7 +40,7 @@
3740
from rag.app.tag import label_question
3841
from rag.prompts import chunks_format
3942
from rag.prompts.prompt_template import load_prompt
40-
from rag.prompts.prompts import cross_languages, keyword_extraction
43+
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
4144

4245

4346
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@@ -81,10 +84,10 @@ def create_agent_session(tenant_id, agent_id):
8184
if not isinstance(cvs.dsl, str):
8285
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
8386

84-
session_id=get_uuid()
87+
session_id = get_uuid()
8588
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
8689
canvas.reset()
87-
90+
8891
cvs.dsl = json.loads(str(canvas))
8992
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
9093
API4ConversationService.save(**conv)
@@ -442,26 +445,46 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
442445
def agent_completions(tenant_id, agent_id):
443446
req = request.json
444447

448+
ans = {}
449+
if req.get("stream", True):
450+
451+
def generate():
452+
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
453+
if isinstance(answer, str):
454+
try:
455+
ans = json.loads(answer[5:]) # remove "data:"
456+
except Exception:
457+
continue
458+
459+
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
460+
continue
461+
462+
yield answer
463+
464+
yield "data:[DONE]\n\n"
445465

446466
if req.get("stream", True):
447-
resp = Response(agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req), mimetype="text/event-stream")
467+
resp = Response(generate(), mimetype="text/event-stream")
448468
resp.headers.add_header("Cache-control", "no-cache")
449469
resp.headers.add_header("Connection", "keep-alive")
450470
resp.headers.add_header("X-Accel-Buffering", "no")
451471
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
452472
return resp
453-
result = {}
473+
474+
full_content = ""
454475
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
455476
try:
456-
ans = json.loads(answer[5:]) # remove "data:"
457-
if not result:
458-
result = ans.copy()
459-
else:
460-
result["data"]["answer"] += ans["data"]["answer"]
461-
result["data"]["reference"] = ans["data"].get("reference", [])
477+
ans = json.loads(answer[5:])
478+
479+
if ans["event"] == "message":
480+
full_content += ans["data"]["content"]
481+
482+
if ans.get("data", {}).get("reference", None):
483+
ans["data"]["content"] = full_content
484+
return get_result(data=ans)
462485
except Exception as e:
463-
return get_error_data_result(str(e))
464-
return result
486+
return get_result(data=f"**ERROR**: {str(e)}")
487+
return get_result(data=ans)
465488

466489

467490
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@@ -556,10 +579,7 @@ def list_agent_session(tenant_id, agent_id):
556579
if message_num != 0 and messages[message_num]["role"] != "user":
557580
chunk_list = []
558581
# Add boundary and type checks to prevent KeyError
559-
if (chunk_num < len(conv["reference"]) and
560-
conv["reference"][chunk_num] is not None and
561-
isinstance(conv["reference"][chunk_num], dict) and
562-
"chunks" in conv["reference"][chunk_num]):
582+
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
563583
chunks = conv["reference"][chunk_num]["chunks"]
564584
for chunk in chunks:
565585
# Ensure chunk is a dictionary before calling get method
@@ -860,15 +880,7 @@ def begin_inputs(agent_id):
860880
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
861881

862882
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
863-
return get_result(
864-
data={
865-
"title": cvs.title,
866-
"avatar": cvs.avatar,
867-
"inputs": canvas.get_component_input_form("begin"),
868-
"prologue": canvas.get_prologue(),
869-
"mode": canvas.get_mode()
870-
}
871-
)
883+
return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
872884

873885

874886
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
@@ -908,7 +920,7 @@ def stream():
908920
return resp
909921

910922

911-
@manager.route("/searchbots/retrieval_test", methods=['POST']) # noqa: F821
923+
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
912924
@validate_request("kb_id", "question")
913925
def retrieval_test_embedded():
914926
token = request.headers.get("Authorization").split()
@@ -938,18 +950,30 @@ def retrieval_test_embedded():
938950
if not tenant_id:
939951
return get_error_data_result(message="permission denined.")
940952

953+
if req.get("search_id", ""):
954+
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
955+
meta_data_filter = search_config.get("meta_data_filter", {})
956+
metas = DocumentService.get_meta_by_kbs(kb_ids)
957+
if meta_data_filter.get("method") == "auto":
958+
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
959+
filters = gen_meta_filter(chat_mdl, metas, question)
960+
doc_ids.extend(meta_filter(metas, filters))
961+
if not doc_ids:
962+
doc_ids = None
963+
elif meta_data_filter.get("method") == "manual":
964+
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
965+
if not doc_ids:
966+
doc_ids = None
967+
941968
try:
942969
tenants = UserTenantService.query(user_id=tenant_id)
943970
for kb_id in kb_ids:
944971
for tenant in tenants:
945-
if KnowledgebaseService.query(
946-
tenant_id=tenant.tenant_id, id=kb_id):
972+
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
947973
tenant_ids.append(tenant.tenant_id)
948974
break
949975
else:
950-
return get_json_result(
951-
data=False, message='Only owner of knowledgebase authorized for this operation.',
952-
code=settings.RetCode.OPERATING_ERROR)
976+
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
953977

954978
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
955979
if not e:
@@ -969,17 +993,11 @@ def retrieval_test_embedded():
969993
question += keyword_extraction(chat_mdl, question)
970994

971995
labels = label_question(question, [kb])
972-
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
973-
similarity_threshold, vector_similarity_weight, top,
974-
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
975-
rank_feature=labels
976-
)
996+
ranks = settings.retrievaler.retrieval(
997+
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
998+
)
977999
if use_kg:
978-
ck = settings.kg_retrievaler.retrieval(question,
979-
tenant_ids,
980-
kb_ids,
981-
embd_mdl,
982-
LLMBundle(kb.tenant_id, LLMType.CHAT))
1000+
ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
9831001
if ck["content_with_weight"]:
9841002
ranks["chunks"].insert(0, ck)
9851003

@@ -990,8 +1008,7 @@ def retrieval_test_embedded():
9901008
return get_json_result(data=ranks)
9911009
except Exception as e:
9921010
if str(e).find("not_found") > 0:
993-
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
994-
code=settings.RetCode.DATA_ERROR)
1011+
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
9951012
return server_error_response(e)
9961013

9971014

api/db/services/canvas_service.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -135,24 +135,6 @@ def accessible(cls, canvas_id, tenant_id):
135135
return True
136136

137137

138-
def structure_answer(conv, ans, message_id, session_id):
139-
if not conv:
140-
return ans
141-
content = ""
142-
if ans["event"] == "message":
143-
if ans["data"].get("start_to_think") is True:
144-
content = "<think>"
145-
elif ans["data"].get("end_to_think") is True:
146-
content = "</think>"
147-
else:
148-
content = ans["data"]["content"]
149-
150-
reference = ans["data"].get("reference")
151-
result = {"id": message_id, "session_id": session_id, "answer": content}
152-
if reference:
153-
result["reference"] = [reference]
154-
return result
155-
156138
def completion(tenant_id, agent_id, session_id=None, **kwargs):
157139
query = kwargs.get("query", "") or kwargs.get("question", "")
158140
files = kwargs.get("files", [])
@@ -196,14 +178,13 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
196178
})
197179
txt = ""
198180
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
199-
ans = structure_answer(conv, ans, message_id, session_id)
200-
txt += ans["answer"]
201-
if ans.get("answer") or ans.get("reference"):
202-
yield "data:" + json.dumps({"code": 0, "data": ans},
203-
ensure_ascii=False) + "\n\n"
181+
ans["session_id"] = session_id
182+
if ans["event"] == "message":
183+
txt += ans["data"]["content"]
184+
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
204185

205186
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
206-
conv.reference.append(canvas.get_reference())
187+
conv.reference = canvas.get_reference()
207188
conv.errors = canvas.error
208189
conv.dsl = str(canvas)
209190
conv = conv.to_dict()
@@ -232,9 +213,9 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
232213
except Exception as e:
233214
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
234215
continue
235-
if not ans["data"]["answer"]:
216+
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
236217
continue
237-
content_piece = ans["data"]["answer"]
218+
content_piece = ans["data"]["content"]
238219
completion_tokens += len(tiktokenenc.encode(content_piece))
239220

240221
yield "data: " + json.dumps(
@@ -279,9 +260,9 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
279260
):
280261
if isinstance(ans, str):
281262
ans = json.loads(ans[5:])
282-
if not ans["data"]["answer"]:
263+
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
283264
continue
284-
all_content += ans["data"]["answer"]
265+
all_content += ans["data"]["content"]
285266

286267
completion_tokens = len(tiktokenenc.encode(all_content))
287268

0 commit comments

Comments
 (0)