Skip to content

Conversation

molbap
Copy link
Contributor

@molbap molbap commented Sep 5, 2025

What does this PR do?

As per title, adds support for LongCat-Flash, a 560B MoE from Meituan.

Status:

  • Current modeling_longcat_flash file allows loading checkpoint without trust_remote_code, using a specific base_model_tp_plan found in the config. `from_pretrained('..., tp_plan='auto') loads the model properly.
  • Chat template is as provided by authors.
  • A no-op hook added to deepseek_v3 to abstract lora scaling.
  • Testing out generations and correctness. # DOING
  • A few modular adjustments to make to derive from DeepSeekv3, estimate ~300 loc total.
  • Quality & last touches

Launch snippet:

# launch_longcat.py
from transformers import LongcatFlashForCausalLM, AutoTokenizer
import torch

torch.manual_seed(30)
model_id = "meituan-longcat/LongCat-Flash-Chat"

tokenizer = AutoTokenizer.from_pretrained(model_id)

chat = [
      {"role": "user", "content": "Hello! What is the capital of France? What can you tell me about it?"},
]

model = LongcatFlashForCausalLM.from_pretrained(
      model_id,
      tp_plan="auto",
      dtype=torch.bfloat16,
      trust_remote_code=False, # can be removed.
      )

inputs = tokenizer.apply_chat_template(
      chat, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)

outputs = model.generate(inputs, max_new_tokens=30)
print(tokenizer.batch_decode(outputs))

Note that you will need at least 2x8 H100 to launch the model with TP as follows

torchrun  --nproc_per_node=8 --nnodes=2 --node_rank=0 | 1  --rdzv-id <an_id> --rdzv-backend c10d --rdzv-endpoint $NODE_ID:$NODE_PORT  --log-dir ./logs_longcat launch_longcat.py

Copy link
Contributor

github-actions bot commented Sep 5, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: deepseek_v3

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants