diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 407c42013..954828f23 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -5,7 +5,7 @@ """ import time -from typing import Annotated, Any, cast +from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from llama_stack_api.openai_responses import OpenAIResponseObject @@ -50,6 +50,15 @@ # Default values when RH Identity auth is not configured AUTH_DISABLED = "auth_disabled" +# Keep this tuple centralized so infer_endpoint can catch all expected backend +# failures in one place while preserving a single telemetry/error-mapping path. +_INFER_HANDLED_EXCEPTIONS = ( + RuntimeError, + APIConnectionError, + RateLimitError, + APIStatusError, + OpenAIAPIStatusError, +) def _get_rh_identity_context(request: Request) -> tuple[str, str]: @@ -66,7 +75,7 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]: Tuple of (org_id, system_id). Returns ("auth_disabled", "auth_disabled") when RH Identity auth is not configured or data is unavailable. """ - rh_identity: RHIdentityData | None = getattr( + rh_identity: Optional[RHIdentityData] = getattr( request.state, "rh_identity_data", None ) if rh_identity is None: @@ -103,13 +112,7 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str: Returns: Instructions string for the LLM, with system context if available. """ - if ( - configuration.customization is not None - and configuration.customization.system_prompt is not None - ): - base_prompt = configuration.customization.system_prompt - else: - base_prompt = constants.DEFAULT_SYSTEM_PROMPT + base_prompt = _get_base_prompt() context_parts = [] if systeminfo.os: @@ -126,6 +129,16 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str: return f"{base_prompt}\n\nUser's system: {system_context}" +def _get_base_prompt() -> str: + """Get the base system prompt with configuration fallback.""" + if ( + configuration.customization is not None + and configuration.customization.system_prompt is not None + ): + return configuration.customization.system_prompt + return constants.DEFAULT_SYSTEM_PROMPT + + def _get_default_model_id() -> str: """Get the default model ID from configuration. @@ -162,7 +175,10 @@ def _get_default_model_id() -> str: async def retrieve_simple_response( - question: str, instructions: str, tools: list | None = None + question: str, + instructions: str, + tools: Optional[list[Any]] = None, + model_id: Optional[str] = None, ) -> str: """Retrieve a simple response from the LLM for a stateless query. @@ -173,29 +189,30 @@ async def retrieve_simple_response( question: The combined user input (question + context). instructions: System instructions for the LLM. tools: Optional list of MCP tool definitions for the LLM. + model_id: Fully qualified model identifier in provider/model format. + When omitted, the configured default model is used. Returns: The LLM-generated response text. Raises: APIConnectionError: If the Llama Stack service is unreachable. - HTTPException: 503 if no model is configured. + HTTPException: 503 if no default model is configured. """ client = AsyncLlamaStackClientHolder().get_client() - model_id = _get_default_model_id() - - logger.debug("Using model %s for rlsapi v1 inference", model_id) + resolved_model_id = model_id or _get_default_model_id() + logger.debug("Using model %s for rlsapi v1 inference", resolved_model_id) response = await client.responses.create( input=question, - model=model_id, + model=resolved_model_id, instructions=instructions, tools=tools or [], stream=False, store=False, ) response = cast(OpenAIResponseObject, response) - extract_token_usage(response.usage, model_id) + extract_token_usage(response.usage, resolved_model_id) return extract_text_from_response_items(response.output) @@ -205,6 +222,13 @@ def _get_cla_version(request: Request) -> str: return request.headers.get("User-Agent", "") +def _get_configured_default_model_name() -> str: + """Get configured default model name for telemetry payloads.""" + if configuration.inference is None: + return "" + return configuration.inference.default_model or "" + + def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-positional-arguments background_tasks: BackgroundTasks, infer_request: RlsapiV1InferRequest, @@ -222,11 +246,7 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position question=infer_request.question, response=response_text, inference_time=inference_time, - model=( - (configuration.inference.default_model or "") - if configuration.inference - else "" - ), + model=_get_configured_default_model_name(), org_id=org_id, system_id=system_id, request_id=request_id, @@ -277,6 +297,50 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po return inference_time +def _map_inference_error_to_http_exception( + error: Exception, model_id: str, request_id: str +) -> Optional[HTTPException]: + """Map known inference errors to HTTPException. + + Returns None for RuntimeError values that are not context-length related, + so callers can preserve existing re-raise behavior for unknown runtime + errors. + """ + if isinstance(error, RuntimeError): + error_message = str(error).lower() + if "context_length" in error_message or "context length" in error_message: + logger.error("Prompt too long for request %s: %s", request_id, error) + error_response = PromptTooLongResponse(model=model_id) + return HTTPException(**error_response.model_dump()) + logger.error("Unexpected RuntimeError for request %s: %s", request_id, error) + return None + + if isinstance(error, APIConnectionError): + logger.error( + "Unable to connect to Llama Stack for request %s: %s", request_id, error + ) + error_response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause="Unable to connect to the inference backend", + ) + return HTTPException(**error_response.model_dump()) + + if isinstance(error, RateLimitError): + logger.error("Rate limit exceeded for request %s: %s", request_id, error) + error_response = QuotaExceededResponse( + response="The quota has been exceeded", + cause="Rate limit exceeded, please try again later", + ) + return HTTPException(**error_response.model_dump()) + + if isinstance(error, (APIStatusError, OpenAIAPIStatusError)): + logger.exception("API error for request %s: %s", request_id, error) + error_response = handle_known_apistatus_errors(error, model_id) + return HTTPException(**error_response.model_dump()) + + return None + + @router.post("/infer", responses=infer_responses) @authorize(Action.RLSAPI_V1_INFER) async def infer_endpoint( # pylint: disable=R0914 @@ -315,7 +379,7 @@ async def infer_endpoint( # pylint: disable=R0914 instructions = _build_instructions(infer_request.context.systeminfo) model_id = _get_default_model_id() provider, model = extract_provider_and_model_from_model_id(model_id) - mcp_tools = await get_mcp_tools(request_headers=request.headers) + mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers) logger.debug( "Request %s: Combined input source length: %d", request_id, len(input_source) ) @@ -323,86 +387,31 @@ async def infer_endpoint( # pylint: disable=R0914 start_time = time.monotonic() try: response_text = await retrieve_simple_response( - input_source, instructions, tools=mcp_tools + input_source, + instructions, + tools=cast(list[Any], mcp_tools), + model_id=model_id, ) inference_time = time.monotonic() - start_time - except RuntimeError as e: - if "context_length" in str(e).lower(): - _record_inference_failure( - background_tasks, - infer_request, - request, - request_id, - e, - start_time, - model, - provider, - ) - logger.error("Prompt too long for request %s: %s", request_id, e) - error_response = PromptTooLongResponse(model=model_id) - raise HTTPException(**error_response.model_dump()) from e + except _INFER_HANDLED_EXCEPTIONS as error: _record_inference_failure( background_tasks, infer_request, request, request_id, - e, + error, start_time, model, provider, ) - logger.error("Unexpected RuntimeError for request %s: %s", request_id, e) - raise - except APIConnectionError as e: - _record_inference_failure( - background_tasks, - infer_request, - request, - request_id, - e, - start_time, - model, - provider, - ) - logger.error( - "Unable to connect to Llama Stack for request %s: %s", request_id, e - ) - error_response = ServiceUnavailableResponse( - backend_name="Llama Stack", - cause="Unable to connect to the inference backend", - ) - raise HTTPException(**error_response.model_dump()) from e - except RateLimitError as e: - _record_inference_failure( - background_tasks, - infer_request, - request, - request_id, - e, - start_time, - model, - provider, - ) - logger.error("Rate limit exceeded for request %s: %s", request_id, e) - error_response = QuotaExceededResponse( - response="The quota has been exceeded", - cause="Rate limit exceeded, please try again later", - ) - raise HTTPException(**error_response.model_dump()) from e - except (APIStatusError, OpenAIAPIStatusError) as e: - _record_inference_failure( - background_tasks, - infer_request, - request, + mapped_error = _map_inference_error_to_http_exception( + error, + model_id, request_id, - e, - start_time, - model, - provider, ) - logger.exception("API error for request %s: %s", request_id, e) - error_response = handle_known_apistatus_errors(e, model_id) - raise HTTPException(**error_response.model_dump()) from e + if mapped_error is not None: + raise mapped_error from error + raise if not response_text: logger.warning("Empty response from LLM for request %s", request_id)