Skip to content

Commit 29de3cd

Browse files
yugong333jeejeelee
andauthored
Adding SplitK in fused_moe_lora kernel (vllm-project#27818)
Signed-off-by: Yu Gong <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent 7e2729b commit 29de3cd

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,17 @@ def _fused_moe_lora_kernel(
8888
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
8989

9090
# calculate pid_m,pid_n
91+
pid_sk = pid % SPLIT_K
92+
pid_m_n = pid // SPLIT_K
9193
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
9294
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
95+
9396
num_pid_in_group = GROUP_SIZE_M * num_pid_n
94-
group_id = pid // num_pid_in_group
97+
group_id = pid_m_n // num_pid_in_group
9598
first_pid_m = group_id * GROUP_SIZE_M
9699
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
97-
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
98-
pid_n = (pid % num_pid_in_group) // group_size_m
100+
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
101+
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
99102

100103
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
101104
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
@@ -113,7 +116,7 @@ def _fused_moe_lora_kernel(
113116
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
114117

115118
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
116-
offs_k = tl.arange(0, BLOCK_SIZE_K)
119+
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
117120

118121
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
119122
token_ind = stride_tl * lora_idx + offs_token_id
@@ -131,7 +134,8 @@ def _fused_moe_lora_kernel(
131134
cur_b_ptr
132135
+ lora_idx * stride_bl
133136
+ expert_id * stride_be
134-
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
137+
+ offs_k[:, None] * stride_bk
138+
+ offs_bn[None, :] * stride_bn
135139
)
136140

137141
# accumulator

0 commit comments

Comments
 (0)