Skip to content
Open
3 changes: 3 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
}

KV_QUANT_CFG_CHOICES = {
Expand Down Expand Up @@ -184,6 +185,7 @@ def auto_quantize(
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"mxfp8",
]
for args.qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"
Expand Down Expand Up @@ -766,6 +768,7 @@ def quantize_main(
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"mxfp8",
]
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
), f"Plain quantization format {args.qformat} not supported for HF export path"
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | mxfp8) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, mxfp8]" >&2
exit 1
;;
esac
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
QUANTIZATION_NVFP4 = "nvfp4"
QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8"
QUANTIZATION_MXFP4 = "mxfp4"
QUANTIZATION_MXFP8 = "mxfp8"
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
QUANTIZATION_NVFP4_AWQ = "nvfp4_awq"
QUANTIZATION_FP8_PB_REAL = "fp8_pb_real"
Expand Down
21 changes: 21 additions & 0 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from modelopt.torch.quantization.qtensor import (
FP8QTensor,
MXFP4QTensor,
MXFP8QTensor,
NVFP4QTensor,
QTensorWrapper,
)
Expand All @@ -54,6 +55,7 @@
QUANTIZATION_INT8_SQ,
QUANTIZATION_INT8_WO,
QUANTIZATION_MXFP4,
QUANTIZATION_MXFP8,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
Expand Down Expand Up @@ -290,6 +292,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
1
].reshape(*weight.shape[:-1], -1)

if quantization_format == QUANTIZATION_MXFP8:
return MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, weight_quantizer)
return get_scaling_factor(weight_quantizer)


Expand Down Expand Up @@ -474,6 +479,14 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
if weight_quantizer.num_bits == (4, 3):
if weight_quantizer.block_sizes:
assert weight_quantizer.block_sizes[-1] > 0, "Invalid block_sizes for FP8 quantizer"
# Check if this is MXFP8 (dynamic block quantization with scale_bits (8, 0))
block_sizes = getattr(weight_quantizer, "block_sizes")
if (
isinstance(block_sizes, dict)
and block_sizes.get("type", "static") == "dynamic"
and block_sizes.get("scale_bits") == (8, 0)
):
return QUANTIZATION_MXFP8
if weight_quantizer.fake_quant:
return QUANTIZATION_FP8_PB_WO
else:
Expand Down Expand Up @@ -669,6 +682,11 @@ def process_layer_quant_config(layer_config_dict):
"quant_algo": "W4A8_MXFP4_FP8",
"group_size": block_size_value,
}
elif v == "mxfp8":
layer_config = {
"quant_algo": "MXFP8",
"group_size": block_size_value,
}
else:
layer_config = {"quant_algo": v}

Expand Down Expand Up @@ -773,6 +791,9 @@ def to_quantized_weight(
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)

if quantization == QUANTIZATION_MXFP8:
return MXFP8QTensor.quantize_with_scale(weight, weights_scaling_factor)

if quantization == QUANTIZATION_FP8_PB_WO:
return FP8QTensor.quantize(
weight, weights_scaling_factor.squeeze(), block_sizes={-1: block_size, -2: block_size}
Expand Down
12 changes: 11 additions & 1 deletion modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from modelopt.torch.quantization import set_quantizer_by_cfg_context
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.qtensor import NVFP4QTensor
from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names

from .convert_hf_config import convert_hf_quant_config_format
Expand All @@ -51,6 +51,7 @@
QUANTIZATION_FP8,
QUANTIZATION_FP8_PB_REAL,
QUANTIZATION_FP8_PC_PT,
QUANTIZATION_MXFP8,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
Expand Down Expand Up @@ -297,6 +298,15 @@ def _export_quantized_weight(
weight_quantizer._scale.to(torch.float32),
)
del weight_quantizer._scale
elif quantization_format == QUANTIZATION_MXFP8:
# MXFP8 uses dynamic block quantization with E8M0 scales (uint8)
weight = getattr(sub_module, weight_name)
e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(
weight, weight_quantizer
)
sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale)
if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
del weight_quantizer._scale
else:
sub_module.register_buffer(
quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name)
Expand Down
41 changes: 27 additions & 14 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
INT4QTensor,
INT8QTensor,
MXFP4QTensor,
MXFP8QTensor,
NF4QTensor,
NVFP4QTensor,
QTensorWrapper,
Expand Down Expand Up @@ -649,8 +650,32 @@ def _real_quantize(self, inputs):
assert self._is_real_quantize_support(), "Real quantization not supported for this format."

buffer_to_register = {}
if self._num_bits == (4, 3):
# FP8 quantization
# Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3)
if (
self._block_sizes
and self._block_sizes.get("scale_bits") == (8, 0)
and self._block_sizes.get("type") == "dynamic"
):
# MX quantization (MXFP4/MXFP8)
if self._num_bits == (2, 1):
# MXFP4
outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1])
buffer_to_register["_scale"] = scales
elif self._num_bits == (4, 3):
# MXFP8
assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, (
f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, "
f"got {self._block_sizes[-1]}"
)
outputs, scales = MXFP8QTensor.quantize(inputs)
buffer_to_register["_scale"] = scales
else:
raise ValueError(
f"Unsupported MX format: num_bits={self._num_bits}. "
f"Expected (2, 1) for MXFP4 or (4, 3) for MXFP8."
)
elif self._num_bits == (4, 3):
# FP8 quantization (non-MX)
# For per-tensor/per-channel quantization, we might need amax which is synced across all ranks
# For blockwise quantization, amax will be recomputed in the kernel
use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1)
Expand Down Expand Up @@ -683,18 +708,6 @@ def _real_quantize(self, inputs):
buffer_to_register["_scale"] = _scale
buffer_to_register["_double_scale"] = _double_scale
buffer_to_register["_scale_zeros"] = _scale_zeros
elif (
self._block_sizes.get("scale_bits") == (8, 0)
and self._block_sizes.get("type") == "dynamic"
):
# MX quantization
if self._num_bits == (2, 1):
outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1])
buffer_to_register["_scale"] = scales
else:
raise ValueError(
f"Real quantization for MX {self._num_bits} format is not supported."
)
elif self._block_sizes.get("scale_bits") == (4, 3):
# NVFP4 default quantization
# Return real quantized tensor and store scales inside TensorQuantizer
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/qtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from .int4_tensor import *
from .int8_tensor import *
from .mxfp4_tensor import *
from .mxfp8_tensor import *
from .nf4_tensor import *
from .nvfp4_tensor import *
Loading