Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 118 additions & 17 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from constants import (
INTERRUPTED_RESPONSE_MESSAGE,
LLM_TOKEN_EVENT,
LLM_TOOL_CALL_EVENT,
LLM_TOOL_RESULT_EVENT,
Expand Down Expand Up @@ -320,6 +321,110 @@ async def retrieve_response_generator(
raise HTTPException(**error_response.model_dump()) from e


async def _persist_interrupted_turn(
context: ResponseGeneratorContext,
responses_params: ResponsesApiParams,
turn_summary: TurnSummary,
) -> None:
"""Persist the user query and an interrupted response into the conversation.

Called when a streaming request is cancelled so the exchange is not lost.
All errors are caught and logged to avoid masking the original
cancellation.

Parameters:
context: The response generator context.
responses_params: The Responses API parameters.
turn_summary: TurnSummary with llm_response already set to the
interrupted message.
"""
try:
await append_turn_to_conversation(
context.client,
responses_params.conversation,
responses_params.input,
INTERRUPTED_RESPONSE_MESSAGE,
)
except Exception: # pylint: disable=broad-except
logger.exception(
"Failed to append interrupted turn to conversation for request %s",
context.request_id,
)

try:
completed_at = datetime.datetime.now(datetime.UTC).strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
store_query_results(
user_id=context.user_id,
conversation_id=context.conversation_id,
model=responses_params.model,
completed_at=completed_at,
started_at=context.started_at,
summary=turn_summary,
query_request=context.query_request,
skip_userid_check=context.skip_userid_check,
topic_summary=None,
)
except Exception: # pylint: disable=broad-except
logger.exception(
"Failed to store interrupted query results for request %s",
context.request_id,
)


def _register_interrupt_callback(
context: ResponseGeneratorContext,
responses_params: ResponsesApiParams,
turn_summary: TurnSummary,
) -> list[bool]:
"""Build an interrupt callback and register the stream for cancellation.

The callback is scheduled as a **separate** asyncio task by
``cancel_stream`` so it executes regardless of where the
``CancelledError`` is raised in the ASGI stack.

A mutable one-element list is used as a shared guard so the
callback and the in-generator ``CancelledError`` handler never
both persist the same turn.

Parameters:
context: The response generator context.
responses_params: The Responses API parameters.
turn_summary: TurnSummary populated during streaming.

Returns:
A mutable list ``[False]`` used as a persist-done guard; the
caller should check ``guard[0]`` before persisting and set
it to ``True`` afterwards.
"""
guard: list[bool] = [False]

async def _on_interrupt() -> None:
if guard[0]:
return
guard[0] = True
turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE
await _persist_interrupted_turn(context, responses_params, turn_summary)

current_task = asyncio.current_task()
if current_task is not None:
get_stream_interrupt_registry().register_stream(
request_id=context.request_id,
user_id=context.user_id,
task=current_task,
on_interrupt=_on_interrupt,
)
else:
logger.warning(
"No current asyncio task for request %s; "
"stream interruption will not be available",
context.request_id,
)

return guard


async def generate_response(
generator: AsyncIterator[str],
context: ResponseGeneratorContext,
Expand All @@ -330,9 +435,9 @@ async def generate_response(

Re-yields events from the generator, handles errors, and ensures
persistence and token consumption after completion. When the
stream is interrupted via ``CancelledError``, all post-stream side
effects (token consumption, result storage) are skipped and the
request is deregistered from the interrupt registry.
stream is interrupted via ``CancelledError``, the user query and
an interrupted response are persisted to the conversation, but
token consumption is skipped (no usage data is available).

Args:
generator: The base generator to wrap
Expand All @@ -343,20 +448,9 @@ async def generate_response(
Yields:
SSE-formatted strings from the wrapped generator
"""
user_id = context.user_id

current_task = asyncio.current_task()
if current_task is not None:
get_stream_interrupt_registry().register_stream(
request_id=context.request_id,
user_id=user_id,
task=current_task,
)
else:
logger.warning(
"No current asyncio task for request %s; stream interruption will not be available",
context.request_id,
)
persist_guard = _register_interrupt_callback(
context, responses_params, turn_summary
)

stream_completed = False
try:
Expand Down Expand Up @@ -390,6 +484,13 @@ async def generate_response(
yield stream_http_error_event(error_response, context.query_request.media_type)
except asyncio.CancelledError:
logger.info("Streaming request %s interrupted by user", context.request_id)
current_task = asyncio.current_task()
if current_task is not None:
current_task.uncancel()
if not persist_guard[0]:
persist_guard[0] = True
turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE
await _persist_interrupted_turn(context, responses_params, turn_summary)
yield stream_interrupted_event(context.request_id)
finally:
get_stream_interrupt_registry().deregister_stream(context.request_id)
Expand Down
3 changes: 3 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

UNABLE_TO_PROCESS_RESPONSE = "Unable to process this request"

# Response stored in the conversation when the user interrupts a streaming request
INTERRUPTED_RESPONSE_MESSAGE = "You interrupted this request."

# Supported attachment types
ATTACHMENT_TYPES = frozenset(
{
Expand Down
35 changes: 31 additions & 4 deletions src/utils/stream_interrupts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""In-memory registry for interrupting active streaming requests."""

import asyncio
from dataclasses import dataclass
from collections.abc import Callable, Coroutine
from typing import Any
from dataclasses import dataclass, field
from enum import Enum
from threading import Lock
from log import get_logger
Expand All @@ -17,10 +19,16 @@ class ActiveStream:
Attributes:
user_id: Owner of the streaming request.
task: Asyncio task producing the stream response.
on_interrupt: Optional async callback invoked when the stream
is cancelled, scheduled as a separate task so it runs
regardless of where the ``CancelledError`` lands.
"""

user_id: str
task: asyncio.Task[None]
on_interrupt: Callable[[], Coroutine[Any, Any, None]] | None = field(
default=None, repr=False
)


class CancelStreamResult(str, Enum):
Expand All @@ -41,17 +49,25 @@ def __init__(self) -> None:
self._lock = Lock()

def register_stream(
self, request_id: str, user_id: str, task: asyncio.Task[None]
self,
request_id: str,
user_id: str,
task: asyncio.Task[None],
on_interrupt: Callable[[], Coroutine[Any, Any, None]] | None = None,
) -> None:
"""Register an active stream task for interrupt support.

Parameters:
request_id: Unique streaming request identifier.
user_id: User identifier that owns the stream.
task: Asyncio task associated with the stream.
on_interrupt: Optional async callback to run when the stream
is cancelled, executed in a separate task.
"""
with self._lock:
self._streams[request_id] = ActiveStream(user_id=user_id, task=task)
self._streams[request_id] = ActiveStream(
user_id=user_id, task=task, on_interrupt=on_interrupt
)

def cancel_stream(self, request_id: str, user_id: str) -> CancelStreamResult:
"""Cancel an active stream owned by user.
Expand All @@ -60,13 +76,19 @@ def cancel_stream(self, request_id: str, user_id: str) -> CancelStreamResult:
lock so that a concurrent ``deregister_stream`` cannot remove
the entry between the ownership check and the cancel call.

When an ``on_interrupt`` callback was registered, it is
scheduled as a **separate** asyncio task after the cancel so
persistence runs regardless of where the ``CancelledError``
is raised (inside the generator or in Starlette's send).

Parameters:
request_id: Unique streaming request identifier.
user_id: User identifier attempting the interruption.

Returns:
CancelStreamResult: Structured cancellation result.
"""
on_interrupt = None
with self._lock:
stream = self._streams.get(request_id)
if stream is None:
Expand All @@ -81,7 +103,12 @@ def cancel_stream(self, request_id: str, user_id: str) -> CancelStreamResult:
if stream.task.done():
return CancelStreamResult.ALREADY_DONE
stream.task.cancel()
return CancelStreamResult.CANCELLED
on_interrupt = stream.on_interrupt

if on_interrupt is not None:
asyncio.get_running_loop().create_task(on_interrupt())

return CancelStreamResult.CANCELLED

def deregister_stream(self, request_id: str) -> None:
"""Remove stream task from registry once completed/cancelled.
Expand Down
Loading
Loading