Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/agent/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,7 @@ mod tests {
Ok(ToolCompletionResponse {
content: None,
tool_calls: vec![ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
id: crate::llm::generate_tool_call_id(0, 0),
name: "echo".to_string(),
arguments: serde_json::json!({"message": "looping"}),
}],
Expand Down Expand Up @@ -1997,7 +1997,7 @@ mod tests {
Ok(ToolCompletionResponse {
content: None,
tool_calls: vec![ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
id: crate::llm::generate_tool_call_id(0, 0),
name: "nonexistent_tool".to_string(),
arguments: serde_json::json!({}),
}],
Expand Down
25 changes: 18 additions & 7 deletions src/agent/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use uuid::Uuid;

use crate::channels::web::util::truncate_preview;
use crate::llm::{ChatMessage, ToolCall};
use crate::llm::provider::generate_tool_call_id;

/// A session containing one or more threads.
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -352,7 +353,7 @@ impl Thread {
/// completed actions in subsequent turns.
pub fn messages(&self) -> Vec<ChatMessage> {
let mut messages = Vec::new();
for turn in &self.turns {
for (turn_idx, turn) in self.turns.iter().enumerate() {
if turn.image_content_parts.is_empty() {
messages.push(ChatMessage::user(&turn.user_input));
} else {
Expand All @@ -363,13 +364,24 @@ impl Thread {
}

if !turn.tool_calls.is_empty() {
// Build ToolCall objects with synthetic stable IDs
let tool_calls: Vec<ToolCall> = turn
// Assign synthetic call IDs for this turn's tool calls, so that
// declarations and results can be consistently correlated.
let tool_calls_with_ids: Vec<(String, &_)> = turn
.tool_calls
.iter()
.enumerate()
.map(|(i, tc)| ToolCall {
id: format!("turn{}_{}", turn.turn_number, i),
.map(|(tc_idx, tc)| {
// Use provider-compatible tool call IDs derived from turn/tool indices.
let seed = format!("{}-{}", turn_idx, tc_idx);
(generate_tool_call_id(&seed), tc)
})
.collect();

// Build ToolCall objects using the synthetic call IDs.
let tool_calls: Vec<ToolCall> = tool_calls_with_ids
.iter()
.map(|(call_id, tc)| ToolCall {
id: call_id.clone(),
name: tc.name.clone(),
arguments: tc.parameters.clone(),
})
Expand All @@ -379,8 +391,7 @@ impl Thread {
messages.push(ChatMessage::assistant_with_tool_calls(None, tool_calls));

// Individual tool result messages, truncated to limit context size.
for (i, tc) in turn.tool_calls.iter().enumerate() {
let call_id = format!("turn{}_{}", turn.turn_number, i);
for (call_id, tc) in tool_calls_with_ids {
let content = if let Some(ref err) = tc.error {
// .error already contains the full error text;
// pass through without wrapping to avoid double-prefix.
Expand Down
2 changes: 1 addition & 1 deletion src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub use nearai_chat::{ModelInfo, NearAiChatProvider};
pub use provider::{
ChatMessage, CompletionRequest, CompletionResponse, ContentPart, FinishReason, ImageUrl,
LlmProvider, ModelMetadata, Role, ToolCall, ToolCompletionRequest, ToolCompletionResponse,
ToolDefinition, ToolResult,
ToolDefinition, ToolResult, generate_tool_call_id,
};
pub use reasoning::{
ActionPlan, Reasoning, ReasoningContext, RespondOutput, RespondResult, SILENT_REPLY_TOKEN,
Expand Down
20 changes: 20 additions & 0 deletions src/llm/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,26 @@ pub struct ToolCall {
pub arguments: serde_json::Value,
}

/// Generate a tool-call ID that satisfies all providers.
///
/// Mistral requires exactly 9 alphanumeric characters (`[a-zA-Z0-9]{9}`).
/// Other providers accept any non-empty string. By default we produce a
/// 9-char base-36 string derived from two seed values so the ID is both
/// deterministic (for replayed history) and provider-compatible.
Comment on lines +236 to +241
pub fn generate_tool_call_id(seed_a: usize, seed_b: usize) -> String {
// Mix the two seeds into a single u64 using a simple hash-like combine.
let combined = (seed_a as u64).wrapping_mul(6364136223846793005).wrapping_add(seed_b as u64);
// Format as 9-char zero-padded base-36 (digits + lowercase letters).
let mut buf = [b'0'; 9];
let mut val = combined;
for b in buf.iter_mut().rev() {
let digit = (val % 36) as u8;
*b = if digit < 10 { b'0' + digit } else { b'a' + digit - 10 };
val /= 36;
}
String::from(std::str::from_utf8(&buf).unwrap())
}

/// Result of a tool execution to send back to the LLM.
#[derive(Debug, Clone)]
pub struct ToolResult {
Expand Down
8 changes: 4 additions & 4 deletions src/llm/reasoning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ fn recover_tool_calls_from_content(
.cloned()
.unwrap_or(serde_json::Value::Object(Default::default()));
calls.push(ToolCall {
id: format!("recovered_{}", calls.len()),
id: super::provider::generate_tool_call_id(calls.len(), 99),
name: name.to_string(),
arguments,
});
Expand All @@ -1348,7 +1348,7 @@ fn recover_tool_calls_from_content(
let name = inner.trim();
if tool_names.contains(name) {
calls.push(ToolCall {
id: format!("recovered_{}", calls.len()),
id: super::provider::generate_tool_call_id(calls.len(), 99),
name: name.to_string(),
arguments: serde_json::Value::Object(Default::default()),
});
Expand Down Expand Up @@ -1382,7 +1382,7 @@ fn recover_tool_calls_from_content(
let arguments = serde_json::from_str::<serde_json::Value>(args_str)
.unwrap_or(serde_json::Value::Object(Default::default()));
calls.push(ToolCall {
id: format!("recovered_{}", calls.len()),
id: super::provider::generate_tool_call_id(calls.len(), 99),
name: name.to_string(),
arguments,
});
Expand All @@ -1393,7 +1393,7 @@ fn recover_tool_calls_from_content(

// No arguments or malformed — call with empty args
calls.push(ToolCall {
id: format!("recovered_{}", calls.len()),
id: super::provider::generate_tool_call_id(calls.len(), 99),
name: name.to_string(),
arguments: serde_json::Value::Object(Default::default()),
});
Expand Down
2 changes: 1 addition & 1 deletion src/llm/rig_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ fn convert_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<RigMessage
fn normalized_tool_call_id(raw: Option<&str>, seed: usize) -> String {
match raw.map(str::trim).filter(|id| !id.is_empty()) {
Some(id) => id.to_string(),
None => format!("generated_tool_call_{seed}"),
None => super::provider::generate_tool_call_id(seed, 0),
}
}

Expand Down
Loading