diff --git a/rig-core/src/tool/mod.rs b/rig-core/src/tool/mod.rs index cffdb01b2..7f935e671 100644 --- a/rig-core/src/tool/mod.rs +++ b/rig-core/src/tool/mod.rs @@ -11,6 +11,7 @@ pub mod server; use std::collections::HashMap; +use std::fmt; use futures::Future; use serde::{Deserialize, Serialize}; @@ -25,18 +26,33 @@ use crate::{ pub enum ToolError { #[cfg(not(target_family = "wasm"))] /// Error returned by the tool - #[error("ToolCallError: {0}")] ToolCallError(#[from] Box), #[cfg(target_family = "wasm")] /// Error returned by the tool - #[error("ToolCallError: {0}")] ToolCallError(#[from] Box), - - #[error("JsonError: {0}")] + /// Error caused by a de/serialization fail JsonError(#[from] serde_json::Error), } +impl fmt::Display for ToolError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ToolError::ToolCallError(e) => { + let error_str = e.to_string(); + // This is required due to being able to use agents as tools + // which means it is possible to get recursive tool call errors + if error_str.starts_with("ToolCallError: ") { + write!(f, "{}", error_str) + } else { + write!(f, "ToolCallError: {}", error_str) + } + } + ToolError::JsonError(e) => write!(f, "JsonError: {e}"), + } + } +} + /// Trait that represents a simple LLM tool /// /// # Example