diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher.rs index 9e6747f2b..f6c0c8319 100644 --- a/src/agent/dispatcher.rs +++ b/src/agent/dispatcher.rs @@ -1844,7 +1844,7 @@ mod tests { Ok(ToolCompletionResponse { content: None, tool_calls: vec![ToolCall { - id: format!("call_{}", uuid::Uuid::new_v4()), + id: crate::llm::generate_tool_call_id(0, 0), name: "echo".to_string(), arguments: serde_json::json!({"message": "looping"}), }], @@ -1997,7 +1997,7 @@ mod tests { Ok(ToolCompletionResponse { content: None, tool_calls: vec![ToolCall { - id: format!("call_{}", uuid::Uuid::new_v4()), + id: crate::llm::generate_tool_call_id(0, 0), name: "nonexistent_tool".to_string(), arguments: serde_json::json!({}), }], diff --git a/src/agent/session.rs b/src/agent/session.rs index 4abbea616..eb52aab35 100644 --- a/src/agent/session.rs +++ b/src/agent/session.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::channels::web::util::truncate_preview; -use crate::llm::{ChatMessage, ToolCall}; +use crate::llm::{ChatMessage, ToolCall, generate_tool_call_id}; /// A session containing one or more threads. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -352,7 +352,12 @@ impl Thread { /// completed actions in subsequent turns. pub fn messages(&self) -> Vec { let mut messages = Vec::new(); - for turn in &self.turns { + // We use the enumeration index (`turn_idx`) rather than `turn.turn_number` + // intentionally: after `truncate_turns()`, the remaining turns are + // re-numbered starting from 0, so the enumeration index and turn_number + // are equivalent. Using the index avoids coupling to the field and keeps + // tool-call ID generation deterministic for the current message window. + for (turn_idx, turn) in self.turns.iter().enumerate() { if turn.image_content_parts.is_empty() { messages.push(ChatMessage::user(&turn.user_input)); } else { @@ -363,13 +368,23 @@ impl Thread { } if !turn.tool_calls.is_empty() { - // Build ToolCall objects with synthetic stable IDs - let tool_calls: Vec = turn + // Assign synthetic call IDs for this turn's tool calls, so that + // declarations and results can be consistently correlated. + let tool_calls_with_ids: Vec<(String, &_)> = turn .tool_calls .iter() .enumerate() - .map(|(i, tc)| ToolCall { - id: format!("turn{}_{}", turn.turn_number, i), + .map(|(tc_idx, tc)| { + // Use provider-compatible tool call IDs derived from turn/tool indices. + (generate_tool_call_id(turn_idx, tc_idx), tc) + }) + .collect(); + + // Build ToolCall objects using the synthetic call IDs. + let tool_calls: Vec = tool_calls_with_ids + .iter() + .map(|(call_id, tc)| ToolCall { + id: call_id.clone(), name: tc.name.clone(), arguments: tc.parameters.clone(), }) @@ -379,8 +394,7 @@ impl Thread { messages.push(ChatMessage::assistant_with_tool_calls(None, tool_calls)); // Individual tool result messages, truncated to limit context size. - for (i, tc) in turn.tool_calls.iter().enumerate() { - let call_id = format!("turn{}_{}", turn.turn_number, i); + for (call_id, tc) in tool_calls_with_ids { let content = if let Some(ref err) = tc.error { // .error already contains the full error text; // pass through without wrapping to avoid double-prefix. diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 51309bf37..2c6b19266 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -46,7 +46,7 @@ pub use nearai_chat::{ModelInfo, NearAiChatProvider}; pub use provider::{ ChatMessage, CompletionRequest, CompletionResponse, ContentPart, FinishReason, ImageUrl, LlmProvider, ModelMetadata, Role, ToolCall, ToolCompletionRequest, ToolCompletionResponse, - ToolDefinition, ToolResult, + ToolDefinition, ToolResult, generate_tool_call_id, }; pub use reasoning::{ ActionPlan, Reasoning, ReasoningContext, RespondOutput, RespondResult, SILENT_REPLY_TOKEN, diff --git a/src/llm/provider.rs b/src/llm/provider.rs index 8a213031c..bb45ec680 100644 --- a/src/llm/provider.rs +++ b/src/llm/provider.rs @@ -233,6 +233,32 @@ pub struct ToolCall { pub arguments: serde_json::Value, } +/// Generate a tool-call ID that satisfies all providers. +/// +/// Mistral requires exactly 9 alphanumeric characters (`[a-zA-Z0-9]{9}`). +/// Other providers accept any non-empty string. By default we produce a +/// 9-char base-62 string derived from two seed values so the ID is both +/// deterministic (for replayed history) and provider-compatible. +pub fn generate_tool_call_id(seed_a: usize, seed_b: usize) -> String { + // Mix the two seeds into a single u64 using a simple hash-like combine. + let combined = (seed_a as u64) + .wrapping_mul(6364136223846793005) + .wrapping_add(seed_b as u64); + // Format as 9-char zero-padded base-62 (0-9, a-z, A-Z). + let mut buf = [b'0'; 9]; + let mut val = combined; + for b in buf.iter_mut().rev() { + let digit = (val % 62) as u8; + *b = match digit { + 0..=9 => b'0' + digit, + 10..=35 => b'a' + (digit - 10), + _ => b'A' + (digit - 36), + }; + val /= 62; + } + buf.iter().map(|&b| b as char).collect::() +} + /// Result of a tool execution to send back to the LLM. #[derive(Debug, Clone)] pub struct ToolResult { @@ -533,6 +559,77 @@ pub fn strip_unsupported_tool_params( #[cfg(test)] mod tests { use super::*; + use std::collections::HashSet; + + #[test] + fn generate_tool_call_id_has_valid_format() { + let samples = [ + (0usize, 0usize), + (1usize, 2usize), + (42usize, 999usize), + (usize::MAX, usize::MAX), + ]; + + for (a, b) in samples { + let id = generate_tool_call_id(a, b); + assert_eq!( + id.len(), + 9, + "tool-call ID must be exactly 9 characters for seeds ({a}, {b})" + ); + assert!( + id.chars().all(|c| c.is_ascii_alphanumeric()), + "tool-call ID must be ASCII alphanumeric for seeds ({a}, {b}), got: {id}" + ); + } + } + + #[test] + fn generate_tool_call_id_is_deterministic_for_same_seeds() { + let pairs = [ + (0usize, 0usize), + (1usize, 2usize), + (123usize, 456usize), + (usize::MAX, 0usize), + ]; + + for (a, b) in pairs { + let id1 = generate_tool_call_id(a, b); + let id2 = generate_tool_call_id(a, b); + let id3 = generate_tool_call_id(a, b); + assert_eq!( + id1, id2, + "tool-call ID must be deterministic for seeds ({a}, {b})" + ); + assert_eq!( + id2, id3, + "tool-call ID must be deterministic across multiple calls for seeds ({a}, {b})" + ); + } + } + + #[test] + fn generate_tool_call_id_differs_for_different_seeds_in_small_sample() { + let seed_pairs = [ + (0usize, 1usize), + (1usize, 0usize), + (1usize, 2usize), + (2usize, 3usize), + (10usize, 20usize), + (100usize, 200usize), + ]; + + let mut ids = HashSet::new(); + for (a, b) in seed_pairs { + let id = generate_tool_call_id(a, b); + let inserted = ids.insert(id.clone()); + assert!( + inserted, + "expected distinct tool-call IDs for different seeds, \ + but duplicate ID '{id}' found for seeds ({a}, {b})" + ); + } + } #[test] fn test_sanitize_preserves_valid_pairs() { diff --git a/src/llm/reasoning.rs b/src/llm/reasoning.rs index b00948ae8..cbec297bb 100644 --- a/src/llm/reasoning.rs +++ b/src/llm/reasoning.rs @@ -23,6 +23,13 @@ You said you would perform an action, but you did not include any tool calls.\n\ Do NOT describe what you intend to do — actually call the tool now.\n\ Use the tool_calls mechanism to invoke the appropriate tool."; +/// Seed value used as the second argument to `generate_tool_call_id` when +/// recovering tool calls from malformed LLM text responses. This must differ +/// from the `0` seed used in `rig_adapter::normalized_tool_call_id` to avoid +/// ID collisions between provider-generated and text-recovered tool calls at +/// the same positional index. +const RECOVERED_TOOL_CALL_SEED: usize = 99; + /// Detect when an LLM response expresses intent to call a tool without /// actually issuing tool calls. Returns `true` if the text contains phrases /// like "Let me search …" or "I'll fetch …" outside of fenced/indented code blocks. @@ -1337,7 +1344,10 @@ fn recover_tool_calls_from_content( .cloned() .unwrap_or(serde_json::Value::Object(Default::default())); calls.push(ToolCall { - id: format!("recovered_{}", calls.len()), + id: super::provider::generate_tool_call_id( + calls.len(), + RECOVERED_TOOL_CALL_SEED, + ), name: name.to_string(), arguments, }); @@ -1348,7 +1358,10 @@ fn recover_tool_calls_from_content( let name = inner.trim(); if tool_names.contains(name) { calls.push(ToolCall { - id: format!("recovered_{}", calls.len()), + id: super::provider::generate_tool_call_id( + calls.len(), + RECOVERED_TOOL_CALL_SEED, + ), name: name.to_string(), arguments: serde_json::Value::Object(Default::default()), }); @@ -1382,7 +1395,10 @@ fn recover_tool_calls_from_content( let arguments = serde_json::from_str::(args_str) .unwrap_or(serde_json::Value::Object(Default::default())); calls.push(ToolCall { - id: format!("recovered_{}", calls.len()), + id: super::provider::generate_tool_call_id( + calls.len(), + RECOVERED_TOOL_CALL_SEED, + ), name: name.to_string(), arguments, }); @@ -1393,7 +1409,7 @@ fn recover_tool_calls_from_content( // No arguments or malformed — call with empty args calls.push(ToolCall { - id: format!("recovered_{}", calls.len()), + id: super::provider::generate_tool_call_id(calls.len(), RECOVERED_TOOL_CALL_SEED), name: name.to_string(), arguments: serde_json::Value::Object(Default::default()), }); diff --git a/src/llm/rig_adapter.rs b/src/llm/rig_adapter.rs index 5c1faef79..2f382d6ed 100644 --- a/src/llm/rig_adapter.rs +++ b/src/llm/rig_adapter.rs @@ -20,6 +20,7 @@ use rust_decimal_macros::dec; use serde::Serialize; use serde::de::DeserializeOwned; use serde_json::Value as JsonValue; +use sha2::{Digest, Sha256}; use std::collections::HashSet; @@ -390,11 +391,48 @@ fn convert_messages(messages: &[ChatMessage]) -> (Option, Vec, seed: usize) -> String { - match raw.map(str::trim).filter(|id| !id.is_empty()) { - Some(id) => id.to_string(), - None => format!("generated_tool_call_{seed}"), + // Trim and treat empty as None. + let trimmed = raw.and_then(|s| { + let t = s.trim(); + if t.is_empty() { None } else { Some(t) } + }); + + if let Some(id) = trimmed { + // If the ID already satisfies `[a-zA-Z0-9]{9}`, pass it through unchanged. + if id.len() == 9 && id.chars().all(|c| c.is_ascii_alphanumeric()) { + return id.to_string(); + } + + // Otherwise, deterministically hash the raw ID and feed the hash-derived + // seed into the provider-level generator so that the encoding and any + // provider-specific constraints remain centralized in one place. + let digest = Sha256::digest(id.as_bytes()); + // Derive a 64-bit value from the first 8 bytes of the digest, then + // split it into two usize seeds so we preserve all 64 bits of entropy + // even on 32-bit targets. + let hash64 = { + // SHA-256 always produces 32 bytes, so indexing the first 8 is safe. + let bytes: [u8; 8] = [ + digest[0], digest[1], digest[2], digest[3], digest[4], digest[5], digest[6], + digest[7], + ]; + u64::from_be_bytes(bytes) + }; + let hi_seed: usize = (hash64 >> 32) as usize; + let lo_seed: usize = (hash64 & 0xFFFF_FFFF) as usize; + return super::provider::generate_tool_call_id(hi_seed, lo_seed); } + + // Fallback for missing/empty raw IDs: use the provider-level generator, + // which already produces compliant IDs. + super::provider::generate_tool_call_id(seed, 0) } /// Convert IronClaw tool definitions to rig-core format. @@ -792,8 +830,9 @@ mod tests { #[test] fn test_convert_messages_tool_result() { + // Use a conforming 9-char alphanumeric ID so it passes through unchanged. let messages = vec![ChatMessage::tool_result( - "call_123", + "abcDE1234", "search", "result text", )]; @@ -804,8 +843,8 @@ mod tests { match &history[0] { RigMessage::User { content } => match content.first() { UserContent::ToolResult(r) => { - assert_eq!(r.id, "call_123"); - assert_eq!(r.call_id.as_deref(), Some("call_123")); + assert_eq!(r.id, "abcDE1234"); + assert_eq!(r.call_id.as_deref(), Some("abcDE1234")); } other => panic!("Expected tool result content, got: {:?}", other), }, @@ -815,8 +854,9 @@ mod tests { #[test] fn test_convert_messages_assistant_with_tool_calls() { + // Use a conforming 9-char alphanumeric ID so it passes through unchanged. let tc = IronToolCall { - id: "call_1".to_string(), + id: "Xt7mK9pQ2".to_string(), name: "search".to_string(), arguments: serde_json::json!({"query": "test"}), }; @@ -830,7 +870,7 @@ mod tests { assert!(content.iter().count() >= 2); for item in content.iter() { if let AssistantContent::ToolCall(tc) = item { - assert_eq!(tc.call_id.as_deref(), Some("call_1")); + assert_eq!(tc.call_id.as_deref(), Some("Xt7mK9pQ2")); } } } @@ -852,7 +892,14 @@ mod tests { match &history[0] { RigMessage::User { content } => match content.first() { UserContent::ToolResult(r) => { - assert!(r.id.starts_with("generated_tool_call_")); + // Missing ID → normalized_tool_call_id generates a 9-char alphanumeric ID. + assert_eq!( + r.id.len(), + 9, + "fallback ID should be 9 chars, got: {}", + r.id + ); + assert!(r.id.chars().all(|c| c.is_ascii_alphanumeric())); assert_eq!(r.call_id.as_deref(), Some(r.id.as_str())); } other => panic!("Expected tool result content, got: {:?}", other), @@ -940,12 +987,14 @@ mod tests { _ => None, }); let tc = tool_call.expect("should have a tool call"); - assert!(!tc.id.is_empty(), "tool call id must not be empty"); - assert!( - tc.id.starts_with("generated_tool_call_"), - "empty id should be replaced with generated id, got: {}", + // Empty ID → normalized_tool_call_id generates a 9-char alphanumeric ID. + assert_eq!( + tc.id.len(), + 9, + "generated id should be 9 chars, got: {}", tc.id ); + assert!(tc.id.chars().all(|c| c.is_ascii_alphanumeric())); assert_eq!(tc.call_id.as_deref(), Some(tc.id.as_str())); } other => panic!("Expected Assistant message, got: {:?}", other), @@ -969,11 +1018,14 @@ mod tests { _ => None, }); let tc = tool_call.expect("should have a tool call"); - assert!( - tc.id.starts_with("generated_tool_call_"), - "whitespace-only id should be replaced, got: {:?}", + // Whitespace-only ID → normalized_tool_call_id generates a 9-char alphanumeric ID. + assert_eq!( + tc.id.len(), + 9, + "generated id should be 9 chars, got: {}", tc.id ); + assert!(tc.id.chars().all(|c| c.is_ascii_alphanumeric())); } other => panic!("Expected Assistant message, got: {:?}", other), } @@ -1360,4 +1412,67 @@ mod tests { // Should be 2 separate User messages (text user + tool result user) assert_eq!(history.len(), 2); } + + // -- normalized_tool_call_id tests -- + + #[test] + fn test_normalized_tool_call_id_conforming_passthrough() { + // A 9-char alphanumeric ID should pass through unchanged. + let id = normalized_tool_call_id(Some("abcDE1234"), 42); + assert_eq!(id, "abcDE1234"); + } + + #[test] + fn test_normalized_tool_call_id_non_conforming_hashed() { + // An ID that doesn't match [a-zA-Z0-9]{9} should be hashed into one. + let id = normalized_tool_call_id(Some("call_abc_long_id"), 0); + assert_eq!(id.len(), 9); + assert!(id.chars().all(|c| c.is_ascii_alphanumeric())); + // Should NOT be the raw input. + assert_ne!(id, "call_abc_l"); + } + + #[test] + fn test_normalized_tool_call_id_empty_input() { + let id = normalized_tool_call_id(Some(""), 5); + assert_eq!(id.len(), 9); + assert!(id.chars().all(|c| c.is_ascii_alphanumeric())); + } + + #[test] + fn test_normalized_tool_call_id_whitespace_input() { + let id = normalized_tool_call_id(Some(" "), 5); + assert_eq!(id.len(), 9); + assert!(id.chars().all(|c| c.is_ascii_alphanumeric())); + // Empty and whitespace-only with the same seed should produce identical results. + let id_empty = normalized_tool_call_id(Some(""), 5); + assert_eq!(id, id_empty); + } + + #[test] + fn test_normalized_tool_call_id_none_input() { + let id = normalized_tool_call_id(None, 7); + assert_eq!(id.len(), 9); + assert!(id.chars().all(|c| c.is_ascii_alphanumeric())); + // None and empty string with same seed should produce identical results. + let id_empty = normalized_tool_call_id(Some(""), 7); + assert_eq!(id, id_empty); + } + + #[test] + fn test_normalized_tool_call_id_deterministic() { + let id1 = normalized_tool_call_id(Some("call_xyz_123"), 0); + let id2 = normalized_tool_call_id(Some("call_xyz_123"), 0); + assert_eq!(id1, id2, "same input must produce same output"); + } + + #[test] + fn test_normalized_tool_call_id_different_inputs_differ() { + let id_a = normalized_tool_call_id(Some("call_aaa"), 0); + let id_b = normalized_tool_call_id(Some("call_bbb"), 0); + assert_ne!( + id_a, id_b, + "different raw IDs should produce different hashed IDs" + ); + } }