-
Notifications
You must be signed in to change notification settings - Fork 239
Add support for MXFP8 PTQ #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
16f12fa to
88b6869
Compare
|
Could you also add the corresponding unit tests for impacted functions in quant_utils.py here? Thanks! |
| # Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent) | ||
| scale_factor = torch.exp2(127 - e8m0_scale.float()) | ||
|
|
||
| # NOTE: vLLM/flashinfer may require this behavior: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this required? Should we assert e8m0_scale != 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIU, it doesn't align with MXFP8 specification.
But one of my teammates said that it worked for him in a certain case.
So I wanted to leave some documentation for it for future reference.
| # sm89 | ||
| PTQCommand(quant="fp8", min_sm=89), | ||
| PTQCommand(quant="fp8", kv_cache_quant="none", min_sm=89), # sm100 | ||
| PTQCommand(quant="mxfp8", min_sm=100), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does hopper support mxfp8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blackwell has hardware acceleration for MXFP8.
Hopper does not.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#MXFP8-and-block-scaling
NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: MXFP8.
See what we have for NVFP4 (line below the "mxfp8"):
PTQCommand(quant="nvfp4", min_sm=100),
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #736 +/- ##
==========================================
- Coverage 74.23% 74.02% -0.21%
==========================================
Files 192 193 +1
Lines 19033 19113 +80
==========================================
+ Hits 14129 14149 +20
- Misses 4904 4964 +60 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
meenchen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| assert dequant_tensor.shape == input_shape, ( | ||
| f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}" | ||
| ) | ||
| assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also compare with the fake quant here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand, fake quant is tested by test_qtensor_accuracy (part of class TestQTensor).
In the code below the comment # compare with fake quant as well.
I added a test case for MXFP8 in test test_qtensor_accuracy.
And checked that it works using this command:
pytest --maxfail 1 tests/gpu/torch/quantization/test_qtensor_cuda.py -k "test_qtensor_accuracy"
All the new MXFP8 tests also worked, using this command:
pytest tests/gpu/torch/quantization/test_qtensor_cuda.py -k "test_mxfp8"
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
9b0c088 to
a764b32
Compare
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughThis change introduces support for MXFP8 quantization format across the codebase. A new MXFP8QTensor class implements block-based FP8 E4M3 quantization with E8M0 shared scales. MXFP8 support is integrated into configuration, quantization utilities, export workflows, and test coverage. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/qtensor/mxfp8_tensor.py`:
- Around line 93-138: get_weights_scaling_factor_from_quantizer currently
assumes 2D weights and computes expected_shape = (out_dim, in_dim //
BLOCK_SIZE), which breaks for 3D MoE weights (num_experts, out_dim, in_dim)
because reduce_block_amax yields a 3D scale; update the method to detect MoE by
checking weight.dim() == 3 and set expected_shape = (num_experts, out_dim,
in_dim // cls.BLOCK_SIZE) in that case (or mirror the NVFP4 transpose guard
behavior before calling this method), then after pulling weight_quantizer._scale
ensure scale.shape exactly equals expected_shape (after an allowed reshape only
when numel matches) and raise/assert with a clear message if it does not;
reference symbols: get_weights_scaling_factor_from_quantizer,
get_weights_scaling_factor, cls.BLOCK_SIZE, cls.SCALE_DTYPE, and
weight_quantizer._scale.
🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_hf.py (1)
301-309: Minor redundancy:weightis already available from line 250.The MXFP8 export logic is correct. However,
weightis already fetched at line 250 viagetattr(sub_module, weight_name), so line 303 re-fetches the same value unnecessarily.♻️ Optional: reuse existing weight variable
elif quantization_format == QUANTIZATION_MXFP8: # MXFP8 uses dynamic block quantization with E8M0 scales (uint8) - weight = getattr(sub_module, weight_name) e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer( weight, weight_quantizer )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
examples/llm_ptq/hf_ptq.pyexamples/llm_ptq/scripts/huggingface_example.shmodelopt/torch/export/model_config.pymodelopt/torch/export/quant_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/qtensor/__init__.pymodelopt/torch/quantization/qtensor/mxfp8_tensor.pytests/examples/llm_ptq/test_llm_ptq.pytests/gpu/torch/quantization/test_qtensor_cuda.py
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (2)
MXFP8QTensor(26-269)quantize(195-222)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (2)
MXFP8QTensor(26-269)get_weights_scaling_factor_from_quantizer(94-138)
tests/examples/llm_ptq/test_llm_ptq.py (2)
tests/_test_utils/examples/llm_ptq_utils.py (1)
PTQCommand(28-87)tests/_test_utils/torch/quantization/quant_utils.py (1)
quant(19-30)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (3)
MXFP8QTensor(26-269)get_weights_scaling_factor_from_quantizer(94-138)quantize_with_scale(141-192)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
block_sizes(360-362)block_sizes(365-367)
🔇 Additional comments (32)
examples/llm_ptq/scripts/huggingface_example.sh (2)
56-58: LGTM!The
mxfp8format is correctly added to the valid quantization formats list and the error message is updated consistently.
207-210: Verify ifmxfp8should be added to the supported formats for TRT-LLM torch runtime.The
mxfp8format is not included in the check on line 207, meaning MXFP8-quantized models will exit early with a message to use TensorRT-LLM for deployment. If MXFP8 should support the same workflow asfp8andnvfp4(continuing torun_tensorrt_llm.py), consider adding it:- if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then + if [[ ! " fp8 nvfp4 bf16 fp16 mxfp8 " =~ " ${QFORMAT} " ]]; thenmodelopt/torch/export/model_config.py (1)
38-38: LGTM!The new
QUANTIZATION_MXFP8constant follows the established naming convention and is correctly placed among related quantization format identifiers.tests/examples/llm_ptq/test_llm_ptq.py (1)
117-117: LGTM!The MXFP8 test case is correctly configured with
min_sm=100to ensure it only runs on Blackwell GPUs which have hardware acceleration for MXFP8.modelopt/torch/quantization/qtensor/__init__.py (1)
23-23: LGTM!The
mxfp8_tensormodule export follows the established pattern and is correctly positioned alphabetically among the other tensor module imports.examples/llm_ptq/hf_ptq.py (3)
175-191: LGTM!The
mxfp8format is correctly added to the auto-quantize validation list, enabling MXFP8 as a valid format option for automatic per-layer quantization search.
759-774: LGTM!The
mxfp8format is correctly added to the mono-quantize validation list for the HF export path.
86-86: LGTM!The
mxfp8format is correctly mapped tomtq.MXFP8_DEFAULT_CFGin the quantization configuration choices dictionary. The constant is properly defined and exported in the mtq module.modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
52-52: LGTM!The import addition for
MXFP8QTensoris correctly placed alongside other quantized tensor imports.
693-703: LGTM!The MXFP8 branch correctly:
- Validates block size matches the MXFP8 spec (32)
- Uses
MXFP8QTensor.quantize()which handles block-based quantization internally- Stores scales in the same manner as MXFP4
The distinction between MXFP4
(2, 1)and MXFP8(4, 3)num_bits is clear and properly separated.tests/gpu/torch/quantization/test_qtensor_cuda.py (9)
18-18: LGTM!The imports for
mathandMXFP8QTensorare correctly added to support the new MXFP8 tests.Also applies to: 27-27
253-260: LGTM!The MXFP8 test case is correctly added to
test_qtensor_accuracywith appropriate configuration matching the MXFP8 spec (block size 32, dynamic type, scale_bits (8,0)).
616-676: LGTM!Comprehensive test for MXFP8 quantize/dequantize covering:
- Multiple devices (cuda, cpu)
- Multiple dtypes (float32, float16, bfloat16)
- Various shapes including 3D MoE-like tensors
- Padding scenarios (dimensions not divisible by 32)
- Proper assertions for scale dtype, shapes, and quantized data format
The tolerance of
rtol=5e-2, atol=5e-2is reasonable for FP8 quantization precision.
678-716: LGTM!Excellent test for verifying E8M0 scale computation with known input values. The test validates that per-block max values are preserved through the quantize-dequantize cycle.
718-751: LGTM!Good boundary value testing for FP8 E4M3 limits (max 448, powers of 2, positive/negative values). The
# fmt: off/onmarkers appropriately preserve the readable tensor formatting.
753-782: LGTM!Memory usage test follows the same pattern as the existing NVFP4 test. The 3x threshold is reasonable given MXFP8 stores FP8 data plus uint8 scales.
784-806: LGTM!Tests for
get_weights_scaling_factorwith proper shape and dtype validation. The check for E8M0 values ≤ 254 correctly excludes NaN representation (255).
808-824: LGTM!Good coverage of edge cases for
_compute_e8m0_exponent:
- Zero amax → minimum exponent (-127)
- E4M3_MAX (448) → exponent 0
- Normal value (1.0) → computed exponent
- Very large/small values → clamped to valid range
826-889: LGTM!Comprehensive error handling tests covering:
- 1D tensor assertions
- Non-divisible dimensions
- Wrong scale dtype
- Empty tensor
- 0D tensor (scalar)
- Non-floating point input
- Missing scale in dequantize
This ensures robust input validation.
modelopt/torch/export/unified_export_hf.py (1)
35-35: LGTM!The imports for
MXFP8QTensorandQUANTIZATION_MXFP8are correctly added to support MXFP8 export handling.Also applies to: 54-54
modelopt/torch/export/quant_utils.py (5)
33-33: LGTM!The imports for
MXFP8QTensorandQUANTIZATION_MXFP8are correctly added.Also applies to: 58-58
296-297: LGTM!The MXFP8 weight scaling factor retrieval correctly delegates to
MXFP8QTensor.get_weights_scaling_factor_from_quantizer, which handles both extracting existing scales and computing new ones.
482-489: LGTM!The MXFP8 detection logic correctly identifies the format by checking:
block_sizesis a dicttypeis"dynamic"scale_bitsis(8, 0)(E8M0 format)This is properly positioned before the FP8_PB_WO/FP8_PB_REAL checks at lines 490-493, ensuring MXFP8 is correctly distinguished from other FP8 block quantization formats.
685-689: LGTM!The MXFP8 layer config processing correctly maps the
"mxfp8"format to"MXFP8"quant_algo with the appropriate group_size, following the same pattern as other quantization formats.
794-795: LGTM!The
to_quantized_weightfunction correctly usesMXFP8QTensor.quantize_with_scaleto apply the pre-computed E8M0 scale to the weight tensor.modelopt/torch/quantization/qtensor/mxfp8_tensor.py (7)
1-23: LGTM!Clean module structure with proper license, docstring, imports from existing utilities (
reduce_block_amax,reduce_block_padding), and explicit__all__export.
26-40: LGTM!Class constants are correctly defined:
E4M3_MAX = 448.0matches FP8 E4M3 max valueBLOCK_SIZE = 32per MXFP8 specificationSCALE_DTYPE = torch.uint8for E8M0 biased exponent storage
42-66: LGTM!The
_compute_e8m0_exponentimplementation:
- Converts to float32 for numerical stability
- Handles zero values by using
torch.wherewith min_value fallback- Correctly computes
ceil(log2(amax / E4M3_MAX))- Clamps to valid E8M0 range [-127, 127]
68-91: LGTM!The
get_weights_scaling_factorimplementation correctly:
- Validates 2D minimum dimension
- Validates divisibility by BLOCK_SIZE
- Uses existing
reduce_block_amaxutility- Converts to biased uint8 format (exponent + 127)
140-192: LGTM!The
quantize_with_scaleimplementation is well-structured:
- Proper input validation for dimensions and dtype
- Flexible scale reshaping to handle different input shapes
- Correct scale factor computation:
2^(127 - exponent)- Proper clamping to E4M3 range before FP8 conversion
- The NOTE comment documents potential vLLM/flashinfer compatibility consideration
194-222: LGTM!The
quantizemethod correctly implements the full quantization flow:
- Input validation for empty, dimension, and dtype
- Padding alignment via
reduce_block_padding- Per-block amax computation
- E8M0 exponent computation and biasing
- Shape restoration via cropping
224-269: LGTM!The
dequantizemethod correctly reverses the quantization:
- Requires scale in kwargs (enforced by assertion)
- Converts quantized data to float for computation
- Applies padding for block alignment
- Computes descale as
2^(exponent - 127)- Handles scale shape broadcasting
- Restores original shape via cropping
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
…nsor Signed-off-by: Daniel Serebrenik <[email protected]>
Tested by test_qtensor_accuracy. Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
What does this PR do?
Type of change: new feature
Overview: Add support for MXFP8 PTQ, enabling MXFP8 hardware acceleration during inference on Blackwell GPUs.
Usage
The
hf_quant_config.jsonof the output checkpoint:{ "producer": { "name": "modelopt", "version": "0.41.0.dev50+g7a796a875" }, "quantization": { "quant_algo": "MXFP8", "kv_cache_quant_algo": "FP8", "group_size": 32, "exclude_modules": [ "lm_head" ] } }And
config.json(only thequantization_config):Testing
Used
hf_ptq.pyto quantize the modelnvidia/OpenMath2-Llama3.1-8B(available in hugging-face), see the example command above.Checked that the generated MXFP8 checkpoint can be loaded with vLLM (required changes in vLLM, not merged to main).
Added tests for
MXFP8QTensorintests/gpu/torch/quantization/test_qtensor_cuda.py.Added "mxfp8" in
tests/examples/llm_ptq/test_llm_ptq.pySupport for Nemotron Models
Verify that Nemotron Nano V3 BF16 can be converted to MXFP8 using
hf_ptq.py:https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.