Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ $$
By default, `SLiCE` treats the provided sequence as path values and internally
computes first differences with `torch.diff(..., prepend=zeros)`. Set
`path_mode="increments"` to treat the provided sequence as increments instead.
By default, each step uses the first-order Euler update `I + A(X_i)`. Set
`transition_mode="matrix_exp"` to use `exp(A(X_i))` instead.

## Installation

Expand Down Expand Up @@ -101,6 +103,8 @@ print(y.shape)

Execution mode is controlled by `use_parallel` and `chunk_size`.
`path_mode` determines whether the input sequence is interpreted as path values or increments.
`transition_mode` selects between the default Euler step and a matrix-exponential
transition.

### Use `SLiCELayer` as a residual sequence layer

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "torch-slices"
version = "0.3.2"
version = "0.3.3"
description = "Structured Linear CDE (SLiCE) layers for sequence modelling in PyTorch"
readme = "README.md"
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion src/slices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@

from slices.slices import SLiCE, SLiCELayer, StackedSLiCE

__version__ = "0.3.2"
__version__ = "0.3.3"
__all__ = ["SLiCE", "SLiCELayer", "StackedSLiCE"]
185 changes: 103 additions & 82 deletions src/slices/slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class SLiCE(nn.Module):
path_mode (str): Whether the input is treated as path values
("values", default) or as increments
("increments").
transition_mode (str): How the linear update is discretised.
"euler" uses I + A(X_i), while
"matrix_exp" uses exp(A(X_i)).

Shape:
- Input: (batch_size, seq_len, input_dim)
Expand All @@ -62,13 +65,16 @@ def __init__(
use_parallel: bool = True,
chunk_size: int = 256,
path_mode: str = "values",
transition_mode: str = "euler",
):
super().__init__()

if hidden_dim is None:
hidden_dim = input_dim
if path_mode not in {"values", "increments"}:
raise ValueError("path_mode must be one of {'values', 'increments'}.")
if transition_mode not in {"euler", "matrix_exp"}:
raise ValueError("transition_mode must be one of {'euler', 'matrix_exp'}.")
if block_size < 1:
raise ValueError("block_size must be at least 1.")
if not diagonal_dense and hidden_dim % block_size != 0:
Expand All @@ -82,6 +88,7 @@ def __init__(
self.scale = scale
self.input_dependent_init = input_dependent_init
self.path_mode = path_mode
self.transition_mode = transition_mode

self.use_parallel = use_parallel
if chunk_size < 1:
Expand Down Expand Up @@ -172,17 +179,82 @@ def _prepare_augmented_inputs(self, x: torch.Tensor) -> torch.Tensor:
)
return torch.cat((inc_ts, path), dim=-1) * self.scale

def _discretize_diagonal(self, A: torch.Tensor) -> torch.Tensor:
if self.transition_mode == "matrix_exp":
return torch.exp(A)
return 1.0 + A

def _discretize_matrix(self, A: torch.Tensor) -> torch.Tensor:
if self.transition_mode == "matrix_exp":
return torch.matrix_exp(A)

eye = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype)
eye = eye.view(*((1,) * (A.ndim - 2)), A.shape[-2], A.shape[-1])
return eye + A

def _build_elementwise_transform(
self, inp: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
A = self.vf_A(inp)
M = self._discretize_diagonal(A)
if self.bias:
b = self.vf_B(inp)
else:
b = torch.zeros_like(M)
return M, b

def _build_blockdiag_transform(
self, inp: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
bsz = self.block_size
nblocks = self.hidden_dim // bsz

A = self.vf_A(inp).view(*inp.shape[:-1], nblocks, bsz, bsz)
M = self._discretize_matrix(A)
if self.bias:
b = self.vf_B(inp).view(*inp.shape[:-1], nblocks, bsz)
else:
b = torch.zeros(
*inp.shape[:-1],
nblocks,
bsz,
device=inp.device,
dtype=inp.dtype,
)
return M, b

def _build_diagonal_dense_transform(
self, inp: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = self.block_size
hdiag = self.hidden_dim - bsz

A_diag = self.vf_A_diag(inp)
M_diag = self._discretize_diagonal(A_diag)

A_dense = self.vf_A_dense(inp).view(*inp.shape[:-1], bsz, bsz)
M_dense = self._discretize_matrix(A_dense)

if self.bias:
B = self.vf_B(inp)
b_diag = B[..., :hdiag]
b_dense = B[..., hdiag:]
else:
b_diag = torch.zeros_like(M_diag)
b_dense = torch.zeros(
*inp.shape[:-1],
bsz,
device=inp.device,
dtype=inp.dtype,
)

return M_diag, M_dense, b_diag, b_dense

# ---- scan kernels: block_size == 1 (elementwise) ----

def _scan_kernels_elementwise(self) -> tuple[Callable, Callable, Callable]:
def build(inp_chunk: torch.Tensor):
A = self.vf_A(inp_chunk) # (B, C, H)
M = 1.0 + A
if self.bias:
b = self.vf_B(inp_chunk)
else:
b = torch.zeros_like(M)
return (M, b)
return self._build_elementwise_transform(inp_chunk)

def combine(lhs, rhs):
# Composition: rhs ∘ lhs
Expand All @@ -205,28 +277,7 @@ def _scan_kernels_blockdiag(self) -> tuple[Callable, Callable, Callable]:
nblocks = self.hidden_dim // bsz

def build(inp_chunk: torch.Tensor):
# A: (B, C, nblocks, b, b)
A = self.vf_A(inp_chunk).view(
inp_chunk.shape[0], inp_chunk.shape[1], nblocks, bsz, bsz
)
eye = torch.eye(bsz, device=inp_chunk.device, dtype=inp_chunk.dtype).view(
1, 1, 1, bsz, bsz
)
M = eye + A
if self.bias:
b = self.vf_B(inp_chunk).view(
inp_chunk.shape[0], inp_chunk.shape[1], nblocks, bsz
)
else:
b = torch.zeros(
inp_chunk.shape[0],
inp_chunk.shape[1],
nblocks,
bsz,
device=inp_chunk.device,
dtype=inp_chunk.dtype,
)
return (M, b)
return self._build_blockdiag_transform(inp_chunk)

def combine(lhs, rhs):
M_l, b_l = lhs
Expand All @@ -253,32 +304,7 @@ def _scan_kernels_diagonal_dense(self) -> tuple[Callable, Callable, Callable]:
hdiag = h - bsz

def build(inp_chunk: torch.Tensor):
A_diag = self.vf_A_diag(inp_chunk) # (B,C,hdiag)
M_diag = 1.0 + A_diag

A_dense = self.vf_A_dense(inp_chunk).view(
inp_chunk.shape[0], inp_chunk.shape[1], bsz, bsz
)
eye = torch.eye(bsz, device=inp_chunk.device, dtype=inp_chunk.dtype).view(
1, 1, bsz, bsz
)
M_dense = eye + A_dense

if self.bias:
B = self.vf_B(inp_chunk) # (B,C,h)
b_diag = B[..., :hdiag]
b_dense = B[..., hdiag:]
else:
b_diag = torch.zeros_like(M_diag)
b_dense = torch.zeros(
inp_chunk.shape[0],
inp_chunk.shape[1],
bsz,
device=inp_chunk.device,
dtype=inp_chunk.dtype,
)

return (M_diag, M_dense, b_diag, b_dense)
return self._build_diagonal_dense_transform(inp_chunk)

def combine(lhs, rhs):
Md_l, Mdense_l, bd_l, bdense_l = lhs
Expand Down Expand Up @@ -348,40 +374,28 @@ def _forward_recurrent(self, X: torch.Tensor) -> torch.Tensor:
if self.diagonal_dense:
y_diag = y[:, : -self.block_size]
y_dense = y[:, -self.block_size :]
diag_state_transition = self.vf_A_diag(inp[:, i]) * y_diag
A = self.vf_A_dense(inp[:, i])
dense_state_transition = torch.einsum(
"bij,bj->bi",
A.view(-1, self.block_size, self.block_size),
y_dense,
M_diag, M_dense, b_diag, b_dense = self._build_diagonal_dense_transform(
inp[:, i]
)
state_transition = torch.cat(
[diag_state_transition, dense_state_transition], dim=1
y_diag = M_diag * y_diag + b_diag
y_dense = (
torch.matmul(M_dense, y_dense.unsqueeze(-1)).squeeze(-1) + b_dense
)
y = torch.cat([y_diag, y_dense], dim=1)
elif self.block_size > 1:
state_transition = self.vf_A(inp[:, i]).view(
-1,
self.hidden_dim // self.block_size,
self.block_size,
self.block_size,
)
state_transition = (
state_transition
@ y.view(
M, b = self._build_blockdiag_transform(inp[:, i])
y = torch.matmul(
M,
y.view(
-1,
self.hidden_dim // self.block_size,
self.block_size,
1,
)
).view(-1, self.hidden_dim)
),
).view(-1, self.hidden_dim) + b.view(-1, self.hidden_dim)
else:
state_transition = self.vf_A(inp[:, i]) * y

if self.bias:
bias_term = self.vf_B(inp[:, i])
state_transition += bias_term

y = y + state_transition
M, b = self._build_elementwise_transform(inp[:, i])
y = M * y + b
ys[:, i] = y

return ys
Expand All @@ -395,7 +409,8 @@ def _forward_parallel(self, X: torch.Tensor, chunk_size: int) -> torch.Tensor:
Chunked parallel forward using torch.associative_scan (generic).

Each step defines an affine transform:
y_i = M_i y_{i-1} + b_i, where M_i = I + A_i, b_i = B_i
y_i = M_i y_{i-1} + b_i,
where M_i is either I + A_i or exp(A_i), and b_i = B_i
We scan-combine transforms within each chunk, then apply prefixes
to chunk-start state.
"""
Expand Down Expand Up @@ -480,6 +495,7 @@ class SLiCELayer(nn.Module):
dropout_rate (float): Dropout probability applied either on residual branches
or on the block output, depending on dropout_position.
path_mode (str): How the inner SLiCE interprets the input path.
transition_mode (str): How the inner SLiCE discretises each update.
norm_type (str): "rmsnorm" or "layernorm". Defaults to "rmsnorm".
prenorm (bool): If True, apply normalisation before the SLiCE and
feedforward branches; if False, use post-residual
Expand Down Expand Up @@ -514,6 +530,7 @@ def __init__(
chunk_size: int = 256,
dropout_rate: float = 0.01,
path_mode: str = "values",
transition_mode: str = "euler",
norm_type: str = "rmsnorm",
prenorm: bool = True,
second_norm: bool = True,
Expand Down Expand Up @@ -556,6 +573,7 @@ def __init__(
use_parallel=use_parallel,
chunk_size=chunk_size,
path_mode=path_mode,
transition_mode=transition_mode,
)

self.drop = nn.Dropout(p=dropout_rate)
Expand Down Expand Up @@ -647,6 +665,7 @@ class StackedSLiCE(nn.Module):
chunk_size (int): Chunk size used by each layer's inner SLiCE in parallel mode.
dropout_rate (float): Dropout probability applied in each layer.
path_mode (str): How each inner SLiCE interprets its input path.
transition_mode (str): How each inner SLiCE discretises each update.
norm_type (str): "rmsnorm" or "layernorm" for each stacked layer.
prenorm (bool): Whether each stacked layer uses pre-norm.
second_norm (bool): Whether each stacked layer uses the second
Expand Down Expand Up @@ -681,6 +700,7 @@ def __init__(
chunk_size: int = 256,
dropout_rate: float = 0.01,
path_mode: str = "values",
transition_mode: str = "euler",
norm_type: str = "rmsnorm",
prenorm: bool = True,
second_norm: bool = True,
Expand Down Expand Up @@ -712,6 +732,7 @@ def __init__(
chunk_size=chunk_size,
dropout_rate=dropout_rate,
path_mode=path_mode,
transition_mode=transition_mode,
norm_type=norm_type,
prenorm=prenorm,
second_norm=second_norm,
Expand Down
Loading