From cc7be56f3ba6c041eadd44f31a06dc6fdd336302 Mon Sep 17 00:00:00 2001 From: attilabalint Date: Thu, 20 Jun 2024 21:22:40 +0200 Subject: [PATCH] updated forecast client notebook --- notebooks/03. ForecastClient.ipynb | 24406 ++++++++++++++++++++++++++- 1 file changed, 24312 insertions(+), 94 deletions(-) diff --git a/notebooks/03. ForecastClient.ipynb b/notebooks/03. ForecastClient.ipynb index 4407312..2ef3dc1 100644 --- a/notebooks/03. ForecastClient.ipynb +++ b/notebooks/03. ForecastClient.ipynb @@ -2,24 +2,47 @@ "cells": [ { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:26.397469Z", + "start_time": "2024-06-20T12:13:22.828054Z" + } + }, "outputs": [], "source": [ - "import pandas as pd" + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Load Data" + "# Accessing the Electricity Demand Dataset" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, + "source": [ + "This notebook assumes that you have downloaded the [electricity-demand](https://huggingface.co/datasets/EDS-lab/electricity-demand/tree/main/data) dataset from HuggingFace. To execute this notebook create a folder `data/electricity-demand/` and download the 3 files into it." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:26.783025Z", + "start_time": "2024-06-20T12:13:26.401893Z" + } + }, "outputs": [], "source": [ "from enfobench.datasets import ElectricityDemandDataset\n", @@ -29,66 +52,205 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:26.858410Z", + "start_time": "2024-06-20T12:13:26.786827Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, "outputs": [], "source": [ - "unique_ids = ds.metadata_subset.list_unique_ids()" - ], - "metadata": { - "collapsed": false - } + "unique_ids = ds.metadata_subset.list_unique_ids()\n", + "unique_id = unique_ids[0]" + ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:31.647854Z", + "start_time": "2024-06-20T12:13:26.860841Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, "outputs": [], "source": [ - "len(unique_ids)" + "target, weather, metadata = ds.get_data_by_unique_id(unique_id)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:31.680836Z", + "start_time": "2024-06-20T12:13:31.654637Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "DatetimeIndex: 25029 entries, 2012-09-24 12:00:00 to 2014-02-28 00:00:00\n", + "Data columns (total 1 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 y 25028 non-null float64\n", + "dtypes: float64(1)\n", + "memory usage: 391.1 KB\n" + ] + } + ], "source": [ - "unique_id = unique_ids[0]\n", - "unique_id" + "target.info()" ] }, { "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "target, past_covariates, metadata = ds.get_data_by_unique_id(unique_id)" - ], + "execution_count": 6, "metadata": { - "collapsed": false - } + "ExecuteTime": { + "end_time": "2024-06-20T12:13:31.705221Z", + "start_time": "2024-06-20T12:13:31.684762Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "DatetimeIndex: 12853 entries, 2012-09-17 12:00:00 to 2014-03-07 00:00:00\n", + "Data columns (total 32 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 temperature_2m 12853 non-null float32\n", + " 1 relative_humidity_2m 12853 non-null float32\n", + " 2 dew_point_2m 12853 non-null float32\n", + " 3 apparent_temperature 12853 non-null float32\n", + " 4 precipitation 12853 non-null float32\n", + " 5 rain 12853 non-null float32\n", + " 6 snowfall 12853 non-null float32\n", + " 7 snow_depth 12853 non-null float32\n", + " 8 weather_code 12853 non-null float32\n", + " 9 pressure_msl 12853 non-null float32\n", + " 10 surface_pressure 12853 non-null float32\n", + " 11 cloud_cover 12853 non-null float32\n", + " 12 cloud_cover_low 12853 non-null float32\n", + " 13 cloud_cover_mid 12853 non-null float32\n", + " 14 cloud_cover_high 12853 non-null float32\n", + " 15 et0_fao_evapotranspiration 12853 non-null float32\n", + " 16 vapour_pressure_deficit 12853 non-null float32\n", + " 17 wind_speed_10m 12853 non-null float32\n", + " 18 wind_direction_10m 12853 non-null float32\n", + " 19 wind_gusts_10m 12853 non-null float32\n", + " 20 soil_temperature_0_to_7cm 12853 non-null float32\n", + " 21 soil_temperature_7_to_28cm 12853 non-null float32\n", + " 22 soil_moisture_0_to_7cm 12853 non-null float32\n", + " 23 soil_moisture_7_to_28cm 12853 non-null float32\n", + " 24 is_day 12853 non-null float32\n", + " 25 sunshine_duration 12853 non-null float32\n", + " 26 shortwave_radiation 12853 non-null float32\n", + " 27 direct_radiation 12853 non-null float32\n", + " 28 diffuse_radiation 12853 non-null float32\n", + " 29 direct_normal_irradiance 12853 non-null float32\n", + " 30 global_tilted_irradiance 12853 non-null float32\n", + " 31 terrestrial_radiation 12853 non-null float32\n", + "dtypes: float32(32)\n", + "memory usage: 1.7 MB\n" + ] + } + ], + "source": [ + "weather.info()" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:31.714195Z", + "start_time": "2024-06-20T12:13:31.707526Z" + } + }, "outputs": [], "source": [ - "target.info()" - ], - "metadata": { - "collapsed": false - } + "from enfobench.datasets.utils import create_perfect_forecasts_from_covariates" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:33.850977Z", + "start_time": "2024-06-20T12:13:31.717998Z" + } + }, "outputs": [], "source": [ - "past_covariates.info()" - ], + "perfect_weather_forecasts = create_perfect_forecasts_from_covariates(\n", + " weather[['temperature_2m', 'relative_humidity_2m', 'wind_speed_10m', 'wind_direction_10m', 'cloud_cover']],\n", + " start=pd.Timestamp(\"2013-01-01T00:00:00\"),\n", + " horizon=pd.Timedelta(\"4 days\"),\n", + " step=pd.Timedelta(\"24 hour\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, "metadata": { - "collapsed": false - } + "ExecuteTime": { + "end_time": "2024-06-20T12:13:33.870918Z", + "start_time": "2024-06-20T12:13:33.854999Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 41419 entries, 0 to 41418\n", + "Data columns (total 7 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 cutoff_date 41419 non-null datetime64[ns]\n", + " 1 timestamp 41419 non-null datetime64[us]\n", + " 2 temperature_2m 41419 non-null float32 \n", + " 3 relative_humidity_2m 41419 non-null float32 \n", + " 4 wind_speed_10m 41419 non-null float32 \n", + " 5 wind_direction_10m 41419 non-null float32 \n", + " 6 cloud_cover 41419 non-null float32 \n", + "dtypes: datetime64[ns](1), datetime64[us](1), float32(5)\n", + "memory usage: 1.4 MB\n" + ] + } + ], + "source": [ + "perfect_weather_forecasts.info()" + ] }, { "cell_type": "markdown", @@ -99,15 +261,21 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:33.896748Z", + "start_time": "2024-06-20T12:13:33.874753Z" + } + }, "outputs": [], "source": [ "from enfobench import Dataset\n", "\n", "dataset = Dataset(\n", " target=target,\n", - " past_covariates=past_covariates,\n", + " past_covariates=weather,\n", + " future_covariates=perfect_weather_forecasts,\n", " metadata=metadata,\n", ")" ] @@ -121,8 +289,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 37, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:35.400315Z", + "start_time": "2024-06-20T12:13:33.900506Z" + } + }, "outputs": [], "source": [ "from enfobench.evaluation import ForecastClient\n", @@ -132,67 +305,23945 @@ }, { "cell_type": "code", - "execution_count": null, - "outputs": [], + "execution_count": 38, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:35.439724Z", + "start_time": "2024-06-20T12:13:35.403780Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "ModelInfo(name='Darts.NaiveMean', authors=[{'name': 'Attila Balint', 'email': 'attila.balint@kuleuven.be'}], type='point', params={})" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "client.info()" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 39, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:39.790623Z", + "start_time": "2024-06-20T12:13:35.443577Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 29/29 [00:02<00:00, 12.98it/s]\n" + ] + } + ], "source": [ "from enfobench.evaluation import cross_validate\n", "\n", "crossval_df = cross_validate(\n", " client,\n", " dataset,\n", - " start_date=pd.Timestamp(\"2023-01-01T00:00:00\"),\n", - " end_date=pd.Timestamp(\"2023-02-01T00:00:00\"),\n", + " start_date=pd.Timestamp(\"2013-06-01T10:00:00\"),\n", + " end_date=pd.Timestamp(\"2013-07-01T00:00:00\"),\n", " horizon=pd.Timedelta(\"38 hours\"),\n", " step=pd.Timedelta(\"1 day\"),\n", + " level=90,\n", ")" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 40, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:39.832924Z", + "start_time": "2024-06-20T12:13:39.806289Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
cutoff_datetimestampyhaty
02013-06-01 10:00:002013-06-01 10:30:000.2028680.389
12013-06-01 10:00:002013-06-01 11:00:000.2028680.094
22013-06-01 10:00:002013-06-01 11:30:000.2028680.170
32013-06-01 10:00:002013-06-01 12:00:000.2028680.081
42013-06-01 10:00:002013-06-01 12:30:000.2028680.318
\n", + "
" + ], + "text/plain": [ + " cutoff_date timestamp yhat y\n", + "0 2013-06-01 10:00:00 2013-06-01 10:30:00 0.202868 0.389\n", + "1 2013-06-01 10:00:00 2013-06-01 11:00:00 0.202868 0.094\n", + "2 2013-06-01 10:00:00 2013-06-01 11:30:00 0.202868 0.170\n", + "3 2013-06-01 10:00:00 2013-06-01 12:00:00 0.202868 0.081\n", + "4 2013-06-01 10:00:00 2013-06-01 12:30:00 0.202868 0.318" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "crossval_df" + "crossval_df.head()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 41, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:39.849880Z", + "start_time": "2024-06-20T12:13:39.835755Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, "outputs": [], "source": [ - "cutoff_date_to_plot = crossval_df.cutoff_date.unique()[0]\n", - "crossval_df.loc[crossval_df.cutoff_date == cutoff_date_to_plot].set_index(\"timestamp\").drop(\n", - " columns=[\"cutoff_date\"]\n", - ").plot()" + "import matplotlib.animation as animation\n", + "\n", + "plt.rcParams[\"animation.html\"] = \"jshtml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:51.339674Z", + "start_time": "2024-06-20T12:13:39.853659Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } ], + "source": [ + "plt.ioff()\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", + "\n", + "cutoff_dates = crossval_df.cutoff_date.unique()\n", + "\n", + "\n", + "def animate_forecast(t):\n", + " plt.cla()\n", + "\n", + " cutoff_date = cutoff_dates[t]\n", + " history = dataset.get_history(cutoff_date)\n", + " forecast = (\n", + " crossval_df.loc[crossval_df.cutoff_date == cutoff_date].set_index(\"timestamp\").drop(columns=[\"cutoff_date\"])\n", + " )\n", + "\n", + " ax.plot(history.index, history.y)\n", + " ax.plot(forecast.index, forecast.yhat)\n", + " ax.set_xlim(cutoff_dates[0] - pd.Timedelta('7D'), crossval_df.timestamp.max())\n", + " ax.set_ylabel(\"Energy (kWh)\")\n", + " ax.set_title(f\"Predicted energy consumption at {cutoff_date}\", fontsize=\"large\", loc=\"left\")\n", + "\n", + "\n", + "ani = animation.FuncAnimation(fig, animate_forecast, frames=len(cutoff_dates))\n", + "ani" + ] + }, + { + "cell_type": "code", + "execution_count": 43, "metadata": { - "collapsed": false - } + "ExecuteTime": { + "end_time": "2024-06-20T12:13:51.351012Z", + "start_time": "2024-06-20T12:13:51.344111Z" + } + }, + "outputs": [], + "source": [ + "plt.ion()\n", + "plt.close()" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Evaluate metrics" + "# Metrics" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 44, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:51.365870Z", + "start_time": "2024-06-20T12:13:51.356365Z" + } + }, "outputs": [], "source": [ "from enfobench.evaluation import evaluate_metrics\n", @@ -201,50 +24252,217 @@ }, { "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "evaluate_metrics(\n", - " crossval_df,\n", - " metrics={\"MAE\": mean_absolute_error, \"MBE\": mean_bias_error},\n", - ")" - ], + "execution_count": 45, "metadata": { - "collapsed": false - } + "ExecuteTime": { + "end_time": "2024-06-20T12:13:51.389558Z", + "start_time": "2024-06-20T12:13:51.370421Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAEMBEweight
00.1220820.0224031.0
\n", + "
" + ], + "text/plain": [ + " MAE MBE weight\n", + "0 0.122082 0.022403 1.0" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "overall_metrics = evaluate_metrics(crossval_df, metrics={\"MAE\": mean_absolute_error, \"MBE\": mean_bias_error})\n", + "overall_metrics" + ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 46, + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-20T12:13:51.560505Z", + "start_time": "2024-06-20T12:13:51.392953Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 29/29 [00:00<00:00, 297.25it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
cutoff_dateMAEMBEweight
02013-06-01 10:00:000.1069760.0397371.0
12013-06-02 10:00:000.1117810.0487341.0
22013-06-03 10:00:000.1297920.0259331.0
32013-06-04 10:00:000.1320670.0219461.0
42013-06-05 10:00:000.1251430.0047691.0
\n", + "
" + ], + "text/plain": [ + " cutoff_date MAE MBE weight\n", + "0 2013-06-01 10:00:00 0.106976 0.039737 1.0\n", + "1 2013-06-02 10:00:00 0.111781 0.048734 1.0\n", + "2 2013-06-03 10:00:00 0.129792 0.025933 1.0\n", + "3 2013-06-04 10:00:00 0.132067 0.021946 1.0\n", + "4 2013-06-05 10:00:00 0.125143 0.004769 1.0" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "metrics = evaluate_metrics(\n", + "metrics_per_cutoff = evaluate_metrics(\n", " crossval_df,\n", " metrics={\"MAE\": mean_absolute_error, \"MBE\": mean_bias_error},\n", " groupby=\"cutoff_date\",\n", - ")" + ")\n", + "metrics_per_cutoff.head()" ] }, { "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "metrics" - ], + "execution_count": 47, "metadata": { - "collapsed": false - } + "ExecuteTime": { + "end_time": "2024-06-20T12:13:52.185650Z", + "start_time": "2024-06-20T12:13:51.563801Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(12, 5))\n", + "metrics_per_cutoff.set_index(\"cutoff_date\")[[\"MAE\", \"MBE\"]].plot(ax=ax)\n", + "ax.set_xlabel(\"Cutoff date\")\n", + "ax.set_ylabel(\"Energy (kWh)\")\n", + "ax.set_title(\"Prediction metrics per cutoff point\", fontsize=\"large\", loc=\"left\")\n", + "plt.show()" + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "metrics[[\"MAE\", \"MBE\"]].plot()" - ] + "source": [] }, { "cell_type": "code", @@ -270,7 +24488,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.5" } }, "nbformat": 4,