From 8e6f0a8a6a554ac4dd9ecb4fcca9d590ef7fd686 Mon Sep 17 00:00:00 2001 From: Major Hayden Date: Mon, 2 Mar 2026 08:30:11 -0600 Subject: [PATCH 1/2] refactor(rlsapi): simplify infer flow and error mapping Signed-off-by: Major Hayden --- src/app/endpoints/rlsapi_v1.py | 175 +++++++++++++++++---------------- 1 file changed, 90 insertions(+), 85 deletions(-) diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 407c42013..19d66d2b0 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -103,13 +103,7 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str: Returns: Instructions string for the LLM, with system context if available. """ - if ( - configuration.customization is not None - and configuration.customization.system_prompt is not None - ): - base_prompt = configuration.customization.system_prompt - else: - base_prompt = constants.DEFAULT_SYSTEM_PROMPT + base_prompt = _get_base_prompt() context_parts = [] if systeminfo.os: @@ -126,6 +120,16 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str: return f"{base_prompt}\n\nUser's system: {system_context}" +def _get_base_prompt() -> str: + """Get the base system prompt with configuration fallback.""" + if ( + configuration.customization is not None + and configuration.customization.system_prompt is not None + ): + return configuration.customization.system_prompt + return constants.DEFAULT_SYSTEM_PROMPT + + def _get_default_model_id() -> str: """Get the default model ID from configuration. @@ -162,7 +166,10 @@ def _get_default_model_id() -> str: async def retrieve_simple_response( - question: str, instructions: str, tools: list | None = None + question: str, + instructions: str, + tools: list[Any] | None = None, + model_id: str | None = None, ) -> str: """Retrieve a simple response from the LLM for a stateless query. @@ -173,22 +180,23 @@ async def retrieve_simple_response( question: The combined user input (question + context). instructions: System instructions for the LLM. tools: Optional list of MCP tool definitions for the LLM. + model_id: Fully qualified model identifier in provider/model format. + When omitted, the configured default model is used. Returns: The LLM-generated response text. Raises: APIConnectionError: If the Llama Stack service is unreachable. - HTTPException: 503 if no model is configured. + HTTPException: 503 if no default model is configured. """ client = AsyncLlamaStackClientHolder().get_client() - model_id = _get_default_model_id() - - logger.debug("Using model %s for rlsapi v1 inference", model_id) + resolved_model_id = model_id or _get_default_model_id() + logger.debug("Using model %s for rlsapi v1 inference", resolved_model_id) response = await client.responses.create( input=question, - model=model_id, + model=resolved_model_id, instructions=instructions, tools=tools or [], stream=False, @@ -205,6 +213,13 @@ def _get_cla_version(request: Request) -> str: return request.headers.get("User-Agent", "") +def _get_configured_default_model_name() -> str: + """Get configured default model name for telemetry payloads.""" + if configuration.inference is None: + return "" + return configuration.inference.default_model or "" + + def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-positional-arguments background_tasks: BackgroundTasks, infer_request: RlsapiV1InferRequest, @@ -222,11 +237,7 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position question=infer_request.question, response=response_text, inference_time=inference_time, - model=( - (configuration.inference.default_model or "") - if configuration.inference - else "" - ), + model=_get_configured_default_model_name(), org_id=org_id, system_id=system_id, request_id=request_id, @@ -277,6 +288,49 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po return inference_time +def _map_inference_error_to_http_exception( + error: Exception, model_id: str, request_id: str +) -> HTTPException | None: + """Map known inference errors to HTTPException. + + Returns None for RuntimeError values that are not context-length related, + so callers can preserve existing re-raise behavior for unknown runtime + errors. + """ + if isinstance(error, RuntimeError): + if "context_length" in str(error).lower(): + logger.error("Prompt too long for request %s: %s", request_id, error) + error_response = PromptTooLongResponse(model=model_id) + return HTTPException(**error_response.model_dump()) + logger.error("Unexpected RuntimeError for request %s: %s", request_id, error) + return None + + if isinstance(error, APIConnectionError): + logger.error( + "Unable to connect to Llama Stack for request %s: %s", request_id, error + ) + error_response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause="Unable to connect to the inference backend", + ) + return HTTPException(**error_response.model_dump()) + + if isinstance(error, RateLimitError): + logger.error("Rate limit exceeded for request %s: %s", request_id, error) + error_response = QuotaExceededResponse( + response="The quota has been exceeded", + cause="Rate limit exceeded, please try again later", + ) + return HTTPException(**error_response.model_dump()) + + if isinstance(error, (APIStatusError, OpenAIAPIStatusError)): + logger.exception("API error for request %s: %s", request_id, error) + error_response = handle_known_apistatus_errors(error, model_id) + return HTTPException(**error_response.model_dump()) + + return None + + @router.post("/infer", responses=infer_responses) @authorize(Action.RLSAPI_V1_INFER) async def infer_endpoint( # pylint: disable=R0914 @@ -323,86 +377,37 @@ async def infer_endpoint( # pylint: disable=R0914 start_time = time.monotonic() try: response_text = await retrieve_simple_response( - input_source, instructions, tools=mcp_tools + input_source, + instructions, + tools=cast(list[Any], mcp_tools), + model_id=model_id, ) inference_time = time.monotonic() - start_time - except RuntimeError as e: - if "context_length" in str(e).lower(): - _record_inference_failure( - background_tasks, - infer_request, - request, - request_id, - e, - start_time, - model, - provider, - ) - logger.error("Prompt too long for request %s: %s", request_id, e) - error_response = PromptTooLongResponse(model=model_id) - raise HTTPException(**error_response.model_dump()) from e - _record_inference_failure( - background_tasks, - infer_request, - request, - request_id, - e, - start_time, - model, - provider, - ) - logger.error("Unexpected RuntimeError for request %s: %s", request_id, e) - raise - except APIConnectionError as e: - _record_inference_failure( - background_tasks, - infer_request, - request, - request_id, - e, - start_time, - model, - provider, - ) - logger.error( - "Unable to connect to Llama Stack for request %s: %s", request_id, e - ) - error_response = ServiceUnavailableResponse( - backend_name="Llama Stack", - cause="Unable to connect to the inference backend", - ) - raise HTTPException(**error_response.model_dump()) from e - except RateLimitError as e: + except ( + RuntimeError, + APIConnectionError, + RateLimitError, + APIStatusError, + OpenAIAPIStatusError, + ) as error: _record_inference_failure( background_tasks, infer_request, request, request_id, - e, + error, start_time, model, provider, ) - logger.error("Rate limit exceeded for request %s: %s", request_id, e) - error_response = QuotaExceededResponse( - response="The quota has been exceeded", - cause="Rate limit exceeded, please try again later", - ) - raise HTTPException(**error_response.model_dump()) from e - except (APIStatusError, OpenAIAPIStatusError) as e: - _record_inference_failure( - background_tasks, - infer_request, - request, + mapped_error = _map_inference_error_to_http_exception( + error, + model_id, request_id, - e, - start_time, - model, - provider, ) - logger.exception("API error for request %s: %s", request_id, e) - error_response = handle_known_apistatus_errors(e, model_id) - raise HTTPException(**error_response.model_dump()) from e + if mapped_error is not None: + raise mapped_error from error + raise if not response_text: logger.warning("Empty response from LLM for request %s", request_id) From 50cc05d97c5c150732c5a4df374057ae1ce94b94 Mon Sep 17 00:00:00 2001 From: Major Hayden Date: Mon, 2 Mar 2026 09:24:59 -0600 Subject: [PATCH 2/2] refactor(rlsapi): tighten infer exception typing Signed-off-by: Major Hayden --- src/app/endpoints/rlsapi_v1.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 19d66d2b0..954828f23 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -5,7 +5,7 @@ """ import time -from typing import Annotated, Any, cast +from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from llama_stack_api.openai_responses import OpenAIResponseObject @@ -50,6 +50,15 @@ # Default values when RH Identity auth is not configured AUTH_DISABLED = "auth_disabled" +# Keep this tuple centralized so infer_endpoint can catch all expected backend +# failures in one place while preserving a single telemetry/error-mapping path. +_INFER_HANDLED_EXCEPTIONS = ( + RuntimeError, + APIConnectionError, + RateLimitError, + APIStatusError, + OpenAIAPIStatusError, +) def _get_rh_identity_context(request: Request) -> tuple[str, str]: @@ -66,7 +75,7 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]: Tuple of (org_id, system_id). Returns ("auth_disabled", "auth_disabled") when RH Identity auth is not configured or data is unavailable. """ - rh_identity: RHIdentityData | None = getattr( + rh_identity: Optional[RHIdentityData] = getattr( request.state, "rh_identity_data", None ) if rh_identity is None: @@ -168,8 +177,8 @@ def _get_default_model_id() -> str: async def retrieve_simple_response( question: str, instructions: str, - tools: list[Any] | None = None, - model_id: str | None = None, + tools: Optional[list[Any]] = None, + model_id: Optional[str] = None, ) -> str: """Retrieve a simple response from the LLM for a stateless query. @@ -203,7 +212,7 @@ async def retrieve_simple_response( store=False, ) response = cast(OpenAIResponseObject, response) - extract_token_usage(response.usage, model_id) + extract_token_usage(response.usage, resolved_model_id) return extract_text_from_response_items(response.output) @@ -290,7 +299,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po def _map_inference_error_to_http_exception( error: Exception, model_id: str, request_id: str -) -> HTTPException | None: +) -> Optional[HTTPException]: """Map known inference errors to HTTPException. Returns None for RuntimeError values that are not context-length related, @@ -298,7 +307,8 @@ def _map_inference_error_to_http_exception( errors. """ if isinstance(error, RuntimeError): - if "context_length" in str(error).lower(): + error_message = str(error).lower() + if "context_length" in error_message or "context length" in error_message: logger.error("Prompt too long for request %s: %s", request_id, error) error_response = PromptTooLongResponse(model=model_id) return HTTPException(**error_response.model_dump()) @@ -369,7 +379,7 @@ async def infer_endpoint( # pylint: disable=R0914 instructions = _build_instructions(infer_request.context.systeminfo) model_id = _get_default_model_id() provider, model = extract_provider_and_model_from_model_id(model_id) - mcp_tools = await get_mcp_tools(request_headers=request.headers) + mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers) logger.debug( "Request %s: Combined input source length: %d", request_id, len(input_source) ) @@ -383,13 +393,7 @@ async def infer_endpoint( # pylint: disable=R0914 model_id=model_id, ) inference_time = time.monotonic() - start_time - except ( - RuntimeError, - APIConnectionError, - RateLimitError, - APIStatusError, - OpenAIAPIStatusError, - ) as error: + except _INFER_HANDLED_EXCEPTIONS as error: _record_inference_failure( background_tasks, infer_request,