Skip to content

Commit 81c509c

Browse files
realAsmaAsma Thekkumpate
andauthored
Fixes & Simplifications for MCore KVCache QAT/QAD; Unittests; Distributed Sync of KVCache Quantizer params (#727)
## What does this PR do? **Type of change:** Fix MCore KV Cache Quantization: Amax Device Placement Bug; Code clean up; Distributed Sync of KVCache Quantizer params; unittest expansion to hybrid models **Overview:** Fixes bugs preventing MCore KV Cache quantization from working during checkpoint restore. ### Bug Chain **Bug 1:** `is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False` No `weight_quantizer` for KV-cache-only quant → `is_enabled=False` → metadata not saved → `modelopt_post_restore()` never called. *(Thanks to @jenchen13 )* **Bug 2:** After fixing Bug 1, `_amax` restored on CPU (via `_reset_pytorch_state_from_metadata`). Fallback `_calibrate_quantizers()` never called because `_amax` exists. **Bug 3:** Even if called, `_calibrate_quantizers()` fails — `core_attention` has no parameters → can't determine device/dtype. ### The Fix 1. Remove `is_enabled` check entirely — disabled modules may still need metadata restore. Explicitly skip `output_layer` from extra state callbacks (never quantized) 2. Set `dtype`/`device` on `core_attention` from parent Attention module, `modelopt_post_restore()` calls `self.to(device, dtype)` 3. Remove dead `_calibrate_quantizers()` code (will bring back similar logic for KV cache affine quantization) ### Previous Unit Test Was Wrong `model_test` was `mtq.quantize()`'d, not `mto.restore()`'d. Never tested actual restore path. ### Additional Fixes - Amax sync across DP/TP for KV cache quantizers - `flash_decode` auto-disabled ### Code Cleanup Removed ~100 lines of dead code. ## Testing 1. MCore KV Cache QAD with Nano V3 + Context Parallel works 2. Unit tests: hybrid models, KV+GEMM configs, correct restore workflow, backward pass validation ## Before your PR is "*Ready for review*" - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: No - **Did you update Changelog?**: Yes --------- Signed-off-by: realAsma <akuriparambi@nvidia.com> Co-authored-by: Asma Thekkumpate <akuriparambi@cw-dfw-cs-001-vscode-02.cm.cluster>
1 parent fe52b2a commit 81c509c

File tree

8 files changed

+327
-342
lines changed

8 files changed

+327
-342
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ NVIDIA Model Optimizer Changelog (Linux)
44
0.41 (2026-01-19)
55
^^^^^^^^^^^^^^^^^
66

7+
**Bug Fixes**
8+
9+
- Fix Megatron KV Cache quantization checkpoint restore for QAT/QAD (device placement, amax sync across DP/TP, flash_decode compatibility).
10+
711
**New Features**
812

913
- Add support for Transformer Engine quantization for Megatron Core models.

modelopt/torch/quantization/model_calib.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
9595
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
9696
# TODO: create sync_bias_across_distributed_group
9797

98+
# Step 1:Sync amax across data parallelism
9899
for name, module in model.named_modules():
99100
if isinstance(module, QuantModule):
100101
for child in module.children():
@@ -142,6 +143,7 @@ def sync_quantizer_amax_across_tp(
142143
if quantizer.axis in axes_for_sync and quantizer.amax is not None:
143144
quantizer.sync_amax_across_distributed_group(parallel_state.tensor_parallel_group)
144145

146+
# Step 2: Sync amax across relevant parallelism (such as TP / EP)
145147
for name, module in model.named_modules():
146148
if getattr(module, "_parallel_state", None) is None:
147149
continue
@@ -180,10 +182,20 @@ def sync_quantizer_amax_across_tp(
180182
parallel_state=module.parallel_state,
181183
)
182184

183-
for name, module in model.named_modules():
185+
# MOE Quantization
184186
if hasattr(module, "sync_moe_local_experts_amax"):
185187
module.sync_moe_local_experts_amax()
186188

189+
# KV Cache Quantization
190+
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
191+
# We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache)
192+
# So we should sync amax across DP and TP for these quantizers (DP is already synced from above)
193+
for quantizer in [module.k_bmm_quantizer, module.v_bmm_quantizer]:
194+
if isinstance(quantizer, TensorQuantizer) and quantizer.amax is not None:
195+
quantizer.sync_amax_across_distributed_group(
196+
module.parallel_state.tensor_parallel_group
197+
)
198+
187199

188200
@torch.no_grad()
189201
def mse_calibrate(

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ def forward(self, inputs):
999999
# Check if the input tensor is contiguous
10001000
# Non-contiguous tensors will generate incorrect FP4 quantization results
10011001
if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous():
1002-
inputs.data = inputs.data.contiguous()
1002+
inputs = inputs.contiguous()
10031003
if self.fake_quant:
10041004
with same_device_as(inputs):
10051005
outputs = self._fake_quantize(inputs)

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 74 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Support quantization for megatron linear layers."""
1717

1818
import logging
19+
import types
1920
import warnings
2021
from typing import Any
2122

@@ -28,6 +29,7 @@
2829
from megatron.core.parallel_state import get_data_parallel_group
2930
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
3031
from megatron.core.transformer import MegatronModule
32+
from megatron.core.transformer.attention import Attention
3133
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
3234
from megatron.core.utils import get_tensor_model_parallel_group_if_none
3335

@@ -38,7 +40,6 @@
3840
)
3941
from modelopt.torch.utils.distributed import ParallelState
4042

41-
from ..model_calib import max_calibrate
4243
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
4344
from ..nn.modules.quant_linear import RealQuantLinear
4445
from ..qtensor import QTensorWrapper
@@ -98,11 +99,6 @@ def quant_module_get_extra_state(self) -> dict:
9899
"""
99100
extra_state = {}
100101

101-
is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False
102-
103-
if not is_enabled:
104-
return extra_state
105-
106102
quantizer_state = {}
107103
for name, module in self.named_modules():
108104
if isinstance(module, TensorQuantizer):
@@ -201,6 +197,19 @@ def quant_module_set_extra_state(self, state: Any):
201197
self.allow_post_restore = False
202198

203199

200+
def _create_incompatible_method(method_name: str):
201+
"""Create a method that raises an error for incompatible flash decode methods."""
202+
203+
def _incompatible_method(self, *args, **kwargs):
204+
raise NotImplementedError(
205+
f"{method_name} is not compatible with ModelOpt KV cache quantization. "
206+
f"KV cache quantization requires core_attention to be called. "
207+
f"Please raise an issue at https://github.com/NVIDIA/Model-Optimizer if you need this feature."
208+
)
209+
210+
return _incompatible_method
211+
212+
204213
def megatron_replace_quant_module_hook(model: torch.nn.Module):
205214
"""Configure Megatron-Core model quantization support.
206215
@@ -211,10 +220,39 @@ def megatron_replace_quant_module_hook(model: torch.nn.Module):
211220
1. We change TransformerConfig to enable heterogenous distributed checkpointing.
212221
2. We enable all sub- QuantModule to store quantizer_state as extra_state by
213222
typing-matching the QuantModuleRegistry.
223+
3. For Attention modules, we configure them to use core_attention path for KV cache quantization.
214224
"""
215225

226+
def _configure_attention_for_kv_cache_quant(module: Attention):
227+
"""Configure Attention module for KV cache quantization compatibility."""
228+
# Disable flash_decode if enabled - it bypasses core_attention (only called during inference)
229+
if getattr(module.config, "flash_decode", False):
230+
warnings.warn(
231+
"flash_decode=True is incompatible with ModelOpt KV cache quantization. "
232+
"Setting flash_decode=False. Flash decode bypasses core_attention during decode phase."
233+
)
234+
module.config.flash_decode = False
235+
236+
# Set dtype and device for core_attention (needed for modelopt_post_restore)
237+
assert hasattr(module, "core_attention"), "Attention module must have core_attention"
238+
param = next(iter(module.parameters()), None)
239+
if param is not None:
240+
module.core_attention.dtype = param.dtype
241+
module.core_attention.device = param.device
242+
243+
# Patch flash_decode and flash_decode_and_prefill to raise errors
244+
module.flash_decode = types.MethodType(_create_incompatible_method("flash_decode"), module)
245+
module.flash_decode_and_prefill = types.MethodType(
246+
_create_incompatible_method("flash_decode_and_prefill"), module
247+
)
248+
216249
def _register_extra_state_callbacks(model: torch.nn.Module):
217250
for name, module in model.named_modules():
251+
if name.endswith("output_layer"):
252+
# output_layer is not quantized,
253+
# hence we don't need to register extra state callbacks for it
254+
continue
255+
218256
if type(module) in QuantModuleRegistry:
219257
# This module will be replaced as a QuantModule
220258
register_modelopt_extra_state_callbacks(
@@ -223,6 +261,10 @@ def _register_extra_state_callbacks(model: torch.nn.Module):
223261
quant_module_set_extra_state,
224262
)
225263

264+
# Configure Attention modules for KV cache quantization
265+
if isinstance(module, Attention):
266+
_configure_attention_for_kv_cache_quant(module)
267+
226268
for name, module in model.named_modules():
227269
if isinstance(module, MegatronModule):
228270
if "vision_model" not in name:
@@ -632,152 +674,44 @@ def _setup(self):
632674
self.k_bmm_quantizer = TensorQuantizer()
633675
self.v_bmm_quantizer = TensorQuantizer()
634676

635-
def _calibrate_quantizers(self):
636-
"""Calibrate quantizers with minimal dummy tensors."""
637-
# Get device and dtype from the parent module's parameters
638-
param = next(iter(self.parameters()), None)
639-
device = param.device if param is not None else torch.device("cuda")
640-
dtype = param.dtype if param is not None else torch.float16
641-
642-
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
643-
batch_size = 1
644-
seq_len = 1
645-
646-
# Get dimensions from config
647-
num_heads = self.config.num_attention_heads
648-
head_dim = (
649-
self.config.kv_channels
650-
if hasattr(self.config, "kv_channels")
651-
else self.config.hidden_size // num_heads
677+
# Set parallel_state for distributed sync of BMM quantizers
678+
try:
679+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
680+
except AssertionError:
681+
data_parallel_group = get_data_parallel_group()
682+
self.parallel_state = ParallelState(
683+
data_parallel_group,
684+
mcore_parallel.get_tensor_model_parallel_group(),
652685
)
653686

654-
# Determine tensor format (default to sbhd if not specified)
655-
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
656-
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
657-
658-
if qkv_format == "sbhd":
659-
dummy_tensor = torch.randn(
660-
seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype
661-
)
662-
else:
663-
dummy_tensor = torch.randn(
664-
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
665-
)
666-
667-
# Calibrate each quantizer
668-
quantizers = [
669-
("q_bmm_quantizer", self.q_bmm_quantizer),
670-
("k_bmm_quantizer", self.k_bmm_quantizer),
671-
("v_bmm_quantizer", self.v_bmm_quantizer),
672-
]
673-
674-
for _, quantizer in quantizers:
675-
if quantizer is not None and quantizer.is_enabled():
676-
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
677-
quantizer.reset_amax()
678-
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
679-
680687
def forward(self, query, key, value, *args, **kwargs):
681-
"""Apply post-RoPE quantization to KV cache.
682-
683-
TEDotProductAttention receives Q, K, V after RoPE is applied,
684-
so we quantize them directly for KV cache quantization.
685-
"""
688+
"""Apply post-RoPE quantization to KV cache."""
686689
# Quantize Q, K, V
687690
query = self.q_bmm_quantizer(query)
688691
key = self.k_bmm_quantizer(key)
689692
value = self.v_bmm_quantizer(value)
690-
691693
return super().forward(query, key, value, *args, **kwargs)
692694

693-
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
694-
"""Create a sharded state dictionary for distributed checkpointing."""
695-
sharded_state_dict = {}
696-
697-
# First add non-quantizer parameters
698-
for k, v in self.state_dict(prefix="", keep_vars=True).items():
699-
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
700-
sharded_state_dict[prefix + k] = v
701-
702-
# Process _amax in bmm_quantizers
703-
for name, quantizer in [
704-
("q_bmm_quantizer", self.q_bmm_quantizer),
705-
("k_bmm_quantizer", self.k_bmm_quantizer),
706-
("v_bmm_quantizer", self.v_bmm_quantizer),
707-
]:
708-
if hasattr(quantizer, "_amax") and quantizer._amax is not None:
709-
amax_key = f"{prefix}{name}._amax"
710-
sharded_state_dict[amax_key] = quantizer._amax
711-
712-
# Process other quantizer parameters in bmm_quantizers
713-
quantizer_state_dict = {
714-
k: v
715-
for k, v in self.state_dict(prefix="", keep_vars=True).items()
716-
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
717-
}
718-
719-
if quantizer_state_dict:
720-
sharded_state_dict.update(
721-
**make_sharded_tensors_for_checkpoint(
722-
quantizer_state_dict, prefix, {}, sharded_offsets
723-
)
724-
)
725-
726-
return sharded_state_dict
727-
728-
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
729-
"""Handle loading state dict for quantizers."""
730-
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
731-
full_prefix = f"{prefix}{quantizer_name}."
732-
amax_key = f"{prefix}{quantizer_name}._amax"
733-
734-
# If amax is in state_dict, rename it to the format expected by TensorQuantizer
735-
if amax_key in state_dict:
736-
expected_amax_key = f"{full_prefix}_amax"
737-
state_dict[expected_amax_key] = state_dict.pop(amax_key)
738-
739-
# Handle other quantizer states
740-
for k in list(state_dict.keys()):
741-
if "_quantizer" in k and "_amax" not in k:
742-
name = k.split(prefix)[-1] if prefix else k
743-
if name in self.state_dict():
744-
state_dict[k] = state_dict[k].view_as(self.state_dict()[name])
745-
746-
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
747-
748695
def modelopt_post_restore(self, name=""):
749696
"""Restore quantizer states after model loading."""
750-
super().modelopt_post_restore(name)
751-
752-
def _check_unsupported_states(quantizer):
753-
"""Check for unsupported quantizer states and warn if found."""
754-
if not hasattr(quantizer, "state_dict"):
755-
return
756-
757-
for k in quantizer.state_dict():
758-
if k not in ["_amax", "_pre_quant_scale"]:
759-
warnings.warn(
760-
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
761-
f"incorrect. Please implement a custom restore for {k}."
762-
)
763-
764-
calibration_needed = False
765-
766-
for quantizer_name, quantizer in [
767-
("q_bmm_quantizer", self.q_bmm_quantizer),
768-
("k_bmm_quantizer", self.k_bmm_quantizer),
769-
("v_bmm_quantizer", self.v_bmm_quantizer),
770-
]:
771-
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
772-
continue
773-
774-
_check_unsupported_states(quantizer)
775-
776-
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
777-
calibration_needed = True
697+
for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]:
698+
# TODO: Add support for non-scalar states such as
699+
# Affine KVCache bias vector which is per head per channel
700+
if not all(v.numel() == 1 for v in tq.state_dict().values()):
701+
raise NotImplementedError(
702+
"Only scalar states are supported for KV Cache/BMM Quantizers"
703+
)
704+
# dtype and device should have been set in `megatron_replace_quant_module_hook`
705+
# via `_configure_attention_for_kv_cache_quant`
706+
assert hasattr(self, "device") and hasattr(self, "dtype")
707+
self.to(device=self.device, dtype=self.dtype)
778708

779-
if calibration_needed:
780-
self._calibrate_quantizers()
709+
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
710+
# Currently we do not need sharded_state_dict for TEDotProductAttention since the amax are scalar values.
711+
# However we would need this in future to support non-scalar states such as
712+
# Affine KVCache Quant bias vector.
713+
state_dict = self.state_dict(prefix="", keep_vars=True)
714+
return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets)
781715

782716

783717
@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})

tests/_test_utils/torch/megatron/models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
init_method=torch.nn.init.xavier_uniform_,
8888
bias=True,
8989
gather_output=False,
90-
skip_bias_add=True,
90+
skip_bias_add=False,
9191
is_expert=False,
9292
tp_group=tp_group,
9393
)
@@ -101,7 +101,7 @@ def __init__(
101101
config=config,
102102
init_method=torch.nn.init.xavier_uniform_,
103103
bias=True,
104-
skip_bias_add=True,
104+
skip_bias_add=False,
105105
input_is_parallel=True,
106106
is_expert=False,
107107
tp_group=tp_group,
@@ -311,6 +311,7 @@ def get_mcore_mamba_hybrid_model(
311311
max_sequence_length: int = 4,
312312
vocab_size: int = 64,
313313
bf16: bool = True,
314+
sequence_parallel: bool = False,
314315
# Mamba-specific parameters
315316
mamba_state_dim: int = 32,
316317
mamba_head_dim: int = 16,
@@ -337,7 +338,7 @@ def get_mcore_mamba_hybrid_model(
337338
config = TransformerConfig(
338339
tensor_model_parallel_size=tensor_model_parallel_size,
339340
pipeline_model_parallel_size=pipeline_model_parallel_size,
340-
sequence_parallel=False,
341+
sequence_parallel=sequence_parallel,
341342
num_layers=num_layers,
342343
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
343344
num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage,

tests/_test_utils/torch/megatron/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,16 @@ def convert_maybe_fp8(v):
214214
f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}"
215215
)
216216

217+
# Test backward pass on model_test
218+
model_test.train()
219+
loss = forward_fn(model_test).sum()
220+
loss.backward()
221+
222+
# Assert that trainable parameters have gradients computed
223+
for name, param in model_test.named_parameters():
224+
if param.requires_grad:
225+
assert param.grad is not None, f"Parameter {name} has no gradient computed"
226+
217227

218228
def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model):
219229
"""Copy weights from TEGrouped MoE model to sequential MoE model."""

0 commit comments

Comments
 (0)