Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions src/agent/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub(super) enum AgenticLoopResult {
/// A tool requires approval before continuing.
NeedApproval {
/// The pending approval request to store.
pending: PendingApproval,
pending: Box<PendingApproval>,
},
}

Expand Down Expand Up @@ -216,9 +216,7 @@ impl Agent {
reason: format!("Exceeded maximum tool iterations ({max_tool_iterations})"),
}
.into()),
LoopOutcome::NeedApproval(pending) => {
Ok(AgenticLoopResult::NeedApproval { pending: *pending })
}
LoopOutcome::NeedApproval(pending) => Ok(AgenticLoopResult::NeedApproval { pending }),
}
}

Expand Down Expand Up @@ -481,6 +479,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> {
usize,
crate::llm::ToolCall,
Arc<dyn crate::tools::Tool>,
bool, // allow_always
)> = None;

for (idx, original_tc) in tool_calls.iter().enumerate() {
Expand Down Expand Up @@ -550,7 +549,8 @@ impl<'a> LoopDelegate for ChatDelegate<'a> {
&& let Some(tool) = tool_opt
{
use crate::tools::ApprovalRequirement;
let needs_approval = match tool.requires_approval(&tc.arguments) {
let requirement = tool.requires_approval(&tc.arguments);
let needs_approval = match requirement {
ApprovalRequirement::Never => false,
ApprovalRequirement::UnlessAutoApproved => {
let sess = self.session.lock().await;
Expand Down Expand Up @@ -585,7 +585,8 @@ impl<'a> LoopDelegate for ChatDelegate<'a> {
continue;
}

approval_needed = Some((idx, tc, tool));
let allow_always = !matches!(requirement, ApprovalRequirement::Always);
approval_needed = Some((idx, tc, tool, allow_always));
break;
}
}
Expand Down Expand Up @@ -886,7 +887,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> {
}

// Handle approval if a tool needed it
if let Some((approval_idx, tc, tool)) = approval_needed {
if let Some((approval_idx, tc, tool, allow_always)) = approval_needed {
let display_params = redact_params(&tc.arguments, tool.sensitive_params());
let pending = PendingApproval {
request_id: Uuid::new_v4(),
Expand All @@ -898,6 +899,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> {
context_messages: reason_ctx.messages.clone(),
deferred_tool_calls: tool_calls[approval_idx + 1..].to_vec(),
user_timezone: Some(self.user_tz.name().to_string()),
allow_always,
};

return Ok(Some(LoopOutcome::NeedApproval(Box::new(pending))));
Expand Down Expand Up @@ -1363,6 +1365,35 @@ mod tests {
assert!(always_needs, "Always must always require approval");
}

/// Regression test: `allow_always` must be `false` for `Always` and
/// `true` for `UnlessAutoApproved`, so the UI hides the "always" button
/// for tools that truly cannot be auto-approved.
#[test]
fn test_allow_always_matches_approval_requirement() {
use crate::tools::ApprovalRequirement;

// Mirrors the expression used in dispatcher.rs and thread_ops.rs:
// let allow_always = !matches!(requirement, ApprovalRequirement::Always);

// UnlessAutoApproved → allow_always = true
let req = ApprovalRequirement::UnlessAutoApproved;
let allow_always = !matches!(req, ApprovalRequirement::Always);
assert!(
allow_always,
"UnlessAutoApproved should set allow_always = true"
);

// Always → allow_always = false
let req = ApprovalRequirement::Always;
let allow_always = !matches!(req, ApprovalRequirement::Always);
assert!(!allow_always, "Always should set allow_always = false");

// Never → allow_always = true (approval is never needed, but if it were, always would be ok)
let req = ApprovalRequirement::Never;
let allow_always = !matches!(req, ApprovalRequirement::Always);
assert!(allow_always, "Never should set allow_always = true");
}

#[test]
fn test_pending_approval_serialization_backcompat_without_deferred_calls() {
// PendingApproval from before the deferred_tool_calls field was added
Expand Down Expand Up @@ -1408,6 +1439,7 @@ mod tests {
},
],
user_timezone: None,
allow_always: true,
};

let json = serde_json::to_string(&pending).expect("serialize");
Expand Down
11 changes: 11 additions & 0 deletions src/agent/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ pub struct PendingApproval {
/// through the approval flow even if the approval message lacks timezone.
#[serde(default)]
pub user_timezone: Option<String>,
/// Whether the "always" auto-approve option should be offered to the user.
/// `false` when the tool returned `ApprovalRequirement::Always` (e.g.
/// destructive shell commands), meaning every invocation must be confirmed.
#[serde(default = "default_true")]
pub allow_always: bool,
}

fn default_true() -> bool {
true
}

/// A conversation thread within a session.
Expand Down Expand Up @@ -1106,6 +1115,7 @@ mod tests {
context_messages: vec![ChatMessage::user("do it")],
deferred_tool_calls: vec![],
user_timezone: None,
allow_always: false,
};

thread.await_approval(approval);
Expand All @@ -1132,6 +1142,7 @@ mod tests {
context_messages: vec![],
deferred_tool_calls: vec![],
user_timezone: None,
allow_always: true,
};

thread.await_approval(approval);
Expand Down
2 changes: 2 additions & 0 deletions src/agent/submission.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ pub enum SubmissionResult {
description: String,
/// Parameters being passed.
parameters: serde_json::Value,
/// Whether "always" auto-approve should be offered to the user.
allow_always: bool,
},

/// Successfully processed (for control commands).
Expand Down
28 changes: 20 additions & 8 deletions src/agent/thread_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ impl Agent {
let tool_name = pending.tool_name.clone();
let description = pending.description.clone();
let parameters = pending.display_parameters.clone();
thread.await_approval(pending);
let allow_always = pending.allow_always;
thread.await_approval(*pending);
let _ = self
.channels
.send_status(
Expand All @@ -507,6 +508,7 @@ impl Agent {
tool_name: tool_name.clone(),
description: description.clone(),
parameters: parameters.clone(),
allow_always,
},
&message.metadata,
)
Expand All @@ -516,6 +518,7 @@ impl Agent {
tool_name,
description,
parameters,
allow_always,
})
}
Err(e) => {
Expand Down Expand Up @@ -1060,28 +1063,31 @@ impl Agent {
usize,
crate::llm::ToolCall,
Arc<dyn crate::tools::Tool>,
bool, // allow_always
)> = None;

for (idx, tc) in deferred_tool_calls.iter().enumerate() {
if let Some(tool) = self.tools().get(&tc.name).await {
// Match dispatcher.rs: when auto_approve_tools is true, skip
// all approval checks (including ApprovalRequirement::Always).
let needs_approval = if self.config.auto_approve_tools {
false
let (needs_approval, allow_always) = if self.config.auto_approve_tools {
(false, true)
} else {
use crate::tools::ApprovalRequirement;
match tool.requires_approval(&tc.arguments) {
let requirement = tool.requires_approval(&tc.arguments);
let needs = match requirement {
ApprovalRequirement::Never => false,
ApprovalRequirement::UnlessAutoApproved => {
let sess = session.lock().await;
!sess.is_tool_auto_approved(&tc.name)
}
ApprovalRequirement::Always => true,
}
};
(needs, !matches!(requirement, ApprovalRequirement::Always))
};

if needs_approval {
approval_needed = Some((idx, tc.clone(), tool));
approval_needed = Some((idx, tc.clone(), tool, allow_always));
break; // remaining tools stay deferred
}
}
Expand Down Expand Up @@ -1289,7 +1295,7 @@ impl Agent {
}

// Handle approval if a tool needed it
if let Some((approval_idx, tc, tool)) = approval_needed {
if let Some((approval_idx, tc, tool, allow_always)) = approval_needed {
let new_pending = PendingApproval {
request_id: Uuid::new_v4(),
tool_name: tc.name.clone(),
Expand All @@ -1301,6 +1307,7 @@ impl Agent {
deferred_tool_calls: deferred_tool_calls[approval_idx + 1..].to_vec(),
// Carry forward the resolved timezone from the original pending approval
user_timezone: pending.user_timezone.clone(),
allow_always,
};

let request_id = new_pending.request_id;
Expand All @@ -1324,6 +1331,7 @@ impl Agent {
tool_name: tool_name.clone(),
description: description.clone(),
parameters: parameters.clone(),
allow_always,
},
&message.metadata,
)
Expand All @@ -1334,6 +1342,7 @@ impl Agent {
tool_name,
description,
parameters,
allow_always,
});
}

Expand Down Expand Up @@ -1402,7 +1411,8 @@ impl Agent {
let tool_name = new_pending.tool_name.clone();
let description = new_pending.description.clone();
let parameters = new_pending.display_parameters.clone();
thread.await_approval(new_pending);
let allow_always = new_pending.allow_always;
thread.await_approval(*new_pending);
let _ = self
.channels
.send_status(
Expand All @@ -1412,6 +1422,7 @@ impl Agent {
tool_name: tool_name.clone(),
description: description.clone(),
parameters: parameters.clone(),
allow_always,
},
&message.metadata,
)
Expand All @@ -1421,6 +1432,7 @@ impl Agent {
tool_name,
description,
parameters,
allow_always,
})
}
Err(e) => {
Expand Down
5 changes: 5 additions & 0 deletions src/channels/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ pub enum StatusUpdate {
tool_name: String,
description: String,
parameters: serde_json::Value,
/// When `true`, the UI should offer an "always" option that auto-approves
/// future calls to this tool for the rest of the session. When `false`
/// (i.e. `ApprovalRequirement::Always`), the tool must be approved every
/// time and the "always" button should be hidden.
allow_always: bool,
},
/// Extension needs user authentication (token or OAuth).
AuthRequired {
Expand Down
4 changes: 4 additions & 0 deletions src/channels/relay/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ impl Channel for RelayChannel {
tool_name,
description,
parameters,
allow_always: _,
} = status
else {
return Ok(());
Expand Down Expand Up @@ -794,6 +795,7 @@ mod tests {
tool_name: "shell".into(),
description: "run command".into(),
parameters: serde_json::json!({}),
allow_always: true,
},
&metadata,
)
Expand Down Expand Up @@ -822,6 +824,7 @@ mod tests {
tool_name: "shell".into(),
description: "run command".into(),
parameters: serde_json::json!({}),
allow_always: true,
},
&metadata,
)
Expand Down Expand Up @@ -854,6 +857,7 @@ mod tests {
tool_name: "shell".into(),
description: "run command".into(),
parameters: serde_json::json!({}),
allow_always: true,
},
&metadata,
)
Expand Down
11 changes: 8 additions & 3 deletions src/channels/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ impl Channel for ReplChannel {
tool_name,
description,
parameters,
allow_always,
} => {
let term_width = crossterm::terminal::size()
.map(|(w, _)| w as usize)
Expand Down Expand Up @@ -582,9 +583,13 @@ impl Channel for ReplChannel {
}

eprintln!(" \u{2502}");
eprintln!(
" \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[34malways\x1b[0m (a) / \x1b[31mno\x1b[0m (n)"
);
if allow_always {
eprintln!(
" \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[34malways\x1b[0m (a) / \x1b[31mno\x1b[0m (n)"
);
} else {
eprintln!(" \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[31mno\x1b[0m (n)");
}
eprintln!(" {bot_border}");
eprintln!();
}
Expand Down
14 changes: 11 additions & 3 deletions src/channels/signal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,20 +915,28 @@ impl Channel for SignalChannel {
tool_name,
description: _,
parameters,
allow_always,
} = &status
&& let Some(target_str) = metadata.get("signal_target").and_then(|v| v.as_str())
{
let params_json = serde_json::to_string_pretty(parameters).unwrap_or_default();
let always_line = if *allow_always {
format!(
"\n• `always` or `a` - Approve and auto-approve future {} requests",
tool_name
)
} else {
String::new()
};
let message = format!(
"⚠️ *Approval Required*\n\n\
*Request ID:* `{}`\n\
*Tool:* {}\n\
*Parameters:*\n```\n{}\n```\n\n\
Reply with:\n\
• `yes` or `y` - Approve this request\n\
• `always` or `a` - Approve and auto-approve future {} requests\n\
• `yes` or `y` - Approve this request{}\n\
• `no` or `n` - Deny",
request_id, tool_name, params_json, tool_name
request_id, tool_name, params_json, always_line
);
self.send_status_message(target_str, &message).await;
}
Expand Down
Loading
Loading