diff --git a/src/enfobench/__version__.py b/src/enfobench/__version__.py index 3ced358..b5fdc75 100644 --- a/src/enfobench/__version__.py +++ b/src/enfobench/__version__.py @@ -1 +1 @@ -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/src/enfobench/evaluation/client.py b/src/enfobench/evaluation/client.py index 6e4155d..c21464b 100644 --- a/src/enfobench/evaluation/client.py +++ b/src/enfobench/evaluation/client.py @@ -38,7 +38,7 @@ def environment(self) -> EnvironmentInfo: def forecast( self, horizon: int, - target: pd.DataFrame, + history: pd.DataFrame, past_covariates: Optional[pd.DataFrame] = None, future_covariates: Optional[pd.DataFrame] = None, level: Optional[List[int]] = None, @@ -50,7 +50,7 @@ def forecast( params["level"] = level files = { - "target": to_buffer(target), + "history": to_buffer(history), } if past_covariates is not None: files["past_covariates"] = to_buffer(past_covariates) diff --git a/src/enfobench/evaluation/evaluate.py b/src/enfobench/evaluation/evaluate.py index 7138168..92f94f8 100644 --- a/src/enfobench/evaluation/evaluate.py +++ b/src/enfobench/evaluation/evaluate.py @@ -212,7 +212,7 @@ def cross_validate( forecast = model.forecast( horizon_length, - target=history, + history=history, past_covariates=past_covariates, future_covariates=future_covariates, level=level, diff --git a/src/enfobench/evaluation/protocols.py b/src/enfobench/evaluation/protocols.py index 49d63f5..3a67d6f 100644 --- a/src/enfobench/evaluation/protocols.py +++ b/src/enfobench/evaluation/protocols.py @@ -41,7 +41,7 @@ def info(self) -> ModelInfo: def forecast( self, horizon: int, - target: pd.DataFrame, + history: pd.DataFrame, past_covariates: Optional[pd.DataFrame] = None, future_covariates: Optional[pd.DataFrame] = None, level: Optional[List[int]] = None, diff --git a/src/enfobench/evaluation/server.py b/src/enfobench/evaluation/server.py index 3c87dd0..4150aeb 100644 --- a/src/enfobench/evaluation/server.py +++ b/src/enfobench/evaluation/server.py @@ -29,12 +29,12 @@ async def environment_info(): @app.post("/forecast") async def forecast( horizon: int, - target: Annotated[bytes, File()], + history: Annotated[bytes, File()], past_covariates: Annotated[Optional[bytes], File()] = None, future_covariates: Annotated[Optional[bytes], File()] = None, level: Optional[List[int]] = Query(None), ): - target_df = pd.read_parquet(io.BytesIO(target)) + history_df = pd.read_parquet(io.BytesIO(history)) past_covariates_df = ( pd.read_parquet(io.BytesIO(past_covariates)) if past_covariates is not None else None ) @@ -46,7 +46,7 @@ async def forecast( forecast_df = model.forecast( horizon=horizon, - target=target_df, + history=history_df, past_covariates=past_covariates_df, future_covariates=future_covariates_df, level=level, diff --git a/tests/conftest.py b/tests/conftest.py index 39b0553..20b98ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,17 +22,17 @@ def info(self) -> ModelInfo: def forecast( self, horizon: int, - target, + history, past_covariates=None, future_covariates=None, level=None, **kwargs, ): - index = create_forecast_index(target, horizon) + index = create_forecast_index(history, horizon) return pd.DataFrame( data={ "ds": index, - "yhat": np.full(horizon, fill_value=target["y"].mean()) + self.param1, + "yhat": np.full(horizon, fill_value=history["y"].mean()) + self.param1, } ) diff --git a/tests/test_evaluations/test_server.py b/tests/test_evaluations/test_server.py index 8089bcb..ea58e28 100644 --- a/tests/test_evaluations/test_server.py +++ b/tests/test_evaluations/test_server.py @@ -40,7 +40,7 @@ def test_environment_endpoint(forecast_client): def test_forecast_endpoint(forecast_client): horizon = 24 target_index = pd.date_range(start="2020-01-01", end="2021-01-01", freq="1H") - target_df = pd.DataFrame( + history_df = pd.DataFrame( data={ "y": np.random.normal(size=len(target_index)), "ds": target_index, @@ -53,7 +53,7 @@ def test_forecast_endpoint(forecast_client): "horizon": horizon, }, files={ - "target": to_buffer(target_df), + "history": to_buffer(history_df), }, ) assert response.status_code == 200