Skip to content

Latest commit

 

History

History
14 lines (10 loc) · 1.08 KB

README.md

File metadata and controls

14 lines (10 loc) · 1.08 KB

Mamba clean code in jax and PyTorch

Actually, this is my one-evening attempt to get more handy with jax and flax on the basis of torch implementation on the example of Mamba[1]. It looks more like a somewhat detailed interface of this model that also requires training and inference code. I hope this code will help you become more confident with jax, flax or state-space models[2].

Feel free to contact me on any mistakes you find :) I have also tried to implement associative scan in the jax folder but probably it contains mistakes.

This repo is based on the following ones: annotated-mamba, mamba-minimal in torch, the official implementation

References

[1] - Gu, Dao et al. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces.
[2] Gu et al. (2022). Efficiently Modeling Long Sequences with Structured State Spaces