Skip to content

Consider allowing hidden state initialisation via ssm_state input parameter for selective_scan_fn  #258

@govorunov

Description

@govorunov

Please, please, consider adding the ssm_state input parameter for selective_scan_fn to allow hidden state initialisation for the Mamba block.
Also please consider making hidden state differentiable as currently at selective_scan_fn we have:

Note that the gradient of the last state is not considered in the backward pass.

This change should potentially open the path for encoder-decoder Mamba architecture and for the encoder-only BERT-like architecture.
The architecture analogous to RNNs would be - Mamba encoder goes through the input sequence ignoring output, the last hidden state then used to initialize the decoder with input token and the decoder unrolls the state recursively.
For the encoder to work last hidden state has to be differentiable. This also should open a route to encoder-only BERT architecture, classification/embedding problems, etc.
For the decoder to work the Mamba block needs to be able to accept a hidden state at initialisation.

Related issues: #233 , #101

PS: Excellent work! Very impressive (especially the CUDA part)!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions