diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 4292f1ff4..c657a3f21 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -15,6 +15,7 @@ """Calibrator that returns the MSE amax of all collected tensors.""" +import math from collections.abc import Callable import torch @@ -33,34 +34,68 @@ def __init__( self, amax: torch.Tensor, axis: int | tuple | list | None = None, - num_steps: int = 10, + num_steps: int | None = None, + step_size: float | None = None, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + fp8_scale_sweep: bool = False, ): """Initialize MSE calibrator. Args: amax: Initial amax value (required). axis: Quantization axis. None means per-tensor quantization. - num_steps: Number of amax candidates to try. + num_steps: Number of amax candidates to try. Mutually exclusive with step_size. + If both are provided, num_steps takes precedence. If neither is provided, + defaults to 10. Ignored if fp8_scale_sweep is True. + step_size: Step size for the multiplier range [start_multiplier, stop_multiplier]. + Mutually exclusive with num_steps. If specified, num_steps will be + computed as ceil((stop_multiplier - start_multiplier) / step_size) + 1. + Ignored if fp8_scale_sweep is True. start_multiplier: Starting multiplier for amax search. + Ignored if fp8_scale_sweep is True. stop_multiplier: Ending multiplier for amax search. + Ignored if fp8_scale_sweep is True. quant_func: Function that quantizes input tensor given an amax value. Should have signature: quant_func(x, amax) -> quantized_x. error_func: Function to compute error between x and xq. Default is F.mse_loss(x, xq, reduction='none'). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + instead of using multipliers. This is specifically for NVFP4 + per-block quantization where scales are stored in FP8 format. """ super().__init__(num_bits=None, axis=axis, unsigned=None) self._initial_amax = amax - self._num_steps = num_steps self._start_multiplier = start_multiplier self._stop_multiplier = stop_multiplier self._quant_func = quant_func self._error_func = error_func - self._losses_sum = [None] * num_steps - self._candidate_amaxs = [None] * num_steps + self._fp8_scale_sweep = fp8_scale_sweep + + # Compute num_steps based on fp8_scale_sweep, num_steps, or step_size + if fp8_scale_sweep: + # For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values + # (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN) + self._num_steps = 126 + self._losses_sum = [None] * self._num_steps + self._candidate_amaxs = [None] * self._num_steps + elif num_steps is not None: + # num_steps takes precedence + self._num_steps = num_steps + self._losses_sum = [None] * self._num_steps + self._candidate_amaxs = [None] * self._num_steps + elif step_size is not None: + # Compute num_steps from step_size + self._num_steps = math.ceil((stop_multiplier - start_multiplier) / step_size) + 1 + self._losses_sum = [None] * self._num_steps + self._candidate_amaxs = [None] * self._num_steps + else: + # Default to 10 steps + self._num_steps = 10 + self._losses_sum = [None] * self._num_steps + self._candidate_amaxs = [None] * self._num_steps self._amax = None @@ -79,24 +114,45 @@ def collect(self, x: torch.Tensor): x = x.detach().to(dtype=torch.float32) device = x.device - # Split steps between _start_multiplier to 1.0 and 1.0 to _stop_multiplier - # to ensure balanced exploration on both sides of the original amax (1.0) - steps_first_half = self._num_steps // 2 + 1 # Include 1.0 - steps_second_half = self._num_steps - self._num_steps // 2 # For second range - multipliers = torch.cat( - [ - torch.linspace(self._start_multiplier, 1.0, steps=steps_first_half, device=device), - torch.linspace(1.0, self._stop_multiplier, steps=steps_second_half, device=device)[ - 1: - ], # Skip duplicate 1.0 - ] - ) + + if self._fp8_scale_sweep: + global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) + global_amax_expanded = global_amax * torch.ones_like(self._initial_amax) + + # Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn) + # Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32 + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + + # Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + fp8_values_valid = fp8_values[valid_mask] + + candidates = fp8_values_valid / 448.0 + + print( + f"FP8 scale sweep: trying {len(candidates)} valid FP8 E4M3 multipliers (out of 128 total)" + ) + print( + f"Multiplier range: {candidates.min().item():.6e} to {candidates.max().item():.6e}" + ) + else: + multipliers = torch.linspace( + self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device + ) + print(f"Multipliers: {multipliers}") + candidates = multipliers # Get reduce axis for per-channel quantization reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) - for step, multiplier in enumerate(multipliers): - candidate_amax = self._initial_amax * multiplier + for step, candidate in enumerate(candidates): + if self._fp8_scale_sweep: + multiplier = candidate + candidate_amax = global_amax_expanded * multiplier + else: + # For normal MSE calibration, multiply initial amax by the multiplier + candidate_amax = self._initial_amax * candidate xq = self._quant_func(x, candidate_amax) if self._error_func is not None: @@ -107,12 +163,16 @@ def collect(self, x: torch.Tensor): loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False) if self._candidate_amaxs[step] is None: - self._candidate_amaxs[step] = candidate_amax + self._candidate_amaxs[step] = candidate_amax.detach() if self._losses_sum[step] is None: - self._losses_sum[step] = loss.clone() + self._losses_sum[step] = loss.detach().clone() else: - self._losses_sum[step] += loss + self._losses_sum[step] += loss.detach() + + # Free GPU memory after each calibration iteration + del error, loss, xq + torch.cuda.empty_cache() def reset(self): """Reset the stored losses and amax value.""" diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index a8261fc06..537f8a693 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -387,6 +387,29 @@ "algorithm": "max", } +NVFP4_WEIGHT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + # "step_size": 0.05, + "fp8_scale_sweep": True, + # "num_steps": 5, + # "start_multiplier": 0.25, + # "stop_multiplier": 2.0, + }, +} + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -987,29 +1010,58 @@ class MseCalibConfig(QuantizeAlgorithmConfig): reconstruction error of a tensor after uniform Q→DQ: s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations} + + Note: You can specify either num_steps or step_size (but not both) to control + the amax search range. If both are provided, num_steps takes precedence. + When fp8_scale_sweep is enabled, num_steps and step_size are ignored. """ method: Literal["mse"] = ModeloptField("mse") num_steps: int | None = ModeloptField( - default=10, + default=None, ge=1, title="Number of amax candidates to try.", - description="Number of amax candidates to search over for MSE minimization.", + description="Number of amax candidates to search over for MSE minimization. " + "Mutually exclusive with step_size. If both are provided, num_steps takes precedence. " + "If neither num_steps nor step_size is provided, defaults to 10. " + "Ignored if fp8_scale_sweep is True.", + ) + + step_size: float | None = ModeloptField( + default=None, + gt=0.0, + title="Step size for amax search.", + description="Step size for the multiplier range [start_multiplier, stop_multiplier]. " + "Mutually exclusive with num_steps. If specified, num_steps will be computed as " + "ceil((stop_multiplier - start_multiplier) / step_size) + 1. " + "Ignored if fp8_scale_sweep is True.", ) start_multiplier: float | None = ModeloptField( default=0.25, gt=0.0, title="Starting multiplier for amax search.", - description="Starting multiplier for amax search range (multiplies initial amax).", + description="Starting multiplier for amax search range (multiplies initial amax). " + "Ignored if fp8_scale_sweep is True.", ) stop_multiplier: float | None = ModeloptField( default=4.0, gt=0.0, title="Ending multiplier for amax search.", - description="Ending multiplier for amax search range (multiplies initial amax).", + description="Ending multiplier for amax search range (multiplies initial amax). " + "Ignored if fp8_scale_sweep is True.", + ) + + fp8_scale_sweep: bool | None = ModeloptField( + default=False, + title="Enable FP8 scale sweep for NVFP4 per-block quantization.", + description="If True, sweep over all 128 possible FP8 E4M3 scale values " + "for NVFP4 per-block quantization instead of using multipliers. " + "This is specifically designed for optimizing the FP8-quantized per-block scales " + "in NVFP4 format. When enabled, num_steps, step_size, start_multiplier, and " + "stop_multiplier are ignored for NVFP4 per-block quantizers.", ) distributed_sync: bool | None = ModeloptField( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c1d787956..2ccdb633b 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -192,9 +192,11 @@ def mse_calibrate( model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True, - num_steps: int = 10, + num_steps: int | None = None, + step_size: float | None = None, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, + fp8_scale_sweep: bool = False, ): """Calibrate the model using MSE-based amax search. @@ -207,13 +209,28 @@ def mse_calibrate( forward_loop: A callable which takes the model as argument and forwards calibration data through the model. distributed_sync: Whether to sync amax across distributed processes. - num_steps: Number of amax candidates to try (default: 10). + num_steps: Number of amax candidates to try. Mutually exclusive with step_size. + If both are provided, num_steps takes precedence. If neither is provided, + defaults to 10. Ignored if fp8_scale_sweep is True. + step_size: Step size for the multiplier range [start_multiplier, stop_multiplier]. + Mutually exclusive with num_steps. If specified, num_steps will be + computed as ceil((stop_multiplier - start_multiplier) / step_size) + 1. + Ignored if fp8_scale_sweep is True. start_multiplier: Starting multiplier for amax search (default: 0.25). + Ignored if fp8_scale_sweep is True. stop_multiplier: Ending multiplier for amax search (default: 4.0). + Ignored if fp8_scale_sweep is True. + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + for NVFP4 per-block quantization instead of using multipliers. + This is specifically designed for optimizing the FP8-quantized + per-block scales in NVFP4 format (default: False). See :class:`MseCalibConfig ` for details on the remaining arguments. """ + # Set default for num_steps if neither num_steps nor step_size is provided + if num_steps is None and step_size is None: + num_steps = 10 # Step 1: First get initial amax using max calibration max_calibrate(model, forward_loop, distributed_sync) @@ -228,9 +245,20 @@ def mse_calibrate( # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() - def quant_func(x, amax, quantizer=module): + def quant_func(x, amax, quantizer=module, use_fp8_sweep=fp8_scale_sweep): original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None - quantizer._amax = amax + + # FP8 quantization of NVFP4 static per-block scales + if ( + quantizer.is_static_block_quant + and quantizer._num_bits == (2, 1) + and quantizer._block_sizes.get("scale_bits") == (4, 3) + ): + weight_amax = reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) + quantizer._amax = scaled_e4m3_impl(amax, weight_amax) + else: + # For non-NVFP4 quantizers, use amax directly + quantizer._amax = amax with ( enable_quant(quantizer), @@ -241,17 +269,6 @@ def quant_func(x, amax, quantizer=module): xq = quantizer(x) quantizer._keep_shape = False - # FP8 quantization of NVFP4 static per-block scales - if ( - quantizer.is_static_block_quant - and quantizer._num_bits == (2, 1) - and quantizer._block_sizes.get("scale_bits") == (4, 3) - ): - weight_amax = reduce_amax( - x, axis=None, keepdims=False, squeeze_scalar=True - ) - quantizer._amax = scaled_e4m3_impl(amax / 6.0, weight_amax / 6.0) * 6.0 - if original_amax is not None: quantizer._amax = original_amax else: @@ -259,14 +276,24 @@ def quant_func(x, amax, quantizer=module): return xq + # Determine if this is an NVFP4 per-block quantizer that should use FP8 scale sweep + is_nvfp4_per_block = ( + fp8_scale_sweep + and module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes.get("scale_bits") == (4, 3) + ) + # Create MSE calibrator with quant_func module._calibrator = MseCalibrator( amax=initial_amax, axis=module._calibrator._axis, num_steps=num_steps, + step_size=step_size, start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, quant_func=quant_func, + fp8_scale_sweep=is_nvfp4_per_block, ) # Identify weight quantizers by checking if they have corresponding weight parameters @@ -292,6 +319,8 @@ def quant_func(x, amax, quantizer=module): weight = getattr(parent_module, weight_name) weight_quantizer(weight) + torch.cuda.empty_cache() + # Step 4: Disable weight quantizers during forward loop for _, _, weight_quantizer in weight_quantizers: weight_quantizer.disable() @@ -320,6 +349,8 @@ def quant_func(x, amax, quantizer=module): if hasattr(module._calibrator, "clear"): module._calibrator.clear() + torch.cuda.empty_cache() + # TODO: Sync amax across distributed processes