Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/models/bagel.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ Create a YAML configuration file based on the template above, adjusting:
attn_implementation: "eager" # or "flash_attention_2"
extra_kwargs:
visual_und: false # Enable/disable visual understanding
# Optional: Enable Native Sparse Attention
# monkey_patch_kwargs:
# patch_type: ["nsa"]

# Training hyperparameters
per_device_train_batch_size: 1
Expand Down Expand Up @@ -136,6 +139,52 @@ fsdp_config:

## Advanced Features

### Native Sparse Attention (NSA) Support

We supports Native Sparse Attention training on BAGEL through monkey patching to improve memory efficiency and training speed for long sequences. NSA replaces the standard attention mechanism with a sparse variant that reduces computational complexity.

#### Prerequisites

Install the native sparse attention library:

```bash
pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git
```

#### Configuration

Enable NSA by adding the monkey patch configuration to your model config:

```yaml
model_config:
load_from_pretrained_path: "your-model-checkpoint-path"
attn_implementation: "eager/sdpa"
extra_kwargs:
visual_und: false
monkey_patch_kwargs:
patch_type: ["nsa"]
# NSA configuration parameters (all optional with defaults shown)
block_size: 64 # Size of attention blocks
compress_type: "weightedpool" # Options: weightedpool, linear, avgpool
kernel_size: 32 # Compression kernel size
kernel_stride: 16 # Compression kernel stride
topk: 16 # Number of top-k blocks to keep
init_blocks: 1 # Number of initial blocks to always include
local_blocks: 2 # Number of local blocks around current position
window_size: 512 # Local attention window size
```

#### NSA Parameters

We recommend you to find out the meaning for parameters from [here](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme)

#### Usage Notes

- NSA is most beneficial for longer seqence
- The sparse attention pattern is learned during training and adapts to the data
- All NSA parameters can be tuned based on your specific use case and hardware constraints


### Sequence Packing

BAGEL supports efficient sequence packing to maximize GPU utilization:
Expand Down
2 changes: 2 additions & 0 deletions src/lmms_engine/models/bagel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lmms_engine.mapping_func import register_model

from .bagel import Bagel, BagelConfig
from .monkey_patch import apply_nsa_to_bagel
from .qwen2_navit import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
from .siglip_navit import SiglipVisionConfig, SiglipVisionModel

Expand All @@ -23,4 +24,5 @@
"Qwen2ForCausalLM",
"SiglipVisionConfig",
"SiglipVisionModel",
"apply_nsa_to_bagel",
]
121 changes: 121 additions & 0 deletions src/lmms_engine/models/bagel/monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
from loguru import logger
from torch import nn

from lmms_engine.models.monkey_patch import MONKEY_PATCHER
from lmms_engine.utils import Logging

from .bagel import Bagel

try:
from native_sparse_attention.module.native_sparse_attention import (
COMPRESS_TYPE_TO_FUNC,
COMPRESS_TYPE_TO_WEIGHT,
)
except ImportError:
logger.warning(
"native_sparse_attention is not installed, please install with"
" `pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git`"
)


def add_g_proj_to_attention_layers(model: Bagel, nsa_config: dict):
"""
Add g_proj linear layers to all attention layers in the Bagel model.

Args:
model (Bagel): The Bagel model to modify
"""
# Access the language model's decoder layers
for layer in model.language_model.model.layers:
# Each layer has a self_attn module
if hasattr(layer, "self_attn"):
attn_layer = layer.self_attn
g_proj = nn.Linear(model.hidden_size, model.num_heads * 3, bias=False)
g_proj = g_proj.to(model.dtype)
compress_func = COMPRESS_TYPE_TO_FUNC[nsa_config["compress_type"]]
compress_key = COMPRESS_TYPE_TO_WEIGHT[nsa_config["compress_type"]](
attn_layer.config.num_key_value_heads,
attn_layer.head_dim,
nsa_config["kernel_size"],
)
compress_value = COMPRESS_TYPE_TO_WEIGHT[nsa_config["compress_type"]](
attn_layer.config.num_key_value_heads,
attn_layer.head_dim,
nsa_config["kernel_size"],
)
intra_block_pe = torch.nn.Parameter(
torch.zeros(
attn_layer.config.num_key_value_heads,
nsa_config["kernel_size"],
attn_layer.head_dim,
)
)
attn_layer.compress_func = compress_func
parameters = {
"g_proj": g_proj,
"compress_key": compress_key,
"compress_value": compress_value,
"intra_block_pe": intra_block_pe,
}
# set nsa config
for key, value in nsa_config.items():
setattr(attn_layer, key, value)
setattr(attn_layer.config, key, value)

for key, value in parameters.items():
if isinstance(value, torch.nn.Module) or isinstance(
value, torch.nn.Parameter
):
value = value.to(dtype=model.dtype)
if isinstance(value, torch.nn.Parameter):
attn_layer.register_parameter(key, value)
elif isinstance(value, torch.Tensor):
attn_layer.register_parameter(
key, torch.nn.Parameter(value, requires_grad=True)
)
else:
setattr(attn_layer, key, value)


@MONKEY_PATCHER.register("bagel", "nsa")
def apply_nsa_to_bagel(
model: Bagel,
block_size: int = 64,
compress_type: str = "weightedpool", # weightedpool, linear, avgpool
kernel_size: int = 32,
kernel_stride: int = 16,
topk: int = 16,
init_blocks: int = 1,
local_blocks: int = 2,
window_size: int = 512,
**kwargs,
):
"""
Apply NSA modifications to Bagel model.

Args:
model (Bagel): The Bagel model to modify
**kwargs: Additional keyword arguments
"""
nsa_config = {
"block_size": block_size,
"compress_type": compress_type,
"kernel_size": kernel_size,
"kernel_stride": kernel_stride,
"topk": topk,
"init_blocks": init_blocks,
"local_blocks": local_blocks,
"window_size": window_size,
}
Logging.info("Patch g_proj to bagel model")
add_g_proj_to_attention_layers(model, nsa_config)
Logging.info(
f"NSA applied to bagel model, Model size: {sum(p.numel() for p in model.parameters()) / 1e9} B"
)
model.config.nsa_config = nsa_config

from .nsa_op import forward_train as nsa_forward_train
from .qwen2_navit import PackedAttentionMoT

PackedAttentionMoT.forward_train = nsa_forward_train
195 changes: 195 additions & 0 deletions src/lmms_engine/models/bagel/nsa_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from typing import List, Tuple

import torch
from loguru import logger
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb

try:
from native_sparse_attention.ops import (
compressed_attention,
linear_compress,
topk_sparse_attention,
)
except ImportError:
logger.warning(
"native_sparse_attention is not installed, please install with"
" `pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git`"
)

from transformers.utils import is_flash_attn_2_available

if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func


def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
):
packed_query_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_heads * self.head_dim)
)
packed_key_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)
)
packed_value_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)
)

packed_sequence_und = packed_sequence[packed_und_token_indexes]
packed_sequence_gen = packed_sequence[packed_gen_token_indexes]

packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(
packed_sequence_gen
)

packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(
packed_sequence_gen
)

packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(
packed_sequence_gen
)

g = self.g_proj(packed_sequence)
g = g.view(1, packed_sequence.shape[0], self.num_heads, 3)
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)

packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
packed_key_states = packed_key_states.view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = packed_value_states.view(
-1, self.num_key_value_heads, self.head_dim
)
if self.config.freeze_und:
packed_value_states[packed_und_token_indexes] = packed_value_states[
packed_und_token_indexes
].detach()

packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)

packed_query_states_[packed_und_token_indexes] = self.q_norm(
packed_query_states[packed_und_token_indexes]
)
if self.config.freeze_und:
packed_query_states_[packed_und_token_indexes] = packed_query_states_[
packed_und_token_indexes
].detach()
packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(
packed_query_states[packed_gen_token_indexes]
)

packed_key_states_[packed_und_token_indexes] = self.k_norm(
packed_key_states[packed_und_token_indexes]
)
if self.config.freeze_und:
packed_key_states_[packed_und_token_indexes] = packed_key_states_[
packed_und_token_indexes
].detach()
packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(
packed_key_states[packed_gen_token_indexes]
)
cu_seqlens = torch.tensor(
[0] + sample_lens, dtype=torch.int32, device=packed_query_states_.device
)

# 1. key value compression
compressed_key, compressed_cu_seqlens = self.compress_func(
packed_key_states_,
self.compress_key,
cu_seqlens,
self.kernel_size,
self.kernel_stride,
self.intra_block_pe,
)
compressed_value, _ = self.compress_func(
packed_value_states,
self.compress_value,
cu_seqlens,
self.kernel_size,
self.kernel_stride,
None,
)

packed_cos, packed_sin = packed_position_embeddings
packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
packed_query_states_,
packed_key_states_,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
compressed_attn_output, topk_idx = compressed_attention(
packed_query_states_,
compressed_key,
compressed_value,
self.kernel_size,
self.kernel_stride,
self.block_size,
self.topk,
cu_seqlens,
compressed_cu_seqlens,
seqlens.max().item(),
compressed_seqlens.max().item(),
None,
self.init_blocks,
self.local_blocks,
)

# topk sparse attention
sparse_attn_output = topk_sparse_attention(
packed_query_states_,
packed_key_states_,
packed_value_states,
topk_idx,
self.block_size,
cu_seqlens,
None,
)

# sliding window attention
sliding_attn_output = flash_attn_varlen_func(
packed_query_states_,
packed_key_states_,
packed_value_states,
cu_seqlens,
cu_seqlens,
seqlens.max().item(),
seqlens.max().item(),
causal=True,
window_size=(self.window_size, -1),
)

attn_output = (
compressed_attn_output * g_cmp.unsqueeze(-1)
+ sparse_attn_output * g_swa.unsqueeze(-1)
+ sliding_attn_output * g_slc.unsqueeze(-1)
)

packed_attn_output = attn_output.squeeze(0)

packed_attn_output = packed_attn_output.transpose(0, 1).reshape(
-1, self.num_heads * self.head_dim
)
packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
packed_attn_output_[packed_und_token_indexes] = self.o_proj(
packed_attn_output[packed_und_token_indexes]
)
packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(
packed_attn_output[packed_gen_token_indexes]
)

return packed_attn_output_