@@ -774,7 +774,7 @@ def check_save_sql(self, session: Session, res: str) -> str:
774774
775775 return sql
776776
777- def check_save_chart (self , session : Session , res : str ) -> Dict [str , Any ]:
777+ def check_save_chart (self , session : Session , res : str , sql_prase : str ) -> Dict [str , Any ]:
778778
779779 json_str = extract_nested_json (res )
780780 if json_str is None :
@@ -814,7 +814,7 @@ def check_save_chart(self, session: Session, res: str) -> Dict[str, Any]:
814814 if error :
815815 raise SingleMessageError (message )
816816
817- save_chart (session = session , chart = orjson .dumps (chart ).decode (), record_id = self .record .id )
817+ save_chart (session = session , chart = orjson .dumps (chart ).decode (), record_id = self .record .id , sql_prase = sql_prase )
818818
819819 return chart
820820
@@ -989,6 +989,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
989989
990990 use_dynamic_ds : bool = self .current_assistant and self .current_assistant .type in dynamic_ds_types
991991 is_page_embedded : bool = self .current_assistant and self .current_assistant .type == 4
992+ is_assistant_embedded : bool = self .current_assistant and self .current_assistant .type == 1
992993 dynamic_sql_result = None
993994 sqlbot_temp_sql_text = None
994995 assistant_dynamic_sql = None
@@ -1092,7 +1093,16 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
10921093
10931094 # filter chart
10941095 SQLBotLogUtil .info (full_chart_text )
1095- chart = self .check_save_chart (session = _session , res = full_chart_text )
1096+
1097+ # sql prase
1098+ if is_assistant_embedded :
1099+ sql_prase = self .generate_sql_paras (_session ,real_execute_sql ,full_chart_text )
1100+ if in_chat :
1101+ yield 'data:' + orjson .dumps (
1102+ {'content' : sql_prase ,
1103+ 'type' : 'sql_prase' }).decode () + '\n \n '
1104+
1105+ chart = self .check_save_chart (session = _session , res = full_chart_text ,sql_prase = sql_prase )
10961106 SQLBotLogUtil .info (chart )
10971107
10981108 if not stream :
@@ -1283,6 +1293,49 @@ def validate_history_ds(self, session: Session):
12831293 except Exception as e :
12841294 raise SingleMessageError (f"ds is invalid [{ str (e )} ]" )
12851295
1296+ def generate_sql_paras (self , _session : Session , real_execute_sql : Optional [str ] = '' ,chart : Optional [str ] = '' ):
1297+ # prase sql
1298+ prase_sql_msg : List [Union [BaseMessage , dict [str , Any ]]] = []
1299+ prase_sql_msg .append (SystemMessage (self .chat_question .prase_sql_sys_question ()))
1300+ prase_sql_msg .append (HumanMessage (self .chat_question .prase_sql_user_question (real_execute_sql ,chart )))
1301+ self .current_logs [OperationEnum .PRASE_SQL ] = start_log (session = _session ,
1302+ ai_modal_id = self .chat_question .ai_modal_id ,
1303+ ai_modal_name = self .chat_question .ai_modal_name ,
1304+ operate = OperationEnum .PRASE_SQL ,
1305+ record_id = self .record .id ,
1306+ full_message = [{'type' : msg .type ,
1307+ 'content' : msg .content }
1308+ for
1309+ msg in prase_sql_msg ])
1310+
1311+ token_usage = {}
1312+ prase_res = process_stream (self .llm .stream (prase_sql_msg ), token_usage )
1313+ prase_full_thinking_text = ''
1314+ prase_full_text = ''
1315+ for chunk in prase_res :
1316+ if chunk .get ('content' ):
1317+ prase_full_text += chunk .get ('content' )
1318+ if chunk .get ('reasoning_content' ):
1319+ prase_full_thinking_text += chunk .get ('reasoning_content' )
1320+ prase_sql_msg .append (AIMessage (prase_full_text ))
1321+
1322+ self .current_logs [OperationEnum .PRASE_SQL ] = end_log (session = _session ,
1323+ log = self .current_logs [
1324+ OperationEnum .PRASE_SQL ],
1325+ full_message = [
1326+ {'type' : msg .type ,
1327+ 'content' : msg .content }
1328+ for msg in prase_sql_msg ],
1329+ reasoning_content = prase_full_thinking_text ,
1330+ token_usage = token_usage )
1331+
1332+ prase_json_str = extract_nested_json (prase_full_text )
1333+ return prase_json_str
1334+ # if prase_json_str is None:
1335+ # raise SingleMessageError(f'Cannot parse datasource from answer: {prase_full_text}')
1336+ # ds = orjson.loads(prase_json_str)
1337+ # return ds['info']
1338+
12861339
12871340def execute_sql_with_db (db : SQLDatabase , sql : str ) -> str :
12881341 """Execute SQL query using SQLDatabase
@@ -1455,7 +1508,6 @@ def process_stream(res: Iterator[BaseMessageChunk],
14551508 }
14561509 get_token_usage (chunk , token_usage )
14571510
1458-
14591511def get_lang_name (lang : str ):
14601512 if not lang :
14611513 return '简体中文'
0 commit comments