diff --git a/backend/onyx/chat/chat_state.py b/backend/onyx/chat/chat_state.py index a194f844ed4..da1e77503af 100644 --- a/backend/onyx/chat/chat_state.py +++ b/backend/onyx/chat/chat_state.py @@ -94,6 +94,7 @@ def get_is_clarification(self) -> bool: def run_chat_loop_with_state_containers( func: Callable[..., None], + completion_callback: Callable[[ChatStateContainer], None], is_connected: Callable[[], bool], emitter: Emitter, state_container: ChatStateContainer, @@ -196,3 +197,12 @@ def run_with_exception_capture() -> None: # Skip waiting if user disconnected to exit quickly. if is_connected(): wait_on_background(thread) + try: + completion_callback(state_container) + except Exception as e: + emitter.emit( + Packet( + placement=Placement(turn_index=last_turn_index + 1), + obj=PacketException(type="error", exception=e), + ) + ) diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 36caaa256c6..5d1c6ada7d5 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -5,6 +5,7 @@ import re import traceback +from collections.abc import Callable from uuid import UUID from sqlalchemy.orm import Session @@ -45,6 +46,7 @@ from onyx.db.chat import get_or_create_root_message from onyx.db.chat import reserve_message_id from onyx.db.memory import get_memories +from onyx.db.models import ChatMessage from onyx.db.models import User from onyx.db.projects import get_project_token_count from onyx.db.projects import get_user_files_from_project @@ -532,6 +534,17 @@ def check_is_connected() -> bool: # External container allows non-streaming callers to access accumulated state state_container = external_state_container or ChatStateContainer() + def llm_loop_completion_callback( + state_container: ChatStateContainer, + ) -> None: + llm_loop_completion_handle( + state_container=state_container, + db_session=db_session, + chat_session_id=str(chat_session.id), + is_connected=check_is_connected, + assistant_message=assistant_response, + ) + # Run the LLM loop with explicit wrapper for stop signal handling # The wrapper runs run_llm_loop in a background thread and polls every 300ms # for stop signals. run_llm_loop itself doesn't know about stopping. @@ -547,6 +560,7 @@ def check_is_connected() -> bool: yield from run_chat_loop_with_state_containers( run_deep_research_llm_loop, + llm_loop_completion_callback, is_connected=check_is_connected, emitter=emitter, state_container=state_container, @@ -563,6 +577,7 @@ def check_is_connected() -> bool: else: yield from run_chat_loop_with_state_containers( run_llm_loop, + llm_loop_completion_callback, is_connected=check_is_connected, # Not passed through to run_llm_loop emitter=emitter, state_container=state_container, @@ -580,51 +595,6 @@ def check_is_connected() -> bool: chat_session_id=str(chat_session.id), ) - # Determine if stopped by user - completed_normally = check_is_connected() - if not completed_normally: - logger.debug(f"Chat session {chat_session.id} stopped by user") - - # Build final answer based on completion status - if completed_normally: - if state_container.answer_tokens is None: - raise RuntimeError( - "LLM run completed normally but did not return an answer." - ) - final_answer = state_container.answer_tokens - else: - # Stopped by user - append stop message - if state_container.answer_tokens: - final_answer = ( - state_container.answer_tokens - + " ... \n\nGeneration was stopped by the user." - ) - else: - final_answer = "Generation was stopped by the user." - - # Build citation_docs_info from accumulated citations in state container - citation_docs_info: list[CitationDocInfo] = [] - seen_citation_nums: set[int] = set() - for citation_num, search_doc in state_container.citation_to_doc.items(): - if citation_num not in seen_citation_nums: - seen_citation_nums.add(citation_num) - citation_docs_info.append( - CitationDocInfo( - search_doc=search_doc, - citation_number=citation_num, - ) - ) - - save_chat_turn( - message_text=final_answer, - reasoning_tokens=state_container.reasoning_tokens, - citation_docs_info=citation_docs_info, - tool_calls=state_container.tool_calls, - db_session=db_session, - assistant_message=assistant_response, - is_clarification=state_container.is_clarification, - ) - except ValueError as e: logger.exception("Failed to process chat message.") @@ -677,6 +647,57 @@ def check_is_connected() -> bool: return +def llm_loop_completion_handle( + state_container: ChatStateContainer, + is_connected: Callable[[], bool], + db_session: Session, + chat_session_id: str, + assistant_message: ChatMessage, +) -> None: + # Determine if stopped by user + completed_normally = is_connected() + # Build final answer based on completion status + if completed_normally: + if state_container.answer_tokens is None: + raise RuntimeError( + "LLM run completed normally but did not return an answer." + ) + final_answer = state_container.answer_tokens + else: + # Stopped by user - append stop message + logger.debug(f"Chat session {chat_session_id} stopped by user") + if state_container.answer_tokens: + final_answer = ( + state_container.answer_tokens + + " ... \n\nGeneration was stopped by the user." + ) + else: + final_answer = "The generation was stopped by the user." + + # Build citation_docs_info from accumulated citations in state container + citation_docs_info: list[CitationDocInfo] = [] + seen_citation_nums: set[int] = set() + for citation_num, search_doc in state_container.citation_to_doc.items(): + if citation_num not in seen_citation_nums: + seen_citation_nums.add(citation_num) + citation_docs_info.append( + CitationDocInfo( + search_doc=search_doc, + citation_number=citation_num, + ) + ) + + save_chat_turn( + message_text=final_answer, + reasoning_tokens=state_container.reasoning_tokens, + citation_docs_info=citation_docs_info, + tool_calls=state_container.tool_calls, + db_session=db_session, + assistant_message=assistant_message, + is_clarification=state_container.is_clarification, + ) + + def stream_chat_message_objects( new_msg_req: CreateChatMessageRequest, user: User | None, diff --git a/backend/tests/integration/common_utils/managers/chat.py b/backend/tests/integration/common_utils/managers/chat.py index e2e476e843c..8081ac498c7 100644 --- a/backend/tests/integration/common_utils/managers/chat.py +++ b/backend/tests/integration/common_utils/managers/chat.py @@ -164,6 +164,87 @@ def send_message( return streamed_response + @staticmethod + def send_message_with_disconnect( + chat_session_id: UUID, + message: str, + disconnect_after_packets: int = 0, + parent_message_id: int | None = None, + user_performing_action: DATestUser | None = None, + file_descriptors: list[FileDescriptor] | None = None, + search_doc_ids: list[int] | None = None, + retrieval_options: RetrievalDetails | None = None, + query_override: str | None = None, + regenerate: bool | None = None, + llm_override: LLMOverride | None = None, + prompt_override: PromptOverride | None = None, + alternate_assistant_id: int | None = None, + use_existing_user_message: bool = False, + forced_tool_ids: list[int] | None = None, + ) -> None: + """ + Send a message and simulate client disconnect before stream completes. + + This is useful for testing how the server handles client disconnections + during streaming responses. + + Args: + chat_session_id: The chat session ID + message: The message to send + disconnect_after_packets: Disconnect after receiving this many packets. + If None, disconnect_after_type must be specified. + disconnect_after_type: Disconnect after receiving a packet of this type + (e.g., "message_start", "search_tool_start"). If None, + disconnect_after_packets must be specified. + ... (other standard message parameters) + + Returns: + StreamedResponse containing data received before disconnect, + with is_disconnected=True flag set. + """ + chat_message_req = CreateChatMessageRequest( + chat_session_id=chat_session_id, + parent_message_id=parent_message_id, + message=message, + file_descriptors=file_descriptors or [], + search_doc_ids=search_doc_ids or [], + retrieval_options=retrieval_options, + rerank_settings=None, + query_override=query_override, + regenerate=regenerate, + llm_override=llm_override, + prompt_override=prompt_override, + alternate_assistant_id=alternate_assistant_id, + use_existing_user_message=use_existing_user_message, + forced_tool_ids=forced_tool_ids, + ) + + headers = ( + user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS + ) + cookies = user_performing_action.cookies if user_performing_action else None + + packets_received = 0 + + with requests.post( + f"{API_SERVER_URL}/chat/send-message", + json=chat_message_req.model_dump(), + headers=headers, + stream=True, + cookies=cookies, + ) as response: + for line in response.iter_lines(): + if not line: + continue + + packets_received += 1 + if packets_received > disconnect_after_packets: + break + + return None + @staticmethod def analyze_response(response: Response) -> StreamedResponse: response_data = cast( diff --git a/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py index 35129a3ef5c..f4a61aded20 100644 --- a/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py +++ b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py @@ -1,8 +1,15 @@ +import time + +from onyx.configs.constants import MessageType from tests.integration.common_utils.managers.chat import ChatSessionManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.test_models import DATestUser from tests.integration.conftest import DocumentBuilderType +TERMINATED_RESPONSE_MESSAGE = ( + "Response was terminated prior to completion, try regenerating." +) + def test_send_two_messages(basic_user: DATestUser) -> None: # Create a chat session @@ -104,3 +111,59 @@ def test_send_message__basic_searches( # short doc should be more relevant and thus first assert response.top_documents[0].document_id == short_doc.id assert response.top_documents[1].document_id == long_doc.id + + +def test_send_message_disconnect_and_cleanup( + reset: None, admin_user: DATestUser +) -> None: + """ + Test that when a client disconnects mid-stream: + 1. Client sends a message and disconnects after receiving just 1 packet + 2. Client checks to see that their message ends up completed + + Note: There is an interim period (between disconnect and checkup) where we expect + to see some sort of 'loading' message. + """ + LLMProviderManager.create(user_performing_action=admin_user) + + test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) + + # Send a message and disconnect after receiving just 1 packet + ChatSessionManager.send_message_with_disconnect( + chat_session_id=test_chat_session.id, + message="What are some important events that happened today?", + user_performing_action=admin_user, + disconnect_after_packets=1, + ) + + # Every 5 seconds, check if we have the latest state of the chat session up to a minute + increment_seconds = 1 + max_seconds = 60 + msg = TERMINATED_RESPONSE_MESSAGE + + for _ in range(max_seconds // increment_seconds): + time.sleep(increment_seconds) + + # Get the chat history + chat_history = ChatSessionManager.get_chat_history( + chat_session=test_chat_session, + user_performing_action=admin_user, + ) + + # Find the assistant message + assistant_message = None + for chat_obj in chat_history: + if chat_obj.message_type == MessageType.ASSISTANT: + assistant_message = chat_obj + break + + assert assistant_message is not None, "Assistant message should exist" + msg = assistant_message.message + + if msg != TERMINATED_RESPONSE_MESSAGE: + break + + assert msg != TERMINATED_RESPONSE_MESSAGE, ( + f"Assistant message should no longer be the terminated response message after cleanup, " + f"got: {msg}" + )