Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/app/endpoints/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from constants import MEDIA_TYPE_EVENT_STREAM
from models.config import Action
from models.requests import QueryRequest
from utils.mcp_headers import mcp_headers_dependency, McpHeaders
Expand Down Expand Up @@ -827,7 +828,7 @@ async def response_generator() -> AsyncIterator[bytes]:
# Return streaming response with SSE content type for A2A protocol
return StreamingResponse(
response_generator(),
media_type="text/event-stream",
media_type=MEDIA_TYPE_EVENT_STREAM,
)

# Non-streaming mode: Buffer entire response
Expand Down
9 changes: 8 additions & 1 deletion src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
LLM_TOOL_CALL_EVENT,
LLM_TOOL_RESULT_EVENT,
LLM_TURN_COMPLETE_EVENT,
MEDIA_TYPE_EVENT_STREAM,
MEDIA_TYPE_JSON,
MEDIA_TYPE_TEXT,
)
Expand Down Expand Up @@ -213,14 +214,20 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
context=context,
)

response_media_type = (
MEDIA_TYPE_TEXT
if query_request.media_type == MEDIA_TYPE_TEXT
else MEDIA_TYPE_EVENT_STREAM
)

return StreamingResponse(
generate_response(
generator=generator,
context=context,
responses_params=responses_params,
turn_summary=turn_summary,
),
media_type=query_request.media_type or MEDIA_TYPE_TEXT,
media_type=response_media_type,
)


Expand Down
1 change: 1 addition & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
# Media type constants for streaming responses
MEDIA_TYPE_JSON = "application/json"
MEDIA_TYPE_TEXT = "text/plain"
MEDIA_TYPE_EVENT_STREAM = "text/event-stream"

# Streaming event type constants
LLM_TOKEN_EVENT = "token"
Expand Down
5 changes: 3 additions & 2 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel, Field
from pydantic_core import SchemaError

from constants import MEDIA_TYPE_EVENT_STREAM
from models.config import Action, Configuration
from quota.quota_exceed_error import QuotaExceedError
from utils.types import RAGChunk, ReferencedDocument, ToolCallSummary, ToolResultSummary
Expand Down Expand Up @@ -485,8 +486,8 @@ def openapi_response(cls) -> dict[str, Any]:
raise SchemaError(f"Examples not found in {cls.__name__}")
example_value = model_examples[0]
content = {
"text/event-stream": {
"schema": {"type": "string", "format": "text/event-stream"},
MEDIA_TYPE_EVENT_STREAM: {
"schema": {"type": "string", "format": MEDIA_TYPE_EVENT_STREAM},
"example": example_value,
}
}
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,85 @@ async def mock_generate_response(
)

assert isinstance(response, StreamingResponse)
assert response.media_type == "text/event-stream"

@pytest.mark.asyncio
async def test_streaming_query_text_media_type_header(
self,
dummy_request: Request, # pylint: disable=redefined-outer-name
setup_configuration: AppConfig,
mocker: MockerFixture,
) -> None:
"""Test streaming query uses plain text header when requested."""
query_request = QueryRequest(
query="What is Kubernetes?", media_type=MEDIA_TYPE_TEXT
) # pyright: ignore[reportCallIssue]

mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration)
mocker.patch("app.endpoints.streaming_query.check_configuration_loaded")
mocker.patch("app.endpoints.streaming_query.check_tokens_available")
mocker.patch("app.endpoints.streaming_query.validate_model_provider_override")

mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient)
mock_client_holder = mocker.Mock()
mock_client_holder.get_client.return_value = mock_client
mocker.patch(
"app.endpoints.streaming_query.AsyncLlamaStackClientHolder",
return_value=mock_client_holder,
)

mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
mock_responses_params.model = "provider1/model1"
mock_responses_params.conversation = "conv_123"
mock_responses_params.model_dump.return_value = {
"input": "test",
"model": "provider1/model1",
}
mocker.patch(
"app.endpoints.streaming_query.prepare_responses_params",
new=mocker.AsyncMock(return_value=mock_responses_params),
)

mocker.patch("app.endpoints.streaming_query.AzureEntraIDManager")
mocker.patch(
"app.endpoints.streaming_query.extract_provider_and_model_from_model_id",
return_value=("provider1", "model1"),
)
mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total")

async def mock_generator() -> AsyncIterator[str]:
yield "data: test\n\n"

mock_turn_summary = TurnSummary()
mocker.patch(
"app.endpoints.streaming_query.retrieve_response_generator",
return_value=(mock_generator(), mock_turn_summary),
)

async def mock_generate_response(
*_args: Any, **_kwargs: Any
) -> AsyncIterator[str]:
async for item in mock_generator():
yield item

mocker.patch(
"app.endpoints.streaming_query.generate_response",
side_effect=mock_generate_response,
)
mocker.patch(
"app.endpoints.streaming_query.normalize_conversation_id",
return_value="123",
)

response = await streaming_query_endpoint_handler(
request=dummy_request,
query_request=query_request,
auth=MOCK_AUTH_STREAMING,
mcp_headers={},
)

assert isinstance(response, StreamingResponse)
assert response.media_type == MEDIA_TYPE_TEXT

@pytest.mark.asyncio
async def test_streaming_query_with_conversation(
Expand Down
Loading