Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/agent/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,7 @@ mod tests {
Ok(ToolCompletionResponse {
content: None,
tool_calls: vec![ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
id: crate::llm::generate_tool_call_id(0, 0),
name: "echo".to_string(),
arguments: serde_json::json!({"message": "looping"}),
}],
Expand Down Expand Up @@ -1997,7 +1997,7 @@ mod tests {
Ok(ToolCompletionResponse {
content: None,
tool_calls: vec![ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
id: crate::llm::generate_tool_call_id(0, 0),
name: "nonexistent_tool".to_string(),
arguments: serde_json::json!({}),
}],
Expand Down
30 changes: 22 additions & 8 deletions src/agent/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
use uuid::Uuid;

use crate::channels::web::util::truncate_preview;
use crate::llm::{ChatMessage, ToolCall};
use crate::llm::{ChatMessage, ToolCall, generate_tool_call_id};

/// A session containing one or more threads.
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -352,7 +352,12 @@ impl Thread {
/// completed actions in subsequent turns.
pub fn messages(&self) -> Vec<ChatMessage> {
let mut messages = Vec::new();
for turn in &self.turns {
// We use the enumeration index (`turn_idx`) rather than `turn.turn_number`
// intentionally: after `truncate_turns()`, the remaining turns are
// re-numbered starting from 0, so the enumeration index and turn_number
// are equivalent. Using the index avoids coupling to the field and keeps
// tool-call ID generation deterministic for the current message window.
for (turn_idx, turn) in self.turns.iter().enumerate() {
if turn.image_content_parts.is_empty() {
messages.push(ChatMessage::user(&turn.user_input));
} else {
Expand All @@ -363,13 +368,23 @@ impl Thread {
}

if !turn.tool_calls.is_empty() {
// Build ToolCall objects with synthetic stable IDs
let tool_calls: Vec<ToolCall> = turn
// Assign synthetic call IDs for this turn's tool calls, so that
// declarations and results can be consistently correlated.
let tool_calls_with_ids: Vec<(String, &_)> = turn
.tool_calls
.iter()
.enumerate()
.map(|(i, tc)| ToolCall {
id: format!("turn{}_{}", turn.turn_number, i),
.map(|(tc_idx, tc)| {
// Use provider-compatible tool call IDs derived from turn/tool indices.
(generate_tool_call_id(turn_idx, tc_idx), tc)
})
.collect();

// Build ToolCall objects using the synthetic call IDs.
let tool_calls: Vec<ToolCall> = tool_calls_with_ids
.iter()
.map(|(call_id, tc)| ToolCall {
id: call_id.clone(),
name: tc.name.clone(),
arguments: tc.parameters.clone(),
})
Expand All @@ -379,8 +394,7 @@ impl Thread {
messages.push(ChatMessage::assistant_with_tool_calls(None, tool_calls));

// Individual tool result messages, truncated to limit context size.
for (i, tc) in turn.tool_calls.iter().enumerate() {
let call_id = format!("turn{}_{}", turn.turn_number, i);
for (call_id, tc) in tool_calls_with_ids {
let content = if let Some(ref err) = tc.error {
// .error already contains the full error text;
// pass through without wrapping to avoid double-prefix.
Expand Down
2 changes: 1 addition & 1 deletion src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub use nearai_chat::{ModelInfo, NearAiChatProvider};
pub use provider::{
ChatMessage, CompletionRequest, CompletionResponse, ContentPart, FinishReason, ImageUrl,
LlmProvider, ModelMetadata, Role, ToolCall, ToolCompletionRequest, ToolCompletionResponse,
ToolDefinition, ToolResult,
ToolDefinition, ToolResult, generate_tool_call_id,
};
pub use reasoning::{
ActionPlan, Reasoning, ReasoningContext, RespondOutput, RespondResult, SILENT_REPLY_TOKEN,
Expand Down
97 changes: 97 additions & 0 deletions src/llm/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,32 @@ pub struct ToolCall {
pub arguments: serde_json::Value,
}

/// Generate a tool-call ID that satisfies all providers.
///
/// Mistral requires exactly 9 alphanumeric characters (`[a-zA-Z0-9]{9}`).
/// Other providers accept any non-empty string. By default we produce a
/// 9-char base-62 string derived from two seed values so the ID is both
/// deterministic (for replayed history) and provider-compatible.
Comment on lines +236 to +241
pub fn generate_tool_call_id(seed_a: usize, seed_b: usize) -> String {
// Mix the two seeds into a single u64 using a simple hash-like combine.
let combined = (seed_a as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(seed_b as u64);
// Format as 9-char zero-padded base-62 (0-9, a-z, A-Z).
let mut buf = [b'0'; 9];
let mut val = combined;
for b in buf.iter_mut().rev() {
let digit = (val % 62) as u8;
*b = match digit {
0..=9 => b'0' + digit,
10..=35 => b'a' + (digit - 10),
_ => b'A' + (digit - 36),
};
val /= 62;
}
buf.iter().map(|&b| b as char).collect::<String>()
}

/// Result of a tool execution to send back to the LLM.
#[derive(Debug, Clone)]
pub struct ToolResult {
Expand Down Expand Up @@ -533,6 +559,77 @@ pub fn strip_unsupported_tool_params(
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;

#[test]
fn generate_tool_call_id_has_valid_format() {
let samples = [
(0usize, 0usize),
(1usize, 2usize),
(42usize, 999usize),
(usize::MAX, usize::MAX),
];

for (a, b) in samples {
let id = generate_tool_call_id(a, b);
assert_eq!(
id.len(),
9,
"tool-call ID must be exactly 9 characters for seeds ({a}, {b})"
);
assert!(
id.chars().all(|c| c.is_ascii_alphanumeric()),
"tool-call ID must be ASCII alphanumeric for seeds ({a}, {b}), got: {id}"
);
}
}

#[test]
fn generate_tool_call_id_is_deterministic_for_same_seeds() {
let pairs = [
(0usize, 0usize),
(1usize, 2usize),
(123usize, 456usize),
(usize::MAX, 0usize),
];

for (a, b) in pairs {
let id1 = generate_tool_call_id(a, b);
let id2 = generate_tool_call_id(a, b);
let id3 = generate_tool_call_id(a, b);
assert_eq!(
id1, id2,
"tool-call ID must be deterministic for seeds ({a}, {b})"
);
assert_eq!(
id2, id3,
"tool-call ID must be deterministic across multiple calls for seeds ({a}, {b})"
);
}
}

#[test]
fn generate_tool_call_id_differs_for_different_seeds_in_small_sample() {
let seed_pairs = [
(0usize, 1usize),
(1usize, 0usize),
(1usize, 2usize),
(2usize, 3usize),
(10usize, 20usize),
(100usize, 200usize),
];

let mut ids = HashSet::new();
for (a, b) in seed_pairs {
let id = generate_tool_call_id(a, b);
let inserted = ids.insert(id.clone());
assert!(
inserted,
"expected distinct tool-call IDs for different seeds, \
but duplicate ID '{id}' found for seeds ({a}, {b})"
);
}
}

#[test]
fn test_sanitize_preserves_valid_pairs() {
Expand Down
24 changes: 20 additions & 4 deletions src/llm/reasoning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ You said you would perform an action, but you did not include any tool calls.\n\
Do NOT describe what you intend to do — actually call the tool now.\n\
Use the tool_calls mechanism to invoke the appropriate tool.";

/// Seed value used as the second argument to `generate_tool_call_id` when
/// recovering tool calls from malformed LLM text responses. This must differ
/// from the `0` seed used in `rig_adapter::normalized_tool_call_id` to avoid
/// ID collisions between provider-generated and text-recovered tool calls at
/// the same positional index.
const RECOVERED_TOOL_CALL_SEED: usize = 99;

/// Detect when an LLM response expresses intent to call a tool without
/// actually issuing tool calls. Returns `true` if the text contains phrases
/// like "Let me search …" or "I'll fetch …" outside of fenced/indented code blocks.
Expand Down Expand Up @@ -1337,7 +1344,10 @@ fn recover_tool_calls_from_content(
.cloned()
.unwrap_or(serde_json::Value::Object(Default::default()));
calls.push(ToolCall {
id: format!("recovered_{}", calls.len()),
id: super::provider::generate_tool_call_id(
calls.len(),
RECOVERED_TOOL_CALL_SEED,
),
name: name.to_string(),
arguments,
});
Expand All @@ -1348,7 +1358,10 @@ fn recover_tool_calls_from_content(
let name = inner.trim();
if tool_names.contains(name) {
calls.push(ToolCall {
id: format!("recovered_{}", calls.len()),
id: super::provider::generate_tool_call_id(
calls.len(),
RECOVERED_TOOL_CALL_SEED,
),
name: name.to_string(),
arguments: serde_json::Value::Object(Default::default()),
});
Expand Down Expand Up @@ -1382,7 +1395,10 @@ fn recover_tool_calls_from_content(
let arguments = serde_json::from_str::<serde_json::Value>(args_str)
.unwrap_or(serde_json::Value::Object(Default::default()));
calls.push(ToolCall {
id: format!("recovered_{}", calls.len()),
id: super::provider::generate_tool_call_id(
calls.len(),
RECOVERED_TOOL_CALL_SEED,
),
name: name.to_string(),
arguments,
});
Expand All @@ -1393,7 +1409,7 @@ fn recover_tool_calls_from_content(

// No arguments or malformed — call with empty args
calls.push(ToolCall {
id: format!("recovered_{}", calls.len()),
id: super::provider::generate_tool_call_id(calls.len(), RECOVERED_TOOL_CALL_SEED),
name: name.to_string(),
arguments: serde_json::Value::Object(Default::default()),
});
Expand Down
Loading
Loading