diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 5b8aed1e..ac755735 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -42,6 +42,7 @@ from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration +from constants import MEDIA_TYPE_EVENT_STREAM from models.config import Action from models.requests import QueryRequest from utils.mcp_headers import mcp_headers_dependency, McpHeaders @@ -827,7 +828,7 @@ async def response_generator() -> AsyncIterator[bytes]: # Return streaming response with SSE content type for A2A protocol return StreamingResponse( response_generator(), - media_type="text/event-stream", + media_type=MEDIA_TYPE_EVENT_STREAM, ) # Non-streaming mode: Buffer entire response diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index b6f29ecd..047278b1 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -34,6 +34,7 @@ LLM_TOOL_CALL_EVENT, LLM_TOOL_RESULT_EVENT, LLM_TURN_COMPLETE_EVENT, + MEDIA_TYPE_EVENT_STREAM, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT, ) @@ -213,6 +214,12 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals context=context, ) + response_media_type = ( + MEDIA_TYPE_TEXT + if query_request.media_type == MEDIA_TYPE_TEXT + else MEDIA_TYPE_EVENT_STREAM + ) + return StreamingResponse( generate_response( generator=generator, @@ -220,7 +227,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals responses_params=responses_params, turn_summary=turn_summary, ), - media_type=query_request.media_type or MEDIA_TYPE_TEXT, + media_type=response_media_type, ) diff --git a/src/constants.py b/src/constants.py index 1b7b6bee..628aeb45 100644 --- a/src/constants.py +++ b/src/constants.py @@ -132,6 +132,7 @@ # Media type constants for streaming responses MEDIA_TYPE_JSON = "application/json" MEDIA_TYPE_TEXT = "text/plain" +MEDIA_TYPE_EVENT_STREAM = "text/event-stream" # Streaming event type constants LLM_TOKEN_EVENT = "token" diff --git a/src/models/responses.py b/src/models/responses.py index c412256c..9b29be51 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from pydantic_core import SchemaError +from constants import MEDIA_TYPE_EVENT_STREAM from models.config import Action, Configuration from quota.quota_exceed_error import QuotaExceedError from utils.types import RAGChunk, ReferencedDocument, ToolCallSummary, ToolResultSummary @@ -485,8 +486,8 @@ def openapi_response(cls) -> dict[str, Any]: raise SchemaError(f"Examples not found in {cls.__name__}") example_value = model_examples[0] content = { - "text/event-stream": { - "schema": {"type": "string", "format": "text/event-stream"}, + MEDIA_TYPE_EVENT_STREAM: { + "schema": {"type": "string", "format": MEDIA_TYPE_EVENT_STREAM}, "example": example_value, } } diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index e5742397..0f165c4f 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -390,6 +390,85 @@ async def mock_generate_response( ) assert isinstance(response, StreamingResponse) + assert response.media_type == "text/event-stream" + + @pytest.mark.asyncio + async def test_streaming_query_text_media_type_header( + self, + dummy_request: Request, # pylint: disable=redefined-outer-name + setup_configuration: AppConfig, + mocker: MockerFixture, + ) -> None: + """Test streaming query uses plain text header when requested.""" + query_request = QueryRequest( + query="What is Kubernetes?", media_type=MEDIA_TYPE_TEXT + ) # pyright: ignore[reportCallIssue] + + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") + mocker.patch("app.endpoints.streaming_query.check_tokens_available") + mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") + + mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) + mock_client_holder = mocker.Mock() + mock_client_holder.get_client.return_value = mock_client + mocker.patch( + "app.endpoints.streaming_query.AsyncLlamaStackClientHolder", + return_value=mock_client_holder, + ) + + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) + mock_responses_params.model = "provider1/model1" + mock_responses_params.conversation = "conv_123" + mock_responses_params.model_dump.return_value = { + "input": "test", + "model": "provider1/model1", + } + mocker.patch( + "app.endpoints.streaming_query.prepare_responses_params", + new=mocker.AsyncMock(return_value=mock_responses_params), + ) + + mocker.patch("app.endpoints.streaming_query.AzureEntraIDManager") + mocker.patch( + "app.endpoints.streaming_query.extract_provider_and_model_from_model_id", + return_value=("provider1", "model1"), + ) + mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total") + + async def mock_generator() -> AsyncIterator[str]: + yield "data: test\n\n" + + mock_turn_summary = TurnSummary() + mocker.patch( + "app.endpoints.streaming_query.retrieve_response_generator", + return_value=(mock_generator(), mock_turn_summary), + ) + + async def mock_generate_response( + *_args: Any, **_kwargs: Any + ) -> AsyncIterator[str]: + async for item in mock_generator(): + yield item + + mocker.patch( + "app.endpoints.streaming_query.generate_response", + side_effect=mock_generate_response, + ) + mocker.patch( + "app.endpoints.streaming_query.normalize_conversation_id", + return_value="123", + ) + + response = await streaming_query_endpoint_handler( + request=dummy_request, + query_request=query_request, + auth=MOCK_AUTH_STREAMING, + mcp_headers={}, + ) + + assert isinstance(response, StreamingResponse) + assert response.media_type == MEDIA_TYPE_TEXT @pytest.mark.asyncio async def test_streaming_query_with_conversation(