diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359..6f99b323 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -111,6 +111,25 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ """ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) +def select_scan_fn_init_hidden(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False, init_state=None): + """ + Selective Scan with hidden state initialization + It uses the base kernel method but applies a transformation to B to mimic the initialization of the hidden state. + Memory and time overhead is minimal compared to ref method. + """ + delta2 = delta + delta_bias[..., None].float() + v = F.softplus(delta2[:,:,0]) * u[:,:,0] + 1e-7 + B = B.clone() + sig = (init_state[:,:,:] * torch.exp(F.softplus(delta2[:,:,0]).unsqueeze(-1) * A.unsqueeze(0)) ) / v.unsqueeze(-1) + dim = A.shape[0] + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + B = B.clone() + B[:,:,:,0] = B[:,:,:,0] + sig + C = C.clone() + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False):