Skip to content

Commit

Permalink
Major Changes: support qwen; change as model adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
YizhaoGao committed Feb 23, 2025
1 parent c242a73 commit e85b648
Show file tree
Hide file tree
Showing 21 changed files with 1,764 additions and 312 deletions.
92 changes: 53 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,33 @@
![SeerAttention Architecture](figures/seer.png)


Official implementation of **SeerAttention** - a novel trainable sparse attention mechanism that learns intrinsic sparsity patterns directly from LLMs through self-distillation at post-training time. Achieves faster inference while maintaining accuracy for long-context processing.
Official implementation of **SeerAttention** - a novel trainable sparse attention mechanism that learns intrinsic sparsity patterns directly from LLMs through self-distillation at post-training time. Achieves faster inference while maintaining accuracy for long-context prefilling.


## News
- **2025/2/23**: Support Qwen! Change the distillation into model adapter so that only AttnGates are saved.
- **2025/2/18**: Deepseek's Native Sparse Attention ([NSA](https://arxiv.org/abs/2502.11089)) and Kimi's Mixture of Block Attention ([MoBA](https://github.com/MoonshotAI/MoBA)) all aquire similar trainable sparse attention concepts as us for pretrain models. Great works!


## Key Features
**Block-level Sparsity** - Learns dynamic sparsity patterns at block level
**Trainable Sparse Attention** - Outperform static/predefined attention sparsity
**Block-level Sparsity** - Hardware efficient sparsity at block level
**Self-Distillation** - Lightweight training of attention gates (original weights frozen)
**Efficient Kernel** - Custom block-sparse FlashAttention implementation
**Better Accuracy** - Outperforms static/heuristic sparse attention methods
**Efficient Kernel** - Block-sparse FlashAttention implementation
**Easy Integration** - Works with existing transformer architectures



## HF models
## Hugging Face Models
The current codebase is improved by only saving the distilled AttnGates' weights. During inference, you can composed the AttnGates and original base model. Check the latest huggingface repos!

The trained AttnGate with base model [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) is avaliable on Hugging Face!
| Model | HF Link |
|-------|---------|
| SeerAttention-Llama-3.1-8B | [SeerAttention/SeerAttention-Llama-3.1-8B](https://huggingface.co/SeerAttention/SeerAttention-Llama-3.1-8B) |
| Base Model | HF Link | AttnGates Size |
|------------------------|-----------------------------------------------------------------------------------------------------------------------------------------|----------------|
| Llama-3.1-8B-Instruct | [SeerAttention/SeerAttention-Llama-3.1-8B-AttnGates](https://huggingface.co/SeerAttention/SeerAttention-Llama-3.1-8B-AttnGates) | 101 MB |
| Llama-3.1-70B-Instruct | [SeerAttention/SeerAttention-Llama-3.1-70B-AttnGates](https://huggingface.co/SeerAttention/SeerAttention-Llama-3.1-70B-AttnGates) | 503 MB |
| Qwen2.5-7B-Instruct | [SeerAttention/SeerAttention-Qwen2.5-7B-AttnGates](https://huggingface.co/SeerAttention/SeerAttention-Qwen2.5-7B-AttnGates) | 77 MB |
| Qwen2.5-14B-Instruct | [SeerAttention/SeerAttention-Qwen2.5-14B-AttnGates](https://huggingface.co/SeerAttention/SeerAttention-Qwen2.5-14B-AttnGates) | 189 MB |
| Qwen2.5-32B-Instruct | [SeerAttention/SeerAttention-Qwen2.5-32B-AttnGates](https://huggingface.co/SeerAttention/SeerAttention-Qwen2.5-32B-AttnGates) | 252 MB |

## Quick Start

Expand All @@ -39,16 +46,49 @@ pip install -e .
```


### 2. Download the pretrained models
```bash
mkdir models
huggingface-cli download meta-llama/Llama-3.1-8B-Instruct --local-dir models/meta-llama/Llama-3.1-8B-Instruct

### 2. Inference with AttnGate Adapter
During inference, we automatically compose your original base model with our distilled AttnGates.

SeerAttention supports two sparse methods (Threshold / TopK) to convert a soft gating score to hard binary attention mask. Currently we simply use a single sparse configuration for all the attention heads. You are encourage to explore other configurations to tradeoff the speedup vs quality.
```python
from transformers import AutoTokenizer, AutoConfig
from seer_attn import SeerAttnLlamaForCausalLM

model_name = "SeerAttention/SeerAttention-Llama-3.1-8B-AttnGates"
config = AutoConfig.from_pretrained(model_name)

tokenizer = AutoTokenizer.from_pretrained(
config.base_model,
padding_side="left",
)

## This will compose the AttnGates and base model
model = SeerAttnLlamaForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
seerattn_sparsity_method='threshold', # Using a threshold based sparse method
seerattn_threshold = 5e-4, # Higher = sparser, typical range 5e-4 ~ 5e-3
)

# Or using a TopK based sparse method
model = SeerAttnLlamaForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
seerattn_sparsity_method='nz_ratio',
seerattn_nz_ratio = 0.5, # Lower = sparser, typical range 0.1 ~ 0.9
)

model = model.cuda()

# Ready to inference
```

### 3. Training Attention Gates with Self-distillation
Only AttnGates are trained to mimic the block-level attention score. In other words, the original model's weights are fronzen.

```bash
## scirpts to reproduce llama-3.1-8b
bash run_distillation.sh
```

Expand All @@ -75,32 +115,6 @@ loss = self.loss_func(predict_mask, mask_ground_truth)
```


### 4. Inference with Sparse Attention
SeerAttention supports two sparse methods (Threshold / TopK) to convert a soft gating score to hard binary attention mask. Currently we simply use a single sparse configuration for all the attention heads.
```python
from seer_attn import SeerAttnLlamaForCausalLM

# Using a threshold based sparse method
model = SeerAttnLlamaForCausalLM.from_pretrained(
'/path/to/your/model/',
torch_dtype=torch.bfloat16,
seerattn_sparsity_method='threshold',
seerattn_threshold = 5e-4, # Higher = sparser, typical range 5e-4 ~ 2e-3
)


# Or using a TopK based sparse method
model = SeerAttnLlamaForCausalLM.from_pretrained(
'/path/to/your/model/',
torch_dtype=torch.bfloat16,
seerattn_sparsity_method='nz_ratio',
seerattn_nz_ratio = 0.5, # Lower = sparser, typical range 0.1 ~ 0.9
)

model = model.cuda()

# Ready to inference
```

## Evaluation
For efficiency, we evaluate `block_sparse_attn` compared with full attention by FlashAttention-2.
Expand Down
111 changes: 85 additions & 26 deletions distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling
from torch.distributed import barrier
from seer_attn import SeerAttnLlamaForCausalLM

from seer_attn import SeerAttnLlamaForCausalLM, SeerAttnQwen2ForCausalLM
from transformers.trainer_utils import get_last_checkpoint
from huggingface_hub import login

from mytrainer import AttnGateTrainer

from datasets import load_dataset, load_from_disk
import warnings

Expand All @@ -30,7 +28,8 @@

@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3.1-8B")
base_model: str = field(default="meta-llama/Meta-Llama-3.1-8B")
model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3.1-8B", metadata={"help": "The local model path."})
seerattn_gate_type: Optional[str] = field(
default="Qavg_Kmaxmin",
metadata={"help": "AttnGate pooling type. Currently support combination of max min avg pooling for both q and k."},
Expand All @@ -43,10 +42,13 @@ class ModelArguments:
default=128,
metadata={"help": "AttnGate hidden size."},
)
seerattn_gate_force_double: Optional[bool] = field(
default=False,
metadata={"help": "Force using double linear for AttnGate."},
)

@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
training_max_length: int = field(
default=65536,
Expand All @@ -72,8 +74,41 @@ class TrainingArguments(transformers.TrainingArguments):
default=False,
metadata={"help": "If the dataset is already toknized."},
)
save_entire_model: bool = field(
default=False,
metadata={"help": "Save entire model."},
)



class AttnGateTrainer(Trainer):
def __init__(
self,
orig_weight_training=False,
gate_loss_scale=1.0,
fix_mask_predictor=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)

self.gate_loss_scale = gate_loss_scale
self.orig_weight_training = orig_weight_training
self.fix_mask_predictor = fix_mask_predictor

def compute_loss(self, model, inputs, **kwargs):
outputs = model(**inputs)

original_loss = outputs.get("loss")
mask_loss = outputs.get("mask_loss")

del outputs

if self.orig_weight_training:
tok_loss = original_loss + self.gate_loss_scale * mask_loss
else:
tok_loss = self.gate_loss_scale * mask_loss

return tok_loss


def smart_tokenizer_and_embedding_resize(
Expand All @@ -98,6 +133,7 @@ def smart_tokenizer_and_embedding_resize(
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg


def tokenize_fn(tokenizer, tranining_max_length, example):
outputs = tokenizer(
tokenizer.eos_token.join(example["text"]),
Expand All @@ -114,29 +150,40 @@ def train():
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()

if model_args.model_name_or_path is None:
model_args.model_name_or_path = model_args.base_model

config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)


original_vocab_size = config.vocab_size
config.base_model = model_args.base_model
config.seerattn_gate_type = model_args.seerattn_gate_type
config.seerattn_gate_block_size = model_args.seerattn_gate_block_size
config.seerattn_gate_hidden_size = model_args.seerattn_gate_hidden_size

model = SeerAttnLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
config.seerattn_gate_force_double = model_args.seerattn_gate_force_double


if "llama" in model_args.model_name_or_path.lower():
model = SeerAttnLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
load_gate=False,
config=config,
torch_dtype=torch.bfloat16,
)
elif "qwen" in model_args.model_name_or_path.lower():
model = SeerAttnQwen2ForCausalLM.from_pretrained(
model_args.model_name_or_path,
load_gate=False,
config=config,
torch_dtype=torch.bfloat16,
)

print("Using AttnGate type:", model_args.seerattn_gate_type)

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
padding_side="right",
use_fast=True,
)
Expand All @@ -151,6 +198,7 @@ def train():
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN


smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
Expand All @@ -160,7 +208,7 @@ def train():
for n, p in model.named_parameters():
if training_args.trainable_params in n:
p.requires_grad = True
torch.nn.init.xavier_uniform_(p)
# torch.nn.init.xavier_uniform_(p)
else:
p.requires_grad = False

Expand All @@ -172,7 +220,7 @@ def train():
dataset = load_from_disk(training_args.dataset_name)
dataset = dataset['input_ids']
else:
dataset = load_dataset(training_args.dataset_name, cache_dir=training_args.cache_dir, trust_remote_code=True)
dataset = load_dataset(training_args.dataset_name, trust_remote_code=True)
dataset = dataset.map(partial(tokenize_fn,tokenizer, training_args.training_max_length),batched=True, num_proc=128, remove_columns=["text", "meta", ])
dataset = dataset['train']

Expand All @@ -181,7 +229,6 @@ def train():

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


trainer = AttnGateTrainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -192,15 +239,27 @@ def train():
data_collator=data_collator
)

if training_args.resume_from_checkpoint is not None:
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
else:
trainer.train()
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if training_args.resume_from_checkpoint is None and last_checkpoint is not None:
print(f"Found checkpoint {last_checkpoint}. Resuming training.")
training_args.resume_from_checkpoint = last_checkpoint

trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if training_args.save_entire_model:
trainer.save_model(output_dir=training_args.output_dir)
elif rank == 0:
if hasattr(trainer.model, 'module'):
state_dict = trainer.model.module.state_dict()
else:
state_dict = trainer.model.state_dict()

model.config.vocab_size = original_vocab_size
model.config.save_pretrained(training_args.output_dir)
attn_gate_state_dict = {k: v for k, v in state_dict.items() if "attn_gate" in k}
torch.save(attn_gate_state_dict, os.path.join(training_args.output_dir, "attn_gate_weights.pth"))

if __name__ == "__main__":
# login(token=os.getenv('HUGGING_FACE_HUB_TOKEN'))
train()


Loading

0 comments on commit e85b648

Please sign in to comment.