From 734515388e0a991169a70ba756c02063108dcee2 Mon Sep 17 00:00:00 2001 From: ahmad-573 Date: Tue, 14 Oct 2025 19:35:02 +0200 Subject: [PATCH] Add capability to pass initial hidden state to selective scan while minimising memory and time overhead and without modifying the cuda kernel --- mamba_ssm/ops/selective_scan_interface.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359c..6f99b3234 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -111,6 +111,25 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ """ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) +def select_scan_fn_init_hidden(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False, init_state=None): + """ + Selective Scan with hidden state initialization + It uses the base kernel method but applies a transformation to B to mimic the initialization of the hidden state. + Memory and time overhead is minimal compared to ref method. + """ + delta2 = delta + delta_bias[..., None].float() + v = F.softplus(delta2[:,:,0]) * u[:,:,0] + 1e-7 + B = B.clone() + sig = (init_state[:,:,:] * torch.exp(F.softplus(delta2[:,:,0]).unsqueeze(-1) * A.unsqueeze(0)) ) / v.unsqueeze(-1) + dim = A.shape[0] + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + B = B.clone() + B[:,:,:,0] = B[:,:,:,0] + sig + C = C.clone() + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False):