-
Notifications
You must be signed in to change notification settings - Fork 78
LCORE-1409: Refactor of shield moderation and inline RAG content persistence #1291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,7 @@ | |
| UnauthorizedResponse, | ||
| UnprocessableEntityResponse, | ||
| ) | ||
| from utils.conversations import append_turn_items_to_conversation | ||
| from utils.endpoints import ( | ||
| check_configuration_loaded, | ||
| validate_and_retrieve_conversation, | ||
|
|
@@ -59,14 +60,11 @@ | |
| get_topic_summary, | ||
| prepare_responses_params, | ||
| ) | ||
| from utils.shields import ( | ||
| append_turn_to_conversation, | ||
| run_shield_moderation, | ||
| validate_shield_ids_override, | ||
| ) | ||
| from utils.shields import run_shield_moderation, validate_shield_ids_override | ||
| from utils.suid import normalize_conversation_id | ||
| from utils.types import ( | ||
| ResponsesApiParams, | ||
| ShieldModerationResult, | ||
| TurnSummary, | ||
| ) | ||
| from utils.vector_search import build_rag_context | ||
|
|
@@ -158,14 +156,21 @@ async def query_endpoint_handler( | |
|
|
||
| client = AsyncLlamaStackClientHolder().get_client() | ||
|
|
||
| # Build RAG context from Inline RAG sources | ||
| inline_rag_context = await build_rag_context( | ||
| client, query_request.query, query_request.vector_store_ids, query_request.solr | ||
| ) | ||
|
|
||
| # Moderation input is the raw user content (query + attachments) without injected RAG | ||
| # context, to avoid false positives from retrieved document content. | ||
| moderation_input = prepare_input(query_request) | ||
| moderation_result = await run_shield_moderation( | ||
| client, moderation_input, query_request.shield_ids | ||
| ) | ||
|
|
||
| # Build RAG context from Inline RAG sources | ||
| inline_rag_context = await build_rag_context( | ||
| client, | ||
| moderation_result.decision, | ||
| query_request.query, | ||
| query_request.vector_store_ids, | ||
| query_request.solr, | ||
| ) | ||
|
|
||
| # Prepare API request parameters | ||
| responses_params = await prepare_responses_params( | ||
|
|
@@ -177,7 +182,7 @@ async def query_endpoint_handler( | |
| stream=False, | ||
| store=True, | ||
| request_headers=request.headers, | ||
| inline_rag_context=inline_rag_context.context_text or None, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Model uses default factory list |
||
| inline_rag_context=inline_rag_context.context_text, | ||
| ) | ||
|
|
||
| # Handle Azure token refresh if needed | ||
|
|
@@ -189,32 +194,22 @@ async def query_endpoint_handler( | |
| ): | ||
| client = await update_azure_token(client) | ||
|
|
||
| # Build index identification mapping for RAG source resolution | ||
| vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do this inside retrieve_response instead of passing as arguments |
||
| rag_id_mapping = configuration.rag_id_mapping | ||
|
|
||
| # Retrieve response using Responses API | ||
| turn_summary = await retrieve_response( | ||
| client, | ||
| responses_params, | ||
| query_request.shield_ids, | ||
| vector_store_ids, | ||
| rag_id_mapping, | ||
| moderation_input=moderation_input, | ||
| ) | ||
|
|
||
| # Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript | ||
| rag_chunks = inline_rag_context.rag_chunks | ||
| tool_rag_chunks = turn_summary.rag_chunks or [] | ||
| logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks)) | ||
| turn_summary.rag_chunks = rag_chunks + tool_rag_chunks | ||
|
|
||
| # Add tool-based RAG documents and chunks | ||
| rag_documents = inline_rag_context.referenced_documents | ||
| tool_rag_documents = turn_summary.referenced_documents or [] | ||
| turn_summary.referenced_documents = deduplicate_referenced_documents( | ||
| rag_documents + tool_rag_documents | ||
| ) | ||
| turn_summary = await retrieve_response(client, responses_params, moderation_result) | ||
|
|
||
| if moderation_result.decision == "passed": | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only append RAG chunks if moderation passed |
||
| # Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript | ||
| rag_chunks = inline_rag_context.rag_chunks | ||
| tool_rag_chunks = turn_summary.rag_chunks | ||
| logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks)) | ||
| turn_summary.rag_chunks = rag_chunks + tool_rag_chunks | ||
|
|
||
| # Add tool-based RAG documents and chunks | ||
| rag_documents = inline_rag_context.referenced_documents | ||
| tool_rag_documents = turn_summary.referenced_documents | ||
| turn_summary.referenced_documents = deduplicate_referenced_documents( | ||
| rag_documents + tool_rag_documents | ||
| ) | ||
|
|
||
| # Get topic summary for new conversation | ||
| if not user_conversation and query_request.generate_topic_summary: | ||
|
|
@@ -272,10 +267,7 @@ async def query_endpoint_handler( | |
| async def retrieve_response( # pylint: disable=too-many-locals | ||
| client: AsyncLlamaStackClient, | ||
| responses_params: ResponsesApiParams, | ||
| shield_ids: Optional[list[str]] = None, | ||
| vector_store_ids: Optional[list[str]] = None, | ||
| rag_id_mapping: Optional[dict[str, str]] = None, | ||
| moderation_input: Optional[str] = None, | ||
| moderation_result: ShieldModerationResult, | ||
| ) -> TurnSummary: | ||
| """ | ||
| Retrieve response from LLMs and agents. | ||
|
|
@@ -286,33 +278,21 @@ async def retrieve_response( # pylint: disable=too-many-locals | |
| Parameters: | ||
| client: The AsyncLlamaStackClient to use for the request. | ||
| responses_params: The Responses API parameters. | ||
| shield_ids: Optional list of shield IDs for moderation. | ||
| vector_store_ids: Vector store IDs used in the query for source resolution. | ||
| rag_id_mapping: Mapping from vector_db_id to user-facing rag_id. | ||
| moderation_input: Text to moderate. Should be the raw user content (query + | ||
| attachments) without injected RAG context to avoid false positives. | ||
| Falls back to responses_params.input if not provided. | ||
| moderation_result: The moderation result. | ||
|
|
||
| Returns: | ||
| TurnSummary: Summary of the LLM response content | ||
| """ | ||
| response: Optional[OpenAIResponseObject] = None | ||
| try: | ||
| moderation_result = await run_shield_moderation( | ||
| if moderation_result.decision == "blocked": | ||
| await append_turn_items_to_conversation( | ||
| client, | ||
| moderation_input or cast(str, responses_params.input), | ||
| shield_ids, | ||
| responses_params.conversation, | ||
| responses_params.input, | ||
| [moderation_result.refusal_response], | ||
| ) | ||
| if moderation_result.decision == "blocked": | ||
| # Handle shield moderation blocking | ||
| violation_message = moderation_result.message | ||
| await append_turn_to_conversation( | ||
| client, | ||
| responses_params.conversation, | ||
| cast(str, responses_params.input), | ||
| violation_message, | ||
| ) | ||
| return TurnSummary(llm_response=violation_message) | ||
| return TurnSummary(llm_response=moderation_result.message) | ||
| try: | ||
| response = await client.responses.create( | ||
| **responses_params.model_dump(exclude_none=True) | ||
| ) | ||
|
|
@@ -333,6 +313,8 @@ async def retrieve_response( # pylint: disable=too-many-locals | |
| error_response = handle_known_apistatus_errors(e, responses_params.model) | ||
| raise HTTPException(**error_response.model_dump()) from e | ||
|
|
||
| vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replaced from outer scope |
||
| rag_id_mapping = configuration.rag_id_mapping | ||
| return build_turn_summary( | ||
| response, responses_params.model, vector_store_ids, rag_id_mapping | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,6 +57,7 @@ | |
| UnauthorizedResponse, | ||
| UnprocessableEntityResponse, | ||
| ) | ||
| from utils.conversations import append_turn_items_to_conversation | ||
| from utils.endpoints import ( | ||
| check_configuration_loaded, | ||
| validate_and_retrieve_conversation, | ||
|
|
@@ -189,10 +190,22 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals | |
|
|
||
| client = AsyncLlamaStackClientHolder().get_client() | ||
|
|
||
| # Moderation input is the raw user content (query + attachments) without injected RAG | ||
| # context, to avoid false positives from retrieved document content. | ||
| moderation_input = prepare_input(query_request) | ||
| moderation_result = await run_shield_moderation( | ||
| client, moderation_input, query_request.shield_ids | ||
| ) | ||
|
|
||
| # Build RAG context from Inline RAG sources | ||
| inline_rag_context = await build_rag_context( | ||
| client, query_request.query, query_request.vector_store_ids, query_request.solr | ||
| client, | ||
| moderation_result.decision, | ||
| query_request.query, | ||
| query_request.vector_store_ids, | ||
| query_request.solr, | ||
| ) | ||
|
|
||
| # Prepare API request parameters | ||
| responses_params = await prepare_responses_params( | ||
| client=client, | ||
|
|
@@ -203,7 +216,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals | |
| stream=True, | ||
| store=True, | ||
| request_headers=request.headers, | ||
| inline_rag_context=inline_rag_context.context_text or None, | ||
| inline_rag_context=inline_rag_context.context_text, | ||
| ) | ||
|
|
||
| # Handle Azure token refresh if needed | ||
|
|
@@ -227,8 +240,10 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals | |
| query_request=query_request, | ||
| started_at=started_at, | ||
| client=client, | ||
| moderation_result=moderation_result, | ||
| vector_store_ids=extract_vector_store_ids_from_tools(responses_params.tools), | ||
| rag_id_mapping=configuration.rag_id_mapping, | ||
| inline_rag_context=inline_rag_context, | ||
| ) | ||
|
|
||
| # Update metrics for the LLM call | ||
|
|
@@ -240,9 +255,14 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals | |
| generator, turn_summary = await retrieve_response_generator( | ||
| responses_params=responses_params, | ||
| context=context, | ||
| inline_rag_documents=inline_rag_context.referenced_documents, | ||
| ) | ||
|
|
||
| # Combine inline RAG results (BYOK + Solr) with tool-based results | ||
| if context.moderation_result.decision == "passed": | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combine tool-based and inline referenced docs here |
||
| turn_summary.referenced_documents = deduplicate_referenced_documents( | ||
| inline_rag_context.referenced_documents + turn_summary.referenced_documents | ||
| ) | ||
|
|
||
| response_media_type = ( | ||
| MEDIA_TYPE_TEXT | ||
| if query_request.media_type == MEDIA_TYPE_TEXT | ||
|
|
@@ -263,7 +283,6 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals | |
| async def retrieve_response_generator( | ||
| responses_params: ResponsesApiParams, | ||
| context: ResponseGeneratorContext, | ||
| inline_rag_documents: list[ReferencedDocument], | ||
| ) -> tuple[AsyncIterator[str], TurnSummary]: | ||
| """ | ||
| Retrieve the appropriate response generator. | ||
|
|
@@ -275,40 +294,41 @@ async def retrieve_response_generator( | |
| Args: | ||
| responses_params: The Responses API parameters | ||
| context: The response generator context | ||
| inline_rag_documents: Referenced documents from inline RAG (BYOK + Solr) | ||
|
|
||
| Returns: | ||
| tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary | ||
|
|
||
| """ | ||
| turn_summary = TurnSummary() | ||
| try: | ||
| moderation_result = await run_shield_moderation( | ||
| context.client, | ||
| prepare_input(context.query_request), | ||
| context.query_request.shield_ids, | ||
| ) | ||
| if moderation_result.decision == "blocked": | ||
| turn_summary.llm_response = moderation_result.message | ||
| await append_turn_to_conversation( | ||
| if context.moderation_result.decision == "blocked": | ||
| turn_summary.llm_response = context.moderation_result.message | ||
| await append_turn_items_to_conversation( | ||
| context.client, | ||
| responses_params.conversation, | ||
| cast(str, responses_params.input), | ||
| moderation_result.message, | ||
| responses_params.input, | ||
| [context.moderation_result.refusal_response], | ||
| ) | ||
| media_type = context.query_request.media_type or MEDIA_TYPE_JSON | ||
| return ( | ||
| shield_violation_generator(moderation_result.message, media_type), | ||
| shield_violation_generator( | ||
| context.moderation_result.message, | ||
| media_type, | ||
| ), | ||
| turn_summary, | ||
| ) | ||
| # Retrieve response stream (may raise exceptions) | ||
| response = await context.client.responses.create( | ||
| **responses_params.model_dump(exclude_none=True) | ||
| ) | ||
| # Store pre-RAG documents for later merging with tool-based RAG | ||
| turn_summary.inline_rag_documents = inline_rag_documents | ||
| return response_generator(response, context, turn_summary), turn_summary | ||
|
|
||
| return ( | ||
| response_generator( | ||
| response, | ||
| context, | ||
| turn_summary, | ||
| ), | ||
| turn_summary, | ||
| ) | ||
| # Handle know LLS client errors only at stream creation time and shield execution | ||
| except RuntimeError as e: # library mode wraps 413 into runtime error | ||
| if "context_length" in str(e).lower(): | ||
|
|
@@ -570,7 +590,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat | |
| turn_response: The streaming response from Llama Stack | ||
| context: The response generator context | ||
| turn_summary: TurnSummary to populate during streaming | ||
|
|
||
| Yields: | ||
| SSE-formatted strings for tokens, tool calls, tool results, | ||
| turn completion, and error events. | ||
|
|
@@ -741,15 +760,19 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat | |
| turn_summary.token_usage = extract_token_usage( | ||
| latest_response_object.usage, context.model_id | ||
| ) | ||
| tool_based_documents = parse_referenced_documents( | ||
| # Parse tool-based referenced documents from the final response object | ||
| tool_rag_docs = parse_referenced_documents( | ||
| latest_response_object, | ||
| vector_store_ids=context.vector_store_ids, | ||
| rag_id_mapping=context.rag_id_mapping, | ||
| ) | ||
|
|
||
| # Merge pre-RAG documents with tool-based documents and deduplicate | ||
| # Combine inline RAG results (BYOK + Solr) with tool-based results | ||
| turn_summary.referenced_documents = deduplicate_referenced_documents( | ||
| turn_summary.inline_rag_documents + tool_based_documents | ||
| context.inline_rag_context.referenced_documents + tool_rag_docs | ||
| ) | ||
| # Combine inline RAG chunks (BYOK + Solr) with tool-based chunks | ||
| turn_summary.rag_chunks = ( | ||
| context.inline_rag_context.rag_chunks + turn_summary.rag_chunks | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do moderation before the inline RAG