Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,21 @@ async def retrieve_response( # pylint: disable=too-many-locals
response: Optional[OpenAIResponseObject] = None
try:
moderation_result = await run_shield_moderation(
client, responses_params.input, shield_ids
client, cast(str, responses_params.input), shield_ids
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Query endpoints always use string-like inputs so it is safe to cast.

)
if moderation_result.decision == "blocked":
# Handle shield moderation blocking
violation_message = moderation_result.message
await append_turn_to_conversation(
client,
responses_params.conversation,
responses_params.input,
cast(str, responses_params.input),
violation_message,
)
return TurnSummary(llm_response=violation_message)
response = await client.responses.create(**responses_params.model_dump())
response = await client.responses.create(
**responses_params.model_dump(exclude_none=True)
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly drop unset attributes (queries use reduced set of responses attributes).

response = cast(OpenAIResponseObject, response)

except RuntimeError as e: # library mode wraps 413 into runtime error
Expand Down
10 changes: 6 additions & 4 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,16 @@ async def retrieve_response_generator(
turn_summary = TurnSummary()
try:
moderation_result = await run_shield_moderation(
context.client, responses_params.input, context.query_request.shield_ids
context.client,
cast(str, responses_params.input),
context.query_request.shield_ids,
)
if moderation_result.decision == "blocked":
turn_summary.llm_response = moderation_result.message
await append_turn_to_conversation(
context.client,
responses_params.conversation,
responses_params.input,
cast(str, responses_params.input),
moderation_result.message,
)
media_type = context.query_request.media_type or MEDIA_TYPE_JSON
Expand All @@ -302,7 +304,7 @@ async def retrieve_response_generator(
)
# Retrieve response stream (may raise exceptions)
response = await context.client.responses.create(
**responses_params.model_dump()
**responses_params.model_dump(exclude_none=True)
)
# Store pre-RAG documents for later merging
turn_summary.pre_rag_documents = doc_ids_from_chunks
Expand Down Expand Up @@ -347,7 +349,7 @@ async def _persist_interrupted_turn(
await append_turn_to_conversation(
context.client,
responses_params.conversation,
responses_params.input,
cast(str, responses_params.input),
INTERRUPTED_RESPONSE_MESSAGE,
)
except Exception: # pylint: disable=broad-except
Expand Down
2 changes: 1 addition & 1 deletion src/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def store_query_results( # pylint: disable=too-many-arguments
summary: TurnSummary,
query: str,
skip_userid_check: bool,
attachments: list[Attachment] | None = None,
attachments: Optional[list[Attachment]] = None,
topic_summary: Optional[str] = None,
) -> None:
"""
Expand Down
Loading
Loading