-
Notifications
You must be signed in to change notification settings - Fork 233
Add support for MXFP8 PTQ #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
7454f24 to
16f12fa
Compare
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
16f12fa to
88b6869
Compare
Signed-off-by: Daniel Serebrenik <[email protected]>
|
Could you also add the corresponding unit tests for impacted functions in quant_utils.py here? Thanks! |
| # Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent) | ||
| scale_factor = torch.exp2(127 - e8m0_scale.float()) | ||
|
|
||
| # NOTE: vLLM/flashinfer may require this behavior: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this required? Should we assert e8m0_scale != 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIU, it doesn't align with MXFP8 specification.
But one of my teammates said that it worked for him in a certain case.
So I wanted to leave some documentation for it for future reference.
| # sm89 | ||
| PTQCommand(quant="fp8", min_sm=89), | ||
| PTQCommand(quant="fp8", kv_cache_quant="none", min_sm=89), # sm100 | ||
| PTQCommand(quant="mxfp8", min_sm=100), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does hopper support mxfp8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blackwell has hardware acceleration for MXFP8.
Hopper does not.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#MXFP8-and-block-scaling
NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: MXFP8.
See what we have for NVFP4 (line below the "mxfp8"):
PTQCommand(quant="nvfp4", min_sm=100),
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #736 +/- ##
==========================================
- Coverage 74.69% 74.42% -0.27%
==========================================
Files 192 193 +1
Lines 18948 19043 +95
==========================================
+ Hits 14153 14173 +20
- Misses 4795 4870 +75 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
meenchen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| assert dequant_tensor.shape == input_shape, ( | ||
| f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}" | ||
| ) | ||
| assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also compare with the fake quant here.
| "test_input", | ||
| [ | ||
| # FP8 E4M3 boundary test values (max is 448, various powers of 2) | ||
| torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The format looks weird, we can turn off the auto format for the tensors, and define them on the top.
Signed-off-by: Daniel Serebrenik <[email protected]>
Signed-off-by: Daniel Serebrenik <[email protected]>
What does this PR do?
Type of change: new feature
Overview: Add support for MXFP8 PTQ, enabling MXFP8 hardware acceleration during inference on Blackwell GPUs.
Usage
The
hf_quant_config.jsonof the output checkpoint:{ "producer": { "name": "modelopt", "version": "0.41.0.dev50+g7a796a875" }, "quantization": { "quant_algo": "MXFP8", "kv_cache_quant_algo": "FP8", "group_size": 32, "exclude_modules": [ "lm_head" ] } }And
config.json(only thequantization_config):Testing
Used
hf_ptq.pyto quantize the modelnvidia/OpenMath2-Llama3.1-8B(available in hugging-face), see the example command above.Checked that the generated MXFP8 checkpoint can be loaded with vLLM (required changes in vLLM, not merged to main).
Added tests for
MXFP8QTensorintests/gpu/torch/quantization/test_qtensor_cuda.py.Added "mxfp8" in
tests/examples/llm_ptq/test_llm_ptq.pyBefore your PR is "Ready for review"
Additional Information