1515#include " helper.h"
1616
1717template <int THREADBLOCK_SIZE>
18- __global__ void update_inputs_kernel_v1 (bool * not_need_stop,
19- int * seq_lens_this_time,
20- int * seq_lens_encoder,
21- int * seq_lens_decoder,
22- int * step_seq_lens_decoder,
23- int64_t * prompt_lens,
24- int64_t * topk_ids,
25- int64_t * input_ids,
26- int * block_tables,
27- const int64_t * stop_nums,
28- bool * stop_flags,
29- bool * is_block_step,
30- const int64_t * next_tokens,
18+ __global__ void update_inputs_kernel_v1 (bool * not_need_stop,
19+ int * seq_lens_this_time,
20+ int * seq_lens_encoder,
21+ int * seq_lens_decoder,
22+ int * step_seq_lens_decoder,
23+ int64_t * prompt_lens,
24+ int64_t * topk_ids,
25+ int64_t * input_ids,
26+ int * block_tables,
27+ const int64_t * stop_nums,
28+ bool * stop_flags,
29+ bool * is_block_step,
30+ const int64_t * next_tokens,
3131 const int bsz,
3232 const int max_bsz,
3333 const int input_ids_stride,
@@ -41,7 +41,6 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
4141
4242 bool stop_flag_now = false ;
4343 int64_t stop_flag_now_int = 0 ;
44-
4544 if (thread_idx < max_bsz) {
4645 if (thread_idx < bsz) {
4746 stop_flag_now = stop_flags[thread_idx];
@@ -50,19 +49,16 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
5049 stop_flag_now_int = 1 ;
5150 }
5251 }
53-
5452 if (thread_idx < bsz) {
5553 if (stop_flag_now) {
56- seq_lens_this_time[thread_idx] = 0 ;
54+ seq_lens_this_time[thread_idx] = 0 ; // stop at next step
5755 seq_lens_decoder[thread_idx] = 0 ;
5856 seq_lens_encoder[thread_idx] = 0 ;
59-
6057 } else {
6158 if (is_pooling_task) {
6259 if (seq_lens_this_time[thread_idx] > 0 ) {
6360 int total_processed =
6461 seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx];
65-
6662 if (total_processed >= prompt_lens[thread_idx]) {
6763 stop_flags[thread_idx] = true ;
6864 seq_lens_encoder[thread_idx] = 0 ;
@@ -77,6 +73,7 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
7773 if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >=
7874 prompt_lens[thread_idx]) {
7975 if (prefill_one_step_stop) {
76+ // prefill done, stop
8077 stop_flags[thread_idx] = true ;
8178 seq_lens_this_time[thread_idx] = 0 ;
8279 seq_lens_decoder[thread_idx] = 0 ;
@@ -87,14 +84,16 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
8784 seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
8885 seq_lens_this_time[thread_idx] = 1 ;
8986 seq_lens_encoder[thread_idx] = 0 ;
90- int64_t * input_ids_now = input_ids + thread_idx * input_ids_stride;
87+ int64_t * input_ids_now = input_ids + thread_idx * input_ids_stride;
9188 input_ids_now[0 ] = next_tokens[thread_idx];
9289
93- int *block_table_now =
90+ // to judge whether block is not enough
91+ int * block_table_now =
9492 block_tables + thread_idx * block_num_per_seq;
9593 if (seq_lens_this_time[thread_idx] != 0 &&
9694 block_table_now[seq_lens_decoder[thread_idx] / block_size] ==
9795 -1 ) {
96+ // should be scheduled by server
9897 is_block_step[thread_idx] = true ;
9998 seq_lens_this_time[thread_idx] = 0 ;
10099 stop_flags[thread_idx] = true ;
@@ -122,56 +121,54 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
122121 }
123122}
124123
125- void UpdateInputesV1 (const paddle::Tensor & stop_flags,
126- const paddle::Tensor & not_need_stop, // only on cpu
127- const paddle::Tensor & seq_lens_this_time,
128- const paddle::Tensor & seq_lens_encoder,
129- const paddle::Tensor & seq_lens_decoder,
130- const paddle::Tensor & step_seq_lens_decoder,
131- const paddle::Tensor & prompt_lens,
132- const paddle::Tensor & topk_ids,
133- const paddle::Tensor & input_ids,
134- const paddle::Tensor & block_tables,
135- const paddle::Tensor & stop_nums,
136- const paddle::Tensor & next_tokens,
137- const paddle::Tensor & is_block_step,
138- const int block_size,
139- const bool is_pooling_task) {
124+ void UpdateInputsV1 (const paddle::Tensor& stop_flags,
125+ const paddle::Tensor& not_need_stop, // only on cpu
126+ const paddle::Tensor& seq_lens_this_time,
127+ const paddle::Tensor& seq_lens_encoder,
128+ const paddle::Tensor& seq_lens_decoder,
129+ const paddle::Tensor& step_seq_lens_decoder,
130+ const paddle::Tensor& prompt_lens,
131+ const paddle::Tensor& topk_ids,
132+ const paddle::Tensor& input_ids,
133+ const paddle::Tensor& block_tables,
134+ const paddle::Tensor& stop_nums,
135+ const paddle::Tensor& next_tokens,
136+ const paddle::Tensor& is_block_step,
137+ const int block_size,
138+ const bool is_pooling_task) {
140139#ifdef PADDLE_WITH_CUSTOM_DEVICE
141- auto dev_ctx = static_cast <const phi::CustomContext *>(
140+ auto dev_ctx = static_cast <const phi::CustomContext*>(
142141 paddle::experimental::DeviceContextPool::Instance ().Get (
143142 input_ids.place ()));
144143 auto cu_stream = dev_ctx->stream ();
145144#else
146145 auto cu_stream = input_ids.stream ();
147146#endif
148147 bool prefill_one_step_stop = false ;
149- if (const char * env_p = std::getenv (" PREFILL_NODE_ONE_STEP_STOP_V1" )) {
148+ if (const char * env_p = std::getenv (" PREFILL_NODE_ONE_STEP_STOP_V1" )) {
150149 if (env_p[0 ] == ' 1' ) {
151150 prefill_one_step_stop = true ;
152151 }
153152 }
154153 const int max_bsz = stop_flags.shape ()[0 ];
155154 const int now_bsz = seq_lens_this_time.shape ()[0 ];
156-
157155 const int bsz_to_process = is_pooling_task ? max_bsz : now_bsz;
158-
159156 const int input_ids_stride = input_ids.shape ()[1 ];
160157 const int block_num_per_seq = block_tables.shape ()[1 ];
161158 auto not_need_stop_gpu = not_need_stop.copy_to (stop_flags.place (), false );
162159 update_inputs_kernel_v1<1024 ><<<1 , 1024 , 0 , cu_stream>>> (
163- const_cast <bool *>(not_need_stop_gpu.data <bool >()),
164- const_cast <int *>(seq_lens_this_time.data <int >()),
165- const_cast <int *>(seq_lens_encoder.data <int >()),
166- const_cast <int *>(seq_lens_decoder.data <int >()),
167- const_cast <int *>(step_seq_lens_decoder.data <int >()),
168- const_cast <int64_t *>(prompt_lens.data <int64_t >()),
169- const_cast <int64_t *>(topk_ids.data <int64_t >()),
170- const_cast <int64_t *>(input_ids.data <int64_t >()),
171- const_cast <int *>(block_tables.data <int >()),
160+ const_cast <bool *>(not_need_stop_gpu.data <bool >()),
161+ const_cast <int *>(seq_lens_this_time.data <int >()),
162+ const_cast <int *>(seq_lens_encoder.data <int >()),
163+ const_cast <int *>(seq_lens_decoder.data <int >()),
164+ const_cast <int *>(step_seq_lens_decoder.data <int >()),
165+ const_cast <int64_t *>(prompt_lens.data <int64_t >()),
166+ const_cast <int64_t *>(topk_ids.data <int64_t >()),
167+ const_cast <int64_t *>(input_ids.data <int64_t >()),
168+ const_cast <int *>(block_tables.data <int >()),
172169 stop_nums.data <int64_t >(),
173- const_cast <bool *>(stop_flags.data <bool >()),
174- const_cast <bool *>(is_block_step.data <bool >()),
170+ const_cast <bool *>(stop_flags.data <bool >()),
171+ const_cast <bool *>(is_block_step.data <bool >()),
175172 next_tokens.data <int64_t >(),
176173 bsz_to_process,
177174 max_bsz,
@@ -182,7 +179,7 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
182179 is_pooling_task);
183180 auto not_need_stop_cpu =
184181 not_need_stop_gpu.copy_to (not_need_stop.place (), false );
185- bool * not_need_stop_data = const_cast <bool *>(not_need_stop.data <bool >());
182+ bool * not_need_stop_data = const_cast <bool *>(not_need_stop.data <bool >());
186183 not_need_stop_data[0 ] = not_need_stop_cpu.data <bool >()[0 ];
187184}
188185
@@ -200,7 +197,7 @@ PD_BUILD_STATIC_OP(update_inputs_v1)
200197 " stop_nums" ,
201198 " next_tokens" ,
202199 " is_block_step" })
203- .Attrs({" block_size: int" , " is_pooling_task:bool" })
200+ .Attrs({" block_size: int" , " is_pooling_task: bool" })
204201 .Outputs({" not_need_stop_out" ,
205202 " seq_lens_this_time_out" ,
206203 " seq_lens_encoder_out" ,
@@ -219,4 +216,4 @@ PD_BUILD_STATIC_OP(update_inputs_v1)
219216 {" stop_flags" , " stop_flags_out" },
220217 {" step_seq_lens_decoder" , " step_seq_lens_decoder_out" },
221218 {" is_block_step" , " is_block_step_out" }})
222- .SetKernelFn(PD_KERNEL(UpdateInputesV1 ));
219+ .SetKernelFn(PD_KERNEL(UpdateInputsV1 ));
0 commit comments