diff --git a/README.md b/README.md
index 1af739b..3329b72 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,33 @@ method for training NCDEs.
---
+## Update – 22nd May 2025
+
+This repository now supports **Structured Linear Controlled Differential Equations** (SLiCEs), which replace the non-linear vector fields of NCDEs and Log-NCDEs with structured linear vector fields, retaining the same maximal expressivity whilst being significantly more efficient.
+
+SLiCEs are defined by
+
+$$
+h_t = h_0 + \int_0^t \sum_{i=1}^{d_X} A^i_{\theta} h_s \mathrm{d}X_s,
+$$
+
+where each $A^i_{\theta} \in \mathbb{R}^{d_h \times d_h}$ is a trainable matrix acting on the hidden state. When the $A^i_{\theta}$ are dense, this system is known as a **Linear Neural CDE (LNCDE)** and these models are *maximally expressive* (i.e., universal), see [here](https://github.com/Benjamin-Walker/selective-ssms-and-linear-cdes). However, the computational cost and number of parameters when using dense matrices scale as $\mathcal{O}(d_h^3)$, making them impractical for large models.
+
+SLiCEs offer a solution: they retain the maximal expressivity **while reducing computational and memory costs** by structuring the $A^i_{\theta}$ matrices. This repository includes three SLiCE variants:
+- **D-LNCDE**: Diagonal matrices: fastest, but limited expressivity.
+- **BD-LNCDE**: Block-diagonal matrices: maximally expressive and efficient.
+- **DE-LNCDE**: Fully dense matrices: maximally expressive, but computationally expensive.
+
+**In practice**: Replacing the non-linear vector field of a Log-NCDE with the block-diagonal vector field of a BD-LNCDE leads to **20× faster training** per step on the UEA multivariate time-series tasks whilst achieving the same average test accuracy. The figure below compares models on their average test accuracy, average time per 1000 training steps, and average GPU memory, which is represented by the area of each circle.
+
+
+
+
+
+For further details and an expansive comparison with other state-of-the-art sequence models, see the [official SLiCE repository](https://github.com/Benjamin-Walker/structured-linear-cdes).
+
+---
+
## Introduction
Neural controlled differential equations (NCDEs) treat time series data as observations from a control path $X_t$,
@@ -57,6 +84,7 @@ The code for preprocessing the datasets, training S5, LRU, NCDE, NRDE, and Log-N
- `optax` for neural network optimisers.
- `diffrax` for differential equation solvers.
- `signax` for calculating the signature.
+- `roughpy` for calculating the Hall basis.
- `sktime` for handling time series data in ARFF format.
- `tqdm` for progress bars.
- `matplotlib` for plotting.
@@ -67,7 +95,7 @@ conda create -n Log-NCDE python=3.10
conda activate Log-NCDE
conda install pre-commit=3.7.1 sktime=0.30.1 tqdm=4.66.4 matplotlib=3.8.4 -c conda-forge
# Substitue for correct Jax pip install: https://jax.readthedocs.io/en/latest/installation.html
-pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.11.8 optax==0.2.2 diffrax==0.6.0 signax==0.1.1
+pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.12.2 optax==0.2.4 diffrax==0.7.0 signax==0.1.1 roughpy==0.2.0
```
If running `data_dir/process_uea.py` throws this error: No module named 'packaging'
diff --git a/assets/time_vs_acc.png b/assets/time_vs_acc.png
new file mode 100644
index 0000000..ff99a60
Binary files /dev/null and b/assets/time_vs_acc.png differ
diff --git a/data_dir/datasets.py b/data_dir/datasets.py
index 9757743..cca6eab 100644
--- a/data_dir/datasets.py
+++ b/data_dir/datasets.py
@@ -233,6 +233,11 @@ def dataset_generator(
)
+def _scale_to_minus_one_one(x, data_min, data_max, eps=1e-8):
+ """Affine‑maps x from [data_min,data_max] → [‑1,1] with broadcasting."""
+ return 2.0 * (x - data_min) / (data_max - data_min + eps) - 1.0
+
+
def create_uea_dataset(
data_dir,
name,
@@ -242,6 +247,7 @@ def create_uea_dataset(
depth,
include_time,
T,
+ scale=False,
*,
key,
):
@@ -294,6 +300,21 @@ def create_uea_dataset(
)
data = jnp.concatenate([ts[:, :, None], data], axis=2)
+ if scale:
+ if use_presplit:
+ # stack (N,L,C) arrays along N to get all samples
+ all_data = jnp.concatenate([train_data, val_data, test_data], axis=0)
+ data_min = all_data.min(axis=(0, 1), keepdims=True)
+ data_max = all_data.max(axis=(0, 1), keepdims=True)
+
+ train_data = _scale_to_minus_one_one(train_data, data_min, data_max)
+ val_data = _scale_to_minus_one_one(val_data, data_min, data_max)
+ test_data = _scale_to_minus_one_one(test_data, data_min, data_max)
+ else:
+ data_min = data.min(axis=(0, 1), keepdims=True)
+ data_max = data.max(axis=(0, 1), keepdims=True)
+ data = _scale_to_minus_one_one(data, data_min, data_max)
+
return dataset_generator(
name,
data,
@@ -396,6 +417,7 @@ def create_dataset(
depth,
include_time,
T,
+ scale=False,
*,
key,
):
@@ -416,6 +438,7 @@ def create_dataset(
depth,
include_time,
T,
+ scale=scale,
key=key,
)
elif name[:-1] in toy_subfolders:
diff --git a/data_dir/process_uea.py b/data_dir/process_uea.py
index d22a18d..766b6d5 100644
--- a/data_dir/process_uea.py
+++ b/data_dir/process_uea.py
@@ -83,16 +83,19 @@ def convert_all_files(data_dir):
train_file, test_file
)
data = jnp.concatenate([train_data, test_data])
+ orig_data_len = data.shape[0]
labels = jnp.concatenate([train_labels, test_labels])
- unique_rows, indices, inverse_indices = np.unique(
- data, axis=0, return_index=True, return_inverse=True
- )
- data = data[indices]
- labels = labels[indices]
- print(
- f"Deleting {len(inverse_indices) - len(indices)} repeated samples in {ds_name}"
- )
+ # keep first occurrence of each unique row
+ _, first_idx = np.unique(data, axis=0, return_index=True)
+
+ # restore original ordering of those first occurrences
+ keep_idx = np.sort(first_idx)
+
+ data = data[keep_idx]
+ labels = labels[keep_idx]
+
+ print(f"Deleting {orig_data_len - len(data)} repeated samples in {ds_name}")
original_idxs = (
jnp.arange(0, train_data.shape[0]),
diff --git a/experiment_configs/repeats/bd_linear_ncde/EigenWorms.json b/experiment_configs/repeats/bd_linear_ncde/EigenWorms.json
new file mode 100644
index 0000000..bbd55ec
--- /dev/null
+++ b/experiment_configs/repeats/bd_linear_ncde/EigenWorms.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "EigenWorms",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.001",
+ "hidden_dim": "128",
+ "lambd": 0.001,
+ "block_size": 4,
+ "stepsize": 12,
+ "depth": 2
+}
diff --git a/experiment_configs/repeats/bd_linear_ncde/EthanolConcentration.json b/experiment_configs/repeats/bd_linear_ncde/EthanolConcentration.json
new file mode 100644
index 0000000..726a504
--- /dev/null
+++ b/experiment_configs/repeats/bd_linear_ncde/EthanolConcentration.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "EthanolConcentration",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.0001",
+ "hidden_dim": "64",
+ "block_size": 4,
+ "depth": 1,
+ "stepsize": 1,
+ "lambd": 0.000001
+}
diff --git a/experiment_configs/repeats/bd_linear_ncde/Heartbeat.json b/experiment_configs/repeats/bd_linear_ncde/Heartbeat.json
new file mode 100644
index 0000000..95d4b4a
--- /dev/null
+++ b/experiment_configs/repeats/bd_linear_ncde/Heartbeat.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "Heartbeat",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.001",
+ "hidden_dim": "16",
+ "block_size": 4,
+ "depth": 2,
+ "stepsize": 2,
+ "lambd": 0.000001
+}
diff --git a/experiment_configs/repeats/bd_linear_ncde/MotorImagery.json b/experiment_configs/repeats/bd_linear_ncde/MotorImagery.json
new file mode 100644
index 0000000..6dcf39d
--- /dev/null
+++ b/experiment_configs/repeats/bd_linear_ncde/MotorImagery.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "MotorImagery",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.001",
+ "hidden_dim": "16",
+ "block_size": 4,
+ "depth": 2,
+ "stepsize": 16,
+ "lambd": 0.001
+}
diff --git a/experiment_configs/repeats/bd_linear_ncde/SelfRegulationSCP1.json b/experiment_configs/repeats/bd_linear_ncde/SelfRegulationSCP1.json
new file mode 100644
index 0000000..0c70112
--- /dev/null
+++ b/experiment_configs/repeats/bd_linear_ncde/SelfRegulationSCP1.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "SelfRegulationSCP1",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.0001",
+ "hidden_dim": "64",
+ "block_size": 4,
+ "stepsize": 16,
+ "depth": 2,
+ "lambd": 0.0
+}
diff --git a/experiment_configs/repeats/bd_linear_ncde/SelfRegulationSCP2.json b/experiment_configs/repeats/bd_linear_ncde/SelfRegulationSCP2.json
new file mode 100644
index 0000000..4f6972e
--- /dev/null
+++ b/experiment_configs/repeats/bd_linear_ncde/SelfRegulationSCP2.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "SelfRegulationSCP2",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.0001",
+ "hidden_dim": "128",
+ "block_size": 4,
+ "stepsize": 4,
+ "depth": 2,
+ "lambd": 0.001
+}
diff --git a/experiment_configs/repeats/dense_linear_ncde/EigenWorms.json b/experiment_configs/repeats/dense_linear_ncde/EigenWorms.json
new file mode 100644
index 0000000..b09098b
--- /dev/null
+++ b/experiment_configs/repeats/dense_linear_ncde/EigenWorms.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 16,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "EigenWorms",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.001",
+ "hidden_dim": "128",
+ "lambd": 0.001,
+ "block_size": 128,
+ "stepsize": 12,
+ "depth": 2
+}
diff --git a/experiment_configs/repeats/dense_linear_ncde/EthanolConcentration.json b/experiment_configs/repeats/dense_linear_ncde/EthanolConcentration.json
new file mode 100644
index 0000000..7c1aec5
--- /dev/null
+++ b/experiment_configs/repeats/dense_linear_ncde/EthanolConcentration.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "EthanolConcentration",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.0001",
+ "hidden_dim": "64",
+ "block_size": 64,
+ "depth": 1,
+ "stepsize": 1,
+ "lambd": 0.000001
+}
diff --git a/experiment_configs/repeats/dense_linear_ncde/Heartbeat.json b/experiment_configs/repeats/dense_linear_ncde/Heartbeat.json
new file mode 100644
index 0000000..a488b7e
--- /dev/null
+++ b/experiment_configs/repeats/dense_linear_ncde/Heartbeat.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "Heartbeat",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.001",
+ "hidden_dim": "16",
+ "block_size": 16,
+ "depth": 2,
+ "stepsize": 2,
+ "lambd": 0.000001
+}
diff --git a/experiment_configs/repeats/dense_linear_ncde/MotorImagery.json b/experiment_configs/repeats/dense_linear_ncde/MotorImagery.json
new file mode 100644
index 0000000..a2d303c
--- /dev/null
+++ b/experiment_configs/repeats/dense_linear_ncde/MotorImagery.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "MotorImagery",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.001",
+ "hidden_dim": "16",
+ "block_size": 16,
+ "depth": 2,
+ "stepsize": 16,
+ "lambd": 0.001
+}
diff --git a/experiment_configs/repeats/dense_linear_ncde/SelfRegulationSCP1.json b/experiment_configs/repeats/dense_linear_ncde/SelfRegulationSCP1.json
new file mode 100644
index 0000000..2667ab8
--- /dev/null
+++ b/experiment_configs/repeats/dense_linear_ncde/SelfRegulationSCP1.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "SelfRegulationSCP1",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.0001",
+ "hidden_dim": "64",
+ "block_size": 64,
+ "stepsize": 16,
+ "depth": 2,
+ "lambd": 0.0
+}
diff --git a/experiment_configs/repeats/dense_linear_ncde/SelfRegulationSCP2.json b/experiment_configs/repeats/dense_linear_ncde/SelfRegulationSCP2.json
new file mode 100644
index 0000000..f7ab867
--- /dev/null
+++ b/experiment_configs/repeats/dense_linear_ncde/SelfRegulationSCP2.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "SelfRegulationSCP2",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.0001",
+ "hidden_dim": "128",
+ "block_size": 128,
+ "stepsize": 4,
+ "depth": 2,
+ "lambd": 0.001
+}
diff --git a/experiment_configs/repeats/diagonal_linear_ncde/EigenWorms.json b/experiment_configs/repeats/diagonal_linear_ncde/EigenWorms.json
new file mode 100644
index 0000000..92d5f45
--- /dev/null
+++ b/experiment_configs/repeats/diagonal_linear_ncde/EigenWorms.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "EigenWorms",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.001",
+ "hidden_dim": "128",
+ "lambd": 0.001,
+ "block_size": 1,
+ "stepsize": 12,
+ "depth": 1
+}
diff --git a/experiment_configs/repeats/diagonal_linear_ncde/EthanolConcentration.json b/experiment_configs/repeats/diagonal_linear_ncde/EthanolConcentration.json
new file mode 100644
index 0000000..9a71a45
--- /dev/null
+++ b/experiment_configs/repeats/diagonal_linear_ncde/EthanolConcentration.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "EthanolConcentration",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.0001",
+ "hidden_dim": "64",
+ "block_size": 1,
+ "depth": 1,
+ "stepsize": 1,
+ "lambd": 0.000001
+}
diff --git a/experiment_configs/repeats/diagonal_linear_ncde/Heartbeat.json b/experiment_configs/repeats/diagonal_linear_ncde/Heartbeat.json
new file mode 100644
index 0000000..ff81f81
--- /dev/null
+++ b/experiment_configs/repeats/diagonal_linear_ncde/Heartbeat.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "Heartbeat",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "True",
+ "lr": "0.001",
+ "hidden_dim": "16",
+ "block_size": 1,
+ "depth": 1,
+ "stepsize": 2,
+ "lambd": 0.000001
+}
diff --git a/experiment_configs/repeats/diagonal_linear_ncde/MotorImagery.json b/experiment_configs/repeats/diagonal_linear_ncde/MotorImagery.json
new file mode 100644
index 0000000..76ca26e
--- /dev/null
+++ b/experiment_configs/repeats/diagonal_linear_ncde/MotorImagery.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "MotorImagery",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.001",
+ "hidden_dim": "16",
+ "block_size": 1,
+ "depth": 1,
+ "stepsize": 16,
+ "lambd": 0.001
+}
diff --git a/experiment_configs/repeats/diagonal_linear_ncde/SelfRegulationSCP1.json b/experiment_configs/repeats/diagonal_linear_ncde/SelfRegulationSCP1.json
new file mode 100644
index 0000000..e60a2c4
--- /dev/null
+++ b/experiment_configs/repeats/diagonal_linear_ncde/SelfRegulationSCP1.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "SelfRegulationSCP1",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.0001",
+ "hidden_dim": "64",
+ "block_size": 1,
+ "stepsize": 16,
+ "depth": 1,
+ "lambd": 0.0
+}
diff --git a/experiment_configs/repeats/diagonal_linear_ncde/SelfRegulationSCP2.json b/experiment_configs/repeats/diagonal_linear_ncde/SelfRegulationSCP2.json
new file mode 100644
index 0000000..ec72948
--- /dev/null
+++ b/experiment_configs/repeats/diagonal_linear_ncde/SelfRegulationSCP2.json
@@ -0,0 +1,30 @@
+{
+ "seeds": [
+ 2345,
+ 3456,
+ 4567,
+ 5678,
+ 6789
+ ],
+ "data_dir": "data_dir",
+ "output_parent_dir": "",
+ "lr_scheduler": "lambda lr: lr",
+ "num_steps": 100000,
+ "print_steps": 1000,
+ "early_stopping_steps": 10,
+ "batch_size": 32,
+ "model_name": "linear_ncde",
+ "metric": "accuracy",
+ "classification": true,
+ "dataset_name": "SelfRegulationSCP2",
+ "use_presplit": false,
+ "T": 1,
+ "scale": 1,
+ "time": "False",
+ "lr": "0.0001",
+ "hidden_dim": "128",
+ "block_size": 1,
+ "stepsize": 4,
+ "depth": 1,
+ "lambd": 0.001
+}
diff --git a/models/LinearNeuralCDEs.py b/models/LinearNeuralCDEs.py
new file mode 100644
index 0000000..faa7ef3
--- /dev/null
+++ b/models/LinearNeuralCDEs.py
@@ -0,0 +1,217 @@
+"""
+This module implements the `LogLinearCDE` class using JAX and Equinox. The model is a
+block-diagonal Linear Controlled Differential Equation (CDE), where the output is
+approximated during training using the Log-ODE method.
+
+Attributes of the `LogLinearCDE` model:
+- `init_layer`: The linear layer used to initialize the hidden state $h_0$ from the input $x_0$.
+- `out_layer`: The linear layer used to produce final predictions from the hidden state.
+- `vf_A`: Learnable parameters for the linear vector field, shaped as flattened block matrices.
+- `hidden_dim`: The dimension of the hidden state $h_t$.
+- `block_size`: Size of each square block in the block-diagonal vector field.
+- `num_blocks`: Number of blocks, computed as `hidden_dim // block_size`.
+- `parallel_steps`: Number of log-flow matrices composed in parallel (using associative scan).
+- `logsig_depth`: The depth of the log-signature used in the Log-ODE method.
+- `basis_list`: The list of basis elements of the free Lie algebra up to the specified depth.
+- `lambd`: Regularization parameter applied to vector field scaling.
+- `w_init_std`: Standard deviation for the initial weights of the vector field.
+- `classification`: Boolean indicating if the model is used for classification tasks.
+
+The class includes:
+- `log_ode`: Method for computing the iterated Lie brackets of the linear vector fields.
+- `__call__`: Performs the forward pass, where flows are composed and applied to the hidden state
+ either step-by-step or in parallel (using associative scan), followed by output projection.
+"""
+
+from __future__ import annotations
+
+from typing import List, Tuple
+
+import equinox as eqx
+import jax
+import jax.numpy as jnp
+import jax.random as jr
+import roughpy as rp
+
+
+def to_tuple(el):
+ """Convert a basis element which may be an int or a nested [x,y] list into a nested tuple."""
+ if isinstance(el, int):
+ return (el,)
+ else:
+ return to_tuple(el[0]), to_tuple(el[1])
+
+
+def depth(b):
+ """Compute the 'depth' of a bracket structure."""
+ if isinstance(b, int):
+ return 1
+ elif isinstance(b, list):
+ return max(depth(b[0]), depth(b[1])) + 1
+ else:
+ raise TypeError("Invalid basis element type.")
+
+
+class LogLinearCDE(eqx.Module):
+ init_layer: eqx.nn.Linear
+ out_layer: eqx.nn.Linear
+ vf_A: jnp.ndarray
+ hidden_dim: int
+ block_size: int
+ num_blocks: int
+ parallel_steps: int
+ logsig_depth: int
+ basis_list: List[Tuple[int, ...]]
+ lambd: float
+ w_init_std: float
+ classification: bool
+
+ lip2: bool = True
+ nondeterministic: bool = False
+ stateful: bool = False
+
+ def __init__(
+ self,
+ *,
+ data_dim: int,
+ hidden_dim: int,
+ label_dim: int,
+ block_size: int,
+ logsig_depth: int,
+ lambd: float = 1.0,
+ w_init_std: float = 0.25,
+ parallel_steps: int = 128,
+ classification: bool = True,
+ key,
+ ):
+ if hidden_dim % block_size != 0:
+ raise ValueError("hidden_dim must be divisible by block_size.")
+ self.hidden_dim = hidden_dim
+ self.block_size = block_size
+ self.num_blocks = hidden_dim // block_size
+ self.parallel_steps = parallel_steps
+ self.logsig_depth = logsig_depth
+ ctx = rp.get_context(width=data_dim, depth=self.logsig_depth, coeffs=rp.DPReal)
+ basis = ctx.lie_basis
+ basis_list = []
+ for i in range(basis.size(self.logsig_depth)):
+ basis_list.append(eval(str(basis.index_to_key(i))))
+ self.basis_list = basis_list
+ self.lambd = lambd
+ self.w_init_std = w_init_std
+
+ k_init, k_A, k_B = jr.split(key, 3)
+ self.init_layer = eqx.nn.Linear(data_dim, hidden_dim, key=k_init)
+ self.out_layer = eqx.nn.Linear(hidden_dim, label_dim, key=k_B)
+
+ self.vf_A = (
+ jr.normal(k_A, (data_dim + 1, self.num_blocks * block_size * block_size))
+ * self.w_init_std
+ / jnp.sqrt(block_size)
+ )
+ self.classification = classification
+
+ def log_ode(self, vf):
+
+ basis_index = {}
+ for i, b in enumerate(self.basis_list):
+ basis_index[to_tuple(b)] = i
+
+ depth_to_elements = {}
+ for i, b in enumerate(self.basis_list):
+ d = depth(b)
+ depth_to_elements.setdefault(d, []).append((i, b))
+
+ A_arrays = [None] * len(self.basis_list)
+
+ for i_b, b in depth_to_elements[1]:
+ A_arrays[i_b] = vf[b - 1, :, :]
+
+ max_depth = max(depth_to_elements.keys())
+ for d in range(2, max_depth + 1):
+ curr_elements = depth_to_elements[d]
+
+ left_indices = []
+ right_indices = []
+ for (i_b, b) in curr_elements:
+ u_tuple = to_tuple(b[0])
+ v_tuple = to_tuple(b[1])
+ i_u = basis_index[u_tuple]
+ i_v = basis_index[v_tuple]
+ left_indices.append(i_u)
+ right_indices.append(i_v)
+
+ A_left = jnp.stack([A_arrays[i_u] for i_u in left_indices], axis=0)
+ A_right = jnp.stack([A_arrays[i_v] for i_v in right_indices], axis=0)
+
+ A_uv = jnp.einsum("ijk,ikl->ijl", A_right, A_left) - jnp.einsum(
+ "ijk,ikl->ijl", A_left, A_right
+ )
+
+ for idx, (i_b, b) in enumerate(curr_elements):
+ A_arrays[i_b] = A_uv[idx]
+
+ return jnp.stack(A_arrays, axis=2)
+
+ def __call__(self, X):
+ ts, logsigs, x0 = X
+
+ y0 = self.init_layer(x0)
+
+ vfs = self.vf_A.reshape(-1, self.num_blocks, self.block_size, self.block_size)
+ lie_brackets = jax.vmap(self.log_ode, in_axes=(1))(vfs)
+ log_flows = jnp.einsum("ijkl,ml->mijk", lie_brackets, logsigs[:, 1:])
+ flows = log_flows + jnp.eye(self.block_size)[None, None, :, :]
+
+ def step(y, flow):
+ y_block = y.reshape(self.num_blocks, self.block_size, 1)
+ y_next = flow @ y_block
+ y_next = y_next.reshape(
+ self.hidden_dim,
+ )
+ return y_next, y_next
+
+ def parallel_step(y, flows):
+ compose = lambda a, b: jnp.matmul(b, a)
+ flow_total = jax.lax.associative_scan(compose, flows)
+ y_block = y.reshape(self.num_blocks, self.block_size, 1)
+ y_new = jnp.matmul(flow_total, y_block).reshape(-1, self.hidden_dim)
+ return y_new[-1], y_new
+
+ if self.parallel_steps == 1:
+ scan_fn = step
+ remainder = 0
+ scan_inp = flows
+ else:
+ scan_fn = parallel_step
+ t = len(flows)
+ remainder = (t - 1) % self.parallel_steps
+ core = flows[1:] if remainder == 0 else flows[1:-remainder]
+ scan_inp = jnp.reshape(
+ core,
+ (
+ -1,
+ self.parallel_steps,
+ self.num_blocks,
+ self.block_size,
+ self.block_size,
+ ),
+ )
+
+ _, ys = jax.lax.scan(scan_fn, y0, scan_inp) # (T‑1, H)
+ if len(ys.shape) == 3:
+ ys = jnp.reshape(ys, (-1, self.hidden_dim))
+ ys = jnp.vstack([y0, ys])
+ if remainder != 0:
+ inp_rem = flows[-remainder:]
+ _, y_rem = jax.lax.scan(step, ys[-1], inp_rem)
+ ys = jnp.vstack([ys, y_rem])
+
+ if self.classification:
+ ys = jnp.mean(ys, axis=0)
+ preds = jax.nn.softmax(self.out_layer(ys))
+ else:
+ ys = jax.vmap(self.out_layer)(ys)
+ preds = jnp.tanh(ys)
+
+ return preds
diff --git a/models/generate_model.py b/models/generate_model.py
index e41a3d3..0c95213 100644
--- a/models/generate_model.py
+++ b/models/generate_model.py
@@ -42,6 +42,7 @@
import equinox as eqx
import jax.random as jr
+from models.LinearNeuralCDEs import LogLinearCDE
from models.LogNeuralCDEs import LogNeuralCDE
from models.LRU import LRU
from models.NeuralCDEs import NeuralCDE, NeuralRDE
@@ -58,6 +59,7 @@ def create_model(
label_dim,
hidden_dim,
num_blocks=None,
+ block_size=None,
vf_depth=None,
vf_width=None,
classification=True,
@@ -70,6 +72,8 @@ def create_model(
max_steps=16**4,
scale=1.0,
lambd=0.0,
+ stepsize=1,
+ w_init_std=0.25,
*,
key,
):
@@ -99,7 +103,24 @@ def create_model(
),
None,
)
- if model_name == "ncde":
+ elif (
+ model_name == "bd_linear_ncde" or "diagonal_linear_ncde" or "dense_linear_ncde"
+ ):
+ return (
+ LogLinearCDE(
+ data_dim=data_dim,
+ hidden_dim=hidden_dim,
+ label_dim=label_dim,
+ block_size=block_size,
+ logsig_depth=logsig_depth,
+ lambd=lambd,
+ w_init_std=w_init_std,
+ classification=classification,
+ key=key,
+ ),
+ None,
+ )
+ elif model_name == "ncde":
if vf_width is None or vf_depth is None:
raise ValueError("Must specify vf_width and vf_depth for a NCDE.")
return (
diff --git a/results/memory_time_results.json b/results/memory_time_results.json
index 8a50651..db6c913 100644
--- a/results/memory_time_results.json
+++ b/results/memory_time_results.json
@@ -55,9 +55,57 @@
"MotorImagery": 4056,
"SelfRegulationSCP1": 904,
"SelfRegulationSCP2": 1222
+ },
+ "bd_linear_ncde": {
+ "EigenWorms": 3494,
+ "EthanolConcentration": 1192,
+ "Heartbeat": 2732,
+ "MotorImagery": 8612,
+ "SelfRegulationSCP1": 666,
+ "SelfRegulationSCP2": 932
+ },
+ "diagonal_linear_ncde": {
+ "EigenWorms": 3486,
+ "EthanolConcentration": 670,
+ "Heartbeat": 920,
+ "MotorImagery": 6552,
+ "SelfRegulationSCP1": 662,
+ "SelfRegulationSCP2": 664
+ },
+ "dense_linear_ncde": {
+ "EigenWorms": 18580,
+ "EthanolConcentration": 11170,
+ "Heartbeat": 2712,
+ "MotorImagery": 8622,
+ "SelfRegulationSCP1": 666,
+ "SelfRegulationSCP2": 11172
}
},
"time": {
+ "dense_linear_ncde": {
+ "EigenWorms": 145.44,
+ "EthanolConcentration": 71.80,
+ "Heartbeat": 196.6,
+ "MotorImagery": 202.78,
+ "SelfRegulationSCP1": 8.27,
+ "SelfRegulationSCP2": 60.51
+ },
+ "bd_linear_ncde": {
+ "EigenWorms": 15.46,
+ "EthanolConcentration": 12.54,
+ "Heartbeat": 138.09,
+ "MotorImagery": 148.16,
+ "SelfRegulationSCP1": 6.77,
+ "SelfRegulationSCP2": 10.71
+ },
+ "diagonal_linear_ncde": {
+ "EigenWorms": 7.56,
+ "EthanolConcentration": 7.97,
+ "Heartbeat": 6.77,
+ "MotorImagery": 6.16,
+ "SelfRegulationSCP1": 6.39,
+ "SelfRegulationSCP2": 6.91
+ },
"ncde": {
"EigenWorms": 24595.42,
"EthanolConcentration": 2216.64,
@@ -190,4 +238,4 @@
"S5": "S5",
"S6": "S6"
}
-}
\ No newline at end of file
+}
diff --git a/results/paper_outputs.zip b/results/paper_outputs.zip
index 9fa846c..7162aaf 100644
Binary files a/results/paper_outputs.zip and b/results/paper_outputs.zip differ
diff --git a/run_experiment.py b/run_experiment.py
index fc530ea..feca309 100644
--- a/run_experiment.py
+++ b/run_experiment.py
@@ -43,7 +43,18 @@ def run_experiments(model_names, dataset_names, experiment_folder, pytorch_exper
metric = data["metric"]
use_presplit = data["use_presplit"]
T = data["T"]
- if model_name in ["lru", "S5", "S6", "mamba"]:
+ if model_name in [
+ "lru",
+ "S5",
+ "S6",
+ "mamba",
+ "rnn_linear",
+ "rnn_lstm",
+ "rnn_gru",
+ "bd_linear_ncde",
+ "diagonal_linear_ncde",
+ "dense_linear_ncde",
+ ]:
dt0 = None
else:
dt0 = float(data["dt0"])
@@ -52,6 +63,7 @@ def run_experiments(model_names, dataset_names, experiment_folder, pytorch_exper
include_time = data["time"].lower() == "true"
hidden_dim = int(data["hidden_dim"])
if model_name in ["log_ncde", "nrde", "ncde"]:
+ block_size = None
vf_depth = int(data["vf_depth"])
vf_width = int(data["vf_width"])
if model_name in ["log_ncde", "nrde"]:
@@ -67,13 +79,26 @@ def run_experiments(model_names, dataset_names, experiment_folder, pytorch_exper
ssm_dim = None
num_blocks = None
else:
+ if (
+ model_name == "bd_linear_ncde"
+ or model_name == "diagonal_linear_ncde"
+ or model_name == "dense_linear_ncde"
+ ):
+ block_size = int(data["block_size"])
+ ssm_dim = None
+ stepsize = int(float(data["stepsize"]))
+ logsig_depth = int(data["depth"])
+ lambd = float(data["lambd"])
+ num_blocks = None
+ else:
+ block_size = None
+ ssm_dim = int(data["ssm_dim"])
+ stepsize = 1
+ logsig_depth = 1
+ lambd = None
+ num_blocks = int(data["num_blocks"])
vf_depth = None
vf_width = None
- logsig_depth = 1
- stepsize = 1
- lambd = None
- ssm_dim = int(data["ssm_dim"])
- num_blocks = int(data["num_blocks"])
if model_name == "S5":
ssm_blocks = int(data["ssm_blocks"])
else:
@@ -141,6 +166,7 @@ def run_experiments(model_names, dataset_names, experiment_folder, pytorch_exper
model_args = {
"num_blocks": num_blocks,
+ "block_size": block_size,
"hidden_dim": hidden_dim,
"vf_depth": vf_depth,
"vf_width": vf_width,
@@ -151,6 +177,7 @@ def run_experiments(model_names, dataset_names, experiment_folder, pytorch_exper
"stepsize_controller": diffrax.ConstantStepSize(),
"scale": scale,
"lambd": lambd,
+ "stepsize": stepsize,
}
run_args = {
"data_dir": data_dir,
@@ -190,7 +217,16 @@ def run_experiments(model_names, dataset_names, experiment_folder, pytorch_exper
if pytorch_experiments:
model_names = ["mamba", "S6"]
else:
- model_names = ["ncde", "log_ncde", "nrde", "S5", "lru"]
+ model_names = [
+ # "ncde",
+ # "log_ncde",
+ # "nrde",
+ # "S5",
+ # "lru",
+ "bd_linear_ncde",
+ "dense_linear_ncde",
+ "diagonal_linear_ncde",
+ ]
dataset_names = [
"EigenWorms",
"EthanolConcentration",
diff --git a/train.py b/train.py
index 2356a0e..dd6bcd9 100644
--- a/train.py
+++ b/train.py
@@ -77,11 +77,14 @@ def classification_loss(diff_model, static_model, X, y, state, key):
)
norm = 0
if model.lip2:
- for layer in model.vf.mlp.layers:
- norm += jnp.mean(
- jnp.linalg.norm(layer.weight, axis=-1)
- + jnp.linalg.norm(layer.bias, axis=-1)
- )
+ if hasattr(model, "vf"):
+ for layer in model.vf.mlp.layers:
+ norm += jnp.mean(
+ jnp.linalg.norm(layer.weight, axis=-1)
+ + jnp.linalg.norm(layer.bias, axis=-1)
+ )
+ else:
+ norm += jnp.mean(jnp.linalg.norm(model.vf_A, axis=-1))
norm *= model.lambd
return (
jnp.mean(-jnp.sum(y * jnp.log(pred_y + 1e-8), axis=1)) + norm,
@@ -99,12 +102,14 @@ def regression_loss(diff_model, static_model, X, y, state, key):
pred_y = pred_y[:, :, 0]
norm = 0
if model.lip2:
- for layer in model.vf.mlp.layers:
- norm += jnp.mean(
- jnp.linalg.norm(layer.weight, axis=-1)
- + jnp.linalg.norm(layer.bias, axis=-1)
- )
- norm *= model.lambd
+ if hasattr(model, "vf"):
+ for layer in model.vf.mlp.layers:
+ norm += jnp.mean(
+ jnp.linalg.norm(layer.weight, axis=-1)
+ + jnp.linalg.norm(layer.bias, axis=-1)
+ )
+ else:
+ norm += jnp.mean(jnp.linalg.norm(model.vf_A, axis=-1))
return (
jnp.mean(jnp.mean((pred_y - y) ** 2, axis=1)) + norm,
state,
@@ -121,6 +126,8 @@ def make_step(model, filter_spec, X, y, loss_fn, state, opt, opt_state, key):
def train_model(
+ model_name,
+ dataset_name,
model,
metric,
filter_spec,
@@ -188,6 +195,13 @@ def train_model(
):
stepkey, key = jr.split(key, 2)
X, y = data
+
+ if (
+ model_name == "bd_linear_ncde"
+ or model_name == "diagonal_linear_ncde"
+ or model_name == "dense_linear_ncde"
+ ) and dataset_name == "Heartbeat":
+ X = (X[0], X[1] / 10, X[2])
model, state, opt_state, value = make_step(
model, filter_spec, X, y, loss_fn, state, opt, opt_state, stepkey
)
@@ -199,6 +213,12 @@ def train_model(
stepkey, key = jr.split(key, 2)
inference_model = eqx.tree_inference(model, value=True)
X, y = data
+ if (
+ model_name == "bd_linear_ncde"
+ or model_name == "diagonal_linear_ncde"
+ or model_name == "dense_linear_ncde"
+ ) and dataset_name == "Heartbeat":
+ X = (X[0], X[1] / 10, X[2])
prediction, _ = calc_output(
inference_model,
X,
@@ -224,6 +244,12 @@ def train_model(
stepkey, key = jr.split(key, 2)
inference_model = eqx.tree_inference(model, value=True)
X, y = data
+ if (
+ model_name == "bd_linear_ncde"
+ or model_name == "diagonal_linear_ncde"
+ or model_name == "dense_linear_ncde"
+ ) and dataset_name == "Heartbeat":
+ X = (X[0], X[1] / 10, X[2])
prediction, _ = calc_output(
inference_model,
X,
@@ -266,6 +292,12 @@ def train_model(
stepkey, key = jr.split(key, 2)
inference_model = eqx.tree_inference(model, value=True)
X, y = data
+ if (
+ model_name == "bd_linear_ncde"
+ or model_name == "diagonal_linear_ncde"
+ or model_name == "dense_linear_ncde"
+ ) and dataset_name == "Heartbeat":
+ X = (X[0], X[1] / 10, X[2])
prediction, _ = calc_output(
inference_model,
X,
@@ -360,6 +392,15 @@ def create_dataset_model_and_train(
datasetkey, modelkey, trainkey, key = jr.split(key, 4)
print(f"Creating dataset {dataset_name}")
+ if (
+ model_name == "bd_linear_ncde"
+ or model_name == "diagonal_linear_ncde"
+ or model_name == "dense_linear_ncde"
+ ):
+ scale = True
+ else:
+ scale = False
+
dataset = create_dataset(
data_dir,
dataset_name,
@@ -369,6 +410,7 @@ def create_dataset_model_and_train(
T=T,
use_idxs=False,
use_presplit=use_presplit,
+ scale=scale,
key=datasetkey,
)
@@ -387,7 +429,13 @@ def create_dataset_model_and_train(
key=modelkey,
)
filter_spec = jax.tree_util.tree_map(lambda _: True, model)
- if model_name == "nrde" or model_name == "log_ncde":
+ if (
+ model_name == "nrde"
+ or model_name == "log_ncde"
+ or model_name == "bd_linear_ncde"
+ or model_name == "diagonal_linear_ncde"
+ or model_name == "dense_linear_ncde"
+ ):
dataloaders = dataset.path_dataloaders
if model_name == "log_ncde":
where = lambda model: (model.intervals, model.pairs)
@@ -403,6 +451,8 @@ def create_dataset_model_and_train(
dataloaders = dataset.raw_dataloaders
return train_model(
+ model_name,
+ dataset_name,
model,
metric,
filter_spec,