diff --git a/docs/models/bagel.md b/docs/models/bagel.md index e2427eed..d74708d1 100644 --- a/docs/models/bagel.md +++ b/docs/models/bagel.md @@ -70,6 +70,9 @@ Create a YAML configuration file based on the template above, adjusting: attn_implementation: "eager" # or "flash_attention_2" extra_kwargs: visual_und: false # Enable/disable visual understanding + # Optional: Enable Native Sparse Attention + # monkey_patch_kwargs: + # patch_type: ["nsa"] # Training hyperparameters per_device_train_batch_size: 1 @@ -136,6 +139,52 @@ fsdp_config: ## Advanced Features +### Native Sparse Attention (NSA) Support + +We supports Native Sparse Attention training on BAGEL through monkey patching to improve memory efficiency and training speed for long sequences. NSA replaces the standard attention mechanism with a sparse variant that reduces computational complexity. + +#### Prerequisites + +Install the native sparse attention library: + +```bash +pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git +``` + +#### Configuration + +Enable NSA by adding the monkey patch configuration to your model config: + +```yaml +model_config: + load_from_pretrained_path: "your-model-checkpoint-path" + attn_implementation: "eager/sdpa" + extra_kwargs: + visual_und: false + monkey_patch_kwargs: + patch_type: ["nsa"] + # NSA configuration parameters (all optional with defaults shown) + block_size: 64 # Size of attention blocks + compress_type: "weightedpool" # Options: weightedpool, linear, avgpool + kernel_size: 32 # Compression kernel size + kernel_stride: 16 # Compression kernel stride + topk: 16 # Number of top-k blocks to keep + init_blocks: 1 # Number of initial blocks to always include + local_blocks: 2 # Number of local blocks around current position + window_size: 512 # Local attention window size +``` + +#### NSA Parameters + +We recommend you to find out the meaning for parameters from [here](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme) + +#### Usage Notes + +- NSA is most beneficial for longer seqence +- The sparse attention pattern is learned during training and adapts to the data +- All NSA parameters can be tuned based on your specific use case and hardware constraints + + ### Sequence Packing BAGEL supports efficient sequence packing to maximize GPU utilization: diff --git a/src/lmms_engine/models/bagel/__init__.py b/src/lmms_engine/models/bagel/__init__.py index a1f1ee81..ad0541ff 100644 --- a/src/lmms_engine/models/bagel/__init__.py +++ b/src/lmms_engine/models/bagel/__init__.py @@ -6,6 +6,7 @@ from lmms_engine.mapping_func import register_model from .bagel import Bagel, BagelConfig +from .monkey_patch import apply_nsa_to_bagel from .qwen2_navit import Qwen2Config, Qwen2ForCausalLM, Qwen2Model from .siglip_navit import SiglipVisionConfig, SiglipVisionModel @@ -23,4 +24,5 @@ "Qwen2ForCausalLM", "SiglipVisionConfig", "SiglipVisionModel", + "apply_nsa_to_bagel", ] diff --git a/src/lmms_engine/models/bagel/monkey_patch.py b/src/lmms_engine/models/bagel/monkey_patch.py new file mode 100644 index 00000000..faf62e50 --- /dev/null +++ b/src/lmms_engine/models/bagel/monkey_patch.py @@ -0,0 +1,121 @@ +import torch +from loguru import logger +from torch import nn + +from lmms_engine.models.monkey_patch import MONKEY_PATCHER +from lmms_engine.utils import Logging + +from .bagel import Bagel + +try: + from native_sparse_attention.module.native_sparse_attention import ( + COMPRESS_TYPE_TO_FUNC, + COMPRESS_TYPE_TO_WEIGHT, + ) +except ImportError: + logger.warning( + "native_sparse_attention is not installed, please install with" + " `pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git`" + ) + + +def add_g_proj_to_attention_layers(model: Bagel, nsa_config: dict): + """ + Add g_proj linear layers to all attention layers in the Bagel model. + + Args: + model (Bagel): The Bagel model to modify + """ + # Access the language model's decoder layers + for layer in model.language_model.model.layers: + # Each layer has a self_attn module + if hasattr(layer, "self_attn"): + attn_layer = layer.self_attn + g_proj = nn.Linear(model.hidden_size, model.num_heads * 3, bias=False) + g_proj = g_proj.to(model.dtype) + compress_func = COMPRESS_TYPE_TO_FUNC[nsa_config["compress_type"]] + compress_key = COMPRESS_TYPE_TO_WEIGHT[nsa_config["compress_type"]]( + attn_layer.config.num_key_value_heads, + attn_layer.head_dim, + nsa_config["kernel_size"], + ) + compress_value = COMPRESS_TYPE_TO_WEIGHT[nsa_config["compress_type"]]( + attn_layer.config.num_key_value_heads, + attn_layer.head_dim, + nsa_config["kernel_size"], + ) + intra_block_pe = torch.nn.Parameter( + torch.zeros( + attn_layer.config.num_key_value_heads, + nsa_config["kernel_size"], + attn_layer.head_dim, + ) + ) + attn_layer.compress_func = compress_func + parameters = { + "g_proj": g_proj, + "compress_key": compress_key, + "compress_value": compress_value, + "intra_block_pe": intra_block_pe, + } + # set nsa config + for key, value in nsa_config.items(): + setattr(attn_layer, key, value) + setattr(attn_layer.config, key, value) + + for key, value in parameters.items(): + if isinstance(value, torch.nn.Module) or isinstance( + value, torch.nn.Parameter + ): + value = value.to(dtype=model.dtype) + if isinstance(value, torch.nn.Parameter): + attn_layer.register_parameter(key, value) + elif isinstance(value, torch.Tensor): + attn_layer.register_parameter( + key, torch.nn.Parameter(value, requires_grad=True) + ) + else: + setattr(attn_layer, key, value) + + +@MONKEY_PATCHER.register("bagel", "nsa") +def apply_nsa_to_bagel( + model: Bagel, + block_size: int = 64, + compress_type: str = "weightedpool", # weightedpool, linear, avgpool + kernel_size: int = 32, + kernel_stride: int = 16, + topk: int = 16, + init_blocks: int = 1, + local_blocks: int = 2, + window_size: int = 512, + **kwargs, +): + """ + Apply NSA modifications to Bagel model. + + Args: + model (Bagel): The Bagel model to modify + **kwargs: Additional keyword arguments + """ + nsa_config = { + "block_size": block_size, + "compress_type": compress_type, + "kernel_size": kernel_size, + "kernel_stride": kernel_stride, + "topk": topk, + "init_blocks": init_blocks, + "local_blocks": local_blocks, + "window_size": window_size, + } + Logging.info("Patch g_proj to bagel model") + add_g_proj_to_attention_layers(model, nsa_config) + Logging.info( + f"NSA applied to bagel model, Model size: {sum(p.numel() for p in model.parameters()) / 1e9} B" + ) + model.config.nsa_config = nsa_config + + from .nsa_op import forward_train as nsa_forward_train + from .qwen2_navit import PackedAttentionMoT + + PackedAttentionMoT.forward_train = nsa_forward_train diff --git a/src/lmms_engine/models/bagel/nsa_op.py b/src/lmms_engine/models/bagel/nsa_op.py new file mode 100644 index 00000000..71684b3d --- /dev/null +++ b/src/lmms_engine/models/bagel/nsa_op.py @@ -0,0 +1,195 @@ +from typing import List, Tuple + +import torch +from loguru import logger +from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb + +try: + from native_sparse_attention.ops import ( + compressed_attention, + linear_compress, + topk_sparse_attention, + ) +except ImportError: + logger.warning( + "native_sparse_attention is not installed, please install with" + " `pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git`" + ) + +from transformers.utils import is_flash_attn_2_available + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + + +def forward_train( + self, + packed_sequence: torch.Tensor, + sample_lens: List[int], + attention_mask, + packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], + packed_und_token_indexes: torch.LongTensor, + packed_gen_token_indexes: torch.LongTensor, +): + packed_query_states = packed_sequence.new_zeros( + (packed_sequence.shape[0], self.num_heads * self.head_dim) + ) + packed_key_states = packed_sequence.new_zeros( + (packed_sequence.shape[0], self.num_key_value_heads * self.head_dim) + ) + packed_value_states = packed_sequence.new_zeros( + (packed_sequence.shape[0], self.num_key_value_heads * self.head_dim) + ) + + packed_sequence_und = packed_sequence[packed_und_token_indexes] + packed_sequence_gen = packed_sequence[packed_gen_token_indexes] + + packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und) + packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen( + packed_sequence_gen + ) + + packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und) + packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen( + packed_sequence_gen + ) + + packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und) + packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen( + packed_sequence_gen + ) + + g = self.g_proj(packed_sequence) + g = g.view(1, packed_sequence.shape[0], self.num_heads, 3) + g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1) + + packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) + packed_key_states = packed_key_states.view( + -1, self.num_key_value_heads, self.head_dim + ) + packed_value_states = packed_value_states.view( + -1, self.num_key_value_heads, self.head_dim + ) + if self.config.freeze_und: + packed_value_states[packed_und_token_indexes] = packed_value_states[ + packed_und_token_indexes + ].detach() + + packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape) + packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape) + + packed_query_states_[packed_und_token_indexes] = self.q_norm( + packed_query_states[packed_und_token_indexes] + ) + if self.config.freeze_und: + packed_query_states_[packed_und_token_indexes] = packed_query_states_[ + packed_und_token_indexes + ].detach() + packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen( + packed_query_states[packed_gen_token_indexes] + ) + + packed_key_states_[packed_und_token_indexes] = self.k_norm( + packed_key_states[packed_und_token_indexes] + ) + if self.config.freeze_und: + packed_key_states_[packed_und_token_indexes] = packed_key_states_[ + packed_und_token_indexes + ].detach() + packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen( + packed_key_states[packed_gen_token_indexes] + ) + cu_seqlens = torch.tensor( + [0] + sample_lens, dtype=torch.int32, device=packed_query_states_.device + ) + + # 1. key value compression + compressed_key, compressed_cu_seqlens = self.compress_func( + packed_key_states_, + self.compress_key, + cu_seqlens, + self.kernel_size, + self.kernel_stride, + self.intra_block_pe, + ) + compressed_value, _ = self.compress_func( + packed_value_states, + self.compress_value, + cu_seqlens, + self.kernel_size, + self.kernel_stride, + None, + ) + + packed_cos, packed_sin = packed_position_embeddings + packed_query_states_, packed_key_states_ = apply_rotary_pos_emb( + packed_query_states_, + packed_key_states_, + packed_cos, + packed_sin, + unsqueeze_dim=1, + ) + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1] + compressed_attn_output, topk_idx = compressed_attention( + packed_query_states_, + compressed_key, + compressed_value, + self.kernel_size, + self.kernel_stride, + self.block_size, + self.topk, + cu_seqlens, + compressed_cu_seqlens, + seqlens.max().item(), + compressed_seqlens.max().item(), + None, + self.init_blocks, + self.local_blocks, + ) + + # topk sparse attention + sparse_attn_output = topk_sparse_attention( + packed_query_states_, + packed_key_states_, + packed_value_states, + topk_idx, + self.block_size, + cu_seqlens, + None, + ) + + # sliding window attention + sliding_attn_output = flash_attn_varlen_func( + packed_query_states_, + packed_key_states_, + packed_value_states, + cu_seqlens, + cu_seqlens, + seqlens.max().item(), + seqlens.max().item(), + causal=True, + window_size=(self.window_size, -1), + ) + + attn_output = ( + compressed_attn_output * g_cmp.unsqueeze(-1) + + sparse_attn_output * g_swa.unsqueeze(-1) + + sliding_attn_output * g_slc.unsqueeze(-1) + ) + + packed_attn_output = attn_output.squeeze(0) + + packed_attn_output = packed_attn_output.transpose(0, 1).reshape( + -1, self.num_heads * self.head_dim + ) + packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape) + packed_attn_output_[packed_und_token_indexes] = self.o_proj( + packed_attn_output[packed_und_token_indexes] + ) + packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen( + packed_attn_output[packed_gen_token_indexes] + ) + + return packed_attn_output_