Skip to content

Commit

Permalink
clear naming between history and target
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Jun 28, 2023
1 parent b2baf5e commit 3b65bc4
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.1"
__version__ = "0.2.2"
4 changes: 2 additions & 2 deletions src/enfobench/evaluation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def environment(self) -> EnvironmentInfo:
def forecast(
self,
horizon: int,
target: pd.DataFrame,
history: pd.DataFrame,
past_covariates: Optional[pd.DataFrame] = None,
future_covariates: Optional[pd.DataFrame] = None,
level: Optional[List[int]] = None,
Expand All @@ -50,7 +50,7 @@ def forecast(
params["level"] = level

files = {
"target": to_buffer(target),
"history": to_buffer(history),
}
if past_covariates is not None:
files["past_covariates"] = to_buffer(past_covariates)
Expand Down
2 changes: 1 addition & 1 deletion src/enfobench/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def cross_validate(

forecast = model.forecast(
horizon_length,
target=history,
history=history,
past_covariates=past_covariates,
future_covariates=future_covariates,
level=level,
Expand Down
2 changes: 1 addition & 1 deletion src/enfobench/evaluation/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def info(self) -> ModelInfo:
def forecast(
self,
horizon: int,
target: pd.DataFrame,
history: pd.DataFrame,
past_covariates: Optional[pd.DataFrame] = None,
future_covariates: Optional[pd.DataFrame] = None,
level: Optional[List[int]] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/enfobench/evaluation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ async def environment_info():
@app.post("/forecast")
async def forecast(
horizon: int,
target: Annotated[bytes, File()],
history: Annotated[bytes, File()],
past_covariates: Annotated[Optional[bytes], File()] = None,
future_covariates: Annotated[Optional[bytes], File()] = None,
level: Optional[List[int]] = Query(None),
):
target_df = pd.read_parquet(io.BytesIO(target))
history_df = pd.read_parquet(io.BytesIO(history))
past_covariates_df = (
pd.read_parquet(io.BytesIO(past_covariates)) if past_covariates is not None else None
)
Expand All @@ -46,7 +46,7 @@ async def forecast(

forecast_df = model.forecast(
horizon=horizon,
target=target_df,
history=history_df,
past_covariates=past_covariates_df,
future_covariates=future_covariates_df,
level=level,
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def info(self) -> ModelInfo:
def forecast(
self,
horizon: int,
target,
history,
past_covariates=None,
future_covariates=None,
level=None,
**kwargs,
):
index = create_forecast_index(target, horizon)
index = create_forecast_index(history, horizon)
return pd.DataFrame(
data={
"ds": index,
"yhat": np.full(horizon, fill_value=target["y"].mean()) + self.param1,
"yhat": np.full(horizon, fill_value=history["y"].mean()) + self.param1,
}
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_evaluations/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_environment_endpoint(forecast_client):
def test_forecast_endpoint(forecast_client):
horizon = 24
target_index = pd.date_range(start="2020-01-01", end="2021-01-01", freq="1H")
target_df = pd.DataFrame(
history_df = pd.DataFrame(
data={
"y": np.random.normal(size=len(target_index)),
"ds": target_index,
Expand All @@ -53,7 +53,7 @@ def test_forecast_endpoint(forecast_client):
"horizon": horizon,
},
files={
"target": to_buffer(target_df),
"history": to_buffer(history_df),
},
)
assert response.status_code == 200
Expand Down

0 comments on commit 3b65bc4

Please sign in to comment.