diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 44ff5680217..535acef439a 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -102,6 +102,14 @@ class MessageResponseIDInfo(BaseModel): reserved_assistant_message_id: int +class MultiModelMessageResponseIDInfo(BaseModel): + """Response info for multi-model chat: one user message, multiple assistant messages.""" + + user_message_id: int | None + reserved_assistant_message_ids: list[int] # One per model + model_names: list[str] # Display names for UI + + class StreamingError(BaseModel): error: str stack_trace: str | None = None @@ -200,6 +208,7 @@ class LLMMetricsContainer(BaseModel): Packet | StreamStopInfo | MessageResponseIDInfo + | MultiModelMessageResponseIDInfo | StreamingError | UserKnowledgeFilePacket | CreateChatSessionID diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 44ce1eb1d19..c5250d2a18a 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -5,6 +5,9 @@ import re import traceback +from collections.abc import Callable +from collections.abc import Generator +from typing import Any from uuid import UUID from sqlalchemy.orm import Session @@ -17,6 +20,7 @@ from onyx.chat.chat_utils import get_custom_agent_prompt from onyx.chat.chat_utils import is_last_assistant_message_clarification from onyx.chat.chat_utils import load_all_chat_files +from onyx.chat.emitter import Emitter from onyx.chat.emitter import get_default_emitter from onyx.chat.llm_loop import run_llm_loop from onyx.chat.models import AnswerStream @@ -26,6 +30,7 @@ from onyx.chat.models import CreateChatSessionID from onyx.chat.models import ExtractedProjectFiles from onyx.chat.models import MessageResponseIDInfo +from onyx.chat.models import MultiModelMessageResponseIDInfo from onyx.chat.models import ProjectFileMetadata from onyx.chat.models import ProjectSearchConfig from onyx.chat.models import StreamingError @@ -63,10 +68,13 @@ from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE from onyx.server.query_and_chat.models import CreateChatMessageRequest from onyx.server.query_and_chat.models import SendMessageRequest +from onyx.server.query_and_chat.placement import Placement from onyx.server.query_and_chat.streaming_models import AgentResponseDelta from onyx.server.query_and_chat.streaming_models import AgentResponseStart from onyx.server.query_and_chat.streaming_models import CitationInfo +from onyx.server.query_and_chat.streaming_models import OverallStop from onyx.server.query_and_chat.streaming_models import Packet +from onyx.server.query_and_chat.streaming_models import PacketException from onyx.server.usage_limits import check_llm_cost_limit_for_provider from onyx.tools.constants import SEARCH_TOOL_ID from onyx.tools.interface import Tool @@ -272,6 +280,159 @@ def _get_project_search_availability( ) +def _run_multi_model_chat_loops( + llms: list[LLM], + model_names: list[str], + emitter: "Emitter", + state_containers: list[ChatStateContainer], + check_is_connected: Callable[[], bool], + simple_chat_history: list, + tools: list[Tool], + custom_agent_prompt: str | None, + extracted_project_files: "ExtractedProjectFiles", + persona: Any, + memories: list, + token_counter: Callable, + db_session: Session, + forced_tool_id: int | None, + user_identity: "LLMUserIdentity", + chat_session_id: str, +) -> Generator[Packet, None, None]: + """ + Run multiple LLM loops in parallel for multi-model chat. + Each model's packets are tagged with a model_index in their placement. + """ + import queue + import threading + + # Shared queue for all model outputs + shared_queue: queue.Queue[tuple[int, Packet | None]] = queue.Queue() + [threading.Event() for _ in llms] + + def run_model_loop( + model_index: int, llm: LLM, state_container: ChatStateContainer + ) -> None: + """Run a single model's LLM loop and emit packets to shared queue with model_index.""" + try: + # Create a model-specific emitter that tags packets with model_index + class ModelIndexEmitter: + def __init__(self, model_idx: int, shared_q: queue.Queue): + self.model_idx = model_idx + self.shared_q = shared_q + self.bus = queue.Queue() # Compatibility with existing code + + def emit(self, packet: Packet) -> None: + # Clone the placement with model_index added + new_placement = Placement( + turn_index=packet.placement.turn_index, + tab_index=packet.placement.tab_index, + sub_turn_index=packet.placement.sub_turn_index, + model_index=self.model_idx, + ) + tagged_packet = Packet(placement=new_placement, obj=packet.obj) + self.shared_q.put((self.model_idx, tagged_packet)) + + model_emitter = ModelIndexEmitter(model_index, shared_queue) + + # Run the LLM loop for this model + run_llm_loop( + model_emitter, + simple_chat_history=simple_chat_history, + tools=tools, + custom_agent_prompt=custom_agent_prompt, + project_files=extracted_project_files, + persona=persona, + memories=memories, + llm=llm, + token_counter=token_counter, + db_session=db_session, + forced_tool_id=forced_tool_id, + user_identity=user_identity, + chat_session_id=chat_session_id, + state_container=state_container, + ) + + # Signal completion for this model + shared_queue.put((model_index, None)) + + except Exception as e: + logger.exception( + f"Error in model {model_index} ({model_names[model_index]}): {e}" + ) + # Emit error packet for this model + error_packet = Packet( + placement=Placement(turn_index=0, model_index=model_index), + obj=PacketException(type="error", exception=e), + ) + shared_queue.put((model_index, error_packet)) + shared_queue.put((model_index, None)) # Signal completion + + # Start all model loops in parallel + threads = [] + for i, (llm, state_container) in enumerate(zip(llms, state_containers)): + thread = threading.Thread(target=run_model_loop, args=(i, llm, state_container)) + thread.start() + threads.append(thread) + + # Track completion status for each model + completed = [False] * len(llms) + last_turn_indices = [0] * len(llms) + + try: + while not all(completed): + try: + model_idx, packet = shared_queue.get(timeout=0.3) + + if packet is None: + # Model completed + completed[model_idx] = True + continue + + # Track turn index for stop packet + if packet.placement.turn_index > last_turn_indices[model_idx]: + last_turn_indices[model_idx] = packet.placement.turn_index + + # Check for stop packet or exception + if isinstance(packet.obj, OverallStop): + yield packet + completed[model_idx] = True + elif isinstance(packet.obj, PacketException): + # Don't raise - just emit error for this model and continue others + yield Packet( + placement=Placement( + turn_index=last_turn_indices[model_idx], + model_index=model_idx, + ), + obj=OverallStop(type="stop", stop_reason="error"), + ) + completed[model_idx] = True + else: + yield packet + + except queue.Empty: + # Check for user cancellation + if not check_is_connected(): + # Emit stop packets for all incomplete models + for i, is_completed in enumerate(completed): + if not is_completed: + yield Packet( + placement=Placement( + turn_index=last_turn_indices[i] + 1, + model_index=i, + ), + obj=OverallStop( + type="stop", stop_reason="user_cancelled" + ), + ) + completed[i] = True + break + + finally: + # Wait for all threads to complete + for thread in threads: + thread.join(timeout=5.0) + + def handle_stream_message_objects( new_msg_req: SendMessageRequest, user: User | None, @@ -502,19 +663,6 @@ def handle_stream_message_objects( # TODO Need to think of some way to support selected docs from the sidebar - # Reserve a message id for the assistant response for frontend to track packets - assistant_response = reserve_message_id( - db_session=db_session, - chat_session_id=chat_session.id, - parent_message=user_message.id, - message_type=MessageType.ASSISTANT, - ) - - yield MessageResponseIDInfo( - user_message_id=user_message.id, - reserved_assistant_message_id=assistant_response.id, - ) - # Convert the chat history into a simple format that is free of any DB objects # and is easy to parse for the agent loop simple_chat_history = convert_chat_history( @@ -536,51 +684,66 @@ def handle_stream_message_objects( def check_is_connected() -> bool: return check_stop_signal(chat_session.id, redis_client) - # Use external state container if provided, otherwise create internal one - # External container allows non-streaming callers to access accumulated state - state_container = external_state_container or ChatStateContainer() - - # Run the LLM loop with explicit wrapper for stop signal handling - # The wrapper runs run_llm_loop in a background thread and polls every 300ms - # for stop signals. run_llm_loop itself doesn't know about stopping. - # Note: DB session is not thread safe but nothing else uses it and the - # reference is passed directly so it's ok. - if new_msg_req.deep_research: - if chat_session.project_id: - raise RuntimeError("Deep research is not supported for projects") - - # Skip clarification if the last assistant message was a clarification - # (user has already responded to a clarification question) - skip_clarification = is_last_assistant_message_clarification(chat_history) - - yield from run_chat_loop_with_state_containers( - run_deep_research_llm_loop, - is_connected=check_is_connected, - emitter=emitter, - state_container=state_container, - simple_chat_history=simple_chat_history, - tools=tools, - custom_agent_prompt=custom_agent_prompt, - llm=llm, - token_counter=token_counter, - db_session=db_session, - skip_clarification=skip_clarification, - user_identity=user_identity, - chat_session_id=str(chat_session.id), + # Check for multi-model chat mode (2 or more models) + if new_msg_req.llm_overrides and len(new_msg_req.llm_overrides) >= 2: + # Multi-model chat: run N models in parallel + if new_msg_req.deep_research: + raise RuntimeError( + "Deep research is not supported for multi-model chat" + ) + + # Create LLM instances for each model override + llms: list[LLM] = [] + model_names: list[str] = [] + for llm_override in new_msg_req.llm_overrides: + model_llm = get_llm_for_persona( + persona=persona, + user=user, + llm_override=llm_override, + additional_headers=litellm_additional_headers, + long_term_logger=long_term_logger, + ) + llms.append(model_llm) + model_names.append( + llm_override.model_version + or llm_override.model_provider + or f"Model {len(llms)}" + ) + + # Reserve message IDs for all N assistant responses (all have same parent) + num_models = len(new_msg_req.llm_overrides) + assistant_responses = [] + for _ in range(num_models): + assistant_response = reserve_message_id( + db_session=db_session, + chat_session_id=chat_session.id, + parent_message=user_message.id, + message_type=MessageType.ASSISTANT, + ) + assistant_responses.append(assistant_response) + + yield MultiModelMessageResponseIDInfo( + user_message_id=user_message.id, + reserved_assistant_message_ids=[ar.id for ar in assistant_responses], + model_names=model_names, ) - else: - yield from run_chat_loop_with_state_containers( - run_llm_loop, - is_connected=check_is_connected, # Not passed through to run_llm_loop + + # Create state containers for each model + state_containers = [ChatStateContainer() for _ in range(num_models)] + + # Run all N models in parallel + yield from _run_multi_model_chat_loops( + llms=llms, + model_names=model_names, emitter=emitter, - state_container=state_container, + state_containers=state_containers, + check_is_connected=check_is_connected, simple_chat_history=simple_chat_history, tools=tools, custom_agent_prompt=custom_agent_prompt, - project_files=extracted_project_files, + extracted_project_files=extracted_project_files, persona=persona, memories=memories, - llm=llm, token_counter=token_counter, db_session=db_session, forced_tool_id=forced_tool_id, @@ -588,50 +751,163 @@ def check_is_connected() -> bool: chat_session_id=str(chat_session.id), ) - # Determine if stopped by user - completed_normally = check_is_connected() - if not completed_normally: - logger.debug(f"Chat session {chat_session.id} stopped by user") - - # Build final answer based on completion status - if completed_normally: - if state_container.answer_tokens is None: - raise RuntimeError( - "LLM run completed normally but did not return an answer." + # Save each model's response + completed_normally = check_is_connected() + for i, (assistant_response, state_container) in enumerate( + zip(assistant_responses, state_containers) + ): + if completed_normally: + if state_container.answer_tokens is None: + final_answer = ( + f"Model {model_names[i]} did not return an answer." + ) + else: + final_answer = state_container.answer_tokens + else: + if state_container.answer_tokens: + final_answer = ( + state_container.answer_tokens + + " ... The generation was stopped by the user here." + ) + else: + final_answer = "The generation was stopped by the user." + + # Build citation_docs_info from accumulated citations + citation_docs_info: list[CitationDocInfo] = [] + seen_citation_nums: set[int] = set() + for citation_num, search_doc in state_container.citation_to_doc.items(): + if citation_num not in seen_citation_nums: + seen_citation_nums.add(citation_num) + citation_docs_info.append( + CitationDocInfo( + search_doc=search_doc, + citation_number=citation_num, + ) + ) + + save_chat_turn( + message_text=final_answer, + reasoning_tokens=state_container.reasoning_tokens, + citation_docs_info=citation_docs_info, + tool_calls=state_container.tool_calls, + db_session=db_session, + assistant_message=assistant_response, + is_clarification=state_container.is_clarification, ) - final_answer = state_container.answer_tokens + else: - # Stopped by user - append stop message - if state_container.answer_tokens: - final_answer = ( - state_container.answer_tokens - + " ... The generation was stopped by the user here." + # Single-model chat (existing flow) + # Reserve a message id for the assistant response for frontend to track packets + assistant_response = reserve_message_id( + db_session=db_session, + chat_session_id=chat_session.id, + parent_message=user_message.id, + message_type=MessageType.ASSISTANT, + ) + + yield MessageResponseIDInfo( + user_message_id=user_message.id, + reserved_assistant_message_id=assistant_response.id, + ) + + # Use external state container if provided, otherwise create internal one + # External container allows non-streaming callers to access accumulated state + state_container = external_state_container or ChatStateContainer() + + # Run the LLM loop with explicit wrapper for stop signal handling + # The wrapper runs run_llm_loop in a background thread and polls every 300ms + # for stop signals. run_llm_loop itself doesn't know about stopping. + # Note: DB session is not thread safe but nothing else uses it and the + # reference is passed directly so it's ok. + if new_msg_req.deep_research: + if chat_session.project_id: + raise RuntimeError("Deep research is not supported for projects") + + # Skip clarification if the last assistant message was a clarification + # (user has already responded to a clarification question) + skip_clarification = is_last_assistant_message_clarification( + chat_history + ) + + yield from run_chat_loop_with_state_containers( + run_deep_research_llm_loop, + is_connected=check_is_connected, + emitter=emitter, + state_container=state_container, + simple_chat_history=simple_chat_history, + tools=tools, + custom_agent_prompt=custom_agent_prompt, + llm=llm, + token_counter=token_counter, + db_session=db_session, + skip_clarification=skip_clarification, + user_identity=user_identity, + chat_session_id=str(chat_session.id), ) else: - final_answer = "The generation was stopped by the user." - - # Build citation_docs_info from accumulated citations in state container - citation_docs_info: list[CitationDocInfo] = [] - seen_citation_nums: set[int] = set() - for citation_num, search_doc in state_container.citation_to_doc.items(): - if citation_num not in seen_citation_nums: - seen_citation_nums.add(citation_num) - citation_docs_info.append( - CitationDocInfo( - search_doc=search_doc, - citation_number=citation_num, - ) + yield from run_chat_loop_with_state_containers( + run_llm_loop, + is_connected=check_is_connected, # Not passed through to run_llm_loop + emitter=emitter, + state_container=state_container, + simple_chat_history=simple_chat_history, + tools=tools, + custom_agent_prompt=custom_agent_prompt, + project_files=extracted_project_files, + persona=persona, + memories=memories, + llm=llm, + token_counter=token_counter, + db_session=db_session, + forced_tool_id=forced_tool_id, + user_identity=user_identity, + chat_session_id=str(chat_session.id), ) - save_chat_turn( - message_text=final_answer, - reasoning_tokens=state_container.reasoning_tokens, - citation_docs_info=citation_docs_info, - tool_calls=state_container.tool_calls, - db_session=db_session, - assistant_message=assistant_response, - is_clarification=state_container.is_clarification, - ) + # Determine if stopped by user + completed_normally = check_is_connected() + if not completed_normally: + logger.debug(f"Chat session {chat_session.id} stopped by user") + + # Build final answer based on completion status + if completed_normally: + if state_container.answer_tokens is None: + raise RuntimeError( + "LLM run completed normally but did not return an answer." + ) + final_answer = state_container.answer_tokens + else: + # Stopped by user - append stop message + if state_container.answer_tokens: + final_answer = ( + state_container.answer_tokens + + " ... The generation was stopped by the user here." + ) + else: + final_answer = "The generation was stopped by the user." + + # Build citation_docs_info from accumulated citations in state container + citation_docs_info: list[CitationDocInfo] = [] + seen_citation_nums: set[int] = set() + for citation_num, search_doc in state_container.citation_to_doc.items(): + if citation_num not in seen_citation_nums: + seen_citation_nums.add(citation_num) + citation_docs_info.append( + CitationDocInfo( + search_doc=search_doc, + citation_number=citation_num, + ) + ) + + save_chat_turn( + message_text=final_answer, + reasoning_tokens=state_container.reasoning_tokens, + citation_docs_info=citation_docs_info, + tool_calls=state_container.tool_calls, + db_session=db_session, + assistant_message=assistant_response, + is_clarification=state_container.is_clarification, + ) except ValueError as e: logger.exception("Failed to process chat message.") diff --git a/backend/onyx/server/query_and_chat/models.py b/backend/onyx/server/query_and_chat/models.py index b23d71b9722..fe0cbf9053d 100644 --- a/backend/onyx/server/query_and_chat/models.py +++ b/backend/onyx/server/query_and_chat/models.py @@ -83,6 +83,8 @@ class SendMessageRequest(BaseModel): message: str llm_override: LLMOverride | None = None + # For multi-model chat: list of LLM overrides to compare (2-3 models supported) + llm_overrides: list[LLMOverride] | None = None allowed_tool_ids: list[int] | None = None forced_tool_id: int | None = None diff --git a/backend/onyx/server/query_and_chat/placement.py b/backend/onyx/server/query_and_chat/placement.py index b75e57fc140..cd1f07e9b90 100644 --- a/backend/onyx/server/query_and_chat/placement.py +++ b/backend/onyx/server/query_and_chat/placement.py @@ -8,3 +8,5 @@ class Placement(BaseModel): tab_index: int = 0 # Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later sub_turn_index: int | None = None + # For multi-model chat: identifies which model's response this packet belongs to (0, 1, or 2) + model_index: int | None = None diff --git a/backend/tests/unit/onyx/chat/test_multi_model_chat.py b/backend/tests/unit/onyx/chat/test_multi_model_chat.py new file mode 100644 index 00000000000..f8c190a3bd2 --- /dev/null +++ b/backend/tests/unit/onyx/chat/test_multi_model_chat.py @@ -0,0 +1,672 @@ +"""Tests for multi-model chat functionality in process_message.py.""" + +import threading +import time +from typing import Any +from unittest.mock import MagicMock +from unittest.mock import patch + +from onyx.chat.models import MultiModelMessageResponseIDInfo +from onyx.llm.override_models import LLMOverride +from onyx.server.query_and_chat.models import SendMessageRequest +from onyx.server.query_and_chat.placement import Placement +from onyx.server.query_and_chat.streaming_models import AgentResponseDelta +from onyx.server.query_and_chat.streaming_models import OverallStop +from onyx.server.query_and_chat.streaming_models import Packet + + +# ============================================================================= +# Model Tests +# ============================================================================= + + +class TestMultiModelMessageResponseIDInfo: + """Tests for MultiModelMessageResponseIDInfo model.""" + + def test_creation_with_valid_data(self) -> None: + """Test creating the model with valid data.""" + info = MultiModelMessageResponseIDInfo( + user_message_id=1, + reserved_assistant_message_ids=[10, 11, 12], + model_names=["GPT-4", "Claude", "Gemini"], + ) + assert info.user_message_id == 1 + assert info.reserved_assistant_message_ids == [10, 11, 12] + assert info.model_names == ["GPT-4", "Claude", "Gemini"] + + def test_creation_with_two_models(self) -> None: + """Test creating the model with 2 models.""" + info = MultiModelMessageResponseIDInfo( + user_message_id=1, + reserved_assistant_message_ids=[10, 11], + model_names=["GPT-4", "Claude"], + ) + assert len(info.reserved_assistant_message_ids) == 2 + assert len(info.model_names) == 2 + + def test_creation_with_null_user_message_id(self) -> None: + """Test creating the model with null user_message_id.""" + info = MultiModelMessageResponseIDInfo( + user_message_id=None, + reserved_assistant_message_ids=[10, 11], + model_names=["Model A", "Model B"], + ) + assert info.user_message_id is None + + def test_serialization(self) -> None: + """Test JSON serialization of the model.""" + info = MultiModelMessageResponseIDInfo( + user_message_id=1, + reserved_assistant_message_ids=[10, 11, 12], + model_names=["GPT-4", "Claude", "Gemini"], + ) + data = info.model_dump() + assert data["user_message_id"] == 1 + assert data["reserved_assistant_message_ids"] == [10, 11, 12] + assert data["model_names"] == ["GPT-4", "Claude", "Gemini"] + + def test_deserialization(self) -> None: + """Test JSON deserialization of the model.""" + data = { + "user_message_id": 5, + "reserved_assistant_message_ids": [20, 21], + "model_names": ["Model X", "Model Y"], + } + info = MultiModelMessageResponseIDInfo(**data) + assert info.user_message_id == 5 + assert info.reserved_assistant_message_ids == [20, 21] + + +class TestPlacementWithModelIndex: + """Tests for Placement with model_index field.""" + + def test_placement_with_model_index(self) -> None: + """Test creating Placement with model_index.""" + placement = Placement(turn_index=0, model_index=1) + assert placement.turn_index == 0 + assert placement.model_index == 1 + + def test_placement_without_model_index(self) -> None: + """Test creating Placement without model_index (defaults to None).""" + placement = Placement(turn_index=0) + assert placement.model_index is None + + def test_placement_with_all_fields(self) -> None: + """Test creating Placement with all fields including model_index.""" + placement = Placement( + turn_index=2, + tab_index=1, + sub_turn_index=0, + model_index=2, + ) + assert placement.turn_index == 2 + assert placement.tab_index == 1 + assert placement.sub_turn_index == 0 + assert placement.model_index == 2 + + def test_placement_serialization_with_model_index(self) -> None: + """Test serialization includes model_index.""" + placement = Placement(turn_index=0, model_index=0) + data = placement.model_dump() + assert "model_index" in data + assert data["model_index"] == 0 + + +class TestSendMessageRequestWithLLMOverrides: + """Tests for SendMessageRequest with llm_overrides field.""" + + def test_send_message_request_with_llm_overrides(self) -> None: + """Test creating SendMessageRequest with llm_overrides list.""" + overrides = [ + LLMOverride(model_provider="openai", model_version="gpt-4"), + LLMOverride(model_provider="anthropic", model_version="claude-3"), + ] + # Note: SendMessageRequest requires either chat_session_id or chat_session_info + # The validator auto-creates chat_session_info if neither is provided + request = SendMessageRequest( + message="Hello", + llm_overrides=overrides, + ) + assert request.llm_overrides is not None + assert len(request.llm_overrides) == 2 + + def test_send_message_request_with_three_overrides(self) -> None: + """Test creating SendMessageRequest with 3 llm_overrides.""" + overrides = [ + LLMOverride(model_provider="openai", model_version="gpt-4"), + LLMOverride(model_provider="anthropic", model_version="claude-3"), + LLMOverride(model_provider="google", model_version="gemini"), + ] + request = SendMessageRequest( + message="Hello", + llm_overrides=overrides, + ) + assert len(request.llm_overrides) == 3 + + def test_send_message_request_without_overrides(self) -> None: + """Test creating SendMessageRequest without llm_overrides.""" + request = SendMessageRequest(message="Hello") + assert request.llm_overrides is None + + def test_send_message_request_with_single_override(self) -> None: + """Test creating SendMessageRequest with single llm_override (not list).""" + override = LLMOverride(model_provider="openai", model_version="gpt-4") + request = SendMessageRequest( + message="Hello", + llm_override=override, + ) + assert request.llm_override is not None + assert request.llm_overrides is None + + +# ============================================================================= +# Multi-Model Loop Tests +# ============================================================================= + + +class TestRunMultiModelChatLoops: + """Tests for _run_multi_model_chat_loops function.""" + + def _create_mock_llm( + self, model_name: str, response_content: str = "Test response" + ) -> MagicMock: + """Create a mock LLM for testing.""" + mock_llm = MagicMock() + mock_llm.config = MagicMock() + mock_llm.config.model_name = model_name + return mock_llm + + def _create_mock_state_container(self) -> MagicMock: + """Create a mock ChatStateContainer.""" + container = MagicMock() + container.answer_tokens = None + container.reasoning_tokens = None + container.tool_calls = [] + container.citation_to_doc = {} + container.is_clarification = False + return container + + def test_model_index_tagging_with_two_models(self) -> None: + """Test that packets are tagged with correct model_index for 2 models.""" + # Import the function we're testing + from onyx.chat.process_message import _run_multi_model_chat_loops + + # Create mock LLMs + llms = [ + self._create_mock_llm("model-1"), + self._create_mock_llm("model-2"), + ] + model_names = ["Model 1", "Model 2"] + state_containers = [ + self._create_mock_state_container(), + self._create_mock_state_container(), + ] + + # Track model indices from emitted packets + emitted_model_indices: set[int] = set() + + # Mock run_llm_loop to emit a simple packet and complete + def mock_run_llm_loop( + emitter: Any, + simple_chat_history: Any, + tools: Any, + custom_agent_prompt: Any, + project_files: Any, + persona: Any, + memories: Any, + llm: Any, + token_counter: Any, + db_session: Any, + forced_tool_id: Any, + user_identity: Any, + chat_session_id: Any, + state_container: Any, + ) -> None: + # Emit a simple response delta + packet = Packet( + placement=Placement(turn_index=0), + obj=AgentResponseDelta(content="Hello"), + ) + emitter.emit(packet) + # Mark answer as complete + state_container.answer_tokens = "Hello" + + with patch( + "onyx.chat.process_message.run_llm_loop", side_effect=mock_run_llm_loop + ): + mock_emitter = MagicMock() + mock_db_session = MagicMock() + + packets = list( + _run_multi_model_chat_loops( + llms=llms, + model_names=model_names, + emitter=mock_emitter, + state_containers=state_containers, + check_is_connected=lambda: True, + simple_chat_history=[], + tools=[], + custom_agent_prompt=None, + extracted_project_files=MagicMock(), + persona=MagicMock(), + memories=[], + token_counter=MagicMock(), + db_session=mock_db_session, + forced_tool_id=None, + user_identity=MagicMock(), + chat_session_id="test-session", + ) + ) + + # Collect model indices + for packet in packets: + if packet.placement.model_index is not None: + emitted_model_indices.add(packet.placement.model_index) + + # Should have packets from both models (indices 0 and 1) + assert 0 in emitted_model_indices + assert 1 in emitted_model_indices + assert 2 not in emitted_model_indices # Only 2 models + + def test_model_index_tagging_with_three_models(self) -> None: + """Test that packets are tagged with correct model_index for 3 models.""" + from onyx.chat.process_message import _run_multi_model_chat_loops + + llms = [ + self._create_mock_llm("model-1"), + self._create_mock_llm("model-2"), + self._create_mock_llm("model-3"), + ] + model_names = ["Model 1", "Model 2", "Model 3"] + state_containers = [ + self._create_mock_state_container(), + self._create_mock_state_container(), + self._create_mock_state_container(), + ] + + emitted_model_indices: set[int] = set() + + def mock_run_llm_loop( + emitter: Any, + simple_chat_history: Any, + tools: Any, + custom_agent_prompt: Any, + project_files: Any, + persona: Any, + memories: Any, + llm: Any, + token_counter: Any, + db_session: Any, + forced_tool_id: Any, + user_identity: Any, + chat_session_id: Any, + state_container: Any, + ) -> None: + packet = Packet( + placement=Placement(turn_index=0), + obj=AgentResponseDelta(content="Response"), + ) + emitter.emit(packet) + state_container.answer_tokens = "Response" + + with patch( + "onyx.chat.process_message.run_llm_loop", side_effect=mock_run_llm_loop + ): + mock_emitter = MagicMock() + + packets = list( + _run_multi_model_chat_loops( + llms=llms, + model_names=model_names, + emitter=mock_emitter, + state_containers=state_containers, + check_is_connected=lambda: True, + simple_chat_history=[], + tools=[], + custom_agent_prompt=None, + extracted_project_files=MagicMock(), + persona=MagicMock(), + memories=[], + token_counter=MagicMock(), + db_session=MagicMock(), + forced_tool_id=None, + user_identity=MagicMock(), + chat_session_id="test-session", + ) + ) + + for packet in packets: + if packet.placement.model_index is not None: + emitted_model_indices.add(packet.placement.model_index) + + # Should have packets from all 3 models + assert 0 in emitted_model_indices + assert 1 in emitted_model_indices + assert 2 in emitted_model_indices + + def test_parallel_execution(self) -> None: + """Test that multiple models run in parallel.""" + from onyx.chat.process_message import _run_multi_model_chat_loops + + llms = [ + self._create_mock_llm("model-1"), + self._create_mock_llm("model-2"), + ] + model_names = ["Model 1", "Model 2"] + state_containers = [ + self._create_mock_state_container(), + self._create_mock_state_container(), + ] + + execution_times: dict[int, float] = {} + lock = threading.Lock() + + def mock_run_llm_loop( + emitter: Any, + simple_chat_history: Any, + tools: Any, + custom_agent_prompt: Any, + project_files: Any, + persona: Any, + memories: Any, + llm: Any, + token_counter: Any, + db_session: Any, + forced_tool_id: Any, + user_identity: Any, + chat_session_id: Any, + state_container: Any, + ) -> None: + # Record when this model started + model_idx = llms.index(llm) + start_time = time.time() + + # Simulate some work + time.sleep(0.1) + + with lock: + execution_times[model_idx] = start_time + + packet = Packet( + placement=Placement(turn_index=0), + obj=AgentResponseDelta(content="Done"), + ) + emitter.emit(packet) + state_container.answer_tokens = "Done" + + with patch( + "onyx.chat.process_message.run_llm_loop", side_effect=mock_run_llm_loop + ): + list( + _run_multi_model_chat_loops( + llms=llms, + model_names=model_names, + emitter=MagicMock(), + state_containers=state_containers, + check_is_connected=lambda: True, + simple_chat_history=[], + tools=[], + custom_agent_prompt=None, + extracted_project_files=MagicMock(), + persona=MagicMock(), + memories=[], + token_counter=MagicMock(), + db_session=MagicMock(), + forced_tool_id=None, + user_identity=MagicMock(), + chat_session_id="test-session", + ) + ) + + # Both models should have started close together (parallel) + assert len(execution_times) == 2 + time_diff = abs(execution_times[0] - execution_times[1]) + # They should start within 0.05 seconds of each other + assert time_diff < 0.05, f"Models did not start in parallel: {time_diff}s apart" + + def test_error_isolation(self) -> None: + """Test that one model's error doesn't crash others.""" + from onyx.chat.process_message import _run_multi_model_chat_loops + + llms = [ + self._create_mock_llm("model-1"), + self._create_mock_llm("model-2"), + ] + model_names = ["Model 1", "Model 2"] + state_containers = [ + self._create_mock_state_container(), + self._create_mock_state_container(), + ] + + call_count = [0] + + def mock_run_llm_loop( + emitter: Any, + simple_chat_history: Any, + tools: Any, + custom_agent_prompt: Any, + project_files: Any, + persona: Any, + memories: Any, + llm: Any, + token_counter: Any, + db_session: Any, + forced_tool_id: Any, + user_identity: Any, + chat_session_id: Any, + state_container: Any, + ) -> None: + model_idx = llms.index(llm) + call_count[0] += 1 + + if model_idx == 0: + # First model raises an error + raise RuntimeError("Model 1 failed!") + else: + # Second model succeeds + packet = Packet( + placement=Placement(turn_index=0), + obj=AgentResponseDelta(content="Success from model 2"), + ) + emitter.emit(packet) + state_container.answer_tokens = "Success from model 2" + + with patch( + "onyx.chat.process_message.run_llm_loop", side_effect=mock_run_llm_loop + ): + packets = list( + _run_multi_model_chat_loops( + llms=llms, + model_names=model_names, + emitter=MagicMock(), + state_containers=state_containers, + check_is_connected=lambda: True, + simple_chat_history=[], + tools=[], + custom_agent_prompt=None, + extracted_project_files=MagicMock(), + persona=MagicMock(), + memories=[], + token_counter=MagicMock(), + db_session=MagicMock(), + forced_tool_id=None, + user_identity=MagicMock(), + chat_session_id="test-session", + ) + ) + + # Both models should have been called + assert call_count[0] == 2 + + # Should have packets from both models (including error for model 0) + model_indices_seen = {p.placement.model_index for p in packets} + assert 0 in model_indices_seen # Error model + assert 1 in model_indices_seen # Success model + + # Model 1 should have an OverallStop packet with error stop_reason + model_0_stops = [ + p + for p in packets + if p.placement.model_index == 0 and isinstance(p.obj, OverallStop) + ] + assert len(model_0_stops) > 0 + assert model_0_stops[0].obj.stop_reason == "error" + + def test_user_cancellation(self) -> None: + """Test that user cancellation emits stop packets for incomplete models. + + Note: This tests the cancellation detection during the queue polling loop. + When check_is_connected returns False, the loop should emit stop packets + for any models that haven't completed yet. + """ + from onyx.chat.process_message import _run_multi_model_chat_loops + + llms = [ + self._create_mock_llm("model-1"), + self._create_mock_llm("model-2"), + ] + model_names = ["Model 1", "Model 2"] + state_containers = [ + self._create_mock_state_container(), + self._create_mock_state_container(), + ] + + # Event to control when model threads should complete + model_wait_event = threading.Event() + + def mock_run_llm_loop( + emitter: Any, + simple_chat_history: Any, + tools: Any, + custom_agent_prompt: Any, + project_files: Any, + persona: Any, + memories: Any, + llm: Any, + token_counter: Any, + db_session: Any, + forced_tool_id: Any, + user_identity: Any, + chat_session_id: Any, + state_container: Any, + ) -> None: + # Emit a packet + packet = Packet( + placement=Placement(turn_index=0), + obj=AgentResponseDelta(content="Starting..."), + ) + emitter.emit(packet) + + # Wait for event (simulates long-running operation) + # This will time out if cancellation happens first + model_wait_event.wait(timeout=2.0) + + state_container.answer_tokens = "Completed" + + # Start returning False after a short delay to trigger cancellation + call_count = [0] + + def check_is_connected() -> bool: + call_count[0] += 1 + # Return False after a few checks to trigger cancellation + if call_count[0] >= 3: + return False + return True + + with patch( + "onyx.chat.process_message.run_llm_loop", side_effect=mock_run_llm_loop + ): + packets = list( + _run_multi_model_chat_loops( + llms=llms, + model_names=model_names, + emitter=MagicMock(), + state_containers=state_containers, + check_is_connected=check_is_connected, + simple_chat_history=[], + tools=[], + custom_agent_prompt=None, + extracted_project_files=MagicMock(), + persona=MagicMock(), + memories=[], + token_counter=MagicMock(), + db_session=MagicMock(), + forced_tool_id=None, + user_identity=MagicMock(), + chat_session_id="test-session", + ) + ) + + # Release the wait event so threads can clean up + model_wait_event.set() + + # check_is_connected should have been called multiple times + assert call_count[0] >= 3 + + # Should have stop packets for models + stop_packets = [p for p in packets if isinstance(p.obj, OverallStop)] + + # When cancelled, we should get stop packets with user_cancelled reason + cancelled_stops = [ + p for p in stop_packets if p.obj.stop_reason == "user_cancelled" + ] + # Both models should have been cancelled since they were still waiting + assert len(cancelled_stops) == 2 + + +# ============================================================================= +# Integration-Level Tests (testing the detection logic) +# ============================================================================= + + +class TestMultiModelDetection: + """Tests for the multi-model detection logic in handle_stream_message_objects.""" + + def test_detection_with_one_model_uses_normal_mode(self) -> None: + """Test that 1 model in llm_overrides should NOT trigger multi-model mode. + + Note: The current implementation requires >= 2 models for multi-model mode. + If only 1 model is in llm_overrides, it falls through to single-model path. + """ + overrides = [LLMOverride(model_provider="openai", model_version="gpt-4")] + + # 1 override should NOT trigger multi-model mode + assert len(overrides) < 2 + + def test_detection_with_two_models_triggers_multi_model(self) -> None: + """Test that 2 models in llm_overrides triggers multi-model mode.""" + overrides = [ + LLMOverride(model_provider="openai", model_version="gpt-4"), + LLMOverride(model_provider="anthropic", model_version="claude-3"), + ] + + # 2 overrides should trigger multi-model mode + assert len(overrides) >= 2 + + def test_detection_with_three_models_triggers_multi_model(self) -> None: + """Test that 3 models in llm_overrides triggers multi-model mode.""" + overrides = [ + LLMOverride(model_provider="openai", model_version="gpt-4"), + LLMOverride(model_provider="anthropic", model_version="claude-3"), + LLMOverride(model_provider="google", model_version="gemini"), + ] + + # 3 overrides should trigger multi-model mode + assert len(overrides) >= 2 + + def test_model_name_fallback(self) -> None: + """Test model name fallback chain: model_version -> model_provider -> 'Model N'.""" + # Test with model_version + override1 = LLMOverride(model_provider="openai", model_version="gpt-4") + name1 = override1.model_version or override1.model_provider or "Model 1" + assert name1 == "gpt-4" + + # Test with only model_provider + override2 = LLMOverride(model_provider="anthropic") + name2 = override2.model_version or override2.model_provider or "Model 2" + assert name2 == "anthropic" + + # Test with neither + override3 = LLMOverride() + name3 = override3.model_version or override3.model_provider or "Model 3" + assert name3 == "Model 3" diff --git a/web/src/app/chat/components/ChatPage.tsx b/web/src/app/chat/components/ChatPage.tsx index d9854c8525e..1dce17d5928 100644 --- a/web/src/app/chat/components/ChatPage.tsx +++ b/web/src/app/chat/components/ChatPage.tsx @@ -9,7 +9,12 @@ import { import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { usePopup } from "@/components/admin/connectors/Popup"; import { SEARCH_PARAM_NAMES } from "@/app/chat/services/searchParams"; -import { useFederatedConnectors, useFilters, useLlmManager } from "@/lib/hooks"; +import { + useFederatedConnectors, + useFilters, + useLlmManager, + LlmDescriptor, +} from "@/lib/hooks"; import { useForcedTools } from "@/lib/hooks/useForcedTools"; import OnyxInitializingLoader from "@/components/OnyxInitializingLoader"; import { OnyxDocument, MinimalOnyxDocument } from "@/lib/search/interfaces"; @@ -257,6 +262,10 @@ export default function ChatPage({ firstMessage }: ChatPageProps) { const [projectPanelVisible, setProjectPanelVisible] = useState(true); const chatInputBarRef = useRef(null); + // Multi-model selection state (lifted from ChatInputBar for sharing with ChatUI) + const [multiModelMode, setMultiModelMode] = useState(false); + const [selectedModels, setSelectedModels] = useState([]); + const filterManager = useFilters(); const isDefaultAgent = useIsDefaultAgent({ @@ -428,11 +437,12 @@ export default function ChatPage({ firstMessage }: ChatPageProps) { } const handleChatInputSubmit = useCallback( - (message: string) => { + (message: string, selectedModels?: LlmDescriptor[]) => { onSubmit({ message, currentMessageFiles: currentMessageFiles, deepResearch: deepResearchEnabled, + selectedModels, }); if (showOnboarding) { finishOnboarding(); @@ -651,6 +661,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) { onMessageSelection={onMessageSelection} stopGenerating={stopGenerating} handleResubmitLastMessage={handleResubmitLastMessage} + selectedModels={selectedModels} /> )} @@ -710,6 +721,10 @@ export default function ChatPage({ firstMessage }: ChatPageProps) { (!isLoadingOnboarding && onboardingState.currentStep !== OnboardingStep.Complete) } + selectedModels={selectedModels} + setSelectedModels={setSelectedModels} + multiModelMode={multiModelMode} + setMultiModelMode={setMultiModelMode} /> diff --git a/web/src/app/chat/components/input/ChatInputBar.tsx b/web/src/app/chat/components/input/ChatInputBar.tsx index 086d15da191..36b3dc2c46c 100644 --- a/web/src/app/chat/components/input/ChatInputBar.tsx +++ b/web/src/app/chat/components/input/ChatInputBar.tsx @@ -9,6 +9,9 @@ import React, { import { FiPlus } from "react-icons/fi"; import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces"; import LLMPopover from "@/refresh-components/popovers/LLMPopover"; +import MultiModelSelector from "@/refresh-components/popovers/MultiModelSelector"; +import { useMultiModelEnabled } from "@/app/chat/hooks/useMultiModelEnabled"; +import { LlmDescriptor } from "@/lib/hooks"; import { InputPrompt } from "@/app/chat/interfaces"; import { FilterManager, LlmManager, useFederatedConnectors } from "@/lib/hooks"; import usePromptShortcuts from "@/hooks/usePromptShortcuts"; @@ -36,7 +39,13 @@ import { getIconForAction, hasSearchToolsAvailable, } from "@/app/chat/services/actionUtils"; -import { SvgArrowUp, SvgHourglass, SvgPlusCircle, SvgStop } from "@opal/icons"; +import { + SvgArrowUp, + SvgHourglass, + SvgPlusCircle, + SvgSliders, + SvgStop, +} from "@opal/icons"; const MAX_INPUT_HEIGHT = 200; @@ -89,7 +98,7 @@ export interface ChatInputBarProps { selectedDocuments: OnyxDocument[]; initialMessage?: string; stopGenerating: () => void; - onSubmit: (message: string) => void; + onSubmit: (message: string, selectedModels?: LlmDescriptor[]) => void; onHeightChange?: (delta: number) => void; llmManager: LlmManager; chatState: ChatState; @@ -107,6 +116,12 @@ export interface ChatInputBarProps { setPresentingDocument?: (document: MinimalOnyxDocument) => void; toggleDeepResearch: () => void; disabled: boolean; + + // Multi-model selection (lifted to parent) + selectedModels: LlmDescriptor[]; + setSelectedModels: (models: LlmDescriptor[]) => void; + multiModelMode: boolean; + setMultiModelMode: (mode: boolean) => void; } const ChatInputBar = React.memo( @@ -134,6 +149,11 @@ const ChatInputBar = React.memo( toggleDeepResearch, setPresentingDocument, disabled, + // Multi-model selection + selectedModels, + setSelectedModels, + multiModelMode, + setMultiModelMode, }, ref ) => { @@ -159,6 +179,9 @@ const ChatInputBar = React.memo( const { currentMessageFiles, setCurrentMessageFiles } = useProjectsContext(); + // Multi-model chat feature flag + const multiModelEnabled = useMultiModelEnabled(); + const currentIndexingFiles = useMemo(() => { return currentMessageFiles.filter( (file) => file.status === UserFileStatus.PROCESSING @@ -519,7 +542,12 @@ const ChatInputBar = React.memo( ) { event.preventDefault(); if (message) { - onSubmit(message); + // Pass selected models if in multi-model mode with 2+ models selected + const models = + multiModelMode && selectedModels.length >= 2 + ? selectedModels + : undefined; + onSubmit(message, models); } } }} @@ -674,16 +702,46 @@ const ChatInputBar = React.memo( {/* Bottom right controls */}
- {/* LLM popover - loads when ready */} + {/* Multi-model toggle - only shown when feature flag is enabled */} + {multiModelEnabled && ( + { + setMultiModelMode(!multiModelMode); + if (multiModelMode) { + setSelectedModels([]); + } + }} + engaged={multiModelMode} + action + folded + disabled={disabled} + className="bg-transparent" + > + Multi-model + + )} + + {/* LLM popover or Multi-model selector - loads when ready */}
- + {multiModelMode ? ( + + ) : ( + + )}
{/* Submit button - always visible */} @@ -697,7 +755,12 @@ const ChatInputBar = React.memo( if (chatState == "streaming") { stopGenerating(); } else if (message) { - onSubmit(message); + // Pass selected models if in multi-model mode with 2+ models selected + const models = + multiModelMode && selectedModels.length >= 2 + ? selectedModels + : undefined; + onSubmit(message, models); } }} /> diff --git a/web/src/app/chat/hooks/useChatController.ts b/web/src/app/chat/hooks/useChatController.ts index 0fcc268ca43..fd5a936a56a 100644 --- a/web/src/app/chat/hooks/useChatController.ts +++ b/web/src/app/chat/hooks/useChatController.ts @@ -33,6 +33,7 @@ import { FileDescriptor, Message, MessageResponseIDInfo, + MultiModelMessageResponseIDInfo, RegenerationState, RetrievalType, StreamingError, @@ -50,6 +51,7 @@ import { CurrentMessageFIFO, updateCurrentMessageFIFO, } from "../services/currentMessageFIFO"; +import { LLMOverrideParams } from "../services/lib"; import { buildFilters } from "@/lib/search/utils"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import { @@ -90,6 +92,8 @@ export interface OnSubmitProps { isSeededChat?: boolean; modelOverride?: LlmDescriptor; regenerationRequest?: RegenerationRequest | null; + // Multi-model chat: up to 3 models selected + selectedModels?: LlmDescriptor[]; } interface RegenerationRequest { @@ -348,7 +352,10 @@ export function useChatController({ isSeededChat, modelOverride, regenerationRequest, + selectedModels, }: OnSubmitProps) => { + // Check if this is multi-model mode (2 or 3 models selected) + const isMultiModelMode = selectedModels && selectedModels.length >= 2; const projectId = params(SEARCH_PARAM_NAMES.PROJECT_ID); { const params = new URLSearchParams(searchParams?.toString() || ""); @@ -566,9 +573,12 @@ export function useChatController({ // immediately reflects the user message let initialUserNode: Message; let initialAssistantNode: Message; + // For multi-model mode, we create 3 assistant nodes + let initialAssistantNodes: Message[] = []; if (regenerationRequest) { // For regeneration: keep the existing user message, only create new assistant + // Note: regeneration is not supported in multi-model mode initialUserNode = regenerationRequest.parentMessage; initialAssistantNode = buildEmptyMessage({ messageType: "assistant", @@ -588,12 +598,30 @@ export function useChatController({ ); initialUserNode = result.initialUserNode; initialAssistantNode = result.initialAssistantNode; + + // In multi-model mode, create N assistant nodes (one per selected model) + if (isMultiModelMode && selectedModels) { + for (let i = 0; i < selectedModels.length; i++) { + initialAssistantNodes.push( + buildEmptyMessage({ + messageType: "assistant", + parentNodeId: initialUserNode.nodeId, + nodeIdOffset: i + 1, + }) + ); + } + } } // make messages appear + clear input bar - const messagesToUpsert = regenerationRequest - ? [initialAssistantNode] // Only upsert the new assistant for regeneration - : [initialUserNode, initialAssistantNode]; // Upsert both for normal/edit flow + let messagesToUpsert: Message[]; + if (regenerationRequest) { + messagesToUpsert = [initialAssistantNode]; // Only upsert the new assistant for regeneration + } else if (isMultiModelMode) { + messagesToUpsert = [initialUserNode, ...initialAssistantNodes]; // Upsert user + 3 assistants + } else { + messagesToUpsert = [initialUserNode, initialAssistantNode]; // Upsert both for normal/edit flow + } currentMessageTreeLocal = upsertToCompleteMessageTree({ messages: messagesToUpsert, completeMessageTreeOverride: currentMessageTreeLocal, @@ -626,6 +654,21 @@ export function useChatController({ let newUserMessageId: number | null = null; let newAssistantMessageId: number | null = null; + // Multi-model mode state tracking (dynamically sized based on selected models) + const numModels = selectedModels?.length ?? 0; + let newAssistantMessageIds: (number | null)[] = isMultiModelMode + ? Array(numModels).fill(null) + : []; + let packetsPerModel: Packet[][] = isMultiModelMode + ? Array.from({ length: numModels }, () => []) + : []; + let documentsPerModel: OnyxDocument[][] = isMultiModelMode + ? Array.from({ length: numModels }, () => []) + : []; + let citationsPerModel: (CitationMap | null)[] = isMultiModelMode + ? Array(numModels).fill(null) + : []; + try { const lastSuccessfulMessageId = getLastSuccessfulMessageId( currentMessageTreeLocal @@ -648,6 +691,14 @@ export function useChatController({ ? forcedToolIds[0] : null; + // Build llmOverrides for multi-model mode + const llmOverrides: LLMOverrideParams[] | undefined = isMultiModelMode + ? selectedModels!.map((model) => ({ + model_provider: model.name, + model_version: model.modelName, + })) + : undefined; + const stack = new CurrentMessageFIFO(); updateCurrentMessageFIFO(stack, { signal: controller.signal, @@ -669,13 +720,16 @@ export function useChatController({ filterManager.timeRange, filterManager.selectedTags ), - modelProvider: - modelOverride?.name || llmManager.currentLlm.name || undefined, - modelVersion: - modelOverride?.modelName || - llmManager.currentLlm.modelName || - searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || - undefined, + // In multi-model mode, don't pass single model override + modelProvider: isMultiModelMode + ? undefined + : modelOverride?.name || llmManager.currentLlm.name || undefined, + modelVersion: isMultiModelMode + ? undefined + : modelOverride?.modelName || + llmManager.currentLlm.modelName || + searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || + undefined, temperature: llmManager.temperature || undefined, deepResearch, enabledToolIds: @@ -685,6 +739,7 @@ export function useChatController({ .map((tool) => tool.id) : undefined, forcedToolId: effectiveForcedToolId, + llmOverrides, }); const delay = (ms: number) => { @@ -708,12 +763,29 @@ export function useChatController({ // Transition from 'loading' to 'streaming'. updateChatStateAction(frozenSessionId, "streaming"); + // Handle message ID responses (single model and multi-model) if ((packet as MessageResponseIDInfo).user_message_id) { newUserMessageId = (packet as MessageResponseIDInfo) .user_message_id; } + // Multi-model mode: handle MultiModelMessageResponseIDInfo if ( + isMultiModelMode && + (packet as MultiModelMessageResponseIDInfo) + .reserved_assistant_message_ids + ) { + const multiModelResponse = + packet as MultiModelMessageResponseIDInfo; + newAssistantMessageIds = + multiModelResponse.reserved_assistant_message_ids; + newUserMessageId = + multiModelResponse.user_message_id ?? newUserMessageId; + } + + // Single model mode: handle MessageResponseIDInfo + if ( + !isMultiModelMode && (packet as MessageResponseIDInfo).reserved_assistant_message_id ) { newAssistantMessageId = (packet as MessageResponseIDInfo) @@ -764,31 +836,78 @@ export function useChatController({ } } else if (Object.hasOwn(packet, "obj")) { console.debug("Object packet:", JSON.stringify(packet)); - packets.push(packet as Packet); - - // Check if the packet contains document information - const packetObj = (packet as Packet).obj; - - if (packetObj.type === "citation_info") { - // Individual citation packet from backend streaming - const citationInfo = packetObj as { - type: "citation_info"; - citation_number: number; - document_id: string; - }; - // Incrementally build citations map - citations = { - ...(citations || {}), - [citationInfo.citation_number]: citationInfo.document_id, - }; - } else if (packetObj.type === "message_start") { - const messageStart = packetObj as MessageStart; - if (messageStart.final_documents) { - documents = messageStart.final_documents; - updateSelectedNodeForDocDisplay( - frozenSessionId, - initialAssistantNode.nodeId - ); + const typedPacket = packet as Packet; + + // In multi-model mode, route packets by model_index + if (isMultiModelMode) { + const modelIndex = typedPacket.placement?.model_index ?? 0; + if ( + modelIndex >= 0 && + modelIndex < packetsPerModel.length && + packetsPerModel[modelIndex] + ) { + // Create new array to ensure React detects the change + packetsPerModel[modelIndex] = [ + ...packetsPerModel[modelIndex]!, + typedPacket, + ]; + + // Check if the packet contains document information + const packetObj = typedPacket.obj; + + if (packetObj.type === "citation_info") { + const citationInfo = packetObj as { + type: "citation_info"; + citation_number: number; + document_id: string; + }; + citationsPerModel[modelIndex] = { + ...(citationsPerModel[modelIndex] || {}), + [citationInfo.citation_number]: citationInfo.document_id, + }; + } else if (packetObj.type === "message_start") { + const messageStart = packetObj as MessageStart; + if (messageStart.final_documents) { + documentsPerModel[modelIndex] = + messageStart.final_documents; + // Select the first model's documents for display + if (modelIndex === 0 && initialAssistantNodes[0]) { + updateSelectedNodeForDocDisplay( + frozenSessionId, + initialAssistantNodes[0].nodeId + ); + } + } + } + } + } else { + // Single model mode + packets.push(typedPacket); + + // Check if the packet contains document information + const packetObj = typedPacket.obj; + + if (packetObj.type === "citation_info") { + // Individual citation packet from backend streaming + const citationInfo = packetObj as { + type: "citation_info"; + citation_number: number; + document_id: string; + }; + // Incrementally build citations map + citations = { + ...(citations || {}), + [citationInfo.citation_number]: citationInfo.document_id, + }; + } else if (packetObj.type === "message_start") { + const messageStart = packetObj as MessageStart; + if (messageStart.final_documents) { + documents = messageStart.final_documents; + updateSelectedNodeForDocDisplay( + frozenSessionId, + initialAssistantNode.nodeId + ); + } } } } else { @@ -800,8 +919,43 @@ export function useChatController({ parentMessage = parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!; - currentMessageTreeLocal = upsertToCompleteMessageTree({ - messages: [ + // Build messages to upsert based on mode + let messagesToUpsertInLoop: Message[]; + + if (isMultiModelMode) { + // Multi-model mode: update user node + all 3 assistant nodes + const updatedUserNode = { + ...initialUserNode, + messageId: newUserMessageId ?? undefined, + files: files, + }; + + const updatedAssistantNodes = initialAssistantNodes.map( + (node, idx) => ({ + ...node, + messageId: newAssistantMessageIds[idx] ?? undefined, + message: "", + type: "assistant" as const, + retrievalType, + query: query, + documents: documentsPerModel[idx] || [], + citations: citationsPerModel[idx] || {}, + files: [] as FileDescriptor[], + toolCall: null, + stackTrace: null, + overridden_model: selectedModels?.[idx]?.modelName, + stopReason: stopReason, + packets: packetsPerModel[idx] || [], + }) + ); + + messagesToUpsertInLoop = [ + updatedUserNode, + ...updatedAssistantNodes, + ]; + } else { + // Single model mode (existing logic) + messagesToUpsertInLoop = [ { ...initialUserNode, messageId: newUserMessageId ?? undefined, @@ -823,7 +977,11 @@ export function useChatController({ stopReason: stopReason, packets: packets, }, - ], + ]; + } + + currentMessageTreeLocal = upsertToCompleteMessageTree({ + messages: messagesToUpsertInLoop, // Pass the latest map state completeMessageTreeOverride: currentMessageTreeLocal, chatSessionId: frozenSessionId!, diff --git a/web/src/app/chat/hooks/useChatSessionController.ts b/web/src/app/chat/hooks/useChatSessionController.ts index 91d8e6929ab..d6485c9c997 100644 --- a/web/src/app/chat/hooks/useChatSessionController.ts +++ b/web/src/app/chat/hooks/useChatSessionController.ts @@ -323,6 +323,10 @@ export function useChatSessionController({ const onMessageSelection = useCallback( (nodeId: number) => { + console.log( + "[MultiModel] onMessageSelection called with nodeId:", + nodeId + ); updateCurrentSelectedNodeForDocDisplay(nodeId); const currentMessageTree = useChatSessionStore .getState() @@ -330,15 +334,35 @@ export function useChatSessionController({ ?.messageTree; if (currentMessageTree) { + const message = currentMessageTree.get(nodeId); + const parent = message?.parentNodeId + ? currentMessageTree.get(message.parentNodeId) + : null; + console.log( + "[MultiModel] Message found:", + !!message, + "Parent found:", + !!parent + ); + console.log( + "[MultiModel] Parent childrenNodeIds:", + parent?.childrenNodeIds + ); + console.log( + "[MultiModel] Parent latestChildNodeId before:", + parent?.latestChildNodeId + ); + const newMessageTree = setMessageAsLatest(currentMessageTree, nodeId); + const treeChanged = newMessageTree !== currentMessageTree; + console.log("[MultiModel] Tree changed:", treeChanged); + const currentSessionId = useChatSessionStore.getState().currentSessionId; if (currentSessionId) { updateSessionMessageTree(currentSessionId, newMessageTree); } - const message = currentMessageTree.get(nodeId); - if (message?.messageId) { // Makes actual API call to set message as latest in the DB so we can // edit this message and so it sticks around on page reload @@ -346,6 +370,8 @@ export function useChatSessionController({ } else { console.error("Message has no messageId", nodeId); } + } else { + console.log("[MultiModel] No currentMessageTree found!"); } }, [updateCurrentSelectedNodeForDocDisplay, updateSessionMessageTree] diff --git a/web/src/app/chat/hooks/useMultiModelEnabled.ts b/web/src/app/chat/hooks/useMultiModelEnabled.ts new file mode 100644 index 00000000000..54b8d320c75 --- /dev/null +++ b/web/src/app/chat/hooks/useMultiModelEnabled.ts @@ -0,0 +1,25 @@ +import { useEffect, useState } from "react"; + +/** + * Feature flag hook that checks if multi-model chat is enabled via cookie. + * Set cookie `multi-model-enabled=true` to enable the feature. + * + * Example (in browser console): + * document.cookie = "multi-model-enabled=true"; + */ +export function useMultiModelEnabled(): boolean { + const [enabled, setEnabled] = useState(false); + + useEffect(() => { + const cookies = document.cookie.split(";"); + const multiModelCookie = cookies.find((c) => + c.trim().startsWith("multi-model-enabled=") + ); + if (multiModelCookie) { + const value = multiModelCookie.split("=")[1]?.trim(); + setEnabled(value === "true"); + } + }, []); + + return enabled; +} diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 3242156bb08..a0b1a7cf4e0 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -209,6 +209,12 @@ export interface MessageResponseIDInfo { reserved_assistant_message_id: number; } +export interface MultiModelMessageResponseIDInfo { + user_message_id: number | null; + reserved_assistant_message_ids: number[]; + model_names: string[]; +} + export interface UserKnowledgeFilePacket { user_files: FileDescriptor[]; } diff --git a/web/src/app/chat/message/MultiModelResponseView.tsx b/web/src/app/chat/message/MultiModelResponseView.tsx new file mode 100644 index 00000000000..2cb653a5ae7 --- /dev/null +++ b/web/src/app/chat/message/MultiModelResponseView.tsx @@ -0,0 +1,169 @@ +"use client"; + +import React, { useCallback, useMemo } from "react"; +import { Message, FeedbackType } from "@/app/chat/interfaces"; +import { Packet } from "@/app/chat/services/streamingModels"; +import { FullChatState } from "@/app/chat/message/messageComponents/interfaces"; +import { LlmManager, LlmDescriptor } from "@/lib/hooks"; +import AIMessage, { + AIMessageProps, + RegenerationFactory, +} from "@/app/chat/message/messageComponents/AIMessage"; +import Text from "@/refresh-components/texts/Text"; +import { cn } from "@/lib/utils"; +import { SvgCheck } from "@opal/icons"; +import AgentAvatar from "@/refresh-components/avatars/AgentAvatar"; + +export interface MultiModelResponse { + nodeId: number; + messageId?: number; + modelName: string; + packets: Packet[]; + isHighlighted: boolean; + currentFeedback?: FeedbackType | null; +} + +export interface MultiModelResponseViewProps { + responses: MultiModelResponse[]; + chatState: FullChatState; + llmManager: LlmManager | null; + parentMessage?: Message | null; + onHighlightChange: (nodeId: number) => void; + onRegenerate?: RegenerationFactory; +} + +export default function MultiModelResponseView({ + responses, + chatState, + llmManager, + parentMessage, + onHighlightChange, + onRegenerate, +}: MultiModelResponseViewProps) { + const handleSelectResponse = useCallback( + (nodeId: number) => { + onHighlightChange(nodeId); + }, + [onHighlightChange] + ); + + return ( +
+ {/* Single Onyx icon on the left */} +
+ +
+ + {/* Content area */} +
+ {/* Header */} +
+ + Answering with {responses.length} models - click to select + +
+ + {/* Responses Grid - dynamic columns based on number of models */} +
+ {responses.map((response, index) => ( + handleSelectResponse(response.nodeId)} + onRegenerate={onRegenerate} + index={index} + /> + ))} +
+
+
+ ); +} + +interface MultiModelResponseCardProps { + response: MultiModelResponse; + chatState: FullChatState; + llmManager: LlmManager | null; + parentMessage?: Message | null; + onSelect: () => void; + onRegenerate?: RegenerationFactory; + index: number; +} + +function MultiModelResponseCard({ + response, + chatState, + llmManager, + parentMessage, + onSelect, + onRegenerate, + index, +}: MultiModelResponseCardProps) { + const { + nodeId, + messageId, + modelName, + packets, + isHighlighted, + currentFeedback, + } = response; + + return ( +
+ {/* Model Header */} +
+
+ + {modelName} + +
+ {isHighlighted && ( +
+ + + Selected + +
+ )} +
+ + {/* Response Content */} +
+ +
+
+ ); +} diff --git a/web/src/app/chat/message/messageComponents/AIMessage.tsx b/web/src/app/chat/message/messageComponents/AIMessage.tsx index 56d01fee208..6404998ede7 100644 --- a/web/src/app/chat/message/messageComponents/AIMessage.tsx +++ b/web/src/app/chat/message/messageComponents/AIMessage.tsx @@ -82,6 +82,8 @@ export interface AIMessageProps { onRegenerate?: RegenerationFactory; // Parent message needed to construct regeneration request parentMessage?: Message | null; + // Hide the avatar icon (used in multi-model view where a single icon is shown outside) + hideAvatar?: boolean; } // TODO: Consider more robust comparisons: @@ -104,7 +106,9 @@ function arePropsEqual(prev: AIMessageProps, next: AIMessageProps): boolean { prev.otherMessagesCanSwitchTo === next.otherMessagesCanSwitchTo && prev.onRegenerate === next.onRegenerate && prev.parentMessage?.messageId === next.parentMessage?.messageId && - prev.llmManager?.isLoadingProviders === next.llmManager?.isLoadingProviders + prev.llmManager?.isLoadingProviders === + next.llmManager?.isLoadingProviders && + prev.hideAvatar === next.hideAvatar // Skip: chatState.regenerate, chatState.setPresentingDocument, // most of llmManager, onMessageSelection (function/object props) ); @@ -121,6 +125,7 @@ const AIMessage = React.memo(function AIMessage({ onMessageSelection, onRegenerate, parentMessage, + hideAvatar = false, }: AIMessageProps) { const markdownRef = useRef(null); const finalAnswerRef = useRef(null); @@ -530,7 +535,7 @@ const AIMessage = React.memo(function AIMessage({ data-testid={displayComplete ? "onyx-ai-message" : undefined} className="flex items-start pb-5 md:pt-5" > - + {!hideAvatar && } {/* w-full ensures the MultiToolRenderer non-expanded state takes up the full width */}
{ // Build payload for new send-chat-message API const payload = { @@ -142,14 +154,16 @@ export async function* sendMessage({ deep_research: deepResearch ?? false, allowed_tool_ids: enabledToolIds, forced_tool_id: forcedToolId ?? null, + // Use llm_overrides if provided (multi-model mode), otherwise use single llm_override llm_override: - temperature || modelVersion + !llmOverrides && (temperature || modelVersion) ? { temperature, model_provider: modelProvider, model_version: modelVersion, } : null, + llm_overrides: llmOverrides ?? null, }; const body = JSON.stringify(payload); diff --git a/web/src/app/chat/services/messageTree.ts b/web/src/app/chat/services/messageTree.ts index 66adb917c20..ed40bef3e94 100644 --- a/web/src/app/chat/services/messageTree.ts +++ b/web/src/app/chat/services/messageTree.ts @@ -239,7 +239,7 @@ export function setMessageAsLatest( } const newMessages = new Map(currentMessages); - const updatedParent = { + const updatedParent: Message = { ...parent, latestChildNodeId: nodeId, }; diff --git a/web/src/app/chat/services/streamingModels.ts b/web/src/app/chat/services/streamingModels.ts index 8d926548561..62b6f21da97 100644 --- a/web/src/app/chat/services/streamingModels.ts +++ b/web/src/app/chat/services/streamingModels.ts @@ -317,6 +317,7 @@ export interface Placement { turn_index: number; tab_index?: number; // For parallel tool calls - tools with same turn_index but different tab_index run in parallel sub_turn_index?: number | null; + model_index?: number; // For multi-model chat - identifies which model's response (0, 1, or 2) } // Packet wrapper for streaming objects diff --git a/web/src/refresh-components/popovers/MultiModelSelector.tsx b/web/src/refresh-components/popovers/MultiModelSelector.tsx new file mode 100644 index 00000000000..f74b462c904 --- /dev/null +++ b/web/src/refresh-components/popovers/MultiModelSelector.tsx @@ -0,0 +1,462 @@ +"use client"; + +import { useState, useEffect, useCallback, useMemo, useRef } from "react"; +import { + Popover, + PopoverContent, + PopoverTrigger, + PopoverMenu, +} from "@/components/ui/popover"; +import { LlmDescriptor, LlmManager } from "@/lib/hooks"; +import { structureValue } from "@/lib/llm/utils"; +import { + getProviderIcon, + AGGREGATOR_PROVIDERS, +} from "@/app/admin/configuration/llm/utils"; +import SelectButton from "@/refresh-components/buttons/SelectButton"; +import LineItem from "@/refresh-components/buttons/LineItem"; +import InputTypeIn from "@/refresh-components/inputs/InputTypeIn"; +import Text from "@/refresh-components/texts/Text"; +import SimpleLoader from "@/refresh-components/loaders/SimpleLoader"; +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "@/components/ui/accordion"; +import { + SvgCheck, + SvgChevronDown, + SvgChevronRight, + SvgSliders, + SvgX, +} from "@opal/icons"; +import { IconProps } from "@/components/icons/icons"; +import Checkbox from "@/refresh-components/inputs/Checkbox"; + +interface LLMOption { + name: string; + provider: string; + providerDisplayName: string; + modelName: string; + displayName: string; + description?: string; + vendor: string | null; + maxInputTokens?: number | null; + region?: string | null; + version?: string | null; + supportsReasoning?: boolean; + supportsImageInput?: boolean; +} + +export interface MultiModelSelectorProps { + llmManager: LlmManager; + selectedModels: LlmDescriptor[]; + onModelsChange: (models: LlmDescriptor[]) => void; + maxModels?: number; + disabled?: boolean; +} + +export default function MultiModelSelector({ + llmManager, + selectedModels, + onModelsChange, + maxModels = 3, + disabled = false, +}: MultiModelSelectorProps) { + const llmProviders = llmManager.llmProviders; + const isLoadingProviders = llmManager.isLoadingProviders; + + const [open, setOpen] = useState(false); + const [searchQuery, setSearchQuery] = useState(""); + + const searchInputRef = useRef(null); + const scrollContainerRef = useRef(null); + + const llmOptions = useMemo(() => { + if (!llmProviders) { + return []; + } + + const seenKeys = new Set(); + const options: LLMOption[] = []; + + llmProviders.forEach((llmProvider) => { + llmProvider.model_configurations + .filter((modelConfiguration) => modelConfiguration.is_visible) + .forEach((modelConfiguration) => { + const key = `${llmProvider.provider}:${modelConfiguration.name}`; + + if (seenKeys.has(key)) { + return; + } + seenKeys.add(key); + + const displayName = + modelConfiguration.display_name || modelConfiguration.name; + + options.push({ + name: llmProvider.name, + provider: llmProvider.provider, + providerDisplayName: + llmProvider.provider_display_name || llmProvider.provider, + modelName: modelConfiguration.name, + displayName, + vendor: modelConfiguration.vendor || null, + maxInputTokens: modelConfiguration.max_input_tokens, + region: modelConfiguration.region || null, + version: modelConfiguration.version || null, + supportsReasoning: modelConfiguration.supports_reasoning || false, + supportsImageInput: + modelConfiguration.supports_image_input || false, + }); + }); + }); + + return options; + }, [llmProviders]); + + const filteredOptions = useMemo(() => { + if (!searchQuery.trim()) { + return llmOptions; + } + const query = searchQuery.toLowerCase(); + return llmOptions.filter( + (opt) => + opt.displayName.toLowerCase().includes(query) || + opt.modelName.toLowerCase().includes(query) || + (opt.vendor && opt.vendor.toLowerCase().includes(query)) + ); + }, [llmOptions, searchQuery]); + + const groupedOptions = useMemo(() => { + const groups = new Map< + string, + { + displayName: string; + options: LLMOption[]; + Icon: React.FunctionComponent; + } + >(); + + filteredOptions.forEach((option) => { + const provider = option.provider.toLowerCase(); + const isAggregator = AGGREGATOR_PROVIDERS.has(provider); + + const groupKey = + isAggregator && option.vendor + ? `${provider}/${option.vendor.toLowerCase()}` + : provider; + + if (!groups.has(groupKey)) { + let displayName: string; + + if (isAggregator && option.vendor) { + const vendorDisplayName = + option.vendor.charAt(0).toUpperCase() + option.vendor.slice(1); + displayName = `${option.providerDisplayName}/${vendorDisplayName}`; + } else { + displayName = option.providerDisplayName; + } + + groups.set(groupKey, { + displayName, + options: [], + Icon: getProviderIcon(provider), + }); + } + + groups.get(groupKey)!.options.push(option); + }); + + const sortedKeys = Array.from(groups.keys()).sort((a, b) => + groups.get(a)!.displayName.localeCompare(groups.get(b)!.displayName) + ); + + return sortedKeys.map((key) => { + const group = groups.get(key)!; + return { + key, + displayName: group.displayName, + options: group.options, + Icon: group.Icon, + }; + }); + }, [filteredOptions]); + + const [expandedGroups, setExpandedGroups] = useState([]); + + useEffect(() => { + if (!open) { + setSearchQuery(""); + } else { + // Expand all groups by default when opening + setExpandedGroups(groupedOptions.map((g) => g.key)); + } + }, [open, groupedOptions]); + + const isSearching = searchQuery.trim().length > 0; + + const effectiveExpandedGroups = useMemo(() => { + if (isSearching) { + return groupedOptions.map((g) => g.key); + } + return expandedGroups; + }, [isSearching, groupedOptions, expandedGroups]); + + const handleAccordionChange = (value: string[]) => { + if (!isSearching) { + setExpandedGroups(value); + } + }; + + const isModelSelected = (option: LLMOption) => { + return selectedModels.some( + (m) => m.modelName === option.modelName && m.provider === option.provider + ); + }; + + const handleToggleModel = (option: LLMOption) => { + const isSelected = isModelSelected(option); + + if (isSelected) { + // Remove model + const newModels = selectedModels.filter( + (m) => + !(m.modelName === option.modelName && m.provider === option.provider) + ); + onModelsChange(newModels); + } else { + // Add model if under max + if (selectedModels.length < maxModels) { + const newModel: LlmDescriptor = { + name: option.name, + modelName: option.modelName, + provider: option.provider, + }; + onModelsChange([...selectedModels, newModel]); + } + } + }; + + const handleClearAll = () => { + onModelsChange([]); + }; + + const renderModelItem = (option: LLMOption) => { + const isSelected = isModelSelected(option); + const canSelect = selectedModels.length < maxModels || isSelected; + + const capabilities: string[] = []; + if (option.supportsReasoning) { + capabilities.push("Reasoning"); + } + if (option.supportsImageInput) { + capabilities.push("Vision"); + } + const description = + capabilities.length > 0 ? capabilities.join(", ") : undefined; + + return ( +
+ canSelect && handleToggleModel(option)} + icon={() => null} + rightChildren={ + canSelect && handleToggleModel(option)} + /> + } + className={!canSelect ? "opacity-50 cursor-not-allowed" : ""} + > + {option.displayName} + +
+ ); + }; + + const buttonLabel = useMemo(() => { + if (selectedModels.length === 0) { + return "Select models"; + } + return `${selectedModels.length}/${maxModels} models`; + }, [selectedModels.length, maxModels]); + + return ( + + +
+ setOpen(true)} + transient={open} + rightChevronIcon + disabled={disabled} + className={disabled ? "bg-transparent" : ""} + > + {buttonLabel} + +
+
+ +
+ {/* Header with count and clear button */} +
+ + Select up to {maxModels} models + + {selectedModels.length > 0 && ( + + )} +
+ + {/* Selected models display */} + {selectedModels.length > 0 && ( +
+ {selectedModels.map((model, idx) => { + const option = llmOptions.find( + (o) => + o.modelName === model.modelName && + o.provider === model.provider + ); + return ( +
+ + {option?.displayName || model.modelName} + + +
+ ); + })} +
+ )} + + {/* Search Input */} + setSearchQuery(e.target.value)} + placeholder="Search models..." + /> + + {/* Model List with Vendor Groups */} + + {isLoadingProviders + ? [ +
+ + + Loading models... + +
, + ] + : groupedOptions.length === 0 + ? [ +
+ + No models found + +
, + ] + : groupedOptions.length === 1 + ? [ +
+ {groupedOptions[0]!.options.map(renderModelItem)} +
, + ] + : [ + + {groupedOptions.map((group) => { + const isExpanded = effectiveExpandedGroups.includes( + group.key + ); + return ( + + +
+
+ +
+ + {group.displayName} + +
+
+
+ {isExpanded ? ( + + ) : ( + + )} +
+ + + +
+ {group.options.map(renderModelItem)} +
+
+ + ); + })} + , + ]} + +
+ + + ); +} diff --git a/web/src/sections/ChatUI.tsx b/web/src/sections/ChatUI.tsx index 91aa44d2869..40a61a70fbd 100644 --- a/web/src/sections/ChatUI.tsx +++ b/web/src/sections/ChatUI.tsx @@ -17,6 +17,10 @@ import { ErrorBanner } from "@/app/chat/message/Resubmit"; import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces"; import { LlmDescriptor, LlmManager } from "@/lib/hooks"; import AIMessage from "@/app/chat/message/messageComponents/AIMessage"; +import MultiModelResponseView, { + MultiModelResponse, +} from "@/app/chat/message/MultiModelResponseView"; +import { useMultiModelEnabled } from "@/app/chat/hooks/useMultiModelEnabled"; import { ProjectFile } from "@/app/chat/projects/projectsService"; import { useScrollonStream } from "@/app/chat/services/lib"; import useScreenSize from "@/hooks/useScreenSize"; @@ -59,10 +63,13 @@ export interface ChatUIProps { forceSearch?: boolean; queryOverride?: string; isSeededChat?: boolean; + selectedModels?: LlmDescriptor[]; }) => Promise; onMessageSelection: (nodeId: number) => void; stopGenerating: () => void; handleResubmitLastMessage: () => void; + // Multi-model selection from parent + selectedModels: LlmDescriptor[]; } const ChatUI = React.memo( @@ -78,6 +85,7 @@ const ChatUI = React.memo( onMessageSelection, stopGenerating, handleResubmitLastMessage, + selectedModels, }: ChatUIProps, ref: ForwardedRef ) => { @@ -89,6 +97,7 @@ const ChatUI = React.memo( const error = useUncaughtError(); const messageTree = useCurrentMessageTree(); const currentChatState = useCurrentChatState(); + const isMultiModelEnabled = useMultiModelEnabled(); // Stable fallbacks to avoid changing prop identities on each render const emptyDocs = useMemo(() => [], []); @@ -106,9 +115,11 @@ const ChatUI = React.memo( const onSubmitRef = useRef(onSubmit); const deepResearchEnabledRef = useRef(deepResearchEnabled); const currentMessageFilesRef = useRef(currentMessageFiles); + const selectedModelsRef = useRef(selectedModels); onSubmitRef.current = onSubmit; deepResearchEnabledRef.current = deepResearchEnabled; currentMessageFilesRef.current = currentMessageFiles; + selectedModelsRef.current = selectedModels; const createRegenerator = useCallback( (regenerationRequest: { @@ -138,11 +149,62 @@ const ChatUI = React.memo( messageIdToResend: msgId, currentMessageFiles: [], deepResearch: deepResearchEnabledRef.current, + // Preserve multi-model selection when editing + selectedModels: + selectedModelsRef.current.length >= 2 + ? selectedModelsRef.current + : undefined, }); }, [] // Stable - uses refs for latest values ); + // Handle multi-model response highlight change + const handleMultiModelHighlightChange = useCallback( + (nodeId: number) => { + console.log( + "[MultiModel] handleMultiModelHighlightChange called with nodeId:", + nodeId + ); + // Update local state immediately for responsive UI + onMessageSelection(nodeId); + }, + [onMessageSelection] + ); + + // Helper to check if a user message has multi-model responses + const getMultiModelResponses = useCallback( + (userMessage: Message): MultiModelResponse[] | null => { + if (!isMultiModelEnabled || !messageTree) return null; + + const childrenNodeIds = userMessage.childrenNodeIds || []; + // Multi-model mode creates 2 or 3 assistant children + if (childrenNodeIds.length < 2) return null; + + const childMessages = childrenNodeIds + .map((nodeId) => messageTree.get(nodeId)) + .filter( + (msg): msg is Message => + msg !== undefined && msg.type === "assistant" + ); + + // Need at least 2 assistant messages for multi-model view + if (childMessages.length < 2) return null; + + const latestChildNodeId = userMessage.latestChildNodeId; + + return childMessages.map((msg) => ({ + nodeId: msg.nodeId, + messageId: msg.messageId, + modelName: msg.overridden_model || "Model", + packets: msg.packets || [], + isHighlighted: msg.nodeId === latestChildNodeId, + currentFeedback: msg.currentFeedback, + })); + }, + [isMultiModelEnabled, messageTree] + ); + const handleScroll = useCallback(() => { const container = scrollContainerRef.current; if (!container) return; @@ -260,6 +322,9 @@ const ChatUI = React.memo( const nextMessage = messages.length > i + 1 ? messages[i + 1] : null; + // Check for multi-model responses + const multiModelResponses = getMultiModelResponses(message); + return (
+ + {/* Render MultiModelResponseView if this user message has multi-model responses */} + {multiModelResponses && ( + + )}
); } else if (message.type === "assistant") { @@ -302,6 +385,16 @@ const ChatUI = React.memo( // since this is a "parsed" version of the message tree // so the previous message is guaranteed to be the parent of the current message const previousMessage = i !== 0 ? messages[i - 1] : null; + + // Check if this assistant message is part of a multi-model response + // If so, skip rendering since it's already rendered in MultiModelResponseView + if ( + previousMessage?.type === "user" && + getMultiModelResponses(previousMessage) + ) { + return null; + } + const chatStateData = { assistant: liveAssistant, docs: message.documents ?? emptyDocs,