-
Notifications
You must be signed in to change notification settings - Fork 546
[PyTorch] FSDP2 Support for TE #2245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…rgst Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
…es when required instead of doing upfront in fwd pass Signed-off-by: Varun Thumbe <[email protected]>
…ling in fsdp hook functions Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: vthumbe1503 <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables end-to-end FSDP2 training with FP8/MXFP8 initialized weights in Transformer Engine, addressing three critical issues: memory footprint with FP8 initialization, FP8 weight update correctness, and enabling 8-bit weight all-gather for efficient training.
Key Implementation Details:
-
FSDP All-Gather Hooks: Implements
fsdp_pre_all_gatherandfsdp_post_all_gathermethods for Float8/MXFP8 tensors to convert between FP8 and uint8 representations. The hooks intelligently determine forward vs backward pass using FSDP state and selectively all-gather rowwise (forward) or columnwise (backward) usages for MXFP8 tensors whenreshard_after_forward=True. -
Torch Dispatch Functions: Adds handlers for
aten.split.Tensor,aten.new_zeros,aten.copy_,aten.as_strided, andaten.slice.Tensorto enable FSDP2 weight sharding operations on quantized tensors. Falls back to high-precision dequantization path when padding requirements aren't met. -
Usage Validation Refactoring: Moves quantization usage validation from module forward pass to the functional layer methods, allowing FSDP2 to all-gather different usages in forward vs backward passes.
-
DTensor Support: Updates
reset_parametersto properly handle DTensor shards in deferred initialization, setting amax reduction groups for current scaling quantization to ensure consistent scale_inv across shards.
Critical Issue Found:
- Line 414 in
mxfp8_tensor.py:AttributeErrorwhen only columnwise usage is enabled - the code attempts.size()onsplitted_tensor_data[0]which will beNonewhen_rowwise_datadoesn't exist.
Confidence Score: 3/5
- This PR has a critical bug that will cause runtime errors in certain MXFP8 configurations, but is otherwise well-designed for FSDP2 integration
- Score reflects one definite runtime bug in MXFP8 split logic (line 414) that will fail when only columnwise usage is enabled. The rest of the implementation is solid with comprehensive test coverage, proper fallback paths, and well-thought-out handling of forward/backward pass differences.
- transformer_engine/pytorch/tensor/mxfp8_tensor.py requires immediate attention - line 414 will cause AttributeError in production
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Adds FSDP2 torch dispatch handlers (split, copy, new_zeros, as_strided, slice) and allgather hooks for MXFP8 tensors. Critical bug on line 414 where .size() called on potentially None tensor. |
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 torch dispatch handlers (split, new_zeros) and allgather hooks. Improves view/reshape handling for transpose caching. |
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper to find FSDP state for modules, with LRU caching for performance. |
| transformer_engine/pytorch/module/base.py | 4/5 | Updates reset_parameters to handle DTensor shards in FSDP2 deferred initialization and sets amax reduction groups for current scaling. |
Sequence Diagram
sequenceDiagram
participant App as Application
participant FSDP as FSDP2 Framework
participant Tensor as Float8/MXFP8Tensor
participant Quantizer as Quantizer
participant Dist as Distributed Ops
Note over App,Dist: Model Initialization Phase
App->>Tensor: fp8_model_init(enabled=True)
Tensor->>Tensor: Create FP8/MXFP8 weights
App->>FSDP: fully_shard(module, mesh)
FSDP->>Tensor: split() - torch dispatch
Tensor->>Tensor: Split rowwise/columnwise data
Tensor-->>FSDP: Return weight shards
Note over App,Dist: Forward Pass - All-Gather
FSDP->>Tensor: fsdp_pre_all_gather(module, mesh)
Tensor->>Dist: _get_module_fsdp_state(module)
Dist-->>Tensor: FSDP state with TrainingState
Tensor->>Quantizer: set_usage(rowwise=True, columnwise=False)
Tensor-->>FSDP: (rowwise_data, rowwise_scale_inv), metadata
FSDP->>Dist: all_gather(uint8 data)
Dist-->>FSDP: Gathered uint8 tensors
FSDP->>Tensor: fsdp_post_all_gather(outputs, metadata)
Tensor->>Tensor: Reconstruct FP8/MXFP8 from uint8
Tensor-->>FSDP: Full unsharded FP8 weight
Note over App,Dist: Backward Pass - All-Gather
FSDP->>Tensor: fsdp_pre_all_gather(module, mesh)
Tensor->>Dist: _get_module_fsdp_state(module)
Dist-->>Tensor: FSDP state (PRE_BACKWARD)
Tensor->>Quantizer: set_usage(rowwise=False, columnwise=True)
Tensor-->>FSDP: (columnwise_data, columnwise_scale_inv), metadata
FSDP->>Dist: all_gather(uint8 data)
Dist-->>FSDP: Gathered uint8 tensors
FSDP->>Tensor: fsdp_post_all_gather(outputs, metadata)
Tensor->>Tensor: Reconstruct FP8/MXFP8 from uint8
Tensor-->>FSDP: Full unsharded FP8 weight
Note over App,Dist: Optimizer Step
App->>Tensor: optimizer.step()
Tensor->>Tensor: Update FP8 weight shards in-place
Tensor->>Quantizer: Sync amax across shards (if current_scaling)
11 files reviewed, 1 comment
|
/te-ci L1 pytorch |
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: vthumbe1503 <[email protected]>
|
/te-ci L1 pytorch |
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 2 comments
| out_data.append(scale_inv_out) | ||
| return [ | ||
| MXFP8Tensor( | ||
| shape=splitted_tensor_data[0].size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: AttributeError when only columnwise data exists - splitted_tensor_data[0] will be None when _rowwise_data is None (lines 379-390 append None when data is None), causing .size() to fail
| shape=splitted_tensor_data[0].size(), | |
| shape=splitted_tensor_data[0].size() if splitted_tensor_data[0] is not None else splitted_tensor_data[1].size(), |
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
| return [ | ||
| MXFP8Tensor( | ||
| shape=splitted_tensor_data[0].size(), | ||
| dtype=tensor.dtype, | ||
| rowwise_data=splitted_tensor_data[0], | ||
| rowwise_scale_inv=splitted_tensor_data[2], | ||
| columnwise_data=splitted_tensor_data[1], | ||
| columnwise_scale_inv=splitted_tensor_data[3], | ||
| quantizer=tensor._quantizer, | ||
| requires_grad=False, | ||
| fp8_dtype=tensor._fp8_dtype, | ||
| ) | ||
| for splitted_tensor_data in zip(*out_data) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: AttributeError when _rowwise_data is None - splitted_tensor_data[0] will be None when only columnwise usage is enabled (lines 383-394 append None when data is None), causing .size() call to fail on line 423
| return [ | |
| MXFP8Tensor( | |
| shape=splitted_tensor_data[0].size(), | |
| dtype=tensor.dtype, | |
| rowwise_data=splitted_tensor_data[0], | |
| rowwise_scale_inv=splitted_tensor_data[2], | |
| columnwise_data=splitted_tensor_data[1], | |
| columnwise_scale_inv=splitted_tensor_data[3], | |
| quantizer=tensor._quantizer, | |
| requires_grad=False, | |
| fp8_dtype=tensor._fp8_dtype, | |
| ) | |
| for splitted_tensor_data in zip(*out_data) | |
| ] | |
| return [ | |
| MXFP8Tensor( | |
| shape=splitted_tensor_data[0].size() if splitted_tensor_data[0] is not None else splitted_tensor_data[1].size(), | |
| dtype=tensor.dtype, | |
| rowwise_data=splitted_tensor_data[0], | |
| rowwise_scale_inv=splitted_tensor_data[2], | |
| columnwise_data=splitted_tensor_data[1], | |
| columnwise_scale_inv=splitted_tensor_data[3], | |
| quantizer=tensor._quantizer, | |
| requires_grad=False, | |
| fp8_dtype=tensor._fp8_dtype, | |
| ) | |
| for splitted_tensor_data in zip(*out_data) | |
| ] |
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
| splitted_tensor_data[0].size() | ||
| if splitted_tensor_data[0] is not None | ||
| else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: AttributeError when only columnwise_usage is enabled - splitted_tensor_data[0] (rowwise_data) will be None when tensor._rowwise_data is None (lines 383-394 append None when data is None), causing .size() to fail
| splitted_tensor_data[0].size() | |
| if splitted_tensor_data[0] is not None | |
| else None | |
| shape=( | |
| splitted_tensor_data[0].size() | |
| if splitted_tensor_data[0] is not None | |
| else splitted_tensor_data[1].size() | |
| ), |
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, no comments
Description
Motivation:
What this PR does?
Type of change
Changes
FSDP Allgather Hooks for FP8/MXFP8: Adds fsdp_pre_allgather and fsdp_post_allgather methods for for FP8/MXFP8 tensors, since allgather is only supported for native torch tensors with uint8/fp16/bf16/fp32 data types. fsdp_pre_all_gather method for us would return the uint8 sharded tensors for FP8/MXFP8 that we need to allgather and the metadata that is needed to reconstruct the FP8/MXFP8 tensor post allgather. Post_Allgather reconstructs the Float8/MXFP8 tensor from the allgathered uint8 data.
FP8/MXFP8 Torch Dispatch Functions for FSDP2 to handle ops on both rowwise/columnwise data(MXFP8), data/transpose(FP8). NOTE(Scale Inv Padding also handled for MXFP8 pre and post all gather).
Quantized Tensor Class Issues:
Validating rowwise/columnwise Usages for quantizers/tensors in TE Layers
Resetting Parameters for Deferred Initialization(meta device)
Test and Miscellaneous issues
Checklist:
Summary by CodeRabbit
Release Notes