From 0e9afa08ef4af4d757d20a144afaa2069e517501 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 10:22:10 +0100 Subject: [PATCH 01/14] Towards keras3 --- CHANGELOG.md | 4 +++ tslearn/shapelets/shapelets.py | 58 ++++++++++++++++++---------------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 542df512..b5d2d37a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ Changelogs for this project are recorded in this file since v0.2.0. ## [Towards v0.7] +### Changed + +* The `shapelets` module now depends on Keras3+ (should be keras-backend-blind) and not anymore on TF + ## [v0.6.3] ### Changed diff --git a/tslearn/shapelets/shapelets.py b/tslearn/shapelets/shapelets.py index 50476a2b..6e46aa0c 100644 --- a/tslearn/shapelets/shapelets.py +++ b/tslearn/shapelets/shapelets.py @@ -1,19 +1,20 @@ -from tensorflow.keras.models import Model, model_from_json -from tensorflow.keras.layers import (InputSpec, Dense, Conv1D, Layer, Input, - concatenate, add) -from tensorflow.keras.metrics import (categorical_accuracy, - categorical_crossentropy, - binary_accuracy, binary_crossentropy) -from tensorflow.keras.utils import to_categorical +import keras +from keras.models import Model, model_from_json +from keras.layers import (InputSpec, Dense, Conv1D, Layer, Input, + concatenate, add) +from keras.metrics import (categorical_accuracy, + categorical_crossentropy, + binary_accuracy, binary_crossentropy) +from keras.utils import to_categorical +from keras.regularizers import l2 +from keras.initializers import Initializer +import keras.ops as ops + from sklearn.base import ClassifierMixin, TransformerMixin from sklearn.utils import check_array, check_X_y from sklearn.utils.validation import check_is_fitted from sklearn.utils.multiclass import unique_labels -from tensorflow.keras.regularizers import l2 -from tensorflow.keras.initializers import Initializer -import tensorflow.keras.backend as K import numpy -import tensorflow as tf import warnings @@ -35,28 +36,31 @@ class GlobalMinPooling1D(Layer): Examples -------- - >>> x = tf.constant([5.0, numpy.nan, 6.8, numpy.nan, numpy.inf]) - >>> x = tf.reshape(x, [1, 5, 1]) + >>> x = numpy.array([5.0, numpy.nan, 6.8, numpy.nan, numpy.inf]) + >>> x = x.reshape([1, 5, 1]) >>> GlobalMinPooling1D()(x).numpy() array([[5.]], dtype=float32) """ - def __init__(self, **kwargs): + def __init__(self, data_format=None, keepdims=False, **kwargs): super().__init__(**kwargs) + + self.data_format = ( + "channels_last" if data_format is None else data_format + ) + self.keepdims = keepdims self.input_spec = InputSpec(ndim=3) def compute_output_shape(self, input_shape): return input_shape[0], input_shape[2] - def call(self, inputs, **kwargs): - inputs_without_nans = tf.where(tf.math.is_finite(inputs), - inputs, - tf.zeros_like(inputs) + numpy.inf) - return tf.reduce_min(inputs_without_nans, axis=1) + def call(self, inputs): + steps_axis = 1 if self.data_format == "channels_last" else 2 + return ops.min(inputs, axis=steps_axis, keepdims=self.keepdims) class GlobalArgminPooling1D(Layer): - """Global min pooling operation for temporal data. + """Global argmin pooling operation for temporal data. # Input shape 3D tensor with shape: `(batch_size, steps, features)`. # Output shape @@ -72,7 +76,7 @@ def compute_output_shape(self, input_shape): return input_shape[0], input_shape[2] def call(self, inputs, **kwargs): - return K.cast(K.argmin(inputs, axis=1), dtype=K.floatx()) + return ops.cast(ops.argmin(inputs, axis=1), dtype=float) def _kmeans_init_shapelets(X, n_shapelets, shp_len, n_draw=10000): @@ -106,7 +110,7 @@ def __call__(self, shape, dtype=None): shapelets = _kmeans_init_shapelets(self.X_, n_shapelets, shp_len)[:, :, 0] - return tf.convert_to_tensor(shapelets, dtype=K.floatx()) + return ops.convert_to_tensor(shapelets, dtype=float) def get_config(self): return {'data': self.X_} @@ -140,11 +144,11 @@ def build(self, input_shape): def call(self, x, **kwargs): # (x - y)^2 = x^2 + y^2 - 2 * x * y - x_sq = K.expand_dims(K.sum(x ** 2, axis=2), axis=-1) - y_sq = K.reshape(K.sum(self.kernel ** 2, axis=1), + x_sq = ops.expand_dims(ops.sum(x ** 2, axis=2), axis=-1) + y_sq = ops.reshape(ops.sum(self.kernel ** 2, axis=1), (1, 1, self.n_shapelets)) - xy = K.dot(x, K.transpose(self.kernel)) - return (x_sq + y_sq - 2 * xy) / K.int_shape(self.kernel)[1] + xy = ops.dot(x, ops.transpose(self.kernel)) + return (x_sq + y_sq - 2 * xy) / ops.shape(self.kernel)[1] def compute_output_shape(self, input_shape): return input_shape[0], input_shape[1], self.n_shapelets @@ -426,7 +430,7 @@ def fit(self, X, y): self._check_series_length(X) if self.random_state is not None: - tf.keras.utils.set_random_seed(seed=self.random_state) + keras.utils.set_random_seed(seed=self.random_state) n_ts, sz, d = X.shape self._X_fit_dims = X.shape From 6e7c0de2087e6f01f2068be183a980ff667cf346 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 10:30:57 +0100 Subject: [PATCH 02/14] edit requirements --- azure-pipelines.yml | 2 +- docs/requirements_rtd.txt | 2 +- requirements.txt | 2 +- requirements_nocast.txt | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index cc04f70c..920b4f48 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -42,7 +42,7 @@ jobs: set -xe python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - python -m pip install tensorflow==2.9.0 + python -m pip install keras>=3 tensorflow python -m pytest -v tslearn/ --doctest-modules displayName: 'Test' diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt index 6859395e..b1ab951c 100644 --- a/docs/requirements_rtd.txt +++ b/docs/requirements_rtd.txt @@ -8,7 +8,7 @@ ipykernel nbsphinx sphinx-gallery pillow -tensorflow>=2 +keras>=3 Pygments numba sphinx_bootstrap_theme diff --git a/requirements.txt b/requirements.txt index b23ded6e..79ce5ae6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ numba scipy scikit-learn joblib>=0.12 -tensorflow>=2 +keras>=3 pandas cesium h5py diff --git a/requirements_nocast.txt b/requirements_nocast.txt index 270f87f5..2d71623d 100644 --- a/requirements_nocast.txt +++ b/requirements_nocast.txt @@ -3,5 +3,5 @@ numba scipy scikit-learn joblib>=0.12 -tensorflow>=2 +keras>=3 h5py From 7c72f13134982c14dc4bf0a8259dc9c86f003af3 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 10:39:53 +0100 Subject: [PATCH 03/14] keras3 not available for Python 3.8 --- .readthedocs.yml | 2 +- azure-pipelines.yml | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 44d7f569..302155f1 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -9,7 +9,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.8" + python: "3.9" # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 920b4f48..114abdbd 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -10,8 +10,8 @@ jobs: vmImage: 'ubuntu-latest' strategy: matrix: - Python38: - python.version: '3.8' + # Python38: + # python.version: '3.8' Python39: python.version: '3.9' variables: @@ -81,7 +81,7 @@ jobs: set -xe python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - python -m pip install tensorflow==2.9.0 + python -m pip install keras>=3 tensorflow python -m pytest -v tslearn/ --doctest-modules -k 'not tslearn.metrics.softdtw_variants.soft_dtw and not tslearn.metrics.softdtw_variants.cdist_soft_dtw and not tslearn.metrics.dtw_variants.dtw or tslearn.metrics.dtw_variants.dtw_' displayName: 'Test' @@ -124,7 +124,7 @@ jobs: python -m pip install coverage pytest-cov python -m pip install cesium pandas stumpy python -m pip install scikit-learn==1.2 - python -m pip install tensorflow==2.9.0 + python -m pip install keras>=3 tensorflow python -m pytest -v tslearn/ --doctest-modules --cov=tslearn displayName: 'Test' @@ -139,8 +139,8 @@ jobs: vmImage: 'macOS-12' strategy: matrix: - Python38: - python.version: '3.8' + # Python38: + # python.version: '3.8' Python39: python.version: '3.9' Python310: @@ -176,7 +176,7 @@ jobs: set -xe python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - python -m pip install tensorflow==2.9.0 + python -m pip install keras>=3 tensorflow python -m pytest -v tslearn/ --doctest-modules -k 'not test_all_estimators' displayName: 'Test' @@ -186,9 +186,9 @@ jobs: vmImage: 'windows-latest' strategy: matrix: - Python38: - python_ver: '38' - python.version: '3.8' + # Python38: + # python_ver: '38' + # python.version: '3.8' Python39: python_ver: '39' python.version: '3.9' @@ -219,6 +219,6 @@ jobs: - script: | python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - python -m pip install tensorflow==2.9.0 + python -m pip install keras>=3 tensorflow python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py displayName: 'Test' From c25b6b75f515cc9e84f5dcfacf76d5544faa67f5 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 10:50:18 +0100 Subject: [PATCH 04/14] add tensorflow --- docs/requirements_rtd.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt index b1ab951c..30e94ee3 100644 --- a/docs/requirements_rtd.txt +++ b/docs/requirements_rtd.txt @@ -9,6 +9,7 @@ nbsphinx sphinx-gallery pillow keras>=3 +tensorflow Pygments numba sphinx_bootstrap_theme From b28cb6374266e3425fffb54c1e4dd6c28319cd79 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 10:54:50 +0100 Subject: [PATCH 05/14] RTD on Python 3.10 --- .readthedocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 302155f1..7a479688 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -9,7 +9,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.9" + python: "3.10" # Build documentation in the docs/ directory with Sphinx sphinx: From 8e6550e6d418de59c5c51eae1a4acda2f00c319c Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 11:12:32 +0100 Subject: [PATCH 06/14] torch instead of TF for RTD --- docs/requirements_rtd.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt index 30e94ee3..b1ab951c 100644 --- a/docs/requirements_rtd.txt +++ b/docs/requirements_rtd.txt @@ -9,7 +9,6 @@ nbsphinx sphinx-gallery pillow keras>=3 -tensorflow Pygments numba sphinx_bootstrap_theme From d578fd862e0b37c5efa6f7f662aa16d3ea2149e7 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 11:21:16 +0100 Subject: [PATCH 07/14] version specifier --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 114abdbd..4d360897 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -176,7 +176,7 @@ jobs: set -xe python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - python -m pip install keras>=3 tensorflow + python -m pip install keras>=3.0 tensorflow python -m pytest -v tslearn/ --doctest-modules -k 'not test_all_estimators' displayName: 'Test' From 32d5637cc0d836d0bba636c0e978393e9917b2dd Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 11:37:22 +0100 Subject: [PATCH 08/14] remove occurrences of TF --- .../examples/classification/plot_shapelet_distances.py | 2 +- .../examples/classification/plot_shapelet_locations.py | 2 +- docs/examples/classification/plot_shapelets.py | 4 ++-- tslearn/shapelets/__init__.py | 2 +- tslearn/tests/test_estimators.py | 2 +- tslearn/tests/test_shapelets.py | 10 +++++----- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/examples/classification/plot_shapelet_distances.py b/docs/examples/classification/plot_shapelet_distances.py index fdd9a307..a83c1b8b 100644 --- a/docs/examples/classification/plot_shapelet_distances.py +++ b/docs/examples/classification/plot_shapelet_distances.py @@ -24,7 +24,7 @@ from tslearn.datasets import CachedDatasets from tslearn.preprocessing import TimeSeriesScalerMinMax from tslearn.shapelets import LearningShapelets -from tensorflow.keras.optimizers import Adam +from keras.optimizers import Adam # Set a seed to ensure determinism numpy.random.seed(42) diff --git a/docs/examples/classification/plot_shapelet_locations.py b/docs/examples/classification/plot_shapelet_locations.py index c55987d2..288c78de 100644 --- a/docs/examples/classification/plot_shapelet_locations.py +++ b/docs/examples/classification/plot_shapelet_locations.py @@ -25,7 +25,7 @@ from tslearn.preprocessing import TimeSeriesScalerMinMax from tslearn.shapelets import LearningShapelets, \ grabocka_params_to_shapelet_size_dict -from tensorflow.keras.optimizers import Adam +from keras.optimizers import Adam # Set a seed to ensure determinism numpy.random.seed(42) diff --git a/docs/examples/classification/plot_shapelets.py b/docs/examples/classification/plot_shapelets.py index 150b70a1..d32ac777 100644 --- a/docs/examples/classification/plot_shapelets.py +++ b/docs/examples/classification/plot_shapelets.py @@ -16,7 +16,7 @@ import numpy from sklearn.metrics import accuracy_score -import tensorflow as tf +import keras import matplotlib.pyplot as plt from tslearn.datasets import CachedDatasets @@ -49,7 +49,7 @@ # Define the model using parameters provided by the authors (except that we # use fewer iterations here) shp_clf = LearningShapelets(n_shapelets_per_size=shapelet_sizes, - optimizer=tf.optimizers.Adam(.01), + optimizer=keras.optimizers.Adam(.01), batch_size=16, weight_regularizer=.01, max_iter=200, diff --git a/tslearn/shapelets/__init__.py b/tslearn/shapelets/__init__.py index cb5bc570..b84c11e1 100644 --- a/tslearn/shapelets/__init__.py +++ b/tslearn/shapelets/__init__.py @@ -1,7 +1,7 @@ """ The :mod:`tslearn.shapelets` module gathers Shapelet-based algorithms. -It depends on the `tensorflow` library for optimization (TF2 is required). +It depends on the `keras` library (Keras3+ is required). **User guide:** See the :ref:`Shapelets ` section for further details. diff --git a/tslearn/tests/test_estimators.py b/tslearn/tests/test_estimators.py index 08108e38..0cb62edd 100644 --- a/tslearn/tests/test_estimators.py +++ b/tslearn/tests/test_estimators.py @@ -77,7 +77,7 @@ def _get_all_classes(): # keras is likely not installed warnings.warn('Skipped common tests for shapelets ' 'as it could not be imported. keras ' - '(and tensorflow) are probably not ' + 'is probably not ' 'installed!') continue elif name.endswith('pytorch_backend'): diff --git a/tslearn/tests/test_shapelets.py b/tslearn/tests/test_shapelets.py index f58a519b..9aa3c71c 100644 --- a/tslearn/tests/test_shapelets.py +++ b/tslearn/tests/test_shapelets.py @@ -8,9 +8,9 @@ def test_shapelets(): - pytest.importorskip('tensorflow') + pytest.importorskip('keras') from tslearn.shapelets import LearningShapelets - import tensorflow as tf + import keras n, sz, d = 15, 10, 2 rng = np.random.RandomState(0) @@ -27,7 +27,7 @@ def test_shapelets(): clf = LearningShapelets(n_shapelets_per_size={2: 5}, max_iter=1, verbose=0, - optimizer=tf.optimizers.Adam(.1), + optimizer=keras.optimizers.Adam(.1), random_state=0) cross_validate(clf, time_series, y, cv=2) @@ -62,7 +62,7 @@ def test_shapelets(): def test_shapelet_lengths(): - pytest.importorskip('tensorflow') + pytest.importorskip('keras') from tslearn.shapelets import LearningShapelets # Test variable-length @@ -97,7 +97,7 @@ def test_shapelet_lengths(): def test_series_lengths(): - pytest.importorskip('tensorflow') + pytest.importorskip('keras') from tslearn.shapelets import LearningShapelets # Test long shapelets From df661c631016a5e72e3d29113efd6d105fee8b0e Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 11:49:07 +0100 Subject: [PATCH 09/14] fix adam LR --- docs/examples/classification/plot_shapelet_distances.py | 2 +- docs/examples/classification/plot_shapelet_locations.py | 2 +- docs/examples/classification/plot_shapelets.py | 2 +- tslearn/tests/test_shapelets.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/examples/classification/plot_shapelet_distances.py b/docs/examples/classification/plot_shapelet_distances.py index a83c1b8b..c1199556 100644 --- a/docs/examples/classification/plot_shapelet_distances.py +++ b/docs/examples/classification/plot_shapelet_distances.py @@ -45,7 +45,7 @@ # Define the model and fit it using the training data shp_clf = LearningShapelets(n_shapelets_per_size=shapelet_sizes, weight_regularizer=0.0001, - optimizer=Adam(lr=0.01), + optimizer=Adam(learning_rate=0.01), max_iter=300, verbose=0, scale=False, diff --git a/docs/examples/classification/plot_shapelet_locations.py b/docs/examples/classification/plot_shapelet_locations.py index 288c78de..556af225 100644 --- a/docs/examples/classification/plot_shapelet_locations.py +++ b/docs/examples/classification/plot_shapelet_locations.py @@ -51,7 +51,7 @@ # Define the model and fit it using the training data shp_clf = LearningShapelets(n_shapelets_per_size=shapelet_sizes, weight_regularizer=0.001, - optimizer=Adam(lr=0.01), + optimizer=Adam(learning_rate=0.01), max_iter=250, verbose=0, scale=False, diff --git a/docs/examples/classification/plot_shapelets.py b/docs/examples/classification/plot_shapelets.py index d32ac777..202b5eba 100644 --- a/docs/examples/classification/plot_shapelets.py +++ b/docs/examples/classification/plot_shapelets.py @@ -49,7 +49,7 @@ # Define the model using parameters provided by the authors (except that we # use fewer iterations here) shp_clf = LearningShapelets(n_shapelets_per_size=shapelet_sizes, - optimizer=keras.optimizers.Adam(.01), + optimizer=keras.optimizers.Adam(learning_rate=.01), batch_size=16, weight_regularizer=.01, max_iter=200, diff --git a/tslearn/tests/test_shapelets.py b/tslearn/tests/test_shapelets.py index 9aa3c71c..466f10b5 100644 --- a/tslearn/tests/test_shapelets.py +++ b/tslearn/tests/test_shapelets.py @@ -27,7 +27,7 @@ def test_shapelets(): clf = LearningShapelets(n_shapelets_per_size={2: 5}, max_iter=1, verbose=0, - optimizer=keras.optimizers.Adam(.1), + optimizer=keras.optimizers.Adam(learning_rate=.1), random_state=0) cross_validate(clf, time_series, y, cv=2) From f2ea5f6892602a61a28053396270a314521f5823 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 12:13:53 +0100 Subject: [PATCH 10/14] remove TF in tests --- azure-pipelines.yml | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 4d360897..54beba3f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -42,8 +42,7 @@ jobs: set -xe python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - python -m pip install keras>=3 tensorflow - python -m pytest -v tslearn/ --doctest-modules + export KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules displayName: 'Test' - job: 'linux_without_torch' @@ -80,8 +79,6 @@ jobs: - script: | set -xe python -m pip install pytest pytest-azurepipelines - python -m pip install scikit-learn==1.2 - python -m pip install keras>=3 tensorflow python -m pytest -v tslearn/ --doctest-modules -k 'not tslearn.metrics.softdtw_variants.soft_dtw and not tslearn.metrics.softdtw_variants.cdist_soft_dtw and not tslearn.metrics.dtw_variants.dtw or tslearn.metrics.dtw_variants.dtw_' displayName: 'Test' @@ -124,8 +121,7 @@ jobs: python -m pip install coverage pytest-cov python -m pip install cesium pandas stumpy python -m pip install scikit-learn==1.2 - python -m pip install keras>=3 tensorflow - python -m pytest -v tslearn/ --doctest-modules --cov=tslearn + export KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules --cov=tslearn displayName: 'Test' # Upload coverage to codecov.io @@ -176,8 +172,7 @@ jobs: set -xe python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - python -m pip install keras>=3.0 tensorflow - python -m pytest -v tslearn/ --doctest-modules -k 'not test_all_estimators' + export KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules -k 'not test_all_estimators' displayName: 'Test' @@ -219,6 +214,5 @@ jobs: - script: | python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - python -m pip install keras>=3 tensorflow - python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py + export KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py displayName: 'Test' From 7dba779eca059503d42f9f76b01c98b6347e36de Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 14:00:32 +0100 Subject: [PATCH 11/14] properly deal with nans in keras-torch --- tslearn/shapelets/shapelets.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tslearn/shapelets/shapelets.py b/tslearn/shapelets/shapelets.py index 6e46aa0c..6b000ffa 100644 --- a/tslearn/shapelets/shapelets.py +++ b/tslearn/shapelets/shapelets.py @@ -56,7 +56,10 @@ def compute_output_shape(self, input_shape): def call(self, inputs): steps_axis = 1 if self.data_format == "channels_last" else 2 - return ops.min(inputs, axis=steps_axis, keepdims=self.keepdims) + inputs_without_nans = ops.where(ops.isfinite(inputs), + inputs, + ops.zeros_like(inputs) + ops.max(inputs[ops.isfinite(inputs)])) + return ops.min(inputs_without_nans, axis=steps_axis, keepdims=self.keepdims) class GlobalArgminPooling1D(Layer): @@ -66,6 +69,13 @@ class GlobalArgminPooling1D(Layer): # Output shape 2D tensor with shape: `(batch_size, features)` + + Examples + -------- + >>> x = numpy.array([5.0, numpy.nan, 6.8, numpy.nan, numpy.inf]) + >>> x = x.reshape([1, 5, 1]) + >>> GlobalArgminPooling1D()(x).numpy() + array([[0.]], dtype=float32) """ def __init__(self, **kwargs): @@ -76,7 +86,10 @@ def compute_output_shape(self, input_shape): return input_shape[0], input_shape[2] def call(self, inputs, **kwargs): - return ops.cast(ops.argmin(inputs, axis=1), dtype=float) + inputs_without_nans = ops.where(ops.isfinite(inputs), + inputs, + ops.zeros_like(inputs) + ops.max(inputs[ops.isfinite(inputs)])) + return ops.cast(ops.argmin(inputs_without_nans, axis=1), dtype=float) def _kmeans_init_shapelets(X, n_shapelets, shp_len, n_draw=10000): @@ -571,7 +584,7 @@ def locate(self, X): >>> X[0, 4:7, 0] = numpy.array([1, 2, 3]) >>> y = [1, 0, 0] >>> # Data is all zeros except a motif 1-2-3 in the first time series - >>> clf = LearningShapelets(n_shapelets_per_size={3: 1}, max_iter=0, + >>> clf = LearningShapelets(n_shapelets_per_size={3: 1}, max_iter=1, ... verbose=0) >>> _ = clf.fit(X, y) >>> weights_shapelet = [ @@ -784,7 +797,7 @@ def get_weights(self, layer_name=None): -------- >>> from tslearn.generators import random_walk_blobs >>> X, y = random_walk_blobs(n_ts_per_blob=100, sz=256, d=1, n_blobs=3) - >>> clf = LearningShapelets(n_shapelets_per_size={10: 5}, max_iter=0, + >>> clf = LearningShapelets(n_shapelets_per_size={10: 5}, max_iter=1, ... verbose=0) >>> clf.fit(X, y).get_weights("classification")[0].shape (5, 3) @@ -820,7 +833,7 @@ def set_weights(self, weights, layer_name=None): -------- >>> from tslearn.generators import random_walk_blobs >>> X, y = random_walk_blobs(n_ts_per_blob=10, sz=16, d=1, n_blobs=3) - >>> clf = LearningShapelets(n_shapelets_per_size={3: 1}, max_iter=0, + >>> clf = LearningShapelets(n_shapelets_per_size={3: 1}, max_iter=1, ... verbose=0) >>> _ = clf.fit(X, y) >>> weights_shapelet = [ From 84a66ea7e2f424b9b717a253f51d0de50755df33 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 14:14:43 +0100 Subject: [PATCH 12/14] Windows equivalent to export is set --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 54beba3f..84925523 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -214,5 +214,5 @@ jobs: - script: | python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - export KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py + set KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py displayName: 'Test' From f2cfa2799aa3a45fe9834b68d594df7abf1c673d Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 15:12:25 +0100 Subject: [PATCH 13/14] do not test shapelets without torch --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 84925523..b08d4233 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -79,7 +79,7 @@ jobs: - script: | set -xe python -m pip install pytest pytest-azurepipelines - python -m pytest -v tslearn/ --doctest-modules -k 'not tslearn.metrics.softdtw_variants.soft_dtw and not tslearn.metrics.softdtw_variants.cdist_soft_dtw and not tslearn.metrics.dtw_variants.dtw or tslearn.metrics.dtw_variants.dtw_' + python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/shapelets/ -k 'not tslearn.metrics.softdtw_variants.soft_dtw and not tslearn.metrics.softdtw_variants.cdist_soft_dtw and not tslearn.metrics.dtw_variants.dtw or tslearn.metrics.dtw_variants.dtw_' displayName: 'Test' From aecce61991026c705ecc9412da52df0e04202452 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 19 Jan 2024 15:21:22 +0100 Subject: [PATCH 14/14] bugfix --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b08d4233..347525f6 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -214,5 +214,5 @@ jobs: - script: | python -m pip install pytest pytest-azurepipelines python -m pip install scikit-learn==1.2 - set KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py + set KERAS_BACKEND=torch && python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py displayName: 'Test'