diff --git a/README.md b/README.md index 7d88194..ee409de 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ $$ By default, `SLiCE` treats the provided sequence as path values and internally computes first differences with `torch.diff(..., prepend=zeros)`. Set `path_mode="increments"` to treat the provided sequence as increments instead. +By default, each step uses the first-order Euler update `I + A(X_i)`. Set +`transition_mode="matrix_exp"` to use `exp(A(X_i))` instead. ## Installation @@ -101,6 +103,8 @@ print(y.shape) Execution mode is controlled by `use_parallel` and `chunk_size`. `path_mode` determines whether the input sequence is interpreted as path values or increments. +`transition_mode` selects between the default Euler step and a matrix-exponential +transition. ### Use `SLiCELayer` as a residual sequence layer diff --git a/pyproject.toml b/pyproject.toml index 84a5f9e..ac76608 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "torch-slices" -version = "0.3.2" +version = "0.3.3" description = "Structured Linear CDE (SLiCE) layers for sequence modelling in PyTorch" readme = "README.md" license = "MIT" diff --git a/src/slices/__init__.py b/src/slices/__init__.py index 432f87c..bc1aa2c 100644 --- a/src/slices/__init__.py +++ b/src/slices/__init__.py @@ -12,5 +12,5 @@ from slices.slices import SLiCE, SLiCELayer, StackedSLiCE -__version__ = "0.3.2" +__version__ = "0.3.3" __all__ = ["SLiCE", "SLiCELayer", "StackedSLiCE"] diff --git a/src/slices/slices.py b/src/slices/slices.py index 485783b..6e79eb7 100644 --- a/src/slices/slices.py +++ b/src/slices/slices.py @@ -43,6 +43,9 @@ class SLiCE(nn.Module): path_mode (str): Whether the input is treated as path values ("values", default) or as increments ("increments"). + transition_mode (str): How the linear update is discretised. + "euler" uses I + A(X_i), while + "matrix_exp" uses exp(A(X_i)). Shape: - Input: (batch_size, seq_len, input_dim) @@ -62,6 +65,7 @@ def __init__( use_parallel: bool = True, chunk_size: int = 256, path_mode: str = "values", + transition_mode: str = "euler", ): super().__init__() @@ -69,6 +73,8 @@ def __init__( hidden_dim = input_dim if path_mode not in {"values", "increments"}: raise ValueError("path_mode must be one of {'values', 'increments'}.") + if transition_mode not in {"euler", "matrix_exp"}: + raise ValueError("transition_mode must be one of {'euler', 'matrix_exp'}.") if block_size < 1: raise ValueError("block_size must be at least 1.") if not diagonal_dense and hidden_dim % block_size != 0: @@ -82,6 +88,7 @@ def __init__( self.scale = scale self.input_dependent_init = input_dependent_init self.path_mode = path_mode + self.transition_mode = transition_mode self.use_parallel = use_parallel if chunk_size < 1: @@ -172,17 +179,82 @@ def _prepare_augmented_inputs(self, x: torch.Tensor) -> torch.Tensor: ) return torch.cat((inc_ts, path), dim=-1) * self.scale + def _discretize_diagonal(self, A: torch.Tensor) -> torch.Tensor: + if self.transition_mode == "matrix_exp": + return torch.exp(A) + return 1.0 + A + + def _discretize_matrix(self, A: torch.Tensor) -> torch.Tensor: + if self.transition_mode == "matrix_exp": + return torch.matrix_exp(A) + + eye = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype) + eye = eye.view(*((1,) * (A.ndim - 2)), A.shape[-2], A.shape[-1]) + return eye + A + + def _build_elementwise_transform( + self, inp: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + A = self.vf_A(inp) + M = self._discretize_diagonal(A) + if self.bias: + b = self.vf_B(inp) + else: + b = torch.zeros_like(M) + return M, b + + def _build_blockdiag_transform( + self, inp: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + bsz = self.block_size + nblocks = self.hidden_dim // bsz + + A = self.vf_A(inp).view(*inp.shape[:-1], nblocks, bsz, bsz) + M = self._discretize_matrix(A) + if self.bias: + b = self.vf_B(inp).view(*inp.shape[:-1], nblocks, bsz) + else: + b = torch.zeros( + *inp.shape[:-1], + nblocks, + bsz, + device=inp.device, + dtype=inp.dtype, + ) + return M, b + + def _build_diagonal_dense_transform( + self, inp: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = self.block_size + hdiag = self.hidden_dim - bsz + + A_diag = self.vf_A_diag(inp) + M_diag = self._discretize_diagonal(A_diag) + + A_dense = self.vf_A_dense(inp).view(*inp.shape[:-1], bsz, bsz) + M_dense = self._discretize_matrix(A_dense) + + if self.bias: + B = self.vf_B(inp) + b_diag = B[..., :hdiag] + b_dense = B[..., hdiag:] + else: + b_diag = torch.zeros_like(M_diag) + b_dense = torch.zeros( + *inp.shape[:-1], + bsz, + device=inp.device, + dtype=inp.dtype, + ) + + return M_diag, M_dense, b_diag, b_dense + # ---- scan kernels: block_size == 1 (elementwise) ---- def _scan_kernels_elementwise(self) -> tuple[Callable, Callable, Callable]: def build(inp_chunk: torch.Tensor): - A = self.vf_A(inp_chunk) # (B, C, H) - M = 1.0 + A - if self.bias: - b = self.vf_B(inp_chunk) - else: - b = torch.zeros_like(M) - return (M, b) + return self._build_elementwise_transform(inp_chunk) def combine(lhs, rhs): # Composition: rhs ∘ lhs @@ -205,28 +277,7 @@ def _scan_kernels_blockdiag(self) -> tuple[Callable, Callable, Callable]: nblocks = self.hidden_dim // bsz def build(inp_chunk: torch.Tensor): - # A: (B, C, nblocks, b, b) - A = self.vf_A(inp_chunk).view( - inp_chunk.shape[0], inp_chunk.shape[1], nblocks, bsz, bsz - ) - eye = torch.eye(bsz, device=inp_chunk.device, dtype=inp_chunk.dtype).view( - 1, 1, 1, bsz, bsz - ) - M = eye + A - if self.bias: - b = self.vf_B(inp_chunk).view( - inp_chunk.shape[0], inp_chunk.shape[1], nblocks, bsz - ) - else: - b = torch.zeros( - inp_chunk.shape[0], - inp_chunk.shape[1], - nblocks, - bsz, - device=inp_chunk.device, - dtype=inp_chunk.dtype, - ) - return (M, b) + return self._build_blockdiag_transform(inp_chunk) def combine(lhs, rhs): M_l, b_l = lhs @@ -253,32 +304,7 @@ def _scan_kernels_diagonal_dense(self) -> tuple[Callable, Callable, Callable]: hdiag = h - bsz def build(inp_chunk: torch.Tensor): - A_diag = self.vf_A_diag(inp_chunk) # (B,C,hdiag) - M_diag = 1.0 + A_diag - - A_dense = self.vf_A_dense(inp_chunk).view( - inp_chunk.shape[0], inp_chunk.shape[1], bsz, bsz - ) - eye = torch.eye(bsz, device=inp_chunk.device, dtype=inp_chunk.dtype).view( - 1, 1, bsz, bsz - ) - M_dense = eye + A_dense - - if self.bias: - B = self.vf_B(inp_chunk) # (B,C,h) - b_diag = B[..., :hdiag] - b_dense = B[..., hdiag:] - else: - b_diag = torch.zeros_like(M_diag) - b_dense = torch.zeros( - inp_chunk.shape[0], - inp_chunk.shape[1], - bsz, - device=inp_chunk.device, - dtype=inp_chunk.dtype, - ) - - return (M_diag, M_dense, b_diag, b_dense) + return self._build_diagonal_dense_transform(inp_chunk) def combine(lhs, rhs): Md_l, Mdense_l, bd_l, bdense_l = lhs @@ -348,40 +374,28 @@ def _forward_recurrent(self, X: torch.Tensor) -> torch.Tensor: if self.diagonal_dense: y_diag = y[:, : -self.block_size] y_dense = y[:, -self.block_size :] - diag_state_transition = self.vf_A_diag(inp[:, i]) * y_diag - A = self.vf_A_dense(inp[:, i]) - dense_state_transition = torch.einsum( - "bij,bj->bi", - A.view(-1, self.block_size, self.block_size), - y_dense, + M_diag, M_dense, b_diag, b_dense = self._build_diagonal_dense_transform( + inp[:, i] ) - state_transition = torch.cat( - [diag_state_transition, dense_state_transition], dim=1 + y_diag = M_diag * y_diag + b_diag + y_dense = ( + torch.matmul(M_dense, y_dense.unsqueeze(-1)).squeeze(-1) + b_dense ) + y = torch.cat([y_diag, y_dense], dim=1) elif self.block_size > 1: - state_transition = self.vf_A(inp[:, i]).view( - -1, - self.hidden_dim // self.block_size, - self.block_size, - self.block_size, - ) - state_transition = ( - state_transition - @ y.view( + M, b = self._build_blockdiag_transform(inp[:, i]) + y = torch.matmul( + M, + y.view( -1, self.hidden_dim // self.block_size, self.block_size, 1, - ) - ).view(-1, self.hidden_dim) + ), + ).view(-1, self.hidden_dim) + b.view(-1, self.hidden_dim) else: - state_transition = self.vf_A(inp[:, i]) * y - - if self.bias: - bias_term = self.vf_B(inp[:, i]) - state_transition += bias_term - - y = y + state_transition + M, b = self._build_elementwise_transform(inp[:, i]) + y = M * y + b ys[:, i] = y return ys @@ -395,7 +409,8 @@ def _forward_parallel(self, X: torch.Tensor, chunk_size: int) -> torch.Tensor: Chunked parallel forward using torch.associative_scan (generic). Each step defines an affine transform: - y_i = M_i y_{i-1} + b_i, where M_i = I + A_i, b_i = B_i + y_i = M_i y_{i-1} + b_i, + where M_i is either I + A_i or exp(A_i), and b_i = B_i We scan-combine transforms within each chunk, then apply prefixes to chunk-start state. """ @@ -480,6 +495,7 @@ class SLiCELayer(nn.Module): dropout_rate (float): Dropout probability applied either on residual branches or on the block output, depending on dropout_position. path_mode (str): How the inner SLiCE interprets the input path. + transition_mode (str): How the inner SLiCE discretises each update. norm_type (str): "rmsnorm" or "layernorm". Defaults to "rmsnorm". prenorm (bool): If True, apply normalisation before the SLiCE and feedforward branches; if False, use post-residual @@ -514,6 +530,7 @@ def __init__( chunk_size: int = 256, dropout_rate: float = 0.01, path_mode: str = "values", + transition_mode: str = "euler", norm_type: str = "rmsnorm", prenorm: bool = True, second_norm: bool = True, @@ -556,6 +573,7 @@ def __init__( use_parallel=use_parallel, chunk_size=chunk_size, path_mode=path_mode, + transition_mode=transition_mode, ) self.drop = nn.Dropout(p=dropout_rate) @@ -647,6 +665,7 @@ class StackedSLiCE(nn.Module): chunk_size (int): Chunk size used by each layer's inner SLiCE in parallel mode. dropout_rate (float): Dropout probability applied in each layer. path_mode (str): How each inner SLiCE interprets its input path. + transition_mode (str): How each inner SLiCE discretises each update. norm_type (str): "rmsnorm" or "layernorm" for each stacked layer. prenorm (bool): Whether each stacked layer uses pre-norm. second_norm (bool): Whether each stacked layer uses the second @@ -681,6 +700,7 @@ def __init__( chunk_size: int = 256, dropout_rate: float = 0.01, path_mode: str = "values", + transition_mode: str = "euler", norm_type: str = "rmsnorm", prenorm: bool = True, second_norm: bool = True, @@ -712,6 +732,7 @@ def __init__( chunk_size=chunk_size, dropout_rate=dropout_rate, path_mode=path_mode, + transition_mode=transition_mode, norm_type=norm_type, prenorm=prenorm, second_norm=second_norm, diff --git a/tests/test_slices.py b/tests/test_slices.py index 5d42c7a..713f16f 100644 --- a/tests/test_slices.py +++ b/tests/test_slices.py @@ -1,3 +1,5 @@ +import math + import pytest import torch @@ -41,13 +43,21 @@ def test_slice_raises_if_path_mode_is_invalid(): SLiCE(input_dim=3, hidden_dim=3, path_mode="not-a-mode") +def test_slice_raises_if_transition_mode_is_invalid(): + with pytest.raises(ValueError, match="transition_mode must be one of"): + SLiCE(input_dim=3, hidden_dim=3, transition_mode="not-a-mode") + + # ----------------------- # SLiCE: forward paths # ----------------------- +@pytest.mark.parametrize("transition_mode", ["euler", "matrix_exp"]) @pytest.mark.parametrize("use_parallel", [False, True]) -def test_slice_values_mode_matches_manual_external_differencing(use_parallel: bool): +def test_slice_values_mode_matches_manual_external_differencing( + use_parallel: bool, transition_mode: str +): x = _rand_x(batch=2, seq=6, dim=4, seed=1) dx = torch.diff(x, dim=1, prepend=torch.zeros_like(x[:, :1, :])) @@ -59,6 +69,7 @@ def test_slice_values_mode_matches_manual_external_differencing(use_parallel: bo bias=True, use_parallel=use_parallel, chunk_size=2, + transition_mode=transition_mode, ) m_values = SLiCE(**kwargs, path_mode="values") m_increments = SLiCE(**kwargs, path_mode="increments") @@ -136,6 +147,160 @@ def test_slice_forward_diagonal_dense_bias_true_with_grads(): _assert_grads_exist(m) +@pytest.mark.parametrize("use_parallel", [False, True]) +def test_slice_matrix_exp_matches_manual_blockdiag_reference(use_parallel: bool): + m = SLiCE( + input_dim=1, + hidden_dim=2, + block_size=2, + diagonal_dense=False, + bias=False, + scale=1.0, + use_parallel=use_parallel, + chunk_size=2, + path_mode="increments", + transition_mode="matrix_exp", + ) + m.eval() + + init_vec = torch.tensor([1.0, -0.5], dtype=torch.float32) + base_A = torch.tensor([[0.2, -0.1], [0.05, 0.3]], dtype=torch.float32) + + with torch.no_grad(): + m.init.copy_(init_vec.reshape_as(m.init)) + m.vf_A.weight.zero_() + m.vf_A.weight[:, 1].copy_(base_A.reshape(-1)) + + X = torch.tensor([[[0.2], [-0.1], [0.05]]], dtype=torch.float32) + + y = init_vec + expected_states = [] + for scale in X[0, :, 0]: + y = torch.matrix_exp(scale * base_A) @ y + expected_states.append(y) + expected = torch.stack(expected_states) + + Y = m(X)[0] + + assert Y.shape == expected.shape + torch.testing.assert_close(Y, expected, rtol=1e-5, atol=1e-6) + + +@pytest.mark.parametrize("use_parallel", [False, True]) +def test_slice_matrix_exp_scalar_time_increments_produce_exact_rotation( + use_parallel: bool, +): + """ + Use a 1D time-increment input with A(dt) = dt * 2π [[0, -1], [1, 0]]. + + For dt = 1/4, exp(A(dt)) is an exact 90-degree counter-clockwise rotation, so + starting from [1, 0] we should visit the unit circle axes exactly. + """ + + m = SLiCE( + input_dim=1, + hidden_dim=2, + block_size=2, + diagonal_dense=False, + bias=False, + scale=1.0, + use_parallel=use_parallel, + chunk_size=2, + path_mode="increments", + transition_mode="matrix_exp", + ) + m.eval() + + rotation_generator = torch.tensor( + [[0.0, -2.0 * math.pi], [2.0 * math.pi, 0.0]], dtype=torch.float32 + ) + + with torch.no_grad(): + m.init.copy_(torch.tensor([1.0, 0.0], dtype=torch.float32)) + m.vf_A.weight.zero_() + # Augmented input channels are [inc_ts, dt]. Use only dt. + m.vf_A.weight[:, 1].copy_(rotation_generator.reshape(-1)) + + X = torch.full((1, 4, 1), 0.25, dtype=torch.float32) + expected = torch.tensor( + [ + [0.0, 1.0], + [-1.0, 0.0], + [0.0, -1.0], + [1.0, 0.0], + ], + dtype=torch.float32, + ) + + Y = m(X)[0] + + assert Y.shape == expected.shape + torch.testing.assert_close(Y, expected, rtol=1e-5, atol=1e-6) + + +@pytest.mark.parametrize("use_parallel", [False, True]) +def test_slice_euler_scalar_time_increments_do_not_match_exact_rotation( + use_parallel: bool, +): + """ + The same 1D time-increment rotation setup as above, but using the Euler + discretisation. This should follow the explicit Euler update, not the exact + quarter-turn rotation. + """ + + m = SLiCE( + input_dim=1, + hidden_dim=2, + block_size=2, + diagonal_dense=False, + bias=False, + scale=1.0, + use_parallel=use_parallel, + chunk_size=2, + path_mode="increments", + transition_mode="euler", + ) + m.eval() + + rotation_generator = torch.tensor( + [[0.0, -2.0 * math.pi], [2.0 * math.pi, 0.0]], dtype=torch.float32 + ) + + with torch.no_grad(): + m.init.copy_(torch.tensor([1.0, 0.0], dtype=torch.float32)) + m.vf_A.weight.zero_() + # Augmented input channels are [inc_ts, dt]. Use only dt. + m.vf_A.weight[:, 1].copy_(rotation_generator.reshape(-1)) + + X = torch.full((1, 4, 1), 0.25, dtype=torch.float32) + exact_rotation = torch.tensor( + [ + [0.0, 1.0], + [-1.0, 0.0], + [0.0, -1.0], + [1.0, 0.0], + ], + dtype=torch.float32, + ) + # Hand-computed explicit Euler states for + # (I + (pi / 2) * [[0, -1], [1, 0]]) applied four times to [1, 0]. + expected_euler = torch.tensor( + [ + [1.0, 1.5707963267948966], + [-1.4674011002723395, 3.141592653589793], + [-6.4022033008170185, 0.8366043953472126], + [-7.716338412008886, -9.219953032970322], + ], + dtype=torch.float32, + ) + + Y = m(X)[0] + + assert Y.shape == exact_rotation.shape + torch.testing.assert_close(Y, expected_euler, rtol=1e-5, atol=1e-6) + assert not torch.allclose(Y, exact_rotation, rtol=1e-3, atol=1e-3) + + def test_slice_parallel_falls_back_when_associative_scan_is_unavailable(monkeypatch): x = _rand_x(batch=2, seq=4, dim=3, seed=10) monkeypatch.setattr( @@ -413,6 +578,30 @@ def test_stacked_slice_propagates_second_norm_toggle(): _assert_no_nan(y) +def test_stacked_slice_propagates_transition_mode(): + x = _rand_x(batch=2, seq=4, dim=6, seed=16) + + m = StackedSLiCE( + num_layers=2, + data_dim=6, + hidden_dim=8, + label_dim=5, + tokens=False, + block_size=4, + diagonal_dense=False, + use_parallel=False, + dropout_rate=0.0, + transition_mode="matrix_exp", + ) + m.eval() + + y = m(x) + + assert all(layer.slice.transition_mode == "matrix_exp" for layer in m.layers) + assert y.shape == (2, 4, 5) + _assert_no_nan(y) + + def test_stacked_slice_hidden_matches_forward_pre_projection(): batch, seq = 2, 5 vocab = 11 @@ -632,6 +821,7 @@ def test_slice_increments_mode_preserves_direct_input_behaviour_parallel(): @pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("transition_mode", ["euler", "matrix_exp"]) @pytest.mark.parametrize( "cfg", [ @@ -646,10 +836,23 @@ def test_slice_increments_mode_preserves_direct_input_behaviour_parallel(): ], ) @pytest.mark.parametrize("chunk_size", [1, 2, 8]) -def test_slice_parallel_matches_recurrent(cfg, bias: bool, chunk_size: int): +def test_slice_parallel_matches_recurrent( + cfg, bias: bool, transition_mode: str, chunk_size: int +): x = _rand_x(batch=2, seq=7, dim=cfg["input_dim"], seed=11) - m_recurrent = SLiCE(**cfg, bias=bias, use_parallel=False) - m_parallel = SLiCE(**cfg, bias=bias, use_parallel=True, chunk_size=chunk_size) + m_recurrent = SLiCE( + **cfg, + bias=bias, + use_parallel=False, + transition_mode=transition_mode, + ) + m_parallel = SLiCE( + **cfg, + bias=bias, + use_parallel=True, + chunk_size=chunk_size, + transition_mode=transition_mode, + ) m_parallel.load_state_dict(m_recurrent.state_dict()) y_recurrent = m_recurrent(x) @@ -670,6 +873,7 @@ def test_slice_parallel_with_input_dependent_init_matches_recurrent(): bias=True, input_dependent_init=True, use_parallel=False, + transition_mode="matrix_exp", ) m_parallel = SLiCE( input_dim=4, @@ -680,6 +884,7 @@ def test_slice_parallel_with_input_dependent_init_matches_recurrent(): input_dependent_init=True, use_parallel=True, chunk_size=3, + transition_mode="matrix_exp", ) m_parallel.load_state_dict(m_recurrent.state_dict()) diff --git a/uv.lock b/uv.lock index 3d70234..0859e24 100644 --- a/uv.lock +++ b/uv.lock @@ -2578,7 +2578,7 @@ wheels = [ [[package]] name = "torch-slices" -version = "0.3.2" +version = "0.3.3" source = { editable = "." } dependencies = [ { name = "numpy" },