-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bamba architecture #10810
base: master
Are you sure you want to change the base?
Bamba architecture #10810
Conversation
* ggml : improve ggml_mul speed when masking recurrent states
* ggml : make the ggml_mul fast broadcast path more consistently formatted
The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly.
The max index is 31, so trimming the arguments is necessary.
Whoops, this is needed for the offset in the concatenated output.
This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity.
This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
There are likely still some missing hparams, but the tensor mapping should be correct Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
….cpp parsing Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
… step rank head count and time step rank are used for the same purpose in the model, so we stick with the existing key. Chunk size is not used in this impl because of the way the mixer is implemented without chunking. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Not necessary despite their presence in the model config. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
There are still problems at inference around matrix dimensions not lining up, so there are likely still places where the per-layer sizes are not being used correctly. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
For hybrid models, this value should be 0 for the non-recurrent layers Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
This also seems like not _quite_ the right direction Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
…rid models Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Still not fully working, but worth committing these: * per-layer n_embd_[kv]_s (probably a no-op since first layer is ssm) * fix setting n_kv_hybrid when not worst_case * Use the right n_kv for build_inp_s_copy when hybrid * Use the right n_kv for recurrent section of llama_set_inputs * Use the right logic to determine batch splitting for hybrid Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
I'd added this at one point, but it's not actually needed Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <[email protected]>
I think After finalizing the TTS arch, I will try to finally do this refactoring. Or at the very least - the "split into separate source files" part. I think mamba2 support should be added after this happens.
It's hard to answer. We are hacking the new KV modes into the original design and things are sometimes hard to fit. That is why we need to reimplement this in order to allow different implementations for different use cases. |
@ggerganov Thanks for the feedback!
This is music to my ears! I honestly thought about trying to do this myself to help me decompose the problem of adding this hybrid support (knowing it would be throw-away, but a very good learning experience). I know I'm just dipping my toes in at this point, but if there's any help I can offer here, please let me know. For the time being, unless you'd prefer otherwise, I'll keep this PR open in Draft as a guide to interested parties trying out the Bamba architecture. Once the big refactor moves forward, I'll look into refactoring this work on top of it. |
Description
This PR adds support for the forthcoming
bamba
architecture from IBM Research. It is a hybrid SSM architecture which is similar in nature to jamba, but uses the mamba2 mixer instead of the originalmamba
mixer.Dependencies
This PR is based on in-flight work for this new model which will be published soon. There are other in-flight PRs for this model:
transformers
: Add the Bamba Model huggingface/transformers#34982vllm
: Bamba VLLM Draft fabianlim/vllm#2NOTE: In order to run the conversion steps, you will need to install the branch of
transformers
from the above PR until it is merged and released.TODOs
metal
cuda
Changes
This PR comes with some big changes in addition to the standard pattern for contributing a new architecture. The need for these big changes arises from how the KV cache is used differently for
attention
layers versusrecurrent
layers.Library Interface
llama_model_is_hybrid
. This mirrors thellama_model_is_recurrent
function and currently only returnstrue
forBamba
llama.cpp
project architecturellama_hparams
:recurrent_layer_arr
to support per-layer indicator for recurrencebool recurrent_layer(uint32_t il) const
to index intorecurrent_layer_arr
n_embd_k_s
andn_embd_v_s
to be per-layer and return0
for non-recurrent layersllama_context
:struct llama_kv_cache kv_hybrid
as a secondary KV cache for hybrid models. This is the biggest structural change that allows the recurrent and attention layers to operate independently in the cache. For non-hybrid models, this will never be initialized.llama_kv_cache_init
:recurrent
argument. This allows recurrent models to initialize a non-recurrent cache (i.e. the attention cache for hybrid models)kv_self
, all recurrent layers will have zero size, and inkv_hybrid
, all attention layers will have zero size)llm_load_hparams
:recurrent_layer_arr
based onllama_model_is_recurrent(&model)
. For hybrid models, it falls to the individual model architecture construction to populate the per-layer entries correctly.llm_build_mamba2
:kv_hybrid
as the cache that gets initialized instead ofkv_self
. This is determined based on a newhybrid
flag passed by thebuild_<arch>
implementation.llm_build_context
:kv_hybrid
cache (kv_hybrid
,n_kv_hybrid
,kv_head_hybrid
,rs_zero_hybrid
)lctx.kv_hybrid
llama_set_inputs
:recurrent
branch for hybrid models and usekv_self
orkv_hybrid
as appropriate (kv_hybrid
will be the recurrent cache for hybrid models)llama_decode_internal
:simple_split
for hybrid models as well, even ifkv_self.recurrent
is falsellama_kv_slot_restorer
forkv_hybrid
and perform thesave
/restore
operations if (and only if) the model is hybridkv_hybrid
ring buffer when neededkv_hybrid
cache for hybrid modelsllama_new_context_with_model
:kv_size
andkv_size_hybrid
independentlykv_self
asrecurrent
IFF the model is recurrent but not hybridkv_hybrid
cache for hybrid modelsModel Architecture
LLM_KV_ATTENTION_LAYER_INDICES
("%s.attention.layer_indices"
) to indicate per-layer list of which layers use attentionllama.cpp
layer and require conversion from period/offset to a list in the conversion scriptLLM_KV_SSM_HEAD_DIM
("%s.ssm.head_dim"
) to set thehead_dim
from config rather than deducing it fromd_inner / n_head
llm_load_tensors
build_bamba
to construct the graphLLM_ARCH_BAMBA
as hybrid inllama_model_is_hybrid
Conversion
Keys.HybridMamba
section inconstants.py
to hold hybrid model parametersKeys.HybridMamba
,Keys.HybridRecurrent
, or justKeys.Hybrid
ssm_head_dim
andattn_layer_indices
inconstants.py
andgguf_writer.py
tensor_mapping.py
names for Bamba layer namesclass BambaModel
inconvert_hf_to_gguf.py
and base it onMamba2Model
Open Questions
mamba2
PR?kv_self
, but kept getting hung up onkv_self.size
being different for the two different types of caching.kv_hybrid
?