Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Oct 7, 2025

Description

Motivation:

  • FSDP2 training currently doesn't work with model initialized with fp8 weights. And if high precision weights are used with TE layers along with low precision autocast, the memory consumed by the model is more than what the model would consume with BF16 when te auto-cast is used, making it difficult to adopt TE for fp8 based fsdp2 training(issue). Hence it will be useful to get FSDP2 to work with FP8 initialized weights(issue).
  • Along with fixing the memory usage for model initialized with FP8 weight tensors we also want FSDP2 to actually work in terms of the FP8 tensors getting updated correctly after every training step. Current behavior is the Float8Tensors for weights dont get updated. This is not just specific to FSDP but also to DDP with fp8 initialized weights.issue
  • We also want the FSDP weight allgather to use FP8 instead of a high precision allgather for efficient training performance. Currently in TE for fp8 initialized weights, allgather happens in high precision.(issue).

What this PR does?

  • Enables FSDP2 based model training EtoE for any pytorch model with TE layers and FP8 initialized weights
    • Solves the memory foot-print issue with FP8 initialized weights. Initialization with FP8(per-tensor scaling) on balckwell takes half the memory footprint compared to BF16 which is expected. MXFP8 and BF16 consume the same amount of memory due to both rowwise/columnwise usages needed in case of MXFP8.
    • Fixes the FP8 weight updates when model is initialized with FP8 weights to ensure correctness of training results
    • Enables 8bit weight Allgather for both FP8/MXFP8 tensors.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • FSDP Allgather Hooks for FP8/MXFP8: Adds fsdp_pre_allgather and fsdp_post_allgather methods for for FP8/MXFP8 tensors, since allgather is only supported for native torch tensors with uint8/fp16/bf16/fp32 data types. fsdp_pre_all_gather method for us would return the uint8 sharded tensors for FP8/MXFP8 that we need to allgather and the metadata that is needed to reconstruct the FP8/MXFP8 tensor post allgather. Post_Allgather reconstructs the Float8/MXFP8 tensor from the allgathered uint8 data.

    • Handling quantization usages in allgather: Assumption here is that fsdp_pre_all_gather and post_all_gather methods are only going to be called for the weight tensors, which is a fair assumption since fsdp is only used to shard the weights. Which means that we would be using rowwise usage for the forward pass and columnwise usage for the backward pass.
    • Identifying forward/backward pass during allgather: This is needed since just one to rowwise/columnise usages need to be allgathered based on whether it is a forward/backward pass of the training step. fsdp_pre_all_gather method passes module as an argument which is essentially nn Module that has the Quantized tensor registered as a parameter. This module might not necessarily be an FSDP module since we might be wrapping the module at a much higher level in the heirarchy(For eg TransformerLayer and not wrapping the submodule Linear). Hence we have a method that computes the lowest common ancestor FSDP module and uses that to get FSDP state which has the information as to whether it is a forward or backward pass. NOTE: The return value is cached with lru_cache since we dont want to call during every iteration/allgather done during training. The return value is a reference which is mutated internally by FSDP during the course of training.
    • Reshard After Forward: FSDP2 allows for a configuration that tells whether the parameters need to be resharded after forward pass (meaning weights will be re-allgathered again for backward pass). By default, this configuration is set to False for the root module and True for submodules. This configuration is obtainable from the FSDP state of the module , the parameter belongs to. And is used to determine whether we need to send both rowwise/columnwise data in one-go or just one of them based on forward/backward pass. This is more important in MXFP8 since we might want to send both the usages, instead of sending just one usage, dequantizing and quantizing back to get all necessary usages(leading to quantization errors).
    • Current Scaling Quantization: In case of Current Scaling quantization, we need to make sure there is one single amax/scale inverse being used across all the shards which is going to be true when the model is initialized. However, each quantized weight shard is updated independently by the optimizer during training. And hence we need to set amax reduction group in quantizer if not already set. And so this is done in the allgather of forward pass itself(by utilizing fsdp mesh information), so that when the weight shard is updated, quantizer is going to synchronize among the shards to compute a single amax and hence make sure each weight shard uses the same scale inverse.
  • FP8/MXFP8 Torch Dispatch Functions for FSDP2 to handle ops on both rowwise/columnwise data(MXFP8), data/transpose(FP8). NOTE(Scale Inv Padding also handled for MXFP8 pre and post all gather).

    • Split Function If the model is initialized on CUDA device at the start, torch chunk/split is called on our custom Quantized tensors to split tensors along dimension 0. At the end of split FSDP2 keeps the split/shard that is needed for that process and discards everything else to free memory before model training. NOTE: In case of meta deferred initialization this method isnt called. And quantized tensors are directly instantiated for the weight shard corresponding to the process rather than initializing everything and discarding the shards not needed.
    • new_zeros: Implementing this function will make sure a new tensor is created with shape that the shard is supposed to be of. Original implementation in Float8Tensor dint create a deep copy for the scale inverses. That is fixed now.
    • copy: Splitted/Sharded tensor is then copied to the zero tensor created above.
    • as_strided: FSDP2 allows for a possibility where one of the shards might have fewer elements than the other shard if split dimension 0 has number of elements not divisible by the number of shards. It pads the smaller shard. And hence calls as_strided API after allgather to remove the padding. Currently we dont handle the case where divisibilty condition is not met(would be complicated for mxfp8 and beyond scope of this PR) and hence as_strided API is essentially a no-op for us.
    • view: In FSDP2, sharded parameters are flattened with view and that is used to allgather when compiled autograd is enabled. However, for MXFP8 we throw an error if we flatten the tensor since the last dimension of MXFP8 should never change. Currently in that case, we are enabling the dequantization followed by high precision view path, so that FSDP2 doesnt fail. However, we raise a warning when that happens. This is not concern for us at the moment since we dont use compiled autograd and so this view is essentially not even used.
  • Quantized Tensor Class Issues:

    • Missing Dequantize/Compute/Quantize Pathway: When optimizer is applied on FP8/MXFP8 weights, optimizer sends the optimizer ops(lerp for weight update) on a list of Float8 Weights instead of individually doing an op on each Float8 weight seperately. Our normal dequantize/Compute Op/Quantize route didnt handle a list of Float8 Tensors and so, weights were not getting updated in place. PR fixes this.
    • make_like API relying on data Attribute: make_like API in Quantized tensor class should not be setting data attribute since that is specific to Float8Tensor. So that setting logic is moved to Float8Tensor class instead.
  • Validating rowwise/columnwise Usages for quantizers/tensors in TE Layers

    • Weight Tensor Usage Validation: Currently we validate the presence of all desired rowwise/columnwise usages for weight tensors in the forward pass of our Layers itself. However in case of FSDP2, different usages are allgathered in forward and in backward pass. So validation of appropriate quantization usages are moved to forward and backward functions of the layers i.e rowwise usage is needed in forward and columnwise usage is needed in backward.
    • Quantizer Usage Validation: We also update the weight quantizer even when weights are already in FP8. If weights are already in FP8, there is no need to update the quantizer since the damage is already done and that quantizer is never going to be used. And hence this update is now removed from the code.
  • Resetting Parameters for Deferred Initialization(meta device)

    • Updating Dtensors instead of regular tensor: In case of deferred initialization with FSDP2. Parameters are Dtensors that just hold unmaterialized shard needed by the process. And so the local tensor of Dtensor needs to be updated with quantized weights initialized with param_init_fn.
    • Current scaling quantization: For this case, amax reduction group needs to passed to the quantizer so that all weight shards initialized share a single scale inverse.
  • Test and Miscellaneous issues

    • More complete Test Cases for FSDP2: Originally the test only enabled to test a linear layer. Now we can test it with model created with different TE layers. And tests for combinations with and without fp8 model init and different quantization recipes(fp8/mxfp8). NOTE: NVFP4 is pending.
    • View and Reshape not handling Columnwise elegantly In case the columnwise data is present and is accurate, view and reshape ops are now also performed on the transpose(FP8)/columnwise-data(MXFP8) instead of invalidating them.
    • Float8 make_empty API: For make_empty if transpose is desired, shape of transpose created originally was (shape[-1), math.prod(shape[:-1])). Now made it consistent with the transpose shapes we create in C++ which is essentially (shape[-1], shape[0], shape[1]....shape[-2]). This is needed since, we are handling transpose ops in the torch dispatch needed for FSDP2 and we need to be consistent everywhere.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Summary by CodeRabbit

Release Notes

  • New Features
    • Added FP8 mixed-precision training support with FSDP2/HSDP distributed sharding.
    • Introduced multiple FP8 quantization scaling recipes: delayed scaling, current scaling, and MX_FP8 block scaling.
    • Expanded distributed training configuration options: batch size, sequence length, data type, layer configuration, number of layers, device placement, and sharding specification.
    • Improved distributed tensor parameter support and synchronization for FSDP integration.

@vthumbe1503 vthumbe1503 changed the title FSDP2 Weight Update Fix [Pytorch] FSDP2 Weight Update Fix Oct 8, 2025
@vthumbe1503 vthumbe1503 changed the title [Pytorch] FSDP2 Weight Update Fix [PyTorch] FSDP2 Weight Update Fix Oct 8, 2025
@vthumbe1503 vthumbe1503 changed the title [PyTorch] FSDP2 Weight Update Fix [PyTorch] TE FSDP2 Support for FP8/MXFP8 Oct 17, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables end-to-end FSDP2 training with FP8/MXFP8 initialized weights in Transformer Engine, addressing three critical issues: memory footprint with FP8 initialization, FP8 weight update correctness, and enabling 8-bit weight all-gather for efficient training.

Key Implementation Details:

  • FSDP All-Gather Hooks: Implements fsdp_pre_all_gather and fsdp_post_all_gather methods for Float8/MXFP8 tensors to convert between FP8 and uint8 representations. The hooks intelligently determine forward vs backward pass using FSDP state and selectively all-gather rowwise (forward) or columnwise (backward) usages for MXFP8 tensors when reshard_after_forward=True.

  • Torch Dispatch Functions: Adds handlers for aten.split.Tensor, aten.new_zeros, aten.copy_, aten.as_strided, and aten.slice.Tensor to enable FSDP2 weight sharding operations on quantized tensors. Falls back to high-precision dequantization path when padding requirements aren't met.

  • Usage Validation Refactoring: Moves quantization usage validation from module forward pass to the functional layer methods, allowing FSDP2 to all-gather different usages in forward vs backward passes.

  • DTensor Support: Updates reset_parameters to properly handle DTensor shards in deferred initialization, setting amax reduction groups for current scaling quantization to ensure consistent scale_inv across shards.

Critical Issue Found:

  • Line 414 in mxfp8_tensor.py: AttributeError when only columnwise usage is enabled - the code attempts .size() on splitted_tensor_data[0] which will be None when _rowwise_data doesn't exist.

Confidence Score: 3/5

  • This PR has a critical bug that will cause runtime errors in certain MXFP8 configurations, but is otherwise well-designed for FSDP2 integration
  • Score reflects one definite runtime bug in MXFP8 split logic (line 414) that will fail when only columnwise usage is enabled. The rest of the implementation is solid with comprehensive test coverage, proper fallback paths, and well-thought-out handling of forward/backward pass differences.
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py requires immediate attention - line 414 will cause AttributeError in production

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Adds FSDP2 torch dispatch handlers (split, copy, new_zeros, as_strided, slice) and allgather hooks for MXFP8 tensors. Critical bug on line 414 where .size() called on potentially None tensor.
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 torch dispatch handlers (split, new_zeros) and allgather hooks. Improves view/reshape handling for transpose caching.
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper to find FSDP state for modules, with LRU caching for performance.
transformer_engine/pytorch/module/base.py 4/5 Updates reset_parameters to handle DTensor shards in FSDP2 deferred initialization and sets amax reduction groups for current scaling.

Sequence Diagram

sequenceDiagram
    participant App as Application
    participant FSDP as FSDP2 Framework
    participant Tensor as Float8/MXFP8Tensor
    participant Quantizer as Quantizer
    participant Dist as Distributed Ops

    Note over App,Dist: Model Initialization Phase
    App->>Tensor: fp8_model_init(enabled=True)
    Tensor->>Tensor: Create FP8/MXFP8 weights
    App->>FSDP: fully_shard(module, mesh)
    FSDP->>Tensor: split() - torch dispatch
    Tensor->>Tensor: Split rowwise/columnwise data
    Tensor-->>FSDP: Return weight shards

    Note over App,Dist: Forward Pass - All-Gather
    FSDP->>Tensor: fsdp_pre_all_gather(module, mesh)
    Tensor->>Dist: _get_module_fsdp_state(module)
    Dist-->>Tensor: FSDP state with TrainingState
    Tensor->>Quantizer: set_usage(rowwise=True, columnwise=False)
    Tensor-->>FSDP: (rowwise_data, rowwise_scale_inv), metadata
    FSDP->>Dist: all_gather(uint8 data)
    Dist-->>FSDP: Gathered uint8 tensors
    FSDP->>Tensor: fsdp_post_all_gather(outputs, metadata)
    Tensor->>Tensor: Reconstruct FP8/MXFP8 from uint8
    Tensor-->>FSDP: Full unsharded FP8 weight

    Note over App,Dist: Backward Pass - All-Gather
    FSDP->>Tensor: fsdp_pre_all_gather(module, mesh)
    Tensor->>Dist: _get_module_fsdp_state(module)
    Dist-->>Tensor: FSDP state (PRE_BACKWARD)
    Tensor->>Quantizer: set_usage(rowwise=False, columnwise=True)
    Tensor-->>FSDP: (columnwise_data, columnwise_scale_inv), metadata
    FSDP->>Dist: all_gather(uint8 data)
    Dist-->>FSDP: Gathered uint8 tensors
    FSDP->>Tensor: fsdp_post_all_gather(outputs, metadata)
    Tensor->>Tensor: Reconstruct FP8/MXFP8 from uint8
    Tensor-->>FSDP: Full unsharded FP8 weight

    Note over App,Dist: Optimizer Step
    App->>Tensor: optimizer.step()
    Tensor->>Tensor: Update FP8 weight shards in-place
    Tensor->>Quantizer: Sync amax across shards (if current_scaling)
Loading

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

out_data.append(scale_inv_out)
return [
MXFP8Tensor(
shape=splitted_tensor_data[0].size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: AttributeError when only columnwise data exists - splitted_tensor_data[0] will be None when _rowwise_data is None (lines 379-390 append None when data is None), causing .size() to fail

Suggested change
shape=splitted_tensor_data[0].size(),
shape=splitted_tensor_data[0].size() if splitted_tensor_data[0] is not None else splitted_tensor_data[1].size(),

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 421 to 434
return [
MXFP8Tensor(
shape=splitted_tensor_data[0].size(),
dtype=tensor.dtype,
rowwise_data=splitted_tensor_data[0],
rowwise_scale_inv=splitted_tensor_data[2],
columnwise_data=splitted_tensor_data[1],
columnwise_scale_inv=splitted_tensor_data[3],
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
for splitted_tensor_data in zip(*out_data)
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: AttributeError when _rowwise_data is None - splitted_tensor_data[0] will be None when only columnwise usage is enabled (lines 383-394 append None when data is None), causing .size() call to fail on line 423

Suggested change
return [
MXFP8Tensor(
shape=splitted_tensor_data[0].size(),
dtype=tensor.dtype,
rowwise_data=splitted_tensor_data[0],
rowwise_scale_inv=splitted_tensor_data[2],
columnwise_data=splitted_tensor_data[1],
columnwise_scale_inv=splitted_tensor_data[3],
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
for splitted_tensor_data in zip(*out_data)
]
return [
MXFP8Tensor(
shape=splitted_tensor_data[0].size() if splitted_tensor_data[0] is not None else splitted_tensor_data[1].size(),
dtype=tensor.dtype,
rowwise_data=splitted_tensor_data[0],
rowwise_scale_inv=splitted_tensor_data[2],
columnwise_data=splitted_tensor_data[1],
columnwise_scale_inv=splitted_tensor_data[3],
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
for splitted_tensor_data in zip(*out_data)
]

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 424 to 426
splitted_tensor_data[0].size()
if splitted_tensor_data[0] is not None
else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: AttributeError when only columnwise_usage is enabled - splitted_tensor_data[0] (rowwise_data) will be None when tensor._rowwise_data is None (lines 383-394 append None when data is None), causing .size() to fail

Suggested change
splitted_tensor_data[0].size()
if splitted_tensor_data[0] is not None
else None
shape=(
splitted_tensor_data[0].size()
if splitted_tensor_data[0] is not None
else splitted_tensor_data[1].size()
),

Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix

Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503 vthumbe1503 merged commit 29537c9 into NVIDIA:main Nov 11, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants