-
Notifications
You must be signed in to change notification settings - Fork 650
[Feature] support pooling model runner #4590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
cfb1754 to
5832cc4
Compare
26b8569 to
955fac1
Compare
…into suport_spoling
| const_cast<bool*>(is_block_step.data<bool>()), | ||
| next_tokens.data<int64_t>(), | ||
| now_bsz, | ||
| bsz_to_process, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么需要改成max_bsz啊
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要max-num-seqs个seq_lens_encoder都变成0
|
|
||
| pooler_output = pooler_output.numpy() | ||
| if pooler_output.dtype != np.float32: | ||
| pooler_output = pooler_output.astype(np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里输出是都固定返回fp32类型吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,只有bfloat16这么修改,因为numpy无此类型
| save_embedding_baseline(embedding, baseline_file) | ||
| else: | ||
| print(f"Comparing with baseline: {baseline_file}") | ||
| check_embedding_against_baseline(embedding, baseline_file, threshold=0.01) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是每次CI环境只会执行savedump不会执行精度比较?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里先让宝库在ci环境上执行一下,再合入
| if save_each_rank or model_output.mp_rank == 0: | ||
| output = _build_stream_transfer_data(output_tokens=None, pooler_outputs=pooler_output.outputs) | ||
|
|
||
| async_output_queue.put(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果没有开启FD_USE_GET_SAVE_OUTPUT_V1,行为是未定义的?这个环境变量默认值是0,感觉用户很难注意到特意去打开,是否可以改成Pooling模型自动打开,如果有问题最好也给个醒目的报错和修改提示
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接把这个环境变量删除了,pooling就走这个,没有第二种选择
|
不能跑cudagraph的原因是什么? |
|
| delta_text, token_ids = self._decode_token( | ||
| token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished | ||
| ) | ||
| if isinstance(content, RequestOutput): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else分支的类型是什么明确给出,然后再来个else报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else的现在只要不是生成式都走下面,目前我们只有这两个,后续还会有reward等,都是走else,这里孙磊参考改的
| seq_lens_encoder[thread_idx] = 0; | ||
| int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride; | ||
| input_ids_now[0] = next_tokens[thread_idx]; | ||
| if (is_pooling_task) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对这个算子做了什么逻辑的改动?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pooling时,seq_lens_encode的全部shape的值都改成0,确保exist_prefill为0,解决hung的问题
|
|
||
|
|
||
| def _build_stream_transfer_data(output_tokens: np.ndarray): | ||
| def _build_stream_transfer_data(output_tokens: np.ndarray, pooler_outputs: None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pooler_outputs默认是None的话,是这样写的吗?应该是pooler_outputs: type = None ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
| print("num_tokens", num_tokens) | ||
| print("max_num_seqs", max_num_seqs) | ||
| print("num_reqs", num_reqs) | ||
| print("min_tokens_per_req", min_tokens_per_req) | ||
| print("num_scheduled_token_list", num_scheduled_tokens_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| ) | ||
| return None | ||
|
|
||
| def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数命名语义再明确点儿
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个也是参考vllm的命名
b4a8a1a to
f439ca2
Compare
…into suport_spoling
a0e0452 to
15a0df8
Compare
1169a1b to
7ca73ba
Compare
gongshaotian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR supports actual inference for pooling models
Usage
online serving
Request Method (curl example)
A. EmbeddingCompletionRequest 示例(标准文本输入)
B. EmbeddingChatRequest 示例(消息序列输入)
Currently, there are bugs when enabling CUDA Graph and custom all-reduce, so they are temporarily disabled.
TODO:
Offline interface support for pooling is temporarily unavailable and will be supported in the future.