Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored vertical_slash_index.cu for performance improvement #72

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions csrc/vertical_slash_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,12 @@
// Licensed under the MIT license.

#include <assert.h>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <torch/extension.h>

#include <cuda.h>

// __device__ int min(int x, int y) {
// return x < y ? x : y;
// }

// __device__ int max(int x, int y) {
// return x > y ? x : y;
// }

__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[block_count++] = idx;
Expand Down Expand Up @@ -50,8 +40,7 @@ __global__ void convert_vertical_slash_indexes_kernel(
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;

int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
Expand All @@ -60,26 +49,28 @@ __global__ void convert_vertical_slash_indexes_kernel(

int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
int v_idx = vertical_indexes[v++];
int s_idx = slash_indexes[s++];
int v_idx = vertical_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_V + v++];
int s_idx = slash_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_S + s++];

while (s_idx >= end_m) {
s_idx = slash_indexes[s++];
s_idx = slash_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_S + s++];
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;

while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
v_idx = vertical_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_V + v++];
} else {
v_idx = end_m + BLOCK_SIZE_M;
}
} else {
if (s < NNZ_S) {
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
s_idx = max(end_m - slash_indexes[(batch_idx * N_HEADS + head_idx) * NNZ_S + s++], BLOCK_SIZE_M);
} else {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
break;
Expand All @@ -98,6 +89,7 @@ __global__ void convert_vertical_slash_indexes_kernel(
column_count[0] = tmp_col_cnt;
}


void convert_vertical_slash_indexes_64x64(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
Expand Down Expand Up @@ -165,3 +157,4 @@ std::vector<at::Tensor> convert_vertical_slash_indexes(

return { block_count, block_offset, column_count, column_index };
}