Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 28 additions & 33 deletions lmdeploy/serve/openai/reasoning_parser/qwen_qwq_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,8 @@ def extract_reasoning_content_streaming(
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token in delta_text:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)
elif self.think_end_token in previous_text:
# </think> in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# no </think> in previous or delta, reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
# no <think> in previous or delta, all content
return DeltaMessage(content=delta_text)
Copy link
Collaborator

@RunningLeon RunningLeon Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the model does not have reasoning ability, but the output becomes reasoning_content?

Copy link
Contributor Author

@ywx217 ywx217 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific model name and model output for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the model does not have reasoning ability, the output should be normal content, not reasoning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, when enable_thinking=False, the parser is disabled in main branch

if VariableInterface.reasoning_parser is not None and request.enable_thinking is not False:


def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest,
**kwargs) -> Tuple[Optional[str], Optional[str]]:
Expand All @@ -109,26 +95,35 @@ def extract_reasoning_content(self, model_output: str, request: ChatCompletionRe
reasoning_content (str | None): The reasoning content.
final_output (str | None): The content.
"""
# DeepSeek R1 doesn't generate <think> now.
start_index = model_output.find(self.think_start_token)
end_index = model_output.find(self.think_end_token)
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output:
if end_index < 0:
# for qwen3 model, the reasoning content is wrapped by <think> </think> xml tags
return None, model_output
# Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output:
model_output = f'{self.think_start_token}{model_output}'
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]

end_index = len(f'{self.think_start_token}{reasoning_content}{self.think_end_token}')
final_output = model_output[end_index:]
if reasoning_content.startswith('\n'):
reasoning_content = reasoning_content[1:]
if reasoning_content.endswith('\n'):
reasoning_content = reasoning_content[:-1]
if start_index < 0:
return None, model_output
reasoning_content = model_output[start_index + len(self.think_start_token):]
reasoning_content = self._trim_newlines(reasoning_content)
return reasoning_content, None

if start_index >= 0 and start_index < end_index:
reasoning_content = model_output[start_index + len(self.think_start_token):end_index]
else:
reasoning_content = model_output[:end_index]
reasoning_content = self._trim_newlines(reasoning_content)

final_output = model_output[end_index + len(self.think_end_token):]
final_output = self._trim_newlines(final_output)

if len(final_output) == 0:
return reasoning_content, None

return reasoning_content, final_output

@classmethod
def _trim_newlines(cls, text: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why perform _trim_newlines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<think>\n{reasoning_content}\n</think>

to remove \n before and after reasoning_content

"""Trim newlines from the start and end of a string."""
while text.startswith('\n'):
text = text[1:]
while text.endswith('\n'):
text = text[:-1]
return text
45 changes: 45 additions & 0 deletions tests/test_lmdeploy/test_qwen3_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,48 @@ def test_no_think_nonstream():
first_message = resp.choices[0].message
assert first_message.content == '你好呀!✨ 很高兴见到你!'
assert first_message.reasoning_content is None


THINK_START_SEQUENCE = ['<think>', '\n']
TRUNCATED_SEQUENCE = ['OK', ', ', 'user', ' ', 'sends']


@pytest.mark.parametrize(
'sequence, expected_content, expected_reasoning_content',
[
# without think start token
(TRUNCATED_SEQUENCE, ''.join(TRUNCATED_SEQUENCE), None),
# with think start token
(THINK_START_SEQUENCE + TRUNCATED_SEQUENCE, None, ''.join(TRUNCATED_SEQUENCE)),
])
def test_truncated_think_nonstream(sequence, expected_content, expected_reasoning_content):

tokenizer = DummyTokenizer()
VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer)
VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer)
req = ChatCompletionRequest(model='qwen', messages=[], stream=False)
resp: ChatCompletionResponse = _chat_completion_v1(req, sequence)

assert len(resp.choices) == 1
first_message = resp.choices[0].message
assert first_message.content == expected_content
assert first_message.reasoning_content == expected_reasoning_content


@pytest.mark.parametrize(
'sequence, expected_content, expected_reasoning_content',
[
# without think start token
(TRUNCATED_SEQUENCE, ''.join(TRUNCATED_SEQUENCE), ''),
# with think start token
(THINK_START_SEQUENCE + TRUNCATED_SEQUENCE, '', ''.join(TRUNCATED_SEQUENCE)),
])
def test_truncated_think_stream(sequence, expected_content, expected_reasoning_content):
tokenizer = DummyTokenizer()
VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer)
VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer)
req = ChatCompletionRequest(model='qwen', messages=[], stream=True)
content, reasoning_content, tool_calls = _stream_parse(req, sequence)

assert content == expected_content
assert reasoning_content.lstrip() == expected_reasoning_content