diff --git a/Karpeev2024RiemannianGen/code/ckpt.pth b/Karpeev2024RiemannianGen/code/ckpt.pth
new file mode 100644
index 0000000..e84b2c5
Binary files /dev/null and b/Karpeev2024RiemannianGen/code/ckpt.pth differ
diff --git a/Karpeev2024RiemannianGen/code/score-based model.ipynb b/Karpeev2024RiemannianGen/code/score-based model.ipynb
new file mode 100644
index 0000000..cc905c5
--- /dev/null
+++ b/Karpeev2024RiemannianGen/code/score-based model.ipynb
@@ -0,0 +1,2020 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "6f4478ea94964503aed6695c389cd8d4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0ddcab2b812f48558a8cdfc2aa22abee",
+ "IPY_MODEL_c2a0b4a38be5484fa2992c6d2b3154e9",
+ "IPY_MODEL_3a6bee3c17884ae6bd5655cc75fdacd3"
+ ],
+ "layout": "IPY_MODEL_8b7219c434164f58a21b788ee3d0d1e1"
+ }
+ },
+ "0ddcab2b812f48558a8cdfc2aa22abee": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ee984355aa894cd2a0b68edc85823036",
+ "placeholder": "",
+ "style": "IPY_MODEL_424d7a61642b44ad86d01dd88bc04d87",
+ "value": "Average Loss: 90.768392: 100%"
+ }
+ },
+ "c2a0b4a38be5484fa2992c6d2b3154e9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_129881e1ff164c218164b8f6466b8bf8",
+ "max": 50,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_017c61b626974a5681acb5fd030b927c",
+ "value": 50
+ }
+ },
+ "3a6bee3c17884ae6bd5655cc75fdacd3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9ac6148f92494e37990294a60a63164c",
+ "placeholder": "",
+ "style": "IPY_MODEL_518bcf6e325b410292f18771b5f0ab1c",
+ "value": " 50/50 [2:00:43<00:00, 143.31s/it]"
+ }
+ },
+ "8b7219c434164f58a21b788ee3d0d1e1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ee984355aa894cd2a0b68edc85823036": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "424d7a61642b44ad86d01dd88bc04d87": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "129881e1ff164c218164b8f6466b8bf8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "017c61b626974a5681acb5fd030b927c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "9ac6148f92494e37990294a60a63164c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "518bcf6e325b410292f18771b5f0ab1c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "import numpy as np\n",
+ "import random as rn\n",
+ "import pandas as pd\n",
+ "import tensorflow as tf\n",
+ "import matplotlib.pyplot as plt\n",
+ "from tensorflow.keras.models import Sequential\n",
+ "from tensorflow.keras.layers import LSTM, Dense, Normalization\n",
+ "\n",
+ "# Fix seed for NumPy\n",
+ "np.random.seed(42)\n",
+ "\n",
+ "# Fix seed for TensorFlow\n",
+ "tf.random.set_seed(42)\n",
+ "\n",
+ "rn.seed(42)\n",
+ "\n",
+ "tf.keras.utils.set_random_seed(42)"
+ ],
+ "metadata": {
+ "id": "RX2AKL-mDuL7"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Data generation and plot"
+ ],
+ "metadata": {
+ "id": "391Mb-JwrN2d"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "wdd3K0E9DMzy"
+ },
+ "outputs": [],
+ "source": [
+ "def make_sine_ts(n_points, start_time=0, dimension=1, n_periods=4, ampl=10):\n",
+ " sigma = ampl / 10\n",
+ " time = np.arange(1, n_points + 1)\n",
+ " series_sine = ampl * np.sin(np.tile(time * (2 * np.pi * n_periods) / n_points + start_time, (dimension, 1)).T) + sigma * np.random.randn(n_points, dimension)\n",
+ " table = np.column_stack((time, series_sine))\n",
+ " columns = ['Time'] + [f'Sine_{i}' for i in range(1, dimension + 1)]\n",
+ " ts = pd.DataFrame(table, columns=columns)\n",
+ " return ts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# generate synthetic time series\n",
+ "time = 500000\n",
+ "time_series_data = pd.DataFrame()\n",
+ "for i in range(28):\n",
+ " ts = make_sine_ts(time, dimension=1, start_time=i * 10)\n",
+ " time_series_data[f'Sine_{i+1}'] = ts.iloc[:, 1]\n",
+ "time_series_data"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 444
+ },
+ "id": "z25QSoBsDOEx",
+ "outputId": "51606af9-8656-442a-f025-ae26a685758f"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " Sine_1 Sine_2 Sine_3 Sine_4 Sine_5 Sine_6 \\\n",
+ "0 0.497217 -6.503658 9.298829 -10.018320 8.515519 -3.087901 \n",
+ "1 -0.137259 -6.261307 9.008358 -9.833064 6.953748 -1.567958 \n",
+ "2 0.649197 -4.758246 10.286693 -10.470742 7.579386 -2.304788 \n",
+ "3 1.525040 -5.023999 9.330359 -9.593537 5.970097 -3.081937 \n",
+ "4 -0.231640 -5.019861 9.995089 -12.345050 7.139608 -2.803368 \n",
+ "... ... ... ... ... ... ... \n",
+ "499995 1.331610 -5.005519 11.045806 -9.072662 6.202448 -1.579102 \n",
+ "499996 -0.111264 -5.337173 8.345270 -9.980975 9.403580 -2.855766 \n",
+ "499997 -0.244700 -5.552347 8.396881 -9.643462 6.636783 -2.613017 \n",
+ "499998 -0.797962 -3.970647 8.733367 -10.006217 6.389225 -2.123757 \n",
+ "499999 -1.641133 -4.963781 10.963403 -8.098568 8.320622 -2.658918 \n",
+ "\n",
+ " Sine_7 Sine_8 Sine_9 Sine_10 ... Sine_19 Sine_20 \\\n",
+ "0 -4.155745 6.761862 -10.122948 8.129912 ... -6.787299 10.515523 \n",
+ "1 -4.818426 8.256200 -9.895204 9.407118 ... -6.252890 12.183682 \n",
+ "2 -3.387350 8.695239 -8.756535 7.312252 ... -8.106997 8.541812 \n",
+ "3 -2.547643 7.994574 -10.223428 9.714817 ... -8.088379 9.707112 \n",
+ "4 -2.606996 7.881453 -10.264317 10.024044 ... -8.275953 9.007172 \n",
+ "... ... ... ... ... ... ... ... \n",
+ "499995 -3.000656 7.146260 -8.603152 9.276178 ... -7.040916 10.604008 \n",
+ "499996 -2.594933 8.081965 -11.012017 8.579396 ... -8.076067 9.392880 \n",
+ "499997 -2.584798 7.640324 -9.260246 9.577715 ... -7.690593 11.444883 \n",
+ "499998 -4.265820 7.397909 -10.272636 7.342878 ... -7.761913 10.622368 \n",
+ "499999 -3.889068 6.765554 -10.882433 7.051304 ... -7.133268 10.901605 \n",
+ "\n",
+ " Sine_21 Sine_22 Sine_23 Sine_24 Sine_25 Sine_26 \\\n",
+ "0 -10.214769 6.224937 -0.134012 -6.372275 8.521845 -8.934159 \n",
+ "1 -8.528483 5.866819 0.773012 -5.965893 9.047096 -11.491004 \n",
+ "2 -8.747523 4.362181 0.890155 -7.038727 9.793125 -10.618115 \n",
+ "3 -9.644432 4.559934 1.583819 -6.081536 10.010120 -9.884599 \n",
+ "4 -11.797206 4.055618 1.258454 -5.195012 9.155152 -8.276825 \n",
+ "... ... ... ... ... ... ... \n",
+ "499995 -8.147221 3.497439 2.050472 -4.679167 9.401032 -10.171174 \n",
+ "499996 -8.679248 4.729691 1.935377 -6.885263 8.956104 -9.482445 \n",
+ "499997 -7.913844 5.496157 1.996778 -6.425471 9.931313 -8.826545 \n",
+ "499998 -9.025079 6.211347 1.162714 -5.282455 10.086486 -10.105187 \n",
+ "499999 -6.078834 5.088728 1.510166 -6.862937 10.246452 -10.460876 \n",
+ "\n",
+ " Sine_27 Sine_28 \n",
+ "0 6.005677 -1.905193 \n",
+ "1 7.182516 -0.079388 \n",
+ "2 7.192885 -3.053790 \n",
+ "3 6.513398 -0.453277 \n",
+ "4 4.816136 -2.346869 \n",
+ "... ... ... \n",
+ "499995 6.709303 -1.560944 \n",
+ "499996 6.596734 0.355482 \n",
+ "499997 7.147098 -0.736237 \n",
+ "499998 5.222044 -2.804623 \n",
+ "499999 7.546284 -2.599624 \n",
+ "\n",
+ "[500000 rows x 28 columns]"
+ ],
+ "text/html": [
+ "\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Sine_1 | \n",
+ " Sine_2 | \n",
+ " Sine_3 | \n",
+ " Sine_4 | \n",
+ " Sine_5 | \n",
+ " Sine_6 | \n",
+ " Sine_7 | \n",
+ " Sine_8 | \n",
+ " Sine_9 | \n",
+ " Sine_10 | \n",
+ " ... | \n",
+ " Sine_19 | \n",
+ " Sine_20 | \n",
+ " Sine_21 | \n",
+ " Sine_22 | \n",
+ " Sine_23 | \n",
+ " Sine_24 | \n",
+ " Sine_25 | \n",
+ " Sine_26 | \n",
+ " Sine_27 | \n",
+ " Sine_28 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.497217 | \n",
+ " -6.503658 | \n",
+ " 9.298829 | \n",
+ " -10.018320 | \n",
+ " 8.515519 | \n",
+ " -3.087901 | \n",
+ " -4.155745 | \n",
+ " 6.761862 | \n",
+ " -10.122948 | \n",
+ " 8.129912 | \n",
+ " ... | \n",
+ " -6.787299 | \n",
+ " 10.515523 | \n",
+ " -10.214769 | \n",
+ " 6.224937 | \n",
+ " -0.134012 | \n",
+ " -6.372275 | \n",
+ " 8.521845 | \n",
+ " -8.934159 | \n",
+ " 6.005677 | \n",
+ " -1.905193 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " -0.137259 | \n",
+ " -6.261307 | \n",
+ " 9.008358 | \n",
+ " -9.833064 | \n",
+ " 6.953748 | \n",
+ " -1.567958 | \n",
+ " -4.818426 | \n",
+ " 8.256200 | \n",
+ " -9.895204 | \n",
+ " 9.407118 | \n",
+ " ... | \n",
+ " -6.252890 | \n",
+ " 12.183682 | \n",
+ " -8.528483 | \n",
+ " 5.866819 | \n",
+ " 0.773012 | \n",
+ " -5.965893 | \n",
+ " 9.047096 | \n",
+ " -11.491004 | \n",
+ " 7.182516 | \n",
+ " -0.079388 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.649197 | \n",
+ " -4.758246 | \n",
+ " 10.286693 | \n",
+ " -10.470742 | \n",
+ " 7.579386 | \n",
+ " -2.304788 | \n",
+ " -3.387350 | \n",
+ " 8.695239 | \n",
+ " -8.756535 | \n",
+ " 7.312252 | \n",
+ " ... | \n",
+ " -8.106997 | \n",
+ " 8.541812 | \n",
+ " -8.747523 | \n",
+ " 4.362181 | \n",
+ " 0.890155 | \n",
+ " -7.038727 | \n",
+ " 9.793125 | \n",
+ " -10.618115 | \n",
+ " 7.192885 | \n",
+ " -3.053790 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1.525040 | \n",
+ " -5.023999 | \n",
+ " 9.330359 | \n",
+ " -9.593537 | \n",
+ " 5.970097 | \n",
+ " -3.081937 | \n",
+ " -2.547643 | \n",
+ " 7.994574 | \n",
+ " -10.223428 | \n",
+ " 9.714817 | \n",
+ " ... | \n",
+ " -8.088379 | \n",
+ " 9.707112 | \n",
+ " -9.644432 | \n",
+ " 4.559934 | \n",
+ " 1.583819 | \n",
+ " -6.081536 | \n",
+ " 10.010120 | \n",
+ " -9.884599 | \n",
+ " 6.513398 | \n",
+ " -0.453277 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " -0.231640 | \n",
+ " -5.019861 | \n",
+ " 9.995089 | \n",
+ " -12.345050 | \n",
+ " 7.139608 | \n",
+ " -2.803368 | \n",
+ " -2.606996 | \n",
+ " 7.881453 | \n",
+ " -10.264317 | \n",
+ " 10.024044 | \n",
+ " ... | \n",
+ " -8.275953 | \n",
+ " 9.007172 | \n",
+ " -11.797206 | \n",
+ " 4.055618 | \n",
+ " 1.258454 | \n",
+ " -5.195012 | \n",
+ " 9.155152 | \n",
+ " -8.276825 | \n",
+ " 4.816136 | \n",
+ " -2.346869 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 499995 | \n",
+ " 1.331610 | \n",
+ " -5.005519 | \n",
+ " 11.045806 | \n",
+ " -9.072662 | \n",
+ " 6.202448 | \n",
+ " -1.579102 | \n",
+ " -3.000656 | \n",
+ " 7.146260 | \n",
+ " -8.603152 | \n",
+ " 9.276178 | \n",
+ " ... | \n",
+ " -7.040916 | \n",
+ " 10.604008 | \n",
+ " -8.147221 | \n",
+ " 3.497439 | \n",
+ " 2.050472 | \n",
+ " -4.679167 | \n",
+ " 9.401032 | \n",
+ " -10.171174 | \n",
+ " 6.709303 | \n",
+ " -1.560944 | \n",
+ "
\n",
+ " \n",
+ " 499996 | \n",
+ " -0.111264 | \n",
+ " -5.337173 | \n",
+ " 8.345270 | \n",
+ " -9.980975 | \n",
+ " 9.403580 | \n",
+ " -2.855766 | \n",
+ " -2.594933 | \n",
+ " 8.081965 | \n",
+ " -11.012017 | \n",
+ " 8.579396 | \n",
+ " ... | \n",
+ " -8.076067 | \n",
+ " 9.392880 | \n",
+ " -8.679248 | \n",
+ " 4.729691 | \n",
+ " 1.935377 | \n",
+ " -6.885263 | \n",
+ " 8.956104 | \n",
+ " -9.482445 | \n",
+ " 6.596734 | \n",
+ " 0.355482 | \n",
+ "
\n",
+ " \n",
+ " 499997 | \n",
+ " -0.244700 | \n",
+ " -5.552347 | \n",
+ " 8.396881 | \n",
+ " -9.643462 | \n",
+ " 6.636783 | \n",
+ " -2.613017 | \n",
+ " -2.584798 | \n",
+ " 7.640324 | \n",
+ " -9.260246 | \n",
+ " 9.577715 | \n",
+ " ... | \n",
+ " -7.690593 | \n",
+ " 11.444883 | \n",
+ " -7.913844 | \n",
+ " 5.496157 | \n",
+ " 1.996778 | \n",
+ " -6.425471 | \n",
+ " 9.931313 | \n",
+ " -8.826545 | \n",
+ " 7.147098 | \n",
+ " -0.736237 | \n",
+ "
\n",
+ " \n",
+ " 499998 | \n",
+ " -0.797962 | \n",
+ " -3.970647 | \n",
+ " 8.733367 | \n",
+ " -10.006217 | \n",
+ " 6.389225 | \n",
+ " -2.123757 | \n",
+ " -4.265820 | \n",
+ " 7.397909 | \n",
+ " -10.272636 | \n",
+ " 7.342878 | \n",
+ " ... | \n",
+ " -7.761913 | \n",
+ " 10.622368 | \n",
+ " -9.025079 | \n",
+ " 6.211347 | \n",
+ " 1.162714 | \n",
+ " -5.282455 | \n",
+ " 10.086486 | \n",
+ " -10.105187 | \n",
+ " 5.222044 | \n",
+ " -2.804623 | \n",
+ "
\n",
+ " \n",
+ " 499999 | \n",
+ " -1.641133 | \n",
+ " -4.963781 | \n",
+ " 10.963403 | \n",
+ " -8.098568 | \n",
+ " 8.320622 | \n",
+ " -2.658918 | \n",
+ " -3.889068 | \n",
+ " 6.765554 | \n",
+ " -10.882433 | \n",
+ " 7.051304 | \n",
+ " ... | \n",
+ " -7.133268 | \n",
+ " 10.901605 | \n",
+ " -6.078834 | \n",
+ " 5.088728 | \n",
+ " 1.510166 | \n",
+ " -6.862937 | \n",
+ " 10.246452 | \n",
+ " -10.460876 | \n",
+ " 7.546284 | \n",
+ " -2.599624 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
500000 rows × 28 columns
\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "dataframe",
+ "variable_name": "time_series_data"
+ }
+ },
+ "metadata": {},
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Forecast with score based model"
+ ],
+ "metadata": {
+ "id": "68jZbwWpYWub"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.decomposition import PCA\n",
+ "from sklearn.manifold import MDS\n",
+ "from sklearn.metrics import mean_squared_error\n",
+ "from scipy.linalg import hankel\n",
+ "from sklearn.covariance import LedoitWolf"
+ ],
+ "metadata": {
+ "id": "zSrJWMwb5963"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "time_series_data.shape"
+ ],
+ "metadata": {
+ "id": "5oL91A1MMD1i",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "213a7b63-87e5-4a60-c008-137faaf141f4"
+ },
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(500000, 28)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 5
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "window_size = 6"
+ ],
+ "metadata": {
+ "id": "Lw8VtWANCKRp"
+ },
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def create_hankel_matrix(time_series, window_size):\n",
+ " hankel_matrix = hankel(time_series[:-window_size+1], time_series[-window_size:])\n",
+ " return hankel_matrix"
+ ],
+ "metadata": {
+ "id": "IX0QaWv1_U0e"
+ },
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def calculate_distance_matrix(matrices):\n",
+ " # Check if the matrices have the same number of vectors\n",
+ " # matrices shape (num_series, num_times, window_sz)\n",
+ " distances_matrix = np.zeros((matrices.shape[1], matrices.shape[0], matrices.shape[0]))\n",
+ "\n",
+ " # Iterate through the vectors in the matrices\n",
+ " for i in range(distances_matrix.shape[0]):\n",
+ " mu = np.mean(matrices[:,i,:], axis=-1)\n",
+ " T = matrices.shape[2]\n",
+ " for t in range(T):\n",
+ " x = (matrices[:,i,t] - mu).reshape(-1, 1)\n",
+ " distances_matrix[i] += x @ x.T\n",
+ " distances_matrix[i] *= 1/T\n",
+ " return distances_matrix"
+ ],
+ "metadata": {
+ "id": "E1A5wfDRCnlh"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def reshape_to_2d(matrix):\n",
+ " return matrix.reshape(matrix.shape[0], -1)\n",
+ "\n",
+ "def reshape_to_3d(matrix, num):\n",
+ " return matrix.reshape(matrix.shape[0], num, -1)"
+ ],
+ "metadata": {
+ "id": "Nfs2Jg3uKEKI"
+ },
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def get_distance_matrix(time_series, window_size=6):\n",
+ " matrices = np.stack([create_hankel_matrix(time_series.iloc[:, i], window_size) for i in range(time_series.shape[1])])\n",
+ " print(matrices.shape)\n",
+ " distances = calculate_distance_matrix(matrices)\n",
+ " print(distances.shape)\n",
+ " return distances"
+ ],
+ "metadata": {
+ "id": "pf1HrIN9LtZ2"
+ },
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "X = get_distance_matrix(time_series_data)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "POySDwjOFk84",
+ "outputId": "b25ec455-8b28-4788-f618-fb817cc4bc04"
+ },
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "(28, 499995, 6)\n",
+ "(499995, 28, 28)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "X = X.reshape(X.shape[0], 1, X.shape[1], X.shape[2])\n",
+ "X.shape"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1IdWi25aVUOL",
+ "outputId": "a6633c89-5894-454a-f407-160435c2cd4a"
+ },
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(499995, 1, 28, 28)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 12
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "N = int(0.8 * len(X))\n",
+ "X_train, X_test = X[:N], X[N:]\n",
+ "X_train.shape, X_test.shape"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "7axvnxSgGY82",
+ "outputId": "281d0656-cb5a-465c-e3ae-d8d034715a81"
+ },
+ "execution_count": 13,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "((399996, 1, 28, 28), (99999, 1, 28, 28))"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 13
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import numpy as np\n",
+ "import functools\n",
+ "\n",
+ "from torch.optim import Adam\n",
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "from torchvision import datasets, transforms\n",
+ "from torchvision.datasets import MNIST\n",
+ "import tqdm"
+ ],
+ "metadata": {
+ "id": "rr9h4xIyKxg3"
+ },
+ "execution_count": 14,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "YyQtV7155Nht"
+ },
+ "source": [
+ "#@title Defining a time-dependent score-based model (double click to expand or collapse)\n",
+ "\n",
+ "class GaussianFourierProjection(nn.Module):\n",
+ " \"\"\"Gaussian random features for encoding time steps.\"\"\"\n",
+ " def __init__(self, embed_dim, scale=30.):\n",
+ " super().__init__()\n",
+ " # Randomly sample weights during initialization. These weights are fixed\n",
+ " # during optimization and are not trainable.\n",
+ " self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)\n",
+ " def forward(self, x):\n",
+ " x_proj = x[:, None] * self.W[None, :] * 2 * np.pi\n",
+ " return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)\n",
+ "\n",
+ "\n",
+ "class Dense(nn.Module):\n",
+ " \"\"\"A fully connected layer that reshapes outputs to feature maps.\"\"\"\n",
+ " def __init__(self, input_dim, output_dim):\n",
+ " super().__init__()\n",
+ " self.dense = nn.Linear(input_dim, output_dim)\n",
+ " def forward(self, x):\n",
+ " return self.dense(x)[..., None, None]\n",
+ "\n",
+ "\n",
+ "class ScoreNet(nn.Module):\n",
+ " \"\"\"A time-dependent score-based model built upon U-Net architecture.\"\"\"\n",
+ "\n",
+ " def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):\n",
+ " \"\"\"Initialize a time-dependent score-based network.\n",
+ "\n",
+ " Args:\n",
+ " marginal_prob_std: A function that takes time t and gives the standard\n",
+ " deviation of the perturbation kernel p_{0t}(x(t) | x(0)).\n",
+ " channels: The number of channels for feature maps of each resolution.\n",
+ " embed_dim: The dimensionality of Gaussian random feature embeddings.\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " # Gaussian random feature embedding layer for time\n",
+ " self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),\n",
+ " nn.Linear(embed_dim, embed_dim))\n",
+ " # Encoding layers where the resolution decreases\n",
+ " self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)\n",
+ " self.dense1 = Dense(embed_dim, channels[0])\n",
+ " self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])\n",
+ " self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)\n",
+ " self.dense2 = Dense(embed_dim, channels[1])\n",
+ " self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])\n",
+ " self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)\n",
+ " self.dense3 = Dense(embed_dim, channels[2])\n",
+ " self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])\n",
+ " self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)\n",
+ " self.dense4 = Dense(embed_dim, channels[3])\n",
+ " self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])\n",
+ "\n",
+ " # Decoding layers where the resolution increases\n",
+ " self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)\n",
+ " self.dense5 = Dense(embed_dim, channels[2])\n",
+ " self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])\n",
+ " self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)\n",
+ " self.dense6 = Dense(embed_dim, channels[1])\n",
+ " self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])\n",
+ " self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)\n",
+ " self.dense7 = Dense(embed_dim, channels[0])\n",
+ " self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])\n",
+ " self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)\n",
+ "\n",
+ " # The swish activation function\n",
+ " self.act = lambda x: x * torch.sigmoid(x)\n",
+ " self.marginal_prob_std = marginal_prob_std\n",
+ "\n",
+ " def forward(self, x, t):\n",
+ " # Obtain the Gaussian random feature embedding for t\n",
+ " embed = self.act(self.embed(t))\n",
+ " # Encoding path\n",
+ " h1 = self.conv1(x)\n",
+ " ## Incorporate information from t\n",
+ " h1 += self.dense1(embed)\n",
+ " ## Group normalization\n",
+ " h1 = self.gnorm1(h1)\n",
+ " h1 = self.act(h1)\n",
+ " h2 = self.conv2(h1)\n",
+ " h2 += self.dense2(embed)\n",
+ " h2 = self.gnorm2(h2)\n",
+ " h2 = self.act(h2)\n",
+ " h3 = self.conv3(h2)\n",
+ " h3 += self.dense3(embed)\n",
+ " h3 = self.gnorm3(h3)\n",
+ " h3 = self.act(h3)\n",
+ " h4 = self.conv4(h3)\n",
+ " h4 += self.dense4(embed)\n",
+ " h4 = self.gnorm4(h4)\n",
+ " h4 = self.act(h4)\n",
+ "\n",
+ " # Decoding path\n",
+ " h = self.tconv4(h4)\n",
+ " ## Skip connection from the encoding path\n",
+ " h += self.dense5(embed)\n",
+ " h = self.tgnorm4(h)\n",
+ " h = self.act(h)\n",
+ " h = self.tconv3(torch.cat([h, h3], dim=1))\n",
+ " h += self.dense6(embed)\n",
+ " h = self.tgnorm3(h)\n",
+ " h = self.act(h)\n",
+ " h = self.tconv2(torch.cat([h, h2], dim=1))\n",
+ " h += self.dense7(embed)\n",
+ " h = self.tgnorm2(h)\n",
+ " h = self.act(h)\n",
+ " h = self.tconv1(torch.cat([h, h1], dim=1))\n",
+ "\n",
+ " # Normalize output\n",
+ " h = h / self.marginal_prob_std(t)[:, None, None, None]\n",
+ " return h"
+ ],
+ "execution_count": 15,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Set up the SDE\n",
+ "\n",
+ "device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}\n",
+ "\n",
+ "def marginal_prob_std(t, sigma):\n",
+ " \"\"\"Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.\n",
+ "\n",
+ " Args:\n",
+ " t: A vector of time steps.\n",
+ " sigma: The $\\sigma$ in our SDE.\n",
+ "\n",
+ " Returns:\n",
+ " The standard deviation.\n",
+ " \"\"\"\n",
+ " t = torch.tensor(t, device=device)\n",
+ " return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))\n",
+ "\n",
+ "def diffusion_coeff(t, sigma):\n",
+ " \"\"\"Compute the diffusion coefficient of our SDE.\n",
+ "\n",
+ " Args:\n",
+ " t: A vector of time steps.\n",
+ " sigma: The $\\sigma$ in our SDE.\n",
+ "\n",
+ " Returns:\n",
+ " The vector of diffusion coefficients.\n",
+ " \"\"\"\n",
+ " return torch.tensor(sigma**t, device=device)\n",
+ "\n",
+ "sigma = 25.0#@param {'type':'number'}\n",
+ "marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)\n",
+ "diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)"
+ ],
+ "metadata": {
+ "id": "QQGWsLzUTEsB"
+ },
+ "execution_count": 16,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "zOsoqPdXHuL5"
+ },
+ "source": [
+ "#@title Define the loss function (double click to expand or collapse)\n",
+ "\n",
+ "def loss_fn(model, x, marginal_prob_std, eps=1e-5):\n",
+ " \"\"\"The loss function for training score-based generative models.\n",
+ "\n",
+ " Args:\n",
+ " model: A PyTorch model instance that represents a\n",
+ " time-dependent score-based model.\n",
+ " x: A mini-batch of training data.\n",
+ " marginal_prob_std: A function that gives the standard deviation of\n",
+ " the perturbation kernel.\n",
+ " eps: A tolerance value for numerical stability.\n",
+ " \"\"\"\n",
+ " random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps\n",
+ " z = torch.randn_like(x)\n",
+ " std = marginal_prob_std(random_t)\n",
+ " perturbed_x = x + z * std[:, None, None, None]\n",
+ " score = model(perturbed_x, random_t)\n",
+ " loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))\n",
+ " return loss"
+ ],
+ "execution_count": 17,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "8PPsLx4dGCGa",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 208,
+ "referenced_widgets": [
+ "6f4478ea94964503aed6695c389cd8d4",
+ "0ddcab2b812f48558a8cdfc2aa22abee",
+ "c2a0b4a38be5484fa2992c6d2b3154e9",
+ "3a6bee3c17884ae6bd5655cc75fdacd3",
+ "8b7219c434164f58a21b788ee3d0d1e1",
+ "ee984355aa894cd2a0b68edc85823036",
+ "424d7a61642b44ad86d01dd88bc04d87",
+ "129881e1ff164c218164b8f6466b8bf8",
+ "017c61b626974a5681acb5fd030b927c",
+ "9ac6148f92494e37990294a60a63164c",
+ "518bcf6e325b410292f18771b5f0ab1c"
+ ]
+ },
+ "outputId": "6d2dd74d-7a80-4cc1-d302-f820e5a6c5e1"
+ },
+ "source": [
+ "#@title Training (double click to expand or collapse)\n",
+ "\n",
+ "score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))\n",
+ "score_model = score_model.to(device)\n",
+ "\n",
+ "n_epochs = 50#@param {'type':'integer'}\n",
+ "## size of a mini-batch\n",
+ "batch_size = 32 #@param {'type':'integer'}\n",
+ "## learning rate\n",
+ "lr=1e-4 #@param {'type':'number'}\n",
+ "\n",
+ "# dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)\n",
+ "dataset = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(X_train).float())\n",
+ "data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)\n",
+ "\n",
+ "optimizer = Adam(score_model.parameters(), lr=lr)\n",
+ "tqdm_epoch = tqdm.notebook.trange(n_epochs)\n",
+ "for epoch in tqdm_epoch:\n",
+ " avg_loss = 0.\n",
+ " num_items = 0\n",
+ " for x, y in data_loader:\n",
+ " x = x.to(device)\n",
+ " loss = loss_fn(score_model, x, marginal_prob_std_fn)\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " avg_loss += loss.item() * x.shape[0]\n",
+ " num_items += x.shape[0]\n",
+ " # Print the averaged training loss so far.\n",
+ " tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))\n",
+ " # Update the checkpoint after each epoch of training.\n",
+ " torch.save(score_model.state_dict(), 'ckpt.pth')"
+ ],
+ "execution_count": 18,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
+ " warnings.warn(_create_warning_msg(\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " 0%| | 0/50 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "6f4478ea94964503aed6695c389cd8d4"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ ":15: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " t = torch.tensor(t, device=device)\n",
+ "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "6FxBTOSSH2QR"
+ },
+ "source": [
+ "#@title Define the Euler-Maruyama sampler (double click to expand or collapse)\n",
+ "\n",
+ "## The number of sampling steps.\n",
+ "num_steps = 500#@param {'type':'integer'}\n",
+ "def Euler_Maruyama_sampler(score_model,\n",
+ " init_x,\n",
+ " marginal_prob_std,\n",
+ " diffusion_coeff,\n",
+ " batch_size=64,\n",
+ " num_steps=num_steps,\n",
+ " device='cuda',\n",
+ " eps=1e-3):\n",
+ " \"\"\"Generate samples from score-based models with the Euler-Maruyama solver.\n",
+ "\n",
+ " Args:\n",
+ " score_model: A PyTorch model that represents the time-dependent score-based model.\n",
+ " marginal_prob_std: A function that gives the standard deviation of\n",
+ " the perturbation kernel.\n",
+ " diffusion_coeff: A function that gives the diffusion coefficient of the SDE.\n",
+ " batch_size: The number of samplers to generate by calling this function once.\n",
+ " num_steps: The number of sampling steps.\n",
+ " Equivalent to the number of discretized time steps.\n",
+ " device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.\n",
+ " eps: The smallest time step for numerical stability.\n",
+ "\n",
+ " Returns:\n",
+ " Samples.\n",
+ " \"\"\"\n",
+ " t = torch.ones(batch_size, device=device)\n",
+ " init_x = init_x * marginal_prob_std(t)[:, None, None, None]\n",
+ " time_steps = torch.linspace(1., eps, num_steps, device=device)\n",
+ " step_size = time_steps[0] - time_steps[1]\n",
+ " x = init_x\n",
+ " with torch.no_grad():\n",
+ " for time_step in tqdm.notebook.tqdm(time_steps):\n",
+ " batch_time_step = torch.ones(batch_size, device=device) * time_step\n",
+ " g = diffusion_coeff(batch_time_step)\n",
+ " mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size\n",
+ " x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)\n",
+ " # Do not include any noise in the last sampling step.\n",
+ " return mean_x"
+ ],
+ "execution_count": 19,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Define the Predictor-Corrector sampler (double click to expand or collapse)\n",
+ "\n",
+ "signal_to_noise_ratio = 0.16 #@param {'type':'number'}\n",
+ "\n",
+ "## The number of sampling steps.\n",
+ "num_steps = 500#@param {'type':'integer'}\n",
+ "def pc_sampler(score_model,\n",
+ " init_x,\n",
+ " marginal_prob_std,\n",
+ " diffusion_coeff,\n",
+ " batch_size=64,\n",
+ " num_steps=num_steps,\n",
+ " snr=signal_to_noise_ratio,\n",
+ " device='cuda',\n",
+ " eps=1e-3):\n",
+ " \"\"\"Generate samples from score-based models with Predictor-Corrector method.\n",
+ "\n",
+ " Args:\n",
+ " score_model: A PyTorch model that represents the time-dependent score-based model.\n",
+ " marginal_prob_std: A function that gives the standard deviation\n",
+ " of the perturbation kernel.\n",
+ " diffusion_coeff: A function that gives the diffusion coefficient\n",
+ " of the SDE.\n",
+ " batch_size: The number of samplers to generate by calling this function once.\n",
+ " num_steps: The number of sampling steps.\n",
+ " Equivalent to the number of discretized time steps.\n",
+ " device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.\n",
+ " eps: The smallest time step for numerical stability.\n",
+ "\n",
+ " Returns:\n",
+ " Samples.\n",
+ " \"\"\"\n",
+ " t = torch.ones(batch_size, device=device)\n",
+ " init_x = init_x * marginal_prob_std(t)[:, None, None, None]\n",
+ " time_steps = np.linspace(1., eps, num_steps)\n",
+ " step_size = time_steps[0] - time_steps[1]\n",
+ " x = init_x\n",
+ " with torch.no_grad():\n",
+ " for time_step in tqdm.notebook.tqdm(time_steps):\n",
+ " batch_time_step = torch.ones(batch_size, device=device) * time_step\n",
+ " # Corrector step (Langevin MCMC)\n",
+ " grad = score_model(x, batch_time_step)\n",
+ " grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()\n",
+ " noise_norm = np.sqrt(np.prod(x.shape[1:]))\n",
+ " langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2\n",
+ " x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)\n",
+ "\n",
+ " # Predictor step (Euler-Maruyama)\n",
+ " g = diffusion_coeff(batch_time_step)\n",
+ " x_mean = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size\n",
+ " x = x_mean + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)\n",
+ "\n",
+ " # The last step does not include any noise\n",
+ " return x_mean"
+ ],
+ "metadata": {
+ "id": "bd_FC442w1Pq"
+ },
+ "execution_count": 42,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Define the ODE sampler (double click to expand or collapse)\n",
+ "\n",
+ "from scipy import integrate\n",
+ "\n",
+ "## The error tolerance for the black-box ODE solver\n",
+ "error_tolerance = 1e-5 #@param {'type': 'number'}\n",
+ "def ode_sampler(score_model,\n",
+ " init_x,\n",
+ " marginal_prob_std,\n",
+ " diffusion_coeff,\n",
+ " batch_size=64,\n",
+ " atol=error_tolerance,\n",
+ " rtol=error_tolerance,\n",
+ " device='cuda',\n",
+ " z=None,\n",
+ " eps=1e-3):\n",
+ " \"\"\"Generate samples from score-based models with black-box ODE solvers.\n",
+ "\n",
+ " Args:\n",
+ " score_model: A PyTorch model that represents the time-dependent score-based model.\n",
+ " marginal_prob_std: A function that returns the standard deviation\n",
+ " of the perturbation kernel.\n",
+ " diffusion_coeff: A function that returns the diffusion coefficient of the SDE.\n",
+ " batch_size: The number of samplers to generate by calling this function once.\n",
+ " atol: Tolerance of absolute errors.\n",
+ " rtol: Tolerance of relative errors.\n",
+ " device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.\n",
+ " z: The latent code that governs the final sample. If None, we start from p_1;\n",
+ " otherwise, we start from the given z.\n",
+ " eps: The smallest time step for numerical stability.\n",
+ " \"\"\"\n",
+ " t = torch.ones(batch_size, device=device)\n",
+ " # Create the latent code\n",
+ " if z is None:\n",
+ " init_x = init_x * marginal_prob_std(t)[:, None, None, None]\n",
+ " else:\n",
+ " init_x = z\n",
+ "\n",
+ " shape = init_x.shape\n",
+ "\n",
+ " def score_eval_wrapper(sample, time_steps):\n",
+ " \"\"\"A wrapper of the score-based model for use by the ODE solver.\"\"\"\n",
+ " sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)\n",
+ " time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))\n",
+ " with torch.no_grad():\n",
+ " score = score_model(sample, time_steps)\n",
+ " return score.cpu().numpy().reshape((-1,)).astype(np.float64)\n",
+ "\n",
+ " def ode_func(t, x):\n",
+ " \"\"\"The ODE function for use by the ODE solver.\"\"\"\n",
+ " time_steps = np.ones((shape[0],)) * t\n",
+ " g = diffusion_coeff(torch.tensor(t)).cpu().numpy()\n",
+ " return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)\n",
+ "\n",
+ " # Run the black-box ODE solver.\n",
+ " res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')\n",
+ " print(f\"Number of function evaluations: {res.nfev}\")\n",
+ " x = torch.tensor(res.y[:, -1], device=device).reshape(shape)\n",
+ "\n",
+ " return x\n"
+ ],
+ "metadata": {
+ "id": "rsx-ik6Jw4iv"
+ },
+ "execution_count": 43,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "kKoAPnr7Pf2B",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 606
+ },
+ "outputId": "73a1e373-f55e-4df3-a77b-10c5ad81f3cf"
+ },
+ "source": [
+ "#@title Sampling (double click to expand or collapse)\n",
+ "\n",
+ "from torchvision.utils import make_grid\n",
+ "\n",
+ "## Load the pre-trained checkpoint from disk.\n",
+ "device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}\n",
+ "\n",
+ "ckpt = torch.load('ckpt.pth', map_location=device)\n",
+ "score_model.load_state_dict(ckpt)\n",
+ "\n",
+ "sample_batch_size = 64 #@param {'type':'integer'}\n",
+ "sampler = ode_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}\n",
+ "\n",
+ "init_x = torch.tensor(X_test[:sample_batch_size].copy(), device=device).float()\n",
+ "# init_x = torch.randn(sample_batch_size, 1, 28, 28, device=device)\n",
+ "\n",
+ "## Generate samples using the specified sampler.\n",
+ "samples = sampler(score_model,\n",
+ " init_x,\n",
+ " marginal_prob_std_fn,\n",
+ " diffusion_coeff_fn,\n",
+ " sample_batch_size,\n",
+ " device=device)\n",
+ "test_sampled = samples\n",
+ "## Sample visualization.\n",
+ "samples = samples.clamp(0.0, 1.0)\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))\n",
+ "\n",
+ "plt.figure(figsize=(6,6))\n",
+ "plt.axis('off')\n",
+ "plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)\n",
+ "plt.show()"
+ ],
+ "execution_count": 62,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ ":15: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " t = torch.tensor(t, device=device)\n",
+ ":28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return torch.tensor(sigma**t, device=device)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Number of function evaluations: 290\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "