Skip to content

Commit cfb1754

Browse files
committed
fix update_inputs_v1
1 parent 78eb318 commit cfb1754

File tree

1 file changed

+50
-53
lines changed

1 file changed

+50
-53
lines changed

custom_ops/gpu_ops/update_inputs_v1.cu

Lines changed: 50 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,19 @@
1515
#include "helper.h"
1616

1717
template <int THREADBLOCK_SIZE>
18-
__global__ void update_inputs_kernel_v1(bool *not_need_stop,
19-
int *seq_lens_this_time,
20-
int *seq_lens_encoder,
21-
int *seq_lens_decoder,
22-
int *step_seq_lens_decoder,
23-
int64_t *prompt_lens,
24-
int64_t *topk_ids,
25-
int64_t *input_ids,
26-
int *block_tables,
27-
const int64_t *stop_nums,
28-
bool *stop_flags,
29-
bool *is_block_step,
30-
const int64_t *next_tokens,
18+
__global__ void update_inputs_kernel_v1(bool* not_need_stop,
19+
int* seq_lens_this_time,
20+
int* seq_lens_encoder,
21+
int* seq_lens_decoder,
22+
int* step_seq_lens_decoder,
23+
int64_t* prompt_lens,
24+
int64_t* topk_ids,
25+
int64_t* input_ids,
26+
int* block_tables,
27+
const int64_t* stop_nums,
28+
bool* stop_flags,
29+
bool* is_block_step,
30+
const int64_t* next_tokens,
3131
const int bsz,
3232
const int max_bsz,
3333
const int input_ids_stride,
@@ -41,7 +41,6 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
4141

4242
bool stop_flag_now = false;
4343
int64_t stop_flag_now_int = 0;
44-
4544
if (thread_idx < max_bsz) {
4645
if (thread_idx < bsz) {
4746
stop_flag_now = stop_flags[thread_idx];
@@ -50,19 +49,16 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
5049
stop_flag_now_int = 1;
5150
}
5251
}
53-
5452
if (thread_idx < bsz) {
5553
if (stop_flag_now) {
56-
seq_lens_this_time[thread_idx] = 0;
54+
seq_lens_this_time[thread_idx] = 0; // stop at next step
5755
seq_lens_decoder[thread_idx] = 0;
5856
seq_lens_encoder[thread_idx] = 0;
59-
6057
} else {
6158
if (is_pooling_task) {
6259
if (seq_lens_this_time[thread_idx] > 0) {
6360
int total_processed =
6461
seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx];
65-
6662
if (total_processed >= prompt_lens[thread_idx]) {
6763
stop_flags[thread_idx] = true;
6864
seq_lens_encoder[thread_idx] = 0;
@@ -77,6 +73,7 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
7773
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >=
7874
prompt_lens[thread_idx]) {
7975
if (prefill_one_step_stop) {
76+
// prefill done, stop
8077
stop_flags[thread_idx] = true;
8178
seq_lens_this_time[thread_idx] = 0;
8279
seq_lens_decoder[thread_idx] = 0;
@@ -87,14 +84,16 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
8784
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
8885
seq_lens_this_time[thread_idx] = 1;
8986
seq_lens_encoder[thread_idx] = 0;
90-
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
87+
int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride;
9188
input_ids_now[0] = next_tokens[thread_idx];
9289

93-
int *block_table_now =
90+
// to judge whether block is not enough
91+
int* block_table_now =
9492
block_tables + thread_idx * block_num_per_seq;
9593
if (seq_lens_this_time[thread_idx] != 0 &&
9694
block_table_now[seq_lens_decoder[thread_idx] / block_size] ==
9795
-1) {
96+
// should be scheduled by server
9897
is_block_step[thread_idx] = true;
9998
seq_lens_this_time[thread_idx] = 0;
10099
stop_flags[thread_idx] = true;
@@ -122,56 +121,54 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
122121
}
123122
}
124123

125-
void UpdateInputesV1(const paddle::Tensor &stop_flags,
126-
const paddle::Tensor &not_need_stop, // only on cpu
127-
const paddle::Tensor &seq_lens_this_time,
128-
const paddle::Tensor &seq_lens_encoder,
129-
const paddle::Tensor &seq_lens_decoder,
130-
const paddle::Tensor &step_seq_lens_decoder,
131-
const paddle::Tensor &prompt_lens,
132-
const paddle::Tensor &topk_ids,
133-
const paddle::Tensor &input_ids,
134-
const paddle::Tensor &block_tables,
135-
const paddle::Tensor &stop_nums,
136-
const paddle::Tensor &next_tokens,
137-
const paddle::Tensor &is_block_step,
138-
const int block_size,
139-
const bool is_pooling_task) {
124+
void UpdateInputsV1(const paddle::Tensor& stop_flags,
125+
const paddle::Tensor& not_need_stop, // only on cpu
126+
const paddle::Tensor& seq_lens_this_time,
127+
const paddle::Tensor& seq_lens_encoder,
128+
const paddle::Tensor& seq_lens_decoder,
129+
const paddle::Tensor& step_seq_lens_decoder,
130+
const paddle::Tensor& prompt_lens,
131+
const paddle::Tensor& topk_ids,
132+
const paddle::Tensor& input_ids,
133+
const paddle::Tensor& block_tables,
134+
const paddle::Tensor& stop_nums,
135+
const paddle::Tensor& next_tokens,
136+
const paddle::Tensor& is_block_step,
137+
const int block_size,
138+
const bool is_pooling_task) {
140139
#ifdef PADDLE_WITH_CUSTOM_DEVICE
141-
auto dev_ctx = static_cast<const phi::CustomContext *>(
140+
auto dev_ctx = static_cast<const phi::CustomContext*>(
142141
paddle::experimental::DeviceContextPool::Instance().Get(
143142
input_ids.place()));
144143
auto cu_stream = dev_ctx->stream();
145144
#else
146145
auto cu_stream = input_ids.stream();
147146
#endif
148147
bool prefill_one_step_stop = false;
149-
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
148+
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
150149
if (env_p[0] == '1') {
151150
prefill_one_step_stop = true;
152151
}
153152
}
154153
const int max_bsz = stop_flags.shape()[0];
155154
const int now_bsz = seq_lens_this_time.shape()[0];
156-
157155
const int bsz_to_process = is_pooling_task ? max_bsz : now_bsz;
158-
159156
const int input_ids_stride = input_ids.shape()[1];
160157
const int block_num_per_seq = block_tables.shape()[1];
161158
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
162159
update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>(
163-
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
164-
const_cast<int *>(seq_lens_this_time.data<int>()),
165-
const_cast<int *>(seq_lens_encoder.data<int>()),
166-
const_cast<int *>(seq_lens_decoder.data<int>()),
167-
const_cast<int *>(step_seq_lens_decoder.data<int>()),
168-
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
169-
const_cast<int64_t *>(topk_ids.data<int64_t>()),
170-
const_cast<int64_t *>(input_ids.data<int64_t>()),
171-
const_cast<int *>(block_tables.data<int>()),
160+
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
161+
const_cast<int*>(seq_lens_this_time.data<int>()),
162+
const_cast<int*>(seq_lens_encoder.data<int>()),
163+
const_cast<int*>(seq_lens_decoder.data<int>()),
164+
const_cast<int*>(step_seq_lens_decoder.data<int>()),
165+
const_cast<int64_t*>(prompt_lens.data<int64_t>()),
166+
const_cast<int64_t*>(topk_ids.data<int64_t>()),
167+
const_cast<int64_t*>(input_ids.data<int64_t>()),
168+
const_cast<int*>(block_tables.data<int>()),
172169
stop_nums.data<int64_t>(),
173-
const_cast<bool *>(stop_flags.data<bool>()),
174-
const_cast<bool *>(is_block_step.data<bool>()),
170+
const_cast<bool*>(stop_flags.data<bool>()),
171+
const_cast<bool*>(is_block_step.data<bool>()),
175172
next_tokens.data<int64_t>(),
176173
bsz_to_process,
177174
max_bsz,
@@ -182,7 +179,7 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
182179
is_pooling_task);
183180
auto not_need_stop_cpu =
184181
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
185-
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
182+
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
186183
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
187184
}
188185

@@ -200,7 +197,7 @@ PD_BUILD_STATIC_OP(update_inputs_v1)
200197
"stop_nums",
201198
"next_tokens",
202199
"is_block_step"})
203-
.Attrs({"block_size: int", "is_pooling_task:bool"})
200+
.Attrs({"block_size: int", "is_pooling_task: bool"})
204201
.Outputs({"not_need_stop_out",
205202
"seq_lens_this_time_out",
206203
"seq_lens_encoder_out",
@@ -219,4 +216,4 @@ PD_BUILD_STATIC_OP(update_inputs_v1)
219216
{"stop_flags", "stop_flags_out"},
220217
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
221218
{"is_block_step", "is_block_step_out"}})
222-
.SetKernelFn(PD_KERNEL(UpdateInputesV1));
219+
.SetKernelFn(PD_KERNEL(UpdateInputsV1));

0 commit comments

Comments
 (0)