Skip to content

Commit 7c327f6

Browse files
committed
method renaming
1 parent c1e67f6 commit 7c327f6

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

csrc/kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ class KernelIrScanner : private IrVisitor {
272272
summary_.has_argsort = true;
273273
}
274274

275-
void handle(GroupedBlockScalingFactorLayoutOp* aop) final {
276-
summary_.has_grouped_block_sf_layout = true;
275+
void handle(PreprocessGroupedMatmulInputSf* aop) final {
276+
summary_.has_preprocess_grouped_matmul_input_sf = true;
277277
}
278278

279279
void handle(TopKOp* top) final {

csrc/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ struct KernelSummary {
143143
bool has_argsort = false;
144144

145145
//! Do we have any grouped_block_sf_layout op?
146-
bool has_grouped_block_sf_layout = false;
146+
bool has_preprocess_grouped_matmul_input_sf = false;
147147

148148
//! Do we have any topk op?
149149
bool has_topk = false;

csrc/runtime/compiled_kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1445,7 +1445,7 @@ std::string CompiledKernel::getStructuredCode() const {
14451445
kernel()->summary().has_argsort,
14461446
kernel()->summary().has_topk,
14471447
kernel()->summary().has_scan,
1448-
kernel()->summary().has_grouped_block_sf_layout);
1448+
kernel()->summary().has_preprocess_grouped_matmul_input_sf);
14491449
}
14501450

14511451
std::string CompiledKernel::disassembledKernelSASS() const {

runtime/block_layout.cu

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace {
1212

1313
// TODO: support vectorized store
1414
template <int BLOCK_ROW_OUTER, int BLOCK_ROW_INNER, int BLOCK_COL>
15-
__device__ nvfuser_index_t offsetAfterSwizzlePadding(
15+
__device__ nvfuser_index_t outputOffsetAfterSwizzlePadding(
1616
const nvfuser_index_t row_idx,
1717
const nvfuser_index_t col_idx,
1818
const nvfuser_index_t padded_col_size) {
@@ -64,36 +64,37 @@ template <
6464
int BLOCK_ROW_INNER,
6565
int BLOCK_COL,
6666
int UNROLL_FACTOR>
67-
__device__ void groupedBlockLayout(
67+
__device__ void preprocessGroupedMatmulInputSf(
6868
T* output,
6969
const T* input,
7070
const nvfuser_index_t row_idx,
7171
const nvfuser_index_t col_idx,
72-
const Index_T* expert_offsets,
72+
const Index_T* input_offsets,
7373
const Index_T* output_offsets,
7474
const nvfuser_index_t col_size,
7575
const nvfuser_index_t group_size) {
7676
// find corresponding expert_id
7777
int expert_id = 0;
7878
for (int i = 0; i < group_size; ++i) {
79-
if (row_idx < expert_offsets[i + 1]) {
79+
if (row_idx < input_offsets[i + 1]) {
8080
expert_id = i;
8181
break;
8282
}
8383
}
8484

8585
// row idx for current group
86-
nvfuser_index_t c_row_idx = row_idx - expert_offsets[expert_id];
86+
nvfuser_index_t c_row_idx = row_idx - input_offsets[expert_id];
8787
// compute output group offset for current group
8888
nvfuser_index_t padded_col_size =
8989
(col_size + BLOCK_COL - 1) / BLOCK_COL * BLOCK_COL;
9090
T* out_group_offset = output + output_offsets[expert_id] * padded_col_size;
9191

9292
// TODO: vectorized load/store instead of for loop
9393
for (int i = 0; i < UNROLL_FACTOR && col_idx + i < col_size; ++i) {
94-
nvfuser_index_t index =
95-
offsetAfterSwizzlePadding<BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL>(
96-
c_row_idx, col_idx + i, padded_col_size);
94+
nvfuser_index_t index = outputOffsetAfterSwizzlePadding<
95+
BLOCK_ROW_OUTER,
96+
BLOCK_ROW_INNER,
97+
BLOCK_COL>(c_row_idx, col_idx + i, padded_col_size);
9798
out_group_offset[index] = input[i];
9899
}
99100
}

0 commit comments

Comments
 (0)