Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/app/endpoints/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals
generate_topic_summary=True,
media_type=None,
vector_store_ids=vector_store_ids,
shield_ids=None,
)

# Get LLM client and select model
Expand Down
4 changes: 3 additions & 1 deletion src/app/endpoints/query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
)

# Run shield moderation before calling LLM
moderation_result = await run_shield_moderation(client, input_text)
moderation_result = await run_shield_moderation(
client, input_text, query_request.shield_ids
)
if moderation_result.blocked:
violation_message = moderation_result.message or ""
await append_turn_to_conversation(
Expand Down
4 changes: 3 additions & 1 deletion src/app/endpoints/streaming_query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ async def retrieve_response( # pylint: disable=too-many-locals
)

# Run shield moderation before calling LLM
moderation_result = await run_shield_moderation(client, input_text)
moderation_result = await run_shield_moderation(
client, input_text, query_request.shield_ids
)
if moderation_result.blocked:
violation_message = moderation_result.message or ""
await append_turn_to_conversation(
Expand Down
9 changes: 9 additions & 0 deletions src/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class QueryRequest(BaseModel):
generate_topic_summary: Whether to generate topic summary for new conversations.
media_type: The optional media type for response format (application/json or text/plain).
vector_store_ids: The optional list of specific vector store IDs to query for RAG.
shield_ids: The optional list of safety shield IDs to apply.

Example:
```python
Expand Down Expand Up @@ -166,6 +167,14 @@ class QueryRequest(BaseModel):
examples=["ocp_docs", "knowledge_base", "vector_db_1"],
)

shield_ids: Optional[list[str]] = Field(
None,
description="Optional list of safety shield IDs to apply. "
"If None, all configured shields are used. "
"If empty list, all shields are skipped.",
examples=["llama-guard", "custom-shield"],
)

# provides examples for /docs endpoint
model_config = {
"extra": "forbid",
Expand Down
28 changes: 25 additions & 3 deletions src/utils/shields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utility functions for working with Llama Stack shields."""

import logging
from typing import Any, cast
from typing import Any, Optional, cast

from fastapi import HTTPException
from llama_stack_client import AsyncLlamaStackClient, BadRequestError
Expand Down Expand Up @@ -63,26 +63,48 @@ def detect_shield_violations(output_items: list[Any]) -> bool:
async def run_shield_moderation(
client: AsyncLlamaStackClient,
input_text: str,
shield_ids: Optional[list[str]] = None,
) -> ShieldModerationResult:
"""
Run shield moderation on input text.

Iterates through all configured shields and runs moderation checks.
Iterates through configured shields and runs moderation checks.
Raises HTTPException if shield model is not found.

Parameters:
client: The Llama Stack client.
input_text: The text to moderate.
shield_ids: Optional list of shield IDs to use. If None, uses all shields.
If empty list, skips all shields.

Returns:
ShieldModerationResult: Result indicating if content was blocked and the message.

Raises:
HTTPException: If shield's provider_resource_id is not configured or model not found.
"""
Comment on lines 68 to 85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Document the new 422 raise path in the docstring.
The function now raises HTTP 422 for invalid shield_ids, but the Raises section only mentions missing models.

✏️ Proposed docstring update
     Raises:
-        HTTPException: If shield's provider_resource_id is not configured or model not found.
+        HTTPException: If shield's provider_resource_id is not configured, model not found,
+            or the requested shield_ids do not match any available shields.
As per coding guidelines, Follow Google Python docstring conventions including Parameters, Returns, Raises, and Attributes sections as needed.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""
Run shield moderation on input text.
Iterates through all configured shields and runs moderation checks.
Iterates through configured shields and runs moderation checks.
Raises HTTPException if shield model is not found.
Parameters:
client: The Llama Stack client.
input_text: The text to moderate.
shield_ids: Optional list of shield IDs to use. If None, uses all shields.
If empty list, skips all shields.
Returns:
ShieldModerationResult: Result indicating if content was blocked and the message.
Raises:
HTTPException: If shield's provider_resource_id is not configured or model not found.
"""
"""
Run shield moderation on input text.
Iterates through configured shields and runs moderation checks.
Raises HTTPException if shield model is not found.
Parameters:
client: The Llama Stack client.
input_text: The text to moderate.
shield_ids: Optional list of shield IDs to use. If None, uses all shields.
If empty list, skips all shields.
Returns:
ShieldModerationResult: Result indicating if content was blocked and the message.
Raises:
HTTPException: If shield's provider_resource_id is not configured, model not found,
or the requested shield_ids do not match any available shields.
"""
🤖 Prompt for AI Agents
In `@src/utils/shields.py` around lines 68 - 85, The docstring for
run_shield_moderation is missing documentation of the new HTTP 422 error path
for invalid shield_ids; update the Google-style docstring for
run_shield_moderation to include in the Raises section that an HTTPException
with status 422 is raised when shield_ids is invalid (e.g., contains unknown IDs
or is malformed), and ensure the Parameters, Returns, and Raises sections are
present and concise per Google Python docstring conventions, referencing the
ShieldModerationResult return type and the existing 500/404 or other
HTTPException cases for missing provider_resource_id or model-not-found.

all_shields = await client.shields.list()

# Filter shields based on shield_ids parameter
if shield_ids is not None:
if len(shield_ids) == 0:
logger.info("shield_ids=[] provided, skipping all shields")
return ShieldModerationResult(blocked=False)

shields_to_run = [s for s in all_shields if s.identifier in shield_ids]

# Log warning if requested shield not found
requested = set(shield_ids)
available = {s.identifier for s in shields_to_run}
missing = requested - available
if missing:
logger.warning("Requested shields not found: %s", missing)
else:
shields_to_run = list(all_shields)

available_models = {model.id for model in await client.models.list()}

for shield in await client.shields.list():
for shield in shields_to_run:
if (
not shield.provider_resource_id
or shield.provider_resource_id not in available_models
Expand Down
Loading