-
Notifications
You must be signed in to change notification settings - Fork 541
[JAX] Fix bug with pre scale bias #2300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR fixes a critical bug in the JAX attention implementation where pre-scale bias was being added twice to attention weights. The fix modifies transformer.py to add the bias once in the unfused attention path and then nullify it (bias = None) to prevent the downstream DPA module from adding it again. Additionally, the test suite was strengthened by replacing forward-pass mean-only comparison with full tensor comparison for non-FP8 modes, addressing inadequate test coverage that allowed the bug to persist (backward tests used rtol/atol ~0.039 when actual gradient differences were ~0.003).
The changes integrate with the existing attention bias architecture where AttnBiasType.PRE_SCALE_BIAS requires special handling—bias must be added before scaling, unlike POST_SCALE_BIAS. The fix coordinates between the transformer layer and the Softmax/DPA module to ensure bias is applied exactly once in the attention computation pipeline.
Confidence score: 4/5
- This PR fixes a correctness bug and strengthens test coverage but has minor concerns about FP8 testing gaps
- Score reflects that while the bias fix is correct, the FP8 fallback in tests (lines 345-348 of test_layer.py) still uses mean-only comparison, perpetuating weak test coverage for that code path; also, parameter order inconsistency between
_output_fnand_loss_fncould cause maintenance issues - Pay close attention to tests/jax/test_layer.py lines 345-348where FP8 mode still uses mean-only comparison instead of full tensor comparison
2 files reviewed, 1 comment
tests/jax/test_layer.py
Outdated
| if not get_quantize_config().is_fp8_enabled(): | ||
| assert_allclose(ref_out, test_out, **tols) | ||
| else: | ||
| assert_allclose(ref_out.mean(), test_out.mean(), **tols) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Only comparing means in FP8 mode loses test coverage. This doesn't validate output shape, distribution, or element-wise correctness. The PR description mentions test_layer only tests means, which was insufficient to catch the double-addition bug. What specific tolerance values were found to work for full tensor comparison in FP8mode without false positives?
Prompt To Fix With AI
This is a comment left during a code review.
Path: tests/jax/test_layer.py
Line: 345:348
Comment:
**logic:** Only comparing means in FP8 mode loses test coverage. This doesn't validate output shape, distribution, or element-wise correctness. The PR description mentions test_layer only tests means, which was insufficient to catch the double-addition bug. What specific tolerance values were found to work for full tensor comparison in FP8mode without false positives?
How can I resolve this? If you propose a fix, please make it concise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference for fp8 is around 0.4, which is huge. Maybe this is the bug and needs investigation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to skip this tests to get this merged sooner.
|
Thanks for finding and raising the tolerance issue Pawel! I'll take a look today and see how we can make these tolerances stricter |
Signed-off-by: Pawel Gadzinski <[email protected]>
1e91f44 to
c0a0947
Compare
Signed-off-by: Pawel Gadzinski <[email protected]>
|
@jberchtold-nvidia I removed the tolerance change from this PR, since the pipeline was failing and I saw that you are working on it in the other PR. I will leave only the fix to the bug. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Fixed critical bug where pre-scale bias was incorrectly added twice in attention computation. When attn_bias_type == PRE_SCALE_BIAS, the bias was added to attn_weights at transformer.py:199 and then added again inside the Softmax module at module.py:194. The fix sets bias = None after the first addition to prevent double-addition.
Key changes:
- Added
bias = Noneafter pre-scale bias addition in_UnfusedDotProductAttention - Ensures bias is only applied once for PRE_SCALE_BIAS case
- POST_SCALE_BIAS path remains unchanged (bias still passed to Softmax)
Note from PR author: The original tests were insufficient to catch this bug because they only compared means in forward pass and used overly loose tolerances (rtol/atol ~0.039 for bf16 when actual gradient magnitudes were ~0.003) for backward pass comparisons.
Confidence Score: 5/5
- This PR is safe to merge - it's a minimal, surgical fix for a clear bug
- The fix is a simple one-line addition that correctly prevents double-addition of bias. The logic is clear: after manually adding bias for PRE_SCALE_BIAS case, setting it to None prevents it from being added again in the Softmax module. No edge cases or side effects identified.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/flax/transformer.py | 5/5 | Fixed double-addition bug for PRE_SCALE_BIAS by setting bias to None after first addition |
Sequence Diagram
sequenceDiagram
participant DPA as _UnfusedDotProductAttention
participant SM as Softmax (module.py)
Note over DPA: attn_bias_type = PRE_SCALE_BIAS
rect rgb(255, 200, 200)
Note over DPA: BEFORE FIX (Bug)
DPA->>DPA: Line 199: attn_weights += bias
DPA->>SM: Line 240: Softmax(attn_weights, mask, bias)
SM->>SM: Line 194: logits = logits + bias
Note over SM: ❌ Bias added TWICE!
end
rect rgb(200, 255, 200)
Note over DPA: AFTER FIX (Correct)
DPA->>DPA: Line 199: attn_weights += bias
DPA->>DPA: Line 200: bias = None
DPA->>SM: Line 240: Softmax(attn_weights, mask, None)
SM->>SM: Line 194: Skip (bias is None)
Note over SM: ✓ Bias added once
end
1 file reviewed, no comments
|
/te-ci jax |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
KshitijLakhani
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]>
* fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
Description
Pre-scale bias is added twice:
here for the first time:
TransformerEngine/transformer_engine/jax/flax/transformer.py
Line 199 in e2f2a0b
here for the second:
TransformerEngine/transformer_engine/jax/flax/module.py
Line 194 in e2f2a0b
But what's more important there is problem with test_layer. They test only mean of forward, which is not enough to catch most bugs. For backward, all tensors are compared, but with some very big precision computed here
TransformerEngine/tests/jax/utils.py
Line 1486 in e2f2a0b
For bf16 it is
{'rtol': 0.039372532809214794, 'atol': 0.039372532809214794}, but max of all wgrads and dgrads is around0.003, so basically it tests nothing.I did some temporary fix to catch this issue, we will see if anything more will be caught. But these tols need to be lowered.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: