1616"""Support quantization for megatron linear layers."""
1717
1818import logging
19+ import types
1920import warnings
2021from typing import Any
2122
2829from megatron .core .parallel_state import get_data_parallel_group
2930from megatron .core .tensor_parallel .mappings import gather_from_sequence_parallel_region
3031from megatron .core .transformer import MegatronModule
32+ from megatron .core .transformer .attention import Attention
3133from megatron .core .transformer .utils import make_sharded_tensors_for_checkpoint
3234from megatron .core .utils import get_tensor_model_parallel_group_if_none
3335
3840)
3941from modelopt .torch .utils .distributed import ParallelState
4042
41- from ..model_calib import max_calibrate
4243from ..nn import QuantModule , QuantModuleRegistry , TensorQuantizer
4344from ..nn .modules .quant_linear import RealQuantLinear
4445from ..qtensor import QTensorWrapper
@@ -98,11 +99,6 @@ def quant_module_get_extra_state(self) -> dict:
9899 """
99100 extra_state = {}
100101
101- is_enabled = self .weight_quantizer .is_enabled if hasattr (self , "weight_quantizer" ) else False
102-
103- if not is_enabled :
104- return extra_state
105-
106102 quantizer_state = {}
107103 for name , module in self .named_modules ():
108104 if isinstance (module , TensorQuantizer ):
@@ -201,6 +197,19 @@ def quant_module_set_extra_state(self, state: Any):
201197 self .allow_post_restore = False
202198
203199
200+ def _create_incompatible_method (method_name : str ):
201+ """Create a method that raises an error for incompatible flash decode methods."""
202+
203+ def _incompatible_method (self , * args , ** kwargs ):
204+ raise NotImplementedError (
205+ f"{ method_name } is not compatible with ModelOpt KV cache quantization. "
206+ f"KV cache quantization requires core_attention to be called. "
207+ f"Please raise an issue at https://github.com/NVIDIA/Model-Optimizer if you need this feature."
208+ )
209+
210+ return _incompatible_method
211+
212+
204213def megatron_replace_quant_module_hook (model : torch .nn .Module ):
205214 """Configure Megatron-Core model quantization support.
206215
@@ -211,10 +220,39 @@ def megatron_replace_quant_module_hook(model: torch.nn.Module):
211220 1. We change TransformerConfig to enable heterogenous distributed checkpointing.
212221 2. We enable all sub- QuantModule to store quantizer_state as extra_state by
213222 typing-matching the QuantModuleRegistry.
223+ 3. For Attention modules, we configure them to use core_attention path for KV cache quantization.
214224 """
215225
226+ def _configure_attention_for_kv_cache_quant (module : Attention ):
227+ """Configure Attention module for KV cache quantization compatibility."""
228+ # Disable flash_decode if enabled - it bypasses core_attention (only called during inference)
229+ if getattr (module .config , "flash_decode" , False ):
230+ warnings .warn (
231+ "flash_decode=True is incompatible with ModelOpt KV cache quantization. "
232+ "Setting flash_decode=False. Flash decode bypasses core_attention during decode phase."
233+ )
234+ module .config .flash_decode = False
235+
236+ # Set dtype and device for core_attention (needed for modelopt_post_restore)
237+ assert hasattr (module , "core_attention" ), "Attention module must have core_attention"
238+ param = next (iter (module .parameters ()), None )
239+ if param is not None :
240+ module .core_attention .dtype = param .dtype
241+ module .core_attention .device = param .device
242+
243+ # Patch flash_decode and flash_decode_and_prefill to raise errors
244+ module .flash_decode = types .MethodType (_create_incompatible_method ("flash_decode" ), module )
245+ module .flash_decode_and_prefill = types .MethodType (
246+ _create_incompatible_method ("flash_decode_and_prefill" ), module
247+ )
248+
216249 def _register_extra_state_callbacks (model : torch .nn .Module ):
217250 for name , module in model .named_modules ():
251+ if name .endswith ("output_layer" ):
252+ # output_layer is not quantized,
253+ # hence we don't need to register extra state callbacks for it
254+ continue
255+
218256 if type (module ) in QuantModuleRegistry :
219257 # This module will be replaced as a QuantModule
220258 register_modelopt_extra_state_callbacks (
@@ -223,6 +261,10 @@ def _register_extra_state_callbacks(model: torch.nn.Module):
223261 quant_module_set_extra_state ,
224262 )
225263
264+ # Configure Attention modules for KV cache quantization
265+ if isinstance (module , Attention ):
266+ _configure_attention_for_kv_cache_quant (module )
267+
226268 for name , module in model .named_modules ():
227269 if isinstance (module , MegatronModule ):
228270 if "vision_model" not in name :
@@ -632,152 +674,44 @@ def _setup(self):
632674 self .k_bmm_quantizer = TensorQuantizer ()
633675 self .v_bmm_quantizer = TensorQuantizer ()
634676
635- def _calibrate_quantizers (self ):
636- """Calibrate quantizers with minimal dummy tensors."""
637- # Get device and dtype from the parent module's parameters
638- param = next (iter (self .parameters ()), None )
639- device = param .device if param is not None else torch .device ("cuda" )
640- dtype = param .dtype if param is not None else torch .float16
641-
642- # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
643- batch_size = 1
644- seq_len = 1
645-
646- # Get dimensions from config
647- num_heads = self .config .num_attention_heads
648- head_dim = (
649- self .config .kv_channels
650- if hasattr (self .config , "kv_channels" )
651- else self .config .hidden_size // num_heads
677+ # Set parallel_state for distributed sync of BMM quantizers
678+ try :
679+ data_parallel_group = get_data_parallel_group (with_context_parallel = True )
680+ except AssertionError :
681+ data_parallel_group = get_data_parallel_group ()
682+ self .parallel_state = ParallelState (
683+ data_parallel_group ,
684+ mcore_parallel .get_tensor_model_parallel_group (),
652685 )
653686
654- # Determine tensor format (default to sbhd if not specified)
655- apply_rope_fusion = getattr (self .config , "apply_rope_fusion" , False )
656- qkv_format = "bshd" if apply_rope_fusion else "sbhd"
657-
658- if qkv_format == "sbhd" :
659- dummy_tensor = torch .randn (
660- seq_len , batch_size , num_heads , head_dim , device = device , dtype = dtype
661- )
662- else :
663- dummy_tensor = torch .randn (
664- batch_size , seq_len , num_heads , head_dim , device = device , dtype = dtype
665- )
666-
667- # Calibrate each quantizer
668- quantizers = [
669- ("q_bmm_quantizer" , self .q_bmm_quantizer ),
670- ("k_bmm_quantizer" , self .k_bmm_quantizer ),
671- ("v_bmm_quantizer" , self .v_bmm_quantizer ),
672- ]
673-
674- for _ , quantizer in quantizers :
675- if quantizer is not None and quantizer .is_enabled ():
676- if not hasattr (quantizer , "_amax" ) or quantizer ._amax is None :
677- quantizer .reset_amax ()
678- max_calibrate (quantizer , lambda q : q (dummy_tensor ), distributed_sync = False )
679-
680687 def forward (self , query , key , value , * args , ** kwargs ):
681- """Apply post-RoPE quantization to KV cache.
682-
683- TEDotProductAttention receives Q, K, V after RoPE is applied,
684- so we quantize them directly for KV cache quantization.
685- """
688+ """Apply post-RoPE quantization to KV cache."""
686689 # Quantize Q, K, V
687690 query = self .q_bmm_quantizer (query )
688691 key = self .k_bmm_quantizer (key )
689692 value = self .v_bmm_quantizer (value )
690-
691693 return super ().forward (query , key , value , * args , ** kwargs )
692694
693- def sharded_state_dict (self , prefix = "" , sharded_offsets = (), metadata = None ):
694- """Create a sharded state dictionary for distributed checkpointing."""
695- sharded_state_dict = {}
696-
697- # First add non-quantizer parameters
698- for k , v in self .state_dict (prefix = "" , keep_vars = True ).items ():
699- if isinstance (v , torch .Tensor ) and v is not None and "_quantizer" not in k :
700- sharded_state_dict [prefix + k ] = v
701-
702- # Process _amax in bmm_quantizers
703- for name , quantizer in [
704- ("q_bmm_quantizer" , self .q_bmm_quantizer ),
705- ("k_bmm_quantizer" , self .k_bmm_quantizer ),
706- ("v_bmm_quantizer" , self .v_bmm_quantizer ),
707- ]:
708- if hasattr (quantizer , "_amax" ) and quantizer ._amax is not None :
709- amax_key = f"{ prefix } { name } ._amax"
710- sharded_state_dict [amax_key ] = quantizer ._amax
711-
712- # Process other quantizer parameters in bmm_quantizers
713- quantizer_state_dict = {
714- k : v
715- for k , v in self .state_dict (prefix = "" , keep_vars = True ).items ()
716- if isinstance (v , torch .Tensor ) and "_quantizer" in k and "_amax" not in k
717- }
718-
719- if quantizer_state_dict :
720- sharded_state_dict .update (
721- ** make_sharded_tensors_for_checkpoint (
722- quantizer_state_dict , prefix , {}, sharded_offsets
723- )
724- )
725-
726- return sharded_state_dict
727-
728- def _load_from_state_dict (self , state_dict , prefix , * args , ** kwargs ):
729- """Handle loading state dict for quantizers."""
730- for quantizer_name in ["q_bmm_quantizer" , "k_bmm_quantizer" , "v_bmm_quantizer" ]:
731- full_prefix = f"{ prefix } { quantizer_name } ."
732- amax_key = f"{ prefix } { quantizer_name } ._amax"
733-
734- # If amax is in state_dict, rename it to the format expected by TensorQuantizer
735- if amax_key in state_dict :
736- expected_amax_key = f"{ full_prefix } _amax"
737- state_dict [expected_amax_key ] = state_dict .pop (amax_key )
738-
739- # Handle other quantizer states
740- for k in list (state_dict .keys ()):
741- if "_quantizer" in k and "_amax" not in k :
742- name = k .split (prefix )[- 1 ] if prefix else k
743- if name in self .state_dict ():
744- state_dict [k ] = state_dict [k ].view_as (self .state_dict ()[name ])
745-
746- super ()._load_from_state_dict (state_dict , prefix , * args , ** kwargs )
747-
748695 def modelopt_post_restore (self , name = "" ):
749696 """Restore quantizer states after model loading."""
750- super ().modelopt_post_restore (name )
751-
752- def _check_unsupported_states (quantizer ):
753- """Check for unsupported quantizer states and warn if found."""
754- if not hasattr (quantizer , "state_dict" ):
755- return
756-
757- for k in quantizer .state_dict ():
758- if k not in ["_amax" , "_pre_quant_scale" ]:
759- warnings .warn (
760- f"Restore of { k } for { name } is not supported. The restore of this layer might be "
761- f"incorrect. Please implement a custom restore for { k } ."
762- )
763-
764- calibration_needed = False
765-
766- for quantizer_name , quantizer in [
767- ("q_bmm_quantizer" , self .q_bmm_quantizer ),
768- ("k_bmm_quantizer" , self .k_bmm_quantizer ),
769- ("v_bmm_quantizer" , self .v_bmm_quantizer ),
770- ]:
771- if not hasattr (self , quantizer_name ) or not quantizer .is_enabled ():
772- continue
773-
774- _check_unsupported_states (quantizer )
775-
776- if not hasattr (quantizer , "_amax" ) or quantizer ._amax is None :
777- calibration_needed = True
697+ for tq in [self .q_bmm_quantizer , self .k_bmm_quantizer , self .v_bmm_quantizer ]:
698+ # TODO: Add support for non-scalar states such as
699+ # Affine KVCache bias vector which is per head per channel
700+ if not all (v .numel () == 1 for v in tq .state_dict ().values ()):
701+ raise NotImplementedError (
702+ "Only scalar states are supported for KV Cache/BMM Quantizers"
703+ )
704+ # dtype and device should have been set in `megatron_replace_quant_module_hook`
705+ # via `_configure_attention_for_kv_cache_quant`
706+ assert hasattr (self , "device" ) and hasattr (self , "dtype" )
707+ self .to (device = self .device , dtype = self .dtype )
778708
779- if calibration_needed :
780- self ._calibrate_quantizers ()
709+ def sharded_state_dict (self , prefix = "" , sharded_offsets = (), metadata = None ):
710+ # Currently we do not need sharded_state_dict for TEDotProductAttention since the amax are scalar values.
711+ # However we would need this in future to support non-scalar states such as
712+ # Affine KVCache Quant bias vector.
713+ state_dict = self .state_dict (prefix = "" , keep_vars = True )
714+ return make_sharded_tensors_for_checkpoint (state_dict , prefix , {}, sharded_offsets )
781715
782716
783717@QuantModuleRegistry .register ({megatron_moe_layer .MoELayer : "megatron_moe_MoELayer" })
0 commit comments