Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add support for Cohere2ForCausalLM #10900

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3167,6 +3167,24 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)


@Model.register("Cohere2ForCausalLM")
class Cohere2Model(Model):
model_arch = gguf.MODEL_ARCH.COHERE2

def set_gguf_parameters(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config.json has "max_position_embeddings": 8192, but the model supports 128K context. Do we need to adjust this value here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't quote me on this but I think it's fine to leave this as-is and force users to adjust rope settings to enable the full context

super().set_gguf_parameters()

self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])

rotary_pct = self.hparams["rotary_pct"]
hidden_size = self.hparams["hidden_size"]
num_attention_heads = self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads)))
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)


@Model.register("OlmoForCausalLM")
@Model.register("OLMoForCausalLM")
class OlmoModel(Model):
Expand Down
14 changes: 14 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class MODEL_ARCH(IntEnum):
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
COHERE2 = auto()
DBRX = auto()
OLMO = auto()
OLMO2 = auto()
Expand Down Expand Up @@ -435,6 +436,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.COHERE2: "cohere2",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.OLMO2: "olmo2",
Expand Down Expand Up @@ -1114,6 +1116,18 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
],
MODEL_ARCH.COHERE2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.DBRX: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
189 changes: 189 additions & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ enum llm_arch {
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2,
LLM_ARCH_DBRX,
LLM_ARCH_OLMO,
LLM_ARCH_OLMO2,
Expand Down Expand Up @@ -235,6 +236,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" },
{ LLM_ARCH_DBRX, "dbrx" },
{ LLM_ARCH_OLMO, "olmo" },
{ LLM_ARCH_OLMO2, "olmo2" },
Expand Down Expand Up @@ -1240,6 +1242,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
},
},
{
LLM_ARCH_COHERE2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_DBRX,
{
Expand Down Expand Up @@ -6110,6 +6127,16 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_COHERE2:
{
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_8B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_DBRX:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
Expand Down Expand Up @@ -8863,6 +8890,32 @@ static bool llm_load_tensors(
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_COHERE2:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);

// output
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
// init output from the input tok embed
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
llama_model_loader::TENSOR_DUPLICATED);

for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);

layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
}
}
break;
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
Expand Down Expand Up @@ -14783,6 +14836,137 @@ struct llm_build_context {

}

struct ggml_cgraph * build_cohere2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);

const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const float f_logit_scale = hparams.f_logit_scale;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;

inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);

// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
// cohere2 requires different mask for layers using sliding window (SWA)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();

// sliding window switch pattern
const int32_t sliding_window_pattern = 4;

for (int il = 0; il < n_layer; ++il) {
// three layers sliding window attention (window size 4096) and ROPE
// fourth layer uses global attention without positional embeddings
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;

// norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
cb(cur, "attn_norm", il);
struct ggml_tensor * ffn_inp = cur;

// self-attention
{
// rope freq factors for 128k context
struct ggml_tensor * rope_factors = build_rope_factors(il);

// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}

struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}

struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}

if (is_sliding) {
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
beta_fast, beta_slow);
cb(Qcur, "Qcur", il);

Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
} else {
// For non-sliding layers, just reshape without applying RoPE
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur", il);

Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
cb(Kcur, "Kcur", il);
}

cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
}

if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}

struct ggml_tensor * attn_out = cur;

// feed-forward network
{
cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
cb, il);
cb(cur, "ffn_out", il);
}

// add together residual + FFN + self-attention
cur = ggml_add(ctx0, cur, inpL);
cur = ggml_add(ctx0, cur, attn_out);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}

cur = inpL;

cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);

// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);

if (f_logit_scale) {
cur = ggml_scale(ctx0, cur, f_logit_scale);
}

cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);

return gf;
}

// ref: https://allenai.org/olmo
// based on the original build_llama() function, changes:
// * non-parametric layer norm
Expand Down Expand Up @@ -17530,6 +17714,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_command_r();
} break;
case LLM_ARCH_COHERE2:
{
result = llm.build_cohere2();
} break;
case LLM_ARCH_DBRX:
{
result = llm.build_dbrx();
Expand Down Expand Up @@ -20802,6 +20990,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_MINICPM:
case LLM_ARCH_XVERSE:
case LLM_ARCH_COMMAND_R:
case LLM_ARCH_COHERE2:
case LLM_ARCH_OLMO:
case LLM_ARCH_ARCTIC:
case LLM_ARCH_DEEPSEEK:
Expand Down
Loading