diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher.rs index 4301c09af3..991bb5819d 100644 --- a/src/agent/dispatcher.rs +++ b/src/agent/dispatcher.rs @@ -29,7 +29,7 @@ pub(super) enum AgenticLoopResult { /// A tool requires approval before continuing. NeedApproval { /// The pending approval request to store. - pending: PendingApproval, + pending: Box, }, } @@ -216,9 +216,7 @@ impl Agent { reason: format!("Exceeded maximum tool iterations ({max_tool_iterations})"), } .into()), - LoopOutcome::NeedApproval(pending) => { - Ok(AgenticLoopResult::NeedApproval { pending: *pending }) - } + LoopOutcome::NeedApproval(pending) => Ok(AgenticLoopResult::NeedApproval { pending }), } } @@ -481,6 +479,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { usize, crate::llm::ToolCall, Arc, + bool, // allow_always )> = None; for (idx, original_tc) in tool_calls.iter().enumerate() { @@ -550,7 +549,8 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { && let Some(tool) = tool_opt { use crate::tools::ApprovalRequirement; - let needs_approval = match tool.requires_approval(&tc.arguments) { + let requirement = tool.requires_approval(&tc.arguments); + let needs_approval = match requirement { ApprovalRequirement::Never => false, ApprovalRequirement::UnlessAutoApproved => { let sess = self.session.lock().await; @@ -585,7 +585,8 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { continue; } - approval_needed = Some((idx, tc, tool)); + let allow_always = !matches!(requirement, ApprovalRequirement::Always); + approval_needed = Some((idx, tc, tool, allow_always)); break; } } @@ -886,7 +887,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { } // Handle approval if a tool needed it - if let Some((approval_idx, tc, tool)) = approval_needed { + if let Some((approval_idx, tc, tool, allow_always)) = approval_needed { let display_params = redact_params(&tc.arguments, tool.sensitive_params()); let pending = PendingApproval { request_id: Uuid::new_v4(), @@ -898,6 +899,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { context_messages: reason_ctx.messages.clone(), deferred_tool_calls: tool_calls[approval_idx + 1..].to_vec(), user_timezone: Some(self.user_tz.name().to_string()), + allow_always, }; return Ok(Some(LoopOutcome::NeedApproval(Box::new(pending)))); @@ -1363,6 +1365,35 @@ mod tests { assert!(always_needs, "Always must always require approval"); } + /// Regression test: `allow_always` must be `false` for `Always` and + /// `true` for `UnlessAutoApproved`, so the UI hides the "always" button + /// for tools that truly cannot be auto-approved. + #[test] + fn test_allow_always_matches_approval_requirement() { + use crate::tools::ApprovalRequirement; + + // Mirrors the expression used in dispatcher.rs and thread_ops.rs: + // let allow_always = !matches!(requirement, ApprovalRequirement::Always); + + // UnlessAutoApproved → allow_always = true + let req = ApprovalRequirement::UnlessAutoApproved; + let allow_always = !matches!(req, ApprovalRequirement::Always); + assert!( + allow_always, + "UnlessAutoApproved should set allow_always = true" + ); + + // Always → allow_always = false + let req = ApprovalRequirement::Always; + let allow_always = !matches!(req, ApprovalRequirement::Always); + assert!(!allow_always, "Always should set allow_always = false"); + + // Never → allow_always = true (approval is never needed, but if it were, always would be ok) + let req = ApprovalRequirement::Never; + let allow_always = !matches!(req, ApprovalRequirement::Always); + assert!(allow_always, "Never should set allow_always = true"); + } + #[test] fn test_pending_approval_serialization_backcompat_without_deferred_calls() { // PendingApproval from before the deferred_tool_calls field was added @@ -1408,6 +1439,7 @@ mod tests { }, ], user_timezone: None, + allow_always: true, }; let json = serde_json::to_string(&pending).expect("serialize"); diff --git a/src/agent/session.rs b/src/agent/session.rs index 4abbea6168..3e84afc0b6 100644 --- a/src/agent/session.rs +++ b/src/agent/session.rs @@ -188,6 +188,15 @@ pub struct PendingApproval { /// through the approval flow even if the approval message lacks timezone. #[serde(default)] pub user_timezone: Option, + /// Whether the "always" auto-approve option should be offered to the user. + /// `false` when the tool returned `ApprovalRequirement::Always` (e.g. + /// destructive shell commands), meaning every invocation must be confirmed. + #[serde(default = "default_true")] + pub allow_always: bool, +} + +fn default_true() -> bool { + true } /// A conversation thread within a session. @@ -1106,6 +1115,7 @@ mod tests { context_messages: vec![ChatMessage::user("do it")], deferred_tool_calls: vec![], user_timezone: None, + allow_always: false, }; thread.await_approval(approval); @@ -1132,6 +1142,7 @@ mod tests { context_messages: vec![], deferred_tool_calls: vec![], user_timezone: None, + allow_always: true, }; thread.await_approval(approval); diff --git a/src/agent/submission.rs b/src/agent/submission.rs index 463361330d..8fcf5bf8cc 100644 --- a/src/agent/submission.rs +++ b/src/agent/submission.rs @@ -382,6 +382,8 @@ pub enum SubmissionResult { description: String, /// Parameters being passed. parameters: serde_json::Value, + /// Whether "always" auto-approve should be offered to the user. + allow_always: bool, }, /// Successfully processed (for control commands). diff --git a/src/agent/thread_ops.rs b/src/agent/thread_ops.rs index e5f2005d25..16c117c243 100644 --- a/src/agent/thread_ops.rs +++ b/src/agent/thread_ops.rs @@ -497,7 +497,8 @@ impl Agent { let tool_name = pending.tool_name.clone(); let description = pending.description.clone(); let parameters = pending.display_parameters.clone(); - thread.await_approval(pending); + let allow_always = pending.allow_always; + thread.await_approval(*pending); let _ = self .channels .send_status( @@ -507,6 +508,7 @@ impl Agent { tool_name: tool_name.clone(), description: description.clone(), parameters: parameters.clone(), + allow_always, }, &message.metadata, ) @@ -516,6 +518,7 @@ impl Agent { tool_name, description, parameters, + allow_always, }) } Err(e) => { @@ -1060,28 +1063,31 @@ impl Agent { usize, crate::llm::ToolCall, Arc, + bool, // allow_always )> = None; for (idx, tc) in deferred_tool_calls.iter().enumerate() { if let Some(tool) = self.tools().get(&tc.name).await { // Match dispatcher.rs: when auto_approve_tools is true, skip // all approval checks (including ApprovalRequirement::Always). - let needs_approval = if self.config.auto_approve_tools { - false + let (needs_approval, allow_always) = if self.config.auto_approve_tools { + (false, true) } else { use crate::tools::ApprovalRequirement; - match tool.requires_approval(&tc.arguments) { + let requirement = tool.requires_approval(&tc.arguments); + let needs = match requirement { ApprovalRequirement::Never => false, ApprovalRequirement::UnlessAutoApproved => { let sess = session.lock().await; !sess.is_tool_auto_approved(&tc.name) } ApprovalRequirement::Always => true, - } + }; + (needs, !matches!(requirement, ApprovalRequirement::Always)) }; if needs_approval { - approval_needed = Some((idx, tc.clone(), tool)); + approval_needed = Some((idx, tc.clone(), tool, allow_always)); break; // remaining tools stay deferred } } @@ -1289,7 +1295,7 @@ impl Agent { } // Handle approval if a tool needed it - if let Some((approval_idx, tc, tool)) = approval_needed { + if let Some((approval_idx, tc, tool, allow_always)) = approval_needed { let new_pending = PendingApproval { request_id: Uuid::new_v4(), tool_name: tc.name.clone(), @@ -1301,6 +1307,7 @@ impl Agent { deferred_tool_calls: deferred_tool_calls[approval_idx + 1..].to_vec(), // Carry forward the resolved timezone from the original pending approval user_timezone: pending.user_timezone.clone(), + allow_always, }; let request_id = new_pending.request_id; @@ -1324,6 +1331,7 @@ impl Agent { tool_name: tool_name.clone(), description: description.clone(), parameters: parameters.clone(), + allow_always, }, &message.metadata, ) @@ -1334,6 +1342,7 @@ impl Agent { tool_name, description, parameters, + allow_always, }); } @@ -1402,7 +1411,8 @@ impl Agent { let tool_name = new_pending.tool_name.clone(); let description = new_pending.description.clone(); let parameters = new_pending.display_parameters.clone(); - thread.await_approval(new_pending); + let allow_always = new_pending.allow_always; + thread.await_approval(*new_pending); let _ = self .channels .send_status( @@ -1412,6 +1422,7 @@ impl Agent { tool_name: tool_name.clone(), description: description.clone(), parameters: parameters.clone(), + allow_always, }, &message.metadata, ) @@ -1421,6 +1432,7 @@ impl Agent { tool_name, description, parameters, + allow_always, }) } Err(e) => { diff --git a/src/channels/channel.rs b/src/channels/channel.rs index 43e35688cc..a85cf8c5d2 100644 --- a/src/channels/channel.rs +++ b/src/channels/channel.rs @@ -305,6 +305,11 @@ pub enum StatusUpdate { tool_name: String, description: String, parameters: serde_json::Value, + /// When `true`, the UI should offer an "always" option that auto-approves + /// future calls to this tool for the rest of the session. When `false` + /// (i.e. `ApprovalRequirement::Always`), the tool must be approved every + /// time and the "always" button should be hidden. + allow_always: bool, }, /// Extension needs user authentication (token or OAuth). AuthRequired { diff --git a/src/channels/relay/channel.rs b/src/channels/relay/channel.rs index 52aea478ee..9216e9b8a7 100644 --- a/src/channels/relay/channel.rs +++ b/src/channels/relay/channel.rs @@ -423,6 +423,7 @@ impl Channel for RelayChannel { tool_name, description, parameters, + allow_always: _, } = status else { return Ok(()); @@ -794,6 +795,7 @@ mod tests { tool_name: "shell".into(), description: "run command".into(), parameters: serde_json::json!({}), + allow_always: true, }, &metadata, ) @@ -822,6 +824,7 @@ mod tests { tool_name: "shell".into(), description: "run command".into(), parameters: serde_json::json!({}), + allow_always: true, }, &metadata, ) @@ -854,6 +857,7 @@ mod tests { tool_name: "shell".into(), description: "run command".into(), parameters: serde_json::json!({}), + allow_always: true, }, &metadata, ) diff --git a/src/channels/repl.rs b/src/channels/repl.rs index 40d669198c..36ca7c28a0 100644 --- a/src/channels/repl.rs +++ b/src/channels/repl.rs @@ -539,6 +539,7 @@ impl Channel for ReplChannel { tool_name, description, parameters, + allow_always, } => { let term_width = crossterm::terminal::size() .map(|(w, _)| w as usize) @@ -582,9 +583,13 @@ impl Channel for ReplChannel { } eprintln!(" \u{2502}"); - eprintln!( - " \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[34malways\x1b[0m (a) / \x1b[31mno\x1b[0m (n)" - ); + if allow_always { + eprintln!( + " \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[34malways\x1b[0m (a) / \x1b[31mno\x1b[0m (n)" + ); + } else { + eprintln!(" \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[31mno\x1b[0m (n)"); + } eprintln!(" {bot_border}"); eprintln!(); } diff --git a/src/channels/signal.rs b/src/channels/signal.rs index b8934c5cb1..84afccd5fb 100644 --- a/src/channels/signal.rs +++ b/src/channels/signal.rs @@ -915,20 +915,28 @@ impl Channel for SignalChannel { tool_name, description: _, parameters, + allow_always, } = &status && let Some(target_str) = metadata.get("signal_target").and_then(|v| v.as_str()) { let params_json = serde_json::to_string_pretty(parameters).unwrap_or_default(); + let always_line = if *allow_always { + format!( + "\n• `always` or `a` - Approve and auto-approve future {} requests", + tool_name + ) + } else { + String::new() + }; let message = format!( "⚠️ *Approval Required*\n\n\ *Request ID:* `{}`\n\ *Tool:* {}\n\ *Parameters:*\n```\n{}\n```\n\n\ Reply with:\n\ - • `yes` or `y` - Approve this request\n\ - • `always` or `a` - Approve and auto-approve future {} requests\n\ + • `yes` or `y` - Approve this request{}\n\ • `no` or `n` - Deny", - request_id, tool_name, params_json, tool_name + request_id, tool_name, params_json, always_line ); self.send_status_message(target_str, &message).await; } diff --git a/src/channels/wasm/wrapper.rs b/src/channels/wasm/wrapper.rs index 0be8756b1a..ae38c9a66a 100644 --- a/src/channels/wasm/wrapper.rs +++ b/src/channels/wasm/wrapper.rs @@ -1997,6 +1997,7 @@ impl WasmChannel { tool_name, description, parameters, + allow_always, .. } => { // WASM channels (Telegram, Slack, etc.) cannot render @@ -2035,6 +2036,11 @@ impl WasmChannel { }) .unwrap_or_default(); + let reply_hint = if *allow_always { + "Reply \"yes\" to approve, \"no\" to deny, or \"always\" to auto-approve." + } else { + "Reply \"yes\" to approve or \"no\" to deny." + }; let prompt = format!( "Approval needed: {tool_name}\n\ {description}\n\ @@ -2042,7 +2048,7 @@ impl WasmChannel { Parameters:\n\ {params_preview}\n\ \n\ - Reply \"yes\" to approve, \"no\" to deny, or \"always\" to auto-approve." + {reply_hint}" ); let metadata_json = serde_json::to_string(metadata).unwrap_or_default(); @@ -2935,15 +2941,23 @@ fn status_to_wit( request_id, tool_name, description, + allow_always, .. - } => wit_channel::StatusUpdate { - status: wit_channel::StatusType::ApprovalNeeded, - message: format!( - "Approval needed for tool '{}'. {}\nRequest ID: {}\nReply with: yes (or /approve), no (or /deny), or always (or /always).", - tool_name, description, request_id - ), - metadata_json, - }, + } => { + let reply_hint = if *allow_always { + "yes (or /approve), no (or /deny), or always (or /always)" + } else { + "yes (or /approve) or no (or /deny)" + }; + wit_channel::StatusUpdate { + status: wit_channel::StatusType::ApprovalNeeded, + message: format!( + "Approval needed for tool '{}'. {}\nRequest ID: {}\nReply with: {}.", + tool_name, description, request_id, reply_hint + ), + metadata_json, + } + } StatusUpdate::JobStarted { job_id, title, @@ -3611,6 +3625,7 @@ mod tests { tool_name: "http_request".into(), description: "Fetch weather".into(), parameters: serde_json::json!({"url": "https://wttr.in"}), + allow_always: true, }, &metadata, ) @@ -4072,6 +4087,7 @@ mod tests { tool_name: "http_request".to_string(), description: "Fetch weather data".to_string(), parameters: serde_json::json!({"url": "https://api.weather.test"}), + allow_always: true, }, &metadata, ) @@ -4097,6 +4113,7 @@ mod tests { tool_name: "http_request".to_string(), description: "Fetch weather data".to_string(), parameters: serde_json::json!({"url": "https://api.weather.test"}), + allow_always: true, }, &metadata, ) diff --git a/src/channels/web/mod.rs b/src/channels/web/mod.rs index 0d970569a9..9acfedeb40 100644 --- a/src/channels/web/mod.rs +++ b/src/channels/web/mod.rs @@ -366,6 +366,7 @@ impl Channel for GatewayChannel { tool_name, description, parameters, + allow_always, } => SseEvent::ApprovalNeeded { request_id, tool_name, @@ -373,6 +374,7 @@ impl Channel for GatewayChannel { parameters: serde_json::to_string_pretty(¶meters) .unwrap_or_else(|_| parameters.to_string()), thread_id, + allow_always, }, StatusUpdate::AuthRequired { extension_name, diff --git a/src/channels/web/static/app.js b/src/channels/web/static/app.js index 127c18fa0c..a82222471b 100644 --- a/src/channels/web/static/app.js +++ b/src/channels/web/static/app.js @@ -1136,18 +1136,19 @@ function showApproval(data) { approveBtn.textContent = I18n.t('approval.approve'); approveBtn.addEventListener('click', () => sendApprovalAction(data.request_id, 'approve')); - const alwaysBtn = document.createElement('button'); - alwaysBtn.className = 'always'; - alwaysBtn.textContent = I18n.t('approval.always'); - alwaysBtn.addEventListener('click', () => sendApprovalAction(data.request_id, 'always')); - const denyBtn = document.createElement('button'); denyBtn.className = 'deny'; denyBtn.textContent = I18n.t('approval.deny'); denyBtn.addEventListener('click', () => sendApprovalAction(data.request_id, 'deny')); actions.appendChild(approveBtn); - actions.appendChild(alwaysBtn); + if (data.allow_always !== false) { + const alwaysBtn = document.createElement('button'); + alwaysBtn.className = 'always'; + alwaysBtn.textContent = I18n.t('approval.always'); + alwaysBtn.addEventListener('click', () => sendApprovalAction(data.request_id, 'always')); + actions.appendChild(alwaysBtn); + } actions.appendChild(denyBtn); card.appendChild(actions); diff --git a/src/channels/web/types.rs b/src/channels/web/types.rs index 3fad9f3525..b2c060c9c9 100644 --- a/src/channels/web/types.rs +++ b/src/channels/web/types.rs @@ -177,6 +177,8 @@ pub enum SseEvent { parameters: String, #[serde(skip_serializing_if = "Option::is_none")] thread_id: Option, + /// Whether the "always" auto-approve option should be shown. + allow_always: bool, }, #[serde(rename = "auth_required")] AuthRequired { @@ -1080,6 +1082,7 @@ mod tests { description: "Run ls".to_string(), parameters: "{}".to_string(), thread_id: Some("t1".to_string()), + allow_always: true, }; let ws = WsServerMessage::from_sse_event(&sse); match ws { diff --git a/src/tools/builtin/http.rs b/src/tools/builtin/http.rs index 9d7af888da..0bd8eb37cb 100644 --- a/src/tools/builtin/http.rs +++ b/src/tools/builtin/http.rs @@ -837,7 +837,7 @@ impl Tool for HttpTool { })); if has_credentials { - return ApprovalRequirement::Always; + return ApprovalRequirement::UnlessAutoApproved; } // GET requests (or missing method, since GET is the default) are low-risk @@ -1093,25 +1093,31 @@ mod tests { } #[test] - fn test_auth_header_object_format_returns_always() { + fn test_auth_header_object_format_returns_unless_auto_approved() { let tool = HttpTool::new(); let params = serde_json::json!({ "method": "GET", "url": "https://api.example.com/data", "headers": {"Authorization": "Bearer token123"} }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] - fn test_auth_header_array_format_returns_always() { + fn test_auth_header_array_format_returns_unless_auto_approved() { let tool = HttpTool::new(); let params = serde_json::json!({ "method": "GET", "url": "https://api.example.com/data", "headers": [{"name": "Authorization", "value": "Bearer token123"}] }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] @@ -1124,7 +1130,10 @@ mod tests { "url": "https://example.com", "headers": {"AUTHORIZATION": "Bearer x"} }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); // Array format with mixed case let params = serde_json::json!({ @@ -1132,7 +1141,10 @@ mod tests { "url": "https://example.com", "headers": [{"name": "X-Api-Key", "value": "key123"}] }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] @@ -1161,8 +1173,8 @@ mod tests { }); assert_eq!( tool.requires_approval(¶ms), - ApprovalRequirement::Always, - "Header '{}' should trigger Always approval", + ApprovalRequirement::UnlessAutoApproved, + "Header '{}' should trigger UnlessAutoApproved approval", header_name ); } @@ -1203,7 +1215,7 @@ mod tests { // ── Credential registry approval tests ───────────────────────────── #[test] - fn test_host_with_credential_mapping_returns_always() { + fn test_host_with_credential_mapping_returns_unless_auto_approved() { use crate::secrets::CredentialMapping; use crate::tools::wasm::SharedCredentialRegistry; @@ -1223,7 +1235,10 @@ mod tests { "method": "GET", "url": "https://api.openai.com/v1/models" }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] @@ -1243,24 +1258,55 @@ mod tests { } #[test] - fn test_url_query_param_credential_returns_always() { + fn test_url_query_param_credential_returns_unless_auto_approved() { let tool = HttpTool::new(); let params = serde_json::json!({ "method": "GET", "url": "https://api.example.com/data?api_key=secret123" }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] - fn test_bearer_value_in_custom_header_returns_always() { + fn test_bearer_value_in_custom_header_returns_unless_auto_approved() { let tool = HttpTool::new(); let params = serde_json::json!({ "method": "GET", "url": "https://example.com", "headers": {"X-Custom": format!("Bearer {TEST_OPENAI_API_KEY}")} }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); + } + + /// Regression test: credentialed HTTP requests must return + /// `UnlessAutoApproved` (not `Always`) so that the session auto-approve + /// set is respected when the user says "always". + #[test] + fn test_credentialed_requests_respect_auto_approve() { + let tool = HttpTool::new(); + + // Manual credentials (Authorization header) + let params = serde_json::json!({ + "method": "GET", + "url": "https://api.github.com/orgs/Casa", + "headers": {"Authorization": "Bearer ghp_abc123"} + }); + // Must NOT be Always — Always ignores the session auto-approve set + assert_ne!( + tool.requires_approval(¶ms), + ApprovalRequirement::Always, + "Credentialed HTTP requests must not return Always; use UnlessAutoApproved" + ); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved, + ); } #[test]