Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions backend/onyx/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import datetime
import json
import os
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
Expand Down Expand Up @@ -103,6 +105,7 @@
from onyx.utils.headers import get_custom_tool_additional_request_headers
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.threadpool_concurrency import run_in_background
from shared_configs.contextvars import get_current_tenant_id

logger = setup_logger()
Expand Down Expand Up @@ -507,7 +510,7 @@ def stream_generator() -> Generator[str, None, None]:


@router.post("/send-chat-message", response_model=None, tags=PUBLIC_API_TAGS)
def handle_send_chat_message(
async def handle_send_chat_message(
chat_message_req: SendMessageRequest,
request: Request,
user: User | None = Depends(current_chat_accessible_user),
Expand Down Expand Up @@ -572,34 +575,63 @@ def handle_send_chat_message(
# Note: LLM cost tracking is now handled in multi_llm.py
return result

# Streaming path, normal Onyx UI behavior
def stream_generator() -> Generator[str, None, None]:
# Use prod-cons pattern to continue processing even if request stops yielding
buffer: asyncio.Queue[str | None] = asyncio.Queue()
loop = asyncio.get_running_loop()

# Capture headers before spawning thread
litellm_headers = extract_headers(request.headers, LITELLM_PASS_THROUGH_HEADERS)
custom_tool_headers = get_custom_tool_additional_request_headers(request.headers)

def producer() -> None:
"""
Producer function that runs handle_stream_message_objects in a loop
and writes results to the buffer.
"""
state_container = ChatStateContainer()
try:
logger.debug("Producer started")
with get_session_with_current_tenant() as db_session:
for obj in handle_stream_message_objects(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
litellm_additional_headers=litellm_headers,
custom_tool_additional_headers=custom_tool_headers,
external_state_container=state_container,
):
yield get_json_line(obj.model_dump())
# Thread-safe put into the asyncio queue
loop.call_soon_threadsafe(
buffer.put_nowait, get_json_line(obj.model_dump())
)
Comment on lines +604 to +606
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the asyncio queue is full (when maxsize is set), put_nowait will raise asyncio.QueueFull exception, which would crash the producer thread. The current code doesn't handle this case.

While the queue is currently unbounded, if backpressure is added in the future, this needs proper error handling.

Copy link
Contributor Author

@Danelegend Danelegend Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is currently no maxsize set, and for now we don't expect to put one

# Note: LLM cost tracking is now handled in multi_llm.py

except Exception as e:
logger.exception("Error in chat message streaming")
yield json.dumps({"error": str(e)})

loop.call_soon_threadsafe(buffer.put_nowait, json.dumps({"error": str(e)}))
finally:
logger.debug("Stream generator finished")
# Signal end of stream
loop.call_soon_threadsafe(buffer.put_nowait, None)
logger.debug("Producer finished")

async def stream_from_buffer() -> AsyncGenerator[str, None]:
"""
Async generator that reads from the buffer and yields to the client.
"""
try:
while True:
item = await buffer.get()
if item is None:
# End of stream signal
break
yield item
except asyncio.CancelledError:
logger.warning("Stream cancelled (Consumer disconnected)")
finally:
logger.debug("Stream consumer finished")

return StreamingResponse(stream_generator(), media_type="text/event-stream")
run_in_background(producer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The producer thread continues running even if the client disconnects, which is the intended fix for the refresh bug. However, if the consumer (stream_from_buffer) terminates early (e.g., client disconnect, error), the producer thread has no way to know and will keep processing. This could lead to wasted resources processing messages nobody will receive.

Consider tracking the consumer's state and providing a way for the producer to check if it should stop early, or use a cancellation token pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only stop if the user presses the stop button. There is another mechanism that is applied that takes care of this.


return StreamingResponse(stream_from_buffer(), media_type="text/event-stream")


@router.put("/set-message-as-latest")
Expand Down
Loading