diff --git a/README.md b/README.md index 687b6d1..c5c931f 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ $$ y_i = y_{i-1} + A(X_i)y_{i-1} + B(X_i), $$ -where $A(\cdot): \mathbb{R}^D \rightarrow \mathbb{R}^{H \times H}$ and $B(\cdot): \mathbb{R}^D \rightarrow \mathbb{R}^H$ are *learned linear maps*, the initial state $y_0$ is either a function of $X_0$ or a learnt vector, and the input is augmented with an extra channel: +where $A(\cdot): \mathbb{R}^D \rightarrow \mathbb{R}^{H \times H}$ and $B(\cdot): \mathbb{R}^D \rightarrow \mathbb{R}^H$ are *learned linear maps*, the initial state $y_0$ is either a function of $X_0$ or a learnt vector, and the driving path is augmented with an extra channel: - `inc` = a constant “increment” channel (all ones) @@ -20,6 +20,10 @@ $$ X_i = [inc_i, x_i] \in \mathbb{R}^{D+1}. $$ +By default, `SLiCE` treats the provided sequence as path values and internally +computes first differences with `torch.diff(..., prepend=zeros)`. Set +`path_mode="increments"` to treat the provided sequence as increments instead. + ## Installation ```bash @@ -35,7 +39,7 @@ pip install git+https://github.com/datasig-ac-uk/slices.git ## What's included - **`SLiCE`**: the core structured recurrence -- **`SLiCELayer`**: a residual layer wrapping `SLiCE` with a post-activation stage (`GLU` or `tanh`) +- **`SLiCELayer`**: a residual SLiCE layer with RMSNorm + GELU MLP by default - **`StackedSLiCE`**: stacks multiple `SLiCELayer`s with an embedding + output projection (supports tokens or continuous inputs) `SLiCE` supports both: @@ -92,6 +96,7 @@ print(y.shape) ``` Execution mode is configured via constructor arguments (`use_parallel`, `chunk_size`). +`path_mode` determines how `SLiCE` treats the sequence you pass in. ### Use `SLiCELayer` as a residual sequence layer @@ -99,19 +104,29 @@ Execution mode is configured via constructor arguments (`use_parallel`, `chunk_s import torch from slices import SLiCELayer -x = torch.randn(4, 256, 64) +x = torch.randn(4, 256, 64) # (batch, seq, input_dim) layer = SLiCELayer( input_dim=64, block_size=4, diagonal_dense=True, - dropout_rate=0.01, - use_glu=True, ) y = layer(x) # (4, 256, 64) ``` +`SLiCELayer` defaults to this structure: +- RMSNorm -> SLiCE -> residual +- RMSNorm -> Linear -> GELU -> Linear -> residual + +Optional toggles for the LayerNorm + single-stage wrapper include: +- `norm_type="layernorm"` +- `prenorm=False` +- `ff_style="single"` +- `ff_mult=1` +- `ff_activation="glu"` or `ff_activation="tanh"` +- `dropout_position="output"` + ### Stack layers for a full model #### Token sequence mode (`tokens=True`) @@ -135,7 +150,6 @@ model = StackedSLiCE( tokens=True, block_size=4, diagonal_dense=False, - use_glu=True, ) logits = model(x) # (batch, seq_len, vocab_size) @@ -149,7 +163,7 @@ Uses an `nn.Linear(data_dim, hidden_dim)` front-end. import torch from slices import StackedSLiCE -x = torch.randn(16, 100, 12) # (batch, seq, data_dim) +x = torch.randn(16, 100, 12) # (batch, seq, input_dim) model = StackedSLiCE( num_layers=3, @@ -174,7 +188,7 @@ This example: **character-level language disambiguation** - trains a compact token-mode `StackedSLiCE` end-to-end - evaluates validation accuracy every `--eval-every` training steps -- prints sample predictions so you can inspect model behavior quickly +- prints sample predictions so you can inspect model behaviour quickly To run it, first install the example dependencies: @@ -205,6 +219,7 @@ uv run python examples/benchmark_parallel_vs_recurrent.py This script: - benchmarks all four SLiCE matrix modes (`diagonal`, `block_diagonal`, `diagonal_dense`, `dense`) +- uses the default value-path semantics unless `path_mode="increments"` is set in code - prints timing/speedup tables - saves a combined 3D plot to `examples/images/parallel_vs_recurrent_speedup_3d_all_modes.png` diff --git a/examples/benchmark_parallel_vs_recurrent.py b/examples/benchmark_parallel_vs_recurrent.py index 9d5a5b6..1228595 100644 --- a/examples/benchmark_parallel_vs_recurrent.py +++ b/examples/benchmark_parallel_vs_recurrent.py @@ -487,6 +487,7 @@ def _benchmark_mode_grid( continue for s_idx, seq_len in enumerate(seq_lens): + # Random raw path values; SLiCE differences them internally by default. x = torch.randn( args.batch_size, seq_len, dim, device=device, dtype=torch.float32 ) diff --git a/examples/language_disambiguation.py b/examples/language_disambiguation.py index 11a1e2d..d31040d 100644 --- a/examples/language_disambiguation.py +++ b/examples/language_disambiguation.py @@ -333,7 +333,8 @@ def main() -> None: x_train, y_train = encode_texts(train_texts, train_labels, vocab, args.max_seq_len) x_val, y_val = encode_texts(val_texts, val_labels, vocab, args.max_seq_len) - # 3) Create model and optimiser. + # 3) Create model and optimiser. The default layer stack uses + # RMSNorm -> SLiCE -> residual -> GELU MLP -> residual. model = StackedSLiCE( num_layers=NUM_LAYERS, data_dim=len(vocab), @@ -343,7 +344,6 @@ def main() -> None: block_size=BLOCK_SIZE, diagonal_dense=False, scale=scale, - use_glu=True, dropout_rate=DROPOUT, use_parallel=True, chunk_size=128, diff --git a/src/slices/__init__.py b/src/slices/__init__.py index b0c7157..1cc05e2 100644 --- a/src/slices/__init__.py +++ b/src/slices/__init__.py @@ -6,7 +6,7 @@ Main components: - SLiCE: Core structured recurrence - - SLiCELayer: Residual layer wrapping SLiCE with post-activation + - SLiCELayer: Default residual SLiCE layer - StackedSLiCE: Stacked model with embedding and output projection """ diff --git a/src/slices/slices.py b/src/slices/slices.py index 1995d60..91b5dd8 100644 --- a/src/slices/slices.py +++ b/src/slices/slices.py @@ -6,6 +6,19 @@ import torch.nn as nn +class RMSNorm(nn.Module): + """A minimal RMSNorm used by the default SLiCELayer configuration.""" + + def __init__(self, d_model: int, eps: float = 1e-6): + super().__init__() + self.eps = float(eps) + self.weight = nn.Parameter(torch.ones(d_model)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + return (x * rms) * self.weight + + class SLiCE(nn.Module): """ A structured linear controlled differential equation (SLiCE) recurrence. @@ -27,6 +40,9 @@ class SLiCE(nn.Module): dense block of size block_size x block_size. init_std (float): Standard deviation for vector field initialisation. scale (float): Scaling factor applied to the input. + path_mode (str): Whether the input is treated as path values + ("values", default) or as increments + ("increments"). Shape: - Input: (batch_size, seq_len, input_dim) @@ -45,11 +61,14 @@ def __init__( input_dependent_init: bool = False, use_parallel: bool = True, chunk_size: int = 256, + path_mode: str = "values", ): super().__init__() if hidden_dim is None: hidden_dim = input_dim + if path_mode not in {"values", "increments"}: + raise ValueError("path_mode must be one of {'values', 'increments'}.") if block_size < 1: raise ValueError("block_size must be at least 1.") if not diagonal_dense and hidden_dim % block_size != 0: @@ -62,6 +81,7 @@ def __init__( self.init_std = init_std self.scale = scale self.input_dependent_init = input_dependent_init + self.path_mode = path_mode self.use_parallel = use_parallel if chunk_size < 1: @@ -136,6 +156,22 @@ def __init__( self.vf_B = nn.Linear(self.input_dim + 1, self.hidden_dim, bias=False) nn.init.normal_(self.vf_B.weight, mean=0.0, std=self.init_std) + def _prepare_driving_path(self, x: torch.Tensor) -> torch.Tensor: + if self.path_mode == "values": + return torch.diff( + x, + dim=1, + prepend=torch.zeros_like(x[:, :1, :]), + ) + return x + + def _prepare_augmented_inputs(self, x: torch.Tensor) -> torch.Tensor: + path = self._prepare_driving_path(x) + inc_ts = torch.ones( + path.shape[0], path.shape[1], 1, device=x.device, dtype=x.dtype + ) + return torch.cat((inc_ts, path), dim=-1) * self.scale + # ---- scan kernels: block_size == 1 (elementwise) ---- def _scan_kernels_elementwise(self) -> tuple[Callable, Callable, Callable]: @@ -292,13 +328,7 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: def _forward_recurrent(self, X: torch.Tensor) -> torch.Tensor: batch_size, seq_len, in_dim = X.shape - # Add the increments of a sample counting channel. - inc_ts = torch.full( - (batch_size, seq_len, 1), 1.0, device=X.device, dtype=X.dtype - ) - inp = torch.cat((inc_ts, X), dim=-1) # shape: (batch_size, seq_len, x_dim) - # Scale the input - inp = inp * self.scale + inp = self._prepare_augmented_inputs(X) # Initialise the hidden state if self.input_dependent_init: @@ -373,10 +403,7 @@ def _forward_parallel(self, X: torch.Tensor, chunk_size: int) -> torch.Tensor: batch_size, seq_len, _ = X.shape - inc_ts = torch.full( - (batch_size, seq_len, 1), 1.0, device=X.device, dtype=X.dtype - ) - inp = torch.cat((inc_ts, X), dim=-1) * self.scale # (B, T, D+1) + inp = self._prepare_augmented_inputs(X) if self.input_dependent_init: y_start = self.init(X[:, 0, :]) @@ -418,13 +445,24 @@ def _forward_parallel(self, X: torch.Tensor, chunk_size: int) -> torch.Tensor: class SLiCELayer(nn.Module): """ - A residual block wrapping a SLiCE. Includes: - 1. SLiCE forward pass - 2. Residual connection - 3. A Linear→GLU (or tanh) stage - 4. Residual connection - 5. LayerNorm - 6. Dropout + A residual layer wrapping a SLiCE. + + SLiCELayer defaults to this structure: + 1. RMSNorm + 2. SLiCE + 3. Residual connection + 4. RMSNorm + 5. Token MLP with hidden size ff_mult * input_dim and GELU + 6. Residual connection + 7. Dropout on each residual branch + + Optional toggles for the LayerNorm + GLU/tanh single-stage wrapper: + - norm_type="layernorm" + - prenorm=False + - ff_style="single" + - ff_mult=1 + - ff_activation="glu" or "tanh" + - dropout_position="output" The output dimension of the SLiCE is the same as the input dimension to preserve shape for the residual. @@ -437,9 +475,21 @@ class SLiCELayer(nn.Module): init_std (float): Standard deviation for weight initialisation in the SLiCE. use_parallel (bool): Whether the inner SLiCE uses parallel scan execution. chunk_size (int): Chunk size used by the inner SLiCE when in parallel mode. - dropout_rate (float): Dropout probability applied after the residual addition. - use_glu (bool): Whether to apply a Linear -> GLU stage after the residual or - a Linear -> tanh stage. + dropout_rate (float): Dropout probability applied either on residual branches + or on the block output, depending on dropout_position. + path_mode (str): How the inner SLiCE interprets the input path. + norm_type (str): "rmsnorm" or "layernorm". Defaults to "rmsnorm". + prenorm (bool): If True, apply normalisation before the SLiCE and + feedforward branches; if False, apply one norm after + both residual updates. + ff_style (str): "mlp" for Linear -> activation -> Linear, or + "single" for a single Linear -> activation branch. + ff_activation (str): "gelu", "glu", or "tanh". + ff_mult (int): Expansion factor for the hidden feedforward size. + dropout_position (str): "residual" to drop branch outputs before + residual addition, or "output" to drop the + final layer output. + norm_eps (float): Epsilon used by the normalisation layers. Shape: - Input: (batch_size, seq_len, input_dim) @@ -458,9 +508,35 @@ def __init__( use_parallel: bool = True, chunk_size: int = 256, dropout_rate: float = 0.01, - use_glu: bool = False, + path_mode: str = "values", + norm_type: str = "rmsnorm", + prenorm: bool = True, + ff_style: str = "mlp", + ff_activation: str = "gelu", + ff_mult: int = 4, + dropout_position: str = "residual", + norm_eps: float = 1e-6, ): super().__init__() + if norm_type not in {"rmsnorm", "layernorm"}: + raise ValueError("norm_type must be one of {'rmsnorm', 'layernorm'}.") + if ff_style not in {"mlp", "single"}: + raise ValueError("ff_style must be one of {'mlp', 'single'}.") + if ff_activation not in {"gelu", "glu", "tanh"}: + raise ValueError("ff_activation must be one of {'gelu', 'glu', 'tanh'}.") + if ff_mult < 1: + raise ValueError("ff_mult must be at least 1.") + if ff_style == "single" and ff_mult != 1: + raise ValueError("ff_mult must be 1 when ff_style='single'.") + if dropout_position not in {"residual", "output"}: + raise ValueError("dropout_position must be one of {'residual', 'output'}.") + + self.norm_type = norm_type + self.prenorm = prenorm + self.ff_style = ff_style + self.ff_activation = ff_activation + self.ff_mult = ff_mult + self.dropout_position = dropout_position self.slice = SLiCE( input_dim=input_dim, hidden_dim=None, @@ -472,30 +548,39 @@ def __init__( input_dependent_init=input_dependent_init, use_parallel=use_parallel, chunk_size=chunk_size, + path_mode=path_mode, ) - self.norm = nn.LayerNorm(input_dim) - # Linear -> GLU or Linear -> tanh stage - self.use_glu = use_glu - if self.use_glu: - # Expand from input_dim -> 2*input_dim for GLU gating - self.linear = nn.Linear(input_dim, 2 * input_dim) - self.act = nn.GLU(dim=-1) + self.drop = nn.Dropout(p=dropout_rate) + if self.prenorm: + if norm_type == "rmsnorm": + self.slice_norm = RMSNorm(input_dim, eps=norm_eps) + self.ff_norm = RMSNorm(input_dim, eps=norm_eps) + else: + self.slice_norm = nn.LayerNorm(input_dim, eps=norm_eps) + self.ff_norm = nn.LayerNorm(input_dim, eps=norm_eps) else: - self.linear = nn.Linear(input_dim, input_dim) - self.act = lambda x: torch.tanh(x) + if norm_type == "rmsnorm": + self.norm = RMSNorm(input_dim, eps=norm_eps) + else: + self.norm = nn.LayerNorm(input_dim, eps=norm_eps) - self.drop = nn.Dropout(p=dropout_rate) + ff_hidden_dim = ff_mult * input_dim + ff_in_dim = 2 * ff_hidden_dim if ff_activation == "glu" else ff_hidden_dim + self.ff_in = nn.Linear(input_dim, ff_in_dim) + self.ff_out = ( + nn.Linear(ff_hidden_dim, input_dim) if ff_style == "mlp" else nn.Identity() + ) + if ff_activation == "gelu": + self.act = nn.GELU() + elif ff_activation == "glu": + self.act = nn.GLU(dim=-1) + else: + self.act = nn.Tanh() def forward(self, X: torch.Tensor) -> torch.Tensor: """ - Forward pass: - 1. Compute SLiCE on input - 2. Apply residual skip connection - 3. Apply Linear -> GLU (or tanh) stage - 4. Add residual skip connection - 5. LayerNorm - 6. Dropout + Forward pass for a configurable SLiCELayer. Args: X (torch.Tensor): shape (batch_size, seq_len, input_dim) @@ -503,26 +588,25 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: shape (batch_size, seq_len, input_dim) """ - # Step 1: SLiCE - ys = self.slice(X) # shape: (batch_size, seq_len, input_dim) - - # Step 2: Residual skip - ys = ys + X - - # Step 3: Linear -> GLU (or tanh) - ys_lin = self.linear(ys) # shape: (batch_size, seq_len, 2*input_dim) - ys_lin = self.act(ys_lin) # shape: (batch_size, seq_len, input_dim) - - # Step 4: Residual skip - ys = ys + ys_lin - - # Step 5: LayerNorm - ys = self.norm(ys) + slice_input = self.slice_norm(X) if self.prenorm else X + slice_out = self.slice(slice_input) + if self.dropout_position == "residual": + X = X + self.drop(slice_out) + else: + X = X + slice_out - # Step 6: Dropout - ys = self.drop(ys) + ff_input = self.ff_norm(X) if self.prenorm else X + ff_out = self.ff_out(self.act(self.ff_in(ff_input))) + if self.dropout_position == "residual": + X = X + self.drop(ff_out) + else: + X = X + ff_out - return ys + if not self.prenorm: + X = self.norm(X) + if self.dropout_position == "output": + X = self.drop(X) + return X class StackedSLiCE(nn.Module): @@ -535,17 +619,22 @@ class StackedSLiCE(nn.Module): data_dim (int): Dimension of the input. hidden_dim (int): Hidden dimension used in each SLiCELayer. label_dim (int): Size of the output dimension. - block_size (int): The size of the blocks along the diagonal of A in each block. + block_size (int): The size of the blocks along the diagonal of A in each layer. diagonal_dense (bool): If True, A is composed of a diagonal matrix and a dense - block in each block. - init_std (float): Standard deviation for the initialisation in each block. - use_parallel (bool): Whether each block's inner SLiCE uses + block in each layer. + init_std (float): Standard deviation for the initialisation in each layer. + use_parallel (bool): Whether each layer's inner SLiCE uses parallel scan execution. - chunk_size (int): Chunk size used by each block's inner SLiCE in parallel mode. - dropout_rate (float): Dropout probability applied in each block after the - residual. - use_glu (bool): Whether to apply a Linear -> GLU or Linear -> tanh stage after - the residual. + chunk_size (int): Chunk size used by each layer's inner SLiCE in parallel mode. + dropout_rate (float): Dropout probability applied in each layer. + path_mode (str): How each inner SLiCE interprets its input path. + norm_type (str): "rmsnorm" or "layernorm" for each stacked layer. + prenorm (bool): Whether each stacked layer uses pre-norm. + ff_style (str): "mlp" or "single" feedforward branch shape. + ff_activation (str): "gelu", "glu", or "tanh". + ff_mult (int): Expansion factor for the feedforward hidden size. + dropout_position (str): "residual" or "output". + norm_eps (float): Epsilon used by the normalisation layers. Shape: - Input: (batch_size, seq_len) if the input is tokens or @@ -569,7 +658,14 @@ def __init__( use_parallel: bool = True, chunk_size: int = 256, dropout_rate: float = 0.01, - use_glu: bool = False, + path_mode: str = "values", + norm_type: str = "rmsnorm", + prenorm: bool = True, + ff_style: str = "mlp", + ff_activation: str = "gelu", + ff_mult: int = 4, + dropout_position: str = "residual", + norm_eps: float = 1e-6, ): super().__init__() self.tokens = tokens @@ -592,7 +688,14 @@ def __init__( use_parallel=use_parallel, chunk_size=chunk_size, dropout_rate=dropout_rate, - use_glu=use_glu, + path_mode=path_mode, + norm_type=norm_type, + prenorm=prenorm, + ff_style=ff_style, + ff_activation=ff_activation, + ff_mult=ff_mult, + dropout_position=dropout_position, + norm_eps=norm_eps, ) for _ in range(num_layers) ] diff --git a/tests/test_slices.py b/tests/test_slices.py index 9bd6da1..cd8d614 100644 --- a/tests/test_slices.py +++ b/tests/test_slices.py @@ -36,11 +36,40 @@ def test_slice_raises_if_block_size_less_than_one(): SLiCE(input_dim=3, hidden_dim=3, block_size=0) +def test_slice_raises_if_path_mode_is_invalid(): + with pytest.raises(ValueError, match="path_mode must be one of"): + SLiCE(input_dim=3, hidden_dim=3, path_mode="not-a-mode") + + # ----------------------- # SLiCE: forward paths # ----------------------- +@pytest.mark.parametrize("use_parallel", [False, True]) +def test_slice_values_mode_matches_manual_external_differencing(use_parallel: bool): + x = _rand_x(batch=2, seq=6, dim=4, seed=1) + dx = torch.diff(x, dim=1, prepend=torch.zeros_like(x[:, :1, :])) + + kwargs = dict( + input_dim=4, + hidden_dim=4, + block_size=2, + diagonal_dense=False, + bias=True, + use_parallel=use_parallel, + chunk_size=2, + ) + m_values = SLiCE(**kwargs, path_mode="values") + m_increments = SLiCE(**kwargs, path_mode="increments") + m_increments.load_state_dict(m_values.state_dict()) + + y_values = m_values(x) + y_increments = m_increments(dx) + + torch.testing.assert_close(y_values, y_increments, rtol=1e-5, atol=1e-6) + + def test_slice_forward_block_diagonal_block_size_gt1_bias_true_with_grads(): # diagonal_dense=False, block_size>1 hits the 4D block-matmul path x = _rand_x(batch=3, seq=5, dim=4, seed=2).requires_grad_(True) @@ -164,13 +193,54 @@ def test_slice_diagonal_dense_block_size_one_runs(): # ----------------------- -# SLiCELayer: both GLU and tanh paths +# SLiCELayer # ----------------------- -@pytest.mark.parametrize("use_glu", [False, True]) +def test_slice_layer_default_forward_shape(): + x = _rand_x(batch=2, seq=4, dim=6, seed=7) + + block = SLiCELayer( + input_dim=6, + block_size=2, + diagonal_dense=False, + use_parallel=False, + dropout_rate=0.05, + ) + + y = block(x) + assert block.norm_type == "rmsnorm" + assert block.prenorm + assert block.ff_style == "mlp" + assert block.ff_activation == "gelu" + assert block.dropout_position == "residual" + assert y.shape == x.shape + _assert_no_nan(y) + + +def test_slice_layer_default_backward(): + x = _rand_x(batch=2, seq=4, dim=6, seed=14).requires_grad_(True) + + block = SLiCELayer( + input_dim=6, + block_size=2, + diagonal_dense=False, + use_parallel=False, + dropout_rate=0.0, + ) + + y = block(x) + y.sum().backward() + + assert x.grad is not None + _assert_grads_exist(block) + + +@pytest.mark.parametrize("ff_activation", ["tanh", "glu"]) @pytest.mark.parametrize("diagonal_dense", [False, True]) -def test_slice_block_forward_covers_glu_and_tanh(use_glu: bool, diagonal_dense: bool): +def test_slice_layer_single_stage_toggles_cover_glu_and_tanh( + ff_activation: str, diagonal_dense: bool +): x = _rand_x(batch=2, seq=4, dim=6, seed=7) block = SLiCELayer( @@ -179,10 +249,21 @@ def test_slice_block_forward_covers_glu_and_tanh(use_glu: bool, diagonal_dense: diagonal_dense=diagonal_dense, use_parallel=False, dropout_rate=0.05, - use_glu=use_glu, + norm_type="layernorm", + prenorm=False, + ff_style="single", + ff_activation=ff_activation, + ff_mult=1, + dropout_position="output", ) y = block(x) + assert block.norm_type == "layernorm" + assert not block.prenorm + assert block.ff_style == "single" + assert block.ff_mult == 1 + assert block.ff_activation == ff_activation + assert block.dropout_position == "output" assert y.shape == x.shape _assert_no_nan(y) @@ -193,7 +274,7 @@ def test_slice_block_forward_covers_glu_and_tanh(use_glu: bool, diagonal_dense: def test_stacked_slice_tokens_path(): - # tokens=True uses nn.Embedding + # tokens=True uses nn.Embedding and the default layer structure. batch, seq = 2, 5 vocab = 11 hidden = 8 @@ -211,7 +292,6 @@ def test_stacked_slice_tokens_path(): diagonal_dense=False, use_parallel=False, dropout_rate=0.0, - use_glu=True, ) m.eval() @@ -221,7 +301,7 @@ def test_stacked_slice_tokens_path(): def test_stacked_slice_continuous_path(): - # tokens=False uses nn.Linear embedding + # tokens=False uses nn.Linear embedding. batch, seq = 2, 4 data_dim = 6 hidden = 8 @@ -239,7 +319,12 @@ def test_stacked_slice_continuous_path(): diagonal_dense=True, use_parallel=False, dropout_rate=0.0, - use_glu=False, + norm_type="layernorm", + prenorm=False, + ff_style="single", + ff_activation="tanh", + ff_mult=1, + dropout_position="output", ) m.eval() @@ -248,7 +333,7 @@ def test_stacked_slice_continuous_path(): _assert_no_nan(y) -def test_slice_input_dim2_golden_example(): +def test_slice_increments_mode_preserves_direct_input_behaviour(): """ Hand calculated example with input_dim=2 so augmented inp has 3 channels: inp = [inc_ts, x1, x2] @@ -262,6 +347,7 @@ def test_slice_input_dim2_golden_example(): bias=False, scale=1.0, use_parallel=False, + path_mode="increments", ) m.eval() @@ -326,10 +412,11 @@ def test_slice_input_dim2_golden_example(): # ----------------------- -def test_slice_input_dim2_golden_example_parallel(): +def test_slice_increments_mode_preserves_direct_input_behaviour_parallel(): """ - Same hand-calculated setup as test_slice_input_dim2_golden_example, - but evaluated through the parallel/chunked path. + Same hand-calculated setup as + test_slice_increments_mode_preserves_direct_input_behaviour, but evaluated + through the parallel/chunked path. """ m = SLiCE( @@ -341,6 +428,7 @@ def test_slice_input_dim2_golden_example_parallel(): scale=1.0, use_parallel=True, chunk_size=4, + path_mode="increments", ) m.eval()