Skip to content

Commit 2e96809

Browse files
feat: support attention indexer on mlu for deepseek v3.2 prerequisite. (#311)
Co-authored-by: phantomlei <[email protected]>
1 parent 3028d49 commit 2e96809

File tree

15 files changed

+1023
-45
lines changed

15 files changed

+1023
-45
lines changed

cibuild/build_mlu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function error() {
66
exit 1
77
}
88

9-
IMAGE="cambricon-base/pytorch:v25.06.0-torch2.7.1-torchmlu1.27.2-ubuntu22.04-py310_xllm251016"
9+
IMAGE="cambricon-base/pytorch:v25.06.0-torch2.7.1-torchmlu1.27.2-ubuntu22.04-py310_xllm251104"
1010

1111
RUN_OPTS=(
1212
--rm

xllm/core/framework/model/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set(BASE_DEPS
1212
:chat_template
1313
glog::glog
1414
torch
15+
torch_python
1516
)
1617

1718
if(USE_NPU)

xllm/core/kernels/mlu/attention.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ limitations under the License.
1818
namespace xllm::kernel::mlu {
1919

2020
void reshape_paged_cache(torch::Tensor& key,
21-
torch::Tensor& value,
21+
const std::optional<torch::Tensor>& value,
2222
torch::Tensor& k_cache,
23-
torch::Tensor& v_cache,
23+
const std::optional<torch::Tensor>& v_cache,
2424
const torch::Tensor& slot_mapping,
2525
bool direction) {
2626
tmo::torch_api::reshape_paged_cache(
@@ -115,4 +115,40 @@ void batch_decode(const torch::Tensor& query,
115115
kv_cache_quant_bit_size);
116116
}
117117

118+
void masked_indexer_select_paged_kv(const bool is_prefill,
119+
const torch::Tensor& query,
120+
const torch::Tensor& cu_seq_q_lens,
121+
const torch::Tensor& cu_seq_k_lens,
122+
const torch::Tensor& q_scale,
123+
const torch::Tensor& weights,
124+
const double softmax_scale,
125+
const torch::Tensor& k_cache,
126+
const torch::Tensor& k_context_lens,
127+
const torch::Tensor& k_cache_block_table,
128+
const torch::Tensor& k_scale_cache,
129+
const int64_t index_topk,
130+
const torch::Tensor& kv_cache_block_table,
131+
const int64_t kv_cache_block_size,
132+
const torch::Tensor& new_block_table,
133+
const torch::Tensor& new_context_lens,
134+
const int64_t quant_block_size) {
135+
tmo::torch_api::masked_indexer_select_paged_kv(is_prefill,
136+
query,
137+
cu_seq_q_lens,
138+
cu_seq_k_lens,
139+
q_scale,
140+
weights,
141+
softmax_scale,
142+
k_cache,
143+
k_context_lens,
144+
k_cache_block_table,
145+
k_scale_cache,
146+
index_topk,
147+
kv_cache_block_table,
148+
kv_cache_block_size,
149+
new_block_table,
150+
new_context_lens,
151+
quant_block_size);
152+
}
153+
118154
} // namespace xllm::kernel::mlu

xllm/core/kernels/mlu/fused_moe.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,15 @@ torch::Tensor fused_moe(
166166
/*b_scale=*/w1_scale.has_value() ? std::make_optional(w1_scale.value())
167167
: std::nullopt,
168168
/*bias=*/std::nullopt,
169+
/*a_calibration=*/std::nullopt,
170+
/*b_calibration=*/std::nullopt,
169171
/*quant_flag=*/w1_quant_flag.has_value() ? w1_quant_flag : std::nullopt,
170172
/*b_offset=*/std::nullopt,
171173
/*tile_config=*/std::nullopt,
172174
/*max_dim=*/tokens,
173175
/*trans_a=*/false,
174-
/*trans_b=*/true);
176+
/*trans_b=*/true,
177+
/*a_quant_bit=*/is_smoothquant ? 8 : -1);
175178

176179
// prepare the parameters for the second group gemm
177180
torch::Tensor act_out;
@@ -231,12 +234,15 @@ torch::Tensor fused_moe(
231234
w2_scale.has_value() ? std::make_optional(w2_scale.value())
232235
: std::nullopt, // b_scale
233236
/*bias=*/std::nullopt,
237+
/*a_calibration=*/std::nullopt,
238+
/*b_calibration=*/std::nullopt,
234239
w2_quant_flag.has_value() ? w2_quant_flag : std::nullopt, // quant_flag
235240
/*b_offset=*/std::nullopt,
236241
/*tile_config=*/std::nullopt,
237-
tokens, // max_dim
242+
/*max_dim=*/tokens,
238243
/*trans_a=*/false,
239-
/*trans_b=*/true);
244+
/*trans_b=*/true,
245+
/*a_quant_bit=*/is_smoothquant ? 8 : -1);
240246

241247
auto output = torch::empty({reduce_weight.size(0), gemm2_out.size(1)},
242248
gemm2_out.options());

xllm/core/kernels/mlu/mlu_ops_api.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ void active(const torch::Tensor& input,
5252
int expert_size);
5353

5454
void reshape_paged_cache(torch::Tensor& key,
55-
torch::Tensor& value,
55+
const std::optional<torch::Tensor>& value,
5656
torch::Tensor& k_cache,
57-
torch::Tensor& v_cache,
57+
const std::optional<torch::Tensor>& v_cache,
5858
const torch::Tensor& slot_mapping,
5959
bool direction);
6060

@@ -102,6 +102,24 @@ void batch_decode(const torch::Tensor& query,
102102
bool return_lse,
103103
int kv_cache_quant_bit_size);
104104

105+
void masked_indexer_select_paged_kv(const bool is_prefill,
106+
const torch::Tensor& query,
107+
const torch::Tensor& cu_seq_q_lens,
108+
const torch::Tensor& cu_seq_k_lens,
109+
const torch::Tensor& q_scale,
110+
const torch::Tensor& weights,
111+
const double softmax_scale,
112+
const torch::Tensor& k_cache,
113+
const torch::Tensor& k_context_lens,
114+
const torch::Tensor& k_cache_block_table,
115+
const torch::Tensor& k_scale_cache,
116+
const int64_t index_topk,
117+
const torch::Tensor& kv_cache_block_table,
118+
const int64_t kv_cache_block_size,
119+
const torch::Tensor& new_block_table,
120+
const torch::Tensor& new_context_lens,
121+
const int64_t quant_block_size);
122+
105123
void fused_layernorm(const torch::Tensor& input,
106124
torch::Tensor& output,
107125
const std::optional<torch::Tensor>& residual,

xllm/core/kernels/mlu/scaled_matmul.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <glog/logging.h>
17+
1618
#include "mlu_ops_api.h"
1719

1820
namespace xllm::kernel::mlu {
@@ -36,14 +38,12 @@ torch::Tensor scaled_matmul(
3638
const std::optional<torch::Tensor>& output /* = std::nullopt */
3739
) {
3840
// Check: only support w8a8 quantization for now.
39-
TORCH_CHECK(quant_bit_size == 8 && a_quant_bit_size == 8,
40-
"scaled_matmul only supports w8a8 quantization (quant_bit_size "
41-
"== 8, a_quant_bit_size == 8) for now. "
42-
"Got quant_bit_size = ",
43-
quant_bit_size,
44-
", a_quant_bit_size = ",
45-
a_quant_bit_size,
46-
".");
41+
CHECK(quant_bit_size == 8 && a_quant_bit_size == 8)
42+
<< "scaled_matmul only supports w8a8 quantization (quant_bit_size "
43+
"scaled_matmul only supports w8a8 quantization (quant_bit_size "
44+
"== 8, a_quant_bit_size == 8) for now. "
45+
"Got quant_bit_size = "
46+
<< quant_bit_size << ", a_quant_bit_size = " << a_quant_bit_size;
4747

4848
// Only support smooth_quant algorithm for now
4949
std::string quant_algo = "smooth_quant";
@@ -63,10 +63,8 @@ torch::Tensor scaled_matmul(
6363
at::ScalarType torch_half = at::ScalarType::Half;
6464
at::ScalarType torch_bfloat16 = at::ScalarType::BFloat16;
6565

66-
TORCH_CHECK(output_dtype == torch_half || output_dtype == torch_bfloat16,
67-
"output dtype must be half or bfloat16, but got: ",
68-
output_dtype,
69-
".");
66+
CHECK(output_dtype == torch_half || output_dtype == torch_bfloat16)
67+
<< "output dtype must be half or bfloat16, but got: " << output_dtype;
7068

7169
// Select output tensor
7270
torch::Tensor output_tensor;

xllm/core/kernels/ops_api.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,5 +246,30 @@ torch::Tensor random_sample(RandomSampleParams& params) {
246246
throw std::runtime_error("random_sample not implemented");
247247
#endif
248248
}
249+
250+
void masked_indexer_select_paged_kv(MaskedIndexerSelectPagedKVParams& params) {
251+
#if defined(USE_MLU)
252+
mlu::masked_indexer_select_paged_kv(params.is_prefill,
253+
params.query,
254+
params.cu_seq_q_lens,
255+
params.cu_seq_k_lens,
256+
params.q_scale,
257+
params.weights,
258+
params.softmax_scale,
259+
params.k_cache,
260+
params.k_context_lens,
261+
params.k_cache_block_table,
262+
params.k_scale_cache,
263+
params.index_topk,
264+
params.kv_cache_block_table,
265+
params.kv_cache_block_size,
266+
params.new_block_table,
267+
params.new_context_lens,
268+
params.quant_block_size);
269+
#else
270+
throw std::runtime_error("masked_indexer_select_paged_kv not implemented");
271+
#endif
272+
}
273+
249274
} // namespace kernel
250275
} // namespace xllm

xllm/core/kernels/ops_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,7 @@ torch::Tensor apply_top_k_top_p(TopKPParams& params);
4949

5050
torch::Tensor random_sample(RandomSampleParams& params);
5151

52+
void masked_indexer_select_paged_kv(MaskedIndexerSelectPagedKVParams& params);
53+
5254
} // namespace kernel
5355
} // namespace xllm

xllm/core/kernels/param.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ struct ActivationParams {
5656
// Reshape paged cache parameters
5757
struct ReshapePagedCacheParams {
5858
torch::Tensor key;
59-
torch::Tensor value;
59+
std::optional<torch::Tensor> value;
6060
torch::Tensor k_cache;
61-
torch::Tensor v_cache;
61+
std::optional<torch::Tensor> v_cache;
6262
torch::Tensor slot_mapping;
6363
bool direction = false;
6464
};
@@ -220,5 +220,27 @@ struct TopKPParams {
220220
struct RandomSampleParams {
221221
torch::Tensor logits;
222222
};
223+
224+
// Masked indexer select paged kv parameters
225+
struct MaskedIndexerSelectPagedKVParams {
226+
bool is_prefill;
227+
torch::Tensor query;
228+
torch::Tensor cu_seq_q_lens;
229+
torch::Tensor cu_seq_k_lens;
230+
torch::Tensor q_scale;
231+
torch::Tensor weights;
232+
double softmax_scale;
233+
torch::Tensor k_cache;
234+
torch::Tensor k_context_lens;
235+
torch::Tensor k_cache_block_table;
236+
torch::Tensor k_scale_cache;
237+
int64_t index_topk;
238+
torch::Tensor kv_cache_block_table;
239+
int64_t kv_cache_block_size;
240+
torch::Tensor new_block_table;
241+
torch::Tensor new_context_lens;
242+
int64_t quant_block_size;
243+
};
244+
223245
} // namespace kernel
224246
} // namespace xllm

xllm/core/layers/common/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_library(
1717
linear_impl.h
1818
word_embedding_impl.h
1919
layer_utils.h
20+
indexer.h
2021
SRCS
2122
qwen3_attention.cpp
2223
attention.cpp
@@ -28,6 +29,7 @@ cc_library(
2829
qwen3_moe_decoder_layer.cpp
2930
linear_impl.cpp
3031
layer_utils.cpp
32+
indexer.cpp
3133
DEPS
3234
"-Wl,--whole-archive"
3335
"-Wl,--no-whole-archive"
@@ -76,3 +78,20 @@ cc_test(
7678
torch
7779
GTest::gtest_main
7880
)
81+
82+
# Add test for Indexer
83+
cc_test(
84+
NAME
85+
indexer_test
86+
SRCS
87+
tests/indexer_tests.cpp
88+
tests/tests_utils.cpp
89+
DEPS
90+
:common_layers
91+
:parallel_state
92+
:model
93+
:state_dict
94+
glog::glog
95+
torch
96+
GTest::gtest_main
97+
)

0 commit comments

Comments
 (0)