Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bamba architecture #10810

Draft
wants to merge 43 commits into
base: master
Choose a base branch
from
Draft

Conversation

gabe-l-hart
Copy link
Contributor

Description

This PR adds support for the forthcoming bamba architecture from IBM Research. It is a hybrid SSM architecture which is similar in nature to jamba, but uses the mamba2 mixer instead of the original mamba mixer.

Dependencies

This PR is based on in-flight work for this new model which will be published soon. There are other in-flight PRs for this model:

NOTE: In order to run the conversion steps, you will need to install the branch of transformers from the above PR until it is merged and released.

TODOs

  • Figure out trajectory for the mamba2 PR
  • Fix support on metal
  • Ensure support with cuda

Changes

This PR comes with some big changes in addition to the standard pattern for contributing a new architecture. The need for these big changes arises from how the KV cache is used differently for attention layers versus recurrent layers.

Library Interface

  • Add llama_model_is_hybrid. This mirrors the llama_model_is_recurrent function and currently only returns true for Bamba

llama.cpp project architecture

  • llama_hparams:
    • Add recurrent_layer_arr to support per-layer indicator for recurrence
    • Add bool recurrent_layer(uint32_t il) const to index into recurrent_layer_arr
    • Update n_embd_k_s and n_embd_v_s to be per-layer and return 0 for non-recurrent layers
  • llama_context:
    • Add struct llama_kv_cache kv_hybrid as a secondary KV cache for hybrid models. This is the biggest structural change that allows the recurrent and attention layers to operate independently in the cache. For non-hybrid models, this will never be initialized.
  • llama_kv_cache_init:
    • Add the recurrent argument. This allows recurrent models to initialize a non-recurrent cache (i.e. the attention cache for hybrid models)
    • Determine the size of the cache tensors on a per-layer basis so that layers which are managed by the other cache in a hybrid model have zero size (i.e. in the kv_self, all recurrent layers will have zero size, and in kv_hybrid, all attention layers will have zero size)
  • llm_load_hparams:
    • Automatically fill recurrent_layer_arr based on llama_model_is_recurrent(&model). For hybrid models, it falls to the individual model architecture construction to populate the per-layer entries correctly.
  • llm_build_mamba2:
    • Allow kv_hybrid as the cache that gets initialized instead of kv_self. This is determined based on a new hybrid flag passed by the build_<arch> implementation.
  • llm_build_context:
    • Add replicas of all KV-related members to support the second kv_hybrid cache (kv_hybrid, n_kv_hybrid, kv_head_hybrid, rs_zero_hybrid)
    • Populate all of ^ based on lctx.kv_hybrid
  • llama_set_inputs:
    • Do the recurrent branch for hybrid models and use kv_self or kv_hybrid as appropriate (kv_hybrid will be the recurrent cache for hybrid models)
  • llama_decode_internal:
    • Use simple_split for hybrid models as well, even if kv_self.recurrent is false
    • Add a second llama_kv_slot_restorer for kv_hybrid and perform the save/restore operations if (and only if) the model is hybrid
    • Update kv_hybrid ring buffer when needed
    • Defrag the kv_hybrid cache for hybrid models
  • llama_new_context_with_model:
    • Manage kv_size and kv_size_hybrid independently
    • Init kv_self as recurrent IFF the model is recurrent but not hybrid
    • Init the kv_hybrid cache for hybrid models

Model Architecture

  • Add architecture enum and layer set as normal
  • Add hparam for LLM_KV_ATTENTION_LAYER_INDICES ("%s.attention.layer_indices") to indicate per-layer list of which layers use attention
    • NOTE: Some hybrid models use a period and offset rather than an explicit list, but the list is the most flexible so I opted to only use that at the llama.cpp layer and require conversion from period/offset to a list in the conversion script
  • Add hparam LLM_KV_SSM_HEAD_DIM ("%s.ssm.head_dim") to set the head_dim from config rather than deducing it from d_inner / n_head
  • Add model enum entry in llm_load_tensors
  • Add build_bamba to construct the graph
  • Mark LLM_ARCH_BAMBA as hybrid in llama_model_is_hybrid

Conversion

  • Add Keys.HybridMamba section in constants.py to hold hybrid model parameters
    • NOTE: I'm torn on whether this should be Keys.HybridMamba, Keys.HybridRecurrent, or just Keys.Hybrid
  • Add new hparams with plumbing for ssm_head_dim and attn_layer_indices in constants.py and gguf_writer.py
  • Update tensor_mapping.py names for Bamba layer names
  • Add class BambaModel in convert_hf_to_gguf.py and base it on Mamba2Model

Open Questions

  • What are the plans for the current mamba2 PR?
  • Is there a better approach to handle they hybrid KV caching? I tried going the route of adding additional pointers into the single kv_self, but kept getting hung up on kv_self.size being different for the two different types of caching.
  • Are there other places where the KV cache is used that should be updated to also support kv_hybrid?

compilade and others added 30 commits August 21, 2024 18:00
* ggml : improve ggml_mul speed when masking recurrent states
* ggml : make the ggml_mul fast broadcast path more consistently formatted
The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.
The max index is 31, so trimming the arguments is necessary.
Whoops, this is needed for the offset in the concatenated output.
This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.
This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
There are likely still some missing hparams, but the tensor mapping should
be correct

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
….cpp parsing

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
… step rank

head count and time step rank are used for the same purpose in the model,
so we stick with the existing key. Chunk size is not used in this impl
because of the way the mixer is implemented without chunking.

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Not necessary despite their presence in the model config.

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
There are still problems at inference around matrix dimensions not lining
up, so there are likely still places where the per-layer sizes are not
being used correctly.

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
For hybrid models, this value should be 0 for the non-recurrent layers

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
This also seems like not _quite_ the right direction

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
…rid models

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Still not fully working, but worth committing these:

* per-layer n_embd_[kv]_s (probably a no-op since first layer is ssm)
* fix setting n_kv_hybrid when not worst_case
* Use the right n_kv for build_inp_s_copy when hybrid
* Use the right n_kv for recurrent section of llama_set_inputs
* Use the right logic to determine batch splitting for hybrid

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
I'd added this at one point, but it's not actually needed

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <[email protected]>
@github-actions github-actions bot added testing Everything test related python python script changes ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Dec 12, 2024
@ggerganov
Copy link
Owner

ggerganov commented Dec 13, 2024

What are the plans for the current mamba2 PR?

I think src/llama.cpp desperately needs some refactoring before we continue to introduce major changes to it. It's time to split it in multiple files and refactor the KV cache implementation to support different modes and be able to add tests. I have started several times to do that, but keep getting side-tracked by some other things.

After finalizing the TTS arch, I will try to finally do this refactoring. Or at the very least - the "split into separate source files" part.

I think mamba2 support should be added after this happens.

Is there a better approach to handle they hybrid KV caching? I tried going the route of adding additional pointers into the single kv_self, but kept getting hung up on kv_self.size being different for the two different types of caching.

Are there other places where the KV cache is used that should be updated to also support kv_hybrid?

It's hard to answer. We are hacking the new KV modes into the original design and things are sometimes hard to fit. That is why we need to reimplement this in order to allow different implementations for different use cases.

@gabe-l-hart
Copy link
Contributor Author

@ggerganov Thanks for the feedback!

It's time to split it in multiple files and refactor the KV cache implementation to support different modes and be able to add tests.

This is music to my ears! I honestly thought about trying to do this myself to help me decompose the problem of adding this hybrid support (knowing it would be throw-away, but a very good learning experience). I know I'm just dipping my toes in at this point, but if there's any help I can offer here, please let me know.

For the time being, unless you'd prefer otherwise, I'll keep this PR open in Draft as a guide to interested parties trying out the Bamba architecture. Once the big refactor moves forward, I'll look into refactoring this work on top of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning python python script changes testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants