Skip to content

Commit ecc3dd6

Browse files
authored
[Bugfix] Fix FusedMoE LoRA kernel offs_token out of bound value (vllm-project#32279)
Signed-off-by: Xin Yang <xyangx@amazon.com>
1 parent 7e1f10d commit ecc3dd6

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ def _fused_moe_lora_kernel(
139139
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
140140
token_ind = stride_tl * lora_id + offs_token_id
141141
offs_token = tl.load(
142-
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
142+
sorted_token_ids_ptr + token_ind,
143+
mask=token_ind < max_loras * stride_tl,
144+
other=num_valid_tokens,
143145
)
144146
token_mask = offs_token < num_valid_tokens
145147

@@ -185,7 +187,7 @@ def _fused_moe_lora_kernel(
185187
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
186188

187189
if MUL_ROUTED_WEIGHT:
188-
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
190+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
189191
accumulator = accumulator * moe_weight[:, None]
190192
accumulator = accumulator.to(c_ptr.dtype.element_ty)
191193
# Write back the block of the output

0 commit comments

Comments
 (0)