diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 3a36f8b8..a3c75fdd 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -312,6 +312,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals generate_topic_summary=True, media_type=None, vector_store_ids=vector_store_ids, + shield_ids=None, ) # Get LLM client and select model diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py new file mode 100644 index 00000000..dfd5f766 --- /dev/null +++ b/src/app/endpoints/query_v2.py @@ -0,0 +1,862 @@ +# pylint: disable=too-many-locals,too-many-branches,too-many-nested-blocks + +"""Handler for REST API call to provide answer to query using Response API.""" + +import json +import logging +from typing import Annotated, Any, Optional, cast + +from fastapi import APIRouter, Depends, Request +from llama_stack_api.openai_responses import ( + OpenAIResponseMCPApprovalRequest, + OpenAIResponseMCPApprovalResponse, + OpenAIResponseObject, + OpenAIResponseOutput, + OpenAIResponseOutputMessageFileSearchToolCall, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseOutputMessageWebSearchToolCall, +) +from llama_stack_client import AsyncLlamaStackClient + +import constants +import metrics +from app.endpoints.query import ( + query_endpoint_handler_base, + validate_attachments_metadata, +) +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from configuration import AppConfig, configuration +from constants import DEFAULT_RAG_TOOL +from models.config import Action, ModelContextProtocolServer +from models.requests import QueryRequest +from models.responses import ( + ForbiddenResponse, + InternalServerErrorResponse, + NotFoundResponse, + QueryResponse, + QuotaExceededResponse, + ReferencedDocument, + ServiceUnavailableResponse, + UnauthorizedResponse, + UnprocessableEntityResponse, +) +from utils.endpoints import ( + check_configuration_loaded, + get_system_prompt, + get_topic_summary_system_prompt, +) +from utils.mcp_headers import mcp_headers_dependency +from utils.query import parse_arguments_string +from utils.responses import extract_text_from_response_output_item +from utils.shields import ( + append_turn_to_conversation, + run_shield_moderation, +) +from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id +from utils.token_counter import TokenCounter +from utils.types import RAGChunk, ToolCallSummary, ToolResultSummary, TurnSummary + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["query_v1"]) + +query_v2_response: dict[int | str, dict[str, Any]] = { + 200: QueryResponse.openapi_response(), + 401: UnauthorizedResponse.openapi_response( + examples=["missing header", "missing token"] + ), + 403: ForbiddenResponse.openapi_response( + examples=["endpoint", "conversation read", "model override"] + ), + 404: NotFoundResponse.openapi_response( + examples=["conversation", "model", "provider"] + ), + # 413: PromptTooLongResponse.openapi_response(), + 422: UnprocessableEntityResponse.openapi_response(), + 429: QuotaExceededResponse.openapi_response(), + 500: InternalServerErrorResponse.openapi_response(examples=["configuration"]), + 503: ServiceUnavailableResponse.openapi_response(), +} + + +def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches + output_item: OpenAIResponseOutput, + rag_chunks: list[RAGChunk], +) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]: + """Translate Responses API tool outputs into ToolCallSummary and ToolResultSummary records. + + Processes OpenAI response output items and extracts tool call and result information. + Also parses RAG chunks from file_search_call items and appends them to the provided list. + + Args: + output_item: An OpenAIResponseOutput item from the response.output array + rag_chunks: List to append extracted RAG chunks to (from file_search_call items) + Returns: + A tuple of (ToolCallSummary, ToolResultSummary) one of them possibly None + if current llama stack Responses API does not provide the information. + + Supported tool types: + - function_call: Function tool calls with parsed arguments (no result) + - file_search_call: File search operations with results (also extracts RAG chunks) + - web_search_call: Web search operations (incomplete) + - mcp_call: MCP calls with server labels + - mcp_list_tools: MCP server tool listings + - mcp_approval_request: MCP approval requests (no result) + - mcp_approval_response: MCP approval responses (no call) + """ + item_type = getattr(output_item, "type", None) + + if item_type == "function_call": + item = cast(OpenAIResponseOutputMessageFunctionToolCall, output_item) + return ( + ToolCallSummary( + id=item.call_id, + name=item.name, + args=parse_arguments_string(item.arguments), + type="function_call", + ), + None, # not supported by Responses API at all + ) + + if item_type == "file_search_call": + file_search_item = cast( + OpenAIResponseOutputMessageFileSearchToolCall, output_item + ) + extract_rag_chunks_from_file_search_item(file_search_item, rag_chunks) + response_payload: Optional[dict[str, Any]] = None + if file_search_item.results is not None: + response_payload = { + "results": [result.model_dump() for result in file_search_item.results] + } + return ToolCallSummary( + id=file_search_item.id, + name=DEFAULT_RAG_TOOL, + args={"queries": file_search_item.queries}, + type="file_search_call", + ), ToolResultSummary( + id=file_search_item.id, + status=file_search_item.status, + content=json.dumps(response_payload) if response_payload else "", + type="file_search_call", + round=1, + ) + + # Incomplete OpenAI Responses API definition in LLS: action attribute not supported yet + if item_type == "web_search_call": + web_search_item = cast( + OpenAIResponseOutputMessageWebSearchToolCall, output_item + ) + return ( + ToolCallSummary( + id=web_search_item.id, + name="web_search", + args={}, + type="web_search_call", + ), + ToolResultSummary( + id=web_search_item.id, + status=web_search_item.status, + content="", + type="web_search_call", + round=1, + ), + ) + + if item_type == "mcp_call": + mcp_call_item = cast(OpenAIResponseOutputMessageMCPCall, output_item) + args = parse_arguments_string(mcp_call_item.arguments) + if mcp_call_item.server_label: + args["server_label"] = mcp_call_item.server_label + content = ( + mcp_call_item.error + if mcp_call_item.error + else (mcp_call_item.output if mcp_call_item.output else "") + ) + + return ToolCallSummary( + id=mcp_call_item.id, + name=mcp_call_item.name, + args=args, + type="mcp_call", + ), ToolResultSummary( + id=mcp_call_item.id, + status="success" if mcp_call_item.error is None else "failure", + content=content, + type="mcp_call", + round=1, + ) + + if item_type == "mcp_list_tools": + mcp_list_tools_item = cast(OpenAIResponseOutputMessageMCPListTools, output_item) + tools_info = [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.input_schema, + } + for tool in mcp_list_tools_item.tools + ] + content_dict = { + "server_label": mcp_list_tools_item.server_label, + "tools": tools_info, + } + return ( + ToolCallSummary( + id=mcp_list_tools_item.id, + name="mcp_list_tools", + args={"server_label": mcp_list_tools_item.server_label}, + type="mcp_list_tools", + ), + ToolResultSummary( + id=mcp_list_tools_item.id, + status="success", + content=json.dumps(content_dict), + type="mcp_list_tools", + round=1, + ), + ) + + if item_type == "mcp_approval_request": + approval_request_item = cast(OpenAIResponseMCPApprovalRequest, output_item) + args = parse_arguments_string(approval_request_item.arguments) + return ( + ToolCallSummary( + id=approval_request_item.id, + name=approval_request_item.name, + args=args, + type="tool_call", + ), + None, + ) + + if item_type == "mcp_approval_response": + approval_response_item = cast(OpenAIResponseMCPApprovalResponse, output_item) + content_dict = {} + if approval_response_item.reason: + content_dict["reason"] = approval_response_item.reason + return ( + None, + ToolResultSummary( + id=approval_response_item.approval_request_id, + status="success" if approval_response_item.approve else "denied", + content=json.dumps(content_dict), + type="mcp_approval_response", + round=1, + ), + ) + + return None, None + + +async def get_topic_summary( # pylint: disable=too-many-nested-blocks + question: str, client: AsyncLlamaStackClient, model_id: str +) -> str: + """ + Get a topic summary for a question using Responses API. + + This is the Responses API version of get_topic_summary, which uses + client.responses.create() instead of the Agent API. + + Args: + question: The question to generate a topic summary for + client: The AsyncLlamaStackClient to use for the request + model_id: The llama stack model ID (full format: provider/model) + + Returns: + str: The topic summary for the question + """ + topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) + + # Use Responses API to generate topic summary + response = cast( + OpenAIResponseObject, + await client.responses.create( + input=question, + model=model_id, + instructions=topic_summary_system_prompt, + stream=False, + store=False, # Don't store topic summary requests + ), + ) + + # Extract text from response output + summary_text = "".join( + extract_text_from_response_output_item(output_item) + for output_item in response.output + ) + + return summary_text.strip() if summary_text else "" + + +@router.post("/query", responses=query_v2_response, summary="Query Endpoint Handler V1") +@authorize(Action.QUERY) +async def query_endpoint_handler_v2( + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> QueryResponse: + """ + Handle request to the /query endpoint using Responses API. + + This is a wrapper around query_endpoint_handler_base that provides + the Responses API specific retrieve_response and get_topic_summary functions. + + Returns: + QueryResponse: Contains the conversation ID and the LLM-generated response. + """ + check_configuration_loaded(configuration) + return await query_endpoint_handler_base( + request=request, + query_request=query_request, + auth=auth, + mcp_headers=mcp_headers, + retrieve_response_func=retrieve_response, + get_topic_summary_func=get_topic_summary, + ) + + +async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, + mcp_headers: Optional[dict[str, dict[str, str]]] = None, + *, + provider_id: str = "", +) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]: + """ + Retrieve response from LLMs and agents. + + Retrieves a response from the Llama Stack LLM or agent for a + given query, handling shield configuration, tool usage, and + attachment validation. + + This function configures system prompts, shields, and toolgroups + (including RAG and MCP integration) as needed based on + the query request and system configuration. It + validates attachments, manages conversation and session + context, and processes MCP headers for multi-component + processing. Corresponding metrics are updated. + + Parameters: + client (AsyncLlamaStackClient): The AsyncLlamaStackClient to use for the request. + model_id (str): The identifier of the LLM model to use. + query_request (QueryRequest): The user's query and associated metadata. + token (str): The authentication token for authorization. + mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. + provider_id (str): The identifier of the LLM provider to use. + + Returns: + tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content + and the conversation ID, the list of parsed referenced documents, + and token usage information. + """ + # use system prompt from request or default one + system_prompt = get_system_prompt(query_request, configuration) + logger.debug("Using system prompt: %s", system_prompt) + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + + # Prepare tools for responses API + toolgroups = await prepare_tools_for_responses_api( + client, query_request, token, configuration, mcp_headers + ) + + # Prepare input for Responses API + # Convert attachments to text and concatenate with query + input_text = query_request.query + if query_request.attachments: + for attachment in query_request.attachments: + # Append attachment content with type label + input_text += ( + f"\n\n[Attachment: {attachment.attachment_type}]\n{attachment.content}" + ) + + # Handle conversation ID for Responses API + # Create conversation upfront if not provided + conversation_id = query_request.conversation_id + if conversation_id: + # Conversation ID was provided - convert to llama-stack format + logger.debug("Using existing conversation ID: %s", conversation_id) + llama_stack_conv_id = to_llama_stack_conversation_id(conversation_id) + else: + # No conversation_id provided - create a new conversation first + logger.debug("No conversation_id provided, creating new conversation") + + conversation = await client.conversations.create(metadata={}) + llama_stack_conv_id = conversation.id + # Store the normalized version for later use + conversation_id = normalize_conversation_id(llama_stack_conv_id) + logger.info( + "Created new conversation with ID: %s (normalized: %s)", + llama_stack_conv_id, + conversation_id, + ) + + # Run shield moderation before calling LLM + moderation_result = await run_shield_moderation( + client, input_text, query_request.shield_ids + ) + if moderation_result.blocked: + violation_message = moderation_result.message or "" + await append_turn_to_conversation( + client, llama_stack_conv_id, input_text, violation_message + ) + summary = TurnSummary( + llm_response=violation_message, + tool_calls=[], + tool_results=[], + rag_chunks=[], + ) + return ( + summary, + normalize_conversation_id(conversation_id), + [], + TokenCounter(), + ) + + # Create OpenAI response using responses API + create_kwargs: dict[str, Any] = { + "input": input_text, + "model": model_id, + "instructions": system_prompt, + "tools": cast(Any, toolgroups), + "stream": False, + "store": True, + "conversation": llama_stack_conv_id, + } + + response = await client.responses.create(**create_kwargs) + response = cast(OpenAIResponseObject, response) + logger.debug( + "Received response with ID: %s, conversation ID: %s, output items: %d", + response.id, + conversation_id, + len(response.output), + ) + + # Process OpenAI response format + llm_response = "" + tool_calls: list[ToolCallSummary] = [] + tool_results: list[ToolResultSummary] = [] + rag_chunks: list[RAGChunk] = [] + for output_item in response.output: + message_text = extract_text_from_response_output_item(output_item) + if message_text: + llm_response += message_text + + tool_call, tool_result = _build_tool_call_summary(output_item, rag_chunks) + if tool_call: + tool_calls.append(tool_call) + if tool_result: + tool_results.append(tool_result) + + logger.info( + "Response processing complete - Tool calls: %d, Response length: %d chars", + len(tool_calls), + len(llm_response), + ) + + summary = TurnSummary( + llm_response=llm_response, + tool_calls=tool_calls, + tool_results=tool_results, + rag_chunks=rag_chunks, + ) + + # Extract referenced documents and token usage from Responses API response + referenced_documents = parse_referenced_documents_from_responses_api(response) + model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id + token_usage = extract_token_usage_from_responses_api( + response, model_label, provider_id, system_prompt + ) + + if not summary.llm_response: + logger.warning( + "Response lacks content (conversation_id=%s)", + conversation_id, + ) + + return ( + summary, + normalize_conversation_id(conversation_id), + referenced_documents, + token_usage, + ) + + +def extract_rag_chunks_from_file_search_item( + item: OpenAIResponseOutputMessageFileSearchToolCall, + rag_chunks: list[RAGChunk], +) -> None: + """Extract RAG chunks from a file search tool call item and append to rag_chunks. + + Args: + item: The file search tool call item. + rag_chunks: List to append extracted RAG chunks to. + """ + if item.results is not None: + for result in item.results: + rag_chunk = RAGChunk( + content=result.text, source=result.filename, score=result.score + ) + rag_chunks.append(rag_chunk) + + +def parse_rag_chunks_from_responses_api( + response_obj: OpenAIResponseObject, +) -> list[RAGChunk]: + """ + Extract rag_chunks from the llama-stack OpenAI response. + + Args: + response_obj: The ResponseObject from OpenAI compatible response API in llama-stack. + + Returns: + List of RAGChunk with content, source, score + """ + rag_chunks: list[RAGChunk] = [] + + for output_item in response_obj.output: + item_type = getattr(output_item, "type", None) + if item_type == "file_search_call": + item = cast(OpenAIResponseOutputMessageFileSearchToolCall, output_item) + extract_rag_chunks_from_file_search_item(item, rag_chunks) + + return rag_chunks + + +def parse_referenced_documents_from_responses_api( + response: OpenAIResponseObject, # pylint: disable=unused-argument +) -> list[ReferencedDocument]: + """ + Parse referenced documents from OpenAI Responses API response. + + Args: + response: The OpenAI Response API response object + + Returns: + list[ReferencedDocument]: List of referenced documents with doc_url and doc_title + """ + documents: list[ReferencedDocument] = [] + # Use a set to track unique documents by (doc_url, doc_title) tuple + seen_docs: set[tuple[Optional[str], Optional[str]]] = set() + + # Handle None response (e.g., when agent fails) + if response is None or not response.output: + return documents + + for output_item in response.output: + item_type = getattr(output_item, "type", None) + + # 1. Parse from file_search_call results + if item_type == "file_search_call": + results = getattr(output_item, "results", []) or [] + for result in results: + # Handle both object and dict access + if isinstance(result, dict): + attributes = result.get("attributes", {}) + else: + attributes = getattr(result, "attributes", {}) + + # Try to get URL from attributes + # Look for common URL fields in attributes + doc_url = ( + attributes.get("doc_url") + or attributes.get("docs_url") + or attributes.get("url") + or attributes.get("link") + ) + doc_title = attributes.get("title") + + if doc_title or doc_url: + # Treat empty string as None for URL to satisfy Optional[AnyUrl] + final_url = doc_url if doc_url else None + if (final_url, doc_title) not in seen_docs: + documents.append( + ReferencedDocument(doc_url=final_url, doc_title=doc_title) + ) + seen_docs.add((final_url, doc_title)) + + return documents + + +def extract_token_usage_from_responses_api( + response: OpenAIResponseObject, + model: str, + provider: str, + system_prompt: str = "", # pylint: disable=unused-argument +) -> TokenCounter: + """ + Extract token usage from OpenAI Responses API response and update metrics. + + This function extracts token usage information from the Responses API response + object and updates Prometheus metrics. If usage information is not available, + it returns zero values without estimation. + + Note: When llama stack internally uses chat_completions, the usage field may be + empty or a dict. This is expected and will be populated in future llama stack versions. + + Args: + response: The OpenAI Response API response object + model: The model identifier for metrics labeling + provider: The provider identifier for metrics labeling + system_prompt: The system prompt used (unused, kept for compatibility) + + Returns: + TokenCounter: Token usage information with input_tokens and output_tokens + """ + token_counter = TokenCounter() + token_counter.llm_calls = 1 + + # Extract usage from the response if available + # Note: usage attribute exists at runtime but may not be in type definitions + usage = getattr(response, "usage", None) + if usage: + try: + # Handle both dict and object cases due to llama_stack inconsistency: + # - When llama_stack converts to chat_completions internally, usage is a dict + # - When using proper Responses API, usage should be an object + # TODO: Remove dict handling once llama_stack standardizes on object type # pylint: disable=fixme + if isinstance(usage, dict): + input_tokens = usage.get("input_tokens", 0) + output_tokens = usage.get("output_tokens", 0) + else: + # Object with attributes (expected final behavior) + input_tokens = getattr(usage, "input_tokens", 0) + output_tokens = getattr(usage, "output_tokens", 0) + # Only set if we got valid values + if input_tokens or output_tokens: + token_counter.input_tokens = input_tokens or 0 + token_counter.output_tokens = output_tokens or 0 + + logger.debug( + "Extracted token usage from Responses API: input=%d, output=%d", + token_counter.input_tokens, + token_counter.output_tokens, + ) + + # Update Prometheus metrics only when we have actual usage data + try: + metrics.llm_token_sent_total.labels(provider, model).inc( + token_counter.input_tokens + ) + metrics.llm_token_received_total.labels(provider, model).inc( + token_counter.output_tokens + ) + except (AttributeError, TypeError, ValueError) as e: + logger.warning("Failed to update token metrics: %s", e) + _increment_llm_call_metric(provider, model) + else: + logger.debug( + "Usage object exists but tokens are 0 or None, treating as no usage info" + ) + # Still increment the call counter + _increment_llm_call_metric(provider, model) + except (AttributeError, KeyError, TypeError) as e: + logger.warning( + "Failed to extract token usage from response.usage: %s. Usage value: %s", + e, + usage, + ) + # Still increment the call counter + _increment_llm_call_metric(provider, model) + else: + # No usage information available - this is expected when llama stack + # internally converts to chat_completions + logger.debug( + "No usage information in Responses API response, token counts will be 0" + ) + # token_counter already initialized with 0 values + # Still increment the call counter + _increment_llm_call_metric(provider, model) + + return token_counter + + +def _increment_llm_call_metric(provider: str, model: str) -> None: + """Safely increment LLM call metric.""" + try: + metrics.llm_calls_total.labels(provider, model).inc() + except (AttributeError, TypeError, ValueError) as e: + logger.warning("Failed to update LLM call metric: %s", e) + + +def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]: + """ + Convert vector store IDs to tools format for Responses API. + + Args: + vector_store_ids: List of vector store identifiers + + Returns: + Optional[list[dict[str, Any]]]: List containing file_search tool configuration, + or None if no vector stores provided + """ + if not vector_store_ids: + return None + + return [ + { + "type": "file_search", + "vector_store_ids": vector_store_ids, + "max_num_results": 10, + } + ] + + +def get_mcp_tools( + mcp_servers: list[ModelContextProtocolServer], + token: str | None = None, + mcp_headers: dict[str, dict[str, str]] | None = None, +) -> list[dict[str, Any]]: + """ + Convert MCP servers to tools format for Responses API. + + Args: + mcp_servers: List of MCP server configurations + token: Optional authentication token for MCP server authorization + mcp_headers: Optional per-request headers for MCP servers, keyed by server URL + + Returns: + list[dict[str, Any]]: List of MCP tool definitions with server + details and optional auth headers + + The way it works is we go through all the defined mcp servers and + create a tool definitions for each of them. If MCP server definition + has a non-empty resolved_authorization_headers we create invocation + headers, following the algorithm: + 1. If the header value is 'kubernetes' the header value is a k8s token + 2. If the header value is 'client': + find the value for a given MCP server/header in mcp_headers. + if the value is not found omit this header, otherwise use found value + 3. otherwise use the value from resolved_authorization_headers directly + + This algorithm allows to: + 1. Use static global header values, provided by configuration + 2. Use user specific k8s token, which will work for the majority of kubernetes + based MCP servers + 3. Use user specific tokens (passed by the client) for user specific MCP headers + """ + + def _get_token_value(original: str, header: str) -> str | None: + """Convert to header value.""" + match original: + case constants.MCP_AUTH_KUBERNETES: + # use k8s token + if token is None or token == "": + return None + return f"Bearer {token}" + case constants.MCP_AUTH_CLIENT: + # use client provided token + if mcp_headers is None: + return None + c_headers = mcp_headers.get(mcp_server.name, None) + if c_headers is None: + return None + return c_headers.get(header, None) + case _: + # use provided + return original + + tools = [] + for mcp_server in mcp_servers: + # Base tool definition + tool_def = { + "type": "mcp", + "server_label": mcp_server.name, + "server_url": mcp_server.url, + "require_approval": "never", + } + + # Build headers + headers = {} + for name, value in mcp_server.resolved_authorization_headers.items(): + # for each defined header + h_value = _get_token_value(value, name) + # only add the header if we got value + if h_value is not None: + headers[name] = h_value + + # Skip server if auth headers were configured but not all could be resolved + if mcp_server.authorization_headers and len(headers) != len( + mcp_server.authorization_headers + ): + logger.warning( + "Skipping MCP server %s: required %d auth headers but only resolved %d", + mcp_server.name, + len(mcp_server.authorization_headers), + len(headers), + ) + continue + + if len(headers) > 0: + # add headers to tool definition + tool_def["headers"] = headers # type: ignore[index] + # collect tools info + tools.append(tool_def) + return tools + + +async def prepare_tools_for_responses_api( + client: AsyncLlamaStackClient, + query_request: QueryRequest, + token: str, + config: AppConfig, + mcp_headers: Optional[dict[str, dict[str, str]]] = None, +) -> Optional[list[dict[str, Any]]]: + """ + Prepare tools for Responses API including RAG and MCP tools. + + This function retrieves vector stores and combines them with MCP + server tools to create a unified toolgroups list for the Responses API. + + Args: + client: The Llama Stack client instance + query_request: The user's query request + token: Authentication token for MCP tools + config: Configuration object containing MCP server settings + mcp_headers: Per-request headers for MCP servers + + Returns: + Optional[list[dict[str, Any]]]: List of tool configurations for the + Responses API, or None if no_tools is True or no tools are available + """ + if query_request.no_tools: + return None + + toolgroups = [] + # Get vector stores for RAG tools - use specified ones or fetch all + if query_request.vector_store_ids: + vector_store_ids = query_request.vector_store_ids + else: + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data + ] + + # Add RAG tools if vector stores are available + rag_tools = get_rag_tools(vector_store_ids) + if rag_tools: + toolgroups.extend(rag_tools) + + # Add MCP server tools + mcp_tools = get_mcp_tools(config.mcp_servers, token, mcp_headers) + if mcp_tools: + toolgroups.extend(mcp_tools) + logger.debug( + "Configured %d MCP tools: %s", + len(mcp_tools), + [tool.get("server_label", "unknown") for tool in mcp_tools], + ) + # Convert empty list to None for consistency with existing behavior + if not toolgroups: + return None + + return toolgroups diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py new file mode 100644 index 00000000..787b4565 --- /dev/null +++ b/src/app/endpoints/streaming_query_v2.py @@ -0,0 +1,480 @@ +"""Streaming query handler using Responses API (v2).""" + +import logging +from typing import Annotated, Any, AsyncIterator, Optional, cast + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import StreamingResponse +from llama_stack_api.openai_responses import ( + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseFailed, + OpenAIResponseObjectStreamResponseOutputItemDone, + OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseObjectStreamResponseOutputTextDone, +) +from llama_stack_client import AsyncLlamaStackClient + +from app.endpoints.query import ( + is_transcripts_enabled, + persist_user_conversation_details, + validate_attachments_metadata, +) +from app.endpoints.query_v2 import ( + _build_tool_call_summary, + extract_token_usage_from_responses_api, + get_topic_summary, + parse_referenced_documents_from_responses_api, + prepare_tools_for_responses_api, +) +from app.endpoints.streaming_query import ( + LLM_TOKEN_EVENT, + LLM_TOOL_CALL_EVENT, + LLM_TOOL_RESULT_EVENT, + format_stream_data, + stream_end_event, + stream_event, + stream_start_event, + streaming_query_endpoint_handler_base, +) +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from configuration import configuration +from constants import MEDIA_TYPE_JSON +from models.config import Action +from models.context import ResponseGeneratorContext +from models.requests import QueryRequest +from models.responses import ( + ForbiddenResponse, + InternalServerErrorResponse, + NotFoundResponse, + QuotaExceededResponse, + ServiceUnavailableResponse, + StreamingQueryResponse, + UnauthorizedResponse, + UnprocessableEntityResponse, +) +from utils.endpoints import ( + cleanup_after_streaming, + get_system_prompt, +) +from utils.query import create_violation_stream +from utils.quota import consume_tokens, get_available_quotas +from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id +from utils.mcp_headers import mcp_headers_dependency +from utils.shields import ( + append_turn_to_conversation, + run_shield_moderation, +) +from utils.token_counter import TokenCounter +from utils.transcripts import store_transcript +from utils.types import RAGChunk, TurnSummary + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["streaming_query_v1"]) +auth_dependency = get_auth_dependency() + +streaming_query_v2_responses: dict[int | str, dict[str, Any]] = { + 200: StreamingQueryResponse.openapi_response(), + 401: UnauthorizedResponse.openapi_response( + examples=["missing header", "missing token"] + ), + 403: ForbiddenResponse.openapi_response( + examples=["conversation read", "endpoint", "model override"] + ), + 404: NotFoundResponse.openapi_response( + examples=["conversation", "model", "provider"] + ), + # 413: PromptTooLongResponse.openapi_response(), + 422: UnprocessableEntityResponse.openapi_response(), + 429: QuotaExceededResponse.openapi_response(), + 500: InternalServerErrorResponse.openapi_response(examples=["configuration"]), + 503: ServiceUnavailableResponse.openapi_response(), +} + + +def create_responses_response_generator( # pylint: disable=too-many-locals,too-many-statements + context: ResponseGeneratorContext, +) -> Any: + """ + Create a response generator function for Responses API streaming. + + This factory function returns an async generator that processes streaming + responses from the Responses API and yields Server-Sent Events (SSE). + + Args: + context: Context object containing all necessary parameters for response generation + + Returns: + An async generator function that yields SSE-formatted strings + """ + + async def response_generator( # pylint: disable=too-many-branches,too-many-statements + turn_response: AsyncIterator[OpenAIResponseObjectStream], + ) -> AsyncIterator[str]: + """ + Generate SSE formatted streaming response. + + Asynchronously generates a stream of Server-Sent Events + (SSE) representing incremental responses from a + language model turn. + + Yields start, token, tool call, turn completion, and + end events as SSE-formatted strings. Collects the + complete response for transcript storage if enabled. + """ + chunk_id = 0 + summary = TurnSummary( + llm_response="", tool_calls=[], tool_results=[], rag_chunks=[] + ) + + # Determine media type for response formatting + media_type = context.query_request.media_type or MEDIA_TYPE_JSON + + # Accumulators for Responses API + text_parts: list[str] = [] + emitted_turn_complete = False + + # Use the conversation_id from context (either provided or newly created) + conv_id = context.conversation_id + + # Track the latest response object from response.completed event + latest_response_object: Optional[Any] = None + + # RAG chunks + rag_chunks: list[RAGChunk] = [] + + logger.debug("Starting streaming response (Responses API) processing") + + async for chunk in turn_response: + event_type = getattr(chunk, "type", None) + logger.debug("Processing chunk %d, type: %s", chunk_id, event_type) + + # Emit start event when response is created + if event_type == "response.created": + yield stream_start_event(conv_id) + + # Text streaming + if event_type == "response.output_text.delta": + delta_chunk = cast( + OpenAIResponseObjectStreamResponseOutputTextDelta, chunk + ) + if delta_chunk.delta: + text_parts.append(delta_chunk.delta) + yield stream_event( + { + "id": chunk_id, + "token": delta_chunk.delta, + }, + LLM_TOKEN_EVENT, + media_type, + ) + chunk_id += 1 + + # Final text of the output (capture, but emit at response.completed) + elif event_type == "response.output_text.done": + text_done_chunk = cast( + OpenAIResponseObjectStreamResponseOutputTextDone, chunk + ) + if text_done_chunk.text: + summary.llm_response = text_done_chunk.text + + # Content part started - emit an empty token to kick off UI streaming + elif event_type == "response.content_part.added": + yield stream_event( + { + "id": chunk_id, + "token": "", + }, + LLM_TOKEN_EVENT, + media_type, + ) + chunk_id += 1 + + # Process tool calls and results are emitted together when output items are done + # TODO(asimurka): support emitting tool calls and results separately when ready + elif event_type == "response.output_item.done": + output_item_done_chunk = cast( + OpenAIResponseObjectStreamResponseOutputItemDone, chunk + ) + if output_item_done_chunk.item.type == "message": + continue + tool_call, tool_result = _build_tool_call_summary( + output_item_done_chunk.item, rag_chunks + ) + if tool_call: + summary.tool_calls.append(tool_call) + yield stream_event( + tool_call.model_dump(), + LLM_TOOL_CALL_EVENT, + media_type, + ) + if tool_result: + summary.tool_results.append(tool_result) + yield stream_event( + tool_result.model_dump(), + LLM_TOOL_RESULT_EVENT, + media_type, + ) + + # Completed response - capture final text and response object + elif event_type == "response.completed": + # Capture the response object for token usage extraction + completed_chunk = cast( + OpenAIResponseObjectStreamResponseCompleted, chunk + ) + latest_response_object = completed_chunk.response + + if not emitted_turn_complete: + final_message = summary.llm_response or "".join(text_parts) + if not final_message: + final_message = "No response from the model" + summary.llm_response = final_message + yield stream_event( + { + "id": chunk_id, + "token": final_message, + }, + "turn_complete", + media_type, + ) + chunk_id += 1 + emitted_turn_complete = True + + # Incomplete response - emit error because LLS does not + # support incomplete responses "incomplete_detail" attribute yet + elif event_type == "response.incomplete": + error_response = InternalServerErrorResponse.query_failed( + "An unexpected error occurred while processing the request." + ) + logger.error("Error while obtaining answer for user question") + yield format_stream_data( + {"event": "error", "data": {**error_response.detail.model_dump()}} + ) + return + + # Failed response - emit error with custom cause from error message + elif event_type == "response.failed": + failed_chunk = cast(OpenAIResponseObjectStreamResponseFailed, chunk) + latest_response_object = failed_chunk.response + error_message = ( + failed_chunk.response.error.message + if failed_chunk.response.error + else "An unexpected error occurred while processing the request." + ) + error_response = InternalServerErrorResponse.query_failed(error_message) + logger.error("Error while obtaining answer for user question") + yield format_stream_data( + {"event": "error", "data": {**error_response.detail.model_dump()}} + ) + return + + logger.debug( + "Streaming complete - Tool calls: %d, Response chars: %d", + len(summary.tool_calls), + len(summary.llm_response), + ) + + # Extract token usage from the response object + token_usage = ( + extract_token_usage_from_responses_api( + latest_response_object, context.model_id, context.provider_id + ) + if latest_response_object is not None + else TokenCounter() + ) + consume_tokens( + configuration.quota_limiters, + configuration.token_usage_history, + context.user_id, + input_tokens=token_usage.input_tokens, + output_tokens=token_usage.output_tokens, + model_id=context.model_id, + provider_id=context.provider_id, + ) + referenced_documents = parse_referenced_documents_from_responses_api( + cast(OpenAIResponseObject, latest_response_object) + ) + available_quotas = get_available_quotas( + configuration.quota_limiters, context.user_id + ) + yield stream_end_event( + context.metadata_map, + token_usage, + available_quotas, + referenced_documents, + media_type, + ) + + # Perform cleanup tasks (database and cache operations)) + await cleanup_after_streaming( + user_id=context.user_id, + conversation_id=conv_id, + model_id=context.model_id, + provider_id=context.provider_id, + llama_stack_model_id=context.llama_stack_model_id, + query_request=context.query_request, + summary=summary, + metadata_map=context.metadata_map, + started_at=context.started_at, + client=context.client, + config=configuration, + skip_userid_check=context.skip_userid_check, + get_topic_summary_func=get_topic_summary, + is_transcripts_enabled_func=is_transcripts_enabled, + store_transcript_func=store_transcript, + persist_user_conversation_details_func=persist_user_conversation_details, + rag_chunks=[rag_chunk.model_dump() for rag_chunk in rag_chunks], + ) + + return response_generator + + +@router.post( + "/streaming_query", + response_class=StreamingResponse, + responses=streaming_query_v2_responses, + summary="Streaming Query Endpoint Handler V1", +) +@authorize(Action.STREAMING_QUERY) +async def streaming_query_endpoint_handler_v2( # pylint: disable=too-many-locals + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(auth_dependency)], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> StreamingResponse: + """ + Handle request to the /streaming_query endpoint using Responses API. + + Returns a streaming response using Server-Sent Events (SSE) format with + content type text/event-stream. + + Returns: + StreamingResponse: An HTTP streaming response yielding + SSE-formatted events for the query lifecycle with content type + text/event-stream. + + Raises: + HTTPException: + - 401: Unauthorized - Missing or invalid credentials + - 403: Forbidden - Insufficient permissions or model override not allowed + - 404: Not Found - Conversation, model, or provider not found + - 422: Unprocessable Entity - Request validation failed + - 429: Too Many Requests - Quota limit exceeded + - 500: Internal Server Error - Configuration not loaded or other server errors + - 503: Service Unavailable - Unable to connect to Llama Stack backend + """ + return await streaming_query_endpoint_handler_base( + request=request, + query_request=query_request, + auth=auth, + mcp_headers=mcp_headers, + retrieve_response_func=retrieve_response, + create_response_generator_func=create_responses_response_generator, + ) + + +async def retrieve_response( # pylint: disable=too-many-locals + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, + mcp_headers: Optional[dict[str, dict[str, str]]] = None, +) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]: + """ + Retrieve response from LLMs and agents. + + Asynchronously retrieves a streaming response and conversation + ID from the Llama Stack agent for a given user query. + + This function configures shields, system prompt, and tool usage + based on the request and environment. It prepares the agent with + appropriate headers and toolgroups, validates attachments if + present, and initiates a streaming turn with the user's query + and any provided documents. + + Parameters: + model_id (str): Identifier of the model to use for the query. + query_request (QueryRequest): The user's query and associated metadata. + token (str): Authentication token for downstream services. + mcp_headers (dict[str, dict[str, str]], optional): + Multi-cluster proxy headers for tool integrations. + + Returns: + tuple: A tuple containing the streaming response object + and the conversation ID. + """ + # use system prompt from request or default one + system_prompt = get_system_prompt(query_request, configuration) + logger.debug("Using system prompt: %s", system_prompt) + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + + # Prepare tools for responses API + toolgroups = await prepare_tools_for_responses_api( + client, query_request, token, configuration, mcp_headers + ) + + # Prepare input for Responses API + # Convert attachments to text and concatenate with query + input_text = query_request.query + if query_request.attachments: + for attachment in query_request.attachments: + input_text += ( + f"\n\n[Attachment: {attachment.attachment_type}]\n" + f"{attachment.content}" + ) + + # Handle conversation ID for Responses API + # Create conversation upfront if not provided + conversation_id = query_request.conversation_id + if conversation_id: + # Conversation ID was provided - convert to llama-stack format + logger.debug("Using existing conversation ID: %s", conversation_id) + llama_stack_conv_id = to_llama_stack_conversation_id(conversation_id) + else: + # No conversation_id provided - create a new conversation first + logger.debug("No conversation_id provided, creating new conversation") + conversation = await client.conversations.create(metadata={}) + llama_stack_conv_id = conversation.id + # Store the normalized version for later use + conversation_id = normalize_conversation_id(llama_stack_conv_id) + logger.info( + "Created new conversation with ID: %s (normalized: %s)", + llama_stack_conv_id, + conversation_id, + ) + + # Run shield moderation before calling LLM + moderation_result = await run_shield_moderation( + client, input_text, query_request.shield_ids + ) + if moderation_result.blocked: + violation_message = moderation_result.message or "" + await append_turn_to_conversation( + client, llama_stack_conv_id, input_text, violation_message + ) + return ( + create_violation_stream(violation_message, moderation_result.shield_model), + normalize_conversation_id(conversation_id), + ) + + create_params: dict[str, Any] = { + "input": input_text, + "model": model_id, + "instructions": system_prompt, + "stream": True, + "store": True, + "tools": toolgroups, + "conversation": llama_stack_conv_id, + } + + response = await client.responses.create(**create_params) + response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) + + return response_stream, normalize_conversation_id(conversation_id) diff --git a/src/models/requests.py b/src/models/requests.py index 38bdc9ac..1924ed97 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -83,6 +83,7 @@ class QueryRequest(BaseModel): generate_topic_summary: Whether to generate topic summary for new conversations. media_type: The optional media type for response format (application/json or text/plain). vector_store_ids: The optional list of specific vector store IDs to query for RAG. + shield_ids: The optional list of safety shield IDs to apply. Example: ```python @@ -166,6 +167,14 @@ class QueryRequest(BaseModel): examples=["ocp_docs", "knowledge_base", "vector_db_1"], ) + shield_ids: Optional[list[str]] = Field( + None, + description="Optional list of safety shield IDs to apply. " + "If None, all configured shields are used. " + "If empty list, all shields are skipped.", + examples=["llama-guard", "custom-shield"], + ) + # provides examples for /docs endpoint model_config = { "extra": "forbid", diff --git a/src/utils/shields.py b/src/utils/shields.py index ecfa80f5..38aa4b19 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -1,7 +1,7 @@ """Utility functions for working with Llama Stack shields.""" import logging -from typing import Any, cast +from typing import Any, Optional, cast from fastapi import HTTPException from llama_stack_client import AsyncLlamaStackClient @@ -9,7 +9,8 @@ import metrics from models.responses import ( - NotFoundResponse, + NotFoundResponse, + UnprocessableEntityResponse, ) from utils.types import ShieldModerationResult @@ -65,24 +66,56 @@ def detect_shield_violations(output_items: list[Any]) -> bool: async def run_shield_moderation( client: AsyncLlamaStackClient, input_text: str, + shield_ids: Optional[list[str]] = None, ) -> ShieldModerationResult: """ Run shield moderation on input text. - Iterates through all configured shields and runs moderation checks. + Iterates through configured shields and runs moderation checks. Raises HTTPException if shield model is not found. Parameters: client: The Llama Stack client. input_text: The text to moderate. + shield_ids: Optional list of shield IDs to use. If None, uses all shields. + If empty list, skips all shields. Returns: ShieldModerationResult: Result indicating if content was blocked and the message. + + Raises: + HTTPException: If shield's provider_resource_id is not configured or model not found. """ + all_shields = await client.shields.list() + + # Filter shields based on shield_ids parameter + if shield_ids is not None: + if len(shield_ids) == 0: + logger.info("shield_ids=[] provided, skipping all shields") + return ShieldModerationResult(blocked=False) + + shields_to_run = [s for s in all_shields if s.identifier in shield_ids] + + # Log warning if requested shield not found + requested = set(shield_ids) + available = {s.identifier for s in shields_to_run} + missing = requested - available + if missing: + logger.warning("Requested shields not found: %s", missing) + + # Reject if no requested shields were found (prevents accidental bypass) + if not shields_to_run: + response = UnprocessableEntityResponse( + response="Invalid shield configuration", + cause=f"Requested shield_ids not found: {sorted(missing)}", + ) + raise HTTPException(**response.model_dump()) + else: + shields_to_run = list(all_shields) + available_models = {model.id for model in await client.models.list()} - shields = await client.shields.list() - for shield in shields: + for shield in shields_to_run: if ( not shield.provider_resource_id or shield.provider_resource_id not in available_models diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index 5c352758..7702d70a 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -305,6 +305,75 @@ async def test_returns_blocked_on_bad_request_error( assert result.shield_model == "moderation-model" mock_metric.inc.assert_called_once() + @pytest.mark.asyncio + async def test_shield_ids_empty_list_skips_all_shields( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids=[] explicitly skips all shields (intentional bypass).""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + result = await run_shield_moderation(mock_client, "test input", shield_ids=[]) + + assert result.blocked is False + mock_client.shields.list.assert_called_once() + + @pytest.mark.asyncio + async def test_shield_ids_raises_exception_when_no_shields_found( + self, mocker: MockerFixture + ) -> None: + """Test shield_ids raises HTTPException when no requested shields exist.""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + with pytest.raises(HTTPException) as exc_info: + await run_shield_moderation( + mock_client, "test input", shield_ids=["typo-shield"] + ) + + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert "Invalid shield configuration" in exc_info.value.detail["response"] # type: ignore + assert "typo-shield" in exc_info.value.detail["cause"] # type: ignore + + @pytest.mark.asyncio + async def test_shield_ids_filters_to_specific_shield( + self, mocker: MockerFixture + ) -> None: + """Test that shield_ids filters to only specified shields.""" + mock_client = mocker.Mock() + + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield1.provider_resource_id = "model-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + shield2.provider_resource_id = "model-2" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + + model1 = mocker.Mock() + model1.id = "model-1" + mock_client.models.list = mocker.AsyncMock(return_value=[model1]) + + moderation_result = mocker.Mock() + moderation_result.results = [mocker.Mock(flagged=False)] + mock_client.moderations.create = mocker.AsyncMock( + return_value=moderation_result + ) + + result = await run_shield_moderation( + mock_client, "test input", shield_ids=["shield-1"] + ) + + assert result.blocked is False + assert mock_client.moderations.create.call_count == 1 + mock_client.moderations.create.assert_called_with( + input="test input", model="model-1" + ) + class TestAppendTurnToConversation: # pylint: disable=too-few-public-methods """Tests for append_turn_to_conversation function."""