Skip to content

Commit

Permalink
added usage section to README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-balint-kul committed Jun 22, 2023
1 parent c3ed3da commit b334a6f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
74 changes: 70 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ benchmark forecast models.

- [Installation](#installation)
- [Usage](#usage)
- [Features](#features)
- [Contributing](#contributing)
- [License](#license)

Expand All @@ -28,12 +27,79 @@ pip install enfobench

## Usage

Import your dataset and make sure that the timestamp column in named 'ds' and the target values named 'y'.

```python
import pandas as pd

# Load your dataset and make sure that the timestamp column in named 'ds' and the target values named 'y'
data = (
pd.read_csv("../path/to/your/data.csv")
.rename(columns={"timestamp": "ds", "value": "y"})
)
y = data.set_index("ds")["y"]
```

You can perform a cross validation on any model locally that adheres to the `enfobench.Model` protocol.

```python
import enfobench as efb
import MyModel
from enfobench.evaluation import cross_validate

# Import your model and instantiate it
model = MyModel()

# Run cross validation on your model
cv_results = cross_validate(
model,
start=pd.Timestamp("2018-01-01"),
end=pd.Timestamp("2018-01-31"),
horizon=pd.Timedelta("24 hours"),
step=pd.Timedelta("1 day"),
y=y,
)
```

# returns 3
efb.example.add(1, 2)
You can use the same crossvalidation interface with your model served behind an API.

```python
from enfobench.evaluation import cross_validate, ForecastClient

# Import your model and instantiate it
client = ForecastClient(host='localhost', port=3000)

# Run cross validation on your model
cv_results = cross_validate(
client,
start=pd.Timestamp("2018-01-01"),
end=pd.Timestamp("2018-01-31"),
horizon=pd.Timedelta("24 hours"),
step=pd.Timedelta("1 day"),
y=y,
)
```

The package also collects common metrics for you that you can quickly evaluate on your results.

```python
from enfobench.evaluation import evaluate_metrics_on_forecasts

from enfobench.evaluation.metrics import (
mean_bias_error, mean_absolute_error, mean_squared_error, root_mean_squared_error,
)

# Merge the cross validation results with the original data
forecasts = cv_results.merge(data, on="ds", how="left")

metrics = evaluate_metrics_on_forecasts(
forecasts,
metrics={
"mean_bias_error": mean_bias_error,
"mean_absolute_error": mean_absolute_error,
"mean_squared_error": mean_squared_error,
"root_mean_squared_error": root_mean_squared_error,
},
)
```

## Contributing
Expand Down
2 changes: 1 addition & 1 deletion src/enfobench/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.1.1"
7 changes: 6 additions & 1 deletion src/enfobench/evaluation/_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def generate_cutoff_dates(
def cross_validate(
model: Union[Model, ForecastClient],
start: pd.Timestamp,
end: pd.Timestamp,
horizon: pd.Timedelta,
step: pd.Timedelta,
y: pd.Series,
Expand All @@ -56,6 +57,8 @@ def cross_validate(
Model to cross-validate.
start:
Start date of the time series.
end:
End date of the time series.
horizon:
Forecast horizon.
step:
Expand All @@ -69,7 +72,9 @@ def cross_validate(
Frequency of the time series.
(Optional, if not provided, it will be inferred from the time series index.)
"""
cutoff_dates = generate_cutoff_dates(start, y.index[-1], horizon, step)
if end > y.index[-1]:
raise ValueError("End date is beyond the target values.")
cutoff_dates = generate_cutoff_dates(start, end, horizon, step)
horizon_length = steps_in_horizon(horizon, freq or y.index.inferred_freq)

# Cross-validation
Expand Down

0 comments on commit b334a6f

Please sign in to comment.