Skip to content

Commit

Permalink
propagating error message through server-client cross-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
attilabalint committed Jun 25, 2024
1 parent bd647a9 commit 5749e9b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.1"
__version__ = "0.7.2"
6 changes: 5 additions & 1 deletion src/enfobench/evaluation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pandas as pd
import requests
from requests import HTTPError

from enfobench.core import ModelInfo
from enfobench.evaluation.server import EnvironmentInfo
Expand Down Expand Up @@ -58,7 +59,10 @@ def forecast(
params=params,
files=files,
)
if response.status_code != HTTPStatus.OK:
if response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR:
response = json.loads(response.text)
raise HTTPError(response.get("error", "Internal Server Error"), response=response)
elif response.status_code != HTTPStatus.OK:
response.raise_for_status()

df = pd.DataFrame.from_records(response.json()["forecast"])
Expand Down
35 changes: 26 additions & 9 deletions src/enfobench/evaluation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import json
import sys
from typing import Annotated, Any
import traceback as tb

import pandas as pd
import pkg_resources
from fastapi import FastAPI, File, Query
from fastapi import File, Query
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from starlette.responses import RedirectResponse
from fastapi import FastAPI
from starlette.exceptions import HTTPException

from enfobench.core import ForecasterType, Model

Expand Down Expand Up @@ -54,6 +57,13 @@ def server_factory(model: Model) -> FastAPI:
packages={package.key: package.version for package in pkg_resources.working_set},
)

@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse({
"error": exc.detail,
"status_code": exc.status_code,
}, status_code=exc.status_code)

@app.get("/", include_in_schema=False)
async def index():
return RedirectResponse(url="/docs")
Expand Down Expand Up @@ -82,14 +92,21 @@ async def forecast(
future_covariates_df = pd.read_parquet(io.BytesIO(future_covariates)) if future_covariates is not None else None
metadata = json.load(io.BytesIO(metadata)) if metadata is not None else None

forecast_df = model.forecast(
horizon=horizon,
history=history_df,
past_covariates=past_covariates_df,
future_covariates=future_covariates_df,
metadata=metadata,
level=level,
)
try:
forecast_df = model.forecast(
horizon=horizon,
history=history_df,
past_covariates=past_covariates_df,
future_covariates=future_covariates_df,
metadata=metadata,
level=level,
)
except Exception as e:
with io.StringIO() as file:
tb.print_exception(e, file=file)
detail = file.getvalue()
raise HTTPException(500, detail=detail)

forecast_df.fillna(0, inplace=True)
forecast_df.rename_axis("timestamp", inplace=True)
forecast_df.reset_index(inplace=True)
Expand Down

0 comments on commit 5749e9b

Please sign in to comment.