From 5749e9b85385bac48d026f15f9de48b08e8a97da Mon Sep 17 00:00:00 2001 From: attilabalint Date: Tue, 25 Jun 2024 12:06:31 +0200 Subject: [PATCH] propagating error message through server-client cross-validation --- src/enfobench/__version__.py | 2 +- src/enfobench/evaluation/client.py | 6 ++++- src/enfobench/evaluation/server.py | 35 ++++++++++++++++++++++-------- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/enfobench/__version__.py b/src/enfobench/__version__.py index a5f830a..bc8c296 100644 --- a/src/enfobench/__version__.py +++ b/src/enfobench/__version__.py @@ -1 +1 @@ -__version__ = "0.7.1" +__version__ = "0.7.2" diff --git a/src/enfobench/evaluation/client.py b/src/enfobench/evaluation/client.py index 2fc6e6c..0f98aa5 100644 --- a/src/enfobench/evaluation/client.py +++ b/src/enfobench/evaluation/client.py @@ -4,6 +4,7 @@ import pandas as pd import requests +from requests import HTTPError from enfobench.core import ModelInfo from enfobench.evaluation.server import EnvironmentInfo @@ -58,7 +59,10 @@ def forecast( params=params, files=files, ) - if response.status_code != HTTPStatus.OK: + if response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR: + response = json.loads(response.text) + raise HTTPError(response.get("error", "Internal Server Error"), response=response) + elif response.status_code != HTTPStatus.OK: response.raise_for_status() df = pd.DataFrame.from_records(response.json()["forecast"]) diff --git a/src/enfobench/evaluation/server.py b/src/enfobench/evaluation/server.py index e3fa74c..f29b1e1 100644 --- a/src/enfobench/evaluation/server.py +++ b/src/enfobench/evaluation/server.py @@ -2,14 +2,17 @@ import json import sys from typing import Annotated, Any +import traceback as tb import pandas as pd import pkg_resources -from fastapi import FastAPI, File, Query +from fastapi import File, Query from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel from starlette.responses import RedirectResponse +from fastapi import FastAPI +from starlette.exceptions import HTTPException from enfobench.core import ForecasterType, Model @@ -54,6 +57,13 @@ def server_factory(model: Model) -> FastAPI: packages={package.key: package.version for package in pkg_resources.working_set}, ) + @app.exception_handler(HTTPException) + async def http_exception_handler(request, exc): + return JSONResponse({ + "error": exc.detail, + "status_code": exc.status_code, + }, status_code=exc.status_code) + @app.get("/", include_in_schema=False) async def index(): return RedirectResponse(url="/docs") @@ -82,14 +92,21 @@ async def forecast( future_covariates_df = pd.read_parquet(io.BytesIO(future_covariates)) if future_covariates is not None else None metadata = json.load(io.BytesIO(metadata)) if metadata is not None else None - forecast_df = model.forecast( - horizon=horizon, - history=history_df, - past_covariates=past_covariates_df, - future_covariates=future_covariates_df, - metadata=metadata, - level=level, - ) + try: + forecast_df = model.forecast( + horizon=horizon, + history=history_df, + past_covariates=past_covariates_df, + future_covariates=future_covariates_df, + metadata=metadata, + level=level, + ) + except Exception as e: + with io.StringIO() as file: + tb.print_exception(e, file=file) + detail = file.getvalue() + raise HTTPException(500, detail=detail) + forecast_df.fillna(0, inplace=True) forecast_df.rename_axis("timestamp", inplace=True) forecast_df.reset_index(inplace=True)