1+ import asyncio
2+ import json
3+ import logging
4+ from typing import Any , AsyncIterator , Awaitable
5+
16from fastapi import APIRouter , Depends , HTTPException
7+ from fastapi .responses import StreamingResponse
28from sqlalchemy .ext .asyncio import AsyncSession
39
410from app .db .conversation_repository import ConversationRepository
1622from app .services .academic .search_service import AcademicSearchService
1723
1824router = APIRouter ()
25+ logger = logging .getLogger (__name__ )
1926
2027@router .post ("/search" , response_model = PaperSearchResponse )
2128async def search_papers (
@@ -42,54 +49,17 @@ async def intelligent_search(
4249 # 对 IntelligentSearchService.__init__ 的可选依赖参数生成 OpenAPI schema。
4350 service = IntelligentSearchService ()
4451
45- try :
46- response = await service .search (request )
47-
48- # 如果用户已登录,保存对话历史
49- if current_user :
50- repo = ConversationRepository (db )
51-
52- # 如果没有指定conversation_id,创建新对话
53- conversation_id = request .conversation_id
54- if not conversation_id :
55- # 使用查询的前50个字符作为标题
56- title = request .query [:50 ] if len (request .query ) <= 50 else request .query [:47 ] + "..."
57- conversation = await repo .create_conversation (
58- current_user .id ,
59- ConversationCreate (title = title )
60- )
61- conversation_id = conversation .id
62-
63- # 保存用户消息
64- await repo .add_message (
65- conversation_id ,
66- ConversationMessageCreate (
67- role = "user" ,
68- content = request .query ,
69- )
70- )
52+ if request .stream :
53+ return StreamingResponse (
54+ _stream_intelligent_search (service , request , current_user , db ),
55+ media_type = "text/event-stream" ,
56+ )
7157
72- # 保存AI回复
73- assistant_content = response .answer .response if response .answer else ""
74- await repo .add_message (
75- conversation_id ,
76- ConversationMessageCreate (
77- role = "assistant" ,
78- content = assistant_content ,
79- extra_data = {
80- "papers_count" : len (response .papers ) if response .papers else 0 ,
81- "search_performed" : response .search_performed ,
82- }
83- )
84- )
85-
86- # 将conversation_id添加到响应中(扩展响应模型)
87- # 注意:这需要修改IntelligentSearchResponse schema
88- response_dict = response .model_dump ()
89- response_dict ["conversation_id" ] = conversation_id
90- return response_dict
91-
92- return response
58+ try :
59+ payload , persist = await _prepare_stream_payload (response , request , current_user , db )
60+ if persist :
61+ await persist ()
62+ return payload
9363 except ValueError as exc :
9464 raise HTTPException (status_code = 400 , detail = str (exc )) from exc
9565
@@ -102,3 +72,102 @@ async def get_paper_details(paper_id: str):
10272async def get_recommendations (domain : str ):
10373 """根据研究领域推荐论文"""
10474 pass
75+
76+
77+ async def _prepare_stream_payload (
78+ response : IntelligentSearchResponse ,
79+ request : IntelligentSearchRequest ,
80+ current_user : User | None ,
81+ db : AsyncSession ,
82+ ) -> tuple [Any , Awaitable [None ] | None ]:
83+ """Return a payload ready for serialization plus a persistence coroutine."""
84+
85+ if not current_user :
86+ return response , None
87+
88+ repo = ConversationRepository (db )
89+ conversation_id = request .conversation_id
90+ if conversation_id :
91+ conversation = await repo .get_conversation (conversation_id , current_user .id )
92+ if not conversation or conversation .category != "search" :
93+ raise HTTPException (
94+ status_code = 404 ,
95+ detail = "Conversation not found for intelligent search" ,
96+ )
97+ else :
98+ title = request .query [:50 ] if len (request .query ) <= 50 else request .query [:47 ] + "..."
99+ conversation = await repo .create_conversation (
100+ current_user .id ,
101+ ConversationCreate (title = title , category = "search" ),
102+ )
103+ conversation_id = conversation .id
104+
105+ payload = response .model_dump ()
106+ payload ["conversation_id" ] = conversation_id
107+
108+ async def persist () -> None :
109+ await repo .add_message (
110+ conversation_id ,
111+ ConversationMessageCreate (
112+ role = "user" ,
113+ content = request .query ,
114+ ),
115+ )
116+ assistant_content = response .answer .response if response .answer else ""
117+ await repo .add_message (
118+ conversation_id ,
119+ ConversationMessageCreate (
120+ role = "assistant" ,
121+ content = assistant_content ,
122+ extra_data = {
123+ "papers_count" : len (response .papers ) if response .papers else 0 ,
124+ "search_performed" : response .search_performed ,
125+ },
126+ ),
127+ )
128+
129+ return payload , persist
130+
131+
132+ def _format_sse (event : dict [str , Any ]) -> bytes :
133+ serialized = json .dumps (event , ensure_ascii = False )
134+ return f"data: { serialized } \n \n " .encode ("utf-8" )
135+
136+
137+ async def _stream_intelligent_search (
138+ service : IntelligentSearchService ,
139+ request : IntelligentSearchRequest ,
140+ current_user : User | None ,
141+ db : AsyncSession ,
142+ ) -> AsyncIterator [bytes ]:
143+ queue : asyncio .Queue [dict [str , Any ] | None ] = asyncio .Queue ()
144+
145+ async def push_delta (token : str ) -> None :
146+ await queue .put ({"type" : "token" , "content" : token })
147+
148+ async def produce () -> None :
149+ try :
150+ response = await service .search (request , stream_callback = push_delta )
151+ payload , persist = await _prepare_stream_payload (response , request , current_user , db )
152+ payload_dict = payload if isinstance (payload , dict ) else payload .model_dump ()
153+ await queue .put ({"type" : "result" , "payload" : payload_dict })
154+ if persist :
155+ try :
156+ await persist ()
157+ except Exception as exc : # pragma: no cover
158+ logger .exception ("Failed to persist streaming conversation: %s" , exc )
159+ except Exception as exc : # pragma: no cover - streaming path
160+ logger .exception ("Intelligent search streaming失败: %s" , exc )
161+ await queue .put ({"type" : "error" , "message" : str (exc )})
162+ finally :
163+ await queue .put (None )
164+
165+ producer = asyncio .create_task (produce ())
166+ try :
167+ while True :
168+ event = await queue .get ()
169+ if event is None :
170+ break
171+ yield _format_sse (event )
172+ finally :
173+ await producer
0 commit comments