Skip to content

Commit cc410e8

Browse files
authored
[Bugfix] Fix weight_loader v1 block scale (vllm-project#31103)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 825c2dc commit cc410e8

File tree

1 file changed

+40
-27
lines changed

1 file changed

+40
-27
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
8080
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
8181

8282

83+
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
84+
assert weight_block_size is not None
85+
block_n = weight_block_size[0]
86+
shard_offset = (shard_offset + block_n - 1) // block_n
87+
shard_size = (shard_size + block_n - 1) // block_n
88+
return shard_size, shard_offset
89+
90+
8391
def adjust_bitsandbytes_4bit_shard(
8492
param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
8593
) -> tuple[int, int]:
@@ -763,8 +771,18 @@ def weight_loader(
763771

764772
assert loaded_shard_id < len(self.output_sizes)
765773
if output_dim is not None:
766-
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
767-
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
774+
shard_offset = sum(self.output_sizes[:loaded_shard_id])
775+
shard_size = self.output_sizes[loaded_shard_id]
776+
777+
if isinstance(param, BlockQuantScaleParameter):
778+
weight_block_size = getattr(self, "weight_block_size", None)
779+
shard_size, shard_offset = adjust_block_scale_shard(
780+
weight_block_size, shard_size, shard_offset
781+
)
782+
783+
shard_offset //= self.tp_size
784+
shard_size //= self.tp_size
785+
768786
# Special case for quantization.
769787
# If quantized, we need to adjust the offset and size to account
770788
# for the packing.
@@ -867,24 +885,17 @@ def weight_loader_v2(
867885

868886
assert loaded_shard_id < len(self.output_sizes)
869887

888+
shard_offset = sum(self.output_sizes[:loaded_shard_id])
889+
shard_size = self.output_sizes[loaded_shard_id]
890+
870891
if isinstance(param, BlockQuantScaleParameter):
871-
assert self.quant_method is not None
872-
# Assume the weight block size has been set by quant method
873-
assert hasattr(self, "weight_block_size")
874-
weight_block_size = self.weight_block_size
875-
assert weight_block_size is not None
876-
block_n, _ = weight_block_size[0], weight_block_size[1]
877-
shard_offset = (
878-
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
879-
) // self.tp_size
880-
shard_size = (
881-
(self.output_sizes[loaded_shard_id] + block_n - 1)
882-
// block_n
883-
// self.tp_size
892+
weight_block_size = getattr(self, "weight_block_size", None)
893+
shard_size, shard_offset = adjust_block_scale_shard(
894+
weight_block_size, shard_size, shard_offset
884895
)
885-
else:
886-
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
887-
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
896+
897+
shard_offset //= self.tp_size
898+
shard_size //= self.tp_size
888899

889900
param.load_merged_column_weight(
890901
loaded_weight=loaded_weight,
@@ -1066,16 +1077,11 @@ def weight_loader_v2(
10661077
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
10671078
shard_size = self._get_shard_size_mapping(loaded_shard_id)
10681079

1069-
# Note(simon): This is needed for Qwen3's fp8 quantization.
10701080
if isinstance(param, BlockQuantScaleParameter):
1071-
assert self.quant_method is not None
1072-
# Assume the weight block size has been set by quant method
1073-
assert hasattr(self, "weight_block_size")
1074-
weight_block_size = self.weight_block_size
1075-
assert weight_block_size is not None
1076-
block_n, _ = weight_block_size[0], weight_block_size[1]
1077-
shard_offset = (shard_offset + block_n - 1) // block_n
1078-
shard_size = (shard_size + block_n - 1) // block_n
1081+
weight_block_size = getattr(self, "weight_block_size", None)
1082+
shard_size, shard_offset = adjust_block_scale_shard(
1083+
weight_block_size, shard_size, shard_offset
1084+
)
10791085

10801086
param.load_qkv_weight(
10811087
loaded_weight=loaded_weight,
@@ -1208,6 +1214,13 @@ def weight_loader(
12081214
elif loaded_shard_id == "v":
12091215
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
12101216
shard_size = self.num_kv_heads * self.v_head_size
1217+
1218+
if isinstance(param, BlockQuantScaleParameter):
1219+
weight_block_size = getattr(self, "weight_block_size", None)
1220+
shard_size, shard_offset = adjust_block_scale_shard(
1221+
weight_block_size, shard_size, shard_offset
1222+
)
1223+
12111224
# Special case for Quantized Weights.
12121225
# If quantized, we need to adjust the offset and size to account
12131226
# for the packing.

0 commit comments

Comments
 (0)