Skip to content

Conversation

@Iamleos
Copy link

@Iamleos Iamleos commented Oct 13, 2025

No description provided.

@gemini-code-assist
Copy link

Summary of Changes

Hello @Iamleos, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Flash Attention mechanism by integrating support for custom attention masks. This feature provides greater control over the attention computation, moving beyond fixed causal masking to allow for arbitrary masking patterns. The changes involve updates across the attention functions, the core kernel, and associated metadata structures, ensuring that custom masks can be efficiently applied during both reference and optimized attention calculations. This flexibility is crucial for advanced attention models and research.

Highlights

  • Custom Mask Support: Introduced the ability to use custom attention masks in the Flash Attention kernel, allowing for more flexible attention patterns beyond standard causal masking.
  • API Changes: The ref_ragged_paged_attention and ragged_paged_attention functions now accept custom_mask and causal parameters, enabling users to specify their own masks or toggle causal masking.
  • Kernel Enhancements: The underlying _ragged_paged_attention_kernel was updated to handle the new custom mask input, including new scratch space for mask data and asynchronous fetching mechanisms for efficient processing.
  • Metadata and Backend Updates: The FlashAttentionMetadata and FlashAttentionBackend classes were extended to properly manage and propagate the custom_mask information throughout the attention pipeline.
  • Testing: New test cases were added to test_flashattention.py to validate the correctness of the custom mask functionality, specifically for multi-head attention (MHA) prefill scenarios.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for custom attention masks in the flash attention kernel, which is a significant feature. The changes span the reference implementation, the Pallas kernel, the attention backend, and the tests.

Overall, the direction is good, but I've identified a few critical issues that must be addressed before merging:

  • There's a bug in the reference implementation where using sliding_window with a custom mask (causal=False) will cause a crash.
  • A data type mismatch for the custom mask buffers between HBM and VMEM in the Pallas kernel could lead to incorrect behavior.

I've also found other high-priority issues, including a bug in the test logic for custom masks and a problematic implementation of mask loading in the kernel. Please review my detailed comments for suggestions on how to resolve these issues.

Comment on lines 149 to 176
if causal:
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
jnp.int32, attn.shape, 1
)
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
mask = q_span < kv_span
else:
mask_start = cu_kv_lens[i]
mask = custom_mask[mask_start : mask_start + kv_len]
if sliding_window is not None:
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential UnboundLocalError here. The variables q_span and kv_span are defined only within the if causal: block (lines 150-153), but they are used on line 159 regardless of whether causal is true. If causal=False, this code will raise an error.

The sliding_window logic should likely be moved inside the if causal: block, as it seems to be a feature of causal attention. If sliding_window is intended to work with custom_mask, its implementation needs to be defined for the causal=False case.

Suggested change
if causal:
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
jnp.int32, attn.shape, 1
)
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
mask = q_span < kv_span
else:
mask_start = cu_kv_lens[i]
mask = custom_mask[mask_start : mask_start + kv_len]
if sliding_window is not None:
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
if causal:
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
jnp.int32, attn.shape, 1
)
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
mask = q_span < kv_span
if sliding_window is not None:
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
else:
mask_start = cu_kv_lens[i]
mask = custom_mask[mask_start : mask_start + kv_len]


bkvmask_double_buf = pltpu.VMEM(
(2, bq_sz, bkv_sz, head_dim),
jnp.bool,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a data type mismatch for the custom mask buffer. bkvmask_double_buf is defined with jnp.bool, but the data copied into it is jnp.int32. DMA transfers might not handle this conversion implicitly and could lead to incorrect mask values or errors.

To ensure correctness, the dtype of the VMEM buffer should match the dtype of the data in HBM. I recommend using a consistent integer type (e.g., jnp.int32 as used in the ragged_paged_attention function) for the mask throughout the pipeline.

        jnp.int32,

Comment on lines 748 to 758
mask = bkvmask_ref[bkv_sem_idx, :actual_bq_sz]
num_q_heads_per_kv_head_mask = jnp.concat(
[mask] * num_q_heads_per_kv_head
)
num_kv_heads_mask = jnp.concat(
[
num_q_heads_per_kv_head_mask.reshape(
1, *num_q_heads_per_kv_head_mask.shape
)
]
* actual_num_kv_heads
)
return num_kv_heads_mask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This load_mask function has a bug and an inefficiency.

  1. Bug: The mask is not sliced correctly. It should be sliced to actual_bkv_sz on the key dimension to match the shape of the attention scores s. Currently, it uses bkv_sz, which can lead to shape mismatches in lax.select if actual_bkv_sz < bkv_sz.
  2. Inefficiency: Using jnp.concat([x] * n) is inefficient. jnp.tile is the idiomatic and more performant way to repeat an array in JAX.
                    mask = bkvmask_ref[bkv_sem_idx, :actual_bq_sz, :actual_bkv_sz]
                    num_q_heads_per_kv_head_mask = jnp.tile(
                        mask, (num_q_heads_per_kv_head, 1, 1)
                    )
                    num_kv_heads_mask = jnp.tile(
                        jnp.expand_dims(num_q_heads_per_kv_head_mask, 0),
                        (actual_num_kv_heads, 1, 1, 1),
                    )
                    return num_kv_heads_mask

Comment on lines 103 to 108
if batch.forward_mode == ForwardMode.TARGET_VERIFY:
# convert custom_mask from bool to int8, because dma not support bool type
if batch.spec_info.custom_mask.dtype == jnp.bool:
metadata.custom_mask = batch.spec_info.custom_mask.astype(jnp.int32)
else:
metadata.custom_mask = batch.spec_info.custom_mask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's an inconsistency in the data type used for custom_mask. The comment on line 104 mentions int8, but the code casts to jnp.int32. The Pallas kernel also has conflicting comments and usage (int8 comment, int32 usage, bool buffer). This should be standardized across the implementation to avoid errors and confusion. Using a consistent type like jnp.int32 everywhere (including in the VMEM buffer) would resolve this.

custom_masks = []
for bid, seq_len in enumerate([kv_len for _, kv_len in lens]):
prefix_len = seq_len - q_lens[bid]
prefix_mask = jnp.full((prefix_len), True, dtype=jnp.bool)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The prefix_mask is initialized with True, which means these tokens will be masked out in the attention calculation. This is likely incorrect, as prefix tokens (past context) should be attended to. The mask for these tokens should be False.

Suggested change
prefix_mask = jnp.full((prefix_len), True, dtype=jnp.bool)
prefix_mask = jnp.full((prefix_len), False, dtype=jnp.bool)

Comment on lines +118 to +121
if custom_mask == None or custom_mask.size < jnp.cumsum(kv_lens)[-1]:
raise ValueError(
f"use custom_mask, custom_mask length must larger than total kv length"
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This validation logic can be improved for clarity and efficiency. jnp.cumsum(kv_lens) is computed here for validation and again on line 129. This is a minor inefficiency. Consider computing it once and reusing the result.

Additionally, it's more idiomatic and recommended by PEP 8 to use is None and is not None for singleton comparisons instead of == None and != None.

@Iamleos Iamleos force-pushed the feat/mask_attn_scratch branch 4 times, most recently from 44bb4e7 to eef3aae Compare October 14, 2025 13:17
@Iamleos Iamleos force-pushed the feat/mask_attn_scratch branch from eef3aae to 1bf099a Compare October 14, 2025 13:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants