Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions rig-core/src/tool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

pub mod server;
use std::collections::HashMap;
use std::fmt;

use futures::Future;
use serde::{Deserialize, Serialize};
Expand All @@ -25,18 +26,33 @@ use crate::{
pub enum ToolError {
#[cfg(not(target_family = "wasm"))]
/// Error returned by the tool
#[error("ToolCallError: {0}")]
ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),

#[cfg(target_family = "wasm")]
/// Error returned by the tool
#[error("ToolCallError: {0}")]
ToolCallError(#[from] Box<dyn std::error::Error>),

#[error("JsonError: {0}")]
/// Error caused by a de/serialization fail
JsonError(#[from] serde_json::Error),
}

impl fmt::Display for ToolError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ToolError::ToolCallError(e) => {
let error_str = e.to_string();
// This is required due to being able to use agents as tools
// which means it is possible to get recursive tool call errors
if error_str.starts_with("ToolCallError: ") {
write!(f, "{}", error_str)
} else {
write!(f, "ToolCallError: {}", error_str)
}
}
ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
}
}
}

/// Trait that represents a simple LLM tool
///
/// # Example
Expand Down