diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 1d2f4044b..b4bc94412 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -293,7 +293,7 @@ async def retrieve_response( # pylint: disable=too-many-locals response: Optional[OpenAIResponseObject] = None try: moderation_result = await run_shield_moderation( - client, responses_params.input, shield_ids + client, cast(str, responses_params.input), shield_ids ) if moderation_result.decision == "blocked": # Handle shield moderation blocking @@ -301,11 +301,13 @@ async def retrieve_response( # pylint: disable=too-many-locals await append_turn_to_conversation( client, responses_params.conversation, - responses_params.input, + cast(str, responses_params.input), violation_message, ) return TurnSummary(llm_response=violation_message) - response = await client.responses.create(**responses_params.model_dump()) + response = await client.responses.create( + **responses_params.model_dump(exclude_none=True) + ) response = cast(OpenAIResponseObject, response) except RuntimeError as e: # library mode wraps 413 into runtime error diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 5bd8bc52a..45998754e 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -285,14 +285,16 @@ async def retrieve_response_generator( turn_summary = TurnSummary() try: moderation_result = await run_shield_moderation( - context.client, responses_params.input, context.query_request.shield_ids + context.client, + cast(str, responses_params.input), + context.query_request.shield_ids, ) if moderation_result.decision == "blocked": turn_summary.llm_response = moderation_result.message await append_turn_to_conversation( context.client, responses_params.conversation, - responses_params.input, + cast(str, responses_params.input), moderation_result.message, ) media_type = context.query_request.media_type or MEDIA_TYPE_JSON @@ -302,7 +304,7 @@ async def retrieve_response_generator( ) # Retrieve response stream (may raise exceptions) response = await context.client.responses.create( - **responses_params.model_dump() + **responses_params.model_dump(exclude_none=True) ) # Store pre-RAG documents for later merging turn_summary.pre_rag_documents = doc_ids_from_chunks @@ -347,7 +349,7 @@ async def _persist_interrupted_turn( await append_turn_to_conversation( context.client, responses_params.conversation, - responses_params.input, + cast(str, responses_params.input), INTERRUPTED_RESPONSE_MESSAGE, ) except Exception: # pylint: disable=broad-except diff --git a/src/utils/query.py b/src/utils/query.py index 0bb47fb11..fef28eea6 100644 --- a/src/utils/query.py +++ b/src/utils/query.py @@ -223,7 +223,7 @@ def store_query_results( # pylint: disable=too-many-arguments summary: TurnSummary, query: str, skip_userid_check: bool, - attachments: list[Attachment] | None = None, + attachments: Optional[list[Attachment]] = None, topic_summary: Optional[str] = None, ) -> None: """ diff --git a/src/utils/responses.py b/src/utils/responses.py index a1d2c39e5..cbaac1a5d 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -11,6 +11,8 @@ OpenAIResponseContentPartRefusal as ContentPartRefusal, OpenAIResponseInputMessageContent as InputMessageContent, OpenAIResponseInputMessageContentText as InputTextPart, + OpenAIResponseInputToolFileSearch as InputToolFileSearch, + OpenAIResponseInputToolMCP as InputToolMCP, OpenAIResponseMessage as ResponseMessage, OpenAIResponseObject as ResponseObject, OpenAIResponseOutput as ResponseOutput, @@ -24,6 +26,7 @@ OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, OpenAIResponseMCPApprovalResponse as MCPApprovalResponse, OpenAIResponseUsage as ResponseUsage, + OpenAIResponseInputTool as InputTool, ) from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient @@ -101,12 +104,12 @@ async def get_topic_summary( async def prepare_tools( # pylint: disable=too-many-arguments,too-many-positional-arguments client: AsyncLlamaStackClient, - vector_store_ids: list[str] | None, - no_tools: bool | None, + vector_store_ids: Optional[list[str]], + no_tools: Optional[bool], token: str, - mcp_headers: McpHeaders | None = None, + mcp_headers: Optional[McpHeaders] = None, request_headers: Optional[Mapping[str, str]] = None, -) -> list[dict[str, Any]] | None: +) -> Optional[list[InputTool]]: """Prepare tools for Responses API including RAG and MCP tools. Args: @@ -124,7 +127,7 @@ async def prepare_tools( # pylint: disable=too-many-arguments,too-many-position if no_tools: return None - toolgroups = [] + toolgroups: list[InputTool] = [] # Get all vector stores if vector stores are not restricted by request if vector_store_ids is None: try: @@ -152,7 +155,7 @@ async def prepare_tools( # pylint: disable=too-many-arguments,too-many-position logger.debug( "Configured %d MCP tools: %s", len(mcp_tools), - [tool.get("server_label", "unknown") for tool in mcp_tools], + [tool.server_label for tool in mcp_tools], ) # Convert empty list to None for consistency with existing behavior if not toolgroups: @@ -162,8 +165,8 @@ async def prepare_tools( # pylint: disable=too-many-arguments,too-many-position def _build_provider_data_headers( - tools: list[dict[str, Any]] | None, -) -> dict[str, str] | None: + tools: Optional[list[InputTool]], +) -> Optional[dict[str, str]]: """Build extra HTTP headers containing MCP provider data for Llama Stack. Extracts per-server auth headers from MCP tool definitions and encodes @@ -181,9 +184,9 @@ def _build_provider_data_headers( return None mcp_headers: McpHeaders = { - tool["server_url"]: tool["headers"] + tool.server_url: tool.headers for tool in tools - if tool.get("type") == "mcp" and tool.get("headers") and tool.get("server_url") + if tool.type == "mcp" and tool.headers } if not mcp_headers: @@ -195,9 +198,9 @@ def _build_provider_data_headers( async def prepare_responses_params( # pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments client: AsyncLlamaStackClient, query_request: QueryRequest, - user_conversation: UserConversation | None, + user_conversation: Optional[UserConversation], token: str, - mcp_headers: McpHeaders | None = None, + mcp_headers: Optional[McpHeaders] = None, stream: bool = False, store: bool = True, request_headers: Optional[Mapping[str, str]] = None, @@ -287,25 +290,26 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma def extract_vector_store_ids_from_tools( - tools: list[dict[str, Any]] | None, + tools: Optional[list[InputTool]], ) -> list[str]: """Extract vector store IDs from prepared tool configurations. Parameters: - tools: The prepared tools list from ResponsesApiParams. + tools: The prepared tools list (typed InputTool or dict). Returns: List of vector store IDs used in file_search tools, or empty list. """ if not tools: return [] + vector_store_ids: list[str] = [] for tool in tools: - if tool.get("type") == "file_search": - return tool.get("vector_store_ids", []) - return [] + if tool.type == "file_search": + vector_store_ids.extend(tool.vector_store_ids) + return vector_store_ids -def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: +def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[InputToolFileSearch]]: """Convert vector store IDs to tools format for Responses API. Args: @@ -318,19 +322,18 @@ def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: return None return [ - { - "type": "file_search", - "vector_store_ids": vector_store_ids, - "max_num_results": 10, - } + InputToolFileSearch( + vector_store_ids=vector_store_ids, + max_num_results=10, + ) ] async def get_mcp_tools( # pylint: disable=too-many-return-statements,too-many-locals - token: str | None = None, - mcp_headers: McpHeaders | None = None, + token: Optional[str] = None, + mcp_headers: Optional[McpHeaders] = None, request_headers: Optional[Mapping[str, str]] = None, -) -> list[dict[str, Any]]: +) -> list[InputToolMCP]: """Convert MCP servers to tools format for Responses API. Args: @@ -348,7 +351,7 @@ async def get_mcp_tools( # pylint: disable=too-many-return-statements,too-many- no headers are passed, and the server responds with 401 and WWW-Authenticate. """ - def _get_token_value(original: str, header: str) -> str | None: + def _get_token_value(original: str, header: str) -> Optional[str]: """Convert to header value.""" match original: case constants.MCP_AUTH_KUBERNETES: @@ -376,18 +379,10 @@ def _get_token_value(original: str, header: str) -> str | None: # use provided return original - tools = [] + tools: list[InputToolMCP] = [] for mcp_server in configuration.mcp_servers: - # Base tool definition - tool_def = { - "type": "mcp", - "server_label": mcp_server.name, - "server_url": mcp_server.url, - "require_approval": "never", - } - # Build headers - headers = {} + headers: dict[str, str] = {} for name, value in mcp_server.resolved_authorization_headers.items(): # for each defined header h_value = _get_token_value(value, name) @@ -427,21 +422,23 @@ def _get_token_value(original: str, header: str) -> str | None: existing_lower.add(h_name.lower()) # Build Authorization header - if headers.get("Authorization"): - tool_def["authorization"] = headers.pop("Authorization") - - if len(headers) > 0: - # add headers to tool definition - tool_def["headers"] = headers # type: ignore[index] - # collect tools info - tools.append(tool_def) + authorization = headers.pop("Authorization", None) + tools.append( + InputToolMCP( + server_label=mcp_server.name, + server_url=mcp_server.url, + require_approval="never", + headers=headers if headers else None, + authorization=authorization, + ) + ) return tools def parse_referenced_documents( # pylint: disable=too-many-locals - response: ResponseObject | None, - vector_store_ids: list[str] | None = None, - rag_id_mapping: dict[str, str] | None = None, + response: Optional[ResponseObject], + vector_store_ids: Optional[list[str]] = None, + rag_id_mapping: Optional[dict[str, str]] = None, ) -> list[ReferencedDocument]: """Parse referenced documents from Responses API response. @@ -455,7 +452,7 @@ def parse_referenced_documents( # pylint: disable=too-many-locals """ documents: list[ReferencedDocument] = [] # Use a set to track unique documents by (doc_url, doc_title) tuple - seen_docs: set[tuple[str | None, str | None]] = set() + seen_docs: set[tuple[Optional[str], Optional[str]]] = set() # Handle None response (e.g., when agent fails) if response is None or not response.output: @@ -489,7 +486,7 @@ def parse_referenced_documents( # pylint: disable=too-many-locals doc_title = attributes.get("title") if doc_title or doc_url: - # Treat empty string as None for URL to satisfy AnyUrl | None + # Treat empty string as None for URL to satisfy Optional[AnyUrl] final_url = doc_url if doc_url else None if (final_url, doc_title) not in seen_docs: documents.append( @@ -504,7 +501,7 @@ def parse_referenced_documents( # pylint: disable=too-many-locals return documents -def extract_token_usage(usage: ResponseUsage | None, model: str) -> TokenCounter: +def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCounter: """Extract token usage from Responses API usage object and update metrics. Args: @@ -549,9 +546,9 @@ def extract_token_usage(usage: ResponseUsage | None, model: str) -> TokenCounter def build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches,too-many-locals output_item: ResponseOutput, rag_chunks: list[RAGChunk], - vector_store_ids: list[str] | None = None, - rag_id_mapping: dict[str, str] | None = None, -) -> tuple[ToolCallSummary | None, ToolResultSummary | None]: + vector_store_ids: Optional[list[str]] = None, + rag_id_mapping: Optional[dict[str, str]] = None, +) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]: """Translate Responses API tool outputs into ToolCallSummary and ToolResultSummary. Args: @@ -582,7 +579,7 @@ def build_tool_call_summary( # pylint: disable=too-many-return-statements,too-m extract_rag_chunks_from_file_search_item( file_search_item, rag_chunks, vector_store_ids, rag_id_mapping ) - response_payload: dict[str, Any] | None = None + response_payload: Optional[dict[str, Any]] = None if file_search_item.results is not None: response_payload = { "results": [result.model_dump() for result in file_search_item.results] @@ -709,7 +706,7 @@ def build_mcp_tool_call_from_arguments_done( output_index: int, arguments: str, mcp_call_items: dict[int, tuple[str, str]], -) -> ToolCallSummary | None: +) -> Optional[ToolCallSummary]: """Build ToolCallSummary from MCP call arguments completion event. Args: @@ -765,7 +762,7 @@ def _resolve_source_for_result( result: Any, vector_store_ids: list[str], rag_id_mapping: dict[str, str], -) -> str | None: +) -> Optional[str]: """Resolve the human-friendly index name for a file search result. Uses the vector store mapping to convert internal llama-stack IDs @@ -785,14 +782,14 @@ def _resolve_source_for_result( if len(vector_store_ids) > 1: attributes = getattr(result, "attributes", {}) or {} - attr_store_id: str | None = attributes.get("vector_store_id") + attr_store_id: Optional[str] = attributes.get("vector_store_id") if attr_store_id: return rag_id_mapping.get(attr_store_id, attr_store_id) return None -def _build_chunk_attributes(result: Any) -> dict[str, Any] | None: +def _build_chunk_attributes(result: Any) -> Optional[dict[str, Any]]: """Extract document metadata attributes from a file search result. Parameters: @@ -812,8 +809,8 @@ def _build_chunk_attributes(result: Any) -> dict[str, Any] | None: def extract_rag_chunks_from_file_search_item( item: FileSearchCall, rag_chunks: list[RAGChunk], - vector_store_ids: list[str] | None = None, - rag_id_mapping: dict[str, str] | None = None, + vector_store_ids: Optional[list[str]] = None, + rag_id_mapping: Optional[dict[str, str]] = None, ) -> None: """Extract RAG chunks from a file search tool call item. @@ -914,7 +911,7 @@ async def check_model_configured( async def select_model_for_responses( client: AsyncLlamaStackClient, - user_conversation: UserConversation | None, + user_conversation: Optional[UserConversation], ) -> str: """Select model for Responses API if not explicitly specified in the request. @@ -979,10 +976,10 @@ async def select_model_for_responses( def build_turn_summary( - response: ResponseObject | None, + response: Optional[ResponseObject], model: str, - vector_store_ids: list[str] | None = None, - rag_id_mapping: dict[str, str] | None = None, + vector_store_ids: Optional[list[str]] = None, + rag_id_mapping: Optional[dict[str, str]] = None, ) -> TurnSummary: """Build a TurnSummary from a ResponseObject. @@ -1023,7 +1020,7 @@ def build_turn_summary( def extract_text_from_response_items( - response_items: Sequence[ResponseItem] | None, + response_items: Optional[Sequence[ResponseItem]], ) -> str: """Extract text from response items iteratively. @@ -1102,7 +1099,7 @@ def deduplicate_referenced_documents( docs: list[ReferencedDocument], ) -> list[ReferencedDocument]: """Remove duplicate referenced documents based on URL and title.""" - seen: set[tuple[str | None, str | None]] = set() + seen: set[tuple[Optional[str], Optional[str]]] = set() out: list[ReferencedDocument] = [] for d in docs: key = (str(d.doc_url) if d.doc_url else None, d.doc_title) diff --git a/src/utils/types.py b/src/utils/types.py index 7d7df876b..220a85239 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -5,6 +5,8 @@ from llama_stack_api import ImageContentItem, TextContentItem from llama_stack_api.openai_responses import ( OpenAIResponseInputFunctionToolCallOutput as FunctionToolCallOutput, + OpenAIResponseInputTool as InputTool, + OpenAIResponseInputToolChoice as ToolChoice, OpenAIResponseMCPApprovalRequest as McpApprovalRequest, OpenAIResponseMCPApprovalResponse as McpApprovalResponse, OpenAIResponseMessage as ResponseMessage, @@ -13,6 +15,8 @@ OpenAIResponseOutputMessageMCPCall as McpCall, OpenAIResponseOutputMessageMCPListTools as McpListTools, OpenAIResponseOutputMessageWebSearchToolCall as WebSearchToolCall, + OpenAIResponsePrompt as Prompt, + OpenAIResponseText as Text, ) from llama_stack_client.lib.agents.tool_parser import ToolParser from llama_stack_client.lib.agents.types import ( @@ -129,21 +133,90 @@ class ShieldModerationBlocked(BaseModel): Field(discriminator="decision"), ] +IncludeParameter: TypeAlias = Literal[ + "web_search_call.action.sources", + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", +] + +ResponseItem: TypeAlias = ( + ResponseMessage + | WebSearchToolCall + | FileSearchToolCall + | FunctionToolCallOutput + | McpCall + | McpListTools + | McpApprovalRequest + | FunctionToolCall + | McpApprovalResponse +) + +ResponseInput: TypeAlias = str | list[ResponseItem] + class ResponsesApiParams(BaseModel): - """Parameters for a Llama Stack Responses API request.""" + """Parameters for a Llama Stack Responses API request. - input: str = Field(description="The input text with attachments appended") + All fields accepted by the Llama Stack client responses.create() body are + included so that dumped model can be passed directly to response create. + """ + + input: ResponseInput = Field(description="The input text or structured input items") model: str = Field(description='The full model ID in format "provider/model"') + conversation: str = Field(description="The conversation ID in llama-stack format") + include: Optional[list[IncludeParameter]] = Field( + default=None, + description="Output item types to include in the response", + ) instructions: Optional[str] = Field( default=None, description="The resolved system prompt" ) - tools: Optional[list[dict[str, Any]]] = Field( - default=None, description="Prepared tool groups for Responses API" + max_infer_iters: Optional[int] = Field( + default=None, + description="Maximum number of inference iterations", + ) + max_tool_calls: Optional[int] = Field( + default=None, + description="Maximum tool calls allowed in a single response", + ) + metadata: Optional[dict[str, str]] = Field( + default=None, + description="Custom metadata for tracking or logging", + ) + parallel_tool_calls: Optional[bool] = Field( + default=None, + description="Whether the model can make multiple tool calls in parallel", + ) + previous_response_id: Optional[str] = Field( + default=None, + description="Identifier of the previous response in a multi-turn conversation", + ) + prompt: Optional[Prompt] = Field( + default=None, + description="Prompt template with variables for dynamic substitution", ) - conversation: str = Field(description="The conversation ID in llama-stack format") - stream: bool = Field(description="Whether to stream the response") store: bool = Field(description="Whether to store the response") + stream: bool = Field(description="Whether to stream the response") + temperature: Optional[float] = Field( + default=None, + description="Sampling temperature (e.g. 0.0-2.0)", + ) + text: Optional[Text] = Field( + default=None, + description="Text response configuration (format constraints)", + ) + tool_choice: Optional[ToolChoice] = Field( + default=None, + description="Tool selection strategy", + ) + tools: Optional[list[InputTool]] = Field( + default=None, + description="Prepared tool groups for Responses API (same type as ResponsesRequest.tools)", + ) extra_headers: Optional[dict[str, str]] = Field( default=None, description="Extra HTTP headers to send with the request (e.g. x-llamastack-provider-data)", @@ -227,10 +300,10 @@ class TurnSummary(BaseModel): class TranscriptMetadata(BaseModel): """Metadata for a transcript entry.""" - provider: str | None = None + provider: Optional[str] = None model: str - query_provider: str | None = None - query_model: str | None = None + query_provider: Optional[str] = None + query_model: Optional[str] = None user_id: str conversation_id: str timestamp: str @@ -248,28 +321,3 @@ class Transcript(BaseModel): attachments: list[dict[str, Any]] = Field(default_factory=list) tool_calls: list[dict[str, Any]] = Field(default_factory=list) tool_results: list[dict[str, Any]] = Field(default_factory=list) - - -ResponseItem: TypeAlias = ( - ResponseMessage - | WebSearchToolCall - | FileSearchToolCall - | FunctionToolCallOutput - | McpCall - | McpListTools - | McpApprovalRequest - | FunctionToolCall - | McpApprovalResponse -) - -ResponseInput: TypeAlias = str | list[ResponseItem] - -IncludeParameter: TypeAlias = Literal[ - "web_search_call.action.sources", - "code_interpreter_call.outputs", - "computer_call_output.output.image_url", - "file_search_call.results", - "message.input_image.image_url", - "message.output_text.logprobs", - "reasoning.encrypted_content", -] diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index 4f760ef76..f1f3ef924 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -9,6 +9,8 @@ import pytest from fastapi import HTTPException from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolFileSearch as InputToolFileSearch, + OpenAIResponseInputToolMCP as InputToolMCP, OpenAIResponseOutputMessageFileSearchToolCall as FileSearchCall, OpenAIResponseOutputMessageFunctionToolCall as FunctionCall, OpenAIResponseOutputMessageMCPCall as MCPCall, @@ -339,9 +341,9 @@ def test_get_rag_tools_with_vector_stores(self) -> None: tools = get_rag_tools(["db1", "db2"]) assert isinstance(tools, list) assert len(tools) == 1 - assert tools[0]["type"] == "file_search" - assert tools[0]["vector_store_ids"] == ["db1", "db2"] - assert tools[0]["max_num_results"] == 10 + assert tools[0].type == "file_search" + assert tools[0].vector_store_ids == ["db1", "db2"] + assert tools[0].max_num_results == 10 class TestGetMCPTools: @@ -364,10 +366,10 @@ async def test_get_mcp_tools_without_auth(self, mocker: MockerFixture) -> None: tools_no_auth = await get_mcp_tools(token=None) assert len(tools_no_auth) == 2 - assert tools_no_auth[0]["type"] == "mcp" - assert tools_no_auth[0]["server_label"] == "fs" - assert tools_no_auth[0]["server_url"] == "http://localhost:3000" - assert "headers" not in tools_no_auth[0] + assert tools_no_auth[0].type == "mcp" + assert tools_no_auth[0].server_label == "fs" + assert tools_no_auth[0].server_url == "http://localhost:3000" + assert tools_no_auth[0].headers is None @pytest.mark.asyncio async def test_get_mcp_tools_with_kubernetes_auth( @@ -387,7 +389,7 @@ async def test_get_mcp_tools_with_kubernetes_auth( mocker.patch("utils.responses.configuration", mock_config) tools_k8s = await get_mcp_tools(token="user-k8s-token") assert len(tools_k8s) == 1 - assert tools_k8s[0]["authorization"] == "Bearer user-k8s-token" + assert tools_k8s[0].authorization == "Bearer user-k8s-token" @pytest.mark.asyncio async def test_get_mcp_tools_with_mcp_headers(self, mocker: MockerFixture) -> None: @@ -412,10 +414,10 @@ async def test_get_mcp_tools_with_mcp_headers(self, mocker: MockerFixture) -> No } tools = await get_mcp_tools(token=None, mcp_headers=mcp_headers) assert len(tools) == 1 - assert tools[0]["headers"] == { + assert tools[0].headers == { "X-Custom": "custom-value", } - assert tools[0]["authorization"] == "client-provided-token" + assert tools[0].authorization == "client-provided-token" # Test with mcp_headers=None (server should be skipped) tools_no_headers = await get_mcp_tools(token=None, mcp_headers=None) @@ -491,7 +493,7 @@ async def test_get_mcp_tools_with_static_headers( tools = await get_mcp_tools(token=None) assert len(tools) == 1 - assert tools[0]["authorization"] == "static-secret-token" + assert tools[0].authorization == "static-secret-token" @pytest.mark.asyncio async def test_get_mcp_tools_with_mixed_headers( @@ -525,8 +527,8 @@ async def test_get_mcp_tools_with_mixed_headers( tools = await get_mcp_tools(token="k8s-token", mcp_headers=mcp_headers) assert len(tools) == 1 - assert tools[0]["authorization"] == "Bearer k8s-token" - assert tools[0]["headers"] == { + assert tools[0].authorization == "Bearer k8s-token" + assert tools[0].headers == { "X-API-Key": "secret-api-key", "X-Custom": "client-custom-value", } @@ -576,8 +578,8 @@ async def test_get_mcp_tools_includes_server_without_auth( tools = await get_mcp_tools(token=None, mcp_headers=None) assert len(tools) == 1 - assert tools[0]["server_label"] == "public-server" - assert "headers" not in tools[0] + assert tools[0].server_label == "public-server" + assert tools[0].headers is None @pytest.mark.asyncio async def test_get_mcp_tools_oauth_no_headers_raises_401_with_www_authenticate( @@ -646,7 +648,7 @@ async def test_get_mcp_tools_with_propagated_headers( token=None, mcp_headers=None, request_headers=request_headers ) assert len(tools) == 1 - assert tools[0]["headers"] == { + assert tools[0].headers == { "x-rh-identity": "encoded-identity", "x-request-id": "req-456", } @@ -680,9 +682,10 @@ async def test_get_mcp_tools_propagated_headers_do_not_overwrite_auth_headers( ) assert len(tools) == 1 # Authorization from authorization_headers goes to tool's authorization field - assert tools[0]["authorization"] == "secret-token" + assert tools[0].authorization == "secret-token" # x-rh-identity from propagated headers should be in headers - assert tools[0]["headers"]["x-rh-identity"] == "identity-value" + assert tools[0].headers is not None + assert tools[0].headers["x-rh-identity"] == "identity-value" @pytest.mark.asyncio async def test_get_mcp_tools_propagated_headers_missing_from_request( @@ -707,7 +710,7 @@ async def test_get_mcp_tools_propagated_headers_missing_from_request( token=None, mcp_headers=None, request_headers=request_headers ) assert len(tools) == 1 - assert tools[0]["headers"] == {"x-rh-identity": "identity-value"} + assert tools[0].headers == {"x-rh-identity": "identity-value"} @pytest.mark.asyncio async def test_get_mcp_tools_propagated_headers_no_request_headers( @@ -753,8 +756,8 @@ async def test_get_mcp_tools_propagated_headers_additive_with_mcp_headers( token=None, mcp_headers=mcp_hdrs, request_headers=request_headers ) assert len(tools) == 1 - assert tools[0]["authorization"] == "Bearer client-token" - assert tools[0]["headers"] == {"x-rh-identity": "identity-value"} + assert tools[0].authorization == "Bearer client-token" + assert tools[0].headers == {"x-rh-identity": "identity-value"} @pytest.mark.asyncio async def test_get_mcp_tools_mixed_case_precedence( @@ -785,10 +788,11 @@ async def test_get_mcp_tools_mixed_case_precedence( ) assert len(tools) == 1 # Auth header goes to tool's authorization field - assert tools[0]["authorization"] == "file-secret" + assert tools[0].authorization == "file-secret" # Propagated header should be in headers (Authorization not in headers) - assert tools[0]["headers"]["x-rh-identity"] == "identity-value" - assert len(tools[0]["headers"]) == 1 + assert tools[0].headers is not None + assert tools[0].headers["x-rh-identity"] == "identity-value" + assert len(tools[0].headers) == 1 class TestGetTopicSummary: @@ -895,8 +899,8 @@ async def test_prepare_tools_with_vector_store_ids( result = await prepare_tools(mock_client, ["vs1", "vs2"], False, "token") assert result is not None assert len(result) == 1 - assert result[0]["type"] == "file_search" - assert result[0]["vector_store_ids"] == ["vs1", "vs2"] + assert result[0].type == "file_search" + assert result[0].vector_store_ids == ["vs1", "vs2"] @pytest.mark.asyncio async def test_prepare_tools_fetch_vector_stores( @@ -918,7 +922,7 @@ async def test_prepare_tools_fetch_vector_stores( result = await prepare_tools(mock_client, None, False, "token") assert result is not None assert len(result) == 1 - assert result[0]["vector_store_ids"] == ["vs1", "vs2"] + assert result[0].vector_store_ids == ["vs1", "vs2"] @pytest.mark.asyncio async def test_prepare_tools_connection_error(self, mocker: MockerFixture) -> None: @@ -938,13 +942,16 @@ async def test_prepare_tools_connection_error(self, mocker: MockerFixture) -> No async def test_prepare_tools_with_mcp_servers(self, mocker: MockerFixture) -> None: """Test prepare_tools includes MCP tools.""" mock_client = mocker.AsyncMock() - mock_mcp_tool = {"type": "mcp", "server_label": "test-server"} + mock_mcp_tool = InputToolMCP( + server_label="test-server", + server_url="http://test", + ) mocker.patch("utils.responses.get_mcp_tools", return_value=[mock_mcp_tool]) result = await prepare_tools(mock_client, ["vs1"], False, "token") assert result is not None assert len(result) == 2 # RAG tool + MCP tool - assert any(tool.get("type") == "mcp" for tool in result) + assert any(tool.type == "mcp" for tool in result) @pytest.mark.asyncio async def test_prepare_tools_api_status_error(self, mocker: MockerFixture) -> None: @@ -1146,20 +1153,18 @@ async def test_prepare_responses_params_includes_mcp_provider_data_headers( # Simulate MCP tools with headers (as returned by prepare_tools/get_mcp_tools) mcp_tools_with_headers = [ - { - "type": "mcp", - "server_label": "mcp::aap-controller", - "server_url": "http://aap.foo.redhat.com:8004/sse", - "require_approval": "never", - "headers": {"X-Authorization": "client-token"}, - }, - { - "type": "mcp", - "server_label": "mcp::aap-lightspeed", - "server_url": "http://aap.foo.redhat.com:8005/sse", - "require_approval": "never", - "headers": {"X-Authorization": "client-token-2"}, - }, + InputToolMCP( + server_label="mcp::aap-controller", + server_url="http://aap.foo.redhat.com:8004/sse", + require_approval="never", + headers={"X-Authorization": "client-token"}, + ), + InputToolMCP( + server_label="mcp::aap-lightspeed", + server_url="http://aap.foo.redhat.com:8005/sse", + require_approval="never", + headers={"X-Authorization": "client-token-2"}, + ), ] mock_config = mocker.Mock() @@ -1986,27 +1991,27 @@ class TestExtractVectorStoreIdsFromTools: def test_with_file_search_tool(self) -> None: """Test extraction from file_search tool definition.""" tools = [ - {"type": "file_search", "vector_store_ids": ["vs-1", "vs-2"]}, - {"type": "mcp", "server_label": "test"}, + InputToolFileSearch(vector_store_ids=["vs-1", "vs-2"]), + InputToolMCP(server_label="test", server_url="http://test"), ] result = extract_vector_store_ids_from_tools(tools) assert result == ["vs-1", "vs-2"] def test_with_no_file_search(self) -> None: """Test extraction returns empty list when no file_search tool.""" - tools = [{"type": "mcp", "server_label": "test"}] + tools = [InputToolMCP(server_label="test", server_url="http://test")] result = extract_vector_store_ids_from_tools(tools) - assert result == [] + assert not result def test_with_none_tools(self) -> None: """Test extraction returns empty list for None tools.""" result = extract_vector_store_ids_from_tools(None) - assert result == [] + assert not result def test_with_empty_tools(self) -> None: """Test extraction returns empty list for empty tools list.""" result = extract_vector_store_ids_from_tools([]) - assert result == [] + assert not result class TestExtractRagChunksWithIndexResolution: