Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add support for Cohere2ForCausalLM #10900

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

dranger003
Copy link
Contributor

@dranger003 dranger003 commented Dec 19, 2024

Closes #10816

Cohere updated their Command-R model architecture for C4AI Command R7B requiring an update to llama.cpp. Looking at the HF code, it looks like the model is using a hybrid cache like Gemma2. Additional info from their model page on HF:

The model features three layers with sliding window attention (window size 4096) and ROPE for efficient local context modeling and relative positional encoding. A fourth layer uses global attention without positional embeddings, enabling unrestricted token interactions across the entire sequence.

Summary changes in this PR (based on my very limited knowledge of neural nets):

  • Add sliding window and RoPE dim count during conversion
  • Remove ATTN_K_NORM and ATTN_Q_NORM
  • Support alternating sliding window attention in build_cohere2 (looking at llama.cpp's build_gemma2) using pattern of 4 layers
  • Use LLAMA_ROPE_TYPE_NORM as the rope type

HF transformers implementation reference:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere2/modular_cohere2.py

Test weights:
https://huggingface.co/dranger003/c4ai-command-r7b-12-2024-GGUF

@github-actions github-actions bot added the python python script changes label Dec 19, 2024
@dranger003 dranger003 marked this pull request as draft December 19, 2024 15:12
@dranger003
Copy link
Contributor Author

dranger003 commented Dec 19, 2024

HF config.json:

{
  "architectures": [
    "Cohere2ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 5,
  "cache_implementation": "hybrid",
  "eos_token_id": 255001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "layer_norm_eps": 1e-05,
  "layer_switch": 4,
  "logit_scale": 0.25,
  "max_position_embeddings": 8192,
  "model_type": "cohere2",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "order_of_interleaved_layers": "local_attn_first",
  "pad_token_id": 0,
  "position_embedding_type": "rope_gptj",
  "rope_scaling": null,
  "rope_theta": 50000,
  "rotary_pct": 1.0,
  "sliding_window": 4096,
  "sliding_window_pattern": 4,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.48.0.dev0",
  "use_cache": true,
  "use_embedding_sharing": true,
  "use_gated_activation": true,
  "use_parallel_block": true,
  "use_parallel_embedding": true,
  "vocab_size": 256000
}

@dranger003
Copy link
Contributor Author

Info from @foldl:

It uses (3 SWA layers + 1 global attention layer). So, build_command_r need to be updated, even though the result seems promising.

Here is an implementation of interleaved SWA/global-attention layers.

https://github.com/foldl/chatllm.cpp/blob/ff54a787948f02151b38231375be042b632a271e/models/cohere.cpp#L246C1-L258C1

class Cohere2Model(Model):
model_arch = gguf.MODEL_ARCH.COHERE2

def set_gguf_parameters(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config.json has "max_position_embeddings": 8192, but the model supports 128K context. Do we need to adjust this value here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't quote me on this but I think it's fine to leave this as-is and force users to adjust rope settings to enable the full context

src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated
cb(Vcur, "Vcur", il);
}

Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use build_rope_factors(il) for c when calling ggml_rope_ext with this model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RoPE is only applied to SWA layers.

Copy link
Contributor Author

@dranger003 dranger003 Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, looks like the cache is working now. Not sure if I still need build_rope_factors() though?

@dranger003 dranger003 marked this pull request as ready for review December 20, 2024 00:26
@dranger003 dranger003 changed the title Add support for Cohere2ForCausalLM llama : add support for Cohere2ForCausalLM Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Support for C4AI Command R7B / Cohere2ForCausalLM
3 participants