Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 98 additions & 89 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import time
from typing import Annotated, Any, cast
from typing import Annotated, Any, Optional, cast

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from llama_stack_api.openai_responses import OpenAIResponseObject
Expand Down Expand Up @@ -50,6 +50,15 @@

# Default values when RH Identity auth is not configured
AUTH_DISABLED = "auth_disabled"
# Keep this tuple centralized so infer_endpoint can catch all expected backend
# failures in one place while preserving a single telemetry/error-mapping path.
_INFER_HANDLED_EXCEPTIONS = (
RuntimeError,
APIConnectionError,
RateLimitError,
APIStatusError,
OpenAIAPIStatusError,
)


def _get_rh_identity_context(request: Request) -> tuple[str, str]:
Expand All @@ -66,7 +75,7 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]:
Tuple of (org_id, system_id). Returns ("auth_disabled", "auth_disabled")
when RH Identity auth is not configured or data is unavailable.
"""
rh_identity: RHIdentityData | None = getattr(
rh_identity: Optional[RHIdentityData] = getattr(
request.state, "rh_identity_data", None
)
if rh_identity is None:
Expand Down Expand Up @@ -103,13 +112,7 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
Returns:
Instructions string for the LLM, with system context if available.
"""
if (
configuration.customization is not None
and configuration.customization.system_prompt is not None
):
base_prompt = configuration.customization.system_prompt
else:
base_prompt = constants.DEFAULT_SYSTEM_PROMPT
base_prompt = _get_base_prompt()

context_parts = []
if systeminfo.os:
Expand All @@ -126,6 +129,16 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
return f"{base_prompt}\n\nUser's system: {system_context}"


def _get_base_prompt() -> str:
"""Get the base system prompt with configuration fallback."""
if (
configuration.customization is not None
and configuration.customization.system_prompt is not None
):
return configuration.customization.system_prompt
return constants.DEFAULT_SYSTEM_PROMPT


def _get_default_model_id() -> str:
"""Get the default model ID from configuration.

Expand Down Expand Up @@ -162,7 +175,10 @@ def _get_default_model_id() -> str:


async def retrieve_simple_response(
question: str, instructions: str, tools: list | None = None
question: str,
instructions: str,
tools: Optional[list[Any]] = None,
model_id: Optional[str] = None,
) -> str:
"""Retrieve a simple response from the LLM for a stateless query.

Expand All @@ -173,29 +189,30 @@ async def retrieve_simple_response(
question: The combined user input (question + context).
instructions: System instructions for the LLM.
tools: Optional list of MCP tool definitions for the LLM.
model_id: Fully qualified model identifier in provider/model format.
When omitted, the configured default model is used.

Returns:
The LLM-generated response text.

Raises:
APIConnectionError: If the Llama Stack service is unreachable.
HTTPException: 503 if no model is configured.
HTTPException: 503 if no default model is configured.
"""
client = AsyncLlamaStackClientHolder().get_client()
model_id = _get_default_model_id()

logger.debug("Using model %s for rlsapi v1 inference", model_id)
resolved_model_id = model_id or _get_default_model_id()
logger.debug("Using model %s for rlsapi v1 inference", resolved_model_id)

response = await client.responses.create(
input=question,
model=model_id,
model=resolved_model_id,
instructions=instructions,
tools=tools or [],
stream=False,
store=False,
)
response = cast(OpenAIResponseObject, response)
extract_token_usage(response.usage, model_id)
extract_token_usage(response.usage, resolved_model_id)

return extract_text_from_response_items(response.output)

Expand All @@ -205,6 +222,13 @@ def _get_cla_version(request: Request) -> str:
return request.headers.get("User-Agent", "")


def _get_configured_default_model_name() -> str:
"""Get configured default model name for telemetry payloads."""
if configuration.inference is None:
return ""
return configuration.inference.default_model or ""


def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-positional-arguments
background_tasks: BackgroundTasks,
infer_request: RlsapiV1InferRequest,
Expand All @@ -222,11 +246,7 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
question=infer_request.question,
response=response_text,
inference_time=inference_time,
model=(
(configuration.inference.default_model or "")
if configuration.inference
else ""
),
model=_get_configured_default_model_name(),
org_id=org_id,
system_id=system_id,
request_id=request_id,
Expand Down Expand Up @@ -277,6 +297,50 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
return inference_time


def _map_inference_error_to_http_exception(
error: Exception, model_id: str, request_id: str
) -> Optional[HTTPException]:
"""Map known inference errors to HTTPException.

Returns None for RuntimeError values that are not context-length related,
so callers can preserve existing re-raise behavior for unknown runtime
errors.
"""
if isinstance(error, RuntimeError):
error_message = str(error).lower()
if "context_length" in error_message or "context length" in error_message:
logger.error("Prompt too long for request %s: %s", request_id, error)
error_response = PromptTooLongResponse(model=model_id)
return HTTPException(**error_response.model_dump())
logger.error("Unexpected RuntimeError for request %s: %s", request_id, error)
return None

if isinstance(error, APIConnectionError):
logger.error(
"Unable to connect to Llama Stack for request %s: %s", request_id, error
)
error_response = ServiceUnavailableResponse(
backend_name="Llama Stack",
cause="Unable to connect to the inference backend",
)
return HTTPException(**error_response.model_dump())

if isinstance(error, RateLimitError):
logger.error("Rate limit exceeded for request %s: %s", request_id, error)
error_response = QuotaExceededResponse(
response="The quota has been exceeded",
cause="Rate limit exceeded, please try again later",
)
return HTTPException(**error_response.model_dump())

if isinstance(error, (APIStatusError, OpenAIAPIStatusError)):
logger.exception("API error for request %s: %s", request_id, error)
error_response = handle_known_apistatus_errors(error, model_id)
return HTTPException(**error_response.model_dump())

return None


@router.post("/infer", responses=infer_responses)
@authorize(Action.RLSAPI_V1_INFER)
async def infer_endpoint( # pylint: disable=R0914
Expand Down Expand Up @@ -315,94 +379,39 @@ async def infer_endpoint( # pylint: disable=R0914
instructions = _build_instructions(infer_request.context.systeminfo)
model_id = _get_default_model_id()
provider, model = extract_provider_and_model_from_model_id(model_id)
mcp_tools = await get_mcp_tools(request_headers=request.headers)
mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers)
logger.debug(
"Request %s: Combined input source length: %d", request_id, len(input_source)
)

start_time = time.monotonic()
try:
response_text = await retrieve_simple_response(
input_source, instructions, tools=mcp_tools
input_source,
instructions,
tools=cast(list[Any], mcp_tools),
model_id=model_id,
)
inference_time = time.monotonic() - start_time
except RuntimeError as e:
if "context_length" in str(e).lower():
_record_inference_failure(
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error("Prompt too long for request %s: %s", request_id, e)
error_response = PromptTooLongResponse(model=model_id)
raise HTTPException(**error_response.model_dump()) from e
except _INFER_HANDLED_EXCEPTIONS as error:
_record_inference_failure(
background_tasks,
infer_request,
request,
request_id,
e,
error,
start_time,
model,
provider,
)
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
raise
except APIConnectionError as e:
_record_inference_failure(
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error(
"Unable to connect to Llama Stack for request %s: %s", request_id, e
)
error_response = ServiceUnavailableResponse(
backend_name="Llama Stack",
cause="Unable to connect to the inference backend",
)
raise HTTPException(**error_response.model_dump()) from e
except RateLimitError as e:
_record_inference_failure(
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
error_response = QuotaExceededResponse(
response="The quota has been exceeded",
cause="Rate limit exceeded, please try again later",
)
raise HTTPException(**error_response.model_dump()) from e
except (APIStatusError, OpenAIAPIStatusError) as e:
_record_inference_failure(
background_tasks,
infer_request,
request,
mapped_error = _map_inference_error_to_http_exception(
error,
model_id,
request_id,
e,
start_time,
model,
provider,
)
logger.exception("API error for request %s: %s", request_id, e)
error_response = handle_known_apistatus_errors(e, model_id)
raise HTTPException(**error_response.model_dump()) from e
if mapped_error is not None:
raise mapped_error from error
raise

if not response_text:
logger.warning("Empty response from LLM for request %s", request_id)
Expand Down
Loading