Skip to content

Commit

Permalink
docs: add more notes on mamba (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
danbev committed Aug 12, 2024
1 parent 270939f commit cdaa605
Showing 1 changed file with 62 additions and 15 deletions.
77 changes: 62 additions & 15 deletions notes/mamba.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ Mamba is said that it might be just as influential as the transformer
architecture but this is left to be seen.

One of the authors is Tri Dao, was also involved in the developement of
[Flash Attention](./flash-attention.md) and there one part of Mamba is taking
advantage of the GPU hardware.
[Flash Attention](./flash-attention.md) and Mamba takes advantage of the GPU
hardware.

So we have transformers which I thought were the absolute latest and advanced
as they seem to be used all over the place. Transformers are effecient at
training as they can be parallelized, incontrast to RNNs which are sequential.
Transformers are effecient at training as they can be parallelized, incontrast
to RNNs which are sequential which makes large models a slow process.

But, the issue with transformers is that they don't scale to long sequences
which is because the self attention mechanism is quadratic in the sequence
length. Every token has to attend to every other token in a sequence (n²). So if
we have 40 tokens that means 1600 attention operations, which means more
computation and this just increases the longer the input sequence is.

In this respect RNNs are more performant as they don't have the quadratic
scaling issue that the self attention mechanism has (but do have other like
slower training).
scaling issue that the self attention mechanism has (but do have other
disadvantages like slower training).

The core of Mamba is state space models (SSMs). Before we go further it might
make sense to review [RNNs](./rnn.md) and [SSMs](./state-space-models.md).
Expand All @@ -29,10 +29,55 @@ Selective state space models, which Mamaba is a type of, give us a linear
recurrent network simliar to RRNs, but also have the fast training that we get
from transformers. So we get the best of both worlds.

```
_ _
h_t = Ah_{t-1} + Bx_t
Where:
A_bar = is the state transition matrix.
B_bar = input projection matrix.
x_t = the input at time t.
h_t = the hidden state at time t.
h_t-1 = the previous hidden state.
```

```
Input (x_t)
|
v
+----------+
| B | Input Projection Matrix
+----------+
|
v
+---+---+
| + | <---- A * h_{t-1} (Previous Hidden State)
+---+---+
|
v
+----------+
| S4D/SSM | State Space Model
+----------+
|
v
+------------+
| LayerNorm |
+------------+
|
v
+---------+
| SiLU | Activation Function
+---------+
|
v
Hidden State (h_t)
```

One major difference with state space models is that they have state which is
not something the transformers have. So transformers don't have an intrinsic
state which gets updated as the model processes a sequence. But neural networks
like RNNs do have state, but recall that they process the input sequentially.
not something the transformers have (well one might consider the kv-cache the
state). So transformers don't have an intrinsic state which gets updated as the
model processes a sequence. But neural networks like RNNs do have state, but
recall that they process the input sequentially.

To understand how Mamba fits in I found it useful to compare it to how
transformers look in an neural network:
Expand Down Expand Up @@ -97,10 +142,10 @@ h'(t) = Ah(t) + Bx(t) (state equation)
yₜ = Ch(t) + Dx(t) (output equation) (Dx is not referred to in the paper)
h ∈ Rⁿ is the like the hidden state in an RNN
x ∈ R¹ is the input sequence
x ∈ R¹ is the input sequence, x(t) is the input at time t.
y ∈ R¹ is the output sequence
A ∈ Rⁿ×ⁿ is the state transition matrix
B ∈ R¹×ⁿ is the input matrix
B ∈ R¹×ⁿ is the input projection matrix
C ∈ Rⁿ×¹ is the output matrix
```
Now, the current state of the system is give in `h(t)`. And the matrix A can
Expand Down Expand Up @@ -360,7 +405,9 @@ a filter that is moved over the input and the dot product is computed. My
thought was how is this possible when it the input is sequential, like it can't
access future values so what is it convolving over?
I think the answer is that the causual convolution where the filter is only
applied to past values.
applied to past values. During training the model does have access to the
complete sequence but during inference it does not.


So it is that the system state is a representation of the past values and it can
then be seen as the filter is moving across those past values. So the filter in
Expand Down Expand Up @@ -444,7 +491,7 @@ t = 1 +--------------+
+--------------+
```
What I'm trying to convey here is that the filter is moving across the input
and at each timestep it is computing the weighted sums of the past state and
and at each timestep, it is computing the weighted sums of the past state and
the current input. As move input comes in the filter is "moved" across to the
next input.

Expand Down Expand Up @@ -495,7 +542,7 @@ One thing to keep in mind is that the state h is intended to capture the history
of the sequence x. How this is done depends on the transformation matrices A
and B. In practice if the sequence is long then the model may forget earlier
information. The model prioritizes more recent information. Just to draw a
parallel to transformers, the self attention mechanism can take the entire
parallel to transformers, the self-attention mechanism can take the entire
sequence into account but it this can become very computationally expensive
as the sequence becomes very long.

Expand Down

0 comments on commit cdaa605

Please sign in to comment.