File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed 
src/compressed_tensors/quantization/lifecycle Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -83,7 +83,7 @@ def initialize_module_for_quantization(
8383
8484    if  is_attention_module (module ):
8585        # quantized actions based on calltime status 
86-         _initialize_attn_scales (module )
86+         _initialize_attn_scales (module ,  scheme . output_activations )
8787
8888    else :
8989
@@ -220,10 +220,18 @@ def _initialize_scale_zero_point(
220220        register_offload_parameter (module , f"{ base_name }  , init_g_idx )
221221
222222
223- def  _initialize_attn_scales (module : Module ) ->  None :
223+ def  _initialize_attn_scales (module : Module ,  quantization_args :  QuantizationArgs ) ->  None :
224224    """Initlaize k_scale, v_scale for  self_attn""" 
225225
226-     expected_shape  =  1   # per tensor 
226+     if  quantization_args .strategy  ==  QuantizationStrategy .CHANNEL :
227+         expected_shape  =  module .k_proj .out_features 
228+     elif  quantization_args .strategy  ==  QuantizationStrategy .TENSOR :
229+         expected_shape  =  1 
230+     else :
231+         raise  ValueError (
232+             f"One of { (QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL )}  
233+             f"for kv cache quantization." 
234+         )
227235
228236    param  =  next (module .parameters ())
229237    scale_dtype  =  param .dtype 
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments