diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index fd123bd21..a0a653c2c 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -36,6 +36,7 @@ UnauthorizedResponse, UnprocessableEntityResponse, ) +from utils.conversations import append_turn_items_to_conversation from utils.endpoints import ( check_configuration_loaded, validate_and_retrieve_conversation, @@ -59,14 +60,11 @@ get_topic_summary, prepare_responses_params, ) -from utils.shields import ( - append_turn_to_conversation, - run_shield_moderation, - validate_shield_ids_override, -) +from utils.shields import run_shield_moderation, validate_shield_ids_override from utils.suid import normalize_conversation_id from utils.types import ( ResponsesApiParams, + ShieldModerationResult, TurnSummary, ) from utils.vector_search import build_rag_context @@ -158,14 +156,21 @@ async def query_endpoint_handler( client = AsyncLlamaStackClientHolder().get_client() - # Build RAG context from Inline RAG sources - inline_rag_context = await build_rag_context( - client, query_request.query, query_request.vector_store_ids, query_request.solr - ) - # Moderation input is the raw user content (query + attachments) without injected RAG # context, to avoid false positives from retrieved document content. moderation_input = prepare_input(query_request) + moderation_result = await run_shield_moderation( + client, moderation_input, query_request.shield_ids + ) + + # Build RAG context from Inline RAG sources + inline_rag_context = await build_rag_context( + client, + moderation_result.decision, + query_request.query, + query_request.vector_store_ids, + query_request.solr, + ) # Prepare API request parameters responses_params = await prepare_responses_params( @@ -177,7 +182,7 @@ async def query_endpoint_handler( stream=False, store=True, request_headers=request.headers, - inline_rag_context=inline_rag_context.context_text or None, + inline_rag_context=inline_rag_context.context_text, ) # Handle Azure token refresh if needed @@ -189,32 +194,22 @@ async def query_endpoint_handler( ): client = await update_azure_token(client) - # Build index identification mapping for RAG source resolution - vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools) - rag_id_mapping = configuration.rag_id_mapping - # Retrieve response using Responses API - turn_summary = await retrieve_response( - client, - responses_params, - query_request.shield_ids, - vector_store_ids, - rag_id_mapping, - moderation_input=moderation_input, - ) - - # Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript - rag_chunks = inline_rag_context.rag_chunks - tool_rag_chunks = turn_summary.rag_chunks or [] - logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks)) - turn_summary.rag_chunks = rag_chunks + tool_rag_chunks - - # Add tool-based RAG documents and chunks - rag_documents = inline_rag_context.referenced_documents - tool_rag_documents = turn_summary.referenced_documents or [] - turn_summary.referenced_documents = deduplicate_referenced_documents( - rag_documents + tool_rag_documents - ) + turn_summary = await retrieve_response(client, responses_params, moderation_result) + + if moderation_result.decision == "passed": + # Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript + rag_chunks = inline_rag_context.rag_chunks + tool_rag_chunks = turn_summary.rag_chunks + logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks)) + turn_summary.rag_chunks = rag_chunks + tool_rag_chunks + + # Add tool-based RAG documents and chunks + rag_documents = inline_rag_context.referenced_documents + tool_rag_documents = turn_summary.referenced_documents + turn_summary.referenced_documents = deduplicate_referenced_documents( + rag_documents + tool_rag_documents + ) # Get topic summary for new conversation if not user_conversation and query_request.generate_topic_summary: @@ -272,10 +267,7 @@ async def query_endpoint_handler( async def retrieve_response( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, responses_params: ResponsesApiParams, - shield_ids: Optional[list[str]] = None, - vector_store_ids: Optional[list[str]] = None, - rag_id_mapping: Optional[dict[str, str]] = None, - moderation_input: Optional[str] = None, + moderation_result: ShieldModerationResult, ) -> TurnSummary: """ Retrieve response from LLMs and agents. @@ -286,33 +278,21 @@ async def retrieve_response( # pylint: disable=too-many-locals Parameters: client: The AsyncLlamaStackClient to use for the request. responses_params: The Responses API parameters. - shield_ids: Optional list of shield IDs for moderation. - vector_store_ids: Vector store IDs used in the query for source resolution. - rag_id_mapping: Mapping from vector_db_id to user-facing rag_id. - moderation_input: Text to moderate. Should be the raw user content (query + - attachments) without injected RAG context to avoid false positives. - Falls back to responses_params.input if not provided. + moderation_result: The moderation result. Returns: TurnSummary: Summary of the LLM response content """ response: Optional[OpenAIResponseObject] = None - try: - moderation_result = await run_shield_moderation( + if moderation_result.decision == "blocked": + await append_turn_items_to_conversation( client, - moderation_input or cast(str, responses_params.input), - shield_ids, + responses_params.conversation, + responses_params.input, + [moderation_result.refusal_response], ) - if moderation_result.decision == "blocked": - # Handle shield moderation blocking - violation_message = moderation_result.message - await append_turn_to_conversation( - client, - responses_params.conversation, - cast(str, responses_params.input), - violation_message, - ) - return TurnSummary(llm_response=violation_message) + return TurnSummary(llm_response=moderation_result.message) + try: response = await client.responses.create( **responses_params.model_dump(exclude_none=True) ) @@ -333,6 +313,8 @@ async def retrieve_response( # pylint: disable=too-many-locals error_response = handle_known_apistatus_errors(e, responses_params.model) raise HTTPException(**error_response.model_dump()) from e + vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools) + rag_id_mapping = configuration.rag_id_mapping return build_turn_summary( response, responses_params.model, vector_store_ids, rag_id_mapping ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d04e9d5be..99f0082c2 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -57,6 +57,7 @@ UnauthorizedResponse, UnprocessableEntityResponse, ) +from utils.conversations import append_turn_items_to_conversation from utils.endpoints import ( check_configuration_loaded, validate_and_retrieve_conversation, @@ -189,10 +190,22 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals client = AsyncLlamaStackClientHolder().get_client() + # Moderation input is the raw user content (query + attachments) without injected RAG + # context, to avoid false positives from retrieved document content. + moderation_input = prepare_input(query_request) + moderation_result = await run_shield_moderation( + client, moderation_input, query_request.shield_ids + ) + # Build RAG context from Inline RAG sources inline_rag_context = await build_rag_context( - client, query_request.query, query_request.vector_store_ids, query_request.solr + client, + moderation_result.decision, + query_request.query, + query_request.vector_store_ids, + query_request.solr, ) + # Prepare API request parameters responses_params = await prepare_responses_params( client=client, @@ -203,7 +216,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals stream=True, store=True, request_headers=request.headers, - inline_rag_context=inline_rag_context.context_text or None, + inline_rag_context=inline_rag_context.context_text, ) # Handle Azure token refresh if needed @@ -227,8 +240,10 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals query_request=query_request, started_at=started_at, client=client, + moderation_result=moderation_result, vector_store_ids=extract_vector_store_ids_from_tools(responses_params.tools), rag_id_mapping=configuration.rag_id_mapping, + inline_rag_context=inline_rag_context, ) # Update metrics for the LLM call @@ -240,9 +255,14 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals generator, turn_summary = await retrieve_response_generator( responses_params=responses_params, context=context, - inline_rag_documents=inline_rag_context.referenced_documents, ) + # Combine inline RAG results (BYOK + Solr) with tool-based results + if context.moderation_result.decision == "passed": + turn_summary.referenced_documents = deduplicate_referenced_documents( + inline_rag_context.referenced_documents + turn_summary.referenced_documents + ) + response_media_type = ( MEDIA_TYPE_TEXT if query_request.media_type == MEDIA_TYPE_TEXT @@ -263,7 +283,6 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals async def retrieve_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, - inline_rag_documents: list[ReferencedDocument], ) -> tuple[AsyncIterator[str], TurnSummary]: """ Retrieve the appropriate response generator. @@ -275,30 +294,26 @@ async def retrieve_response_generator( Args: responses_params: The Responses API parameters context: The response generator context - inline_rag_documents: Referenced documents from inline RAG (BYOK + Solr) - Returns: tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary """ turn_summary = TurnSummary() try: - moderation_result = await run_shield_moderation( - context.client, - prepare_input(context.query_request), - context.query_request.shield_ids, - ) - if moderation_result.decision == "blocked": - turn_summary.llm_response = moderation_result.message - await append_turn_to_conversation( + if context.moderation_result.decision == "blocked": + turn_summary.llm_response = context.moderation_result.message + await append_turn_items_to_conversation( context.client, responses_params.conversation, - cast(str, responses_params.input), - moderation_result.message, + responses_params.input, + [context.moderation_result.refusal_response], ) media_type = context.query_request.media_type or MEDIA_TYPE_JSON return ( - shield_violation_generator(moderation_result.message, media_type), + shield_violation_generator( + context.moderation_result.message, + media_type, + ), turn_summary, ) # Retrieve response stream (may raise exceptions) @@ -306,9 +321,14 @@ async def retrieve_response_generator( **responses_params.model_dump(exclude_none=True) ) # Store pre-RAG documents for later merging with tool-based RAG - turn_summary.inline_rag_documents = inline_rag_documents - return response_generator(response, context, turn_summary), turn_summary - + return ( + response_generator( + response, + context, + turn_summary, + ), + turn_summary, + ) # Handle know LLS client errors only at stream creation time and shield execution except RuntimeError as e: # library mode wraps 413 into runtime error if "context_length" in str(e).lower(): @@ -570,7 +590,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat turn_response: The streaming response from Llama Stack context: The response generator context turn_summary: TurnSummary to populate during streaming - Yields: SSE-formatted strings for tokens, tool calls, tool results, turn completion, and error events. @@ -741,15 +760,19 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat turn_summary.token_usage = extract_token_usage( latest_response_object.usage, context.model_id ) - tool_based_documents = parse_referenced_documents( + # Parse tool-based referenced documents from the final response object + tool_rag_docs = parse_referenced_documents( latest_response_object, vector_store_ids=context.vector_store_ids, rag_id_mapping=context.rag_id_mapping, ) - - # Merge pre-RAG documents with tool-based documents and deduplicate + # Combine inline RAG results (BYOK + Solr) with tool-based results turn_summary.referenced_documents = deduplicate_referenced_documents( - turn_summary.inline_rag_documents + tool_based_documents + context.inline_rag_context.referenced_documents + tool_rag_docs + ) + # Combine inline RAG chunks (BYOK + Solr) with tool-based chunks + turn_summary.rag_chunks = ( + context.inline_rag_context.rag_chunks + turn_summary.rag_chunks ) diff --git a/src/constants.py b/src/constants.py index 39dbb9fe0..20145a812 100644 --- a/src/constants.py +++ b/src/constants.py @@ -214,3 +214,5 @@ # Environment variable to force StreamHandler instead of RichHandler # Set to any non-empty value to disable RichHandler LIGHTSPEED_STACK_DISABLE_RICH_HANDLER_ENV_VAR = "LIGHTSPEED_STACK_DISABLE_RICH_HANDLER" + +DEFAULT_VIOLATION_MESSAGE = "I cannot process this request due to policy restrictions." diff --git a/src/models/context.py b/src/models/context.py index 2ef76f36d..9876a1485 100644 --- a/src/models/context.py +++ b/src/models/context.py @@ -4,6 +4,7 @@ from llama_stack_client import AsyncLlamaStackClient from models.requests import QueryRequest +from utils.types import RAGContext, ShieldModerationResult @dataclass @@ -23,6 +24,8 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes query_request: The query request object started_at: Timestamp when the request started (ISO 8601 format) client: The Llama Stack client for API interactions + moderation_result: The moderation result + inline_rag_context: Inline RAG context vector_store_ids: Vector store IDs used in the query for source resolution. rag_id_mapping: Mapping from vector_db_id to user-facing rag_id. """ @@ -42,7 +45,9 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes # Dependencies & State client: AsyncLlamaStackClient + moderation_result: ShieldModerationResult # RAG index identification + inline_rag_context: RAGContext vector_store_ids: list[str] = field(default_factory=list) rag_id_mapping: dict[str, str] = field(default_factory=dict) diff --git a/src/models/requests.py b/src/models/requests.py index c091c6e14..d65c0b49e 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -179,8 +179,7 @@ class QueryRequest(BaseModel): shield_ids: Optional[list[str]] = Field( None, description="Optional list of safety shield IDs to apply. " - "If None, all configured shields are used. " - "If provided, must contain at least one valid shield ID (empty list raises 422 error).", + "If None, all configured shields are used. ", examples=["llama-guard", "custom-shield"], ) diff --git a/src/utils/conversations.py b/src/utils/conversations.py index 66205594f..aa698aabe 100644 --- a/src/utils/conversations.py +++ b/src/utils/conversations.py @@ -3,7 +3,10 @@ import json from datetime import UTC, datetime from typing import Any, Optional, cast +from collections.abc import Sequence +from fastapi import HTTPException +from llama_stack_api import OpenAIResponseMessage, OpenAIResponseOutput from llama_stack_api.openai_responses import ( OpenAIResponseOutputMessageFileSearchToolCall as FileSearchCall, OpenAIResponseOutputMessageFunctionToolCall as FunctionCall, @@ -11,6 +14,8 @@ OpenAIResponseOutputMessageMCPListTools as MCPListTools, OpenAIResponseOutputMessageWebSearchToolCall as WebSearchCall, ) +from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient +from llama_stack_client.types.conversations.item_create_params import Item from llama_stack_client.types.conversations.item_list_response import ( ItemListResponse, OpenAIResponseInputFunctionToolCallOutput as FunctionToolCallOutput, @@ -21,9 +26,14 @@ from constants import DEFAULT_RAG_TOOL from models.database.conversations import UserTurn -from models.responses import ConversationTurn, Message +from models.responses import ( + ConversationTurn, + InternalServerErrorResponse, + Message, + ServiceUnavailableResponse, +) from utils.responses import parse_arguments_string -from utils.types import ToolCallSummary, ToolResultSummary +from utils.types import ResponseInput, ToolCallSummary, ToolResultSummary def _extract_text_from_content(content: str | list[Any]) -> str: @@ -423,3 +433,46 @@ def build_conversation_turns_from_items( ) return chat_history + + +async def append_turn_items_to_conversation( + client: AsyncLlamaStackClient, + conversation_id: str, + user_input: ResponseInput, + llm_output: Sequence[OpenAIResponseOutput], +) -> None: + """ + Append a turn (user input + LLM output) to a conversation in LLS database. + + Args: + client: The Llama Stack client. + conversation_id: The Llama Stack conversation ID. + user_input: User input text or list of ResponseItem. + llm_output: Output from the LLM: a list of OpenAIResponseOutput. + """ + if isinstance(user_input, str): + user_message = OpenAIResponseMessage( + role="user", + content=user_input, + ) + user_items = [user_message.model_dump()] + else: + user_items = [item.model_dump() for item in user_input] + + output_items = [item.model_dump() for item in llm_output] + + items = user_items + output_items + try: + await client.conversations.items.create( + conversation_id, + items=cast(list[Item], items), + ) + except APIConnectionError as e: + error_response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(e), + ) + raise HTTPException(**error_response.model_dump()) from e + except APIStatusError as e: + error_response = InternalServerErrorResponse.generic() + raise HTTPException(**error_response.model_dump()) from e diff --git a/src/utils/shields.py b/src/utils/shields.py index ff99fc3b0..a225cfd6c 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -3,8 +3,14 @@ from typing import Any, Optional from fastapi import HTTPException -from llama_stack_api import OpenAIResponseContentPartRefusal, OpenAIResponseMessage -from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient +from llama_stack_api import OpenAIResponseMessage +from llama_stack_client import ( + APIConnectionError, + APIStatusError as LLSApiStatusError, + AsyncLlamaStackClient, +) +from llama_stack_client.types import ShieldListResponse +from openai._exceptions import APIStatusError as OpenAIAPIStatusError import metrics from configuration import AppConfig @@ -16,17 +22,16 @@ UnprocessableEntityResponse, ServiceUnavailableResponse, ) -from utils.suid import get_suid +from utils.query import handle_known_apistatus_errors from utils.types import ( ShieldModerationBlocked, ShieldModerationPassed, ShieldModerationResult, ) +from constants import DEFAULT_VIOLATION_MESSAGE logger = get_logger(__name__) -DEFAULT_VIOLATION_MESSAGE = "I cannot process this request due to policy restrictions." - async def get_available_shields(client: AsyncLlamaStackClient) -> list[str]: """ @@ -129,47 +134,11 @@ async def run_shield_moderation( Raises: HTTPException: If shield's provider_resource_id is not configured or model not found. """ - all_shields = await client.shields.list() - - # Filter shields based on shield_ids parameter - if shield_ids is not None: - if len(shield_ids) == 0: - response = UnprocessableEntityResponse( - response="Invalid shield configuration", - cause=( - "shield_ids provided but no shields selected. " - "Remove the parameter to use default shields." - ), - ) - raise HTTPException(**response.model_dump()) - - shields_to_run = [s for s in all_shields if s.identifier in shield_ids] - - # Log warning if requested shield not found - requested = set(shield_ids) - available = {s.identifier for s in shields_to_run} - missing = requested - available - if missing: - logger.warning("Requested shields not found: %s", missing) - - # Reject if no requested shields were found (prevents accidental bypass) - if not shields_to_run: - response = UnprocessableEntityResponse( - response="Invalid shield configuration", - cause=f"Requested shield_ids not found: {sorted(missing)}", - ) - raise HTTPException(**response.model_dump()) - else: - shields_to_run = list(all_shields) - + shields_to_run = await get_shields_for_request(client, shield_ids) available_models = {model.id for model in await client.models.list()} - for shield in shields_to_run: - # Only validate provider_resource_id against models for llama-guard. - # Llama Stack does not verify that the llama-guard model is registered, - # so we check it here to fail fast with a clear error. - # Custom shield providers (e.g. lightspeed_question_validity) configure - # their model internally, so provider_resource_id is not a model ID. + # Lightspeed safety providers configure their model internally + # so provider_resource_id is not necessarily a valid model ID. if shield.provider_id == "llama-guard" and ( not shield.provider_resource_id or shield.provider_resource_id not in available_models @@ -184,18 +153,17 @@ async def run_shield_moderation( moderation_result = await client.moderations.create( input=input_text, model=shield.provider_resource_id ) - # Known Llama Stack bug: error is raised when violation is present - # in the shield LLM response but has wrong format that cannot be parsed. - except ValueError: - logger.warning( - "Shield violation detected, treating as blocked", + except APIConnectionError as e: + error_response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(e), ) - metrics.llm_calls_validation_errors_total.inc() - return ShieldModerationBlocked( - message=DEFAULT_VIOLATION_MESSAGE, - moderation_id=f"modr_{get_suid()}", - refusal_response=create_refusal_response(DEFAULT_VIOLATION_MESSAGE), + raise HTTPException(**error_response.model_dump()) from e + except (LLSApiStatusError, OpenAIAPIStatusError) as e: + error_response = handle_known_apistatus_errors( + e, shield.provider_resource_id or "" ) + raise HTTPException(**error_response.model_dump()) from e if moderation_result.results and moderation_result.results[0].flagged: result = moderation_result.results[0] @@ -247,7 +215,7 @@ async def append_turn_to_conversation( cause=str(e), ) raise HTTPException(**error_response.model_dump()) from e - except APIStatusError as e: + except LLSApiStatusError as e: error_response = InternalServerErrorResponse.generic() raise HTTPException(**error_response.model_dump()) from e @@ -255,18 +223,61 @@ async def append_turn_to_conversation( def create_refusal_response(refusal_message: str) -> OpenAIResponseMessage: """Create a refusal response message object. - Creates an OpenAIResponseMessage with assistant role containing a refusal - content part. This can be used for both conversation items and response output. - Args: refusal_message: The refusal message text. Returns: - OpenAIResponseMessage with refusal content. + OpenAIResponseMessage with refusal message. """ - refusal_content = OpenAIResponseContentPartRefusal(refusal=refusal_message) return OpenAIResponseMessage( type="message", role="assistant", - content=[refusal_content], + content=refusal_message, ) + + +async def get_shields_for_request( + client: AsyncLlamaStackClient, + shield_ids: Optional[list[str]] = None, +) -> ShieldListResponse: + """Resolve shields for the request: filtered by shield_ids or all configured. + + Args: + client: Llama Stack client. + shield_ids: Optional list of shield IDs. If provided, only shields + with these identifiers are returned; if None, all configured + shields are returned. + + Returns: + ShieldListResponse: List of Shield objects to run for this request. + + Raises: + HTTPException: 404 if shield_ids is provided and any requested + shield is not configured in Llama Stack. + """ + if shield_ids == []: + return [] + try: + configured_shields: ShieldListResponse = await client.shields.list() + if shield_ids is None: + return configured_shields + requested = set(shield_ids) + configured_ids = {s.identifier for s in configured_shields} + missing = requested - configured_ids + if missing: + response = NotFoundResponse( + resource=f"Shield{'s' if len(missing) > 1 else ''}", + resource_id=", ".join(missing), + ) + raise HTTPException(**response.model_dump()) + + return [s for s in configured_shields if s.identifier in requested] + except APIConnectionError as e: + error_response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(e), + ) + raise HTTPException(**error_response.model_dump()) from e + except LLSApiStatusError as e: + error_response = InternalServerErrorResponse.generic() + raise HTTPException(**error_response.model_dump()) from e diff --git a/src/utils/types.py b/src/utils/types.py index 4a725fd5e..52f3566c9 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -332,7 +332,6 @@ class TurnSummary(BaseModel): tool_results: list[ToolResultSummary] = Field(default_factory=list) rag_chunks: list[RAGChunk] = Field(default_factory=list) referenced_documents: list[ReferencedDocument] = Field(default_factory=list) - inline_rag_documents: list[ReferencedDocument] = Field(default_factory=list) token_usage: TokenCounter = Field(default_factory=TokenCounter) diff --git a/src/utils/vector_search.py b/src/utils/vector_search.py index 485914e0b..b24b3214a 100644 --- a/src/utils/vector_search.py +++ b/src/utils/vector_search.py @@ -493,6 +493,7 @@ async def _fetch_solr_rag( async def build_rag_context( client: AsyncLlamaStackClient, + moderation_decision: str, query: str, vector_store_ids: Optional[list[str]], solr: Optional[dict[str, Any]] = None, @@ -503,12 +504,17 @@ async def build_rag_context( Args: client: The AsyncLlamaStackClient to use for the request - query_request: The user's query request - configuration: Application configuration + moderation_decision: The moderation decision + query: The user's query + vector_store_ids: The vector store IDs to query + solr: The Solr query parameters Returns: RAGContext containing formatted context text and referenced documents """ + if moderation_decision == "blocked": + return RAGContext() + # Fetch from all enabled RAG sources in parallel byok_chunks_task = _fetch_byok_rag(client, query, vector_store_ids) solr_chunks_task = _fetch_solr_rag(client, query, solr) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 32bc5b83d..06ee69926 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -17,7 +17,11 @@ from models.responses import QueryResponse from utils.token_counter import TokenCounter from utils.types import ( + RAGChunk, + RAGContext, + ReferencedDocument, ResponsesApiParams, + ShieldModerationPassed, ToolCallSummary, ToolResultSummary, TurnSummary, @@ -125,6 +129,10 @@ async def test_successful_query_no_conversation( "app.endpoints.query.get_topic_summary", new=mocker.AsyncMock(return_value=None), ) + mocker.patch( + "app.endpoints.query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" @@ -169,6 +177,93 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: assert response.conversation_id == "123" assert response.response == "Kubernetes is a container orchestration platform" + @pytest.mark.asyncio + async def test_query_merges_inline_and_tool_rag_chunks_and_documents( + self, + dummy_request: Request, + setup_configuration: AppConfig, + mocker: MockerFixture, + ) -> None: + """Test that inline RAG and tool-based RAG chunks/docs are correctly merged.""" + query_request = QueryRequest( + query="What is Kubernetes?" + ) # pyright: ignore[reportCallIssue] + + mocker.patch("app.endpoints.query.configuration", setup_configuration) + mocker.patch("app.endpoints.query.check_configuration_loaded") + mocker.patch("app.endpoints.query.check_tokens_available") + mocker.patch("app.endpoints.query.validate_model_provider_override") + + mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) + mock_response_obj = mocker.Mock() + mock_response_obj.output = [] + mock_client.responses = mocker.Mock() + mock_client.responses.create = mocker.AsyncMock(return_value=mock_response_obj) + mock_client_holder = mocker.Mock() + mock_client_holder.get_client.return_value = mock_client + mocker.patch( + "app.endpoints.query.AsyncLlamaStackClientHolder", + return_value=mock_client_holder, + ) + mocker.patch( + "app.endpoints.query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) + + inline_chunk = RAGChunk(content="inline chunk content", source="byok") + inline_doc = ReferencedDocument(doc_title="Inline Doc") + inline_rag = RAGContext( + context_text="", + rag_chunks=[inline_chunk], + referenced_documents=[inline_doc], + ) + mocker.patch( + "app.endpoints.query.build_rag_context", + new=mocker.AsyncMock(return_value=inline_rag), + ) + + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) + mock_responses_params.model = "provider1/model1" + mock_responses_params.conversation = "conv_123" + mock_responses_params.tools = None + mock_responses_params.model_dump.return_value = { + "input": "test", + "model": "provider1/model1", + } + mocker.patch( + "app.endpoints.query.prepare_responses_params", + new=mocker.AsyncMock(return_value=mock_responses_params), + ) + + tool_chunk = RAGChunk(content="tool chunk content", source="vs-1") + tool_doc = ReferencedDocument(doc_title="Tool Doc") + mock_turn_summary = TurnSummary() + mock_turn_summary.rag_chunks = [tool_chunk] + mock_turn_summary.referenced_documents = [tool_doc] + + mocker.patch( + "app.endpoints.query.retrieve_response", + new=mocker.AsyncMock(return_value=mock_turn_summary), + ) + mocker.patch("app.endpoints.query.store_query_results") + mocker.patch("app.endpoints.query.consume_query_tokens") + mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + + response = await query_endpoint_handler( + request=dummy_request, + query_request=query_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert isinstance(response, QueryResponse) + assert len(response.rag_chunks) == 2 + assert response.rag_chunks[0].content == "inline chunk content" + assert response.rag_chunks[1].content == "tool chunk content" + assert len(response.referenced_documents) == 2 + assert response.referenced_documents[0].doc_title == "Inline Doc" + assert response.referenced_documents[1].doc_title == "Tool Doc" + @pytest.mark.asyncio async def test_successful_query_with_conversation( self, @@ -214,7 +309,10 @@ async def test_successful_query_with_conversation( "app.endpoints.query.prepare_responses_params", new=mocker.AsyncMock(return_value=mock_responses_params), ) - + mocker.patch( + "app.endpoints.query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mocker.patch( "app.endpoints.query.retrieve_response", new=mocker.AsyncMock(return_value=TurnSummary()), @@ -275,6 +373,10 @@ async def test_query_with_attachments( "app.endpoints.query.get_topic_summary", new=mocker.AsyncMock(return_value=None), ) + mocker.patch( + "app.endpoints.query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" @@ -335,6 +437,10 @@ async def test_query_with_topic_summary( "app.endpoints.query.AsyncLlamaStackClientHolder", return_value=mock_client_holder, ) + mocker.patch( + "app.endpoints.query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" @@ -405,6 +511,10 @@ async def test_query_azure_token_refresh( "app.endpoints.query.get_topic_summary", new=mocker.AsyncMock(return_value=None), ) + mocker.patch( + "app.endpoints.query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "azure/model1" @@ -476,6 +586,7 @@ async def test_retrieve_response_success(self, mocker: MockerFixture) -> None: mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.input = "test query" mock_responses_params.model = "provider1/model1" + mock_responses_params.tools = None mock_responses_params.model_dump.return_value = { "input": "test query", "model": "provider1/model1", @@ -492,10 +603,6 @@ async def test_retrieve_response_success(self, mocker: MockerFixture) -> None: mock_response.output = [mock_output_item] mock_response.usage = mock_usage - mocker.patch( - "app.endpoints.query.run_shield_moderation", - return_value=mocker.Mock(decision="passed"), - ) mock_client.responses.create = mocker.AsyncMock(return_value=mock_response) mock_summary = TurnSummary() @@ -506,7 +613,9 @@ async def test_retrieve_response_success(self, mocker: MockerFixture) -> None: return_value=mock_summary, ) - result = await retrieve_response(mock_client, mock_responses_params) + result = await retrieve_response( + mock_client, mock_responses_params, ShieldModerationPassed() + ) assert isinstance(result, TurnSummary) assert result.llm_response == "Response text" @@ -527,19 +636,20 @@ async def test_retrieve_response_shield_blocked( "model": "provider1/model1", } + mock_refusal = mocker.Mock() mock_moderation_result = mocker.Mock() mock_moderation_result.decision = "blocked" mock_moderation_result.message = "Content blocked by moderation" - mocker.patch( - "app.endpoints.query.run_shield_moderation", - new=mocker.AsyncMock(return_value=mock_moderation_result), - ) + mock_moderation_result.moderation_id = "mod_123" + mock_moderation_result.refusal_response = mock_refusal mock_append = mocker.patch( - "app.endpoints.query.append_turn_to_conversation", + "app.endpoints.query.append_turn_items_to_conversation", new=mocker.AsyncMock(), ) - result = await retrieve_response(mock_client, mock_responses_params) + result = await retrieve_response( + mock_client, mock_responses_params, mock_moderation_result + ) assert isinstance(result, TurnSummary) assert result.llm_response == "Content blocked by moderation" @@ -558,10 +668,6 @@ async def test_retrieve_response_connection_error( "model": "provider1/model1", } - mocker.patch( - "app.endpoints.query.run_shield_moderation", - return_value=mocker.Mock(decision="passed"), - ) mock_client.responses.create = mocker.AsyncMock( side_effect=APIConnectionError( message="Connection failed", request=mocker.Mock() @@ -569,7 +675,9 @@ async def test_retrieve_response_connection_error( ) with pytest.raises(HTTPException) as exc_info: - await retrieve_response(mock_client, mock_responses_params) + await retrieve_response( + mock_client, mock_responses_params, ShieldModerationPassed() + ) assert exc_info.value.status_code == 503 @@ -587,10 +695,6 @@ async def test_retrieve_response_api_status_error( "model": "provider1/model1", } - mocker.patch( - "app.endpoints.query.run_shield_moderation", - return_value=mocker.Mock(decision="passed"), - ) mock_client.responses.create = mocker.AsyncMock( side_effect=APIStatusError( message="API error", response=mocker.Mock(request=None), body=None @@ -607,7 +711,9 @@ async def test_retrieve_response_api_status_error( ) with pytest.raises(HTTPException): - await retrieve_response(mock_client, mock_responses_params) + await retrieve_response( + mock_client, mock_responses_params, ShieldModerationPassed() + ) @pytest.mark.asyncio async def test_retrieve_response_runtime_error_context_length( @@ -623,16 +729,14 @@ async def test_retrieve_response_runtime_error_context_length( "model": "provider1/model1", } - mocker.patch( - "app.endpoints.query.run_shield_moderation", - return_value=mocker.Mock(decision="passed"), - ) mock_client.responses.create = mocker.AsyncMock( side_effect=RuntimeError("context_length exceeded") ) with pytest.raises(HTTPException) as exc_info: - await retrieve_response(mock_client, mock_responses_params) + await retrieve_response( + mock_client, mock_responses_params, ShieldModerationPassed() + ) assert exc_info.value.status_code == 413 @@ -650,16 +754,14 @@ async def test_retrieve_response_runtime_error_other( "model": "provider1/model1", } - mocker.patch( - "app.endpoints.query.run_shield_moderation", - return_value=mocker.Mock(decision="passed"), - ) mock_client.responses.create = mocker.AsyncMock( side_effect=RuntimeError("Some other error") ) with pytest.raises(RuntimeError): - await retrieve_response(mock_client, mock_responses_params) + await retrieve_response( + mock_client, mock_responses_params, ShieldModerationPassed() + ) @pytest.mark.asyncio async def test_retrieve_response_with_tool_calls( @@ -670,6 +772,7 @@ async def test_retrieve_response_with_tool_calls( mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.input = "test query" mock_responses_params.model = "provider1/model1" + mock_responses_params.tools = None mock_responses_params.model_dump.return_value = { "input": "test query", "model": "provider1/model1", @@ -682,10 +785,6 @@ async def test_retrieve_response_with_tool_calls( mock_response.output = [mocker.Mock(type="message")] mock_response.usage = mock_usage - mocker.patch( - "app.endpoints.query.run_shield_moderation", - return_value=mocker.Mock(decision="passed"), - ) mock_client.responses.create = mocker.AsyncMock(return_value=mock_response) mock_tool_call = ToolCallSummary(id="1", name="test", args={}) @@ -702,7 +801,9 @@ async def test_retrieve_response_with_tool_calls( return_value=mock_summary, ) - result = await retrieve_response(mock_client, mock_responses_params) + result = await retrieve_response( + mock_client, mock_responses_params, ShieldModerationPassed() + ) assert result.llm_response == "Response text" assert len(result.tool_calls) == 1 @@ -711,4 +812,3 @@ async def test_retrieve_response_with_tool_calls( assert result.token_usage.output_tokens == 5 assert result.rag_chunks == [] assert result.referenced_documents == [] - assert result.inline_rag_documents == [] diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index cc3214169..3e0670e94 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -52,7 +52,14 @@ from models.responses import InternalServerErrorResponse from utils.token_counter import TokenCounter from utils.stream_interrupts import StreamInterruptRegistry -from utils.types import RAGContext, ReferencedDocument, ResponsesApiParams, TurnSummary +from utils.types import ( + RAGChunk, + RAGContext, + ReferencedDocument, + ResponsesApiParams, + ShieldModerationPassed, + TurnSummary, +) MOCK_AUTH_STREAMING = ( "00000001-0001-0001-0001-000000000001", @@ -354,6 +361,10 @@ async def test_successful_streaming_query( "app.endpoints.streaming_query.prepare_responses_params", new=mocker.AsyncMock(return_value=mock_responses_params), ) + mocker.patch( + "app.endpoints.streaming_query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mocker.patch("app.endpoints.streaming_query.AzureEntraIDManager") mocker.patch( @@ -437,6 +448,10 @@ async def test_streaming_query_text_media_type_header( "app.endpoints.streaming_query.prepare_responses_params", new=mocker.AsyncMock(return_value=mock_responses_params), ) + mocker.patch( + "app.endpoints.streaming_query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mocker.patch("app.endpoints.streaming_query.AzureEntraIDManager") mocker.patch( @@ -531,6 +546,10 @@ async def test_streaming_query_with_conversation( "app.endpoints.streaming_query.prepare_responses_params", new=mocker.AsyncMock(return_value=mock_responses_params), ) + mocker.patch( + "app.endpoints.streaming_query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mocker.patch("app.endpoints.streaming_query.AzureEntraIDManager") mocker.patch( @@ -623,6 +642,10 @@ async def test_streaming_query_with_attachments( "app.endpoints.streaming_query.prepare_responses_params", new=mocker.AsyncMock(return_value=mock_responses_params), ) + mocker.patch( + "app.endpoints.streaming_query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mocker.patch("app.endpoints.streaming_query.AzureEntraIDManager") mocker.patch( @@ -725,6 +748,10 @@ async def test_streaming_query_azure_token_refresh( "app.endpoints.streaming_query.extract_provider_and_model_from_model_id", return_value=("azure", "model1"), ) + mocker.patch( + "app.endpoints.streaming_query.run_shield_moderation", + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), + ) mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total") async def mock_generator() -> AsyncIterator[str]: @@ -784,17 +811,15 @@ async def test_retrieve_response_generator_success( mock_context.client = mock_client mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.query_request = QueryRequest( query="test" ) # pyright: ignore[reportCallIssue] + mock_context.moderation_result = ShieldModerationPassed() async def mock_response_gen() -> AsyncIterator[str]: yield "test" - mocker.patch( - "app.endpoints.streaming_query.run_shield_moderation", - new=mocker.AsyncMock(return_value=mocker.Mock(blocked=False)), - ) mock_client.responses = mocker.Mock() mock_client.responses.create = mocker.AsyncMock( return_value=mock_response_gen() @@ -812,7 +837,7 @@ async def mock_response_generator( ) generator, turn_summary = await retrieve_response_generator( - mock_responses_params, mock_context, [] + mock_responses_params, mock_context ) assert isinstance(turn_summary, TurnSummary) @@ -834,6 +859,7 @@ async def test_retrieve_response_generator_shield_blocked( mock_context.client = mock_client mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.query_request = QueryRequest( query="test", media_type=MEDIA_TYPE_TEXT ) # pyright: ignore[reportCallIssue] @@ -841,17 +867,16 @@ async def test_retrieve_response_generator_shield_blocked( mock_moderation_result = mocker.Mock() mock_moderation_result.decision = "blocked" mock_moderation_result.message = "Content blocked" + mock_moderation_result.moderation_id = "mod_123" + mock_moderation_result.refusal_response = mocker.Mock() + mock_context.moderation_result = mock_moderation_result mocker.patch( - "app.endpoints.streaming_query.run_shield_moderation", - new=mocker.AsyncMock(return_value=mock_moderation_result), - ) - mocker.patch( - "app.endpoints.streaming_query.append_turn_to_conversation", + "app.endpoints.streaming_query.append_turn_items_to_conversation", new=mocker.AsyncMock(), ) _generator, turn_summary = await retrieve_response_generator( - mock_responses_params, mock_context, [] + mock_responses_params, mock_context ) assert isinstance(turn_summary, TurnSummary) @@ -878,14 +903,12 @@ async def test_retrieve_response_generator_connection_error( mock_context.client = mock_client mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.query_request = QueryRequest( query="test" ) # pyright: ignore[reportCallIssue] + mock_context.moderation_result = ShieldModerationPassed() - mocker.patch( - "app.endpoints.streaming_query.run_shield_moderation", - new=mocker.AsyncMock(return_value=mocker.Mock(blocked=False)), - ) mock_request_obj = mocker.Mock() mock_client.responses = mocker.Mock() mock_client.responses.create = mocker.AsyncMock( @@ -908,7 +931,7 @@ async def test_retrieve_response_generator_connection_error( ) with pytest.raises(HTTPException) as exc_info: - await retrieve_response_generator(mock_responses_params, mock_context, []) + await retrieve_response_generator(mock_responses_params, mock_context) assert exc_info.value.status_code == 503 @@ -933,14 +956,12 @@ async def test_retrieve_response_generator_api_status_error( mock_context.client = mock_client mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.query_request = QueryRequest( query="test" ) # pyright: ignore[reportCallIssue] + mock_context.moderation_result = ShieldModerationPassed() - mocker.patch( - "app.endpoints.streaming_query.run_shield_moderation", - new=mocker.AsyncMock(return_value=mocker.Mock(blocked=False)), - ) mock_request_obj = mocker.Mock() mock_client.responses = mocker.Mock() mock_client.responses.create = mocker.AsyncMock( @@ -960,7 +981,7 @@ async def test_retrieve_response_generator_api_status_error( ) with pytest.raises(HTTPException) as exc_info: - await retrieve_response_generator(mock_responses_params, mock_context, []) + await retrieve_response_generator(mock_responses_params, mock_context) assert exc_info.value.status_code == 500 @@ -985,14 +1006,12 @@ async def test_retrieve_response_generator_runtime_error_context_length( mock_context.client = mock_client mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.query_request = QueryRequest( query="test" ) # pyright: ignore[reportCallIssue] + mock_context.moderation_result = ShieldModerationPassed() - mocker.patch( - "app.endpoints.streaming_query.run_shield_moderation", - new=mocker.AsyncMock(return_value=mocker.Mock(blocked=False)), - ) mock_client.responses = mocker.Mock() mock_client.responses.create = mocker.AsyncMock( side_effect=RuntimeError("context_length exceeded") @@ -1009,7 +1028,7 @@ async def test_retrieve_response_generator_runtime_error_context_length( ) with pytest.raises(HTTPException) as exc_info: - await retrieve_response_generator(mock_responses_params, mock_context, []) + await retrieve_response_generator(mock_responses_params, mock_context) assert exc_info.value.status_code == 413 @@ -1034,21 +1053,19 @@ async def test_retrieve_response_generator_runtime_error_other( mock_context.client = mock_client mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.query_request = QueryRequest( query="test" ) # pyright: ignore[reportCallIssue] + mock_context.moderation_result = ShieldModerationPassed() - mocker.patch( - "app.endpoints.streaming_query.run_shield_moderation", - new=mocker.AsyncMock(return_value=mocker.Mock(blocked=False)), - ) mock_client.responses = mocker.Mock() mock_client.responses.create = mocker.AsyncMock( side_effect=RuntimeError("Some other error") ) with pytest.raises(RuntimeError): - await retrieve_response_generator(mock_responses_params, mock_context, []) + await retrieve_response_generator(mock_responses_params, mock_context) class TestGenerateResponse: @@ -1077,6 +1094,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.user_id = "user_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.query_request = QueryRequest( query="test" ) # pyright: ignore[reportCallIssue] @@ -1134,6 +1152,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.user_id = "user_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.query_request = QueryRequest( query="test", generate_topic_summary=True ) # pyright: ignore[reportCallIssue] @@ -1186,6 +1205,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.conversation_id = "conv_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( query="test" @@ -1228,6 +1248,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.conversation_id = "conv_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( query="test" @@ -1273,6 +1294,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.conversation_id = "conv_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( query="test", media_type=MEDIA_TYPE_JSON @@ -1320,6 +1342,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.conversation_id = "conv_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( query="test", media_type=MEDIA_TYPE_JSON @@ -1608,6 +1631,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -1637,6 +1661,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -1667,6 +1692,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -1707,6 +1733,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -1748,6 +1775,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -1796,6 +1824,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -1846,6 +1875,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() mock_turn_summary.llm_response = "Response" @@ -1893,6 +1923,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -1938,6 +1969,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -1982,6 +2014,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -2024,6 +2057,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -2067,6 +2101,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -2108,6 +2143,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -2128,6 +2164,61 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: assert len(result) > 0 assert any("error" in item for item in result) + @pytest.mark.asyncio + async def test_response_generator_merges_inline_and_tool_rag_chunks_and_documents( + self, mocker: MockerFixture + ) -> None: + """Test that inline RAG and tool-based RAG chunks/docs are correctly merged.""" + inline_chunk = RAGChunk(content="inline chunk content", source="byok") + inline_doc = ReferencedDocument(doc_title="Inline Doc") + inline_rag = RAGContext( + context_text="", + rag_chunks=[inline_chunk], + referenced_documents=[inline_doc], + ) + + tool_chunk = RAGChunk(content="tool chunk content", source="vs-1") + tool_ref_doc = ReferencedDocument(doc_title="Tool Doc") + + mock_response_obj = mocker.Mock(spec=OpenAIResponseObject) + mock_response_obj.usage = mocker.Mock() + mock_response_obj.output = [] + + async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: + completed_chunk = mocker.Mock(spec=CompletedChunk) + completed_chunk.type = "response.completed" + completed_chunk.response = mock_response_obj + yield completed_chunk + + mock_context = mocker.Mock(spec=ResponseGeneratorContext) + mock_context.query_request = QueryRequest( + query="test", media_type=MEDIA_TYPE_JSON + ) # pyright: ignore[reportCallIssue] + mock_context.model_id = "provider1/model1" + mock_context.vector_store_ids = [] + mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = inline_rag + + mock_turn_summary = TurnSummary() + mock_turn_summary.rag_chunks = [tool_chunk] + mock_turn_summary.referenced_documents = [tool_ref_doc] + mocker.patch( + "app.endpoints.streaming_query.parse_referenced_documents", + return_value=[tool_ref_doc], + ) + + async for _ in response_generator( + mock_turn_response(), mock_context, mock_turn_summary + ): + pass + + assert len(mock_turn_summary.rag_chunks) == 2 + assert mock_turn_summary.rag_chunks[0].content == "inline chunk content" + assert mock_turn_summary.rag_chunks[1].content == "tool chunk content" + assert len(mock_turn_summary.referenced_documents) == 2 + assert mock_turn_summary.referenced_documents[0].doc_title == "Inline Doc" + assert mock_turn_summary.referenced_documents[1].doc_title == "Tool Doc" + class TestStreamHttpErrorEvent: """Tests for stream_http_error_event function.""" @@ -2230,6 +2321,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -2282,6 +2374,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -2354,6 +2447,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() @@ -2441,6 +2535,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_context.model_id = "provider1/model1" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.inline_rag_context = RAGContext() mock_turn_summary = TurnSummary() diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index 55ee56886..333c96df0 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -2,6 +2,7 @@ import pytest from fastapi import HTTPException, status +from llama_stack_client import APIConnectionError, APIStatusError from pytest_mock import MockerFixture from utils.shields import ( @@ -9,6 +10,7 @@ append_turn_to_conversation, detect_shield_violations, get_available_shields, + get_shields_for_request, run_shield_moderation, validate_shield_ids_override, ) @@ -305,60 +307,25 @@ async def test_raises_http_exception_when_shield_has_no_provider_resource_id( assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_returns_blocked_on_bad_request_error( + async def test_shield_ids_empty_list_runs_no_shields_returns_passed( self, mocker: MockerFixture ) -> None: - """Test that run_shield_moderation returns blocked when ValueError is raised.""" - mock_metric = mocker.patch( - "utils.shields.metrics.llm_calls_validation_errors_total" - ) - mock_client = mocker.Mock() - - # Setup shield - shield = mocker.Mock() - shield.identifier = "test-shield" - shield.provider_resource_id = "moderation-model" - mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) - - # Setup model - model = mocker.Mock() - model.id = "moderation-model" - mock_client.models.list = mocker.AsyncMock(return_value=[model]) - - # Setup moderation to raise ValueError (known Llama Stack bug) - mock_client.moderations.create = mocker.AsyncMock( - side_effect=ValueError("Bad request") - ) - - result = await run_shield_moderation(mock_client, "test input") - - assert result.decision == "blocked" - assert result.message == DEFAULT_VIOLATION_MESSAGE - mock_metric.inc.assert_called_once() - - @pytest.mark.asyncio - async def test_shield_ids_empty_list_raises_422( - self, mocker: MockerFixture - ) -> None: - """Test that shield_ids=[] raises HTTPException 422 (prevents bypass).""" + """Test that shield_ids=[] runs no shields and returns passed.""" mock_client = mocker.Mock() shield = mocker.Mock() shield.identifier = "shield-1" mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + mock_client.models.list = mocker.AsyncMock(return_value=[]) - with pytest.raises(HTTPException) as exc_info: - await run_shield_moderation(mock_client, "test input", shield_ids=[]) + result = await run_shield_moderation(mock_client, "test input", shield_ids=[]) - assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - assert "shield_ids provided but no shields selected" in str( - exc_info.value.detail - ) + assert result.decision == "passed" @pytest.mark.asyncio - async def test_shield_ids_raises_exception_when_no_shields_found( + async def test_shield_ids_raises_404_when_no_shields_found( self, mocker: MockerFixture ) -> None: - """Test shield_ids raises HTTPException when no requested shields exist.""" + """Test shield_ids raises HTTPException 404 when requested shield not configured.""" mock_client = mocker.Mock() shield = mocker.Mock() shield.identifier = "shield-1" @@ -369,8 +336,8 @@ async def test_shield_ids_raises_exception_when_no_shields_found( mock_client, "test input", shield_ids=["typo-shield"] ) - assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - assert "Invalid shield configuration" in exc_info.value.detail["response"] # type: ignore + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert "Shield" in exc_info.value.detail["response"] # type: ignore assert "typo-shield" in exc_info.value.detail["cause"] # type: ignore @pytest.mark.asyncio @@ -518,3 +485,132 @@ def test_raises_422_when_empty_list_shield_ids_and_override_disabled( validate_shield_ids_override(query_request, mock_config) assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +class TestGetShieldsForRequest: + """Tests for get_shields_for_request function.""" + + @pytest.mark.asyncio + async def test_returns_all_shields_when_shield_ids_none( + self, mocker: MockerFixture + ) -> None: + """Return all configured shields when shield_ids is None.""" + mock_client = mocker.Mock() + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2]) + + result = await get_shields_for_request(mock_client, shield_ids=None) + + assert len(result) == 2 + assert result[0].identifier == "shield-1" + assert result[1].identifier == "shield-2" + mock_client.shields.list.assert_called_once() + + @pytest.mark.asyncio + async def test_returns_empty_list_when_no_shields_configured( + self, mocker: MockerFixture + ) -> None: + """Test that get_shields_for_request returns empty list when no shields configured.""" + mock_client = mocker.Mock() + mock_client.shields.list = mocker.AsyncMock(return_value=[]) + + result = await get_shields_for_request(mock_client, shield_ids=None) + + assert result == [] + + @pytest.mark.asyncio + async def test_filters_to_requested_shields_when_all_exist( + self, mocker: MockerFixture + ) -> None: + """Test that get_shields_for_request returns only requested shields when all exist.""" + mock_client = mocker.Mock() + shield1 = mocker.Mock() + shield1.identifier = "shield-1" + shield2 = mocker.Mock() + shield2.identifier = "shield-2" + shield3 = mocker.Mock() + shield3.identifier = "shield-3" + mock_client.shields.list = mocker.AsyncMock( + return_value=[shield1, shield2, shield3] + ) + + result = await get_shields_for_request( + mock_client, shield_ids=["shield-1", "shield-3"] + ) + + assert len(result) == 2 + assert result[0].identifier == "shield-1" + assert result[1].identifier == "shield-3" + + @pytest.mark.asyncio + async def test_raises_404_when_requested_shield_not_configured( + self, mocker: MockerFixture + ) -> None: + """Raise 404 when a requested shield is not configured.""" + mock_client = mocker.Mock() + shield = mocker.Mock() + shield.identifier = "shield-1" + mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) + + with pytest.raises(HTTPException) as exc_info: + await get_shields_for_request( + mock_client, shield_ids=["shield-1", "missing-shield"] + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert "Shield" in exc_info.value.detail["response"] # type: ignore + assert "missing-shield" in exc_info.value.detail["cause"] # type: ignore + + @pytest.mark.asyncio + async def test_raises_404_when_multiple_requested_shields_not_configured( + self, mocker: MockerFixture + ) -> None: + """Raise 404 with all missing ids when multiple shields not configured.""" + mock_client = mocker.Mock() + mock_client.shields.list = mocker.AsyncMock(return_value=[]) + + with pytest.raises(HTTPException) as exc_info: + await get_shields_for_request( + mock_client, shield_ids=["missing-1", "missing-2"] + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert "Shields" in exc_info.value.detail["response"] # type: ignore + cause = exc_info.value.detail["cause"] # type: ignore + assert "missing-1" in cause + assert "missing-2" in cause + + @pytest.mark.asyncio + async def test_raises_503_on_connection_error(self, mocker: MockerFixture) -> None: + """Raise 503 on APIConnectionError.""" + mock_client = mocker.Mock() + mock_client.shields.list = mocker.AsyncMock( + side_effect=APIConnectionError( + message="Connection failed", request=mocker.Mock() + ) + ) + + with pytest.raises(HTTPException) as exc_info: + await get_shields_for_request(mock_client, shield_ids=None) + + assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + + @pytest.mark.asyncio + async def test_raises_500_on_api_status_error(self, mocker: MockerFixture) -> None: + """Raise 500 on APIStatusError.""" + mock_client = mocker.Mock() + mock_client.shields.list = mocker.AsyncMock( + side_effect=APIStatusError( + message="Server error", + response=mocker.Mock(request=None), + body=None, + ) + ) + + with pytest.raises(HTTPException) as exc_info: + await get_shields_for_request(mock_client, shield_ids=None) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR diff --git a/tests/unit/utils/test_vector_search.py b/tests/unit/utils/test_vector_search.py index 4930cb846..930f59d36 100644 --- a/tests/unit/utils/test_vector_search.py +++ b/tests/unit/utils/test_vector_search.py @@ -462,7 +462,7 @@ async def test_both_sources_disabled(self, mocker) -> None: # type: ignore[no-u mocker.patch("utils.vector_search.configuration", config_mock) client_mock = mocker.AsyncMock() - context = await build_rag_context(client_mock, "test query", None) + context = await build_rag_context(client_mock, "passed", "test query", None) assert context.context_text == "" assert context.rag_chunks == [] @@ -497,7 +497,7 @@ async def test_byok_enabled_only(self, mocker) -> None: # type: ignore[no-untyp client_mock = mocker.AsyncMock() client_mock.vector_io.query.return_value = search_response - context = await build_rag_context(client_mock, "test query", None) + context = await build_rag_context(client_mock, "passed", "test query", None) assert len(context.rag_chunks) > 0 assert "BYOK content" in context.context_text