diff --git a/models/s6.py b/models/s6.py new file mode 100644 index 0000000..d9a2275 --- /dev/null +++ b/models/s6.py @@ -0,0 +1,209 @@ +import math + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + + +class MambaRecurrence(nn.Module): + """ + Implements the Mamba recurrence layer for sequence modeling. + + Args: + d_model (int): Dimension of the model (number of features). + d_state (int): Dimension of the state space. Defaults to 16. + dt_rank (int or str): Rank for the time-step parameterization. Defaults to + "auto". + dt_min (float): Minimum value for time-steps. Defaults to 0.001. + dt_max (float): Maximum value for time-steps. Defaults to 0.1. + dt_init (str): Initialization method for dt. Options are "constant" or "random". + Defaults to "random". + dt_scale (float): Scale factor for dt initialization. Defaults to 1.0. + dt_init_floor (float): Floor value for initializing time-steps. Defaults to + 1e-4. + device (torch.device, optional): Device to run the computations on. If None, + uses CUDA if available. + + Forward Args: + hidden_states (Tensor): Input tensor of shape (batch_size, sequence_length, + d_model). + + Returns: + Tensor: Output tensor of the same shape as input. + """ + + def __init__( + self, + d_model, + d_state=16, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + device=None, + ): + super().__init__() + + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + else: + self.device = device + self.d_model = d_model + self.d_state = d_state + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + + self.x_proj = nn.Linear( + self.d_model, + self.dt_rank + self.d_state * 2, + bias=False, + device=self.device, + ) + self.dt_proj = nn.Linear( + self.dt_rank, self.d_model, bias=True, device=self.device + ) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_model, device=self.device) + * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this + # one as _no_reinit + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=self.device), + "n -> d n", + d=self.d_model, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + + # D "skip" parameter + self.D = nn.Parameter( + torch.ones(self.d_model, device=self.device) + ) # Keep in fp32 + self.D._no_weight_decay = True + + def forward(self, hidden_states): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + batch, seqlen, dim = hidden_states.shape + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + x = rearrange(hidden_states, "b l d -> b d l") + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split( + x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 + ) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=None, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=False, + ) + return rearrange(y, "b d l -> b l d") + + +class S6(nn.Module): + def __init__( + self, + num_blocks: int, + data_dim: int, + model_dim: int, + label_dim: int, + dropout_rate: float = 0.1, + second_embedding: bool = False, + use_glu: bool = False, + d_state: int = 16, + dt_rank: str | int = "auto", + ): + """ + d_state: state dimension of the SSM (default 16). + dt_rank: rank for dt parameterization ("auto" uses ceil(model_dim/16)). + """ + super().__init__() + self.second_embedding = second_embedding + + emb_dim = model_dim // 2 if second_embedding else model_dim + self.embedding = nn.Embedding(data_dim, emb_dim) + if second_embedding: + self.embedding2 = nn.Embedding(data_dim, emb_dim) + + self.blocks = nn.ModuleList( + [ + MambaRecurrence(model_dim, d_state=d_state, dt_rank=dt_rank) + for _ in range(num_blocks) + ] + ) + self.norms = nn.ModuleList([nn.LayerNorm(model_dim) for _ in range(num_blocks)]) + + self.dropout = nn.Dropout(dropout_rate) + self.linear = nn.Linear(model_dim, label_dim) + + self.use_glu = use_glu + if use_glu: + self.glu_projs = nn.ModuleList( + [nn.Linear(model_dim, 2 * model_dim) for _ in range(num_blocks)] + ) + else: + self.glu_projs = nn.ModuleList( + [nn.Linear(model_dim, model_dim) for _ in range(num_blocks)] + ) + self.act = nn.GLU() + + def mask_grads(self): + pass + + def _embed(self, x: torch.Tensor) -> torch.Tensor: + if not self.second_embedding: + # x: (B, L) + return self.embedding(x) # -> (B, L, model_dim) + else: + # x: (B, L, 2) + return torch.cat( + [self.embedding(x[:, :, 0]), self.embedding2(x[:, :, 1])], dim=-1 + ) # -> (B, L, model_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._embed(x) # (B, L, D=model_dim) + for i, (blk, ln) in enumerate(zip(self.blocks, self.norms)): + residual = x + x = blk(x) + h = self.glu_projs[i](x) + h = self.act(h) if self.use_glu else torch.tanh(h) + x = residual + x + h + x = self.dropout(ln(x)) + return self.linear(x) # (B, L, label_dim) diff --git a/train.py b/train.py index 4823ac9..95bac91 100644 --- a/train.py +++ b/train.py @@ -31,6 +31,16 @@ def set_seed(seed=42): torch.backends.cudnn.deterministic = True +def vfA_l2(block): + lcde = block.LCDE + tensors = [lcde.vf_A.weight if isinstance(lcde.vf_A, nn.Linear) else lcde.vf_A] + for name in ("vf_A_u", "vf_A_v"): + t = getattr(lcde, name, None) + if t is not None: + tensors.append(t) + return torch.sqrt(sum((t**2).sum() for t in tensors)) + + def train_model( config, data_dim, @@ -131,7 +141,7 @@ def train_model( norm = 0 for block in model.blocks: if hasattr(block, "LCDE"): - norm += torch.sum(block.LCDE.vf_A**2) ** 0.5 + norm += vfA_l2(block) if task == "C4": batch_size, seq_len, _ = outputs.shape @@ -287,6 +297,8 @@ def run_experiment(config): slstm_at = config.get("slstm_at", [1]) vf_A_norm_lambda = config.get("vf_A_norm_lambda", 0.001) rank = config.get("rank", 0) + d_state = config.get("d_state", 16) + dt_rank = config.get("dt_rank", "auto") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -359,6 +371,21 @@ def train_dataloader_multilength(): use_glu=use_glu, second_embedding=second_embedding, ) + elif model_name == "S6": + from models.s6 import S6 + + model = S6( + num_blocks=num_blocks, + model_dim=model_dim, + data_dim=data_dim, + label_dim=label_dim, + dropout_rate=dropout_rate, + use_glu=use_glu, + second_embedding=second_embedding, + d_state=d_state, + dt_rank=dt_rank, + ) + elif model_name in ["deltanet", "gateddeltanet", "rwkv7", "rwkv6", "deltaproduct"]: from models.fla import StackedBlock