-
Notifications
You must be signed in to change notification settings - Fork 2.3k
fix(chat): post llm loop callback #7309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
|
|
||
| import re | ||
| import traceback | ||
| from collections.abc import Callable | ||
| from uuid import UUID | ||
|
|
||
| from sqlalchemy.orm import Session | ||
|
|
@@ -45,6 +46,7 @@ | |
| from onyx.db.chat import get_or_create_root_message | ||
| from onyx.db.chat import reserve_message_id | ||
| from onyx.db.memory import get_memories | ||
| from onyx.db.models import ChatMessage | ||
| from onyx.db.models import User | ||
| from onyx.db.projects import get_project_token_count | ||
| from onyx.db.projects import get_user_files_from_project | ||
|
|
@@ -532,6 +534,17 @@ def check_is_connected() -> bool: | |
| # External container allows non-streaming callers to access accumulated state | ||
| state_container = external_state_container or ChatStateContainer() | ||
|
|
||
| def llm_loop_completion_callback( | ||
| state_container: ChatStateContainer, | ||
| ) -> None: | ||
| llm_loop_completion_handle( | ||
| state_container=state_container, | ||
| db_session=db_session, | ||
| chat_session_id=str(chat_session.id), | ||
| is_connected=check_is_connected, | ||
| assistant_message=assistant_response, | ||
| ) | ||
|
|
||
| # Run the LLM loop with explicit wrapper for stop signal handling | ||
| # The wrapper runs run_llm_loop in a background thread and polls every 300ms | ||
| # for stop signals. run_llm_loop itself doesn't know about stopping. | ||
|
|
@@ -547,6 +560,7 @@ def check_is_connected() -> bool: | |
|
|
||
| yield from run_chat_loop_with_state_containers( | ||
| run_deep_research_llm_loop, | ||
| llm_loop_completion_callback, | ||
| is_connected=check_is_connected, | ||
| emitter=emitter, | ||
| state_container=state_container, | ||
|
|
@@ -563,6 +577,7 @@ def check_is_connected() -> bool: | |
| else: | ||
| yield from run_chat_loop_with_state_containers( | ||
| run_llm_loop, | ||
| llm_loop_completion_callback, | ||
| is_connected=check_is_connected, # Not passed through to run_llm_loop | ||
| emitter=emitter, | ||
| state_container=state_container, | ||
|
|
@@ -580,51 +595,6 @@ def check_is_connected() -> bool: | |
| chat_session_id=str(chat_session.id), | ||
| ) | ||
|
|
||
| # Determine if stopped by user | ||
| completed_normally = check_is_connected() | ||
| if not completed_normally: | ||
| logger.debug(f"Chat session {chat_session.id} stopped by user") | ||
|
|
||
| # Build final answer based on completion status | ||
| if completed_normally: | ||
| if state_container.answer_tokens is None: | ||
| raise RuntimeError( | ||
| "LLM run completed normally but did not return an answer." | ||
| ) | ||
| final_answer = state_container.answer_tokens | ||
| else: | ||
| # Stopped by user - append stop message | ||
| if state_container.answer_tokens: | ||
| final_answer = ( | ||
| state_container.answer_tokens | ||
| + " ... \n\nGeneration was stopped by the user." | ||
| ) | ||
| else: | ||
| final_answer = "Generation was stopped by the user." | ||
|
|
||
| # Build citation_docs_info from accumulated citations in state container | ||
| citation_docs_info: list[CitationDocInfo] = [] | ||
| seen_citation_nums: set[int] = set() | ||
| for citation_num, search_doc in state_container.citation_to_doc.items(): | ||
| if citation_num not in seen_citation_nums: | ||
| seen_citation_nums.add(citation_num) | ||
| citation_docs_info.append( | ||
| CitationDocInfo( | ||
| search_doc=search_doc, | ||
| citation_number=citation_num, | ||
| ) | ||
| ) | ||
|
|
||
| save_chat_turn( | ||
| message_text=final_answer, | ||
| reasoning_tokens=state_container.reasoning_tokens, | ||
| citation_docs_info=citation_docs_info, | ||
| tool_calls=state_container.tool_calls, | ||
| db_session=db_session, | ||
| assistant_message=assistant_response, | ||
| is_clarification=state_container.is_clarification, | ||
| ) | ||
|
|
||
| except ValueError as e: | ||
| logger.exception("Failed to process chat message.") | ||
|
|
||
|
|
@@ -677,6 +647,57 @@ def check_is_connected() -> bool: | |
| return | ||
|
|
||
|
|
||
| def llm_loop_completion_handle( | ||
| state_container: ChatStateContainer, | ||
| is_connected: Callable[[], bool], | ||
| db_session: Session, | ||
| chat_session_id: str, | ||
| assistant_message: ChatMessage, | ||
| ) -> None: | ||
| # Determine if stopped by user | ||
| completed_normally = is_connected() | ||
| # Build final answer based on completion status | ||
| if completed_normally: | ||
| if state_container.answer_tokens is None: | ||
| raise RuntimeError( | ||
| "LLM run completed normally but did not return an answer." | ||
| ) | ||
|
Comment on lines
+660
to
+664
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P2] When the LLM completes normally but Prompt To Fix With AIThis is a comment left during a code review.
Path: backend/onyx/chat/process_message.py
Line: 678:682
Comment:
[P2] When the LLM completes normally but `answer_tokens` is None, this raises RuntimeError. However, since this runs in the completion callback (which executes in the `finally` block), this exception will propagate differently than the original design where it was caught by the exception handlers in `handle_stream_message_objects`. The error handling path has changed.
How can I resolve this? If you propose a fix, please make it concise.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is handled in the layer that runs this func |
||
| final_answer = state_container.answer_tokens | ||
| else: | ||
| # Stopped by user - append stop message | ||
| logger.debug(f"Chat session {chat_session_id} stopped by user") | ||
| if state_container.answer_tokens: | ||
| final_answer = ( | ||
| state_container.answer_tokens | ||
| + " ... \n\nGeneration was stopped by the user." | ||
| ) | ||
| else: | ||
| final_answer = "The generation was stopped by the user." | ||
|
|
||
| # Build citation_docs_info from accumulated citations in state container | ||
| citation_docs_info: list[CitationDocInfo] = [] | ||
| seen_citation_nums: set[int] = set() | ||
| for citation_num, search_doc in state_container.citation_to_doc.items(): | ||
| if citation_num not in seen_citation_nums: | ||
| seen_citation_nums.add(citation_num) | ||
| citation_docs_info.append( | ||
| CitationDocInfo( | ||
| search_doc=search_doc, | ||
| citation_number=citation_num, | ||
| ) | ||
| ) | ||
|
|
||
| save_chat_turn( | ||
| message_text=final_answer, | ||
| reasoning_tokens=state_container.reasoning_tokens, | ||
| citation_docs_info=citation_docs_info, | ||
| tool_calls=state_container.tool_calls, | ||
| db_session=db_session, | ||
| assistant_message=assistant_message, | ||
| is_clarification=state_container.is_clarification, | ||
| ) | ||
|
|
||
|
|
||
| def stream_chat_message_objects( | ||
| new_msg_req: CreateChatMessageRequest, | ||
| user: User | None, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -164,6 +164,87 @@ def send_message( | |||||||
|
|
||||||||
| return streamed_response | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def send_message_with_disconnect( | ||||||||
| chat_session_id: UUID, | ||||||||
| message: str, | ||||||||
| disconnect_after_packets: int = 0, | ||||||||
| parent_message_id: int | None = None, | ||||||||
| user_performing_action: DATestUser | None = None, | ||||||||
| file_descriptors: list[FileDescriptor] | None = None, | ||||||||
| search_doc_ids: list[int] | None = None, | ||||||||
| retrieval_options: RetrievalDetails | None = None, | ||||||||
| query_override: str | None = None, | ||||||||
| regenerate: bool | None = None, | ||||||||
| llm_override: LLMOverride | None = None, | ||||||||
| prompt_override: PromptOverride | None = None, | ||||||||
| alternate_assistant_id: int | None = None, | ||||||||
| use_existing_user_message: bool = False, | ||||||||
| forced_tool_ids: list[int] | None = None, | ||||||||
| ) -> None: | ||||||||
| """ | ||||||||
| Send a message and simulate client disconnect before stream completes. | ||||||||
|
|
||||||||
| This is useful for testing how the server handles client disconnections | ||||||||
| during streaming responses. | ||||||||
|
|
||||||||
| Args: | ||||||||
| chat_session_id: The chat session ID | ||||||||
| message: The message to send | ||||||||
| disconnect_after_packets: Disconnect after receiving this many packets. | ||||||||
| If None, disconnect_after_type must be specified. | ||||||||
| disconnect_after_type: Disconnect after receiving a packet of this type | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P2: Docstring references non-existent parameter Prompt for AI agents |
||||||||
| (e.g., "message_start", "search_tool_start"). If None, | ||||||||
| disconnect_after_packets must be specified. | ||||||||
| ... (other standard message parameters) | ||||||||
|
|
||||||||
| Returns: | ||||||||
| StreamedResponse containing data received before disconnect, | ||||||||
| with is_disconnected=True flag set. | ||||||||
|
Comment on lines
+202
to
+203
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P2: Docstring incorrectly states the function returns a Prompt for AI agents
Suggested change
|
||||||||
| """ | ||||||||
| chat_message_req = CreateChatMessageRequest( | ||||||||
| chat_session_id=chat_session_id, | ||||||||
| parent_message_id=parent_message_id, | ||||||||
| message=message, | ||||||||
| file_descriptors=file_descriptors or [], | ||||||||
| search_doc_ids=search_doc_ids or [], | ||||||||
| retrieval_options=retrieval_options, | ||||||||
| rerank_settings=None, | ||||||||
| query_override=query_override, | ||||||||
| regenerate=regenerate, | ||||||||
| llm_override=llm_override, | ||||||||
| prompt_override=prompt_override, | ||||||||
| alternate_assistant_id=alternate_assistant_id, | ||||||||
| use_existing_user_message=use_existing_user_message, | ||||||||
| forced_tool_ids=forced_tool_ids, | ||||||||
| ) | ||||||||
|
|
||||||||
| headers = ( | ||||||||
| user_performing_action.headers | ||||||||
| if user_performing_action | ||||||||
| else GENERAL_HEADERS | ||||||||
| ) | ||||||||
| cookies = user_performing_action.cookies if user_performing_action else None | ||||||||
|
|
||||||||
| packets_received = 0 | ||||||||
|
|
||||||||
| with requests.post( | ||||||||
| f"{API_SERVER_URL}/chat/send-message", | ||||||||
| json=chat_message_req.model_dump(), | ||||||||
| headers=headers, | ||||||||
| stream=True, | ||||||||
| cookies=cookies, | ||||||||
| ) as response: | ||||||||
| for line in response.iter_lines(): | ||||||||
| if not line: | ||||||||
| continue | ||||||||
|
|
||||||||
| packets_received += 1 | ||||||||
| if packets_received > disconnect_after_packets: | ||||||||
| break | ||||||||
|
|
||||||||
| return None | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def analyze_response(response: Response) -> StreamedResponse: | ||||||||
| response_data = cast( | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.