Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,28 +192,20 @@ impl HarmonyEncoding {
let render_options = RenderOptions {
conversation_has_function_tools: has_function_tools,
};
let last_assistant_is_final = messages
.iter()
.rev()
.find_map(|msg| {
(msg.author.role == Role::Assistant)
.then(|| msg.channel.as_deref() == Some("final"))
})
.unwrap_or(false);

let should_drop_analysis =
config.is_some_and(|c| c.auto_drop_analysis && last_assistant_is_final);
config.is_some_and(|c| c.auto_drop_analysis);

let first_final_idx = messages
let last_final_idx = messages
.iter()
.position(|msg| msg.channel.as_deref() == Some("final"));
.rposition(|msg| msg.channel.as_deref() == Some("final"));

let result = messages
.iter()
.enumerate()
.filter(|(idx, msg)| {
!(should_drop_analysis
&& first_final_idx.is_some_and(|first| *idx < first)
&& last_final_idx.is_some_and(|last| *idx < last)
&& msg.channel.as_deref() == Some("analysis"))
})
.try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options)));
Expand Down
71 changes: 71 additions & 0 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -886,3 +886,74 @@ fn test_parse_completion_with_invalid_content_token_errors_on_eos() {
.with_channel("analysis");
assert_eq!(parsed_message, &expected_message);
}

#[test]
fn test_multi_turn_auto_drop_analysis() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let expected_output = load_test_data("../test-data/test_multi_turn_auto_drop_analysis.txt");

let convo = Conversation::from_messages([
Message::from_role_and_content(
Role::Developer,
DeveloperContent::new().with_instructions(
"You are a helpful assistant that analyzes code and provides detailed feedback.",
),
),
Message::from_role_and_content(
Role::User,
"Can you help me optimize this Python function?\n\n\
def fibonacci(n):\n\
if n <= 1:\n\
return n\n\
return fibonacci(n-1) + fibonacci(n-2)",
),
// Turn 1: analysis + final
Message::from_role_and_content(
Role::Assistant,
"The user provided a recursive Fibonacci implementation. O(2^n) complexity.",
)
.with_channel("analysis"),
Message::from_role_and_content(
Role::Assistant,
"This recursive Fibonacci has exponential time complexity. \
Would you like me to show optimized versions?",
)
.with_channel("final"),
Message::from_role_and_content(Role::User, "Yes, and benchmark them"),
// Turn 2: analysis + final
Message::from_role_and_content(
Role::Assistant,
"User wants benchmarks. I should run some Python code to compare performance.",
)
.with_channel("analysis"),
Message::from_role_and_content(Role::Assistant, "I'll benchmark both versions for you.")
.with_channel("final"),
Message::from_role_and_content(Role::User, "Run the benchmark for n=30"),
// Turn 3: analysis + tool call (no final after)
Message::from_role_and_content(
Role::Assistant,
"I need to execute Python code to run the benchmark for n=30.",
)
.with_channel("analysis"),
Message::from_role_and_content(
Role::Assistant,
r#"{"code": "import timeit\n\ndef fib_recursive(n):\n if n <= 1: return n\n return fib_recursive(n-1) + fib_recursive(n-2)\n\ndef fib_iter(n):\n if n <= 1: return n\n a, b = 0, 1\n for _ in range(2, n+1): a, b = b, a+b\n return b\n\nprint(timeit.timeit(lambda: fib_recursive(30), number=1))\nprint(timeit.timeit(lambda: fib_iter(30), number=1000))"}"#,
)
.with_channel("commentary")
.with_recipient("functions.python")
.with_content_type("json"),
]);

let tokens = encoding
.render_conversation_for_completion(
&convo,
Role::Assistant,
Some(&crate::encoding::RenderConversationConfig {
auto_drop_analysis: true,
}),
)
.unwrap();

let decoded = encoding.tokenizer.decode_utf8(&tokens).unwrap();
assert_eq!(decoded, expected_output);
}
6 changes: 6 additions & 0 deletions test-data/test_multi_turn_auto_drop_analysis.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<|start|>developer<|message|>You are a helpful assistant that analyzes code and provides detailed feedback.<|end|><|start|>user<|message|>Can you help me optimize this Python function?

def fibonacci(n):
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)<|end|><|start|>assistant<|channel|>final<|message|>This recursive Fibonacci has exponential time complexity. Would you like me to show optimized versions?<|end|><|start|>user<|message|>Yes, and benchmark them<|end|><|start|>assistant<|channel|>final<|message|>I'll benchmark both versions for you.<|end|><|start|>user<|message|>Run the benchmark for n=30<|end|><|start|>assistant<|channel|>analysis<|message|>I need to execute Python code to run the benchmark for n=30.<|end|><|start|>assistant to=functions.python<|channel|>commentary<|message|>{"code": "import timeit\n\ndef fib_recursive(n):\n if n <= 1: return n\n return fib_recursive(n-1) + fib_recursive(n-2)\n\ndef fib_iter(n):\n if n <= 1: return n\n a, b = 0, 1\n for _ in range(2, n+1): a, b = b, a+b\n return b\n\nprint(timeit.timeit(lambda: fib_recursive(30), number=1))\nprint(timeit.timeit(lambda: fib_iter(30), number=1000))"}<|call|><|start|>assistant
74 changes: 74 additions & 0 deletions tests/test_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,3 +1244,77 @@ def test_streamable_parser_tricky_utf8_decoding():

# Ensure if we're accumulating content deltas we still get the full utf-8 text
assert "".join(content_deltas) == tricky_utf8_text


def test_multi_turn_auto_drop_analysis():
"""
In multi-turn conversations with auto_drop_analysis=True,
all analysis messages before the last final message should be dropped.

This test ensures that we use last_final_idx instead of first_final_idx
when determining which analysis messages to drop.
"""
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)

expected_output = (
(ROOT_DIR / "test-data" / "test_multi_turn_auto_drop_analysis.txt")
.read_text(encoding="utf-8")
.rstrip()
)

convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.DEVELOPER,
DeveloperContent.new().with_instructions(
"You are a helpful assistant that analyzes code and provides detailed feedback."
),
),
Message.from_role_and_content(
Role.USER,
"Can you help me optimize this Python function?\n\n"
"def fibonacci(n):\n"
" if n <= 1:\n"
" return n\n"
" return fibonacci(n-1) + fibonacci(n-2)",
),
# Turn 1: analysis + final
Message.from_role_and_content(
Role.ASSISTANT,
"The user provided a recursive Fibonacci implementation. O(2^n) complexity.",
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
"This recursive Fibonacci has exponential time complexity. "
"Would you like me to show optimized versions?",
).with_channel("final"),
Message.from_role_and_content(Role.USER, "Yes, and benchmark them"),
# Turn 2: analysis + final
Message.from_role_and_content(
Role.ASSISTANT,
"User wants benchmarks. I should run some Python code to compare performance.",
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'll benchmark both versions for you."
).with_channel("final"),
Message.from_role_and_content(Role.USER, "Run the benchmark for n=30"),
# Turn 3: analysis + tool call (no final after)
Message.from_role_and_content(
Role.ASSISTANT,
"I need to execute Python code to run the benchmark for n=30.",
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"code": "import timeit\\n\\ndef fib_recursive(n):\\n if n <= 1: return n\\n return fib_recursive(n-1) + fib_recursive(n-2)\\n\\ndef fib_iter(n):\\n if n <= 1: return n\\n a, b = 0, 1\\n for _ in range(2, n+1): a, b = b, a+b\\n return b\\n\\nprint(timeit.timeit(lambda: fib_recursive(30), number=1))\\nprint(timeit.timeit(lambda: fib_iter(30), number=1000))"}',
)
.with_channel("commentary")
.with_recipient("functions.python")
.with_content_type("json"),
]
)

tokens = encoding.render_conversation_for_completion(
convo, Role.ASSISTANT, RenderConversationConfig(auto_drop_analysis=True)
)

assert encoding.decode_utf8(tokens) == expected_output