Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bamba architecture #10810

Draft
wants to merge 43 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
1f0fea7
llama : initial Mamba-2 support
compilade Aug 1, 2024
dceff23
ggml : SIMD ggml_ssm_scan for Mamba-2
compilade Aug 19, 2024
2bfe9de
llama : support running Mamba-Codestral-7B-v0.1
compilade Aug 19, 2024
aff9692
llama : fix Mamba-2 conv state saving
compilade Aug 21, 2024
e04910d
llama : remove unused variable
compilade Aug 22, 2024
fa358e7
llama : add missing break
compilade Aug 22, 2024
38913dc
convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present
compilade Aug 22, 2024
0e601ca
Merge branch 'master' into compilade/mamba2
compilade Sep 18, 2024
273e7a4
llama : avoid redundant state copy for Mamba 1 and 2
compilade Sep 30, 2024
7d6cb36
Merge branch 'master' into compilade/mamba2
compilade Oct 1, 2024
2c77d79
metal : attempt to adapt SSM_SCAN for Mamba-2
compilade Oct 2, 2024
87b97d0
metal : fix SSM_SCAN pipeline scope
compilade Oct 2, 2024
03d0e6e
metal : use log and exp instead of log1pf and expf in SSM_SCAN
compilade Oct 2, 2024
7a351ab
metal : remove unused arguments for SSM_SCAN
compilade Oct 2, 2024
8b15bc6
metal : add back n_seqs to SSM_SCAN args
compilade Oct 2, 2024
5b8ec2b
metal : fix SSM_SCAN state head offset
compilade Oct 2, 2024
62b09b3
metal : fix wrong number of tokens per sequence in SSM_SCAN
compilade Oct 3, 2024
038d958
Merge branch 'master' into compilade/mamba2
compilade Oct 12, 2024
805512a
ggml : remove unused fast broadcast path in GGML_MUL
compilade Oct 12, 2024
7d16e1b
Merge branch 'master' into compilade/mamba2
compilade Nov 1, 2024
3bc7103
ggml : avoid multiply by D in GGML_OP_SSM_SCAN
compilade Nov 4, 2024
8d8f065
Merge branch 'master' into compilade/mamba2
compilade Nov 4, 2024
b4e9c59
convert : fix flake8 lint
compilade Nov 4, 2024
1ee6c48
Merge branch 'master' into compilade/mamba2
compilade Nov 25, 2024
9a68f75
feat(jamba): First pass at GGUF conversion for Jamba models
gabe-l-hart Nov 26, 2024
246dfdb
feat(jamba): Add jamba architecture to llama.cpp enums
gabe-l-hart Nov 26, 2024
e3525e9
feat(convert): Full pass at hparam conversion
gabe-l-hart Dec 2, 2024
fd98682
fix(bamba conv): Jamba -> Bamba
gabe-l-hart Dec 3, 2024
1c1e008
fix(bamba): Jamba->Bamba in llama.cpp
gabe-l-hart Dec 3, 2024
e0af809
feat(bamba): hparam parsing in llama.cpp
gabe-l-hart Dec 3, 2024
fd3bb30
fix(bamba conv): Fizes in tensor name and hparam conversion for llama…
gabe-l-hart Dec 4, 2024
3ee0ae3
feat(bamba): Full tensor parsing for bamba
gabe-l-hart Dec 4, 2024
dfe8d3d
fix(bamba conv): Remove chunk size and consolidate head count w/ time…
gabe-l-hart Dec 5, 2024
41fc019
fix(bamba): Remove ssm_head_count and ssm_chunk_size in llama.cpp
gabe-l-hart Dec 5, 2024
e7b1abb
feat(bamba): Partially complete work on constructing the forward graph
gabe-l-hart Dec 5, 2024
f2478bc
fix: Get n_head_kv per-layer in build_bamba
gabe-l-hart Dec 9, 2024
d3a34e0
fix: per-layer recurrent embd_[kv]_s
gabe-l-hart Dec 9, 2024
92653d0
WIP: Partial work towards separate hybrid cache
gabe-l-hart Dec 9, 2024
44bf431
fix: Only allocate kv cache tensors for the appropriate layers in hyb…
gabe-l-hart Dec 10, 2024
4543ed5
feat: Update the logic in llama_decode_internal for kv_hybrid cache
gabe-l-hart Dec 10, 2024
204e78f
fix: A number of places where hybrid needs to be handled
gabe-l-hart Dec 10, 2024
97e6ba8
fix: Remove outdated TODO in convrsion script
gabe-l-hart Dec 12, 2024
b83e9a6
fix: Remove unused LLM_KV_ATTENTION_LAYER_COUNT
gabe-l-hart Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 220 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [(self.map_tensor_name(name), data_torch)]

# TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that)
def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor:
del new_name, bid # unused

return data_torch.squeeze()

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused

Expand Down Expand Up @@ -296,7 +302,7 @@ def prepare_tensors(self):
break

for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
data = data_torch.squeeze().numpy()
data = self.reshape_tensors(data_torch, new_name, bid).numpy()

# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
Expand Down Expand Up @@ -510,7 +516,10 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
vocab_size = max(
self.hparams.get("vocab_size", len(tokenizer.vocab)),
len(tokenizer.vocab)
)
assert max(tokenizer.vocab.values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)
Expand Down Expand Up @@ -2994,6 +3003,215 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(new_name, data_torch)]


@Model.register("Mamba2ForCausalLM")
class Mamba2Model(Model):
model_arch = gguf.MODEL_ARCH.MAMBA2

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# n_groups and d_inner are used during reshaping
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model

def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 16
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size

if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
elif (self.dir_model / "tokenizer.model.v3").is_file():
# mamba-codestral
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
elif (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
self._set_vocab_builtin("gpt-neox", vocab_size)

def set_gguf_parameters(self):
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64

rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert self.d_inner == 2 * self.d_model
assert self.d_inner % head_dim == 0

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(self.d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(self.d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.startswith("model.backbone") or name.startswith("model.lm_head"):
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
name = name.removeprefix("model.")

if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"

new_name = self.map_tensor_name(name)

if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

yield (new_name, data_torch)

def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor:
if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]):
# unsqueeze A to use similar shape semantics as Mamba-1
# (D is also unsqueezed, but for more straightforward broadcast internally)
return data_torch.reshape((*data_torch.shape, 1))

elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
return data_torch.reshape((self.n_group, self.d_inner // self.n_group))

return data_torch.squeeze()


@Model.register("BambaForCausalLM")
class BambaModel(Mamba2Model):
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
model_arch = gguf.MODEL_ARCH.BAMBA

def __init__(self, *args, **kwargs):

# Hybrid mamba models use a prefix for the mamba-specific params.
# TODO: Extend this if the prefix(es) need to be configurable
self.hparam_prefixes = ["mamba"]

super().__init__(*args, **kwargs)

# Use Llama conversion for attention
self._transformer_model_class: type[Model] = LlamaModel

# Lists of which layers use ssm vs attention
self._attn_layers = self.hparams.get("attn_layer_indices", [])
if not self._attn_layers:
attn_period = self.hparams.get("attn_layer_period")
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
attn_offset = self.hparams.get("attn_layer_offset")
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
self._attn_layers = [
i for i in range(self.block_count)
if i % attn_period == attn_offset
]
self._ssm_layers = [
i for i in range(self.block_count)
if i not in self._attn_layers
]

# n_group and d_inner are used during reshape_tensors for mamaba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model

def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)

def set_gguf_parameters(self):

## General Params ##
self.gguf_writer.add_embedding_length(self.d_model)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])

## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))

## Attention params ##
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))

## Feed Forward Params ##
self.gguf_writer.add_layer_norm_rms_eps(
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
)

## Validation ##
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"

## UNUSED ##
# "tie_word_embeddings" <-- Implied by presence of output weights
# "num_logits_to_keep" <-- Always only keep final token logits
# "use_cache" <-- KV Cache always enabled
# "use_mamba_kernels" <-- I think this will always be true if available?
# "chunk_size" <-- This is used in the mixer implementation in transformers, but not here

def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:

# Determine whether this is a mamaba layer or an attention layer
if bid in self._ssm_layers:
for mamba_new_name, data_torch in super().modify_tensors(
data_torch, name, bid
):
yield mamba_new_name, data_torch
elif bid in self._attn_layers:
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
self, data_torch, name, bid
):
yield llama_new_name, data_torch
else:
yield self.map_tensor_name(name), data_torch


def reshape_tensors(
self, data_torch: Tensor, new_name: str, bid: int | None,
) -> Tensor:
if bid in self._ssm_layers:
return super().reshape_tensors(data_torch, new_name, bid)
elif bid in self._attn_layers:
return self._transformer_model_class.reshape_tensors(
self, data_torch, new_name, bid
)
return data_torch


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R
Expand Down
3 changes: 2 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1769,7 +1769,8 @@ extern "C" {
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C);
struct ggml_tensor * C,
struct ggml_tensor * ids);

// partition into non-overlapping windows with padding if needed
// example:
Expand Down
Loading
Loading