diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index d7efe3b2e60..3556f5ae3f5 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -1,6 +1,8 @@ +import asyncio import datetime import json import os +from collections.abc import AsyncGenerator from collections.abc import Generator from datetime import timedelta from uuid import UUID @@ -103,6 +105,7 @@ from onyx.utils.headers import get_custom_tool_additional_request_headers from onyx.utils.logger import setup_logger from onyx.utils.telemetry import mt_cloud_telemetry +from onyx.utils.threadpool_concurrency import run_in_background from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @@ -507,7 +510,7 @@ def stream_generator() -> Generator[str, None, None]: @router.post("/send-chat-message", response_model=None, tags=PUBLIC_API_TAGS) -def handle_send_chat_message( +async def handle_send_chat_message( chat_message_req: SendMessageRequest, request: Request, user: User | None = Depends(current_chat_accessible_user), @@ -572,34 +575,63 @@ def handle_send_chat_message( # Note: LLM cost tracking is now handled in multi_llm.py return result - # Streaming path, normal Onyx UI behavior - def stream_generator() -> Generator[str, None, None]: + # Use prod-cons pattern to continue processing even if request stops yielding + buffer: asyncio.Queue[str | None] = asyncio.Queue() + loop = asyncio.get_running_loop() + + # Capture headers before spawning thread + litellm_headers = extract_headers(request.headers, LITELLM_PASS_THROUGH_HEADERS) + custom_tool_headers = get_custom_tool_additional_request_headers(request.headers) + + def producer() -> None: + """ + Producer function that runs handle_stream_message_objects in a loop + and writes results to the buffer. + """ state_container = ChatStateContainer() try: + logger.debug("Producer started") with get_session_with_current_tenant() as db_session: for obj in handle_stream_message_objects( new_msg_req=chat_message_req, user=user, db_session=db_session, - litellm_additional_headers=extract_headers( - request.headers, LITELLM_PASS_THROUGH_HEADERS - ), - custom_tool_additional_headers=get_custom_tool_additional_request_headers( - request.headers - ), + litellm_additional_headers=litellm_headers, + custom_tool_additional_headers=custom_tool_headers, external_state_container=state_container, ): - yield get_json_line(obj.model_dump()) + # Thread-safe put into the asyncio queue + loop.call_soon_threadsafe( + buffer.put_nowait, get_json_line(obj.model_dump()) + ) # Note: LLM cost tracking is now handled in multi_llm.py - except Exception as e: logger.exception("Error in chat message streaming") - yield json.dumps({"error": str(e)}) - + loop.call_soon_threadsafe(buffer.put_nowait, json.dumps({"error": str(e)})) finally: - logger.debug("Stream generator finished") + # Signal end of stream + loop.call_soon_threadsafe(buffer.put_nowait, None) + logger.debug("Producer finished") + + async def stream_from_buffer() -> AsyncGenerator[str, None]: + """ + Async generator that reads from the buffer and yields to the client. + """ + try: + while True: + item = await buffer.get() + if item is None: + # End of stream signal + break + yield item + except asyncio.CancelledError: + logger.warning("Stream cancelled (Consumer disconnected)") + finally: + logger.debug("Stream consumer finished") - return StreamingResponse(stream_generator(), media_type="text/event-stream") + run_in_background(producer) + + return StreamingResponse(stream_from_buffer(), media_type="text/event-stream") @router.put("/set-message-as-latest")