-
Notifications
You must be signed in to change notification settings - Fork 18
feat: Bagel naive implementation of sparse attention #45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds an NSA-based attention path for Bagel models: new naive NSA ops, a Bagel-specific NSA forward_train, and a monkey patch that augments attention layers with a g_proj and overrides PackedAttentionMoT.forward_train. The bagel package exports apply_nsa_to_bagel to register and apply the patch. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User
participant M as Model (Bagel)
participant P as apply_nsa_to_bagel
participant L as Layers (attn)
participant C as PackedAttentionMoT
participant N as nsa_forward_train
U->>P: apply_nsa_to_bagel(model, block_size, block_counts, window_size)
P->>L: add g_proj, set block/window config
P->>C: override forward_train = N
P-->>U: return patched model
Note over C,N: Training-time forward now uses NSA path
sequenceDiagram
autonumber
participant T as Trainer
participant C as PackedAttentionMoT (patched)
participant E as nsa_forward_train
participant O as NSA Ops (naive.py)
T->>C: forward_train(packed seq, lens, masks, pos, idx_und, idx_gen)
C->>E: delegate
E->>E: project QKV (UND/MoE), normalize, rotary
E->>O: naive_nsa_with_compression(q,k,v, g_cmp, g_slc, g_swa, block_counts, block_size, window_size, cu_seqlens)
O-->>E: attention output
E->>E: output projections (UND/MoE), concat
E-->>C: packed_attn_output
C-->>T: return
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (10)
src/lmms_engine/models/nsa/naive.py (6)
165-166: Remove unused localg_swa_i.Assigned but not used; keep gating at the aggregated
o_swainstead (see prior comment).Apply this diff:
- # [HQ] - g_swa_i = g_swa_b[i_q] + # g_swa gating is applied to o_swa after FA2 call.
19-29: Fix return annotation forcompression.It returns a tuple
(k_cmp, v_cmp), not a single tensor.Apply this diff:
-@torch.compile -def compression(k: torch.Tensor, v: torch.Tensor, block_size: int) -> torch.Tensor: +@torch.compile +def compression( + k: torch.Tensor, v: torch.Tensor, block_size: int +) -> tuple[torch.Tensor, torch.Tensor]:
225-235: Fix return annotation fornaive_nsa_compression.Function returns
(block_indices, o_cmp).Apply this diff:
-) -> torch.LongTensor: +) -> tuple[torch.LongTensor, torch.Tensor]:
280-290: Fix return annotation fornaive_nsa_compression_varlen.Function returns
(block_indices, o_cmp).Apply this diff:
-) -> torch.LongTensor: +) -> tuple[torch.LongTensor, torch.Tensor]:
252-259: Nit: typos and minor mask clarity.Rename
casual_mask->causal_mask; small readability win. Also ensure both masks use the same device (q_b.device).Apply this diff:
- casual_mask = ( + causal_mask = ( (torch.arange(T) - BS + 1)[:, None] // BS < torch.arange(C)[None, :] ).to(q.device) - empty_mask = casual_mask.all(-1, True) + empty_mask = causal_mask.all(-1, True) @@ - attn_cmp = attn_cmp.masked_fill( - casual_mask & empty_mask.logical_not(), float("-inf") + attn_cmp = attn_cmp.masked_fill( + causal_mask & empty_mask.logical_not(), float("-inf") )- casual_mask = ( + causal_mask = ( (torch.arange(T_b) - BS + 1)[:, None] // BS < torch.arange(C_b)[None, :] - ).to(q_b.device) + ).to(q_b.device) local_mask = ( torch.arange(T_b)[:, None] // BS == torch.arange(C_b)[None, :] - ).to(q.device) + ).to(q_b.device) @@ - attn_cmp = attn_cmp.masked_fill(casual_mask, float("-inf")) + attn_cmp = attn_cmp.masked_fill(causal_mask, float("-inf"))Also applies to: 322-328
100-116: Index tensorcconstruction is per-call constant; precompute/cast once.Minor perf: move to outside the batch loop or compute with correct dtype/device once per function.
src/lmms_engine/models/bagel/nsa_op.py (2)
99-111: Unusedblock_indices; consider discarding to reduce overhead/noise.If not needed, assign to
_to document intent.Apply this diff:
- packed_attn_output, block_indices = naive_nsa_with_compression( + packed_attn_output, _ = naive_nsa_with_compression(
13-13:attention_maskis unused.Keep the signature, but assign to
_to silence linters.Apply this diff:
- attention_mask, + attention_mask, @@ + _ = attention_mask # intentionally unusedsrc/lmms_engine/models/bagel/monkey_patch.py (2)
24-36: Avoid setattr for static attributes; assign directly. Also propagate dtype carefully.Direct assignment is clearer; keep
g_projdtype consistent with attention module params if they differ frommodel.dtype.Apply this diff:
attn_layer.g_proj = g_proj attn_layer.block_size = block_size attn_layer.window_size = window_size attn_layer.block_counts = block_counts - setattr(attn_layer.config, "block_size", block_size) - setattr(attn_layer.config, "window_size", window_size) - setattr(attn_layer.config, "block_counts", block_counts) + attn_layer.config.block_size = block_size + attn_layer.config.window_size = window_size + attn_layer.config.block_counts = block_counts
38-59: Global monkey patching is sweeping; gate or make reversible.Consider a context manager or feature flag to avoid affecting unrelated Bagel instances during the same process.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/lmms_engine/models/bagel/__init__.py(2 hunks)src/lmms_engine/models/bagel/monkey_patch.py(1 hunks)src/lmms_engine/models/bagel/nsa_op.py(1 hunks)src/lmms_engine/models/nsa/naive.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
**/*.py: Use early returns to avoid nested conditionals
Use clear, descriptive names for variables and functions
Prefix handler functions with "handle" (e.g., handle_event)
Prefer constants over functions when possible
Avoid repetition (DRY) by extracting shared logic
Prefer functional, immutable approaches when not verbose
Define composing (higher-level) functions before their components
Mark issues in existing code with comments starting with "TODO:"
Favor functional and stateless approaches where they improve clarity
Keep core logic clean; push implementation details to the edges
Files:
src/lmms_engine/models/bagel/__init__.pysrc/lmms_engine/models/bagel/nsa_op.pysrc/lmms_engine/models/nsa/naive.pysrc/lmms_engine/models/bagel/monkey_patch.py
🧬 Code graph analysis (3)
src/lmms_engine/models/bagel/__init__.py (1)
src/lmms_engine/models/bagel/monkey_patch.py (1)
apply_nsa_to_bagel(39-58)
src/lmms_engine/models/bagel/nsa_op.py (1)
src/lmms_engine/models/nsa/naive.py (1)
naive_nsa_with_compression(350-457)
src/lmms_engine/models/bagel/monkey_patch.py (4)
src/lmms_engine/utils/logging_utils.py (2)
Logging(36-86)info(47-53)src/lmms_engine/models/bagel/bagel.py (1)
Bagel(124-1431)src/lmms_engine/models/bagel/nsa_op.py (1)
forward_train(9-126)src/lmms_engine/models/bagel/qwen2_navit.py (1)
PackedAttentionMoT(443-801)
🪛 Ruff (0.13.1)
src/lmms_engine/models/bagel/nsa_op.py
13-13: Unused function argument: attention_mask
(ARG001)
96-96: Consider [0, *sample_lens] instead of concatenation
Replace with [0, *sample_lens]
(RUF005)
99-99: Unpacked variable block_indices is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
src/lmms_engine/models/nsa/naive.py
89-91: Avoid specifying long messages outside the exception class
(TRY003)
165-165: Local variable g_swa_i is assigned to but never used
Remove assignment to unused variable g_swa_i
(F841)
405-407: Avoid specifying long messages outside the exception class
(TRY003)
src/lmms_engine/models/bagel/monkey_patch.py
31-31: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
32-32: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
33-33: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
40-40: Unused function argument: kwargs
(ARG001)
🔇 Additional comments (1)
src/lmms_engine/models/bagel/__init__.py (1)
9-9: Public API exposure looks fine.Import and export of
apply_nsa_to_bagelare consistent with usage.Also applies to: 27-27
|
晚上仔细看一下 |
|
换了个更好用的 triton kernel,就不用之前搓的那个了 |
Motivation
Modifications
Refer for the docs for the information about changes
Checklist
Summary by CodeRabbit