@@ -80,6 +80,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
8080 return shard_size * marlin_tile_size , shard_offset * marlin_tile_size
8181
8282
83+ def adjust_block_scale_shard (weight_block_size , shard_size , shard_offset ):
84+ assert weight_block_size is not None
85+ block_n = weight_block_size [0 ]
86+ shard_offset = (shard_offset + block_n - 1 ) // block_n
87+ shard_size = (shard_size + block_n - 1 ) // block_n
88+ return shard_size , shard_offset
89+
90+
8391def adjust_bitsandbytes_4bit_shard (
8492 param : Parameter , shard_offsets : dict [str , tuple [int , int ]], loaded_shard_id : str
8593) -> tuple [int , int ]:
@@ -763,8 +771,18 @@ def weight_loader(
763771
764772 assert loaded_shard_id < len (self .output_sizes )
765773 if output_dim is not None :
766- shard_offset = sum (self .output_sizes [:loaded_shard_id ]) // self .tp_size
767- shard_size = self .output_sizes [loaded_shard_id ] // self .tp_size
774+ shard_offset = sum (self .output_sizes [:loaded_shard_id ])
775+ shard_size = self .output_sizes [loaded_shard_id ]
776+
777+ if isinstance (param , BlockQuantScaleParameter ):
778+ weight_block_size = getattr (self , "weight_block_size" , None )
779+ shard_size , shard_offset = adjust_block_scale_shard (
780+ weight_block_size , shard_size , shard_offset
781+ )
782+
783+ shard_offset //= self .tp_size
784+ shard_size //= self .tp_size
785+
768786 # Special case for quantization.
769787 # If quantized, we need to adjust the offset and size to account
770788 # for the packing.
@@ -867,24 +885,17 @@ def weight_loader_v2(
867885
868886 assert loaded_shard_id < len (self .output_sizes )
869887
888+ shard_offset = sum (self .output_sizes [:loaded_shard_id ])
889+ shard_size = self .output_sizes [loaded_shard_id ]
890+
870891 if isinstance (param , BlockQuantScaleParameter ):
871- assert self .quant_method is not None
872- # Assume the weight block size has been set by quant method
873- assert hasattr (self , "weight_block_size" )
874- weight_block_size = self .weight_block_size
875- assert weight_block_size is not None
876- block_n , _ = weight_block_size [0 ], weight_block_size [1 ]
877- shard_offset = (
878- (sum (self .output_sizes [:loaded_shard_id ]) + block_n - 1 ) // block_n
879- ) // self .tp_size
880- shard_size = (
881- (self .output_sizes [loaded_shard_id ] + block_n - 1 )
882- // block_n
883- // self .tp_size
892+ weight_block_size = getattr (self , "weight_block_size" , None )
893+ shard_size , shard_offset = adjust_block_scale_shard (
894+ weight_block_size , shard_size , shard_offset
884895 )
885- else :
886- shard_offset = sum ( self . output_sizes [: loaded_shard_id ]) // self .tp_size
887- shard_size = self . output_sizes [ loaded_shard_id ] // self .tp_size
896+
897+ shard_offset //= self .tp_size
898+ shard_size //= self .tp_size
888899
889900 param .load_merged_column_weight (
890901 loaded_weight = loaded_weight ,
@@ -1066,16 +1077,11 @@ def weight_loader_v2(
10661077 shard_offset = self ._get_shard_offset_mapping (loaded_shard_id )
10671078 shard_size = self ._get_shard_size_mapping (loaded_shard_id )
10681079
1069- # Note(simon): This is needed for Qwen3's fp8 quantization.
10701080 if isinstance (param , BlockQuantScaleParameter ):
1071- assert self .quant_method is not None
1072- # Assume the weight block size has been set by quant method
1073- assert hasattr (self , "weight_block_size" )
1074- weight_block_size = self .weight_block_size
1075- assert weight_block_size is not None
1076- block_n , _ = weight_block_size [0 ], weight_block_size [1 ]
1077- shard_offset = (shard_offset + block_n - 1 ) // block_n
1078- shard_size = (shard_size + block_n - 1 ) // block_n
1081+ weight_block_size = getattr (self , "weight_block_size" , None )
1082+ shard_size , shard_offset = adjust_block_scale_shard (
1083+ weight_block_size , shard_size , shard_offset
1084+ )
10791085
10801086 param .load_qkv_weight (
10811087 loaded_weight = loaded_weight ,
@@ -1208,6 +1214,13 @@ def weight_loader(
12081214 elif loaded_shard_id == "v" :
12091215 shard_offset = (self .num_heads + self .num_kv_heads ) * self .head_size
12101216 shard_size = self .num_kv_heads * self .v_head_size
1217+
1218+ if isinstance (param , BlockQuantScaleParameter ):
1219+ weight_block_size = getattr (self , "weight_block_size" , None )
1220+ shard_size , shard_offset = adjust_block_scale_shard (
1221+ weight_block_size , shard_size , shard_offset
1222+ )
1223+
12111224 # Special case for Quantized Weights.
12121225 # If quantized, we need to adjust the offset and size to account
12131226 # for the packing.
0 commit comments