22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import torch
5- import triton
6- import triton .language as tl
75
6+ from vllm .triton_utils import tl , triton
87from vllm .utils .torch_utils import direct_register_custom_op
98
109_LORA_PTR_DICT : dict [tuple [int , ...], torch .tensor ] = {}
@@ -110,7 +109,7 @@ def _fused_moe_lora_kernel(
110109
111110 # get a_ptr,b_ptr,c_ptr
112111 cur_a_ptr = a_ptr + (slice_id % num_slice_a ) * slice_a_size
113- cur_b_ptr = tl .load (b_ptr + slice_id ).to (tl .pointer_type (tl . bfloat16 ))
112+ cur_b_ptr = tl .load (b_ptr + slice_id ).to (tl .pointer_type (c_ptr . dtype . element_ty ))
114113 cur_c_ptr = c_ptr + (slice_id % num_slice_c ) * slice_c_size
115114
116115 offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N ).to (tl .int64 )) % N
@@ -154,7 +153,7 @@ def _fused_moe_lora_kernel(
154153 moe_weight = tl .load (topk_weights_ptr + offs_token , mask = token_mask , other = 0 )
155154 accumulator = accumulator * moe_weight [:, None ]
156155
157- accumulator = accumulator .to (tl . bfloat16 )
156+ accumulator = accumulator .to (c_ptr . dtype . element_ty )
158157 # Write back the block of the output
159158 offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
160159 c_ptrs = cur_c_ptr + stride_cm * offs_token [:, None ] + stride_cn * offs_cn [None , :]
@@ -205,6 +204,10 @@ def _fused_moe_lora(
205204 assert output .shape [0 ] == topk_weights .shape [0 ]
206205 assert top_k_num == topk_weights .shape [1 ]
207206
207+ for lora_a , lora_b in zip (lora_a_stacked , lora_b_stacked ):
208+ assert lora_a .dtype == lora_b .dtype == output .dtype == qcurr_hidden_states .dtype
209+ assert lora_a .dtype in [torch .float16 , torch .bfloat16 ]
210+
208211 device = qcurr_hidden_states .device
209212 num_slices = len (lora_a_stacked )
210213
@@ -227,9 +230,9 @@ def _fused_moe_lora(
227230 num_tokens = M * top_k_num
228231 w1_output_dim_size = w1_lora_b_stacked .shape [2 ]
229232
230- lora_intermediate_cache1 = torch .zeros (
233+ lora_intermediate_cache1 = torch .empty (
231234 (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size )),
232- dtype = torch . bfloat16 ,
235+ dtype = output . dtype ,
233236 device = device ,
234237 )
235238
@@ -288,10 +291,6 @@ def _fused_moe_lora(
288291 K = max_lora_rank
289292 N = w1_output_dim_size
290293
291- # a_intermediate_cache1 = a_intermediate_cache1.view(
292- # M, -1, a_intermediate_cache1.shape[3]
293- # )
294-
295294 a_intermediate_cache1 = a_intermediate_cache1 .view (
296295 - 1 , a_intermediate_cache1 .shape [3 ]
297296 )
0 commit comments