Skip to content

Commit 26b8569

Browse files
authored
Merge branch 'develop' into suport_spoling
2 parents ad2f7b6 + 5c63a08 commit 26b8569

File tree

9 files changed

+130
-5
lines changed

9 files changed

+130
-5
lines changed

docs/parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ When using FastDeploy to deploy models (including offline inference and service
4848
| ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 |
4949
| ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel |
5050
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
51+
| ```logprobs_mode``` | `str` | Indicates the content returned in the logprobs. Supported mode: `raw_logprobs`, `processed_logprobs`, `raw_logits`, `processed_logits`. Raw means the values before applying logit processors, like bad words. Processed means the values after applying such processors. |
5152
| ```served_model_name```| `str`| The model name used in the API. If not specified, the model name will be the same as the --model argument |
5253
| ```revision``` | `str` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. |
5354
| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. |

docs/zh/parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
| ```dynamic_load_weight``` | `int` | 是否动态加载权重,默认0 |
4747
| ```enable_expert_parallel``` | `bool` | 是否启用专家并行 |
4848
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 |
49+
| ```logprobs_mode``` | `str` | 指定logprobs中返回的内容。支持的模式:`raw_logprobs``processed_logprobs'、`raw_logits`,`processed_logits'。processed表示logits应用温度、惩罚、禁止词处理后计算的logprobs。|
4950
| ```served_model_name``` | `str` | API 中使用的模型名称,如果未指定,模型名称将与--model参数相同 |
5051
| ```revision``` | `str` | 自动下载模型时,用于指定模型的Git版本,分支名或tag |
5152
| ```chat_template``` | `str` | 指定模型拼接使用的模板,支持字符串与文件路径,默认为None,如未指定,则使用模型默认模板 |

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def __init__(
183183
self.max_model_len = 0
184184
self.dtype = "bfloat16"
185185
self.enable_logprob = False
186+
self.logprobs_mode = "raw_logprobs"
186187
self.enable_redundant_experts = False
187188
self.redundant_experts_num = 0
188189
self.seed = 0

fastdeploy/engine/args_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,15 @@ class EngineArgs:
367367
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
368368
"""
369369

370+
logprobs_mode: str = "raw_logprobs"
371+
"""
372+
Indicates the content returned in the logprobs.
373+
Supported mode:
374+
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
375+
Raw means the values before applying logit processors, like bad words.
376+
Processed means the values after applying such processors.
377+
"""
378+
370379
seed: int = 0
371380
"""
372381
Random seed to use for initialization. If not set, defaults to 0.
@@ -412,6 +421,8 @@ def __post_init__(self):
412421
if self.enable_logprob:
413422
if not current_platform.is_cuda():
414423
raise NotImplementedError("Only CUDA platform supports logprob.")
424+
if self.speculative_config is not None and self.logprobs_mode.startswith("processed"):
425+
raise NotImplementedError("processed_logprobs not support in speculative.")
415426
if self.speculative_config is not None:
416427
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
417428
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
@@ -610,6 +621,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
610621
default=EngineArgs.enable_logprob,
611622
help="Enable output of token-level log probabilities.",
612623
)
624+
model_group.add_argument(
625+
"--logprobs-mode",
626+
type=str,
627+
choices=["raw_logprobs", "processed_logprobs", "processed_logits"],
628+
default=EngineArgs.logprobs_mode,
629+
help="Indicates the content returned in the logprobs.",
630+
)
613631
model_group.add_argument(
614632
"--seed",
615633
type=int,

fastdeploy/engine/async_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ def _start_worker_service(self):
830830
f" --runner {self.cfg.model_config.runner}"
831831
f" --convert {self.cfg.model_config.convert}"
832832
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
833+
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
833834
)
834835

835836
worker_append_flag = {

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def _start_worker_service(self):
532532
f" --runner {self.cfg.model_config.runner}"
533533
f" --convert {self.cfg.model_config.convert}"
534534
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
535+
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
535536
)
536537

537538
worker_append_flag = {

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class Sampler(nn.Layer):
199199
Sampler for normal generation.
200200
"""
201201

202-
def __init__(self, fd_config: FDConfig = None):
202+
def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprobs"):
203203
""" """
204204
super().__init__()
205205
if (
@@ -217,6 +217,7 @@ def __init__(self, fd_config: FDConfig = None):
217217
raise NotImplementedError
218218

219219
self.processor = SamplerProcessor()
220+
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
220221
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
221222
if (
222223
fd_config is not None
@@ -335,7 +336,10 @@ def forward_cuda(
335336

336337
num_logprobs = sampling_metadata.max_num_logprobs
337338
if num_logprobs is not None:
338-
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
339+
if self.logprobs_mode == "raw_logprobs":
340+
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
341+
elif self.logprobs_mode == "raw_logits":
342+
raw_logprobs = logits.clone()
339343

340344
logits = apply_penalty_multi_scores(
341345
sampling_metadata.pre_token_ids,
@@ -352,6 +356,12 @@ def forward_cuda(
352356
sampling_metadata.eos_token_ids,
353357
)
354358

359+
if num_logprobs is not None:
360+
if self.logprobs_mode == "processed_logprobs":
361+
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
362+
elif self.logprobs_mode == "processed_logits":
363+
raw_logprobs = logits.clone()
364+
355365
probs = F.softmax(logits)
356366

357367
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
@@ -437,6 +447,7 @@ def __init__(self, fd_config: FDConfig):
437447
self.forward = self.forward_cuda
438448
else:
439449
raise NotImplementedError
450+
self.logprobs_mode = fd_config.model_config.logprobs_mode
440451
self.speculative_verify_window = fd_config.speculative_config.verify_window
441452
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
442453
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
@@ -644,7 +655,10 @@ def forward_cuda(
644655
share_inputs["seq_lens_encoder"],
645656
share_inputs["accept_num"],
646657
)
647-
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
658+
if self.logprobs_mode == "raw_logprobs":
659+
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
660+
elif self.logprobs_mode == "raw_logits":
661+
raw_logprobs = target_logtis.clone()
648662

649663
logprobs_tensors = None
650664
token_ids = share_inputs["accept_tokens"]
@@ -677,6 +691,7 @@ def __init__(self, fd_config: FDConfig):
677691
self.forward = self.forward_cuda
678692
else:
679693
raise NotImplementedError
694+
self.logprobs_mode = fd_config.model_config.logprobs_mode
680695

681696
def pre_process(self, skip_idx_list: List[int] = []):
682697
"""pre process before running"""
@@ -808,7 +823,12 @@ def forward_cuda(
808823
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
809824
if num_logprobs is not None and share_inputs["substep"] == 0:
810825
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
811-
raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"][:real_token_num, :], sampling_metadata)
826+
if self.logprobs_mode == "raw_logprobs":
827+
raw_logprobs = self.compute_logprobs(
828+
share_inputs["draft_logits"][:real_token_num, :], sampling_metadata
829+
)
830+
elif self.logprobs_mode == "raw_logits":
831+
raw_logprobs = share_inputs["draft_logits"][:real_token_num, :].clone()
812832

813833
logits = apply_speculative_penalty_multi_scores(
814834
sampling_metadata.pre_token_ids,

fastdeploy/worker/worker_process.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,12 @@ def parse_args():
614614
action="store_true",
615615
help="Enable output of token-level log probabilities.",
616616
)
617+
parser.add_argument(
618+
"--logprobs_mode",
619+
type=str,
620+
default="raw_logprobs",
621+
help="Indicates the content returned in the logprobs.",
622+
)
617623
parser.add_argument(
618624
"--reasoning_parser",
619625
type=str,

tests/layers/test_sampler.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
"""
1616

1717
import paddle
18+
import paddle.nn.functional as F
1819

1920
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
2021
from fastdeploy.model_executor.layers.sample.sampler import Sampler
2122

2223

2324
def _create_fake_logits(batch_size: int, vocab_size: int) -> paddle.Tensor:
24-
fake_logits = paddle.full(shape=[batch_size, vocab_size], fill_value=1e-2, dtype="float32")
25+
fake_logits = paddle.rand(shape=[batch_size, vocab_size], dtype="float32")
2526
return fake_logits
2627

2728

@@ -41,6 +42,7 @@ def _create_default_sampling_metadata(
4142
batch_size: int,
4243
min_seq_len: int,
4344
max_seq_len: int,
45+
max_num_logprobs: int = None,
4446
) -> SamplingMetadata:
4547

4648
fake_sampling_metadata = SamplingMetadata(
@@ -59,6 +61,8 @@ def _create_default_sampling_metadata(
5961
min_p=paddle.randn([batch_size]),
6062
seed=paddle.to_tensor([[2025]]),
6163
)
64+
if max_num_logprobs is not None:
65+
fake_sampling_metadata.max_num_logprobs = max_num_logprobs
6266
return fake_sampling_metadata
6367

6468

@@ -75,5 +79,77 @@ def test_sampler():
7579
print(next_tokens)
7680

7781

82+
def get_baseline_logprobs(logits, sampling_metadata, logprobs_mode, token_ids):
83+
if logprobs_mode == "raw_logprobs":
84+
logprobs = F.log_softmax(logits, axis=-1)
85+
elif logprobs_mode == "raw_logits":
86+
logprobs = logits.clone()
87+
elif logprobs_mode == "processed_logprobs":
88+
from fastdeploy.model_executor.layers.sample.ops import (
89+
apply_penalty_multi_scores,
90+
)
91+
92+
logits = apply_penalty_multi_scores(
93+
sampling_metadata.pre_token_ids,
94+
sampling_metadata.prompt_ids,
95+
sampling_metadata.prompt_lens,
96+
logits,
97+
sampling_metadata.repetition_penalties,
98+
sampling_metadata.frequency_penalties,
99+
sampling_metadata.presence_penalties,
100+
sampling_metadata.temperature,
101+
sampling_metadata.bad_words_token_ids,
102+
sampling_metadata.step_idx,
103+
sampling_metadata.min_dec_lens,
104+
sampling_metadata.eos_token_ids,
105+
)
106+
logprobs = F.log_softmax(logits, axis=-1)
107+
else:
108+
from fastdeploy.model_executor.layers.sample.ops import (
109+
apply_penalty_multi_scores,
110+
)
111+
112+
logits = apply_penalty_multi_scores(
113+
sampling_metadata.pre_token_ids,
114+
sampling_metadata.prompt_ids,
115+
sampling_metadata.prompt_lens,
116+
logits,
117+
sampling_metadata.repetition_penalties,
118+
sampling_metadata.frequency_penalties,
119+
sampling_metadata.presence_penalties,
120+
sampling_metadata.temperature,
121+
sampling_metadata.bad_words_token_ids,
122+
sampling_metadata.step_idx,
123+
sampling_metadata.min_dec_lens,
124+
sampling_metadata.eos_token_ids,
125+
)
126+
logprobs = logits
127+
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
128+
return token_logprobs
129+
130+
131+
def test_sampler_logprobs():
132+
batch_size = 32
133+
vocab_size = 1024
134+
min_seq_len = 1
135+
max_seq_len = 1024
136+
logprobs_mode_list = ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]
137+
logits = _create_fake_logits(batch_size, vocab_size)
138+
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
139+
for logprobs_mode in logprobs_mode_list:
140+
sampler = Sampler(logprobs_mode=logprobs_mode)
141+
sampler_output = sampler(logits.clone(), sampling_metadata)
142+
baseline_logprobs = get_baseline_logprobs(
143+
logits.clone(), sampling_metadata, logprobs_mode=logprobs_mode, token_ids=sampler_output.sampled_token_ids
144+
)
145+
logprobs = sampler_output.logprobs_tensors.logprobs
146+
print(f"baseline_logprobs = {baseline_logprobs}")
147+
print(f"logprobs = {logprobs}")
148+
equal = paddle.allclose(baseline_logprobs, logprobs, atol=1e-03, rtol=1e-03).item()
149+
print(f"logprobs_mode: {logprobs_mode} equal={equal}")
150+
assert equal
151+
152+
78153
if __name__ == "__main__":
79154
test_sampler()
155+
test_sampler_logprobs()

0 commit comments

Comments
 (0)