Skip to content

Commit

Permalink
first univariate evaluation server and client
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Jun 21, 2023
1 parent 4cdd2cc commit fe09b66
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 3 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ PACKAGE_NAME = enfobench
install:
pip install -U pip
pip install -e ."[test,dev]"
mypy --install-types

## Delete all compiled Python files
clean:
Expand Down Expand Up @@ -44,7 +45,7 @@ tests:
#################################################################################

## Build source distribution and wheel
build: lint
build: lint tests
hatch build

## Upload source distribution and wheel to PyPI
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ classifiers = [
"Topic :: Scientific/Engineering",
]
dependencies = [
"pandas<2.0.0,>=1.3.5",
"fastapi>=0.68.0,<1.0.0",
"pandas>=1.3.0,<2.0.0",
"pydantic>=1.0.0,<2.0.0",
"requests>=2.26.0,<3.0.0",
]
dynamic = ["version"]

Expand Down
3 changes: 3 additions & 0 deletions src/enfobench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .evaluation.client import ForecastClient
from .evaluation.protocols import ForecasterType, Model, ModelInfo
from .evaluation.server import server_factory
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.1"
__version__ = "0.0.2"
Empty file.
64 changes: 64 additions & 0 deletions src/enfobench/evaluation/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import io

import pandas as pd
import requests

from enfobench.evaluation.protocols import EnvironmentInfo, ModelInfo


def to_buffer(df: pd.DataFrame) -> io.BytesIO:
buffer = io.BytesIO()
df.to_parquet(buffer, index=False)
buffer.seek(0)
return buffer


class ForecastClient:
def __init__(self, host: str = "localhost", port: int = 3000, secure: bool = False):
self.base_url = f"{'https' if secure else 'http'}://{host}:{port}"
self.session = requests.Session()

def info(self) -> ModelInfo:
response = self.session.get(f"{self.base_url}/info")
if not response.ok:
response.raise_for_status()

return ModelInfo(**response.json())

def environment(self) -> EnvironmentInfo:
response = self.session.get(f"{self.base_url}/environment")
if not response.ok:
response.raise_for_status()

return EnvironmentInfo(**response.json())

def predict(
self,
horizon: int,
y: pd.DataFrame,
# X: pd.DataFrame,
level: list[int] | None = None,
) -> pd.DataFrame:
params: dict[str, int | list[int]] = {
"horizon": horizon,
}
if level is not None:
params["level"] = level

files = {
"y": to_buffer(y),
# "X": to_buffer(X),
}

response = self.session.post(
url=f"{self.base_url}/predict",
params=params,
files=files,
)
if not response.ok:
response.raise_for_status()

df = pd.read_parquet(io.BytesIO(response.content))
return df
39 changes: 39 additions & 0 deletions src/enfobench/evaluation/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Protocol

import pandas as pd
from pydantic import BaseModel


class ForecasterType(str, Enum):
point = "point"
quantile = "quantile"
density = "density"
ensemble = "ensemble"


class ModelInfo(BaseModel):
name: str
type: ForecasterType
params: dict[str, Any]


class EnvironmentInfo(BaseModel):
packages: dict[str, str]


class Model(Protocol):
def info(self) -> ModelInfo:
...

def forecast(
self,
h: int,
y: pd.Series,
X: pd.DataFrame | None = None,
level: list[int] | None = None,
**kwargs,
) -> pd.DataFrame:
...
60 changes: 60 additions & 0 deletions src/enfobench/evaluation/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import io
from typing import Annotated, List

import pandas as pd
import pkg_resources
from fastapi import FastAPI, File, Query
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse

from enfobench.evaluation.protocols import EnvironmentInfo, Model, ModelInfo


def server_factory(model: Model) -> FastAPI:
app = FastAPI()
environment = EnvironmentInfo(
packages={package.key: package.version for package in pkg_resources.working_set}
)

@app.get("/info", response_model=ModelInfo)
async def model_info():
"""Return model information."""
return model.info()

@app.get("/environment", response_model=EnvironmentInfo)
async def environment_info():
"""Return information of installed packages and their versions."""
return environment

@app.post("/predict")
async def predict(
horizon: int,
y: Annotated[bytes, File()],
# X: Annotated[bytes, File()],
level: List[int] | None = Query(None),
):
y_df = pd.read_parquet(io.BytesIO(y))
# X_df = pd.read_parquet(io.BytesIO(X))

y_df["ds"] = pd.to_datetime(y_df["ds"])
y = y_df.set_index("ds").y

forecast = model.forecast(
h=horizon,
y=y,
# X=X_df,
level=level,
)
forecast.fillna(0, inplace=True)

response = {
"forecast": jsonable_encoder(forecast.to_dict(orient="records")),
}
return JSONResponse(
content=response,
status_code=200,
)

return app

0 comments on commit fe09b66

Please sign in to comment.