@@ -18,9 +18,9 @@ limitations under the License.
1818namespace xllm ::kernel::mlu {
1919
2020void reshape_paged_cache (torch::Tensor& key,
21- torch::Tensor& value,
21+ const std::optional< torch::Tensor> & value,
2222 torch::Tensor& k_cache,
23- torch::Tensor& v_cache,
23+ const std::optional< torch::Tensor> & v_cache,
2424 const torch::Tensor& slot_mapping,
2525 bool direction) {
2626 tmo::torch_api::reshape_paged_cache (
@@ -115,4 +115,40 @@ void batch_decode(const torch::Tensor& query,
115115 kv_cache_quant_bit_size);
116116}
117117
118+ void masked_indexer_select_paged_kv (const bool is_prefill,
119+ const torch::Tensor& query,
120+ const torch::Tensor& cu_seq_q_lens,
121+ const torch::Tensor& cu_seq_k_lens,
122+ const torch::Tensor& q_scale,
123+ const torch::Tensor& weights,
124+ const double softmax_scale,
125+ const torch::Tensor& k_cache,
126+ const torch::Tensor& k_context_lens,
127+ const torch::Tensor& k_cache_block_table,
128+ const torch::Tensor& k_scale_cache,
129+ const int64_t index_topk,
130+ const torch::Tensor& kv_cache_block_table,
131+ const int64_t kv_cache_block_size,
132+ const torch::Tensor& new_block_table,
133+ const torch::Tensor& new_context_lens,
134+ const int64_t quant_block_size) {
135+ tmo::torch_api::masked_indexer_select_paged_kv (is_prefill,
136+ query,
137+ cu_seq_q_lens,
138+ cu_seq_k_lens,
139+ q_scale,
140+ weights,
141+ softmax_scale,
142+ k_cache,
143+ k_context_lens,
144+ k_cache_block_table,
145+ k_scale_cache,
146+ index_topk,
147+ kv_cache_block_table,
148+ kv_cache_block_size,
149+ new_block_table,
150+ new_context_lens,
151+ quant_block_size);
152+ }
153+
118154} // namespace xllm::kernel::mlu
0 commit comments