Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards keras3 #505

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "3.8"
python: "3.10"

# Build documentation in the docs/ directory with Sphinx
sphinx:
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Changelogs for this project are recorded in this file since v0.2.0.

## [Towards v0.7]

### Changed

* The `shapelets` module now depends on Keras3+ (should be keras-backend-blind) and not anymore on TF

## [v0.6.3]

### Changed
Expand Down
30 changes: 12 additions & 18 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ jobs:
vmImage: 'ubuntu-latest'
strategy:
matrix:
Python38:
python.version: '3.8'
# Python38:
# python.version: '3.8'
Python39:
python.version: '3.9'
variables:
Expand Down Expand Up @@ -42,8 +42,7 @@ jobs:
set -xe
python -m pip install pytest pytest-azurepipelines
python -m pip install scikit-learn==1.2
python -m pip install tensorflow==2.9.0
python -m pytest -v tslearn/ --doctest-modules
export KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules
displayName: 'Test'

- job: 'linux_without_torch'
Expand Down Expand Up @@ -80,9 +79,7 @@ jobs:
- script: |
set -xe
python -m pip install pytest pytest-azurepipelines
python -m pip install scikit-learn==1.2
python -m pip install tensorflow==2.9.0
python -m pytest -v tslearn/ --doctest-modules -k 'not tslearn.metrics.softdtw_variants.soft_dtw and not tslearn.metrics.softdtw_variants.cdist_soft_dtw and not tslearn.metrics.dtw_variants.dtw or tslearn.metrics.dtw_variants.dtw_'
python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/shapelets/ -k 'not tslearn.metrics.softdtw_variants.soft_dtw and not tslearn.metrics.softdtw_variants.cdist_soft_dtw and not tslearn.metrics.dtw_variants.dtw or tslearn.metrics.dtw_variants.dtw_'
displayName: 'Test'


Expand Down Expand Up @@ -124,8 +121,7 @@ jobs:
python -m pip install coverage pytest-cov
python -m pip install cesium pandas stumpy
python -m pip install scikit-learn==1.2
python -m pip install tensorflow==2.9.0
python -m pytest -v tslearn/ --doctest-modules --cov=tslearn
export KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules --cov=tslearn
displayName: 'Test'

# Upload coverage to codecov.io
Expand All @@ -139,8 +135,8 @@ jobs:
vmImage: 'macOS-12'
strategy:
matrix:
Python38:
python.version: '3.8'
# Python38:
# python.version: '3.8'
Python39:
python.version: '3.9'
Python310:
Expand Down Expand Up @@ -176,8 +172,7 @@ jobs:
set -xe
python -m pip install pytest pytest-azurepipelines
python -m pip install scikit-learn==1.2
python -m pip install tensorflow==2.9.0
python -m pytest -v tslearn/ --doctest-modules -k 'not test_all_estimators'
export KERAS_BACKEND="torch" && python -m pytest -v tslearn/ --doctest-modules -k 'not test_all_estimators'
displayName: 'Test'


Expand All @@ -186,9 +181,9 @@ jobs:
vmImage: 'windows-latest'
strategy:
matrix:
Python38:
python_ver: '38'
python.version: '3.8'
# Python38:
# python_ver: '38'
# python.version: '3.8'
Python39:
python_ver: '39'
python.version: '3.9'
Expand Down Expand Up @@ -219,6 +214,5 @@ jobs:
- script: |
python -m pip install pytest pytest-azurepipelines
python -m pip install scikit-learn==1.2
python -m pip install tensorflow==2.9.0
python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py
set KERAS_BACKEND=torch && python -m pytest -v tslearn/ --doctest-modules --ignore tslearn/tests/test_estimators.py --ignore tslearn/utils/cast.py
displayName: 'Test'
4 changes: 2 additions & 2 deletions docs/examples/classification/plot_shapelet_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tslearn.datasets import CachedDatasets
from tslearn.preprocessing import TimeSeriesScalerMinMax
from tslearn.shapelets import LearningShapelets
from tensorflow.keras.optimizers import Adam
from keras.optimizers import Adam

# Set a seed to ensure determinism
numpy.random.seed(42)
Expand All @@ -45,7 +45,7 @@
# Define the model and fit it using the training data
shp_clf = LearningShapelets(n_shapelets_per_size=shapelet_sizes,
weight_regularizer=0.0001,
optimizer=Adam(lr=0.01),
optimizer=Adam(learning_rate=0.01),
max_iter=300,
verbose=0,
scale=False,
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/classification/plot_shapelet_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tslearn.preprocessing import TimeSeriesScalerMinMax
from tslearn.shapelets import LearningShapelets, \
grabocka_params_to_shapelet_size_dict
from tensorflow.keras.optimizers import Adam
from keras.optimizers import Adam

# Set a seed to ensure determinism
numpy.random.seed(42)
Expand All @@ -51,7 +51,7 @@
# Define the model and fit it using the training data
shp_clf = LearningShapelets(n_shapelets_per_size=shapelet_sizes,
weight_regularizer=0.001,
optimizer=Adam(lr=0.01),
optimizer=Adam(learning_rate=0.01),
max_iter=250,
verbose=0,
scale=False,
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/classification/plot_shapelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy
from sklearn.metrics import accuracy_score
import tensorflow as tf
import keras
import matplotlib.pyplot as plt

from tslearn.datasets import CachedDatasets
Expand Down Expand Up @@ -49,7 +49,7 @@
# Define the model using parameters provided by the authors (except that we
# use fewer iterations here)
shp_clf = LearningShapelets(n_shapelets_per_size=shapelet_sizes,
optimizer=tf.optimizers.Adam(.01),
optimizer=keras.optimizers.Adam(learning_rate=.01),
batch_size=16,
weight_regularizer=.01,
max_iter=200,
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements_rtd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ipykernel
nbsphinx
sphinx-gallery
pillow
tensorflow>=2
keras>=3
Pygments
numba
sphinx_bootstrap_theme
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ numba
scipy
scikit-learn
joblib>=0.12
tensorflow>=2
keras>=3
pandas
cesium
h5py
2 changes: 1 addition & 1 deletion requirements_nocast.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ numba
scipy
scikit-learn
joblib>=0.12
tensorflow>=2
keras>=3
h5py
2 changes: 1 addition & 1 deletion tslearn/shapelets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
The :mod:`tslearn.shapelets` module gathers Shapelet-based algorithms.

It depends on the `tensorflow` library for optimization (TF2 is required).
It depends on the `keras` library (Keras3+ is required).

**User guide:** See the :ref:`Shapelets <shapelets>` section for further
details.
Expand Down
77 changes: 47 additions & 30 deletions tslearn/shapelets/shapelets.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from tensorflow.keras.models import Model, model_from_json
from tensorflow.keras.layers import (InputSpec, Dense, Conv1D, Layer, Input,
concatenate, add)
from tensorflow.keras.metrics import (categorical_accuracy,
categorical_crossentropy,
binary_accuracy, binary_crossentropy)
from tensorflow.keras.utils import to_categorical
import keras
from keras.models import Model, model_from_json
from keras.layers import (InputSpec, Dense, Conv1D, Layer, Input,
concatenate, add)
from keras.metrics import (categorical_accuracy,
categorical_crossentropy,
binary_accuracy, binary_crossentropy)
from keras.utils import to_categorical
from keras.regularizers import l2
from keras.initializers import Initializer
import keras.ops as ops

from sklearn.base import ClassifierMixin, TransformerMixin
from sklearn.utils import check_array, check_X_y
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.multiclass import unique_labels
from tensorflow.keras.regularizers import l2
from tensorflow.keras.initializers import Initializer
import tensorflow.keras.backend as K
import numpy
import tensorflow as tf

import warnings

Expand All @@ -35,33 +36,46 @@ class GlobalMinPooling1D(Layer):

Examples
--------
>>> x = tf.constant([5.0, numpy.nan, 6.8, numpy.nan, numpy.inf])
>>> x = tf.reshape(x, [1, 5, 1])
>>> x = numpy.array([5.0, numpy.nan, 6.8, numpy.nan, numpy.inf])
>>> x = x.reshape([1, 5, 1])
>>> GlobalMinPooling1D()(x).numpy()
array([[5.]], dtype=float32)
"""

def __init__(self, **kwargs):
def __init__(self, data_format=None, keepdims=False, **kwargs):
super().__init__(**kwargs)

self.data_format = (
"channels_last" if data_format is None else data_format
)
self.keepdims = keepdims
self.input_spec = InputSpec(ndim=3)

def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[2]

def call(self, inputs, **kwargs):
inputs_without_nans = tf.where(tf.math.is_finite(inputs),
inputs,
tf.zeros_like(inputs) + numpy.inf)
return tf.reduce_min(inputs_without_nans, axis=1)
def call(self, inputs):
steps_axis = 1 if self.data_format == "channels_last" else 2
inputs_without_nans = ops.where(ops.isfinite(inputs),
inputs,
ops.zeros_like(inputs) + ops.max(inputs[ops.isfinite(inputs)]))
return ops.min(inputs_without_nans, axis=steps_axis, keepdims=self.keepdims)


class GlobalArgminPooling1D(Layer):
"""Global min pooling operation for temporal data.
"""Global argmin pooling operation for temporal data.
# Input shape
3D tensor with shape: `(batch_size, steps, features)`.
# Output shape
2D tensor with shape:
`(batch_size, features)`

Examples
--------
>>> x = numpy.array([5.0, numpy.nan, 6.8, numpy.nan, numpy.inf])
>>> x = x.reshape([1, 5, 1])
>>> GlobalArgminPooling1D()(x).numpy()
array([[0.]], dtype=float32)
"""

def __init__(self, **kwargs):
Expand All @@ -72,7 +86,10 @@ def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[2]

def call(self, inputs, **kwargs):
return K.cast(K.argmin(inputs, axis=1), dtype=K.floatx())
inputs_without_nans = ops.where(ops.isfinite(inputs),
inputs,
ops.zeros_like(inputs) + ops.max(inputs[ops.isfinite(inputs)]))
return ops.cast(ops.argmin(inputs_without_nans, axis=1), dtype=float)


def _kmeans_init_shapelets(X, n_shapelets, shp_len, n_draw=10000):
Expand Down Expand Up @@ -106,7 +123,7 @@ def __call__(self, shape, dtype=None):
shapelets = _kmeans_init_shapelets(self.X_,
n_shapelets,
shp_len)[:, :, 0]
return tf.convert_to_tensor(shapelets, dtype=K.floatx())
return ops.convert_to_tensor(shapelets, dtype=float)

def get_config(self):
return {'data': self.X_}
Expand Down Expand Up @@ -140,11 +157,11 @@ def build(self, input_shape):

def call(self, x, **kwargs):
# (x - y)^2 = x^2 + y^2 - 2 * x * y
x_sq = K.expand_dims(K.sum(x ** 2, axis=2), axis=-1)
y_sq = K.reshape(K.sum(self.kernel ** 2, axis=1),
x_sq = ops.expand_dims(ops.sum(x ** 2, axis=2), axis=-1)
y_sq = ops.reshape(ops.sum(self.kernel ** 2, axis=1),
(1, 1, self.n_shapelets))
xy = K.dot(x, K.transpose(self.kernel))
return (x_sq + y_sq - 2 * xy) / K.int_shape(self.kernel)[1]
xy = ops.dot(x, ops.transpose(self.kernel))
return (x_sq + y_sq - 2 * xy) / ops.shape(self.kernel)[1]

def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[1], self.n_shapelets
Expand Down Expand Up @@ -426,7 +443,7 @@ def fit(self, X, y):
self._check_series_length(X)

if self.random_state is not None:
tf.keras.utils.set_random_seed(seed=self.random_state)
keras.utils.set_random_seed(seed=self.random_state)

n_ts, sz, d = X.shape
self._X_fit_dims = X.shape
Expand Down Expand Up @@ -567,7 +584,7 @@ def locate(self, X):
>>> X[0, 4:7, 0] = numpy.array([1, 2, 3])
>>> y = [1, 0, 0]
>>> # Data is all zeros except a motif 1-2-3 in the first time series
>>> clf = LearningShapelets(n_shapelets_per_size={3: 1}, max_iter=0,
>>> clf = LearningShapelets(n_shapelets_per_size={3: 1}, max_iter=1,
... verbose=0)
>>> _ = clf.fit(X, y)
>>> weights_shapelet = [
Expand Down Expand Up @@ -780,7 +797,7 @@ def get_weights(self, layer_name=None):
--------
>>> from tslearn.generators import random_walk_blobs
>>> X, y = random_walk_blobs(n_ts_per_blob=100, sz=256, d=1, n_blobs=3)
>>> clf = LearningShapelets(n_shapelets_per_size={10: 5}, max_iter=0,
>>> clf = LearningShapelets(n_shapelets_per_size={10: 5}, max_iter=1,
... verbose=0)
>>> clf.fit(X, y).get_weights("classification")[0].shape
(5, 3)
Expand Down Expand Up @@ -816,7 +833,7 @@ def set_weights(self, weights, layer_name=None):
--------
>>> from tslearn.generators import random_walk_blobs
>>> X, y = random_walk_blobs(n_ts_per_blob=10, sz=16, d=1, n_blobs=3)
>>> clf = LearningShapelets(n_shapelets_per_size={3: 1}, max_iter=0,
>>> clf = LearningShapelets(n_shapelets_per_size={3: 1}, max_iter=1,
... verbose=0)
>>> _ = clf.fit(X, y)
>>> weights_shapelet = [
Expand Down
2 changes: 1 addition & 1 deletion tslearn/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_all_classes():
# keras is likely not installed
warnings.warn('Skipped common tests for shapelets '
'as it could not be imported. keras '
'(and tensorflow) are probably not '
'is probably not '
'installed!')
continue
elif name.endswith('pytorch_backend'):
Expand Down
Loading
Loading