diff --git a/docs/openapi.json b/docs/openapi.json index 6174075cf..57ca778ee 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1661,7 +1661,7 @@ "type": "string", "format": "text/event-stream" }, - "example": "data: {\"event\": \"start\", \"data\": {\"conversation_id\": \"123e4567-e89b-12d3-a456-426614174000\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 0, \"token\": \"No Violation\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 1, \"token\": \"\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 2, \"token\": \"Hello\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 3, \"token\": \"!\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 4, \"token\": \" How\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 5, \"token\": \" can\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 6, \"token\": \" I\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 7, \"token\": \" assist\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 8, \"token\": \" you\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 9, \"token\": \" today\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 10, \"token\": \"?\"}}\n\ndata: {\"event\": \"turn_complete\", \"data\": {\"token\": \"Hello! How can I assist you today?\"}}\n\ndata: {\"event\": \"end\", \"data\": {\"referenced_documents\": [], \"truncated\": null, \"input_tokens\": 11, \"output_tokens\": 19}, \"available_quotas\": {}}\n\n" + "example": "data: {\"event\": \"start\", \"data\": {\"conversation_id\": \"123e4567-e89b-12d3-a456-426614174000\", \"request_id\": \"123e4567-e89b-12d3-a456-426614174001\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 0, \"token\": \"No Violation\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 1, \"token\": \"\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 2, \"token\": \"Hello\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 3, \"token\": \"!\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 4, \"token\": \" How\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 5, \"token\": \" can\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 6, \"token\": \" I\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 7, \"token\": \" assist\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 8, \"token\": \" you\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 9, \"token\": \" today\"}}\n\ndata: {\"event\": \"token\", \"data\": {\"id\": 10, \"token\": \"?\"}}\n\ndata: {\"event\": \"turn_complete\", \"data\": {\"token\": \"Hello! How can I assist you today?\"}}\n\ndata: {\"event\": \"end\", \"data\": {\"referenced_documents\": [], \"truncated\": null, \"input_tokens\": 11, \"output_tokens\": 19}, \"available_quotas\": {}}\n\n" } } }, @@ -1912,6 +1912,121 @@ } } }, + "/v1/streaming_query/interrupt": { + "post": { + "tags": [ + "streaming_query_interrupt" + ], + "summary": "Streaming Query Interrupt Endpoint Handler", + "description": "Interrupt an in-progress streaming query by request identifier.\n\nParameters:\n interrupt_request: Request payload containing the stream request ID.\n auth: Auth context tuple resolved from the authentication dependency.\n registry: Stream interrupt registry dependency used to cancel streams.\n\nReturns:\n StreamingInterruptResponse: Confirmation payload when interruption succeeds.\n\nRaises:\n HTTPException: If no active stream for the given request ID can be interrupted.", + "operationId": "stream_interrupt_endpoint_handler_v1_streaming_query_interrupt_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StreamingInterruptRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StreamingInterruptResponse" + }, + "example": { + "interrupted": true, + "message": "Streaming request interrupted", + "request_id": "123e4567-e89b-12d3-a456-426614174000" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnauthorizedResponse" + }, + "examples": { + "missing header": { + "value": { + "detail": { + "cause": "No Authorization header found", + "response": "Missing or invalid credentials provided by client" + } + } + }, + "missing token": { + "value": { + "detail": { + "cause": "No token found in Authorization header", + "response": "Missing or invalid credentials provided by client" + } + } + } + } + } + } + }, + "403": { + "description": "Permission denied", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForbiddenResponse" + }, + "examples": { + "endpoint": { + "value": { + "detail": { + "cause": "User 6789 is not authorized to access this endpoint.", + "response": "User does not have permission to access this endpoint" + } + } + } + } + } + } + }, + "404": { + "description": "Resource not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/NotFoundResponse" + }, + "examples": { + "streaming request": { + "value": { + "detail": { + "cause": "Streaming Request with ID 123e4567-e89b-12d3-a456-426614174000 does not exist", + "response": "Streaming Request not found" + } + } + } + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/config": { "get": { "tags": [ @@ -4332,7 +4447,7 @@ ], "summary": "Handle A2A Jsonrpc", "description": "Handle A2A JSON-RPC requests following the A2A protocol specification.\n\nThis endpoint uses the DefaultRequestHandler from the A2A SDK to handle\nall JSON-RPC requests including message/send, message/stream, etc.\n\nThe A2A SDK application is created per-request to include authentication\ncontext while still leveraging FastAPI's authorization middleware.\n\nAutomatically detects streaming requests (message/stream JSON-RPC method)\nand returns a StreamingResponse to enable real-time chunk delivery.\n\nArgs:\n request: FastAPI request object\n auth: Authentication tuple\n mcp_headers: MCP headers for context propagation\n\nReturns:\n JSON-RPC response or streaming response", - "operationId": "handle_a2a_jsonrpc_a2a_post", + "operationId": "handle_a2a_jsonrpc_a2a_get", "responses": { "200": { "description": "Successful Response", @@ -4350,7 +4465,7 @@ ], "summary": "Handle A2A Jsonrpc", "description": "Handle A2A JSON-RPC requests following the A2A protocol specification.\n\nThis endpoint uses the DefaultRequestHandler from the A2A SDK to handle\nall JSON-RPC requests including message/send, message/stream, etc.\n\nThe A2A SDK application is created per-request to include authentication\ncontext while still leveraging FastAPI's authorization middleware.\n\nAutomatically detects streaming requests (message/stream JSON-RPC method)\nand returns a StreamingResponse to enable real-time chunk delivery.\n\nArgs:\n request: FastAPI request object\n auth: Authentication tuple\n mcp_headers: MCP headers for context propagation\n\nReturns:\n JSON-RPC response or streaming response", - "operationId": "handle_a2a_jsonrpc_a2a_post", + "operationId": "handle_a2a_jsonrpc_a2a_get", "responses": { "200": { "description": "Successful Response", @@ -7349,6 +7464,13 @@ "response": "Rag not found" }, "label": "rag" + }, + { + "detail": { + "cause": "Streaming Request with ID 123e4567-e89b-12d3-a456-426614174000 does not exist", + "response": "Streaming Request not found" + }, + "label": "streaming request" } ] }, @@ -9224,6 +9346,73 @@ } ] }, + "StreamingInterruptRequest": { + "properties": { + "request_id": { + "type": "string", + "title": "Request Id", + "description": "The active streaming request ID to interrupt", + "examples": [ + "123e4567-e89b-12d3-a456-426614174000" + ] + } + }, + "additionalProperties": false, + "type": "object", + "required": [ + "request_id" + ], + "title": "StreamingInterruptRequest", + "description": "Model representing a request to interrupt an active streaming query.\n\nAttributes:\n request_id: Unique ID of the active streaming request to interrupt.", + "examples": [ + { + "request_id": "123e4567-e89b-12d3-a456-426614174000" + } + ] + }, + "StreamingInterruptResponse": { + "properties": { + "request_id": { + "type": "string", + "title": "Request Id", + "description": "The streaming request ID targeted by the interrupt call", + "examples": [ + "123e4567-e89b-12d3-a456-426614174000" + ] + }, + "interrupted": { + "type": "boolean", + "title": "Interrupted", + "description": "Whether an in-progress stream was interrupted", + "examples": [ + true + ] + }, + "message": { + "type": "string", + "title": "Message", + "description": "Human-readable interruption status message", + "examples": [ + "Streaming request interrupted" + ] + } + }, + "type": "object", + "required": [ + "request_id", + "interrupted", + "message" + ], + "title": "StreamingInterruptResponse", + "description": "Model representing a response to a streaming interrupt request.\n\nAttributes:\n request_id: The streaming request ID targeted by the interrupt call.\n interrupted: Whether an in-progress stream was interrupted.\n message: Human-readable interruption status message.\n\nExample:\n ```python\n response = StreamingInterruptResponse(\n request_id=\"123e4567-e89b-12d3-a456-426614174000\",\n interrupted=True,\n message=\"Streaming request interrupted\",\n )\n ```", + "examples": [ + { + "interrupted": true, + "message": "Streaming request interrupted", + "request_id": "123e4567-e89b-12d3-a456-426614174000" + } + ] + }, "TLSConfiguration": { "properties": { "tls_certificate_path": { diff --git a/src/app/endpoints/stream_interrupt.py b/src/app/endpoints/stream_interrupt.py new file mode 100644 index 000000000..55eb58672 --- /dev/null +++ b/src/app/endpoints/stream_interrupt.py @@ -0,0 +1,91 @@ +"""Endpoint for interrupting in-progress streaming query requests.""" + +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException + +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from models.config import Action +from models.requests import StreamingInterruptRequest +from models.responses import ( + ForbiddenResponse, + NotFoundResponse, + StreamingInterruptResponse, + UnauthorizedResponse, +) +from utils.stream_interrupts import ( + CancelStreamResult, + StreamInterruptRegistry, + get_stream_interrupt_registry, +) + +router = APIRouter(tags=["streaming_query_interrupt"]) + +stream_interrupt_responses: dict[int | str, dict[str, Any]] = { + 200: StreamingInterruptResponse.openapi_response(), + 401: UnauthorizedResponse.openapi_response( + examples=["missing header", "missing token"] + ), + 403: ForbiddenResponse.openapi_response(examples=["endpoint"]), + 404: NotFoundResponse.openapi_response(examples=["streaming request"]), +} + + +@router.post( + "/streaming_query/interrupt", + responses=stream_interrupt_responses, + summary="Streaming Query Interrupt Endpoint Handler", +) +@authorize(Action.STREAMING_QUERY) +async def stream_interrupt_endpoint_handler( + interrupt_request: StreamingInterruptRequest, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], + registry: Annotated[ + StreamInterruptRegistry, Depends(get_stream_interrupt_registry) + ], +) -> StreamingInterruptResponse: + """Interrupt an in-progress streaming query by request identifier. + + Parameters: + interrupt_request: Request payload containing the stream request ID. + auth: Auth context tuple resolved from the authentication dependency. + registry: Stream interrupt registry dependency used to cancel streams. + + Returns: + StreamingInterruptResponse: Confirmation payload when interruption succeeds. + + Raises: + HTTPException: If no active stream for the given request ID can be interrupted. + """ + user_id, _, _, _ = auth + request_id = interrupt_request.request_id + cancel_result = registry.cancel_stream(request_id, user_id) + if cancel_result == CancelStreamResult.NOT_FOUND: + response = NotFoundResponse( + resource="streaming request", + resource_id=request_id, + ) + raise HTTPException(**response.model_dump()) + if cancel_result == CancelStreamResult.FORBIDDEN: + response = ForbiddenResponse( + response="User does not have permission to interrupt this streaming request", + cause=( + f"User {user_id} does not own streaming request " + f"with ID {request_id}" + ), + ) + raise HTTPException(**response.model_dump()) + if cancel_result == CancelStreamResult.ALREADY_DONE: + return StreamingInterruptResponse( + request_id=request_id, + interrupted=False, + message="Streaming request already completed; nothing to interrupt", + ) + + return StreamingInterruptResponse( + request_id=request_id, + interrupted=True, + message="Streaming request interrupted", + ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index b45c4f625..d0eb3d401 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,5 +1,6 @@ """Streaming query handler using Responses API.""" +import asyncio import datetime import json @@ -83,7 +84,8 @@ append_turn_to_conversation, run_shield_moderation, ) -from utils.suid import normalize_conversation_id +from utils.stream_interrupts import get_stream_interrupt_registry +from utils.suid import get_suid, normalize_conversation_id from utils.token_counter import TokenCounter from utils.types import ResponsesApiParams, TurnSummary from utils.vector_search import format_rag_context_for_injection, perform_vector_search @@ -206,9 +208,12 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals ): client = await update_azure_token(client) + request_id = get_suid() + # Create context with index identification mapping for RAG source resolution context = ResponseGeneratorContext( conversation_id=normalize_conversation_id(responses_params.conversation), + request_id=request_id, model_id=responses_params.model, user_id=user_id, skip_userid_check=_skip_userid_check, @@ -322,7 +327,10 @@ async def generate_response( """Wrap a generator with cleanup logic. Re-yields events from the generator, handles errors, and ensures - persistence and token consumption after completion. + persistence and token consumption after completion. When the + stream is interrupted via ``CancelledError``, all post-stream side + effects (token consumption, result storage) are skipped and the + request is deregistered from the interrupt registry. Args: generator: The base generator to wrap @@ -333,13 +341,34 @@ async def generate_response( Yields: SSE-formatted strings from the wrapped generator """ - yield stream_start_event(context.conversation_id) + user_id = context.user_id - # Re-yield all events from the generator + current_task = asyncio.current_task() + if current_task is not None: + get_stream_interrupt_registry().register_stream( + request_id=context.request_id, + user_id=user_id, + task=current_task, + ) + else: + logger.warning( + "No current asyncio task for request %s; stream interruption will not be available", + context.request_id, + ) + + stream_completed = False try: + yield stream_start_event( + conversation_id=context.conversation_id, + request_id=context.request_id, + ) + + # Re-yield all events from the generator async for event in generator: yield event + stream_completed = True + # Handle known LLS client errors during response generation time except RuntimeError as e: # library mode wraps 413 into runtime error error_response = ( @@ -348,19 +377,26 @@ async def generate_response( else InternalServerErrorResponse.generic() ) yield stream_http_error_event(error_response, context.query_request.media_type) - return except APIConnectionError as e: error_response = ServiceUnavailableResponse( backend_name="Llama Stack", cause=str(e), ) yield stream_http_error_event(error_response, context.query_request.media_type) - return except (LLSApiStatusError, OpenAIAPIStatusError) as e: error_response = handle_known_apistatus_errors(e, responses_params.model) yield stream_http_error_event(error_response, context.query_request.media_type) + except asyncio.CancelledError: + logger.info("Streaming request %s interrupted by user", context.request_id) + yield stream_interrupted_event(context.request_id) + finally: + get_stream_interrupt_registry().deregister_stream(context.request_id) + + if not stream_completed: return + # Post-stream side effects: only run when streaming finished successfully + # Get topic summary for new conversations if needed topic_summary = None if not context.query_request.conversation_id: @@ -659,16 +695,17 @@ def format_stream_data(d: dict) -> str: return f"data: {data}\n\n" -def stream_start_event(conversation_id: str) -> str: - """ - Yield the start of the data stream. +def stream_start_event(conversation_id: str, request_id: str) -> str: + """Format an SSE start event for a streaming response. - Format a Server-Sent Events (SSE) start event containing the - conversation ID. + The payload contains both the conversation ID and the request ID + so the client can correlate the stream with a conversation and + use the request ID to issue an interrupt if needed. Parameters: - conversation_id (str): Unique identifier for the - conversation. + conversation_id (str): Unique identifier for the conversation. + request_id (str): Unique SUID for this streaming request, + returned to the client for interrupt support. Returns: str: SSE-formatted string representing the start event. @@ -678,6 +715,30 @@ def stream_start_event(conversation_id: str) -> str: "event": "start", "data": { "conversation_id": conversation_id, + "request_id": request_id, + }, + } + ) + + +def stream_interrupted_event(request_id: str) -> str: + """Format an SSE event indicating the stream was interrupted. + + Emitted to the client just before the generator closes so the + frontend can distinguish an intentional user-initiated interruption + from an unexpected connection drop. + + Parameters: + request_id (str): Unique identifier for the interrupted request. + + Returns: + str: SSE-formatted string representing the interrupted event. + """ + return format_stream_data( + { + "event": "interrupted", + "data": { + "request_id": request_id, }, } ) diff --git a/src/app/routers.py b/src/app/routers.py index 95c91552e..78663e18f 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -13,6 +13,7 @@ config, feedback, streaming_query, + stream_interrupt, authorized, conversations_v2, conversations_v1, @@ -52,6 +53,7 @@ def include_routers(app: FastAPI) -> None: # Query endpoints app.include_router(query.router, prefix="/v1") app.include_router(streaming_query.router, prefix="/v1") + app.include_router(stream_interrupt.router, prefix="/v1") app.include_router(config.router, prefix="/v1") app.include_router(feedback.router, prefix="/v1") app.include_router(conversations_v1.router, prefix="/v1") diff --git a/src/models/context.py b/src/models/context.py index 68e2d25c9..2ef76f36d 100644 --- a/src/models/context.py +++ b/src/models/context.py @@ -16,6 +16,7 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes Attributes: conversation_id: The conversation identifier + request_id: Unique identifier for the streaming request user_id: The user identifier skip_userid_check: Whether to skip user ID validation model_id: The model identifier @@ -28,6 +29,7 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes # Conversation & User context conversation_id: str + request_id: str user_id: str skip_userid_check: bool diff --git a/src/models/requests.py b/src/models/requests.py index 4448940af..4cc5dc429 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -267,6 +267,46 @@ def validate_media_type(self) -> Self: return self +class StreamingInterruptRequest(BaseModel): + """Model representing a request to interrupt an active streaming query. + + Attributes: + request_id: Unique ID of the active streaming request to interrupt. + """ + + request_id: str = Field( + description="The active streaming request ID to interrupt", + examples=["123e4567-e89b-12d3-a456-426614174000"], + ) + + model_config = { + "extra": "forbid", + "json_schema_extra": { + "examples": [ + {"request_id": "123e4567-e89b-12d3-a456-426614174000"}, + ] + }, + } + + @field_validator("request_id") + @classmethod + def check_request_id(cls, value: str) -> str: + """Validate that request identifier matches expected SUID format. + + Parameters: + value: Request identifier submitted by the caller. + + Returns: + str: The validated request identifier. + + Raises: + ValueError: If the request identifier is not a valid SUID. + """ + if not suid.check_suid(value): + raise ValueError(f"Improper request ID {value}") + return value + + class FeedbackCategory(str, Enum): """Enum representing predefined feedback categories for AI responses. diff --git a/src/models/responses.py b/src/models/responses.py index 9b29be513..7df36cc7d 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -503,7 +503,8 @@ def openapi_response(cls) -> dict[str, Any]: "examples": [ ( 'data: {"event": "start", "data": {' - '"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}}\n\n' + '"conversation_id": "123e4567-e89b-12d3-a456-426614174000", ' + '"request_id": "123e4567-e89b-12d3-a456-426614174001"}}\n\n' 'data: {"event": "token", "data": {' '"id": 0, "token": "No Violation"}}\n\n' 'data: {"event": "token", "data": {' @@ -538,6 +539,52 @@ def openapi_response(cls) -> dict[str, Any]: } +class StreamingInterruptResponse(AbstractSuccessfulResponse): + """Model representing a response to a streaming interrupt request. + + Attributes: + request_id: The streaming request ID targeted by the interrupt call. + interrupted: Whether an in-progress stream was interrupted. + message: Human-readable interruption status message. + + Example: + ```python + response = StreamingInterruptResponse( + request_id="123e4567-e89b-12d3-a456-426614174000", + interrupted=True, + message="Streaming request interrupted", + ) + ``` + """ + + request_id: str = Field( + description="The streaming request ID targeted by the interrupt call", + examples=["123e4567-e89b-12d3-a456-426614174000"], + ) + + interrupted: bool = Field( + description="Whether an in-progress stream was interrupted", + examples=[True], + ) + + message: str = Field( + description="Human-readable interruption status message", + examples=["Streaming request interrupted"], + ) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "request_id": "123e4567-e89b-12d3-a456-426614174000", + "interrupted": True, + "message": "Streaming request interrupted", + } + ] + } + } + + class InfoResponse(AbstractSuccessfulResponse): """Model representing a response to an info request. @@ -1749,6 +1796,16 @@ class NotFoundResponse(AbstractErrorResponse): ), }, }, + { + "label": "streaming request", + "detail": { + "response": "Streaming Request not found", + "cause": ( + "Streaming Request with ID " + "123e4567-e89b-12d3-a456-426614174000 does not exist" + ), + }, + }, ] } } diff --git a/src/utils/stream_interrupts.py b/src/utils/stream_interrupts.py new file mode 100644 index 000000000..cf8cea180 --- /dev/null +++ b/src/utils/stream_interrupts.py @@ -0,0 +1,114 @@ +"""In-memory registry for interrupting active streaming requests.""" + +import asyncio +from dataclasses import dataclass +from enum import Enum +from threading import Lock +from log import get_logger +from utils.types import Singleton + +logger = get_logger(__name__) + + +@dataclass +class ActiveStream: + """Represents one active streaming request bound to a user. + + Attributes: + user_id: Owner of the streaming request. + task: Asyncio task producing the stream response. + """ + + user_id: str + task: asyncio.Task[None] + + +class CancelStreamResult(str, Enum): + """Outcomes when attempting to cancel a stream.""" + + CANCELLED = "cancelled" + NOT_FOUND = "not_found" + FORBIDDEN = "forbidden" + ALREADY_DONE = "already_done" + + +class StreamInterruptRegistry(metaclass=Singleton): + """Registry for active streaming tasks keyed by request ID.""" + + def __init__(self) -> None: + """Initialize an empty registry with a lock for thread-safety.""" + self._streams: dict[str, ActiveStream] = {} + self._lock = Lock() + + def register_stream( + self, request_id: str, user_id: str, task: asyncio.Task[None] + ) -> None: + """Register an active stream task for interrupt support. + + Parameters: + request_id: Unique streaming request identifier. + user_id: User identifier that owns the stream. + task: Asyncio task associated with the stream. + """ + with self._lock: + self._streams[request_id] = ActiveStream(user_id=user_id, task=task) + + def cancel_stream(self, request_id: str, user_id: str) -> CancelStreamResult: + """Cancel an active stream owned by user. + + The entire lookup-check-cancel sequence is performed under the + lock so that a concurrent ``deregister_stream`` cannot remove + the entry between the ownership check and the cancel call. + + Parameters: + request_id: Unique streaming request identifier. + user_id: User identifier attempting the interruption. + + Returns: + CancelStreamResult: Structured cancellation result. + """ + with self._lock: + stream = self._streams.get(request_id) + if stream is None: + return CancelStreamResult.NOT_FOUND + if stream.user_id != user_id: + logger.warning( + "User %s attempted to interrupt request %s owned by another user", + user_id, + request_id, + ) + return CancelStreamResult.FORBIDDEN + if stream.task.done(): + return CancelStreamResult.ALREADY_DONE + stream.task.cancel() + return CancelStreamResult.CANCELLED + + def deregister_stream(self, request_id: str) -> None: + """Remove stream task from registry once completed/cancelled. + + Parameters: + request_id: Unique streaming request identifier. + """ + with self._lock: + self._streams.pop(request_id, None) + + def get_stream(self, request_id: str) -> ActiveStream | None: + """Get currently registered stream metadata for tests/introspection. + + Parameters: + request_id: Unique streaming request identifier. + + Returns: + ActiveStream | None: Registered stream metadata, or None when absent. + """ + with self._lock: + return self._streams.get(request_id) + + +def get_stream_interrupt_registry() -> StreamInterruptRegistry: + """Return the module-level interrupt registry. + + Exposed as a callable so it can be used as a FastAPI dependency + and overridden in tests via ``app.dependency_overrides``. + """ + return StreamInterruptRegistry() diff --git a/tests/integration/endpoints/test_stream_interrupt_integration.py b/tests/integration/endpoints/test_stream_interrupt_integration.py new file mode 100644 index 000000000..6788f08ef --- /dev/null +++ b/tests/integration/endpoints/test_stream_interrupt_integration.py @@ -0,0 +1,66 @@ +"""Integration tests for the streaming query interrupt lifecycle.""" + +import asyncio +from collections.abc import Generator + +import pytest +from fastapi import HTTPException + +from app.endpoints.stream_interrupt import stream_interrupt_endpoint_handler +from models.requests import StreamingInterruptRequest +from utils.stream_interrupts import StreamInterruptRegistry + +TEST_REQUEST_ID = "123e4567-e89b-12d3-a456-426614174003" +OWNER_USER_ID = "00000001-0001-0001-0001-000000000001" + + +@pytest.fixture(name="registry") +def registry_fixture() -> Generator[StreamInterruptRegistry, None, None]: + """Provide singleton registry with deterministic per-test cleanup.""" + registry = StreamInterruptRegistry() + registry.deregister_stream(TEST_REQUEST_ID) + yield registry + registry.deregister_stream(TEST_REQUEST_ID) + + +@pytest.mark.asyncio +async def test_stream_interrupt_full_round_trip( + registry: StreamInterruptRegistry, +) -> None: + """Full lifecycle: register, interrupt, then verify deregistration.""" + + async def pending_stream() -> None: + await asyncio.sleep(10) + + task = asyncio.create_task(pending_stream()) + registry.register_stream(TEST_REQUEST_ID, OWNER_USER_ID, task) + + assert registry.get_stream(TEST_REQUEST_ID) is not None + + response = await stream_interrupt_endpoint_handler( + interrupt_request=StreamingInterruptRequest(request_id=TEST_REQUEST_ID), + auth=(OWNER_USER_ID, "mock_username", False, "mock_token"), + registry=registry, + ) + assert response.interrupted is True + + with pytest.raises(asyncio.CancelledError): + await task + + completed_response = await stream_interrupt_endpoint_handler( + interrupt_request=StreamingInterruptRequest(request_id=TEST_REQUEST_ID), + auth=(OWNER_USER_ID, "mock_username", False, "mock_token"), + registry=registry, + ) + assert completed_response.interrupted is False + + registry.deregister_stream(TEST_REQUEST_ID) + assert registry.get_stream(TEST_REQUEST_ID) is None + + with pytest.raises(HTTPException) as exc_info: + await stream_interrupt_endpoint_handler( + interrupt_request=StreamingInterruptRequest(request_id=TEST_REQUEST_ID), + auth=(OWNER_USER_ID, "mock_username", False, "mock_token"), + registry=registry, + ) + assert exc_info.value.status_code == 404 diff --git a/tests/integration/test_openapi_json.py b/tests/integration/test_openapi_json.py index 53dbacc02..bf93f5b0e 100644 --- a/tests/integration/test_openapi_json.py +++ b/tests/integration/test_openapi_json.py @@ -223,6 +223,11 @@ def test_servers_section_present_from_url(spec_from_url: dict[str, Any]) -> None "post", {"200", "401", "403", "404", "422", "429", "500", "503"}, ), + ( + "/v1/streaming_query/interrupt", + "post", + {"200", "401", "403", "404"}, + ), ("/v1/config", "get", {"200", "401", "403", "500"}), ("/v1/feedback", "post", {"200", "401", "403", "404", "500"}), ("/v1/feedback/status", "get", {"200"}), @@ -305,6 +310,11 @@ def test_paths_and_responses_exist_from_file( "post", {"200", "401", "403", "404", "422", "429", "500", "503"}, ), + ( + "/v1/streaming_query/interrupt", + "post", + {"200", "401", "403", "404"}, + ), ("/v1/config", "get", {"200", "401", "403", "500"}), ("/v1/feedback", "post", {"200", "401", "403", "404", "500"}), ("/v1/feedback/status", "get", {"200"}), diff --git a/tests/unit/app/endpoints/test_stream_interrupt.py b/tests/unit/app/endpoints/test_stream_interrupt.py new file mode 100644 index 000000000..8a767ee36 --- /dev/null +++ b/tests/unit/app/endpoints/test_stream_interrupt.py @@ -0,0 +1,150 @@ +"""Unit tests for streaming query interrupt endpoint.""" + +import asyncio +from collections.abc import Generator + +import pytest +from fastapi import HTTPException + +from app.endpoints.stream_interrupt import stream_interrupt_endpoint_handler +from models.requests import StreamingInterruptRequest +from models.responses import StreamingInterruptResponse +from utils.stream_interrupts import StreamInterruptRegistry + +REQUEST_ID_SUCCESS = "123e4567-e89b-12d3-a456-426614174000" +REQUEST_ID_NOT_FOUND = "123e4567-e89b-12d3-a456-426614174001" +REQUEST_ID_WRONG_USER = "123e4567-e89b-12d3-a456-426614174002" +REQUEST_ID_ALREADY_COMPLETED = "123e4567-e89b-12d3-a456-426614174004" + +OWNER_USER_ID = "00000001-0001-0001-0001-000000000001" +NON_OWNER_USER_ID = "00000001-0001-0001-0001-000000000999" + +TEST_REQUEST_IDS = ( + REQUEST_ID_SUCCESS, + REQUEST_ID_NOT_FOUND, + REQUEST_ID_WRONG_USER, + REQUEST_ID_ALREADY_COMPLETED, +) + + +@pytest.fixture(name="registry") +def registry_fixture() -> Generator[StreamInterruptRegistry, None, None]: + """Provide singleton registry with deterministic per-test cleanup.""" + registry = StreamInterruptRegistry() + for request_id in TEST_REQUEST_IDS: + registry.deregister_stream(request_id) + yield registry + for request_id in TEST_REQUEST_IDS: + registry.deregister_stream(request_id) + + +@pytest.mark.asyncio +async def test_stream_interrupt_endpoint_success( + registry: StreamInterruptRegistry, +) -> None: + """Interrupt endpoint cancels an active stream for the same user.""" + + async def pending_stream() -> None: + await asyncio.sleep(10) + + task = asyncio.create_task(pending_stream()) + registry.register_stream(REQUEST_ID_SUCCESS, OWNER_USER_ID, task) + + response = await stream_interrupt_endpoint_handler( + interrupt_request=StreamingInterruptRequest(request_id=REQUEST_ID_SUCCESS), + auth=(OWNER_USER_ID, "mock_username", False, "mock_token"), + registry=registry, + ) + + assert isinstance(response, StreamingInterruptResponse) + assert response.request_id == REQUEST_ID_SUCCESS + assert response.interrupted is True + + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_stream_interrupt_endpoint_not_found( + registry: StreamInterruptRegistry, +) -> None: + """Interrupt endpoint returns 404 for unknown request id.""" + with pytest.raises(HTTPException) as exc_info: + await stream_interrupt_endpoint_handler( + interrupt_request=StreamingInterruptRequest( + request_id=REQUEST_ID_NOT_FOUND + ), + auth=( + OWNER_USER_ID, + "mock_username", + False, + "mock_token", + ), + registry=registry, + ) + + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_stream_interrupt_endpoint_wrong_user( + registry: StreamInterruptRegistry, +) -> None: + """Interrupt endpoint does not cancel streams owned by other users.""" + + async def pending_stream() -> None: + await asyncio.sleep(10) + + task = asyncio.create_task(pending_stream()) + registry.register_stream( + request_id=REQUEST_ID_WRONG_USER, + user_id=OWNER_USER_ID, + task=task, + ) + + with pytest.raises(HTTPException) as exc_info: + await stream_interrupt_endpoint_handler( + interrupt_request=StreamingInterruptRequest( + request_id=REQUEST_ID_WRONG_USER + ), + auth=( + NON_OWNER_USER_ID, + "mock_username", + False, + "mock_token", + ), + registry=registry, + ) + + assert exc_info.value.status_code == 403 + assert task.done() is False + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_stream_interrupt_endpoint_already_completed( + registry: StreamInterruptRegistry, +) -> None: + """Interrupt endpoint reports already-completed streams without error.""" + + async def completed_stream() -> None: + return None + + task = asyncio.create_task(completed_stream()) + await task + registry.register_stream(REQUEST_ID_ALREADY_COMPLETED, OWNER_USER_ID, task) + + response = await stream_interrupt_endpoint_handler( + interrupt_request=StreamingInterruptRequest( + request_id=REQUEST_ID_ALREADY_COMPLETED + ), + auth=(OWNER_USER_ID, "mock_username", False, "mock_token"), + registry=registry, + ) + + assert isinstance(response, StreamingInterruptResponse) + assert response.request_id == REQUEST_ID_ALREADY_COMPLETED + assert response.interrupted is False diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 998a295ce..f72794ebf 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -2,6 +2,7 @@ """Unit tests for the /streaming_query (v2) endpoint using Responses API.""" # pylint: disable=too-many-lines,too-many-function-args +import asyncio import json from collections.abc import AsyncIterator from typing import Any @@ -50,6 +51,7 @@ from models.requests import Attachment, QueryRequest from models.responses import InternalServerErrorResponse from utils.token_counter import TokenCounter +from utils.stream_interrupts import StreamInterruptRegistry from utils.types import ReferencedDocument, ResponsesApiParams, TurnSummary MOCK_AUTH_STREAMING = ( @@ -1056,6 +1058,16 @@ async def test_retrieve_response_generator_runtime_error_other( class TestGenerateResponse: """Tests for generate_response function.""" + @pytest.fixture(autouse=True) + def isolate_stream_interrupt_registry(self, mocker: MockerFixture) -> Any: + """Patch registry accessor with a per-test mock registry instance.""" + test_registry = mocker.Mock(spec=StreamInterruptRegistry) + mocker.patch( + "app.endpoints.streaming_query.get_stream_interrupt_registry", + return_value=test_registry, + ) + return test_registry + @pytest.mark.asyncio async def test_generate_response_success(self, mocker: MockerFixture) -> None: """Test successful response generation.""" @@ -1074,6 +1086,7 @@ async def mock_generator() -> AsyncIterator[str]: ) # pyright: ignore[reportCallIssue] mock_context.started_at = "2024-01-01T00:00:00Z" mock_context.skip_userid_check = False + mock_context.request_id = "123e4567-e89b-12d3-a456-426614174000" mock_response_obj = mocker.Mock() mock_response_obj.output = [] @@ -1100,7 +1113,10 @@ async def mock_generator() -> AsyncIterator[str]: result = [] async for item in generate_response( - mock_generator(), mock_context, mock_responses_params, mock_turn_summary + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, ): result.append(item) @@ -1127,6 +1143,7 @@ async def mock_generator() -> AsyncIterator[str]: ) # pyright: ignore[reportCallIssue] mock_context.started_at = "2024-01-01T00:00:00Z" mock_context.skip_userid_check = False + mock_context.request_id = "123e4567-e89b-12d3-a456-426614174000" mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) mock_responses_params = mocker.Mock(spec=ResponsesApiParams) @@ -1150,7 +1167,10 @@ async def mock_generator() -> AsyncIterator[str]: result = [] async for item in generate_response( - mock_generator(), mock_context, mock_responses_params, mock_turn_summary + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, ): result.append(item) @@ -1170,11 +1190,13 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.conversation_id = "conv_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( query="test" ) # pyright: ignore[reportCallIssue] mock_context.started_at = "2024-01-01T00:00:00Z" mock_context.skip_userid_check = False + mock_context.request_id = "123e4567-e89b-12d3-a456-426614174000" mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" @@ -1183,7 +1205,10 @@ async def mock_generator() -> AsyncIterator[str]: result = [] async for item in generate_response( - mock_generator(), mock_context, mock_responses_params, mock_turn_summary + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, ): result.append(item) @@ -1207,11 +1232,13 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.conversation_id = "conv_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( query="test" ) # pyright: ignore[reportCallIssue] mock_context.started_at = "2024-01-01T00:00:00Z" mock_context.skip_userid_check = False + mock_context.request_id = "123e4567-e89b-12d3-a456-426614174000" mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" @@ -1226,7 +1253,10 @@ async def mock_generator() -> AsyncIterator[str]: result = [] async for item in generate_response( - mock_generator(), mock_context, mock_responses_params, mock_turn_summary + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, ): result.append(item) @@ -1247,9 +1277,11 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.conversation_id = "conv_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( query="test", media_type=MEDIA_TYPE_JSON ) # pyright: ignore[reportCallIssue] + mock_context.request_id = "123e4567-e89b-12d3-a456-426614174000" mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" @@ -1268,7 +1300,10 @@ async def mock_generator() -> AsyncIterator[str]: result = [] async for item in generate_response( - mock_generator(), mock_context, mock_responses_params, mock_turn_summary + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, ): result.append(item) @@ -1289,9 +1324,11 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.conversation_id = "conv_123" mock_context.vector_store_ids = [] mock_context.rag_id_mapping = {} + mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( query="test", media_type=MEDIA_TYPE_JSON ) # pyright: ignore[reportCallIssue] + mock_context.request_id = "123e4567-e89b-12d3-a456-426614174000" mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" @@ -1310,13 +1347,71 @@ async def mock_generator() -> AsyncIterator[str]: result = [] async for item in generate_response( - mock_generator(), mock_context, mock_responses_params, mock_turn_summary + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, ): result.append(item) assert len(result) > 0 assert any("error" in item for item in result) + @pytest.mark.asyncio + async def test_generate_response_cancelled_skips_side_effects( + self, + mocker: MockerFixture, + isolate_stream_interrupt_registry: Any, + ) -> None: + """Test cancelled stream exits without quota consumption and persistence.""" + + async def mock_generator() -> AsyncIterator[str]: + yield "data: token\n\n" + raise asyncio.CancelledError() + + mock_context = mocker.Mock(spec=ResponseGeneratorContext) + mock_context.conversation_id = "conv_123" + mock_context.user_id = "user_123" + mock_context.query_request = QueryRequest( + query="test", media_type=MEDIA_TYPE_JSON + ) # pyright: ignore[reportCallIssue] + mock_context.started_at = "2024-01-01T00:00:00Z" + mock_context.skip_userid_check = False + + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) + mock_responses_params.model = "provider1/model1" + + mock_turn_summary = TurnSummary() + mock_turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) + + consume_query_tokens_mock = mocker.patch( + "app.endpoints.streaming_query.consume_query_tokens" + ) + store_query_results_mock = mocker.patch( + "app.endpoints.streaming_query.store_query_results" + ) + + test_request_id = "123e4567-e89b-12d3-a456-426614174000" + mock_context.request_id = test_request_id + + result = [] + async for item in generate_response( + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, + ): + result.append(item) + + assert any("start" in item for item in result) + assert any('"event": "interrupted"' in item for item in result) + assert not any('"event": "end"' in item for item in result) + consume_query_tokens_mock.assert_not_called() + store_query_results_mock.assert_not_called() + isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( + test_request_id + ) + class TestResponseGenerator: """Tests for response_generator function.""" @@ -1899,10 +1994,11 @@ class TestStreamStartEvent: # pylint: disable=too-few-public-methods def test_stream_start_event(self) -> None: """Test start event formatting.""" - result = stream_start_event("conv_123") + result = stream_start_event("conv_123", "123e4567-e89b-12d3-a456-426614174000") assert "start" in result assert "conv_123" in result + assert "123e4567-e89b-12d3-a456-426614174000" in result class TestShieldViolationGenerator: diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 776aa0472..d29f7d026 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -18,6 +18,7 @@ health, config, feedback, + stream_interrupt, streaming_query, authorized, metrics, @@ -106,7 +107,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 19 + assert len(app.routers) == 20 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -126,6 +127,7 @@ def test_include_routers() -> None: assert metrics.router in app.get_routers() assert rlsapi_v1.router in app.get_routers() assert a2a.router in app.get_routers() + assert stream_interrupt.router in app.get_routers() def test_check_prefixes() -> None: @@ -142,7 +144,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 19 + assert len(app.routers) == 20 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -163,3 +165,4 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(metrics.router) == "" assert app.get_router_prefix(rlsapi_v1.router) == "/v1" assert app.get_router_prefix(a2a.router) == "" + assert app.get_router_prefix(stream_interrupt.router) == "/v1" diff --git a/tests/unit/models/responses/test_error_responses.py b/tests/unit/models/responses/test_error_responses.py index aaf5047ab..3ecf441b8 100644 --- a/tests/unit/models/responses/test_error_responses.py +++ b/tests/unit/models/responses/test_error_responses.py @@ -472,13 +472,14 @@ def test_openapi_response(self) -> None: # Verify example count matches schema examples count assert len(examples) == expected_count - assert expected_count == 4 + assert expected_count == 5 # Verify all labeled examples are present assert "conversation" in examples assert "provider" in examples assert "model" in examples assert "rag" in examples + assert "streaming request" in examples # Verify example structure for one example conversation_example = examples["conversation"]