16
16
import json
17
17
import re
18
18
import time
19
+
19
20
import tiktoken
20
21
from flask import Response , jsonify , request
22
+
21
23
from agent .canvas import Canvas
22
24
from api import settings
23
25
from api .db import LLMType , StatusEnum
27
29
from api .db .services .canvas_service import completion as agent_completion
28
30
from api .db .services .conversation_service import ConversationService , iframe_completion
29
31
from api .db .services .conversation_service import completion as rag_completion
30
- from api .db .services .dialog_service import DialogService , ask , chat , gen_mindmap
32
+ from api .db .services .dialog_service import DialogService , ask , chat , gen_mindmap , meta_filter
33
+ from api .db .services .document_service import DocumentService
31
34
from api .db .services .knowledgebase_service import KnowledgebaseService
32
35
from api .db .services .llm_service import LLMBundle
33
36
from api .db .services .search_service import SearchService
37
40
from rag .app .tag import label_question
38
41
from rag .prompts import chunks_format
39
42
from rag .prompts .prompt_template import load_prompt
40
- from rag .prompts .prompts import cross_languages , keyword_extraction
43
+ from rag .prompts .prompts import cross_languages , gen_meta_filter , keyword_extraction
41
44
42
45
43
46
@manager .route ("/chats/<chat_id>/sessions" , methods = ["POST" ]) # noqa: F821
@@ -81,10 +84,10 @@ def create_agent_session(tenant_id, agent_id):
81
84
if not isinstance (cvs .dsl , str ):
82
85
cvs .dsl = json .dumps (cvs .dsl , ensure_ascii = False )
83
86
84
- session_id = get_uuid ()
87
+ session_id = get_uuid ()
85
88
canvas = Canvas (cvs .dsl , tenant_id , agent_id )
86
89
canvas .reset ()
87
-
90
+
88
91
cvs .dsl = json .loads (str (canvas ))
89
92
conv = {"id" : session_id , "dialog_id" : cvs .id , "user_id" : user_id , "message" : [{"role" : "assistant" , "content" : canvas .get_prologue ()}], "source" : "agent" , "dsl" : cvs .dsl }
90
93
API4ConversationService .save (** conv )
@@ -442,26 +445,46 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
442
445
def agent_completions (tenant_id , agent_id ):
443
446
req = request .json
444
447
448
+ ans = {}
449
+ if req .get ("stream" , True ):
450
+
451
+ def generate ():
452
+ for answer in agent_completion (tenant_id = tenant_id , agent_id = agent_id , ** req ):
453
+ if isinstance (answer , str ):
454
+ try :
455
+ ans = json .loads (answer [5 :]) # remove "data:"
456
+ except Exception :
457
+ continue
458
+
459
+ if ans .get ("event" ) != "message" or not ans .get ("data" , {}).get ("reference" , None ):
460
+ continue
461
+
462
+ yield answer
463
+
464
+ yield "data:[DONE]\n \n "
445
465
446
466
if req .get ("stream" , True ):
447
- resp = Response (agent_completion ( tenant_id = tenant_id , agent_id = agent_id , ** req ), mimetype = "text/event-stream" )
467
+ resp = Response (generate ( ), mimetype = "text/event-stream" )
448
468
resp .headers .add_header ("Cache-control" , "no-cache" )
449
469
resp .headers .add_header ("Connection" , "keep-alive" )
450
470
resp .headers .add_header ("X-Accel-Buffering" , "no" )
451
471
resp .headers .add_header ("Content-Type" , "text/event-stream; charset=utf-8" )
452
472
return resp
453
- result = {}
473
+
474
+ full_content = ""
454
475
for answer in agent_completion (tenant_id = tenant_id , agent_id = agent_id , ** req ):
455
476
try :
456
- ans = json .loads (answer [5 :]) # remove "data:"
457
- if not result :
458
- result = ans .copy ()
459
- else :
460
- result ["data" ]["answer" ] += ans ["data" ]["answer" ]
461
- result ["data" ]["reference" ] = ans ["data" ].get ("reference" , [])
477
+ ans = json .loads (answer [5 :])
478
+
479
+ if ans ["event" ] == "message" :
480
+ full_content += ans ["data" ]["content" ]
481
+
482
+ if ans .get ("data" , {}).get ("reference" , None ):
483
+ ans ["data" ]["content" ] = full_content
484
+ return get_result (data = ans )
462
485
except Exception as e :
463
- return get_error_data_result ( str (e ))
464
- return result
486
+ return get_result ( data = f"**ERROR**: { str (e )} " )
487
+ return get_result ( data = ans )
465
488
466
489
467
490
@manager .route ("/chats/<chat_id>/sessions" , methods = ["GET" ]) # noqa: F821
@@ -556,10 +579,7 @@ def list_agent_session(tenant_id, agent_id):
556
579
if message_num != 0 and messages [message_num ]["role" ] != "user" :
557
580
chunk_list = []
558
581
# Add boundary and type checks to prevent KeyError
559
- if (chunk_num < len (conv ["reference" ]) and
560
- conv ["reference" ][chunk_num ] is not None and
561
- isinstance (conv ["reference" ][chunk_num ], dict ) and
562
- "chunks" in conv ["reference" ][chunk_num ]):
582
+ if chunk_num < len (conv ["reference" ]) and conv ["reference" ][chunk_num ] is not None and isinstance (conv ["reference" ][chunk_num ], dict ) and "chunks" in conv ["reference" ][chunk_num ]:
563
583
chunks = conv ["reference" ][chunk_num ]["chunks" ]
564
584
for chunk in chunks :
565
585
# Ensure chunk is a dictionary before calling get method
@@ -860,15 +880,7 @@ def begin_inputs(agent_id):
860
880
return get_error_data_result (f"Can't find agent by ID: { agent_id } " )
861
881
862
882
canvas = Canvas (json .dumps (cvs .dsl ), objs [0 ].tenant_id )
863
- return get_result (
864
- data = {
865
- "title" : cvs .title ,
866
- "avatar" : cvs .avatar ,
867
- "inputs" : canvas .get_component_input_form ("begin" ),
868
- "prologue" : canvas .get_prologue (),
869
- "mode" : canvas .get_mode ()
870
- }
871
- )
883
+ return get_result (data = {"title" : cvs .title , "avatar" : cvs .avatar , "inputs" : canvas .get_component_input_form ("begin" ), "prologue" : canvas .get_prologue (), "mode" : canvas .get_mode ()})
872
884
873
885
874
886
@manager .route ("/searchbots/ask" , methods = ["POST" ]) # noqa: F821
@@ -908,7 +920,7 @@ def stream():
908
920
return resp
909
921
910
922
911
- @manager .route ("/searchbots/retrieval_test" , methods = [' POST' ]) # noqa: F821
923
+ @manager .route ("/searchbots/retrieval_test" , methods = [" POST" ]) # noqa: F821
912
924
@validate_request ("kb_id" , "question" )
913
925
def retrieval_test_embedded ():
914
926
token = request .headers .get ("Authorization" ).split ()
@@ -938,18 +950,30 @@ def retrieval_test_embedded():
938
950
if not tenant_id :
939
951
return get_error_data_result (message = "permission denined." )
940
952
953
+ if req .get ("search_id" , "" ):
954
+ search_config = SearchService .get_detail (req .get ("search_id" , "" )).get ("search_config" , {})
955
+ meta_data_filter = search_config .get ("meta_data_filter" , {})
956
+ metas = DocumentService .get_meta_by_kbs (kb_ids )
957
+ if meta_data_filter .get ("method" ) == "auto" :
958
+ chat_mdl = LLMBundle (tenant_id , LLMType .CHAT , llm_name = search_config .get ("chat_id" , "" ))
959
+ filters = gen_meta_filter (chat_mdl , metas , question )
960
+ doc_ids .extend (meta_filter (metas , filters ))
961
+ if not doc_ids :
962
+ doc_ids = None
963
+ elif meta_data_filter .get ("method" ) == "manual" :
964
+ doc_ids .extend (meta_filter (metas , meta_data_filter ["manual" ]))
965
+ if not doc_ids :
966
+ doc_ids = None
967
+
941
968
try :
942
969
tenants = UserTenantService .query (user_id = tenant_id )
943
970
for kb_id in kb_ids :
944
971
for tenant in tenants :
945
- if KnowledgebaseService .query (
946
- tenant_id = tenant .tenant_id , id = kb_id ):
972
+ if KnowledgebaseService .query (tenant_id = tenant .tenant_id , id = kb_id ):
947
973
tenant_ids .append (tenant .tenant_id )
948
974
break
949
975
else :
950
- return get_json_result (
951
- data = False , message = 'Only owner of knowledgebase authorized for this operation.' ,
952
- code = settings .RetCode .OPERATING_ERROR )
976
+ return get_json_result (data = False , message = "Only owner of knowledgebase authorized for this operation." , code = settings .RetCode .OPERATING_ERROR )
953
977
954
978
e , kb = KnowledgebaseService .get_by_id (kb_ids [0 ])
955
979
if not e :
@@ -969,17 +993,11 @@ def retrieval_test_embedded():
969
993
question += keyword_extraction (chat_mdl , question )
970
994
971
995
labels = label_question (question , [kb ])
972
- ranks = settings .retrievaler .retrieval (question , embd_mdl , tenant_ids , kb_ids , page , size ,
973
- similarity_threshold , vector_similarity_weight , top ,
974
- doc_ids , rerank_mdl = rerank_mdl , highlight = req .get ("highlight" ),
975
- rank_feature = labels
976
- )
996
+ ranks = settings .retrievaler .retrieval (
997
+ question , embd_mdl , tenant_ids , kb_ids , page , size , similarity_threshold , vector_similarity_weight , top , doc_ids , rerank_mdl = rerank_mdl , highlight = req .get ("highlight" ), rank_feature = labels
998
+ )
977
999
if use_kg :
978
- ck = settings .kg_retrievaler .retrieval (question ,
979
- tenant_ids ,
980
- kb_ids ,
981
- embd_mdl ,
982
- LLMBundle (kb .tenant_id , LLMType .CHAT ))
1000
+ ck = settings .kg_retrievaler .retrieval (question , tenant_ids , kb_ids , embd_mdl , LLMBundle (kb .tenant_id , LLMType .CHAT ))
983
1001
if ck ["content_with_weight" ]:
984
1002
ranks ["chunks" ].insert (0 , ck )
985
1003
@@ -990,8 +1008,7 @@ def retrieval_test_embedded():
990
1008
return get_json_result (data = ranks )
991
1009
except Exception as e :
992
1010
if str (e ).find ("not_found" ) > 0 :
993
- return get_json_result (data = False , message = 'No chunk found! Check the chunk status please!' ,
994
- code = settings .RetCode .DATA_ERROR )
1011
+ return get_json_result (data = False , message = "No chunk found! Check the chunk status please!" , code = settings .RetCode .DATA_ERROR )
995
1012
return server_error_response (e )
996
1013
997
1014
0 commit comments