diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index 8aa3a0420..e221a595f 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -107,7 +107,10 @@ async def get_conversations_list_endpoint_handler( return ConversationsListResponseV2(conversations=conversations) -@router.get("/conversations/{conversation_id}", responses=conversation_get_responses) +@router.get( + "/conversations/{conversation_id}", + responses=conversation_get_responses, +) @authorize(Action.GET_CONVERSATION) async def get_conversation_endpoint_handler( request: Request, # pylint: disable=unused-argument @@ -257,8 +260,12 @@ def build_conversation_turn_from_cache_entry(entry: CacheEntry) -> ConversationT """ # Create Message objects for user and assistant messages = [ - Message(content=entry.query, type="user"), - Message(content=entry.response, type="assistant"), + Message(content=entry.query, type="user", referenced_documents=None), + Message( + content=entry.response, + type="assistant", + referenced_documents=entry.referenced_documents or None, + ), ] # Extract tool calls and results (default to empty lists if None) diff --git a/src/models/responses.py b/src/models/responses.py index 946a71fbf..1de5ab523 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -866,6 +866,7 @@ class Message(BaseModel): Attributes: content: The message content. type: The type of message. + referenced_documents: Optional list of documents referenced in an assistant response. """ content: str = Field( @@ -878,6 +879,10 @@ class Message(BaseModel): description="The type of message", examples=["user", "assistant", "system", "developer"], ) + referenced_documents: Optional[list[ReferencedDocument]] = Field( + None, + description="List of documents referenced in the response (assistant messages only)", + ) class ConversationTurn(BaseModel): diff --git a/src/utils/conversations.py b/src/utils/conversations.py index 577c3fce7..c8fc496d5 100644 --- a/src/utils/conversations.py +++ b/src/utils/conversations.py @@ -71,7 +71,7 @@ def _parse_message_item(item: MessageOutput) -> Message: """ content_text = _extract_text_from_content(item.content) message_type = item.role - return Message(content=content_text, type=message_type) + return Message(content=content_text, type=message_type, referenced_documents=None) def _build_tool_call_summary_from_item( # pylint: disable=too-many-return-statements diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index d1423a7c2..5ca4faf0b 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -436,7 +436,7 @@ async def test_build_conversation_turns_from_items_with_model_dump( result = build_conversation_turns_from_items( mock_items, mock_db_turns, conversation_start_time ) - actual_history = [turn.model_dump() for turn in result] + actual_history = [turn.model_dump(exclude_none=True) for turn in result] assert actual_history == expected_chat_history @pytest.mark.asyncio @@ -745,7 +745,9 @@ async def test_successful_conversation_retrieval( assert isinstance(response, ConversationResponse) assert response.conversation_id == VALID_CONVERSATION_ID - actual_history = [turn.model_dump() for turn in response.chat_history] + actual_history = [ + turn.model_dump(exclude_none=True) for turn in response.chat_history + ] assert actual_history == expected_chat_history @pytest.mark.asyncio diff --git a/tests/unit/app/endpoints/test_conversations_v2.py b/tests/unit/app/endpoints/test_conversations_v2.py index 57019dad8..8e8787809 100644 --- a/tests/unit/app/endpoints/test_conversations_v2.py +++ b/tests/unit/app/endpoints/test_conversations_v2.py @@ -26,7 +26,7 @@ ConversationUpdateResponse, ) from tests.unit.utils.auth_helpers import mock_authorization_resolvers -from utils.types import ToolCallSummary, ToolResultSummary +from utils.types import ReferencedDocument, ToolCallSummary, ToolResultSummary MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") VALID_CONVERSATION_ID = "123e4567-e89b-12d3-a456-426614174000" @@ -95,6 +95,133 @@ def test_build_turn_with_tool_calls(self) -> None: assert len(turn.tool_results) == 1 assert turn.tool_results[0].status == "success" + def test_build_turn_with_referenced_documents(self) -> None: + """Test that referenced_documents from cache are included in the assistant message.""" + ref_docs = [ + ReferencedDocument( + doc_url="https://docs.example.com/page1", + doc_title="Page 1", + source="vs_abc123", + ), + ReferencedDocument( + doc_url="https://docs.example.com/page2", + doc_title="Page 2", + source="vs_abc123", + ), + ] + entry = CacheEntry( + query="What is RHDH?", + response="RHDH is a developer hub.", + provider="vllm", + model="llama-3", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + referenced_documents=ref_docs, + ) + + turn = build_conversation_turn_from_cache_entry(entry) + + assert len(turn.messages) == 2 + user_msg = turn.messages[0] + assistant_msg = turn.messages[1] + + assert user_msg.type == "user" + assert user_msg.referenced_documents is None + + assert assistant_msg.type == "assistant" + assert assistant_msg.referenced_documents is not None + assert len(assistant_msg.referenced_documents) == 2 + assert ( + str(assistant_msg.referenced_documents[0].doc_url) + == "https://docs.example.com/page1" + ) + assert assistant_msg.referenced_documents[0].doc_title == "Page 1" + assert ( + str(assistant_msg.referenced_documents[1].doc_url) + == "https://docs.example.com/page2" + ) + assert assistant_msg.referenced_documents[1].doc_title == "Page 2" + + def test_build_turn_without_referenced_documents(self) -> None: + """Test that assistant message has no referenced_documents when cache entry has none.""" + entry = CacheEntry( + query="Hello", + response="Hi there!", + provider="openai", + model="gpt-4", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + ) + + turn = build_conversation_turn_from_cache_entry(entry) + + assert turn.messages[1].type == "assistant" + assert turn.messages[1].referenced_documents is None + + def test_build_turn_with_empty_referenced_documents(self) -> None: + """Test assistant message has no referenced_documents when cache has empty list.""" + entry = CacheEntry( + query="Hello", + response="Hi there!", + provider="openai", + model="gpt-4", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + referenced_documents=[], + ) + + turn = build_conversation_turn_from_cache_entry(entry) + + assert turn.messages[1].type == "assistant" + assert turn.messages[1].referenced_documents is None + + def test_build_turn_serialization_excludes_none_referenced_documents(self) -> None: + """Test that model_dump(exclude_none=True) omits referenced_documents when None.""" + entry = CacheEntry( + query="Hello", + response="Hi there!", + provider="openai", + model="gpt-4", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + ) + + turn = build_conversation_turn_from_cache_entry(entry) + dumped = turn.model_dump(exclude_none=True) + + user_msg_dict = dumped["messages"][0] + assistant_msg_dict = dumped["messages"][1] + assert "referenced_documents" not in user_msg_dict + assert "referenced_documents" not in assistant_msg_dict + + def test_build_turn_serialization_includes_referenced_documents(self) -> None: + """Test that model_dump(exclude_none=True) includes referenced_documents when present.""" + ref_docs = [ + ReferencedDocument( + doc_url="https://docs.example.com/page1", + doc_title="Page 1", + ), + ] + entry = CacheEntry( + query="What is RHDH?", + response="RHDH is a developer hub.", + provider="vllm", + model="llama-3", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + referenced_documents=ref_docs, + ) + + turn = build_conversation_turn_from_cache_entry(entry) + dumped = turn.model_dump(exclude_none=True) + + user_msg_dict = dumped["messages"][0] + assistant_msg_dict = dumped["messages"][1] + assert "referenced_documents" not in user_msg_dict + assert "referenced_documents" in assistant_msg_dict + assert len(assistant_msg_dict["referenced_documents"]) == 1 + assert assistant_msg_dict["referenced_documents"][0]["doc_title"] == "Page 1" + @pytest.fixture def mock_configuration(mocker: MockerFixture) -> MockType: @@ -428,6 +555,95 @@ async def test_successful_retrieval( assert len(response.chat_history) == 1 assert response.chat_history[0].messages[0].content == "query" + @pytest.mark.asyncio + async def test_successful_retrieval_includes_referenced_documents( + self, mocker: MockerFixture, mock_configuration: MockType + ) -> None: + """Test that GET conversation includes referenced_documents in assistant messages.""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) + mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) + ref_docs = [ + ReferencedDocument( + doc_url="https://docs.example.com/intro", + doc_title="Introduction", + source="vs_abc123", + ), + ReferencedDocument( + doc_url="https://docs.example.com/guide", + doc_title="User Guide", + source="vs_abc123", + ), + ] + mock_configuration.conversation_cache.list.return_value = [ + mocker.Mock(conversation_id=VALID_CONVERSATION_ID) + ] + mock_configuration.conversation_cache.get.return_value = [ + CacheEntry( + query="What is RHDH?", + response="RHDH is a developer hub.", + provider="vllm", + model="llama-3", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + referenced_documents=ref_docs, + ) + ] + + response = await get_conversation_endpoint_handler( + request=mocker.Mock(), + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, + ) + + assert response is not None + assert len(response.chat_history) == 1 + turn = response.chat_history[0] + + user_msg = turn.messages[0] + assert user_msg.type == "user" + assert user_msg.referenced_documents is None + + assistant_msg = turn.messages[1] + assert assistant_msg.type == "assistant" + assert assistant_msg.referenced_documents is not None + assert len(assistant_msg.referenced_documents) == 2 + assert assistant_msg.referenced_documents[0].doc_title == "Introduction" + assert assistant_msg.referenced_documents[1].doc_title == "User Guide" + + @pytest.mark.asyncio + async def test_successful_retrieval_without_referenced_documents( + self, mocker: MockerFixture, mock_configuration: MockType + ) -> None: + """Test that GET conversation works when cache entry has no referenced_documents.""" + mock_authorization_resolvers(mocker) + mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) + mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) + mock_configuration.conversation_cache.list.return_value = [ + mocker.Mock(conversation_id=VALID_CONVERSATION_ID) + ] + mock_configuration.conversation_cache.get.return_value = [ + CacheEntry( + query="Hello", + response="Hi there!", + provider="openai", + model="gpt-4", + started_at="2024-01-01T00:00:00Z", + completed_at="2024-01-01T00:00:05Z", + ) + ] + + response = await get_conversation_endpoint_handler( + request=mocker.Mock(), + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, + ) + + assert response is not None + turn = response.chat_history[0] + assert turn.messages[1].type == "assistant" + assert turn.messages[1].referenced_documents is None + @pytest.mark.asyncio async def test_with_skip_userid_check( self, mocker: MockerFixture, mock_configuration: MockType diff --git a/tests/unit/models/responses/test_successful_responses.py b/tests/unit/models/responses/test_successful_responses.py index 408ce8725..4bade3991 100644 --- a/tests/unit/models/responses/test_successful_responses.py +++ b/tests/unit/models/responses/test_successful_responses.py @@ -607,8 +607,10 @@ def test_constructor(self) -> None: ) assert isinstance(response, AbstractSuccessfulResponse) assert response.conversation_id == "123e4567-e89b-12d3-a456-426614174000" - # Convert ConversationTurn objects to dicts for comparison - actual_history = [turn.model_dump() for turn in response.chat_history] + # Convert ConversationTurn objects to dicts for comparison (exclude None for clean output) + actual_history = [ + turn.model_dump(exclude_none=True) for turn in response.chat_history + ] assert actual_history == chat_history def test_empty_chat_history(self) -> None: