Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SLiCEs

`slices` is a small PyTorch package providing **Structured Linear CDE (SLiCE)** recurrences.
`slices` is a PyTorch implementation of **Structured Linear CDEs (SLiCEs)** with parallel-in-time computation.

It provides expressive, scalable sequence layers based on the methods developed in
[*Structured Linear CDEs: Maximally Expressive and Parallel-in-Time Sequence Models*](https://arxiv.org/abs/2505.17761).

## Mathematical form

Expand Down Expand Up @@ -36,15 +39,15 @@ Or install from source:
pip install git+https://github.com/datasig-ac-uk/slices.git
```

## What's included
## Components

- **`SLiCE`**: the core structured recurrence
- **`SLiCELayer`**: a residual SLiCE layer with RMSNorm + GELU MLP by default
- **`StackedSLiCE`**: stacks multiple `SLiCELayer`s with an embedding + output projection (supports tokens or continuous inputs)
- **`SLiCELayer`**: a residual sequence layer built around `SLiCE`, with RMSNorm and a GELU MLP by default
- **`StackedSLiCE`**: a full sequence model that stacks `SLiCELayer`s with input and output projections for token or continuous inputs

`SLiCE` supports both:
- **Recurrent execution** (step-by-step update)
- **Parallel chunked scan execution** using `torch.associative_scan`
- **Recurrent execution** for step-by-step updates
- **Parallel chunked scan execution** via `torch.associative_scan`

## Structured transition matrices

Expand All @@ -55,7 +58,8 @@ Set:
- `diagonal_dense=False`
- `block_size=1`

Then $A(X_i)$ is diagonal, which aligns with the approach used by Mamba (see [here](https://arxiv.org/abs/2505.17761) for more details).
Then $A(X_i)$ is diagonal, matching the diagonal state-transition setting used by Mamba (as discussed in
[*Structured Linear CDEs: Maximally Expressive and Parallel-in-Time Sequence Models*](https://arxiv.org/abs/2505.17761)).

### 2) Block-diagonal
Set:
Expand Down Expand Up @@ -95,8 +99,8 @@ y = layer(x) # (8, 128, 64)
print(y.shape)
```

Execution mode is configured via constructor arguments (`use_parallel`, `chunk_size`).
`path_mode` determines how `SLiCE` treats the sequence you pass in.
Execution mode is controlled by `use_parallel` and `chunk_size`.
`path_mode` determines whether the input sequence is interpreted as path values or increments.

### Use `SLiCELayer` as a residual sequence layer

Expand All @@ -115,21 +119,21 @@ layer = SLiCELayer(
y = layer(x) # (4, 256, 64)
```

`SLiCELayer` defaults to this structure:
`SLiCELayer` uses the following default structure:
- RMSNorm -> SLiCE -> residual
- RMSNorm -> Linear -> GELU -> Linear -> residual

Optional toggle for the post-norm wrapper:
To use a post-norm wrapper:
- `prenorm=False`

Optional toggles for the LayerNorm + single-stage wrapper include:
Other configuration options include:
- `norm_type="layernorm"`
- `ff_style="single"`
- `ff_mult=1`
- `ff_activation="glu"` or `ff_activation="tanh"`
- `dropout_position="output"`

### Stack layers for a full model
### Build a full model with stacked layers

#### Token sequence mode (`tokens=True`)

Expand Down Expand Up @@ -182,15 +186,14 @@ y = model(x) # (16, 100, 10)

## Training Example

`examples/language_disambiguation.py` is a simple example of training a
`StackedSLiCE` model end-to-end on a real dataset.
`examples/language_disambiguation.py` trains a `StackedSLiCE` model end-to-end on a real dataset.

This example:
The example:
- uses a real dataset (**wikimedia/wikipedia**, English/French subset) for
**character-level language disambiguation**
- trains a compact token-mode `StackedSLiCE` end-to-end
- evaluates validation accuracy every `--eval-every` training steps
- prints sample predictions so you can inspect model behaviour quickly
- prints sample predictions for quick qualitative inspection

To run it, first install the example dependencies:

Expand All @@ -213,7 +216,7 @@ Useful flags:

## Benchmarking

To compare recurrent vs parallel throughput across sequence lengths and hidden dimensions:
To compare recurrent and parallel throughput across sequence lengths and hidden dimensions:

```bash
uv run python examples/benchmark_parallel_vs_recurrent.py
Expand All @@ -222,7 +225,7 @@ uv run python examples/benchmark_parallel_vs_recurrent.py
This script:
- benchmarks all four SLiCE matrix modes (`diagonal`, `block_diagonal`, `diagonal_dense`, `dense`)
- uses the default value-path semantics unless `path_mode="increments"` is set in code
- prints timing/speedup tables
- prints timing and speedup tables
- saves a combined 3D plot to `examples/images/parallel_vs_recurrent_speedup_3d_all_modes.png`

For plotting in development, install development dependencies (includes `matplotlib`):
Expand Down