@@ -88,14 +88,17 @@ def _fused_moe_lora_kernel(
8888 grid_k = tl .cdiv (K , BLOCK_SIZE_K * SPLIT_K )
8989
9090 # calculate pid_m,pid_n
91+ pid_sk = pid % SPLIT_K
92+ pid_m_n = pid // SPLIT_K
9193 num_pid_m = tl .cdiv (EM , BLOCK_SIZE_M )
9294 num_pid_n = tl .cdiv (N , BLOCK_SIZE_N )
95+
9396 num_pid_in_group = GROUP_SIZE_M * num_pid_n
94- group_id = pid // num_pid_in_group
97+ group_id = pid_m_n // num_pid_in_group
9598 first_pid_m = group_id * GROUP_SIZE_M
9699 group_size_m = min (num_pid_m - first_pid_m , GROUP_SIZE_M )
97- pid_m = first_pid_m + ((pid % num_pid_in_group ) % group_size_m )
98- pid_n = (pid % num_pid_in_group ) // group_size_m
100+ pid_m = first_pid_m + ((pid_m_n % num_pid_in_group ) % group_size_m )
101+ pid_n = (pid_m_n % num_pid_in_group ) // group_size_m
99102
100103 num_tokens_post_padded = tl .load (num_tokens_post_padded_ptr + lora_idx )
101104 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded :
@@ -113,7 +116,7 @@ def _fused_moe_lora_kernel(
113116 cur_c_ptr = c_ptr + (slice_id % num_slice_c ) * slice_c_size
114117
115118 offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N ).to (tl .int64 )) % N
116- offs_k = tl .arange (0 , BLOCK_SIZE_K )
119+ offs_k = pid_sk * BLOCK_SIZE_K + tl .arange (0 , BLOCK_SIZE_K )
117120
118121 offs_token_id = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M ).to (tl .int64 )
119122 token_ind = stride_tl * lora_idx + offs_token_id
@@ -131,7 +134,8 @@ def _fused_moe_lora_kernel(
131134 cur_b_ptr
132135 + lora_idx * stride_bl
133136 + expert_id * stride_be
134- + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
137+ + offs_k [:, None ] * stride_bk
138+ + offs_bn [None , :] * stride_bn
135139 )
136140
137141 # accumulator
0 commit comments