-
Notifications
You must be signed in to change notification settings - Fork 0
feat: support custom mask for flash attention with scratch #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feat/eagle-support
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @Iamleos, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Flash Attention mechanism by integrating support for custom attention masks. This feature provides greater control over the attention computation, moving beyond fixed causal masking to allow for arbitrary masking patterns. The changes involve updates across the attention functions, the core kernel, and associated metadata structures, ensuring that custom masks can be efficiently applied during both reference and optimized attention calculations. This flexibility is crucial for advanced attention models and research. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for custom attention masks in the flash attention kernel, which is a significant feature. The changes span the reference implementation, the Pallas kernel, the attention backend, and the tests.
Overall, the direction is good, but I've identified a few critical issues that must be addressed before merging:
- There's a bug in the reference implementation where using
sliding_windowwith a custom mask (causal=False) will cause a crash. - A data type mismatch for the custom mask buffers between HBM and VMEM in the Pallas kernel could lead to incorrect behavior.
I've also found other high-priority issues, including a bug in the test logic for custom masks and a problematic implementation of mask loading in the kernel. Please review my detailed comments for suggestions on how to resolve these issues.
| if causal: | ||
| q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( | ||
| jnp.int32, attn.shape, 1 | ||
| ) | ||
| kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) | ||
| mask = q_span < kv_span | ||
| else: | ||
| mask_start = cu_kv_lens[i] | ||
| mask = custom_mask[mask_start : mask_start + kv_len] | ||
| if sliding_window is not None: | ||
| mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential UnboundLocalError here. The variables q_span and kv_span are defined only within the if causal: block (lines 150-153), but they are used on line 159 regardless of whether causal is true. If causal=False, this code will raise an error.
The sliding_window logic should likely be moved inside the if causal: block, as it seems to be a feature of causal attention. If sliding_window is intended to work with custom_mask, its implementation needs to be defined for the causal=False case.
| if causal: | |
| q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( | |
| jnp.int32, attn.shape, 1 | |
| ) | |
| kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) | |
| mask = q_span < kv_span | |
| else: | |
| mask_start = cu_kv_lens[i] | |
| mask = custom_mask[mask_start : mask_start + kv_len] | |
| if sliding_window is not None: | |
| mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) | |
| if causal: | |
| q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( | |
| jnp.int32, attn.shape, 1 | |
| ) | |
| kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) | |
| mask = q_span < kv_span | |
| if sliding_window is not None: | |
| mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) | |
| else: | |
| mask_start = cu_kv_lens[i] | |
| mask = custom_mask[mask_start : mask_start + kv_len] |
|
|
||
| bkvmask_double_buf = pltpu.VMEM( | ||
| (2, bq_sz, bkv_sz, head_dim), | ||
| jnp.bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a data type mismatch for the custom mask buffer. bkvmask_double_buf is defined with jnp.bool, but the data copied into it is jnp.int32. DMA transfers might not handle this conversion implicitly and could lead to incorrect mask values or errors.
To ensure correctness, the dtype of the VMEM buffer should match the dtype of the data in HBM. I recommend using a consistent integer type (e.g., jnp.int32 as used in the ragged_paged_attention function) for the mask throughout the pipeline.
jnp.int32,| mask = bkvmask_ref[bkv_sem_idx, :actual_bq_sz] | ||
| num_q_heads_per_kv_head_mask = jnp.concat( | ||
| [mask] * num_q_heads_per_kv_head | ||
| ) | ||
| num_kv_heads_mask = jnp.concat( | ||
| [ | ||
| num_q_heads_per_kv_head_mask.reshape( | ||
| 1, *num_q_heads_per_kv_head_mask.shape | ||
| ) | ||
| ] | ||
| * actual_num_kv_heads | ||
| ) | ||
| return num_kv_heads_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This load_mask function has a bug and an inefficiency.
- Bug: The mask is not sliced correctly. It should be sliced to
actual_bkv_szon the key dimension to match the shape of the attention scoress. Currently, it usesbkv_sz, which can lead to shape mismatches inlax.selectifactual_bkv_sz < bkv_sz. - Inefficiency: Using
jnp.concat([x] * n)is inefficient.jnp.tileis the idiomatic and more performant way to repeat an array in JAX.
mask = bkvmask_ref[bkv_sem_idx, :actual_bq_sz, :actual_bkv_sz]
num_q_heads_per_kv_head_mask = jnp.tile(
mask, (num_q_heads_per_kv_head, 1, 1)
)
num_kv_heads_mask = jnp.tile(
jnp.expand_dims(num_q_heads_per_kv_head_mask, 0),
(actual_num_kv_heads, 1, 1, 1),
)
return num_kv_heads_mask| if batch.forward_mode == ForwardMode.TARGET_VERIFY: | ||
| # convert custom_mask from bool to int8, because dma not support bool type | ||
| if batch.spec_info.custom_mask.dtype == jnp.bool: | ||
| metadata.custom_mask = batch.spec_info.custom_mask.astype(jnp.int32) | ||
| else: | ||
| metadata.custom_mask = batch.spec_info.custom_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an inconsistency in the data type used for custom_mask. The comment on line 104 mentions int8, but the code casts to jnp.int32. The Pallas kernel also has conflicting comments and usage (int8 comment, int32 usage, bool buffer). This should be standardized across the implementation to avoid errors and confusion. Using a consistent type like jnp.int32 everywhere (including in the VMEM buffer) would resolve this.
| custom_masks = [] | ||
| for bid, seq_len in enumerate([kv_len for _, kv_len in lens]): | ||
| prefix_len = seq_len - q_lens[bid] | ||
| prefix_mask = jnp.full((prefix_len), True, dtype=jnp.bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prefix_mask is initialized with True, which means these tokens will be masked out in the attention calculation. This is likely incorrect, as prefix tokens (past context) should be attended to. The mask for these tokens should be False.
| prefix_mask = jnp.full((prefix_len), True, dtype=jnp.bool) | |
| prefix_mask = jnp.full((prefix_len), False, dtype=jnp.bool) |
| if custom_mask == None or custom_mask.size < jnp.cumsum(kv_lens)[-1]: | ||
| raise ValueError( | ||
| f"use custom_mask, custom_mask length must larger than total kv length" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This validation logic can be improved for clarity and efficiency. jnp.cumsum(kv_lens) is computed here for validation and again on line 129. This is a minor inefficiency. Consider computing it once and reusing the result.
Additionally, it's more idiomatic and recommended by PEP 8 to use is None and is not None for singleton comparisons instead of == None and != None.
44bb4e7 to
eef3aae
Compare
eef3aae to
1bf099a
Compare
No description provided.