Skip to content

Commit 98ba09a

Browse files
authored
Merge pull request #43 from Dynamite2003/feature/streaming
feat: add streaming output of LLM
2 parents eddbb2d + 4bd6bba commit 98ba09a

File tree

20 files changed

+1226
-143
lines changed

20 files changed

+1226
-143
lines changed
Lines changed: 116 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
import asyncio
2+
import json
3+
import logging
4+
from typing import Any, AsyncIterator, Awaitable
5+
16
from fastapi import APIRouter, Depends, HTTPException
7+
from fastapi.responses import StreamingResponse
28
from sqlalchemy.ext.asyncio import AsyncSession
39

410
from app.db.conversation_repository import ConversationRepository
@@ -16,6 +22,7 @@
1622
from app.services.academic.search_service import AcademicSearchService
1723

1824
router = APIRouter()
25+
logger = logging.getLogger(__name__)
1926

2027
@router.post("/search", response_model=PaperSearchResponse)
2128
async def search_papers(
@@ -42,54 +49,17 @@ async def intelligent_search(
4249
# 对 IntelligentSearchService.__init__ 的可选依赖参数生成 OpenAPI schema。
4350
service = IntelligentSearchService()
4451

45-
try:
46-
response = await service.search(request)
47-
48-
# 如果用户已登录,保存对话历史
49-
if current_user:
50-
repo = ConversationRepository(db)
51-
52-
# 如果没有指定conversation_id,创建新对话
53-
conversation_id = request.conversation_id
54-
if not conversation_id:
55-
# 使用查询的前50个字符作为标题
56-
title = request.query[:50] if len(request.query) <= 50 else request.query[:47] + "..."
57-
conversation = await repo.create_conversation(
58-
current_user.id,
59-
ConversationCreate(title=title)
60-
)
61-
conversation_id = conversation.id
62-
63-
# 保存用户消息
64-
await repo.add_message(
65-
conversation_id,
66-
ConversationMessageCreate(
67-
role="user",
68-
content=request.query,
69-
)
70-
)
52+
if request.stream:
53+
return StreamingResponse(
54+
_stream_intelligent_search(service, request, current_user, db),
55+
media_type="text/event-stream",
56+
)
7157

72-
# 保存AI回复
73-
assistant_content = response.answer.response if response.answer else ""
74-
await repo.add_message(
75-
conversation_id,
76-
ConversationMessageCreate(
77-
role="assistant",
78-
content=assistant_content,
79-
extra_data={
80-
"papers_count": len(response.papers) if response.papers else 0,
81-
"search_performed": response.search_performed,
82-
}
83-
)
84-
)
85-
86-
# 将conversation_id添加到响应中(扩展响应模型)
87-
# 注意:这需要修改IntelligentSearchResponse schema
88-
response_dict = response.model_dump()
89-
response_dict["conversation_id"] = conversation_id
90-
return response_dict
91-
92-
return response
58+
try:
59+
payload, persist = await _prepare_stream_payload(response, request, current_user, db)
60+
if persist:
61+
await persist()
62+
return payload
9363
except ValueError as exc:
9464
raise HTTPException(status_code=400, detail=str(exc)) from exc
9565

@@ -102,3 +72,102 @@ async def get_paper_details(paper_id: str):
10272
async def get_recommendations(domain: str):
10373
"""根据研究领域推荐论文"""
10474
pass
75+
76+
77+
async def _prepare_stream_payload(
78+
response: IntelligentSearchResponse,
79+
request: IntelligentSearchRequest,
80+
current_user: User | None,
81+
db: AsyncSession,
82+
) -> tuple[Any, Awaitable[None] | None]:
83+
"""Return a payload ready for serialization plus a persistence coroutine."""
84+
85+
if not current_user:
86+
return response, None
87+
88+
repo = ConversationRepository(db)
89+
conversation_id = request.conversation_id
90+
if conversation_id:
91+
conversation = await repo.get_conversation(conversation_id, current_user.id)
92+
if not conversation or conversation.category != "search":
93+
raise HTTPException(
94+
status_code=404,
95+
detail="Conversation not found for intelligent search",
96+
)
97+
else:
98+
title = request.query[:50] if len(request.query) <= 50 else request.query[:47] + "..."
99+
conversation = await repo.create_conversation(
100+
current_user.id,
101+
ConversationCreate(title=title, category="search"),
102+
)
103+
conversation_id = conversation.id
104+
105+
payload = response.model_dump()
106+
payload["conversation_id"] = conversation_id
107+
108+
async def persist() -> None:
109+
await repo.add_message(
110+
conversation_id,
111+
ConversationMessageCreate(
112+
role="user",
113+
content=request.query,
114+
),
115+
)
116+
assistant_content = response.answer.response if response.answer else ""
117+
await repo.add_message(
118+
conversation_id,
119+
ConversationMessageCreate(
120+
role="assistant",
121+
content=assistant_content,
122+
extra_data={
123+
"papers_count": len(response.papers) if response.papers else 0,
124+
"search_performed": response.search_performed,
125+
},
126+
),
127+
)
128+
129+
return payload, persist
130+
131+
132+
def _format_sse(event: dict[str, Any]) -> bytes:
133+
serialized = json.dumps(event, ensure_ascii=False)
134+
return f"data: {serialized}\n\n".encode("utf-8")
135+
136+
137+
async def _stream_intelligent_search(
138+
service: IntelligentSearchService,
139+
request: IntelligentSearchRequest,
140+
current_user: User | None,
141+
db: AsyncSession,
142+
) -> AsyncIterator[bytes]:
143+
queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue()
144+
145+
async def push_delta(token: str) -> None:
146+
await queue.put({"type": "token", "content": token})
147+
148+
async def produce() -> None:
149+
try:
150+
response = await service.search(request, stream_callback=push_delta)
151+
payload, persist = await _prepare_stream_payload(response, request, current_user, db)
152+
payload_dict = payload if isinstance(payload, dict) else payload.model_dump()
153+
await queue.put({"type": "result", "payload": payload_dict})
154+
if persist:
155+
try:
156+
await persist()
157+
except Exception as exc: # pragma: no cover
158+
logger.exception("Failed to persist streaming conversation: %s", exc)
159+
except Exception as exc: # pragma: no cover - streaming path
160+
logger.exception("Intelligent search streaming失败: %s", exc)
161+
await queue.put({"type": "error", "message": str(exc)})
162+
finally:
163+
await queue.put(None)
164+
165+
producer = asyncio.create_task(produce())
166+
try:
167+
while True:
168+
event = await queue.get()
169+
if event is None:
170+
break
171+
yield _format_sse(event)
172+
finally:
173+
await producer

backend/app/api/v1/endpoints/conversations.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Conversation API endpoints."""
2-
from typing import List
2+
from typing import List, Literal, Optional
33

44
from fastapi import APIRouter, Depends, HTTPException, Query, status
55
from sqlalchemy.ext.asyncio import AsyncSession
@@ -37,6 +37,9 @@ async def create_conversation(
3737
async def list_conversations(
3838
skip: int = Query(0, ge=0),
3939
limit: int = Query(50, ge=1, le=100),
40+
category: Optional[Literal["search", "reading"]] = Query(
41+
None, description="对话类别过滤:search 或 reading"
42+
),
4043
current_user: User = Depends(get_current_user),
4144
db: AsyncSession = Depends(get_db),
4245
):
@@ -46,6 +49,7 @@ async def list_conversations(
4649
current_user.id,
4750
skip=skip,
4851
limit=limit,
52+
category=category,
4953
)
5054

5155
# 为每个对话添加额外信息
@@ -67,6 +71,7 @@ async def list_conversations(
6771
"id": conv.id,
6872
"user_id": conv.user_id,
6973
"title": conv.title,
74+
"category": conv.category,
7075
"created_at": conv.created_at,
7176
"updated_at": conv.updated_at,
7277
"is_deleted": conv.is_deleted,
@@ -81,14 +86,17 @@ async def list_conversations(
8186
@router.get("/{conversation_id}", response_model=ConversationDetail)
8287
async def get_conversation(
8388
conversation_id: int,
89+
category: Optional[Literal["search", "reading"]] = Query(
90+
None, description="过滤对话类别"
91+
),
8492
current_user: User = Depends(get_current_user),
8593
db: AsyncSession = Depends(get_db),
8694
):
8795
"""获取特定对话详情"""
8896
repo = ConversationRepository(db)
8997
conversation = await repo.get_conversation(conversation_id, current_user.id)
9098

91-
if not conversation:
99+
if not conversation or (category and conversation.category != category):
92100
raise HTTPException(
93101
status_code=status.HTTP_404_NOT_FOUND,
94102
detail="对话不存在或无权访问",
@@ -101,14 +109,17 @@ async def get_conversation(
101109
async def update_conversation(
102110
conversation_id: int,
103111
data: ConversationUpdate,
112+
category: Optional[Literal["search", "reading"]] = Query(
113+
None, description="过滤对话类别"
114+
),
104115
current_user: User = Depends(get_current_user),
105116
db: AsyncSession = Depends(get_db),
106117
):
107118
"""更新对话信息"""
108119
repo = ConversationRepository(db)
109120
conversation = await repo.update_conversation(conversation_id, current_user.id, data)
110121

111-
if not conversation:
122+
if not conversation or (category and conversation.category != category):
112123
raise HTTPException(
113124
status_code=status.HTTP_404_NOT_FOUND,
114125
detail="对话不存在或无权访问",
@@ -128,6 +139,7 @@ async def update_conversation(
128139
"id": conversation.id,
129140
"user_id": conversation.user_id,
130141
"title": conversation.title,
142+
"category": conversation.category,
131143
"created_at": conversation.created_at,
132144
"updated_at": conversation.updated_at,
133145
"is_deleted": conversation.is_deleted,
@@ -141,26 +153,32 @@ async def update_conversation(
141153
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
142154
async def delete_conversation(
143155
conversation_id: int,
156+
category: Optional[Literal["search", "reading"]] = Query(
157+
None, description="过滤对话类别"
158+
),
144159
current_user: User = Depends(get_current_user),
145160
db: AsyncSession = Depends(get_db),
146161
):
147162
"""删除对话"""
148163
repo = ConversationRepository(db)
149-
success = await repo.delete_conversation(conversation_id, current_user.id)
150-
151-
if not success:
164+
conversation = await repo.get_conversation(conversation_id, current_user.id)
165+
if not conversation or (category and conversation.category != category):
152166
raise HTTPException(
153167
status_code=status.HTTP_404_NOT_FOUND,
154168
detail="对话不存在或无权访问",
155169
)
156170

171+
await repo.delete_conversation(conversation_id, current_user.id)
157172
return None
158173

159174

160175
@router.post("/{conversation_id}/messages", response_model=ConversationMessage, status_code=status.HTTP_201_CREATED)
161176
async def add_message(
162177
conversation_id: int,
163178
data: ConversationMessageCreate,
179+
category: Optional[Literal["search", "reading"]] = Query(
180+
None, description="过滤对话类别"
181+
),
164182
current_user: User = Depends(get_current_user),
165183
db: AsyncSession = Depends(get_db),
166184
):
@@ -169,7 +187,7 @@ async def add_message(
169187

170188
# 验证对话存在且属于当前用户
171189
conversation = await repo.get_conversation(conversation_id, current_user.id)
172-
if not conversation:
190+
if not conversation or (category and conversation.category != category):
173191
raise HTTPException(
174192
status_code=status.HTTP_404_NOT_FOUND,
175193
detail="对话不存在或无权访问",

0 commit comments

Comments
 (0)