diff --git a/csrc/vertical_slash_index.cu b/csrc/vertical_slash_index.cu index 45af042..f69d7f7 100644 --- a/csrc/vertical_slash_index.cu +++ b/csrc/vertical_slash_index.cu @@ -2,22 +2,12 @@ // Licensed under the MIT license. #include - #include #include #include #include - #include -// __device__ int min(int x, int y) { -// return x < y ? x : y; -// } - -// __device__ int max(int x, int y) { -// return x > y ? x : y; -// } - __device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { for (int idx = range_start; idx < range_end; idx += block_size) { block_offset[block_count++] = idx; @@ -50,8 +40,7 @@ __global__ void convert_vertical_slash_indexes_kernel( return; } int end_m = start_m + BLOCK_SIZE_M; - vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; - slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; block_count += row_offset; block_offset += row_offset * NNZ_S; @@ -60,26 +49,28 @@ __global__ void convert_vertical_slash_indexes_kernel( int tmp_col_cnt = 0, tmp_blk_cnt = 0; int s = 0, v = 0; - int v_idx = vertical_indexes[v++]; - int s_idx = slash_indexes[s++]; + int v_idx = vertical_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_V + v++]; + int s_idx = slash_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_S + s++]; + while (s_idx >= end_m) { - s_idx = slash_indexes[s++]; + s_idx = slash_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_S + s++]; } s_idx = max(end_m - s_idx, BLOCK_SIZE_M); int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + while (1) { if (v_idx < range_end) { if (v_idx < range_start) { column_index[tmp_col_cnt++] = v_idx; } if (v < NNZ_V) { - v_idx = vertical_indexes[v++]; + v_idx = vertical_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_V + v++]; } else { v_idx = end_m + BLOCK_SIZE_M; } } else { if (s < NNZ_S) { - s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); + s_idx = max(end_m - slash_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_S + s++], BLOCK_SIZE_M); } else { save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); break; @@ -98,6 +89,7 @@ __global__ void convert_vertical_slash_indexes_kernel( column_count[0] = tmp_col_cnt; } + void convert_vertical_slash_indexes_64x64( const int* seqlens, // [BATCH, ] const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] @@ -165,3 +157,4 @@ std::vector convert_vertical_slash_indexes( return { block_count, block_offset, column_count, column_index }; } +