Skip to content

Commit c4d3366

Browse files
committed
refactor!: replace reasoning trace extraction with LangChain additional_kwargs
BREAKING CHANGE: Remove reasoning_config and apply_to_reasoning_traces config options - Remove ReasoningModelConfig, ParsedTaskOutput, and related extraction logic - Replace with automatic extraction from response.additional_kwargs['reasoning_content'] - Remove 1,085 lines of test code and related infrastructure - Simplify parse_task_output() to return str directly - Update DeepSeek-R1 and Nemotron example configs Reasoning traces are now automatically extracted by LangChain and stored in additional_kwargs, eliminating the need for token-based parsing and complex configuration options.
1 parent 89225dc commit c4d3366

File tree

13 files changed

+20
-1906
lines changed

13 files changed

+20
-1906
lines changed

examples/configs/nemotron/config.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,3 @@ models:
22
- type: main
33
engine: nim
44
model: nvidia/llama-3.1-nemotron-ultra-253b-v1
5-
reasoning_config:
6-
remove_reasoning_traces: False # Set True to remove traces from the internal tasks

nemoguardrails/actions/llm/generation.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
5858
from nemoguardrails.kb.kb import KnowledgeBase
5959
from nemoguardrails.llm.prompts import get_prompt
60-
from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
60+
from nemoguardrails.llm.taskmanager import LLMTaskManager
6161
from nemoguardrails.llm.types import Task
6262
from nemoguardrails.logging.explain import LLMCallInfo
6363
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
@@ -496,7 +496,6 @@ async def generate_user_intent(
496496
result = self.llm_task_manager.parse_task_output(
497497
Task.GENERATE_USER_INTENT, output=result
498498
)
499-
result = result.text
500499

501500
user_intent = get_first_nonempty_line(result)
502501
if user_intent is None:
@@ -594,10 +593,6 @@ async def generate_user_intent(
594593
Task.GENERAL, output=text
595594
)
596595

597-
text = _process_parsed_output(
598-
text, self._include_reasoning_traces()
599-
)
600-
601596
else:
602597
# Initialize the LLMCallInfo object
603598
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
@@ -639,8 +634,6 @@ async def generate_user_intent(
639634
text = self.llm_task_manager.parse_task_output(
640635
Task.GENERAL, output=result
641636
)
642-
643-
text = _process_parsed_output(text, self._include_reasoning_traces())
644637
text = text.strip()
645638
if text.startswith('"'):
646639
text = text[1:-1]
@@ -750,7 +743,6 @@ async def generate_next_step(
750743
result = self.llm_task_manager.parse_task_output(
751744
Task.GENERATE_NEXT_STEPS, output=result
752745
)
753-
result = result.text
754746

755747
# If we don't have multi-step generation enabled, we only look at the first line.
756748
if not self.config.enable_multi_step_generation:
@@ -1036,10 +1028,6 @@ async def generate_bot_message(
10361028
Task.GENERAL, output=result
10371029
)
10381030

1039-
result = _process_parsed_output(
1040-
result, self._include_reasoning_traces()
1041-
)
1042-
10431031
log.info(
10441032
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
10451033
time() - t0,
@@ -1111,10 +1099,6 @@ async def generate_bot_message(
11111099
Task.GENERATE_BOT_MESSAGE, output=result
11121100
)
11131101

1114-
result = _process_parsed_output(
1115-
result, self._include_reasoning_traces()
1116-
)
1117-
11181102
# TODO: catch openai.error.InvalidRequestError from exceeding max token length
11191103

11201104
result = get_multiline_response(result)
@@ -1212,7 +1196,6 @@ async def generate_value(
12121196
result = self.llm_task_manager.parse_task_output(
12131197
Task.GENERATE_VALUE, output=result
12141198
)
1215-
result = result.text
12161199

12171200
# We only use the first line for now
12181201
# TODO: support multi-line values?
@@ -1433,7 +1416,6 @@ async def generate_intent_steps_message(
14331416
result = self.llm_task_manager.parse_task_output(
14341417
Task.GENERATE_INTENT_STEPS_MESSAGE, output=result
14351418
)
1436-
result = result.text
14371419

14381420
# TODO: Implement logic for generating more complex Colang next steps (multi-step),
14391421
# not just a single bot intent.
@@ -1516,7 +1498,6 @@ async def generate_intent_steps_message(
15161498
result = self.llm_task_manager.parse_task_output(
15171499
Task.GENERAL, output=result
15181500
)
1519-
result = _process_parsed_output(result, self._include_reasoning_traces())
15201501
text = result.strip()
15211502
if text.startswith('"'):
15221503
text = text[1:-1]
@@ -1529,10 +1510,6 @@ async def generate_intent_steps_message(
15291510
events=[new_event_dict("BotMessage", text=text)],
15301511
)
15311512

1532-
def _include_reasoning_traces(self) -> bool:
1533-
"""Get the configuration value for whether to include reasoning traces in output."""
1534-
return _get_apply_to_reasoning_traces(self.config)
1535-
15361513

15371514
def clean_utterance_content(utterance: str) -> str:
15381515
"""
@@ -1550,27 +1527,3 @@ def clean_utterance_content(utterance: str) -> str:
15501527
# It should be translated to an actual \n character.
15511528
utterance = utterance.replace("\\n", "\n")
15521529
return utterance
1553-
1554-
1555-
def _record_reasoning_trace(trace: str) -> None:
1556-
"""Store the reasoning trace in context for later retrieval."""
1557-
reasoning_trace_var.set(trace)
1558-
1559-
1560-
def _assemble_response(text: str, trace: Optional[str], include_reasoning: bool) -> str:
1561-
"""Combine trace and text if requested, otherwise just return text."""
1562-
return (trace + text) if (trace and include_reasoning) else text
1563-
1564-
1565-
def _process_parsed_output(
1566-
output: ParsedTaskOutput, include_reasoning_trace: bool
1567-
) -> str:
1568-
"""Record trace, then assemble the final LLM response."""
1569-
if reasoning_trace := output.reasoning_trace:
1570-
_record_reasoning_trace(reasoning_trace)
1571-
return _assemble_response(output.text, reasoning_trace, include_reasoning_trace)
1572-
1573-
1574-
def _get_apply_to_reasoning_traces(config: RailsConfig) -> bool:
1575-
"""Get the configuration value for whether to include reasoning traces in output."""
1576-
return config.rails.output.apply_to_reasoning_traces

nemoguardrails/actions/llm/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ def _store_tool_calls(response) -> None:
179179

180180

181181
def _store_response_metadata(response) -> None:
182-
"""Store response metadata excluding content for metadata preservation."""
182+
"""Store response metadata excluding content for metadata preservation.
183+
184+
Also extracts reasoning content from additional_kwargs if available from LangChain.
185+
"""
183186
if hasattr(response, "model_fields"):
184187
metadata = {}
185188
for field_name in response.model_fields:
@@ -188,6 +191,16 @@ def _store_response_metadata(response) -> None:
188191
): # Exclude content since it may be modified by rails
189192
metadata[field_name] = getattr(response, field_name)
190193
llm_response_metadata_var.set(metadata)
194+
195+
if hasattr(response, "additional_kwargs"):
196+
additional_kwargs = response.additional_kwargs
197+
if (
198+
isinstance(additional_kwargs, dict)
199+
and "reasoning_content" in additional_kwargs
200+
):
201+
reasoning_content = additional_kwargs["reasoning_content"]
202+
if reasoning_content:
203+
reasoning_trace_var.set(reasoning_content)
191204
else:
192205
llm_response_metadata_var.set(None)
193206

nemoguardrails/actions/v2_x/generation.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,6 @@ async def generate_user_intent( # pyright: ignore (TODO - Signature completely
318318
Task.GENERATE_USER_INTENT_FROM_USER_ACTION, output=result
319319
)
320320

321-
result = result.text
322-
323321
user_intent = get_first_nonempty_line(result)
324322
# GTP-4o often adds 'user intent: ' in front
325323
if user_intent and ":" in user_intent:
@@ -401,8 +399,6 @@ async def generate_user_intent_and_bot_action(
401399
Task.GENERATE_USER_INTENT_AND_BOT_ACTION_FROM_USER_ACTION, output=result
402400
)
403401

404-
result = result.text
405-
406402
user_intent = get_first_nonempty_line(result)
407403

408404
if user_intent and ":" in user_intent:
@@ -578,8 +574,6 @@ async def generate_flow_from_instructions(
578574
task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result
579575
)
580576

581-
result = result.text
582-
583577
# TODO: why this is not part of a filter or output_parser?
584578
#
585579
lines = _remove_leading_empty_lines(result).split("\n")
@@ -660,8 +654,6 @@ async def generate_flow_from_name(
660654
task=Task.GENERATE_FLOW_FROM_NAME, output=result
661655
)
662656

663-
result = result.text
664-
665657
lines = _remove_leading_empty_lines(result).split("\n")
666658

667659
if lines[0].startswith("flow"):
@@ -736,8 +728,6 @@ async def generate_flow_continuation(
736728
task=Task.GENERATE_FLOW_CONTINUATION, output=result
737729
)
738730

739-
result = result.text
740-
741731
lines = _remove_leading_empty_lines(result).split("\n")
742732

743733
if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""):
@@ -869,8 +859,6 @@ async def generate_value( # pyright: ignore (TODO - different arguments to base
869859
Task.GENERATE_VALUE_FROM_INSTRUCTION, output=result
870860
)
871861

872-
result = result.text
873-
874862
# We only use the first line for now
875863
# TODO: support multi-line values?
876864
value = result.strip().split("\n")[0]
@@ -994,8 +982,6 @@ async def generate_flow(
994982
Task.GENERATE_FLOW_CONTINUATION_FROM_NLD, output=result
995983
)
996984

997-
result = result.text
998-
999985
result = _remove_leading_empty_lines(result)
1000986
lines = result.split("\n")
1001987
if "codeblock" in lines[0]:

nemoguardrails/llm/filters.py

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,6 @@
2424
)
2525

2626

27-
@dataclass
28-
class ReasoningExtractionResult:
29-
"""
30-
Holds cleaned response text and optional chain-of-thought reasoning trace extracted from LLM output.
31-
"""
32-
33-
text: str
34-
reasoning_trace: Optional[str] = None
35-
36-
3727
def colang(events: List[dict]) -> str:
3828
"""Filter that turns an array of events into a colang history."""
3929
return get_colang_history(events)
@@ -448,100 +438,3 @@ def conversation_to_events(conversation: List) -> List[dict]:
448438
)
449439

450440
return events
451-
452-
453-
def _find_token_positions_for_removal(
454-
response: str, start_token: Optional[str], end_token: Optional[str]
455-
) -> Tuple[int, int]:
456-
"""Helper function to find token positions specifically for text removal.
457-
458-
This is useful, for example, to remove reasoning traces from a reasoning LLM response.
459-
460-
This is optimized for the removal use case:
461-
1. Uses find() for first start token
462-
2. Uses rfind() for last end token
463-
3. Sets start_index to 0 if start token is missing
464-
465-
Args:
466-
response(str): The text to search in
467-
start_token(str): The token marking the start of text to remove
468-
end_token(str): The token marking the end of text to remove
469-
470-
Returns:
471-
A tuple of (start_index, end_index) marking the span to remove;
472-
both indices are -1 if start_token and end_token are not provided.
473-
"""
474-
if not start_token or not end_token:
475-
return -1, -1
476-
477-
start_index = response.find(start_token)
478-
# if the start index is missing, this is probably a continuation of a bot message
479-
# started in the prompt.
480-
if start_index == -1:
481-
start_index = 0
482-
483-
end_index = response.rfind(end_token)
484-
485-
return start_index, end_index
486-
487-
488-
def find_reasoning_tokens_position(
489-
response: str, start_token: Optional[str], end_token: Optional[str]
490-
) -> Tuple[int, int]:
491-
"""Finds the positions of the first start token and the last end token.
492-
493-
This is intended to find the outermost boundaries of potential
494-
reasoning sections, typically for removal.
495-
496-
Args:
497-
response(str): The text to search in.
498-
start_token(Optional[str]): The token marking the start of reasoning.
499-
end_token(Optional[str]): The token marking the end of reasoning.
500-
501-
Returns:
502-
A tuple (start_index, end_index).
503-
- start_index: Position of the first `start_token`, or 0 if not found.
504-
- end_index: Position of the last `end_token`, or -1 if not found.
505-
"""
506-
507-
return _find_token_positions_for_removal(response, start_token, end_token)
508-
509-
510-
def extract_and_strip_trace(
511-
response: str, start_token: str, end_token: str
512-
) -> ReasoningExtractionResult:
513-
"""Extracts and removes reasoning traces from the given text.
514-
515-
This function identifies reasoning traces in the text that are marked
516-
by specific start and end tokens. It extracts these traces, removes
517-
them from the original text, and returns both the cleaned text and
518-
the extracted reasoning trace.
519-
520-
Args:
521-
response (str): The text to process.
522-
start_token (str): The token marking the start of a reasoning trace.
523-
end_token (str): The token marking the end of a reasoning trace.
524-
525-
Returns:
526-
ReasoningExtractionResult: An object containing the cleaned text
527-
without reasoning traces and the extracted reasoning trace, if any.
528-
"""
529-
530-
start_index, end_index = find_reasoning_tokens_position(
531-
response, start_token, end_token
532-
)
533-
# handles invalid/empty tokens returned as (-1, -1)
534-
if start_index == -1 and end_index == -1:
535-
return ReasoningExtractionResult(text=response, reasoning_trace=None)
536-
# end token is missing
537-
if end_index == -1:
538-
return ReasoningExtractionResult(text=response, reasoning_trace=None)
539-
# extrace if tokens are present and start < end
540-
if start_index < end_index:
541-
reasoning_trace = response[start_index : end_index + len(end_token)]
542-
cleaned_text = response[:start_index] + response[end_index + len(end_token) :]
543-
return ReasoningExtractionResult(
544-
text=cleaned_text, reasoning_trace=reasoning_trace
545-
)
546-
547-
return ReasoningExtractionResult(text=response, reasoning_trace=None)

0 commit comments

Comments
 (0)