Skip to content

Conversation

@kcz358
Copy link
Collaborator

@kcz358 kcz358 commented Sep 22, 2025

Motivation

Modifications

Refer for the docs for the information about changes

Checklist

  • Format your code
  • Add unit tests
  • Update documentation as needed, including docstrings or example tutorials.

Summary by CodeRabbit

  • New Features
    • Introduces an optional NSA-based selective attention mode for Bagel models, enabling block-wise compression and sliding-window attention.
    • Adds support for variable-length and packed sequences in the new attention path.
  • Performance Improvements
    • Reduces memory usage and improves speed on long sequences via compression and selective attention.
    • More efficient training path through the new attention forward implementation.
  • Refactor
    • Expands the public API to expose the NSA patch entry point for Bagel models.

@coderabbitai
Copy link

coderabbitai bot commented Sep 22, 2025

Walkthrough

Adds an NSA-based attention path for Bagel models: new naive NSA ops, a Bagel-specific NSA forward_train, and a monkey patch that augments attention layers with a g_proj and overrides PackedAttentionMoT.forward_train. The bagel package exports apply_nsa_to_bagel to register and apply the patch.

Changes

Cohort / File(s) Summary of Changes
Bagel NSA integration
src/lmms_engine/models/bagel/__init__.py, src/lmms_engine/models/bagel/monkey_patch.py, src/lmms_engine/models/bagel/nsa_op.py
Exposes apply_nsa_to_bagel in package init. Adds monkey patch to attach g_proj to each attention layer, propagate NSA config, and override PackedAttentionMoT.forward_train with NSA implementation. Introduces NSA forward_train that builds packed QKV, applies rotary embeddings, calls NSA compute, and projects outputs.
NSA core implementation
src/lmms_engine/models/nsa/naive.py
Implements compression, naive NSA, variable-length compression, and an orchestration function naive_nsa_with_compression supporting block pruning and optional sliding window attention.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as User
  participant M as Model (Bagel)
  participant P as apply_nsa_to_bagel
  participant L as Layers (attn)
  participant C as PackedAttentionMoT
  participant N as nsa_forward_train

  U->>P: apply_nsa_to_bagel(model, block_size, block_counts, window_size)
  P->>L: add g_proj, set block/window config
  P->>C: override forward_train = N
  P-->>U: return patched model

  Note over C,N: Training-time forward now uses NSA path
Loading
sequenceDiagram
  autonumber
  participant T as Trainer
  participant C as PackedAttentionMoT (patched)
  participant E as nsa_forward_train
  participant O as NSA Ops (naive.py)

  T->>C: forward_train(packed seq, lens, masks, pos, idx_und, idx_gen)
  C->>E: delegate
  E->>E: project QKV (UND/MoE), normalize, rotary
  E->>O: naive_nsa_with_compression(q,k,v, g_cmp, g_slc, g_swa, block_counts, block_size, window_size, cu_seqlens)
  O-->>E: attention output
  E->>E: output projections (UND/MoE), concat
  E-->>C: packed_attn_output
  C-->>T: return
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble code where layers meet,
A g_proj sprouts—how neat, discreet!
I patch the path where queries stream,
Blocks compress, then windows gleam.
With NSA’s hop my ears take flight—
Packed tokens dance through soft-moon light.
Thump-thump, the model learns just right. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description provides a Modifications section and an example monkey_patch_kwargs but leaves the Motivation section empty and leaves all Checklist items unchecked. It omits required details from the repository template such as the high-level goals/rationale, a clear enumeration of affected APIs/files and how to enable/use the feature, and reproducible testing/benchmark instructions beyond an informal note about slowness and OOM. Because these required sections are incomplete or missing, the PR description does not meet the repository's template requirements. Please populate the Motivation section with goals and rationale and list the key files/APIs changed and how to enable the NSA monkey-patch; add concrete testing and benchmarking steps (hardware, commands, expected resource usage) and either include unit tests or justify their absence; and complete the Checklist and update documentation or docstrings to document backend-specific limitations and OOM risks.
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title "feat: Bagel naive implementation of sparse attention" succinctly and accurately captures the primary change in this PR — adding a naive NSA (sparse attention) implementation and related Bagel monkey-patch and operator modules as shown in the changed files. It is concise, specific to the main feature, and avoids noisy details, so a teammate scanning PR history will understand the primary intent. Therefore the title meets the repository guideline for a short, focused PR title.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch bagel/nsa

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (10)
src/lmms_engine/models/nsa/naive.py (6)

165-166: Remove unused local g_swa_i.

Assigned but not used; keep gating at the aggregated o_swa instead (see prior comment).

Apply this diff:

-            # [HQ]
-            g_swa_i = g_swa_b[i_q]
+            # g_swa gating is applied to o_swa after FA2 call.

19-29: Fix return annotation for compression.

It returns a tuple (k_cmp, v_cmp), not a single tensor.

Apply this diff:

-@torch.compile
-def compression(k: torch.Tensor, v: torch.Tensor, block_size: int) -> torch.Tensor:
+@torch.compile
+def compression(
+    k: torch.Tensor, v: torch.Tensor, block_size: int
+) -> tuple[torch.Tensor, torch.Tensor]:

225-235: Fix return annotation for naive_nsa_compression.

Function returns (block_indices, o_cmp).

Apply this diff:

-) -> torch.LongTensor:
+) -> tuple[torch.LongTensor, torch.Tensor]:

280-290: Fix return annotation for naive_nsa_compression_varlen.

Function returns (block_indices, o_cmp).

Apply this diff:

-) -> torch.LongTensor:
+) -> tuple[torch.LongTensor, torch.Tensor]:

252-259: Nit: typos and minor mask clarity.

Rename casual_mask -> causal_mask; small readability win. Also ensure both masks use the same device (q_b.device).

Apply this diff:

-    casual_mask = (
+    causal_mask = (
         (torch.arange(T) - BS + 1)[:, None] // BS < torch.arange(C)[None, :]
     ).to(q.device)
-    empty_mask = casual_mask.all(-1, True)
+    empty_mask = causal_mask.all(-1, True)
@@
-    attn_cmp = attn_cmp.masked_fill(
-        casual_mask & empty_mask.logical_not(), float("-inf")
+    attn_cmp = attn_cmp.masked_fill(
+        causal_mask & empty_mask.logical_not(), float("-inf")
     )
-        casual_mask = (
+        causal_mask = (
             (torch.arange(T_b) - BS + 1)[:, None] // BS < torch.arange(C_b)[None, :]
-        ).to(q_b.device)
+        ).to(q_b.device)
         local_mask = (
             torch.arange(T_b)[:, None] // BS == torch.arange(C_b)[None, :]
-        ).to(q.device)
+        ).to(q_b.device)
@@
-        attn_cmp = attn_cmp.masked_fill(casual_mask, float("-inf"))
+        attn_cmp = attn_cmp.masked_fill(causal_mask, float("-inf"))

Also applies to: 322-328


100-116: Index tensor c construction is per-call constant; precompute/cast once.

Minor perf: move to outside the batch loop or compute with correct dtype/device once per function.

src/lmms_engine/models/bagel/nsa_op.py (2)

99-111: Unused block_indices; consider discarding to reduce overhead/noise.

If not needed, assign to _ to document intent.

Apply this diff:

-    packed_attn_output, block_indices = naive_nsa_with_compression(
+    packed_attn_output, _ = naive_nsa_with_compression(

13-13: attention_mask is unused.

Keep the signature, but assign to _ to silence linters.

Apply this diff:

-    attention_mask,
+    attention_mask,
@@
+    _ = attention_mask  # intentionally unused
src/lmms_engine/models/bagel/monkey_patch.py (2)

24-36: Avoid setattr for static attributes; assign directly. Also propagate dtype carefully.

Direct assignment is clearer; keep g_proj dtype consistent with attention module params if they differ from model.dtype.

Apply this diff:

             attn_layer.g_proj = g_proj
             attn_layer.block_size = block_size
             attn_layer.window_size = window_size
             attn_layer.block_counts = block_counts
-            setattr(attn_layer.config, "block_size", block_size)
-            setattr(attn_layer.config, "window_size", window_size)
-            setattr(attn_layer.config, "block_counts", block_counts)
+            attn_layer.config.block_size = block_size
+            attn_layer.config.window_size = window_size
+            attn_layer.config.block_counts = block_counts

38-59: Global monkey patching is sweeping; gate or make reversible.

Consider a context manager or feature flag to avoid affecting unrelated Bagel instances during the same process.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9a7eabe and fa8ef38.

📒 Files selected for processing (4)
  • src/lmms_engine/models/bagel/__init__.py (2 hunks)
  • src/lmms_engine/models/bagel/monkey_patch.py (1 hunks)
  • src/lmms_engine/models/bagel/nsa_op.py (1 hunks)
  • src/lmms_engine/models/nsa/naive.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

**/*.py: Use early returns to avoid nested conditionals
Use clear, descriptive names for variables and functions
Prefix handler functions with "handle" (e.g., handle_event)
Prefer constants over functions when possible
Avoid repetition (DRY) by extracting shared logic
Prefer functional, immutable approaches when not verbose
Define composing (higher-level) functions before their components
Mark issues in existing code with comments starting with "TODO:"
Favor functional and stateless approaches where they improve clarity
Keep core logic clean; push implementation details to the edges

Files:

  • src/lmms_engine/models/bagel/__init__.py
  • src/lmms_engine/models/bagel/nsa_op.py
  • src/lmms_engine/models/nsa/naive.py
  • src/lmms_engine/models/bagel/monkey_patch.py
🧬 Code graph analysis (3)
src/lmms_engine/models/bagel/__init__.py (1)
src/lmms_engine/models/bagel/monkey_patch.py (1)
  • apply_nsa_to_bagel (39-58)
src/lmms_engine/models/bagel/nsa_op.py (1)
src/lmms_engine/models/nsa/naive.py (1)
  • naive_nsa_with_compression (350-457)
src/lmms_engine/models/bagel/monkey_patch.py (4)
src/lmms_engine/utils/logging_utils.py (2)
  • Logging (36-86)
  • info (47-53)
src/lmms_engine/models/bagel/bagel.py (1)
  • Bagel (124-1431)
src/lmms_engine/models/bagel/nsa_op.py (1)
  • forward_train (9-126)
src/lmms_engine/models/bagel/qwen2_navit.py (1)
  • PackedAttentionMoT (443-801)
🪛 Ruff (0.13.1)
src/lmms_engine/models/bagel/nsa_op.py

13-13: Unused function argument: attention_mask

(ARG001)


96-96: Consider [0, *sample_lens] instead of concatenation

Replace with [0, *sample_lens]

(RUF005)


99-99: Unpacked variable block_indices is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

src/lmms_engine/models/nsa/naive.py

89-91: Avoid specifying long messages outside the exception class

(TRY003)


165-165: Local variable g_swa_i is assigned to but never used

Remove assignment to unused variable g_swa_i

(F841)


405-407: Avoid specifying long messages outside the exception class

(TRY003)

src/lmms_engine/models/bagel/monkey_patch.py

31-31: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


32-32: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


33-33: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


40-40: Unused function argument: kwargs

(ARG001)

🔇 Additional comments (1)
src/lmms_engine/models/bagel/__init__.py (1)

9-9: Public API exposure looks fine.

Import and export of apply_nsa_to_bagel are consistent with usage.

Also applies to: 27-27

@kcz358 kcz358 requested a review from Luodian September 29, 2025 05:48
@Luodian
Copy link
Collaborator

Luodian commented Sep 29, 2025

晚上仔细看一下

@kcz358
Copy link
Collaborator Author

kcz358 commented Sep 29, 2025

换了个更好用的 triton kernel,就不用之前搓的那个了

@kcz358 kcz358 merged commit 30fb799 into main Oct 7, 2025
2 checks passed
@kcz358 kcz358 deleted the bagel/nsa branch October 7, 2025 13:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants