-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first univariate evaluation server and client
- Loading branch information
1 parent
4cdd2cc
commit fe09b66
Showing
8 changed files
with
173 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .evaluation.client import ForecastClient | ||
from .evaluation.protocols import ForecasterType, Model, ModelInfo | ||
from .evaluation.server import server_factory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.0.1" | ||
__version__ = "0.0.2" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from __future__ import annotations | ||
|
||
import io | ||
|
||
import pandas as pd | ||
import requests | ||
|
||
from enfobench.evaluation.protocols import EnvironmentInfo, ModelInfo | ||
|
||
|
||
def to_buffer(df: pd.DataFrame) -> io.BytesIO: | ||
buffer = io.BytesIO() | ||
df.to_parquet(buffer, index=False) | ||
buffer.seek(0) | ||
return buffer | ||
|
||
|
||
class ForecastClient: | ||
def __init__(self, host: str = "localhost", port: int = 3000, secure: bool = False): | ||
self.base_url = f"{'https' if secure else 'http'}://{host}:{port}" | ||
self.session = requests.Session() | ||
|
||
def info(self) -> ModelInfo: | ||
response = self.session.get(f"{self.base_url}/info") | ||
if not response.ok: | ||
response.raise_for_status() | ||
|
||
return ModelInfo(**response.json()) | ||
|
||
def environment(self) -> EnvironmentInfo: | ||
response = self.session.get(f"{self.base_url}/environment") | ||
if not response.ok: | ||
response.raise_for_status() | ||
|
||
return EnvironmentInfo(**response.json()) | ||
|
||
def predict( | ||
self, | ||
horizon: int, | ||
y: pd.DataFrame, | ||
# X: pd.DataFrame, | ||
level: list[int] | None = None, | ||
) -> pd.DataFrame: | ||
params: dict[str, int | list[int]] = { | ||
"horizon": horizon, | ||
} | ||
if level is not None: | ||
params["level"] = level | ||
|
||
files = { | ||
"y": to_buffer(y), | ||
# "X": to_buffer(X), | ||
} | ||
|
||
response = self.session.post( | ||
url=f"{self.base_url}/predict", | ||
params=params, | ||
files=files, | ||
) | ||
if not response.ok: | ||
response.raise_for_status() | ||
|
||
df = pd.read_parquet(io.BytesIO(response.content)) | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from __future__ import annotations | ||
|
||
from enum import Enum | ||
from typing import Any, Protocol | ||
|
||
import pandas as pd | ||
from pydantic import BaseModel | ||
|
||
|
||
class ForecasterType(str, Enum): | ||
point = "point" | ||
quantile = "quantile" | ||
density = "density" | ||
ensemble = "ensemble" | ||
|
||
|
||
class ModelInfo(BaseModel): | ||
name: str | ||
type: ForecasterType | ||
params: dict[str, Any] | ||
|
||
|
||
class EnvironmentInfo(BaseModel): | ||
packages: dict[str, str] | ||
|
||
|
||
class Model(Protocol): | ||
def info(self) -> ModelInfo: | ||
... | ||
|
||
def forecast( | ||
self, | ||
h: int, | ||
y: pd.Series, | ||
X: pd.DataFrame | None = None, | ||
level: list[int] | None = None, | ||
**kwargs, | ||
) -> pd.DataFrame: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
|
||
import io | ||
from typing import Annotated, List | ||
|
||
import pandas as pd | ||
import pkg_resources | ||
from fastapi import FastAPI, File, Query | ||
from fastapi.encoders import jsonable_encoder | ||
from fastapi.responses import JSONResponse | ||
|
||
from enfobench.evaluation.protocols import EnvironmentInfo, Model, ModelInfo | ||
|
||
|
||
def server_factory(model: Model) -> FastAPI: | ||
app = FastAPI() | ||
environment = EnvironmentInfo( | ||
packages={package.key: package.version for package in pkg_resources.working_set} | ||
) | ||
|
||
@app.get("/info", response_model=ModelInfo) | ||
async def model_info(): | ||
"""Return model information.""" | ||
return model.info() | ||
|
||
@app.get("/environment", response_model=EnvironmentInfo) | ||
async def environment_info(): | ||
"""Return information of installed packages and their versions.""" | ||
return environment | ||
|
||
@app.post("/predict") | ||
async def predict( | ||
horizon: int, | ||
y: Annotated[bytes, File()], | ||
# X: Annotated[bytes, File()], | ||
level: List[int] | None = Query(None), | ||
): | ||
y_df = pd.read_parquet(io.BytesIO(y)) | ||
# X_df = pd.read_parquet(io.BytesIO(X)) | ||
|
||
y_df["ds"] = pd.to_datetime(y_df["ds"]) | ||
y = y_df.set_index("ds").y | ||
|
||
forecast = model.forecast( | ||
h=horizon, | ||
y=y, | ||
# X=X_df, | ||
level=level, | ||
) | ||
forecast.fillna(0, inplace=True) | ||
|
||
response = { | ||
"forecast": jsonable_encoder(forecast.to_dict(orient="records")), | ||
} | ||
return JSONResponse( | ||
content=response, | ||
status_code=200, | ||
) | ||
|
||
return app |