diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 73668e452..a79670540 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -33,9 +33,13 @@ from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse from observability import InferenceEventData, build_inference_event, send_splunk_event -from utils.query import handle_known_apistatus_errors +from utils.query import ( + extract_provider_and_model_from_model_id, + handle_known_apistatus_errors, +) from utils.responses import ( extract_text_from_response_items, + extract_token_usage, get_mcp_tools, ) from utils.suid import get_suid @@ -191,6 +195,7 @@ async def retrieve_simple_response( store=False, ) response = cast(OpenAIResponseObject, response) + extract_token_usage(response.usage, model_id) return extract_text_from_response_items(response.output) @@ -242,6 +247,8 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po request_id: str, error: Exception, start_time: float, + model: str, + provider: str, ) -> float: """Record metrics and queue Splunk event for an inference failure. @@ -257,7 +264,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po The total inference time in seconds. """ inference_time = time.monotonic() - start_time - metrics.llm_calls_failures_total.inc() + metrics.llm_calls_failures_total.labels(provider, model).inc() _queue_splunk_event( background_tasks, infer_request, @@ -307,6 +314,7 @@ async def infer_endpoint( input_source = infer_request.get_input_source() instructions = _build_instructions(infer_request.context.systeminfo) model_id = _get_default_model_id() + model, provider = extract_provider_and_model_from_model_id(model_id) mcp_tools = await get_mcp_tools(request_headers=request.headers) logger.debug( "Request %s: Combined input source length: %d", request_id, len(input_source) @@ -321,19 +329,40 @@ async def infer_endpoint( except RuntimeError as e: if "context_length" in str(e).lower(): _record_inference_failure( - background_tasks, infer_request, request, request_id, e, start_time + background_tasks, + infer_request, + request, + request_id, + e, + start_time, + model, + provider, ) logger.error("Prompt too long for request %s: %s", request_id, e) error_response = PromptTooLongResponse(model=model_id) raise HTTPException(**error_response.model_dump()) from e _record_inference_failure( - background_tasks, infer_request, request, request_id, e, start_time + background_tasks, + infer_request, + request, + request_id, + e, + start_time, + model, + provider, ) logger.error("Unexpected RuntimeError for request %s: %s", request_id, e) raise except APIConnectionError as e: _record_inference_failure( - background_tasks, infer_request, request, request_id, e, start_time + background_tasks, + infer_request, + request, + request_id, + e, + start_time, + model, + provider, ) logger.error( "Unable to connect to Llama Stack for request %s: %s", request_id, e @@ -345,7 +374,14 @@ async def infer_endpoint( raise HTTPException(**error_response.model_dump()) from e except RateLimitError as e: _record_inference_failure( - background_tasks, infer_request, request, request_id, e, start_time + background_tasks, + infer_request, + request, + request_id, + e, + start_time, + model, + provider, ) logger.error("Rate limit exceeded for request %s: %s", request_id, e) error_response = QuotaExceededResponse( @@ -355,7 +391,14 @@ async def infer_endpoint( raise HTTPException(**error_response.model_dump()) from e except (APIStatusError, OpenAIAPIStatusError) as e: _record_inference_failure( - background_tasks, infer_request, request, request_id, e, start_time + background_tasks, + infer_request, + request, + request_id, + e, + start_time, + model, + provider, ) logger.exception("API error for request %s: %s", request_id, e) error_response = handle_known_apistatus_errors(e, model_id) diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 912f32ff3..5c4e4e44f 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -33,7 +33,9 @@ ) # Metric that counts how many LLM calls failed -llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures") +llm_calls_failures_total = Counter( + "ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"] +) # Metric that counts how many LLM calls had validation errors llm_calls_validation_errors_total = Counter( diff --git a/tests/unit/app/endpoints/test_metrics.py b/tests/unit/app/endpoints/test_metrics.py index cdd30636d..d2e1dbbab 100644 --- a/tests/unit/app/endpoints/test_metrics.py +++ b/tests/unit/app/endpoints/test_metrics.py @@ -42,7 +42,6 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None: assert "# TYPE ls_provider_model_configuration gauge" in response_body assert "# TYPE ls_llm_calls_total counter" in response_body assert "# TYPE ls_llm_calls_failures_total counter" in response_body - assert "# TYPE ls_llm_calls_failures_created gauge" in response_body assert "# TYPE ls_llm_validation_errors_total counter" in response_body assert "# TYPE ls_llm_validation_errors_created gauge" in response_body assert "# TYPE ls_llm_token_sent_total counter" in response_body