File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -139,7 +139,9 @@ def _fused_moe_lora_kernel(
139139 offs_token_id = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M ).to (tl .int64 )
140140 token_ind = stride_tl * lora_id + offs_token_id
141141 offs_token = tl .load (
142- sorted_token_ids_ptr + token_ind , token_ind < max_loras * stride_tl , 0
142+ sorted_token_ids_ptr + token_ind ,
143+ mask = token_ind < max_loras * stride_tl ,
144+ other = num_valid_tokens ,
143145 )
144146 token_mask = offs_token < num_valid_tokens
145147
@@ -185,7 +187,7 @@ def _fused_moe_lora_kernel(
185187 b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
186188
187189 if MUL_ROUTED_WEIGHT :
188- moe_weight = tl .load (topk_weights_ptr + offs_token , mask = token_mask , other = 0 )
190+ moe_weight = tl .load (topk_weights_ptr + offs_token , mask = token_mask , other = 0.0 )
189191 accumulator = accumulator * moe_weight [:, None ]
190192 accumulator = accumulator .to (c_ptr .dtype .element_ty )
191193 # Write back the block of the output
You can’t perform that action at this time.
0 commit comments