Skip to content

Commit f4e8154

Browse files
authored
[Kernel] Enable moe LoRA kernel support FP16 (vllm-project#27468)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent a663f6a commit f4e8154

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

tests/lora/test_fused_moe_lora_kernel.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def use_torch(
204204
return torch.stack(outputs, dim=0)
205205

206206

207+
DTYPES = [torch.float16, torch.bfloat16]
208+
DEVICES = [f"cuda:{0}"]
209+
SEED = [42]
210+
211+
207212
@pytest.mark.parametrize("num_tokens", [100])
208213
@pytest.mark.parametrize("top_k_num", [6, 12])
209214
@pytest.mark.parametrize("num_experts", [64])
@@ -212,6 +217,9 @@ def use_torch(
212217
@pytest.mark.parametrize("K", [2048])
213218
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
214219
@pytest.mark.parametrize("block_size", [16])
220+
@pytest.mark.parametrize("dtype", DTYPES)
221+
@pytest.mark.parametrize("device", DEVICES)
222+
@pytest.mark.parametrize("seed", SEED)
215223
def test_fused_moe_lora_kernel(
216224
num_tokens,
217225
top_k_num,
@@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel(
221229
K,
222230
max_lora_rank,
223231
block_size,
232+
dtype,
233+
device,
234+
seed,
224235
):
225-
torch.set_default_device("cuda:0")
226-
current_platform.seed_everything(42)
236+
torch.set_default_device(device)
237+
current_platform.seed_everything(seed)
227238
# the number of randomly generated sentences.
228239
num_sequences = 10
229240
# generate data
@@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel(
240251
max_lora_rank,
241252
K,
242253
),
243-
dtype=torch.bfloat16,
254+
dtype=dtype,
244255
)
245256
]
246257
lora_b_stacked = [
@@ -251,19 +262,19 @@ def test_fused_moe_lora_kernel(
251262
N,
252263
max_lora_rank,
253264
),
254-
dtype=torch.bfloat16,
265+
dtype=dtype,
255266
)
256267
]
257268
hidden_states = torch.rand(
258269
(
259270
num_tokens,
260271
K,
261272
),
262-
dtype=torch.bfloat16,
273+
dtype=dtype,
263274
)
264275

265276
# fused_moe_lora_kernel output
266-
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
277+
output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
267278
use_fused_moe_lora_kernel(
268279
topk_ids,
269280
topk_weights,

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import torch
5-
import triton
6-
import triton.language as tl
75

6+
from vllm.triton_utils import tl, triton
87
from vllm.utils.torch_utils import direct_register_custom_op
98

109
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
@@ -110,7 +109,7 @@ def _fused_moe_lora_kernel(
110109

111110
# get a_ptr,b_ptr,c_ptr
112111
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
113-
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16))
112+
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
114113
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
115114

116115
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
@@ -154,7 +153,7 @@ def _fused_moe_lora_kernel(
154153
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
155154
accumulator = accumulator * moe_weight[:, None]
156155

157-
accumulator = accumulator.to(tl.bfloat16)
156+
accumulator = accumulator.to(c_ptr.dtype.element_ty)
158157
# Write back the block of the output
159158
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
160159
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
@@ -205,6 +204,10 @@ def _fused_moe_lora(
205204
assert output.shape[0] == topk_weights.shape[0]
206205
assert top_k_num == topk_weights.shape[1]
207206

207+
for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked):
208+
assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype
209+
assert lora_a.dtype in [torch.float16, torch.bfloat16]
210+
208211
device = qcurr_hidden_states.device
209212
num_slices = len(lora_a_stacked)
210213

@@ -227,9 +230,9 @@ def _fused_moe_lora(
227230
num_tokens = M * top_k_num
228231
w1_output_dim_size = w1_lora_b_stacked.shape[2]
229232

230-
lora_intermediate_cache1 = torch.zeros(
233+
lora_intermediate_cache1 = torch.empty(
231234
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
232-
dtype=torch.bfloat16,
235+
dtype=output.dtype,
233236
device=device,
234237
)
235238

@@ -288,10 +291,6 @@ def _fused_moe_lora(
288291
K = max_lora_rank
289292
N = w1_output_dim_size
290293

291-
# a_intermediate_cache1 = a_intermediate_cache1.view(
292-
# M, -1, a_intermediate_cache1.shape[3]
293-
# )
294-
295294
a_intermediate_cache1 = a_intermediate_cache1.view(
296295
-1, a_intermediate_cache1.shape[3]
297296
)

0 commit comments

Comments
 (0)