Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 42 additions & 60 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
UnauthorizedResponse,
UnprocessableEntityResponse,
)
from utils.conversations import append_turn_items_to_conversation
from utils.endpoints import (
check_configuration_loaded,
validate_and_retrieve_conversation,
Expand All @@ -59,14 +60,11 @@
get_topic_summary,
prepare_responses_params,
)
from utils.shields import (
append_turn_to_conversation,
run_shield_moderation,
validate_shield_ids_override,
)
from utils.shields import run_shield_moderation, validate_shield_ids_override
from utils.suid import normalize_conversation_id
from utils.types import (
ResponsesApiParams,
ShieldModerationResult,
TurnSummary,
)
from utils.vector_search import build_rag_context
Expand Down Expand Up @@ -158,14 +156,21 @@ async def query_endpoint_handler(

client = AsyncLlamaStackClientHolder().get_client()

# Build RAG context from Inline RAG sources
inline_rag_context = await build_rag_context(
client, query_request.query, query_request.vector_store_ids, query_request.solr
)

# Moderation input is the raw user content (query + attachments) without injected RAG
# context, to avoid false positives from retrieved document content.
moderation_input = prepare_input(query_request)
moderation_result = await run_shield_moderation(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do moderation before the inline RAG

client, moderation_input, query_request.shield_ids
)

# Build RAG context from Inline RAG sources
inline_rag_context = await build_rag_context(
client,
moderation_result.decision,
query_request.query,
query_request.vector_store_ids,
query_request.solr,
)

# Prepare API request parameters
responses_params = await prepare_responses_params(
Expand All @@ -177,7 +182,7 @@ async def query_endpoint_handler(
stream=False,
store=True,
request_headers=request.headers,
inline_rag_context=inline_rag_context.context_text or None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model uses default factory list

inline_rag_context=inline_rag_context.context_text,
)

# Handle Azure token refresh if needed
Expand All @@ -189,32 +194,22 @@ async def query_endpoint_handler(
):
client = await update_azure_token(client)

# Build index identification mapping for RAG source resolution
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
Copy link
Contributor Author

@asimurka asimurka Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this inside retrieve_response instead of passing as arguments

rag_id_mapping = configuration.rag_id_mapping

# Retrieve response using Responses API
turn_summary = await retrieve_response(
client,
responses_params,
query_request.shield_ids,
vector_store_ids,
rag_id_mapping,
moderation_input=moderation_input,
)

# Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript
rag_chunks = inline_rag_context.rag_chunks
tool_rag_chunks = turn_summary.rag_chunks or []
logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks))
turn_summary.rag_chunks = rag_chunks + tool_rag_chunks

# Add tool-based RAG documents and chunks
rag_documents = inline_rag_context.referenced_documents
tool_rag_documents = turn_summary.referenced_documents or []
turn_summary.referenced_documents = deduplicate_referenced_documents(
rag_documents + tool_rag_documents
)
turn_summary = await retrieve_response(client, responses_params, moderation_result)

if moderation_result.decision == "passed":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only append RAG chunks if moderation passed

# Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript
rag_chunks = inline_rag_context.rag_chunks
tool_rag_chunks = turn_summary.rag_chunks
logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks))
turn_summary.rag_chunks = rag_chunks + tool_rag_chunks

# Add tool-based RAG documents and chunks
rag_documents = inline_rag_context.referenced_documents
tool_rag_documents = turn_summary.referenced_documents
turn_summary.referenced_documents = deduplicate_referenced_documents(
rag_documents + tool_rag_documents
)

# Get topic summary for new conversation
if not user_conversation and query_request.generate_topic_summary:
Expand Down Expand Up @@ -272,10 +267,7 @@ async def query_endpoint_handler(
async def retrieve_response( # pylint: disable=too-many-locals
client: AsyncLlamaStackClient,
responses_params: ResponsesApiParams,
shield_ids: Optional[list[str]] = None,
vector_store_ids: Optional[list[str]] = None,
rag_id_mapping: Optional[dict[str, str]] = None,
moderation_input: Optional[str] = None,
moderation_result: ShieldModerationResult,
) -> TurnSummary:
"""
Retrieve response from LLMs and agents.
Expand All @@ -286,33 +278,21 @@ async def retrieve_response( # pylint: disable=too-many-locals
Parameters:
client: The AsyncLlamaStackClient to use for the request.
responses_params: The Responses API parameters.
shield_ids: Optional list of shield IDs for moderation.
vector_store_ids: Vector store IDs used in the query for source resolution.
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
moderation_input: Text to moderate. Should be the raw user content (query +
attachments) without injected RAG context to avoid false positives.
Falls back to responses_params.input if not provided.
moderation_result: The moderation result.

Returns:
TurnSummary: Summary of the LLM response content
"""
response: Optional[OpenAIResponseObject] = None
try:
moderation_result = await run_shield_moderation(
if moderation_result.decision == "blocked":
await append_turn_items_to_conversation(
client,
moderation_input or cast(str, responses_params.input),
shield_ids,
responses_params.conversation,
responses_params.input,
[moderation_result.refusal_response],
)
if moderation_result.decision == "blocked":
# Handle shield moderation blocking
violation_message = moderation_result.message
await append_turn_to_conversation(
client,
responses_params.conversation,
cast(str, responses_params.input),
violation_message,
)
return TurnSummary(llm_response=violation_message)
return TurnSummary(llm_response=moderation_result.message)
try:
response = await client.responses.create(
**responses_params.model_dump(exclude_none=True)
)
Expand All @@ -333,6 +313,8 @@ async def retrieve_response( # pylint: disable=too-many-locals
error_response = handle_known_apistatus_errors(e, responses_params.model)
raise HTTPException(**error_response.model_dump()) from e

vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced from outer scope

rag_id_mapping = configuration.rag_id_mapping
return build_turn_summary(
response, responses_params.model, vector_store_ids, rag_id_mapping
)
73 changes: 48 additions & 25 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
UnauthorizedResponse,
UnprocessableEntityResponse,
)
from utils.conversations import append_turn_items_to_conversation
from utils.endpoints import (
check_configuration_loaded,
validate_and_retrieve_conversation,
Expand Down Expand Up @@ -189,10 +190,22 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals

client = AsyncLlamaStackClientHolder().get_client()

# Moderation input is the raw user content (query + attachments) without injected RAG
# context, to avoid false positives from retrieved document content.
moderation_input = prepare_input(query_request)
moderation_result = await run_shield_moderation(
client, moderation_input, query_request.shield_ids
)

# Build RAG context from Inline RAG sources
inline_rag_context = await build_rag_context(
client, query_request.query, query_request.vector_store_ids, query_request.solr
client,
moderation_result.decision,
query_request.query,
query_request.vector_store_ids,
query_request.solr,
)

# Prepare API request parameters
responses_params = await prepare_responses_params(
client=client,
Expand All @@ -203,7 +216,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
stream=True,
store=True,
request_headers=request.headers,
inline_rag_context=inline_rag_context.context_text or None,
inline_rag_context=inline_rag_context.context_text,
)

# Handle Azure token refresh if needed
Expand All @@ -227,8 +240,10 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
query_request=query_request,
started_at=started_at,
client=client,
moderation_result=moderation_result,
vector_store_ids=extract_vector_store_ids_from_tools(responses_params.tools),
rag_id_mapping=configuration.rag_id_mapping,
inline_rag_context=inline_rag_context,
)

# Update metrics for the LLM call
Expand All @@ -240,9 +255,14 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
generator, turn_summary = await retrieve_response_generator(
responses_params=responses_params,
context=context,
inline_rag_documents=inline_rag_context.referenced_documents,
)

# Combine inline RAG results (BYOK + Solr) with tool-based results
if context.moderation_result.decision == "passed":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine tool-based and inline referenced docs here

turn_summary.referenced_documents = deduplicate_referenced_documents(
inline_rag_context.referenced_documents + turn_summary.referenced_documents
)

response_media_type = (
MEDIA_TYPE_TEXT
if query_request.media_type == MEDIA_TYPE_TEXT
Expand All @@ -263,7 +283,6 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
async def retrieve_response_generator(
responses_params: ResponsesApiParams,
context: ResponseGeneratorContext,
inline_rag_documents: list[ReferencedDocument],
) -> tuple[AsyncIterator[str], TurnSummary]:
"""
Retrieve the appropriate response generator.
Expand All @@ -275,40 +294,41 @@ async def retrieve_response_generator(
Args:
responses_params: The Responses API parameters
context: The response generator context
inline_rag_documents: Referenced documents from inline RAG (BYOK + Solr)

Returns:
tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary

"""
turn_summary = TurnSummary()
try:
moderation_result = await run_shield_moderation(
context.client,
prepare_input(context.query_request),
context.query_request.shield_ids,
)
if moderation_result.decision == "blocked":
turn_summary.llm_response = moderation_result.message
await append_turn_to_conversation(
if context.moderation_result.decision == "blocked":
turn_summary.llm_response = context.moderation_result.message
await append_turn_items_to_conversation(
context.client,
responses_params.conversation,
cast(str, responses_params.input),
moderation_result.message,
responses_params.input,
[context.moderation_result.refusal_response],
)
media_type = context.query_request.media_type or MEDIA_TYPE_JSON
return (
shield_violation_generator(moderation_result.message, media_type),
shield_violation_generator(
context.moderation_result.message,
media_type,
),
turn_summary,
)
# Retrieve response stream (may raise exceptions)
response = await context.client.responses.create(
**responses_params.model_dump(exclude_none=True)
)
# Store pre-RAG documents for later merging with tool-based RAG
turn_summary.inline_rag_documents = inline_rag_documents
return response_generator(response, context, turn_summary), turn_summary

return (
response_generator(
response,
context,
turn_summary,
),
turn_summary,
)
# Handle know LLS client errors only at stream creation time and shield execution
except RuntimeError as e: # library mode wraps 413 into runtime error
if "context_length" in str(e).lower():
Expand Down Expand Up @@ -570,7 +590,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
turn_response: The streaming response from Llama Stack
context: The response generator context
turn_summary: TurnSummary to populate during streaming

Yields:
SSE-formatted strings for tokens, tool calls, tool results,
turn completion, and error events.
Expand Down Expand Up @@ -741,15 +760,19 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
turn_summary.token_usage = extract_token_usage(
latest_response_object.usage, context.model_id
)
tool_based_documents = parse_referenced_documents(
# Parse tool-based referenced documents from the final response object
tool_rag_docs = parse_referenced_documents(
latest_response_object,
vector_store_ids=context.vector_store_ids,
rag_id_mapping=context.rag_id_mapping,
)

# Merge pre-RAG documents with tool-based documents and deduplicate
# Combine inline RAG results (BYOK + Solr) with tool-based results
turn_summary.referenced_documents = deduplicate_referenced_documents(
turn_summary.inline_rag_documents + tool_based_documents
context.inline_rag_context.referenced_documents + tool_rag_docs
)
# Combine inline RAG chunks (BYOK + Solr) with tool-based chunks
turn_summary.rag_chunks = (
context.inline_rag_context.rag_chunks + turn_summary.rag_chunks
)


Expand Down
2 changes: 2 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,5 @@
# Environment variable to force StreamHandler instead of RichHandler
# Set to any non-empty value to disable RichHandler
LIGHTSPEED_STACK_DISABLE_RICH_HANDLER_ENV_VAR = "LIGHTSPEED_STACK_DISABLE_RICH_HANDLER"

DEFAULT_VIOLATION_MESSAGE = "I cannot process this request due to policy restrictions."
5 changes: 5 additions & 0 deletions src/models/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from llama_stack_client import AsyncLlamaStackClient

from models.requests import QueryRequest
from utils.types import RAGContext, ShieldModerationResult


@dataclass
Expand All @@ -23,6 +24,8 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes
query_request: The query request object
started_at: Timestamp when the request started (ISO 8601 format)
client: The Llama Stack client for API interactions
moderation_result: The moderation result
inline_rag_context: Inline RAG context
vector_store_ids: Vector store IDs used in the query for source resolution.
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
"""
Expand All @@ -42,7 +45,9 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes

# Dependencies & State
client: AsyncLlamaStackClient
moderation_result: ShieldModerationResult

# RAG index identification
inline_rag_context: RAGContext
vector_store_ids: list[str] = field(default_factory=list)
rag_id_mapping: dict[str, str] = field(default_factory=dict)
3 changes: 1 addition & 2 deletions src/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ class QueryRequest(BaseModel):
shield_ids: Optional[list[str]] = Field(
None,
description="Optional list of safety shield IDs to apply. "
"If None, all configured shields are used. "
"If provided, must contain at least one valid shield ID (empty list raises 422 error).",
"If None, all configured shields are used. ",
examples=["llama-guard", "custom-shield"],
)

Expand Down
Loading
Loading