diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a9862a742..1dba72324 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -83,6 +83,7 @@ "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, + "mxfp8": mtq.MXFP8_DEFAULT_CFG, } KV_QUANT_CFG_CHOICES = { @@ -184,6 +185,7 @@ def auto_quantize( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "mxfp8", ] for args.qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" @@ -766,6 +768,7 @@ def quantize_main( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "mxfp8", ] or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES ), f"Plain quantization format {args.qformat} not supported for HF export path" diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 043b690e5..ec415b2f7 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | mxfp8) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, mxfp8]" >&2 exit 1 ;; esac diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index 306348f2c..9553b4fcf 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -35,6 +35,7 @@ QUANTIZATION_NVFP4 = "nvfp4" QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8" QUANTIZATION_MXFP4 = "mxfp4" +QUANTIZATION_MXFP8 = "mxfp8" QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8" QUANTIZATION_NVFP4_AWQ = "nvfp4_awq" QUANTIZATION_FP8_PB_REAL = "fp8_pb_real" diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..87b1018c3 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -30,6 +30,7 @@ from modelopt.torch.quantization.qtensor import ( FP8QTensor, MXFP4QTensor, + MXFP8QTensor, NVFP4QTensor, QTensorWrapper, ) @@ -54,6 +55,7 @@ QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO, QUANTIZATION_MXFP4, + QUANTIZATION_MXFP8, QUANTIZATION_NONE, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, @@ -290,6 +292,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[ 1 ].reshape(*weight.shape[:-1], -1) + + if quantization_format == QUANTIZATION_MXFP8: + return MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, weight_quantizer) return get_scaling_factor(weight_quantizer) @@ -474,6 +479,14 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames if weight_quantizer.num_bits == (4, 3): if weight_quantizer.block_sizes: assert weight_quantizer.block_sizes[-1] > 0, "Invalid block_sizes for FP8 quantizer" + # Check if this is MXFP8 (dynamic block quantization with scale_bits (8, 0)) + block_sizes = getattr(weight_quantizer, "block_sizes") + if ( + isinstance(block_sizes, dict) + and block_sizes.get("type", "static") == "dynamic" + and block_sizes.get("scale_bits") == (8, 0) + ): + return QUANTIZATION_MXFP8 if weight_quantizer.fake_quant: return QUANTIZATION_FP8_PB_WO else: @@ -669,6 +682,11 @@ def process_layer_quant_config(layer_config_dict): "quant_algo": "W4A8_MXFP4_FP8", "group_size": block_size_value, } + elif v == "mxfp8": + layer_config = { + "quant_algo": "MXFP8", + "group_size": block_size_value, + } else: layer_config = {"quant_algo": v} @@ -773,6 +791,9 @@ def to_quantized_weight( if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8) + if quantization == QUANTIZATION_MXFP8: + return MXFP8QTensor.quantize_with_scale(weight, weights_scaling_factor) + if quantization == QUANTIZATION_FP8_PB_WO: return FP8QTensor.quantize( weight, weights_scaling_factor.squeeze(), block_sizes={-1: block_size, -2: block_size} diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 1dd1c1822..6296bb816 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -32,7 +32,7 @@ from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer -from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names from .convert_hf_config import convert_hf_quant_config_format @@ -51,6 +51,7 @@ QUANTIZATION_FP8, QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PC_PT, + QUANTIZATION_MXFP8, QUANTIZATION_NONE, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, @@ -297,6 +298,15 @@ def _export_quantized_weight( weight_quantizer._scale.to(torch.float32), ) del weight_quantizer._scale + elif quantization_format == QUANTIZATION_MXFP8: + # MXFP8 uses dynamic block quantization with E8M0 scales (uint8) + weight = getattr(sub_module, weight_name) + e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer( + weight, weight_quantizer + ) + sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale) + if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: + del weight_quantizer._scale else: sub_module.register_buffer( quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 7d3fa1251..0dde20eec 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -49,6 +49,7 @@ INT4QTensor, INT8QTensor, MXFP4QTensor, + MXFP8QTensor, NF4QTensor, NVFP4QTensor, QTensorWrapper, @@ -649,8 +650,32 @@ def _real_quantize(self, inputs): assert self._is_real_quantize_support(), "Real quantization not supported for this format." buffer_to_register = {} - if self._num_bits == (4, 3): - # FP8 quantization + # Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3) + if ( + self._block_sizes + and self._block_sizes.get("scale_bits") == (8, 0) + and self._block_sizes.get("type") == "dynamic" + ): + # MX quantization (MXFP4/MXFP8) + if self._num_bits == (2, 1): + # MXFP4 + outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) + buffer_to_register["_scale"] = scales + elif self._num_bits == (4, 3): + # MXFP8 + assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, ( + f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, " + f"got {self._block_sizes[-1]}" + ) + outputs, scales = MXFP8QTensor.quantize(inputs) + buffer_to_register["_scale"] = scales + else: + raise ValueError( + f"Unsupported MX format: num_bits={self._num_bits}. " + f"Expected (2, 1) for MXFP4 or (4, 3) for MXFP8." + ) + elif self._num_bits == (4, 3): + # FP8 quantization (non-MX) # For per-tensor/per-channel quantization, we might need amax which is synced across all ranks # For blockwise quantization, amax will be recomputed in the kernel use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1) @@ -683,18 +708,6 @@ def _real_quantize(self, inputs): buffer_to_register["_scale"] = _scale buffer_to_register["_double_scale"] = _double_scale buffer_to_register["_scale_zeros"] = _scale_zeros - elif ( - self._block_sizes.get("scale_bits") == (8, 0) - and self._block_sizes.get("type") == "dynamic" - ): - # MX quantization - if self._num_bits == (2, 1): - outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) - buffer_to_register["_scale"] = scales - else: - raise ValueError( - f"Real quantization for MX {self._num_bits} format is not supported." - ) elif self._block_sizes.get("scale_bits") == (4, 3): # NVFP4 default quantization # Return real quantized tensor and store scales inside TensorQuantizer diff --git a/modelopt/torch/quantization/qtensor/__init__.py b/modelopt/torch/quantization/qtensor/__init__.py index c4ed88f87..9c623c1bd 100644 --- a/modelopt/torch/quantization/qtensor/__init__.py +++ b/modelopt/torch/quantization/qtensor/__init__.py @@ -20,5 +20,6 @@ from .int4_tensor import * from .int8_tensor import * from .mxfp4_tensor import * +from .mxfp8_tensor import * from .nf4_tensor import * from .nvfp4_tensor import * diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py new file mode 100644 index 000000000..e87612f3a --- /dev/null +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements MXFP8 quantization for efficient tensor storage and computation.""" + +import torch + +from ..qtensor.base_qtensor import BaseQuantizedTensor +from ..utils import reduce_block_amax, reduce_block_padding + +__all__ = ["MXFP8QTensor"] + + +class MXFP8QTensor(BaseQuantizedTensor): + """Implements the MXFP8 quantization on tensors for more efficient storage or computation. + + MXFP8 uses: + - FP8 E4M3 format for elements + - E8M0 format for shared scales (power-of-2 only, stored as biased uint8 exponent) + - Block size of 32 elements along the last dimension + + Attributes: + quantized_data (torch.Tensor): The quantized data stored as float8_e4m3fn tensor. + """ + + E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 + BLOCK_SIZE = 32 + SCALE_DTYPE = torch.uint8 # E8M0 format stores biased exponent as uint8 + + @classmethod + def _compute_e8m0_exponent(cls, amax: torch.Tensor) -> torch.Tensor: + """Compute E8M0 exponent from per-block amax values. + + Args: + amax: Per-block absolute max values. + + Returns: + torch.Tensor: Float tensor of E8M0 exponents (unbiased, range [-127, 127]). + """ + # Compute E8M0 scale: scale = 2^ceil(log2(amax / E4M3_max)) + descale = amax.float() / cls.E4M3_MAX + + # Handle zero/inf/nan cases + min_value = torch.tensor(-127.0, device=descale.device) + log2_descale = torch.where( + descale > 0, + torch.log2(descale), + min_value, + ) + + e8m0_exponent = torch.ceil(log2_descale) + + # Clamp exponent to valid E8M0 range + return torch.clamp(e8m0_exponent, min=-127, max=127) + + @classmethod + def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor: + """Returns E8M0 scale (uint8 biased exponent) for weight tensor. + + Args: + weight: The weight tensor to compute scale for. Must be at least 2D. + Supports 2D (out_dim, in_dim) and 3D MoE (num_experts, out_dim, in_dim). + + Returns: + torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32]. + For 2D input: (out_dim, in_dim // 32) + For 3D MoE input: (num_experts, out_dim, in_dim // 32) + """ + assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" + + in_dim = weight.shape[-1] + + assert in_dim % cls.BLOCK_SIZE == 0, ( + f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" + ) + + # Compute amax per block (reduce_block_amax handles N-dimensional tensors) + amax = reduce_block_amax(weight, block_sizes={-1: cls.BLOCK_SIZE}) + + # Compute E8M0 exponent and convert to biased uint8 (bias = 127) + e8m0_exponent = cls._compute_e8m0_exponent(amax) + return (e8m0_exponent + 127).to(cls.SCALE_DTYPE) + + @classmethod + def get_weights_scaling_factor_from_quantizer( + cls, + weight: torch.Tensor, + weight_quantizer, + ) -> torch.Tensor: + """Returns E8M0 scale from quantizer or computes from weight. + + This method handles extracting the scale from a weight quantizer, + with proper format conversion and shape correction. + + Args: + weight: The weight tensor. Can be 2D (out_dim, in_dim) or + 3D for MoE (num_experts, out_dim, in_dim). + weight_quantizer: The weight quantizer with block_sizes and optional _scale. + + Returns: + torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32]. + """ + assert hasattr(weight_quantizer, "block_sizes"), ( + "weight_quantizer must have 'block_sizes' attribute" + ) + assert weight_quantizer.block_sizes[-1] == cls.BLOCK_SIZE, ( + f"MXFP8 requires block size {cls.BLOCK_SIZE}, got {weight_quantizer.block_sizes[-1]}" + ) + assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" + + in_dim = weight.shape[-1] + # Expected scale shape: all dims except last, with last dim reduced by block size + # For 2D: (out_dim, in_dim // 32) + # For 3D MoE: (num_experts, out_dim, in_dim // 32) + expected_shape = (*weight.shape[:-1], in_dim // cls.BLOCK_SIZE) + + if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: + scale = weight_quantizer._scale + + assert scale.dtype == cls.SCALE_DTYPE, ( + f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}" + ) + assert scale.shape == expected_shape, ( + f"Scale shape {scale.shape} does not match expected shape {expected_shape}" + ) + return scale + + # No scale in quantizer, compute from weight + return cls.get_weights_scaling_factor(weight) + + @classmethod + def quantize_with_scale( + cls, + weight: torch.Tensor, + e8m0_scale: torch.Tensor, + ) -> torch.Tensor: + """Quantize weight tensor using a pre-computed E8M0 scale. + + This method is useful for export paths where the scale has already been computed. + + Args: + weight: The weight tensor to quantize. Must be at least 1D. + e8m0_scale: E8M0 scale as uint8 biased exponent (bias = 127). + Shape should be [..., out_dim, in_dim // 32] for 2D+ tensors, + or [in_dim // 32] for 1D tensors. + + Returns: + torch.Tensor: Quantized weight as float8_e4m3fn with same shape as input. + """ + assert e8m0_scale.dtype == cls.SCALE_DTYPE, ( + f"e8m0_scale must be {cls.SCALE_DTYPE} (E8M0 format), got {e8m0_scale.dtype}" + ) + + in_dim = weight.shape[-1] + num_blocks = in_dim // cls.BLOCK_SIZE + + assert in_dim % cls.BLOCK_SIZE == 0, ( + f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" + ) + + # Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent) + scale_factor = torch.exp2(127 - e8m0_scale.float()) + + # NOTE: vLLM/flashinfer may require this behavior: + # scale_factor = torch.where( + # e8m0_scale == 0, + # 1.0, + # torch.exp2(127 - e8m0_scale.float()) + # ) + + weight_reshaped = weight.view(*weight.shape[:-1], num_blocks, cls.BLOCK_SIZE) + scale_factor_expanded = scale_factor.unsqueeze(-1) + scaled_weight = weight_reshaped * scale_factor_expanded + scaled_weight = torch.clamp(scaled_weight, min=-cls.E4M3_MAX, max=cls.E4M3_MAX) + quantized_weight = scaled_weight.to(torch.float8_e4m3fn) + + return quantized_weight.view(weight.shape) + + @classmethod + def quantize(cls, input: torch.Tensor) -> tuple: + """Convert a tensor to MXFP8 quantized format. + + Args: + input (torch.Tensor): The input tensor to be quantized. + + Returns: + tuple: (MXFP8QTensor, e8m0_scale) where e8m0_scale is uint8 biased exponent. + """ + original_shape = input.shape + original_dtype = input.dtype + + input = reduce_block_padding(input, block_sizes={-1: cls.BLOCK_SIZE}) + input_amax = reduce_block_amax(input, block_sizes={-1: cls.BLOCK_SIZE}) + + e8m0_exponent = cls._compute_e8m0_exponent(input_amax) + e8m0_scale = (e8m0_exponent + 127).to(cls.SCALE_DTYPE) + + quantized_data = cls.quantize_with_scale(input, e8m0_scale) + + # Crop back to original shape + quantized_data = quantized_data[..., : original_shape[-1]] + + return cls(original_shape, original_dtype, quantized_data), e8m0_scale + + def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: + """Dequantize MXFP8 tensor back to the target dtype. + + Args: + dtype (torch.dtype | None): Target dtype for dequantization. Defaults to original dtype. + **kwargs: Must contain 'scale' (E8M0 biased uint8). + + Returns: + torch.Tensor: Dequantized tensor in the target dtype. + """ + assert "scale" in kwargs, "dequantize requires 'scale' in kwargs" + + e8m0_scale = kwargs["scale"] + + if dtype is None: + dtype = self.metadata["dtype"] + + original_shape = self.metadata["shape"] + quantized_data = self._quantized_data.float() + quantized_data = reduce_block_padding(quantized_data, block_sizes={-1: self.BLOCK_SIZE}) + + num_blocks = quantized_data.shape[-1] // self.BLOCK_SIZE + quantized_blocked = quantized_data.view( + *quantized_data.shape[:-1], num_blocks, self.BLOCK_SIZE + ) + + # Convert E8M0 biased exponent back to scale factor: descale = 2^(exponent - 127) + descale = torch.exp2(e8m0_scale.float() - 127) + + dequantized = quantized_blocked * descale.unsqueeze(-1) + + # Reshape and crop back to original shape + dequantized = dequantized.view(*quantized_data.shape[:-1], quantized_data.shape[-1]) + dequantized = dequantized[..., : original_shape[-1]] + + return dequantized.to(dtype) diff --git a/tests/examples/llm_ptq/test_llm_ptq.py b/tests/examples/llm_ptq/test_llm_ptq.py index 6ba23cc04..4fc39f5ec 100644 --- a/tests/examples/llm_ptq/test_llm_ptq.py +++ b/tests/examples/llm_ptq/test_llm_ptq.py @@ -114,6 +114,7 @@ def test_ptq_whisper(self, command): # sm89 PTQCommand(quant="fp8", min_sm=89), PTQCommand(quant="fp8", kv_cache_quant="none", min_sm=89), # sm100 + PTQCommand(quant="mxfp8", min_sm=100), PTQCommand(quant="nvfp4", min_sm=100), # # multi_gpu diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index 26df7a8c8..269f0fa63 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -15,6 +15,8 @@ """Unit tests for quantized tensors.""" +import math + import pytest import torch from _test_utils.torch.misc import set_seed @@ -22,7 +24,7 @@ from modelopt.torch.quantization.backends.utils import fp4_compatible from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import TensorQuantizer -from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor set_seed() @@ -248,6 +250,14 @@ def test_amax_from_tensor_quantizer( torch.randn([512, 512], dtype=torch.float32), None, ), + # MXFP8 + ( + (4, 3), + {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + None, + torch.randn([512, 512], dtype=torch.float32), + None, + ), ], ) @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -602,3 +612,300 @@ def test_fp8_with_amax_and_block_sizes(self, device, input_dtype, input_shape, b assert torch.allclose(deq_x, x, rtol=1e-1, atol=1e-1) assert hasattr(quantizer, "_scale") assert quantizer._scale.numel() > 1 + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "input_shape", + [ + (128, 128), + (256, 64), + (512, 512), + # 3D shapes (MoE): (num_experts, out_dim, in_dim) + (4, 64, 128), + (1, 64, 128), # single expert edge case + (32, 256, 512), # large-scale MoE + # Shapes requiring padding (last dim not divisible by block size 32) + (8, 128, 65), # odd in_dim + (128, 65), + (256, 100), + (64, 33), + ], + ) + def test_mxfp8_quantize_dequantize(self, device, input_dtype, input_shape): + """Test MXFP8 quantization and dequantization produces correct E8M0 scales.""" + # Create test tensor + test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + # Quantize using MXFP8QTensor + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor) + + # Verify scale is uint8 (E8M0 format) + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + + # Verify scale shape: last dim is ceil(in_dim / 32), other dims preserved + expected_scale_shape = ( + *input_shape[:-1], + math.ceil(input_shape[-1] / MXFP8QTensor.BLOCK_SIZE), + ) + assert e8m0_scale.shape == expected_scale_shape, ( + f"Expected scale shape {expected_scale_shape}, got {e8m0_scale.shape}" + ) + + # Verify quantized data is FP8 E4M3 and preserves original shape + assert qtensor._quantized_data.dtype == torch.float8_e4m3fn, ( + f"Expected float8_e4m3fn, got {qtensor._quantized_data.dtype}" + ) + assert qtensor._quantized_data.shape == input_shape, ( + f"Expected quantized data shape {input_shape}, got {qtensor._quantized_data.shape}" + ) + + # Dequantize + dequant_tensor = qtensor.dequantize( + dtype=input_dtype, + scale=e8m0_scale, + ) + + # Verify dequantized tensor shape and values match original + assert dequant_tensor.shape == input_shape, ( + f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}" + ) + assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( + f"Dequantized tensor differs from original: " + f"max diff = {(dequant_tensor - test_tensor).abs().max()}" + ) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_e8m0_scale_values(self, device): + """Test that MXFP8 produces correct E8M0 scale values (power-of-2 only).""" + # Create a tensor with known amax values per block + # MXFP8 block size is 32, so create a 2x64 tensor (2 rows, 2 blocks per row) + test_tensor = torch.zeros((2, 64), dtype=torch.float32, device=device) + + # First block (row 0, elements 0-31): max abs = 1.0, should give exponent ~127-8 = 119 + # (since E4M3 max is 448, log2(1/448) ≈ -8.8, ceil = -8, biased = 127 + (-8) = 119) + test_tensor[0, :32] = 1.0 + + # Second block (row 0, elements 32-63): max abs = 448.0, should give exponent = 127 + # (since 448/448 = 1, log2(1) = 0, biased = 127) + test_tensor[0, 32:64] = 448.0 + + # Third block (row 1, elements 0-31): max abs = 2.0 + test_tensor[1, :32] = 2.0 + + # Fourth block (row 1, elements 32-63): max abs = 0.5 + test_tensor[1, 32:64] = 0.5 + + # Quantize + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor) + + # Verify all scales are valid uint8 values + assert e8m0_scale.dtype == torch.uint8 + assert e8m0_scale.shape == (2, 2) + + # Verify dequantization works + dequant = qtensor.dequantize( + dtype=torch.float32, + scale=e8m0_scale, + ) + + # Check that the dequantized max values per block are close to original + assert torch.allclose(dequant[0, :32].max(), torch.tensor(1.0, device=device), rtol=0.1) + assert torch.allclose(dequant[0, 32:64].max(), torch.tensor(448.0, device=device), rtol=0.1) + assert torch.allclose(dequant[1, :32].max(), torch.tensor(2.0, device=device), rtol=0.1) + assert torch.allclose(dequant[1, 32:64].max(), torch.tensor(0.5, device=device), rtol=0.1) + + # fmt: off + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "test_input", + [ + # FP8 E4M3 boundary test values (max is 448, various powers of 2) + torch.tensor([[1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 448.0, 0.5, 0.25, + 0.125, 0.0625, 0.03125, 0.015625, -1.0, -2.0, -4.0, -8.0, -16.0, -32.0, + -64.0, -128.0, -256.0, -448.0, -0.5, -0.25, -0.125, -0.0625, -0.03125, -0.015625]]), + # Mix of positive and negative values near E4M3 boundaries + torch.tensor([[448.0, 416.0, 384.0, 352.0, 320.0, 288.0, 256.0, 224.0, 192.0, 160.0, + 128.0, 96.0, 64.0, 48.0, 32.0, 24.0, -448.0, -416.0, -384.0, -352.0, -320.0, + -288.0, -256.0, -224.0, -192.0, -160.0, -128.0, -96.0, -64.0, -48.0, -32.0, -24.0]]), + ], + ) + def test_mxfp8_quantize_boundary_values(self, test_input, device, input_dtype): + # fmt: on + """Test MXFP8 quantization with E4M3 boundary values.""" + x = test_input.to(input_dtype).to(device) + qtensor, e8m0_scale = MXFP8QTensor.quantize(x) + + # Verify scale is uint8 (E8M0 format) + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + + dequant = qtensor.dequantize( + dtype=input_dtype, + scale=e8m0_scale, + ) + + # FP8 E4M3 has limited precision, allow reasonable tolerance + assert torch.allclose(dequant, x, rtol=5e-2, atol=5e-2), ( + f"Dequantized tensor differs from original: max diff = {(dequant - x).abs().max()}" + ) + + @pytest.mark.parametrize( + "input_shape", + [(1600, 1600)], + ) + def test_mxfp8_quantize_gpu_mem(self, input_shape): + """Test MXFP8 GPU memory usage during quantization.""" + + def _get_gpu_mem_used(): + device = torch.device("cuda:0") + free, total = torch.cuda.mem_get_info(device) + return total - free + + # Warmup + test_input = torch.rand((32, 32), dtype=torch.float32, device="cuda") + MXFP8QTensor.quantize(test_input) + + test_input = torch.rand(input_shape, dtype=torch.float32, device="cuda") + torch.cuda.empty_cache() + + input_size = test_input.element_size() * test_input.numel() + before_quantize = _get_gpu_mem_used() + MXFP8QTensor.quantize(test_input) + after_quantize = _get_gpu_mem_used() + + # Memory increase should be reasonable (less than 3x input size) + # MXFP8 stores FP8 data (1 byte) + uint8 scales, so should be efficient + assert (after_quantize - before_quantize) < input_size * 3, ( + f"Memory increase too large: {after_quantize - before_quantize} bytes " + f"for input size {input_size} bytes" + ) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize( + "input_shape", + [(128, 64), (256, 128), (512, 256)], + ) + def test_mxfp8_get_weights_scaling_factor(self, device, input_shape): + """Test MXFP8 get_weights_scaling_factor returns correct E8M0 scales.""" + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + # Get scaling factor + e8m0_scale = MXFP8QTensor.get_weights_scaling_factor(weight) + + # Verify dtype and shape + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + expected_shape = (input_shape[0], input_shape[1] // MXFP8QTensor.BLOCK_SIZE) + assert e8m0_scale.shape == expected_shape, ( + f"Expected scale shape {expected_shape}, got {e8m0_scale.shape}" + ) + + # Verify E8M0 values are in valid range [0, 254] (biased exponent = unbiased + 127) + # The code clamps unbiased exponent to [-127, 127], giving biased range [0, 254] + # Note: 255 (0xFF) represents NaN in E8M0 and should never appear from valid weights + assert torch.all(e8m0_scale <= 254), "E8M0 scale contains NaN value (255)" + + @pytest.mark.parametrize( + ("amax_value", "expected_exponent"), + [ + (0.0, -127.0), # Zero amax: minimum exponent + (448.0, 0.0), # E4M3_MAX: exponent 0 + (1.0, -8.0), # log2(1/448) ~ -8.8, ceil = -8 + (1e40, 127.0), # Very large amax: clamps to max + (1e-50, -127.0), # Very small amax: clamps to min + ], + ) + def test_mxfp8_compute_e8m0_exponent_edge_cases(self, amax_value, expected_exponent): + """Test _compute_e8m0_exponent handles edge cases correctly.""" + amax = torch.tensor([amax_value], device="cuda") + exponent = MXFP8QTensor._compute_e8m0_exponent(amax) + assert exponent.item() == expected_exponent, ( + f"amax={amax_value} should give exponent {expected_exponent}, got {exponent.item()}" + ) + + def test_mxfp8_get_weights_scaling_factor_asserts_1d_weight(self): + """Test get_weights_scaling_factor raises assertion for 1D tensor.""" + weight_1d = torch.randn(64, device="cuda") + with pytest.raises(AssertionError, match="Weight must be at least 2D"): + MXFP8QTensor.get_weights_scaling_factor(weight_1d) + + def test_mxfp8_get_weights_scaling_factor_asserts_non_divisible(self): + """Test get_weights_scaling_factor raises assertion when dim not divisible by 32.""" + # 33 is not divisible by 32 + weight = torch.randn(64, 33, device="cuda") + with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"): + MXFP8QTensor.get_weights_scaling_factor(weight) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_quantize_with_scale_asserts(self, device): + """Test quantize_with_scale raises assertions for invalid inputs.""" + # Test wrong scale dtype assertion + weight = torch.randn(64, 64, dtype=torch.float32, device=device) + wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device) + with pytest.raises(AssertionError, match="e8m0_scale must be"): + MXFP8QTensor.quantize_with_scale(weight, wrong_dtype_scale) + + # Test non-divisible dimension assertion + weight_bad_dim = torch.randn(64, 33, dtype=torch.float32, device=device) + scale = torch.randint(0, 255, (64, 1), dtype=torch.uint8, device=device) + with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"): + MXFP8QTensor.quantize_with_scale(weight_bad_dim, scale) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_get_weights_scaling_factor_from_quantizer_3d_moe(self, device): + """Test get_weights_scaling_factor_from_quantizer handles 3D MoE tensors.""" + input_shape = (4, 64, 128) # (num_experts, out_dim, in_dim) + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + class MockQuantizer: + block_sizes = {-1: MXFP8QTensor.BLOCK_SIZE} + _scale = None + + quantizer = MockQuantizer() + + # Test when _scale is None (should compute from weight) + scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer) + + expected_shape = ( + input_shape[0], + input_shape[1], + input_shape[2] // MXFP8QTensor.BLOCK_SIZE, + ) + assert scale.shape == expected_shape + + # Test when _scale is provided with correct 3D shape + quantizer._scale = torch.randint(0, 255, expected_shape, dtype=torch.uint8, device=device) + scale_from_quantizer = MXFP8QTensor.get_weights_scaling_factor_from_quantizer( + weight, quantizer + ) + assert torch.equal(scale_from_quantizer, quantizer._scale) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_get_weights_scaling_factor_from_quantizer_scale_shape_mismatch(self, device): + """Test get_weights_scaling_factor_from_quantizer raises assertion on shape mismatch.""" + input_shape = (4, 64, 128) # (num_experts, out_dim, in_dim) + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + class MockQuantizer: + block_sizes = {-1: MXFP8QTensor.BLOCK_SIZE} + # Wrong shape: 2D instead of 3D (missing num_experts dimension) + _scale = torch.randint( + 0, 255, (64, 4), dtype=torch.uint8, device=device + ) + + quantizer = MockQuantizer() + + with pytest.raises(AssertionError, match="Scale shape .* does not match expected shape"): + MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_mxfp8_dequantize_default_dtype(self, device, input_dtype): + """Test dequantize uses original dtype when dtype=None.""" + input_tensor = torch.randn(64, 64, dtype=input_dtype, device=device) + qtensor, e8m0_scale = MXFP8QTensor.quantize(input_tensor) + + # Dequantize without specifying dtype + dequant = qtensor.dequantize(scale=e8m0_scale) + + assert dequant.dtype == input_dtype