-
Notifications
You must be signed in to change notification settings - Fork 0
Fix mask support and fallback context for GPU tests #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -896,11 +896,21 @@ def forward( | |||||
| effective_dropout = dropout_p if self.training else 0.0 | ||||||
| deterministic_mode = self.deterministic if deterministic is None else deterministic | ||||||
|
|
||||||
| mask_supported = (attention_mask is None) or ( | ||||||
| attention_mask.dim() == 2 | ||||||
| and attention_mask.shape[0] == batch_size | ||||||
| and attention_mask.shape[1] == seq_len_k | ||||||
| ) | ||||||
| mask_supported = attention_mask is None | ||||||
| if not mask_supported: | ||||||
| dims = attention_mask.dim() | ||||||
| if dims == 2: | ||||||
| mask_supported = attention_mask.shape[0] == batch_size and attention_mask.shape[1] == seq_len_k | ||||||
| elif dims == 3: | ||||||
| mask_supported = ( | ||||||
| attention_mask.shape[0] == batch_size | ||||||
| and attention_mask.shape[1] == seq_len_q | ||||||
| and attention_mask.shape[2] == seq_len_k | ||||||
| ) | ||||||
| elif dims == 4: | ||||||
| mask_supported = attention_mask.shape[0] == batch_size and attention_mask.shape[-1] == seq_len_k | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 4D attention mask validation only checks batch and K lengths, allowing masks with mismatched head or Q dimensions to slip through and fail during mask expansion. Tighten the 4D check to also validate shape[1] (allow 1 or num_heads) and shape[2] (seq_len_q). Prompt for AI agents
Suggested change
|
||||||
| else: | ||||||
| mask_supported = False | ||||||
|
|
||||||
| use_triton = ( | ||||||
| TRITON_AVAILABLE | ||||||
|
|
@@ -1010,18 +1020,13 @@ def forward( | |||||
| sdpa_ctx = nullcontext() | ||||||
| if q.is_cuda: | ||||||
| try: | ||||||
| sdpa_ctx = torch.nn.attention.sdpa_kernel( | ||||||
| SDPBackend.FLASH_ATTENTION | ||||||
| sdpa_ctx = torch.backends.cuda.sdp_kernel( | ||||||
| enable_math=True, | ||||||
| enable_flash=False, | ||||||
| enable_mem_efficient=False, | ||||||
| ) | ||||||
| except (AttributeError, TypeError): | ||||||
| try: | ||||||
| sdpa_ctx = torch.backends.cuda.sdp_kernel( | ||||||
| enable_math=True, | ||||||
| enable_flash=True, | ||||||
| enable_mem_efficient=False, | ||||||
| ) | ||||||
| except Exception: # pragma: no cover - environment dependent | ||||||
| sdpa_ctx = nullcontext() | ||||||
| except Exception: # pragma: no cover - environment dependent | ||||||
| sdpa_ctx = nullcontext() | ||||||
| with sdpa_ctx: | ||||||
| out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs) | ||||||
|
|
||||||
|
|
@@ -1649,16 +1654,13 @@ def backward(ctx, grad_output): | |||||
| sdpa_ctx = nullcontext() | ||||||
| if q_ref.is_cuda: | ||||||
| try: | ||||||
| sdpa_ctx = torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) | ||||||
| except (AttributeError, TypeError): | ||||||
| try: | ||||||
| sdpa_ctx = torch.backends.cuda.sdp_kernel( | ||||||
| enable_math=True, | ||||||
| enable_flash=True, | ||||||
| enable_mem_efficient=False, | ||||||
| ) | ||||||
| except Exception: # pragma: no cover - environment dependent | ||||||
| sdpa_ctx = nullcontext() | ||||||
| sdpa_ctx = torch.backends.cuda.sdp_kernel( | ||||||
| enable_math=True, | ||||||
| enable_flash=False, | ||||||
| enable_mem_efficient=False, | ||||||
| ) | ||||||
| except Exception: # pragma: no cover - environment dependent | ||||||
| sdpa_ctx = nullcontext() | ||||||
|
|
||||||
| go = grad_output.permute(0, 2, 1, 3).reshape(bsz * nh, seq_len_q, hd) | ||||||
|
|
||||||
|
|
||||||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Incomplete Mask Validation Causes Runtime Failures
The 4D mask validation is incomplete. It only checks
shape[0](batch size) andshape[-1](seq_len_k), missing compatibility checks forshape[1](num_heads) andshape[2](seq_len_q). This allows incompatible masks to pass validation, leading to runtime failures during mask expansion instead of falling back to PyTorch SDPA.