-
Notifications
You must be signed in to change notification settings - Fork 657
[Feature] support pooling model runner #4590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
7716866
5fde033
001f23d
73141d4
6a2ddaf
85d14ba
8200040
5832cc4
58616e4
ad2f7b6
aeddcac
955fac1
21c20a7
0206d42
30795d2
6bc1ed2
f439ca2
27d686b
a6a9483
15a0df8
57e76be
eae6db6
90d5ee1
1a35691
2fa8733
5b12f6f
7ca73ba
1e3cae5
90ef114
6c20954
c8c3664
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,8 @@ __global__ void update_inputs_kernel_v1(bool* not_need_stop, | |
| const int input_ids_stride, | ||
| const int block_num_per_seq, | ||
| const int block_size, | ||
| bool prefill_one_step_stop) { | ||
| bool prefill_one_step_stop, | ||
| bool is_pooling_task) { | ||
| int thread_idx = threadIdx.x; | ||
| typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce; | ||
| __shared__ typename BlockReduce::TempStorage temp_storage; | ||
|
|
@@ -48,53 +49,75 @@ __global__ void update_inputs_kernel_v1(bool* not_need_stop, | |
| stop_flag_now_int = 1; | ||
| } | ||
| } | ||
|
|
||
| if (thread_idx < bsz) { | ||
| if (stop_flag_now) { | ||
| seq_lens_this_time[thread_idx] = 0; // stop at next step | ||
| seq_lens_decoder[thread_idx] = 0; | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| } else { | ||
| if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= | ||
| prompt_lens[thread_idx]) { | ||
| if (prefill_one_step_stop) { | ||
| // prefill done, stop | ||
| stop_flags[thread_idx] = true; | ||
| seq_lens_this_time[thread_idx] = 0; | ||
| seq_lens_decoder[thread_idx] = 0; | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| stop_flag_now_int = 1; | ||
| } else { | ||
| // decoding | ||
| seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]; | ||
| seq_lens_this_time[thread_idx] = 1; | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride; | ||
| input_ids_now[0] = next_tokens[thread_idx]; | ||
| if (is_pooling_task) { | ||
| if (seq_lens_this_time[thread_idx] > 0) { | ||
| int total_processed = | ||
| seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx]; | ||
|
|
||
| // to judge whether block is not enough | ||
| int* block_table_now = block_tables + thread_idx * block_num_per_seq; | ||
| if (seq_lens_this_time[thread_idx] != 0 && | ||
| block_table_now[seq_lens_decoder[thread_idx] / block_size] == | ||
| -1) { | ||
| // should be scheduled by server | ||
| is_block_step[thread_idx] = true; | ||
| seq_lens_this_time[thread_idx] = 0; | ||
| if (total_processed >= prompt_lens[thread_idx]) { | ||
| stop_flags[thread_idx] = true; | ||
| step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]; | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| seq_lens_decoder[thread_idx] = 0; | ||
| seq_lens_this_time[thread_idx] = 0; | ||
| stop_flag_now_int = 1; | ||
| } | ||
| } else { | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| stop_flag_now_int = 1; | ||
| } | ||
| } else { | ||
| stop_flags[thread_idx] = true; | ||
| seq_lens_this_time[thread_idx] = 0; | ||
| seq_lens_decoder[thread_idx] = 0; | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| topk_ids[thread_idx] = -1; | ||
| stop_flag_now_int = 1; | ||
| // Normal generation task logic | ||
| if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= | ||
| prompt_lens[thread_idx]) { | ||
| if (prefill_one_step_stop) { | ||
| // prefill done, stop | ||
| stop_flags[thread_idx] = true; | ||
| seq_lens_this_time[thread_idx] = 0; | ||
| seq_lens_decoder[thread_idx] = 0; | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| stop_flag_now_int = 1; | ||
| } else { | ||
| // decoding | ||
| seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]; | ||
| seq_lens_this_time[thread_idx] = 1; | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride; | ||
| input_ids_now[0] = next_tokens[thread_idx]; | ||
|
|
||
| // to judge whether block is not enough | ||
| int* block_table_now = | ||
| block_tables + thread_idx * block_num_per_seq; | ||
| if (seq_lens_this_time[thread_idx] != 0 && | ||
| block_table_now[seq_lens_decoder[thread_idx] / block_size] == | ||
| -1) { | ||
| // should be scheduled by server | ||
| is_block_step[thread_idx] = true; | ||
| seq_lens_this_time[thread_idx] = 0; | ||
| stop_flags[thread_idx] = true; | ||
| step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]; | ||
| seq_lens_decoder[thread_idx] = 0; | ||
| stop_flag_now_int = 1; | ||
| } | ||
| } | ||
| } else { | ||
| stop_flags[thread_idx] = true; | ||
| seq_lens_this_time[thread_idx] = 0; | ||
| seq_lens_decoder[thread_idx] = 0; | ||
| seq_lens_encoder[thread_idx] = 0; | ||
| topk_ids[thread_idx] = -1; | ||
| stop_flag_now_int = 1; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| __syncthreads(); | ||
| int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); | ||
| if (thread_idx == 0) { | ||
|
|
@@ -115,7 +138,8 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags, | |
| const paddle::Tensor& stop_nums, | ||
| const paddle::Tensor& next_tokens, | ||
| const paddle::Tensor& is_block_step, | ||
| const int block_size) { | ||
| const int block_size, | ||
| const bool is_pooling_task) { | ||
| #ifdef PADDLE_WITH_CUSTOM_DEVICE | ||
| auto dev_ctx = static_cast<const phi::CustomContext*>( | ||
| paddle::experimental::DeviceContextPool::Instance().Get( | ||
|
|
@@ -132,6 +156,7 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags, | |
| } | ||
| const int max_bsz = stop_flags.shape()[0]; | ||
| const int now_bsz = seq_lens_this_time.shape()[0]; | ||
| const int bsz_to_process = is_pooling_task ? max_bsz : now_bsz; | ||
| const int input_ids_stride = input_ids.shape()[1]; | ||
| const int block_num_per_seq = block_tables.shape()[1]; | ||
| auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); | ||
|
|
@@ -149,12 +174,13 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags, | |
| const_cast<bool*>(stop_flags.data<bool>()), | ||
| const_cast<bool*>(is_block_step.data<bool>()), | ||
| next_tokens.data<int64_t>(), | ||
| now_bsz, | ||
| bsz_to_process, | ||
|
||
| max_bsz, | ||
| input_ids_stride, | ||
| block_num_per_seq, | ||
| block_size, | ||
| prefill_one_step_stop); | ||
| prefill_one_step_stop, | ||
| is_pooling_task); | ||
| auto not_need_stop_cpu = | ||
| not_need_stop_gpu.copy_to(not_need_stop.place(), false); | ||
| bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>()); | ||
|
|
@@ -175,7 +201,7 @@ PD_BUILD_STATIC_OP(update_inputs_v1) | |
| "stop_nums", | ||
| "next_tokens", | ||
| "is_block_step"}) | ||
| .Attrs({"block_size: int"}) | ||
| .Attrs({"block_size: int", "is_pooling_task: bool"}) | ||
| .Outputs({"not_need_stop_out", | ||
| "seq_lens_this_time_out", | ||
| "seq_lens_encoder_out", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -737,6 +737,7 @@ def _fetch_request(): | |
| raise | ||
| # 2. Schedule requests | ||
| tasks = self.resource_manager.schedule() | ||
|
|
||
| # 3. Send to engine | ||
| if tasks: | ||
| if self.cfg.scheduler_config.splitwise_role == "decode": | ||
|
|
@@ -886,24 +887,27 @@ def _zmq_send_generated_tokens(self): | |
| for request_id, contents in results.items(): | ||
| new_contents = [] | ||
| for content in contents: | ||
| decode_type = content.outputs.decode_type | ||
| delta_text = "" | ||
| if decode_type == 0: | ||
| delta_text, token_ids = self._decode_token( | ||
| token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished | ||
| ) | ||
| if isinstance(content, RequestOutput): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. else分支的类型是什么明确给出,然后再来个else报错
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. else的现在只要不是生成式都走下面,目前我们只有这两个,后续还会有reward等,都是走else,这里孙磊参考改的 |
||
| decode_type = content.outputs.decode_type | ||
| delta_text = "" | ||
| if decode_type == 0: | ||
| delta_text, token_ids = self._decode_token( | ||
| token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished | ||
| ) | ||
| else: | ||
| token_ids = content.outputs.token_ids | ||
| if len(token_ids): | ||
| content.outputs.token_ids = token_ids | ||
| content.outputs.text = delta_text | ||
| new_contents.append(content) | ||
| elif content.finished: | ||
| new_contents.append(content) | ||
| else: | ||
| llm_logger.warning( | ||
| f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}" | ||
| ) | ||
| else: | ||
| token_ids = content.outputs.token_ids | ||
| if len(token_ids): | ||
| content.outputs.token_ids = token_ids | ||
| content.outputs.text = delta_text | ||
| new_contents.append(content) | ||
| elif content.finished: | ||
| new_contents.append(content) | ||
| else: | ||
| llm_logger.warning( | ||
| f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}" | ||
| ) | ||
| if len(new_contents): | ||
| llm_logger.info(f"Send response for request id: {request_id}") | ||
| self.send_response_server.send_response(request_id, new_contents) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对这个算子做了什么逻辑的改动?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pooling时,seq_lens_encode的全部shape的值都改成0,确保exist_prefill为0,解决hung的问题