File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed
src/compressed_tensors/quantization/lifecycle Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -83,7 +83,7 @@ def initialize_module_for_quantization(
8383
8484 if is_attention_module (module ):
8585 # quantized actions based on calltime status
86- _initialize_attn_scales (module )
86+ _initialize_attn_scales (module , scheme . output_activations )
8787
8888 else :
8989
@@ -220,10 +220,18 @@ def _initialize_scale_zero_point(
220220 register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
221221
222222
223- def _initialize_attn_scales (module : Module ) -> None :
223+ def _initialize_attn_scales (module : Module , quantization_args : QuantizationArgs ) -> None :
224224 """Initlaize k_scale, v_scale for self_attn"""
225225
226- expected_shape = 1 # per tensor
226+ if quantization_args .strategy == QuantizationStrategy .CHANNEL :
227+ expected_shape = module .k_proj .out_features
228+ elif quantization_args .strategy == QuantizationStrategy .TENSOR :
229+ expected_shape = 1
230+ else :
231+ raise ValueError (
232+ f"One of { (QuantizationStrategy .TENSOR , QuantizationStrategy .CHANNEL )} must be specified "
233+ f"for kv cache quantization."
234+ )
227235
228236 param = next (module .parameters ())
229237 scale_dtype = param .dtype
You can’t perform that action at this time.
0 commit comments