Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Oct 23, 2025

Description

Pre-scale bias is added twice:

here for the first time:

here for the second:

logits = logits + bias.astype(input_dtype)

But what's more important there is problem with test_layer. They test only mean of forward, which is not enough to catch most bugs. For backward, all tensors are compared, but with some very big precision computed here

def dtype_tols(

For bf16 it is {'rtol': 0.039372532809214794, 'atol': 0.039372532809214794}, but max of all wgrads and dgrads is around 0.003, so basically it tests nothing.

I did some temporary fix to catch this issue, we will see if anything more will be caught. But these tols need to be lowered.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 2 commits October 23, 2025 21:15
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 23, 2025

/te-ci jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR fixes a critical bug in the JAX attention implementation where pre-scale bias was being added twice to attention weights. The fix modifies transformer.py to add the bias once in the unfused attention path and then nullify it (bias = None) to prevent the downstream DPA module from adding it again. Additionally, the test suite was strengthened by replacing forward-pass mean-only comparison with full tensor comparison for non-FP8 modes, addressing inadequate test coverage that allowed the bug to persist (backward tests used rtol/atol ~0.039 when actual gradient differences were ~0.003).

The changes integrate with the existing attention bias architecture where AttnBiasType.PRE_SCALE_BIAS requires special handling—bias must be added before scaling, unlike POST_SCALE_BIAS. The fix coordinates between the transformer layer and the Softmax/DPA module to ensure bias is applied exactly once in the attention computation pipeline.

Confidence score: 4/5

  • This PR fixes a correctness bug and strengthens test coverage but has minor concerns about FP8 testing gaps
  • Score reflects that while the bias fix is correct, the FP8 fallback in tests (lines 345-348 of test_layer.py) still uses mean-only comparison, perpetuating weak test coverage for that code path; also, parameter order inconsistency between _output_fn and _loss_fn could cause maintenance issues
  • Pay close attention to tests/jax/test_layer.py lines 345-348where FP8 mode still uses mean-only comparison instead of full tensor comparison

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 345 to 348
if not get_quantize_config().is_fp8_enabled():
assert_allclose(ref_out, test_out, **tols)
else:
assert_allclose(ref_out.mean(), test_out.mean(), **tols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Only comparing means in FP8 mode loses test coverage. This doesn't validate output shape, distribution, or element-wise correctness. The PR description mentions test_layer only tests means, which was insufficient to catch the double-addition bug. What specific tolerance values were found to work for full tensor comparison in FP8mode without false positives?

Prompt To Fix With AI
This is a comment left during a code review.
Path: tests/jax/test_layer.py
Line: 345:348

Comment:
**logic:** Only comparing means in FP8 mode loses test coverage. This doesn't validate output shape, distribution, or element-wise correctness. The PR description mentions test_layer only tests means, which was insufficient to catch the double-addition bug. What specific tolerance values were found to work for full tensor comparison in FP8mode without false positives?

How can I resolve this? If you propose a fix, please make it concise.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference for fp8 is around 0.4, which is huge. Maybe this is the bug and needs investigation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to skip this tests to get this merged sooner.

@pggPL pggPL requested a review from KshitijLakhani October 23, 2025 21:57
@jberchtold-nvidia jberchtold-nvidia self-requested a review October 24, 2025 15:04
@jberchtold-nvidia
Copy link
Collaborator

Thanks for finding and raising the tolerance issue Pawel! I'll take a look today and see how we can make these tolerances stricter

Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL pggPL force-pushed the inspect_jax_bias_add branch from 1e91f44 to c0a0947 Compare October 31, 2025 09:49
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 31, 2025

@jberchtold-nvidia I removed the tolerance change from this PR, since the pipeline was failing and I saw that you are working on it in the other PR. I will leave only the fix to the bug.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Fixed critical bug where pre-scale bias was incorrectly added twice in attention computation. When attn_bias_type == PRE_SCALE_BIAS, the bias was added to attn_weights at transformer.py:199 and then added again inside the Softmax module at module.py:194. The fix sets bias = None after the first addition to prevent double-addition.

Key changes:

  • Added bias = None after pre-scale bias addition in _UnfusedDotProductAttention
  • Ensures bias is only applied once for PRE_SCALE_BIAS case
  • POST_SCALE_BIAS path remains unchanged (bias still passed to Softmax)

Note from PR author: The original tests were insufficient to catch this bug because they only compared means in forward pass and used overly loose tolerances (rtol/atol ~0.039 for bf16 when actual gradient magnitudes were ~0.003) for backward pass comparisons.

Confidence Score: 5/5

  • This PR is safe to merge - it's a minimal, surgical fix for a clear bug
  • The fix is a simple one-line addition that correctly prevents double-addition of bias. The logic is clear: after manually adding bias for PRE_SCALE_BIAS case, setting it to None prevents it from being added again in the Softmax module. No edge cases or side effects identified.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/flax/transformer.py 5/5 Fixed double-addition bug for PRE_SCALE_BIAS by setting bias to None after first addition

Sequence Diagram

sequenceDiagram
    participant DPA as _UnfusedDotProductAttention
    participant SM as Softmax (module.py)
    
    Note over DPA: attn_bias_type = PRE_SCALE_BIAS
    
    rect rgb(255, 200, 200)
        Note over DPA: BEFORE FIX (Bug)
        DPA->>DPA: Line 199: attn_weights += bias
        DPA->>SM: Line 240: Softmax(attn_weights, mask, bias)
        SM->>SM: Line 194: logits = logits + bias
        Note over SM: ❌ Bias added TWICE!
    end
    
    rect rgb(200, 255, 200)
        Note over DPA: AFTER FIX (Correct)
        DPA->>DPA: Line 199: attn_weights += bias
        DPA->>DPA: Line 200: bias = None
        DPA->>SM: Line 240: Softmax(attn_weights, mask, None)
        SM->>SM: Line 194: Skip (bias is None)
        Note over SM: ✓ Bias added once
    end
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator Author

pggPL commented Oct 31, 2025

/te-ci jax

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pggPL pggPL merged commit b6020e3 into NVIDIA:main Nov 5, 2025
23 checks passed
pggPL added a commit to pggPL/TransformerEngine that referenced this pull request Nov 6, 2025
* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
wdykas pushed a commit to wdykas/TransformerEngine that referenced this pull request Nov 12, 2025
* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants