Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 82 additions & 22 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Calibrator that returns the MSE amax of all collected tensors."""

import math
from collections.abc import Callable

import torch
Expand All @@ -33,34 +34,68 @@ def __init__(
self,
amax: torch.Tensor,
axis: int | tuple | list | None = None,
num_steps: int = 10,
num_steps: int | None = None,
step_size: float | None = None,
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
fp8_scale_sweep: bool = False,
):
"""Initialize MSE calibrator.

Args:
amax: Initial amax value (required).
axis: Quantization axis. None means per-tensor quantization.
num_steps: Number of amax candidates to try.
num_steps: Number of amax candidates to try. Mutually exclusive with step_size.
If both are provided, num_steps takes precedence. If neither is provided,
defaults to 10. Ignored if fp8_scale_sweep is True.
step_size: Step size for the multiplier range [start_multiplier, stop_multiplier].
Mutually exclusive with num_steps. If specified, num_steps will be
computed as ceil((stop_multiplier - start_multiplier) / step_size) + 1.
Ignored if fp8_scale_sweep is True.
start_multiplier: Starting multiplier for amax search.
Ignored if fp8_scale_sweep is True.
stop_multiplier: Ending multiplier for amax search.
Ignored if fp8_scale_sweep is True.
quant_func: Function that quantizes input tensor given an amax value.
Should have signature: quant_func(x, amax) -> quantized_x.
error_func: Function to compute error between x and xq.
Default is F.mse_loss(x, xq, reduction='none').
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
instead of using multipliers. This is specifically for NVFP4
per-block quantization where scales are stored in FP8 format.
"""
super().__init__(num_bits=None, axis=axis, unsigned=None)
self._initial_amax = amax
self._num_steps = num_steps
self._start_multiplier = start_multiplier
self._stop_multiplier = stop_multiplier
self._quant_func = quant_func
self._error_func = error_func
self._losses_sum = [None] * num_steps
self._candidate_amaxs = [None] * num_steps
self._fp8_scale_sweep = fp8_scale_sweep

# Compute num_steps based on fp8_scale_sweep, num_steps, or step_size
if fp8_scale_sweep:
# For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values
# (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN)
self._num_steps = 126
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
elif num_steps is not None:
# num_steps takes precedence
self._num_steps = num_steps
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
elif step_size is not None:
# Compute num_steps from step_size
self._num_steps = math.ceil((stop_multiplier - start_multiplier) / step_size) + 1
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
else:
# Default to 10 steps
self._num_steps = 10
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps

self._amax = None

Expand All @@ -79,24 +114,45 @@ def collect(self, x: torch.Tensor):
x = x.detach().to(dtype=torch.float32)

device = x.device
# Split steps between _start_multiplier to 1.0 and 1.0 to _stop_multiplier
# to ensure balanced exploration on both sides of the original amax (1.0)
steps_first_half = self._num_steps // 2 + 1 # Include 1.0
steps_second_half = self._num_steps - self._num_steps // 2 # For second range
multipliers = torch.cat(
[
torch.linspace(self._start_multiplier, 1.0, steps=steps_first_half, device=device),
torch.linspace(1.0, self._stop_multiplier, steps=steps_second_half, device=device)[
1:
], # Skip duplicate 1.0
]
)

if self._fp8_scale_sweep:
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
global_amax_expanded = global_amax * torch.ones_like(self._initial_amax)

# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()

# Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values_valid = fp8_values[valid_mask]

candidates = fp8_values_valid / 448.0

print(
f"FP8 scale sweep: trying {len(candidates)} valid FP8 E4M3 multipliers (out of 128 total)"
)
print(
f"Multiplier range: {candidates.min().item():.6e} to {candidates.max().item():.6e}"
)
else:
multipliers = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)
print(f"Multipliers: {multipliers}")
candidates = multipliers

# Get reduce axis for per-channel quantization
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)

for step, multiplier in enumerate(multipliers):
candidate_amax = self._initial_amax * multiplier
for step, candidate in enumerate(candidates):
if self._fp8_scale_sweep:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!!

multiplier = candidate
candidate_amax = global_amax_expanded * multiplier
else:
# For normal MSE calibration, multiply initial amax by the multiplier
candidate_amax = self._initial_amax * candidate
xq = self._quant_func(x, candidate_amax)

if self._error_func is not None:
Expand All @@ -107,12 +163,16 @@ def collect(self, x: torch.Tensor):
loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False)

if self._candidate_amaxs[step] is None:
self._candidate_amaxs[step] = candidate_amax
self._candidate_amaxs[step] = candidate_amax.detach()

if self._losses_sum[step] is None:
self._losses_sum[step] = loss.clone()
self._losses_sum[step] = loss.detach().clone()
else:
self._losses_sum[step] += loss
self._losses_sum[step] += loss.detach()

# Free GPU memory after each calibration iteration
del error, loss, xq
torch.cuda.empty_cache()

def reset(self):
"""Reset the stored losses and amax value."""
Expand Down
60 changes: 56 additions & 4 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,29 @@
"algorithm": "max",
}

NVFP4_WEIGHT_MSE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
# "step_size": 0.05,
"fp8_scale_sweep": True,
# "num_steps": 5,
# "start_multiplier": 0.25,
# "stop_multiplier": 2.0,
},
}

NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down Expand Up @@ -987,29 +1010,58 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
reconstruction error of a tensor after uniform Q→DQ:

s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations}

Note: You can specify either num_steps or step_size (but not both) to control
the amax search range. If both are provided, num_steps takes precedence.
When fp8_scale_sweep is enabled, num_steps and step_size are ignored.
"""

method: Literal["mse"] = ModeloptField("mse")

num_steps: int | None = ModeloptField(
default=10,
default=None,
ge=1,
title="Number of amax candidates to try.",
description="Number of amax candidates to search over for MSE minimization.",
description="Number of amax candidates to search over for MSE minimization. "
"Mutually exclusive with step_size. If both are provided, num_steps takes precedence. "
"If neither num_steps nor step_size is provided, defaults to 10. "
"Ignored if fp8_scale_sweep is True.",
)

step_size: float | None = ModeloptField(
default=None,
gt=0.0,
title="Step size for amax search.",
description="Step size for the multiplier range [start_multiplier, stop_multiplier]. "
"Mutually exclusive with num_steps. If specified, num_steps will be computed as "
"ceil((stop_multiplier - start_multiplier) / step_size) + 1. "
"Ignored if fp8_scale_sweep is True.",
)

start_multiplier: float | None = ModeloptField(
default=0.25,
gt=0.0,
title="Starting multiplier for amax search.",
description="Starting multiplier for amax search range (multiplies initial amax).",
description="Starting multiplier for amax search range (multiplies initial amax). "
"Ignored if fp8_scale_sweep is True.",
)

stop_multiplier: float | None = ModeloptField(
default=4.0,
gt=0.0,
title="Ending multiplier for amax search.",
description="Ending multiplier for amax search range (multiplies initial amax).",
description="Ending multiplier for amax search range (multiplies initial amax). "
"Ignored if fp8_scale_sweep is True.",
)

fp8_scale_sweep: bool | None = ModeloptField(
default=False,
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
description="If True, sweep over all 128 possible FP8 E4M3 scale values "
"for NVFP4 per-block quantization instead of using multipliers. "
"This is specifically designed for optimizing the FP8-quantized per-block scales "
"in NVFP4 format. When enabled, num_steps, step_size, start_multiplier, and "
"stop_multiplier are ignored for NVFP4 per-block quantizers.",
)

distributed_sync: bool | None = ModeloptField(
Expand Down
61 changes: 46 additions & 15 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,11 @@ def mse_calibrate(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
distributed_sync=True,
num_steps: int = 10,
num_steps: int | None = None,
step_size: float | None = None,
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
fp8_scale_sweep: bool = False,
):
"""Calibrate the model using MSE-based amax search.

Expand All @@ -207,13 +209,28 @@ def mse_calibrate(
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
distributed_sync: Whether to sync amax across distributed processes.
num_steps: Number of amax candidates to try (default: 10).
num_steps: Number of amax candidates to try. Mutually exclusive with step_size.
If both are provided, num_steps takes precedence. If neither is provided,
defaults to 10. Ignored if fp8_scale_sweep is True.
step_size: Step size for the multiplier range [start_multiplier, stop_multiplier].
Mutually exclusive with num_steps. If specified, num_steps will be
computed as ceil((stop_multiplier - start_multiplier) / step_size) + 1.
Ignored if fp8_scale_sweep is True.
start_multiplier: Starting multiplier for amax search (default: 0.25).
Ignored if fp8_scale_sweep is True.
stop_multiplier: Ending multiplier for amax search (default: 4.0).
Ignored if fp8_scale_sweep is True.
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
for NVFP4 per-block quantization instead of using multipliers.
This is specifically designed for optimizing the FP8-quantized
per-block scales in NVFP4 format (default: False).

See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
"""
# Set default for num_steps if neither num_steps nor step_size is provided
if num_steps is None and step_size is None:
num_steps = 10
# Step 1: First get initial amax using max calibration
max_calibrate(model, forward_loop, distributed_sync)

Expand All @@ -228,9 +245,20 @@ def mse_calibrate(
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()

def quant_func(x, amax, quantizer=module):
def quant_func(x, amax, quantizer=module, use_fp8_sweep=fp8_scale_sweep):
original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None
quantizer._amax = amax

# FP8 quantization of NVFP4 static per-block scales
if (
quantizer.is_static_block_quant
and quantizer._num_bits == (2, 1)
and quantizer._block_sizes.get("scale_bits") == (4, 3)
):
weight_amax = reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
quantizer._amax = scaled_e4m3_impl(amax, weight_amax)
else:
# For non-NVFP4 quantizers, use amax directly
quantizer._amax = amax

with (
enable_quant(quantizer),
Expand All @@ -241,32 +269,31 @@ def quant_func(x, amax, quantizer=module):
xq = quantizer(x)
quantizer._keep_shape = False

# FP8 quantization of NVFP4 static per-block scales
if (
quantizer.is_static_block_quant
and quantizer._num_bits == (2, 1)
and quantizer._block_sizes.get("scale_bits") == (4, 3)
):
weight_amax = reduce_amax(
x, axis=None, keepdims=False, squeeze_scalar=True
)
quantizer._amax = scaled_e4m3_impl(amax / 6.0, weight_amax / 6.0) * 6.0

if original_amax is not None:
quantizer._amax = original_amax
else:
delattr(quantizer, "_amax")

return xq

# Determine if this is an NVFP4 per-block quantizer that should use FP8 scale sweep
is_nvfp4_per_block = (
fp8_scale_sweep
and module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes.get("scale_bits") == (4, 3)
)

# Create MSE calibrator with quant_func
module._calibrator = MseCalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
num_steps=num_steps,
step_size=step_size,
start_multiplier=start_multiplier,
stop_multiplier=stop_multiplier,
quant_func=quant_func,
fp8_scale_sweep=is_nvfp4_per_block,
)

# Identify weight quantizers by checking if they have corresponding weight parameters
Expand All @@ -292,6 +319,8 @@ def quant_func(x, amax, quantizer=module):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)

torch.cuda.empty_cache()

# Step 4: Disable weight quantizers during forward loop
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.disable()
Expand Down Expand Up @@ -320,6 +349,8 @@ def quant_func(x, amax, quantizer=module):
if hasattr(module._calibrator, "clear"):
module._calibrator.clear()

torch.cuda.empty_cache()

# TODO: Sync amax across distributed processes


Expand Down