@@ -12,7 +12,7 @@ namespace {
12
12
13
13
// TODO: support vectorized store
14
14
template <int BLOCK_ROW_OUTER, int BLOCK_ROW_INNER, int BLOCK_COL>
15
- __device__ nvfuser_index_t offsetAfterSwizzlePadding (
15
+ __device__ nvfuser_index_t outputOffsetAfterSwizzlePadding (
16
16
const nvfuser_index_t row_idx,
17
17
const nvfuser_index_t col_idx,
18
18
const nvfuser_index_t padded_col_size) {
@@ -64,36 +64,37 @@ template <
64
64
int BLOCK_ROW_INNER,
65
65
int BLOCK_COL,
66
66
int UNROLL_FACTOR>
67
- __device__ void groupedBlockLayout (
67
+ __device__ void preprocessGroupedMatmulInputSf (
68
68
T* output,
69
69
const T* input,
70
70
const nvfuser_index_t row_idx,
71
71
const nvfuser_index_t col_idx,
72
- const Index_T* expert_offsets ,
72
+ const Index_T* input_offsets ,
73
73
const Index_T* output_offsets,
74
74
const nvfuser_index_t col_size,
75
75
const nvfuser_index_t group_size) {
76
76
// find corresponding expert_id
77
77
int expert_id = 0 ;
78
78
for (int i = 0 ; i < group_size; ++i) {
79
- if (row_idx < expert_offsets [i + 1 ]) {
79
+ if (row_idx < input_offsets [i + 1 ]) {
80
80
expert_id = i;
81
81
break ;
82
82
}
83
83
}
84
84
85
85
// row idx for current group
86
- nvfuser_index_t c_row_idx = row_idx - expert_offsets [expert_id];
86
+ nvfuser_index_t c_row_idx = row_idx - input_offsets [expert_id];
87
87
// compute output group offset for current group
88
88
nvfuser_index_t padded_col_size =
89
89
(col_size + BLOCK_COL - 1 ) / BLOCK_COL * BLOCK_COL;
90
90
T* out_group_offset = output + output_offsets[expert_id] * padded_col_size;
91
91
92
92
// TODO: vectorized load/store instead of for loop
93
93
for (int i = 0 ; i < UNROLL_FACTOR && col_idx + i < col_size; ++i) {
94
- nvfuser_index_t index =
95
- offsetAfterSwizzlePadding<BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL>(
96
- c_row_idx, col_idx + i, padded_col_size);
94
+ nvfuser_index_t index = outputOffsetAfterSwizzlePadding<
95
+ BLOCK_ROW_OUTER,
96
+ BLOCK_ROW_INNER,
97
+ BLOCK_COL>(c_row_idx, col_idx + i, padded_col_size);
97
98
out_group_offset[index] = input[i];
98
99
}
99
100
}
0 commit comments