diff --git a/FEATURE_PARITY.md b/FEATURE_PARITY.md index 85348de539..e0002a4117 100644 --- a/FEATURE_PARITY.md +++ b/FEATURE_PARITY.md @@ -465,7 +465,7 @@ This document tracks feature parity between IronClaw (Rust implementation) and O | Device pairing | ✅ | ❌ | | | Tailscale identity | ✅ | ❌ | | | Trusted-proxy auth | ✅ | ❌ | Header-based reverse proxy auth | -| OAuth flows | ✅ | 🚧 | NEAR AI OAuth | +| OAuth flows | ✅ | 🚧 | NEAR AI OAuth plus hosted extension/MCP OAuth broker; external auth-proxy rollout still pending | | DM pairing verification | ✅ | ✅ | ironclaw pairing approve, host APIs | | Allowlist/blocklist | ✅ | 🚧 | allow_from + pairing store | | Per-group tool policies | ✅ | ❌ | | diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher.rs index 49387e8351..d3825b2f50 100644 --- a/src/agent/dispatcher.rs +++ b/src/agent/dispatcher.rs @@ -29,7 +29,7 @@ pub(super) enum AgenticLoopResult { /// A tool requires approval before continuing. NeedApproval { /// The pending approval request to store. - pending: PendingApproval, + pending: Box, }, } @@ -217,9 +217,7 @@ impl Agent { reason: format!("Exceeded maximum tool iterations ({max_tool_iterations})"), } .into()), - LoopOutcome::NeedApproval(pending) => { - Ok(AgenticLoopResult::NeedApproval { pending: *pending }) - } + LoopOutcome::NeedApproval(pending) => Ok(AgenticLoopResult::NeedApproval { pending }), } } @@ -482,6 +480,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { usize, crate::llm::ToolCall, Arc, + bool, // allow_always )> = None; for (idx, original_tc) in tool_calls.iter().enumerate() { @@ -551,7 +550,8 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { && let Some(tool) = tool_opt { use crate::tools::ApprovalRequirement; - let needs_approval = match tool.requires_approval(&tc.arguments) { + let requirement = tool.requires_approval(&tc.arguments); + let needs_approval = match requirement { ApprovalRequirement::Never => false, ApprovalRequirement::UnlessAutoApproved => { let sess = self.session.lock().await; @@ -586,7 +586,8 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { continue; } - approval_needed = Some((idx, tc, tool)); + let allow_always = !matches!(requirement, ApprovalRequirement::Always); + approval_needed = Some((idx, tc, tool, allow_always)); break; } } @@ -887,7 +888,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { } // Handle approval if a tool needed it - if let Some((approval_idx, tc, tool)) = approval_needed { + if let Some((approval_idx, tc, tool, allow_always)) = approval_needed { let display_params = redact_params(&tc.arguments, tool.sensitive_params()); let pending = PendingApproval { request_id: Uuid::new_v4(), @@ -899,6 +900,7 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { context_messages: reason_ctx.messages.clone(), deferred_tool_calls: tool_calls[approval_idx + 1..].to_vec(), user_timezone: Some(self.user_tz.name().to_string()), + allow_always, }; return Ok(Some(LoopOutcome::NeedApproval(Box::new(pending)))); @@ -1365,6 +1367,35 @@ mod tests { assert!(always_needs, "Always must always require approval"); } + /// Regression test: `allow_always` must be `false` for `Always` and + /// `true` for `UnlessAutoApproved`, so the UI hides the "always" button + /// for tools that truly cannot be auto-approved. + #[test] + fn test_allow_always_matches_approval_requirement() { + use crate::tools::ApprovalRequirement; + + // Mirrors the expression used in dispatcher.rs and thread_ops.rs: + // let allow_always = !matches!(requirement, ApprovalRequirement::Always); + + // UnlessAutoApproved → allow_always = true + let req = ApprovalRequirement::UnlessAutoApproved; + let allow_always = !matches!(req, ApprovalRequirement::Always); + assert!( + allow_always, + "UnlessAutoApproved should set allow_always = true" + ); + + // Always → allow_always = false + let req = ApprovalRequirement::Always; + let allow_always = !matches!(req, ApprovalRequirement::Always); + assert!(!allow_always, "Always should set allow_always = false"); + + // Never → allow_always = true (approval is never needed, but if it were, always would be ok) + let req = ApprovalRequirement::Never; + let allow_always = !matches!(req, ApprovalRequirement::Always); + assert!(allow_always, "Never should set allow_always = true"); + } + #[test] fn test_pending_approval_serialization_backcompat_without_deferred_calls() { // PendingApproval from before the deferred_tool_calls field was added @@ -1410,6 +1441,7 @@ mod tests { }, ], user_timezone: None, + allow_always: true, }; let json = serde_json::to_string(&pending).expect("serialize"); diff --git a/src/agent/job_monitor.rs b/src/agent/job_monitor.rs index 714caeac4b..6497861a4d 100644 --- a/src/agent/job_monitor.rs +++ b/src/agent/job_monitor.rs @@ -211,6 +211,7 @@ mod tests { job_id: job_id.to_string(), status: "completed".to_string(), session_id: None, + fallback_deliverable: None, }, )) .unwrap(); diff --git a/src/agent/session.rs b/src/agent/session.rs index 4abbea6168..3e84afc0b6 100644 --- a/src/agent/session.rs +++ b/src/agent/session.rs @@ -188,6 +188,15 @@ pub struct PendingApproval { /// through the approval flow even if the approval message lacks timezone. #[serde(default)] pub user_timezone: Option, + /// Whether the "always" auto-approve option should be offered to the user. + /// `false` when the tool returned `ApprovalRequirement::Always` (e.g. + /// destructive shell commands), meaning every invocation must be confirmed. + #[serde(default = "default_true")] + pub allow_always: bool, +} + +fn default_true() -> bool { + true } /// A conversation thread within a session. @@ -1106,6 +1115,7 @@ mod tests { context_messages: vec![ChatMessage::user("do it")], deferred_tool_calls: vec![], user_timezone: None, + allow_always: false, }; thread.await_approval(approval); @@ -1132,6 +1142,7 @@ mod tests { context_messages: vec![], deferred_tool_calls: vec![], user_timezone: None, + allow_always: true, }; thread.await_approval(approval); diff --git a/src/agent/submission.rs b/src/agent/submission.rs index a3ae2524d2..8594c9690c 100644 --- a/src/agent/submission.rs +++ b/src/agent/submission.rs @@ -382,6 +382,8 @@ pub enum SubmissionResult { description: String, /// Parameters being passed. parameters: serde_json::Value, + /// Whether "always" auto-approve should be offered to the user. + allow_always: bool, }, /// Successfully processed (for control commands). diff --git a/src/agent/thread_ops.rs b/src/agent/thread_ops.rs index 877a4e2777..e8b8d09a5b 100644 --- a/src/agent/thread_ops.rs +++ b/src/agent/thread_ops.rs @@ -506,7 +506,8 @@ impl Agent { let tool_name = pending.tool_name.clone(); let description = pending.description.clone(); let parameters = pending.display_parameters.clone(); - thread.await_approval(pending); + let allow_always = pending.allow_always; + thread.await_approval(*pending); let _ = self .channels .send_status( @@ -516,6 +517,7 @@ impl Agent { tool_name: tool_name.clone(), description: description.clone(), parameters: parameters.clone(), + allow_always, }, &message.metadata, ) @@ -525,6 +527,7 @@ impl Agent { tool_name, description, parameters, + allow_always, }) } Err(e) => { @@ -1069,28 +1072,31 @@ impl Agent { usize, crate::llm::ToolCall, Arc, + bool, // allow_always )> = None; for (idx, tc) in deferred_tool_calls.iter().enumerate() { if let Some(tool) = self.tools().get(&tc.name).await { // Match dispatcher.rs: when auto_approve_tools is true, skip // all approval checks (including ApprovalRequirement::Always). - let needs_approval = if self.config.auto_approve_tools { - false + let (needs_approval, allow_always) = if self.config.auto_approve_tools { + (false, true) } else { use crate::tools::ApprovalRequirement; - match tool.requires_approval(&tc.arguments) { + let requirement = tool.requires_approval(&tc.arguments); + let needs = match requirement { ApprovalRequirement::Never => false, ApprovalRequirement::UnlessAutoApproved => { let sess = session.lock().await; !sess.is_tool_auto_approved(&tc.name) } ApprovalRequirement::Always => true, - } + }; + (needs, !matches!(requirement, ApprovalRequirement::Always)) }; if needs_approval { - approval_needed = Some((idx, tc.clone(), tool)); + approval_needed = Some((idx, tc.clone(), tool, allow_always)); break; // remaining tools stay deferred } } @@ -1298,7 +1304,7 @@ impl Agent { } // Handle approval if a tool needed it - if let Some((approval_idx, tc, tool)) = approval_needed { + if let Some((approval_idx, tc, tool, allow_always)) = approval_needed { let new_pending = PendingApproval { request_id: Uuid::new_v4(), tool_name: tc.name.clone(), @@ -1310,6 +1316,7 @@ impl Agent { deferred_tool_calls: deferred_tool_calls[approval_idx + 1..].to_vec(), // Carry forward the resolved timezone from the original pending approval user_timezone: pending.user_timezone.clone(), + allow_always, }; let request_id = new_pending.request_id; @@ -1333,6 +1340,7 @@ impl Agent { tool_name: tool_name.clone(), description: description.clone(), parameters: parameters.clone(), + allow_always, }, &message.metadata, ) @@ -1343,6 +1351,7 @@ impl Agent { tool_name, description, parameters, + allow_always, }); } @@ -1411,7 +1420,8 @@ impl Agent { let tool_name = new_pending.tool_name.clone(); let description = new_pending.description.clone(); let parameters = new_pending.display_parameters.clone(); - thread.await_approval(new_pending); + let allow_always = new_pending.allow_always; + thread.await_approval(*new_pending); let _ = self .channels .send_status( @@ -1421,6 +1431,7 @@ impl Agent { tool_name: tool_name.clone(), description: description.clone(), parameters: parameters.clone(), + allow_always, }, &message.metadata, ) @@ -1430,6 +1441,7 @@ impl Agent { tool_name, description, parameters, + allow_always, }) } Err(e) => { @@ -1949,6 +1961,7 @@ mod tests { context_messages: vec![], deferred_tool_calls: vec![], user_timezone: None, + allow_always: false, }; thread.await_approval(pending); diff --git a/src/app.rs b/src/app.rs index fa6675bfad..c6892477f0 100644 --- a/src/app.rs +++ b/src/app.rs @@ -25,7 +25,7 @@ use crate::tools::ToolRegistry; use crate::tools::mcp::{McpProcessManager, McpSessionManager}; use crate::tools::wasm::SharedCredentialRegistry; use crate::tools::wasm::WasmToolRuntime; -use crate::workspace::{EmbeddingProvider, Workspace}; +use crate::workspace::{EmbeddingCacheConfig, EmbeddingProvider, Workspace}; /// Fully initialized application components, ready for channel wiring /// and agent construction. @@ -313,10 +313,13 @@ impl AppBuilder { // Register memory tools if database is available let workspace = if let Some(ref db) = self.db { + let emb_cache_config = EmbeddingCacheConfig { + max_entries: self.config.embeddings.cache_size, + }; let mut ws = Workspace::new_with_db(&self.config.owner_id, db.clone()) .with_search_config(&self.config.search); if let Some(ref emb) = embeddings { - ws = ws.with_embeddings(emb.clone()); + ws = ws.with_embeddings_cached(emb.clone(), emb_cache_config); } let ws = Arc::new(ws); tools.register_memory_tools(Arc::clone(&ws)); diff --git a/src/channels/channel.rs b/src/channels/channel.rs index 43e35688cc..a85cf8c5d2 100644 --- a/src/channels/channel.rs +++ b/src/channels/channel.rs @@ -305,6 +305,11 @@ pub enum StatusUpdate { tool_name: String, description: String, parameters: serde_json::Value, + /// When `true`, the UI should offer an "always" option that auto-approves + /// future calls to this tool for the rest of the session. When `false` + /// (i.e. `ApprovalRequirement::Always`), the tool must be approved every + /// time and the "always" button should be hidden. + allow_always: bool, }, /// Extension needs user authentication (token or OAuth). AuthRequired { diff --git a/src/channels/manager.rs b/src/channels/manager.rs index b026ff850d..0c9a3da77a 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -239,6 +239,11 @@ impl ChannelManager { pub async fn get_channel(&self, name: &str) -> Option> { self.channels.read().await.get(name).cloned() } + + /// Remove a channel from the manager. + pub async fn remove(&self, name: &str) -> Option> { + self.channels.write().await.remove(name) + } } impl Default for ChannelManager { diff --git a/src/channels/relay/channel.rs b/src/channels/relay/channel.rs index 52aea478ee..3b6c337968 100644 --- a/src/channels/relay/channel.rs +++ b/src/channels/relay/channel.rs @@ -1,16 +1,16 @@ -//! Channel trait implementation for channel-relay SSE streams. +//! Channel trait implementation for channel-relay webhook callbacks. //! -//! `RelayChannel` connects to a channel-relay service via SSE, converts -//! incoming events to `IncomingMessage`s, and sends responses via the -//! relay's provider-specific proxy API (Slack). +//! `RelayChannel` receives events from channel-relay via HTTP POST callbacks +//! (pushed through an mpsc channel by the webhook handler), converts them +//! to `IncomingMessage`s, and sends responses via the relay's provider-specific +//! proxy API (Slack). use std::collections::HashMap; -use std::sync::Arc; use async_trait::async_trait; -use tokio::sync::{RwLock, mpsc}; +use tokio::sync::mpsc; -use crate::channels::relay::client::{RelayClient, RelayError}; +use crate::channels::relay::client::{ChannelEvent, RelayClient}; use crate::channels::{Channel, IncomingMessage, MessageStream, OutgoingResponse, StatusUpdate}; use crate::error::ChannelError; @@ -39,44 +39,34 @@ impl RelayProvider { } } -/// Channel implementation that connects to a channel-relay SSE stream. +/// Channel implementation that receives events from channel-relay via webhook callbacks. pub struct RelayChannel { client: RelayClient, provider: RelayProvider, - stream_token: Arc>, team_id: String, instance_id: String, - user_id: String, - /// SSE stream long-poll timeout in seconds. - stream_timeout_secs: u64, - /// Initial exponential backoff in milliseconds. - backoff_initial_ms: u64, - /// Maximum exponential backoff in milliseconds. - backoff_max_ms: u64, - /// Handle to the reconnect task for clean shutdown. - reconnect_handle: RwLock>>, - /// Handle to the SSE parser task for clean shutdown. - parser_handle: Arc>>>, - /// Maximum consecutive reconnect failures before giving up. - max_consecutive_failures: u64, + /// Sender side of the event channel — shared with the webhook handler. + event_tx: mpsc::Sender, + /// Receiver side — taken once by `start()`. + event_rx: tokio::sync::Mutex>>, } impl RelayChannel { /// Create a new relay channel for Slack (default provider). pub fn new( client: RelayClient, - stream_token: String, team_id: String, instance_id: String, - user_id: String, + event_tx: mpsc::Sender, + event_rx: mpsc::Receiver, ) -> Self { Self::new_with_provider( client, RelayProvider::Slack, - stream_token, team_id, instance_id, - user_id, + event_tx, + event_rx, ) } @@ -84,44 +74,24 @@ impl RelayChannel { pub fn new_with_provider( client: RelayClient, provider: RelayProvider, - stream_token: String, team_id: String, instance_id: String, - user_id: String, + event_tx: mpsc::Sender, + event_rx: mpsc::Receiver, ) -> Self { Self { client, provider, - stream_token: Arc::new(RwLock::new(stream_token)), team_id, instance_id, - user_id, - stream_timeout_secs: 86400, - backoff_initial_ms: 1000, - backoff_max_ms: 60000, - reconnect_handle: RwLock::new(None), - parser_handle: Arc::new(RwLock::new(None)), - max_consecutive_failures: 50, + event_tx, + event_rx: tokio::sync::Mutex::new(Some(event_rx)), } } - /// Set backoff/timeout parameters from relay config values. - pub fn with_timeouts( - mut self, - stream_timeout_secs: u64, - backoff_initial_ms: u64, - backoff_max_ms: u64, - ) -> Self { - self.stream_timeout_secs = stream_timeout_secs; - self.backoff_initial_ms = backoff_initial_ms; - self.backoff_max_ms = backoff_max_ms; - self - } - - /// Set the maximum number of consecutive reconnect failures before giving up. - pub fn with_max_failures(mut self, max: u64) -> Self { - self.max_consecutive_failures = max; - self + /// Get a clone of the event sender for wiring into the webhook endpoint. + pub fn event_sender(&self) -> mpsc::Sender { + self.event_tx.clone() } /// Build a provider-appropriate proxy body for sending a message. @@ -151,15 +121,9 @@ impl RelayChannel { team_id: &str, method: &str, body: serde_json::Value, - ) -> Result { + ) -> Result { self.client - .proxy_provider( - self.provider.as_str(), - team_id, - method, - body, - Some(&self.instance_id), - ) + .proxy_provider(self.provider.as_str(), team_id, method, body) .await } } @@ -172,204 +136,82 @@ impl Channel for RelayChannel { async fn start(&self) -> Result { let channel_name = self.name().to_string(); - let token = self.stream_token.read().await.clone(); - let (stream, initial_parser_handle) = self - .client - .connect_stream(&token, self.stream_timeout_secs) - .await - .map_err(|e| ChannelError::StartupFailed { - name: channel_name.clone(), - reason: e.to_string(), - })?; - *self.parser_handle.write().await = Some(initial_parser_handle); + // Take the receiver (can only start once) + let mut event_rx = + self.event_rx + .lock() + .await + .take() + .ok_or_else(|| ChannelError::StartupFailed { + name: channel_name.clone(), + reason: "RelayChannel already started".to_string(), + })?; let (tx, rx) = mpsc::channel(64); - - // Spawn the stream reader + reconnect task - let client = self.client.clone(); - let stream_token = Arc::clone(&self.stream_token); - let instance_id = self.instance_id.clone(); - let user_id = self.user_id.clone(); - let team_id = self.team_id.clone(); - let stream_timeout_secs = self.stream_timeout_secs; - let backoff_initial_ms = self.backoff_initial_ms; - let backoff_max_ms = self.backoff_max_ms; - let max_consecutive_failures = self.max_consecutive_failures; - let parser_handle = Arc::clone(&self.parser_handle); let provider_str = self.provider.as_str().to_string(); let relay_name = channel_name.clone(); - let handle = tokio::spawn(async move { - use futures::StreamExt; - - let mut current_stream = stream; - let mut backoff_ms = backoff_initial_ms; - let mut consecutive_failures: u64 = 0; - - loop { - // Read events from the current stream - while let Some(event) = current_stream.next().await { - // Reset backoff and failure count on successful event - backoff_ms = backoff_initial_ms; - consecutive_failures = 0; - - // Validate required fields - if event.sender_id.is_empty() - || event.channel_id.is_empty() - || event.provider_scope.is_empty() - { - tracing::debug!( - event_type = %event.event_type, - sender_id = %event.sender_id, - channel_id = %event.channel_id, - "Relay: skipping event with missing required fields" - ); - continue; - } - - // Skip non-message events - if !event.is_message() { - tracing::debug!( - event_type = %event.event_type, - "Relay: skipping non-message event" - ); - continue; - } - - tracing::info!( + // Spawn a task that reads events from the webhook handler and converts to IncomingMessage + tokio::spawn(async move { + while let Some(event) = event_rx.recv().await { + // Validate required fields + if event.sender_id.is_empty() + || event.channel_id.is_empty() + || event.provider_scope.is_empty() + { + tracing::debug!( event_type = %event.event_type, - sender = %event.sender_id, - channel = %event.channel_id, - provider = %provider_str, - "Relay: received message from {}", provider_str + sender_id = %event.sender_id, + channel_id = %event.channel_id, + "Relay: skipping event with missing required fields" ); - - let msg = IncomingMessage::new(&relay_name, &event.sender_id, event.text()) - .with_user_name(event.display_name()) - .with_metadata(serde_json::json!({ - "team_id": event.team_id(), - "channel_id": event.channel_id, - "sender_id": event.sender_id, - "sender_name": event.display_name(), - "event_type": event.event_type, - "thread_id": event.thread_id, - "provider": event.provider, - })); - - let msg = if let Some(ref thread_id) = event.thread_id { - msg.with_thread(thread_id) - } else { - msg.with_thread(&event.channel_id) - }; - - if tx.send(msg).await.is_err() { - tracing::info!("Relay channel receiver dropped, stopping"); - return; - } + continue; } - // Stream ended, attempt reconnect with backoff - consecutive_failures += 1; - if consecutive_failures >= max_consecutive_failures { - tracing::error!( - channel = %relay_name, - failures = consecutive_failures, - "Relay channel giving up after {} consecutive failures", - consecutive_failures + // Skip non-message events + if !event.is_message() { + tracing::debug!( + event_type = %event.event_type, + "Relay: skipping non-message event" ); - break; + continue; } - tracing::warn!( - backoff_ms = backoff_ms, - failures = consecutive_failures, - "Relay SSE stream ended, reconnecting..." + tracing::info!( + event_type = %event.event_type, + sender = %event.sender_id, + channel = %event.channel_id, + provider = %provider_str, + "Relay: received message from {}", provider_str ); - tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms * 2).min(backoff_max_ms); - - // Try to reconnect - let token = stream_token.read().await.clone(); - match client.connect_stream(&token, stream_timeout_secs).await { - Ok((new_stream, new_parser)) => { - tracing::info!("Relay SSE stream reconnected"); - consecutive_failures = 0; - backoff_ms = backoff_initial_ms; - current_stream = new_stream; - // Abort old parser before replacing - if let Some(old) = parser_handle.write().await.take() { - old.abort(); - } - *parser_handle.write().await = Some(new_parser); - } - Err(RelayError::TokenExpired) => { - // Attempt token renewal - tracing::info!("Relay stream token expired, renewing..."); - match client.renew_token(&instance_id, &user_id).await { - Ok(new_token) => { - *stream_token.write().await = new_token.clone(); - match client.connect_stream(&new_token, stream_timeout_secs).await { - Ok((new_stream, new_parser)) => { - tracing::info!( - "Relay SSE stream reconnected with new token" - ); - consecutive_failures = 0; - backoff_ms = backoff_initial_ms; - current_stream = new_stream; - if let Some(old) = parser_handle.write().await.take() { - old.abort(); - } - *parser_handle.write().await = Some(new_parser); - } - Err(e) => { - tracing::error!( - error = %e, - "Failed to reconnect after token renewal" - ); - } - } - } - Err(e) => { - tracing::error!( - error = %e, - "Failed to renew relay stream token" - ); - } - } - } - Err(e) => { - tracing::error!(error = %e, "Failed to reconnect relay SSE stream"); - } - } - // Check if the team is still valid (skip when team_id is unknown, - // e.g. when no DB store was available at activation time) - if !team_id.is_empty() { - match client.list_connections(&instance_id).await { - Ok(conns) => { - let has_team = - conns.iter().any(|c| c.team_id == team_id && c.connected); - if !has_team { - tracing::warn!( - team_id = %team_id, - "Team no longer connected, stopping relay channel" - ); - return; - } - } - Err(e) => { - tracing::warn!( - error = %e, - "Could not verify team connection, will retry next iteration" - ); - } - } + let msg = IncomingMessage::new(&relay_name, &event.sender_id, event.text()) + .with_user_name(event.display_name()) + .with_metadata(serde_json::json!({ + "team_id": event.team_id(), + "channel_id": event.channel_id, + "sender_id": event.sender_id, + "sender_name": event.display_name(), + "event_type": event.event_type, + "thread_id": event.thread_id, + "provider": event.provider, + })); + + let msg = if let Some(ref thread_id) = event.thread_id { + msg.with_thread(thread_id) + } else { + msg.with_thread(&event.channel_id) + }; + + if tx.send(msg).await.is_err() { + tracing::info!("Relay channel receiver dropped, stopping"); + return; } } - }); - *self.reconnect_handle.write().await = Some(handle); + tracing::info!("Relay event channel closed"); + }); let stream = tokio_stream::wrappers::ReceiverStream::new(rx); Ok(Box::pin(stream)) @@ -423,6 +265,7 @@ impl Channel for RelayChannel { tool_name, description, parameters, + allow_always: _, } = status else { return Ok(()); @@ -450,28 +293,24 @@ impl Channel for RelayChannel { name: self.name().to_string(), reason: "Missing channel_id for approval buttons".into(), })?; - let sender_id = metadata - .get("sender_id") - .and_then(|v| v.as_str()) - .ok_or_else(|| ChannelError::SendFailed { - name: self.name().to_string(), - reason: "Missing sender_id for approval buttons".into(), - })?; let thread_id = metadata.get("thread_id").and_then(|v| v.as_str()); let team_id = metadata .get("team_id") .and_then(|v| v.as_str()) .unwrap_or(&self.team_id); - // Button value payload (Slack limits button values to 2000 chars; - // safe with typical UUIDs but documented here as a constraint) + // Register server-side approval record and get opaque token. + // The button value contains ONLY the token — no routing fields. + let approval_token = self + .client + .create_approval(team_id, channel_id, thread_id, &request_id) + .await + .map_err(|e| ChannelError::SendFailed { + name: self.name().to_string(), + reason: format!("Failed to register approval: {e}"), + })?; let value_payload = serde_json::json!({ - "instance_id": self.instance_id, - "team_id": team_id, - "channel_id": channel_id, - "thread_ts": thread_id, - "request_id": request_id, - "sender_id": sender_id, + "approval_token": approval_token, }); let value_str = value_payload.to_string(); @@ -582,12 +421,8 @@ impl Channel for RelayChannel { } async fn shutdown(&self) -> Result<(), ChannelError> { - if let Some(handle) = self.reconnect_handle.write().await.take() { - handle.abort(); - } - if let Some(handle) = self.parser_handle.write().await.take() { - handle.abort(); - } + // Relay cleanup is driven by the extension manager dropping the shared + // sender and removing the channel from the channel manager. Ok(()) } } @@ -605,27 +440,20 @@ mod tests { .expect("client") } + fn make_channel() -> RelayChannel { + let (tx, rx) = mpsc::channel(64); + RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx, rx) + } + #[test] fn relay_channel_name() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); assert_eq!(channel.name(), DEFAULT_RELAY_NAME); } #[test] fn conversation_context_extracts_metadata() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let metadata = serde_json::json!({ "sender_name": "bob", @@ -640,8 +468,6 @@ mod tests { #[test] fn metadata_shape_includes_event_type_and_sender_name() { - // Regression: metadata JSON must include event_type and sender_name - // for downstream routing (DM vs channel) and conversation_context(). let metadata = serde_json::json!({ "team_id": "T123", "channel_id": "C456", @@ -651,43 +477,19 @@ mod tests { "thread_id": null, "provider": "slack", }); - // event_type must be present for DM-vs-channel routing assert_eq!( metadata.get("event_type").and_then(|v| v.as_str()), Some("direct_message") ); - // sender_name must be present for conversation_context assert_eq!( metadata.get("sender_name").and_then(|v| v.as_str()), Some("alice") ); } - #[test] - fn with_timeouts_sets_values() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ) - .with_timeouts(43200, 2000, 120000); - - assert_eq!(channel.stream_timeout_secs, 43200); - assert_eq!(channel.backoff_initial_ms, 2000); - assert_eq!(channel.backoff_max_ms, 120000); - } - #[test] fn build_send_body_slack() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let (method, body) = channel.build_send_body("C456", "hello", Some("1234567.890")); assert_eq!(method, "chat.postMessage"); assert_eq!(body["channel"], "C456"); @@ -695,72 +497,95 @@ mod tests { assert_eq!(body["thread_ts"], "1234567.890"); } - #[test] - fn parser_handle_is_shared_arc() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); - // parser_handle should be an Arc — cloning should give a second reference - let handle_clone = Arc::clone(&channel.parser_handle); - // Both point to the same allocation - assert!(Arc::ptr_eq(&channel.parser_handle, &handle_clone)); - } - - #[test] - fn with_max_failures_sets_value() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ) - .with_max_failures(10); + #[tokio::test] + async fn start_processes_events() { + let (tx, rx) = mpsc::channel(64); + let channel = + RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx.clone(), rx); + + let mut stream = channel.start().await.unwrap(); + + // Send an event + tx.send(ChannelEvent { + id: "1".into(), + event_type: "message".into(), + provider: "slack".into(), + provider_scope: "T123".into(), + channel_id: "C456".into(), + sender_id: "U789".into(), + sender_name: Some("alice".into()), + content: Some("hello".into()), + thread_id: None, + raw: serde_json::Value::Null, + timestamp: None, + }) + .await + .unwrap(); + + use futures::StreamExt; + let msg = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .unwrap() + .unwrap(); - assert_eq!(channel.max_consecutive_failures, 10); + assert_eq!(msg.content, "hello"); + assert_eq!(msg.user_id, "U789"); } - #[test] - fn default_max_failures_is_50() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); - assert_eq!(channel.max_consecutive_failures, 50); - } + #[tokio::test] + async fn start_skips_non_message_events() { + let (tx, rx) = mpsc::channel(64); + let channel = + RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx.clone(), rx); + + let mut stream = channel.start().await.unwrap(); + + // Send a non-message event (should be skipped) + tx.send(ChannelEvent { + id: "1".into(), + event_type: "reaction".into(), + provider: "slack".into(), + provider_scope: "T123".into(), + channel_id: "C456".into(), + sender_id: "U789".into(), + sender_name: None, + content: None, + thread_id: None, + raw: serde_json::Value::Null, + timestamp: None, + }) + .await + .unwrap(); + + // Send a real message + tx.send(ChannelEvent { + id: "2".into(), + event_type: "message".into(), + provider: "slack".into(), + provider_scope: "T123".into(), + channel_id: "C456".into(), + sender_id: "U789".into(), + sender_name: None, + content: Some("real message".into()), + thread_id: None, + raw: serde_json::Value::Null, + timestamp: None, + }) + .await + .unwrap(); + + use futures::StreamExt; + let msg = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .unwrap() + .unwrap(); - #[test] - fn empty_team_id_accepted_at_construction() { - // Regression: empty team_id (when no DB store is available) must not - // prevent channel construction or cause immediate shutdown. - let channel = RelayChannel::new( - test_client(), - "token".into(), - String::new(), // empty team_id - "inst1".into(), - "user1".into(), - ); - assert_eq!(channel.team_id, ""); - // The reconnect loop now skips team validation when team_id is empty, - // so the channel remains alive. + assert_eq!(msg.content, "real message"); } #[tokio::test] async fn test_send_status_non_approval_is_noop() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let metadata = serde_json::json!({}); let result = channel .send_status( @@ -775,13 +600,7 @@ mod tests { #[tokio::test] async fn test_send_status_approval_non_dm_skips() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let metadata = serde_json::json!({ "event_type": "message", "channel_id": "C456", @@ -794,6 +613,7 @@ mod tests { tool_name: "shell".into(), description: "run command".into(), parameters: serde_json::json!({}), + allow_always: true, }, &metadata, ) @@ -804,13 +624,7 @@ mod tests { #[tokio::test] async fn test_send_status_approval_dm_missing_channel_id_errors() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let metadata = serde_json::json!({ "event_type": "direct_message", "sender_id": "U789", @@ -822,6 +636,7 @@ mod tests { tool_name: "shell".into(), description: "run command".into(), parameters: serde_json::json!({}), + allow_always: true, }, &metadata, ) @@ -835,14 +650,8 @@ mod tests { } #[tokio::test] - async fn test_send_status_approval_dm_missing_sender_id_errors() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + async fn test_send_status_approval_dm_without_sender_id_is_ok() { + let channel = make_channel(); let metadata = serde_json::json!({ "event_type": "direct_message", "channel_id": "C456", @@ -854,6 +663,7 @@ mod tests { tool_name: "shell".into(), description: "run command".into(), parameters: serde_json::json!({}), + allow_always: true, }, &metadata, ) @@ -861,8 +671,8 @@ mod tests { assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!( - err.contains("sender_id"), - "expected sender_id error, got: {err}" + !err.contains("sender_id"), + "sender_id should not be required anymore, got: {err}" ); } } diff --git a/src/channels/relay/client.rs b/src/channels/relay/client.rs index d1c03a5190..81fbb56c93 100644 --- a/src/channels/relay/client.rs +++ b/src/channels/relay/client.rs @@ -1,15 +1,10 @@ //! HTTP client for the channel-relay service. //! //! Wraps reqwest for all channel-relay API calls: OAuth initiation, -//! SSE streaming, token renewal, and Slack API proxy. +//! approvals, signing-secret fetch, and Slack API proxy. -use std::pin::Pin; -use std::task::{Context, Poll}; - -use futures::Stream; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc; /// Known relay event types. pub mod event_types { @@ -18,7 +13,7 @@ pub mod event_types { pub const MENTION: &str = "mention"; } -/// A parsed SSE event from the channel-relay stream. +/// A parsed event from the channel-relay webhook callback. /// /// Field names match the channel-relay `ChannelEvent` struct exactly. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -123,21 +118,19 @@ impl RelayClient { /// /// Calls `GET /oauth/slack/auth` with `redirect(Policy::none())` and /// returns the `Location` header (Slack OAuth URL) without following it. - pub async fn initiate_oauth( - &self, - instance_id: &str, - user_id: &str, - callback_url: &str, - ) -> Result { + /// Initiate Slack OAuth. Channel-relay derives all URLs from the trusted + /// instance_url in chat-api. IronClaw only passes an optional CSRF nonce + /// for validating the callback — no URLs. + pub async fn initiate_oauth(&self, state_nonce: Option<&str>) -> Result { + let mut query: Vec<(&str, &str)> = vec![]; + if let Some(nonce) = state_nonce { + query.push(("state_nonce", nonce)); + } let resp = self .http .get(format!("{}/oauth/slack/auth", self.base_url)) - .header("X-API-Key", self.api_key.expose_secret()) - .query(&[ - ("instance_id", instance_id), - ("user_id", user_id), - ("callback", callback_url), - ]) + .bearer_auth(self.api_key.expose_secret()) + .query(&query) .send() .await .map_err(|e| RelayError::Network(e.to_string()))?; @@ -173,104 +166,69 @@ impl RelayClient { } } - /// Connect to the SSE event stream. + /// Register a pending approval and return the opaque approval token. /// - /// Returns a stream of parsed `ChannelEvent`s and the `JoinHandle` of the - /// background SSE parser task. The caller is responsible for reconnection - /// logic on stream end/error and for aborting the handle on shutdown. - pub async fn connect_stream( + /// Calls `POST /approvals` with the target team/channel/request identifiers. + /// The returned token is embedded in Slack button values instead of routing fields. + /// The relay derives the authorized approver from the connection's authed_user_id. + pub async fn create_approval( &self, - stream_token: &str, - stream_timeout_secs: u64, - ) -> Result<(ChannelEventStream, tokio::task::JoinHandle<()>), RelayError> { - let resp = self - .http - .get(format!("{}/stream", self.base_url)) - .query(&[("token", stream_token)]) - .timeout(std::time::Duration::from_secs(stream_timeout_secs)) - .send() - .await - .map_err(|e| RelayError::Network(e.to_string()))?; - - let status = resp.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - return Err(RelayError::TokenExpired); - } - if !status.is_success() { - let body = resp.text().await.unwrap_or_default(); - return Err(RelayError::Api { - status: status.as_u16(), - message: body, - }); + team_id: &str, + channel_id: &str, + thread_ts: Option<&str>, + request_id: &str, + ) -> Result { + let mut body = serde_json::json!({ + "team_id": team_id, + "channel_id": channel_id, + "request_id": request_id, + }); + if let Some(ts) = thread_ts { + body["thread_ts"] = serde_json::Value::String(ts.to_string()); } - // Spawn a background task that reads the SSE stream and sends parsed events - let (tx, rx) = mpsc::channel(64); - let byte_stream = resp.bytes_stream(); - let handle = tokio::spawn(parse_sse_stream(byte_stream, tx)); - - Ok((ChannelEventStream { rx }, handle)) - } - - /// Renew an expired stream token. - /// - /// Calls `POST /stream/renew` with API key auth, returns a new stream token. - pub async fn renew_token( - &self, - instance_id: &str, - user_id: &str, - ) -> Result { let resp = self .http - .post(format!("{}/stream/renew", self.base_url)) - .header("X-API-Key", self.api_key.expose_secret()) - .json(&serde_json::json!({ - "instance_id": instance_id, - "user_id": user_id, - })) + .post(format!("{}/approvals", self.base_url)) + .bearer_auth(self.api_key.expose_secret()) + .json(&body) .send() .await .map_err(|e| RelayError::Network(e.to_string()))?; - let status = resp.status(); - if !status.is_success() { + if !resp.status().is_success() { + let status = resp.status().as_u16(); let body = resp.text().await.unwrap_or_default(); return Err(RelayError::Api { - status: status.as_u16(), + status, message: body, }); } - let body: serde_json::Value = resp + let result: serde_json::Value = resp .json() .await .map_err(|e| RelayError::Protocol(e.to_string()))?; - body.get("stream_token") - .or_else(|| body.get("token")) + + result + .get("approval_token") .and_then(|v| v.as_str()) .map(|s| s.to_string()) - .ok_or_else(|| RelayError::Protocol("Response missing stream_token field".to_string())) + .ok_or_else(|| RelayError::Protocol("missing approval_token in response".to_string())) } - /// Proxy an API call through channel-relay for any provider. - /// - /// Calls `POST /proxy/{provider}/{method}?team_id=X&instance_id=Y` with the given JSON body. pub async fn proxy_provider( &self, provider: &str, team_id: &str, method: &str, body: serde_json::Value, - instance_id: Option<&str>, ) -> Result { - let mut query: Vec<(&str, &str)> = vec![("team_id", team_id)]; - if let Some(iid) = instance_id { - query.push(("instance_id", iid)); - } + let query: Vec<(&str, &str)> = vec![("team_id", team_id)]; let resp = self .http .post(format!("{}/proxy/{}/{}", self.base_url, provider, method)) - .header("X-API-Key", self.api_key.expose_secret()) + .bearer_auth(self.api_key.expose_secret()) .query(&query) .json(&body) .send() @@ -291,12 +249,58 @@ impl RelayClient { .map_err(|e| RelayError::Protocol(e.to_string())) } + /// Fetch the per-instance callback signing secret from channel-relay. + /// + /// Calls `GET /relay/signing-secret` (authenticated) and returns the decoded + /// 32-byte secret. Called once at activation time; the result is cached in the + /// extension manager so subsequent calls to `relay_signing_secret()` use it. + pub async fn get_signing_secret(&self, team_id: &str) -> Result, RelayError> { + let resp = self + .http + .get(format!("{}/relay/signing-secret", self.base_url)) + .bearer_auth(self.api_key.expose_secret()) + .query(&[("team_id", team_id)]) + .send() + .await + .map_err(|e| RelayError::Network(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status().as_u16(); + let body = resp.text().await.unwrap_or_default(); + return Err(RelayError::Api { + status, + message: body, + }); + } + + let body: serde_json::Value = resp + .json() + .await + .map_err(|e| RelayError::Protocol(e.to_string()))?; + + body.get("signing_secret") + .and_then(|v| v.as_str()) + .ok_or_else(|| RelayError::Protocol("missing signing_secret in response".to_string())) + .and_then(|raw| { + let decoded = hex::decode(raw).map_err(|e| { + RelayError::Protocol(format!("invalid signing_secret hex: {e}")) + })?; + if decoded.len() != 32 { + return Err(RelayError::Protocol(format!( + "invalid signing_secret length: expected 32 bytes, got {}", + decoded.len() + ))); + } + Ok(decoded) + }) + } + /// List active connections for an instance. pub async fn list_connections(&self, instance_id: &str) -> Result, RelayError> { let resp = self .http .get(format!("{}/connections", self.base_url)) - .header("X-API-Key", self.api_key.expose_secret()) + .bearer_auth(self.api_key.expose_secret()) .query(&[("instance_id", instance_id)]) .send() .await @@ -317,91 +321,6 @@ impl RelayClient { } } -/// Async stream of parsed channel events from SSE. -pub struct ChannelEventStream { - rx: mpsc::Receiver, -} - -impl Stream for ChannelEventStream { - type Item = ChannelEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx.poll_recv(cx) - } -} - -/// Parse SSE format from a reqwest bytes stream. -/// -/// SSE format: -/// ```text -/// event: message -/// data: {"key": "value"} -/// -/// ``` -/// Blank line terminates an event. -async fn parse_sse_stream( - byte_stream: impl futures::Stream> + Send + 'static, - tx: mpsc::Sender, -) { - use futures::StreamExt; - - let mut buffer = Vec::::new(); - let mut event_type = String::new(); - let mut data_lines = Vec::new(); - - let mut byte_stream = std::pin::pin!(byte_stream); - while let Some(chunk_result) = byte_stream.next().await { - let chunk = match chunk_result { - Ok(c) => c, - Err(e) => { - tracing::debug!(error = %e, "SSE stream chunk error"); - break; - } - }; - - buffer.extend_from_slice(&chunk); - - // Process complete lines (decode UTF-8 only on full lines to avoid - // corruption when multi-byte characters span chunk boundaries) - while let Some(newline_pos) = buffer.iter().position(|&b| b == b'\n') { - let line = String::from_utf8_lossy(&buffer[..newline_pos]) - .trim_end_matches('\r') - .to_string(); - buffer.drain(..=newline_pos); - - if line.is_empty() { - // Blank line = end of event - if !data_lines.is_empty() { - let data = data_lines.join("\n"); - if let Ok(mut event) = serde_json::from_str::(&data) { - if event.event_type.is_empty() && !event_type.is_empty() { - event.event_type = event_type.clone(); - } - if tx.send(event).await.is_err() { - return; // receiver dropped - } - } else { - tracing::debug!( - event_type = %event_type, - data_len = data.len(), - "Failed to parse SSE event data as ChannelEvent" - ); - } - } - event_type.clear(); - data_lines.clear(); - } else if let Some(value) = line.strip_prefix("event:") { - event_type = value.trim().to_string(); - } else if let Some(value) = line.strip_prefix("data:") { - data_lines.push(value.trim().to_string()); - } - // Ignore other fields (id:, retry:, comments) - } - } - - tracing::debug!("SSE stream ended"); -} - /// Errors from relay client operations. #[derive(Debug, thiserror::Error)] pub enum RelayError { @@ -413,9 +332,6 @@ pub enum RelayError { #[error("Protocol error: {0}")] Protocol(String), - - #[error("Stream token expired")] - TokenExpired, } #[cfg(test)] @@ -494,9 +410,6 @@ mod tests { message: "unauthorized".into(), }; assert_eq!(err.to_string(), "API error (HTTP 401): unauthorized"); - - let err = RelayError::TokenExpired; - assert_eq!(err.to_string(), "Stream token expired"); } #[test] @@ -518,32 +431,4 @@ mod tests { assert!(make(event_types::DIRECT_MESSAGE).is_message()); assert!(make(event_types::MENTION).is_message()); } - - #[tokio::test] - async fn parse_sse_handles_multibyte_utf8_across_chunks() { - // The crab emoji (🦀) is 4 bytes: [0xF0, 0x9F, 0xA6, 0x80]. - // Split it across two chunks to verify no U+FFFD corruption. - let event_json = r#"{"event_type":"message","content":"hello 🦀 world","provider_scope":"T1","channel_id":"C1","sender_id":"U1"}"#; - let full = format!("event: message\ndata: {}\n\n", event_json); - let bytes = full.as_bytes(); - - // Find the crab emoji and split mid-character - let crab_pos = bytes - .windows(4) - .position(|w| w == [0xF0, 0x9F, 0xA6, 0x80]) - .expect("crab emoji not found"); - let split_at = crab_pos + 2; // split in the middle of the 4-byte emoji - - let chunk1 = bytes::Bytes::copy_from_slice(&bytes[..split_at]); - let chunk2 = bytes::Bytes::copy_from_slice(&bytes[split_at..]); - - let chunks: Vec> = vec![Ok(chunk1), Ok(chunk2)]; - let stream = futures::stream::iter(chunks); - - let (tx, mut rx) = mpsc::channel(8); - parse_sse_stream(stream, tx).await; - - let event = rx.recv().await.expect("should receive event"); - assert_eq!(event.text(), "hello 🦀 world"); - } } diff --git a/src/channels/relay/mod.rs b/src/channels/relay/mod.rs index 1582319fa6..05f5870c40 100644 --- a/src/channels/relay/mod.rs +++ b/src/channels/relay/mod.rs @@ -1,12 +1,13 @@ //! Channel-relay integration for connecting to external messaging platforms //! (Slack) via the channel-relay service. //! -//! The relay service handles OAuth, credential storage, webhook ingestion, -//! and SSE event streaming. IronClaw consumes the SSE stream and sends -//! messages via the relay's proxy API. +//! The relay service handles OAuth, credential storage, and webhook ingestion. +//! IronClaw receives events via webhook callbacks and sends messages via the +//! relay's proxy API. pub mod channel; pub mod client; +pub mod webhook; pub use channel::{DEFAULT_RELAY_NAME, RelayChannel}; pub use client::RelayClient; diff --git a/src/channels/relay/webhook.rs b/src/channels/relay/webhook.rs new file mode 100644 index 0000000000..c5a9f82a65 --- /dev/null +++ b/src/channels/relay/webhook.rs @@ -0,0 +1,66 @@ +//! Shared relay webhook signature verification helpers. + +use hmac::{Hmac, Mac}; +use sha2::Sha256; + +type HmacSha256 = Hmac; + +/// Verify a relay callback HMAC signature. +pub fn verify_relay_signature( + secret: &[u8], + timestamp: &str, + body: &[u8], + signature: &str, +) -> bool { + verify_signature(secret, timestamp, body, signature) +} + +fn verify_signature(secret: &[u8], timestamp: &str, body: &[u8], signature: &str) -> bool { + let mut mac = match HmacSha256::new_from_slice(secret) { + Ok(m) => m, + Err(_) => return false, + }; + mac.update(timestamp.as_bytes()); + mac.update(b"."); + mac.update(body); + let expected = format!("sha256={}", hex::encode(mac.finalize().into_bytes())); + subtle::ConstantTimeEq::ct_eq(expected.as_bytes(), signature.as_bytes()).into() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_signature(secret: &[u8], timestamp: &str, body: &[u8]) -> String { + let mut mac = HmacSha256::new_from_slice(secret).unwrap(); + mac.update(timestamp.as_bytes()); + mac.update(b"."); + mac.update(body); + format!("sha256={}", hex::encode(mac.finalize().into_bytes())) + } + + #[test] + fn verify_valid_signature() { + let secret = b"test-secret"; + let body = b"hello"; + let ts = "1234567890"; + let sig = make_signature(secret, ts, body); + assert!(verify_signature(secret, ts, body, &sig)); + } + + #[test] + fn verify_wrong_secret_fails() { + let body = b"hello"; + let ts = "1234567890"; + let sig = make_signature(b"correct", ts, body); + assert!(!verify_signature(b"wrong", ts, body, &sig)); + } + + #[test] + fn verify_tampered_body_fails() { + let secret = b"secret"; + let ts = "1234567890"; + let sig = make_signature(secret, ts, b"original"); + assert!(!verify_signature(secret, ts, b"tampered", &sig)); + } +} diff --git a/src/channels/repl.rs b/src/channels/repl.rs index 40d669198c..36ca7c28a0 100644 --- a/src/channels/repl.rs +++ b/src/channels/repl.rs @@ -539,6 +539,7 @@ impl Channel for ReplChannel { tool_name, description, parameters, + allow_always, } => { let term_width = crossterm::terminal::size() .map(|(w, _)| w as usize) @@ -582,9 +583,13 @@ impl Channel for ReplChannel { } eprintln!(" \u{2502}"); - eprintln!( - " \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[34malways\x1b[0m (a) / \x1b[31mno\x1b[0m (n)" - ); + if allow_always { + eprintln!( + " \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[34malways\x1b[0m (a) / \x1b[31mno\x1b[0m (n)" + ); + } else { + eprintln!(" \u{2502} \x1b[32myes\x1b[0m (y) / \x1b[31mno\x1b[0m (n)"); + } eprintln!(" {bot_border}"); eprintln!(); } diff --git a/src/channels/signal.rs b/src/channels/signal.rs index b8934c5cb1..84afccd5fb 100644 --- a/src/channels/signal.rs +++ b/src/channels/signal.rs @@ -915,20 +915,28 @@ impl Channel for SignalChannel { tool_name, description: _, parameters, + allow_always, } = &status && let Some(target_str) = metadata.get("signal_target").and_then(|v| v.as_str()) { let params_json = serde_json::to_string_pretty(parameters).unwrap_or_default(); + let always_line = if *allow_always { + format!( + "\n• `always` or `a` - Approve and auto-approve future {} requests", + tool_name + ) + } else { + String::new() + }; let message = format!( "⚠️ *Approval Required*\n\n\ *Request ID:* `{}`\n\ *Tool:* {}\n\ *Parameters:*\n```\n{}\n```\n\n\ Reply with:\n\ - • `yes` or `y` - Approve this request\n\ - • `always` or `a` - Approve and auto-approve future {} requests\n\ + • `yes` or `y` - Approve this request{}\n\ • `no` or `n` - Deny", - request_id, tool_name, params_json, tool_name + request_id, tool_name, params_json, always_line ); self.send_status_message(target_str, &message).await; } diff --git a/src/channels/wasm/wrapper.rs b/src/channels/wasm/wrapper.rs index 65f978ac80..8f0c9db4b5 100644 --- a/src/channels/wasm/wrapper.rs +++ b/src/channels/wasm/wrapper.rs @@ -2043,6 +2043,7 @@ impl WasmChannel { tool_name, description, parameters, + allow_always, .. } => { // WASM channels (Telegram, Slack, etc.) cannot render @@ -2081,6 +2082,11 @@ impl WasmChannel { }) .unwrap_or_default(); + let reply_hint = if *allow_always { + "Reply \"yes\" to approve, \"no\" to deny, or \"always\" to auto-approve." + } else { + "Reply \"yes\" to approve or \"no\" to deny." + }; let prompt = format!( "Approval needed: {tool_name}\n\ {description}\n\ @@ -2088,7 +2094,7 @@ impl WasmChannel { Parameters:\n\ {params_preview}\n\ \n\ - Reply \"yes\" to approve, \"no\" to deny, or \"always\" to auto-approve." + {reply_hint}" ); let metadata_json = serde_json::to_string(metadata).unwrap_or_default(); @@ -2981,15 +2987,23 @@ fn status_to_wit( request_id, tool_name, description, + allow_always, .. - } => wit_channel::StatusUpdate { - status: wit_channel::StatusType::ApprovalNeeded, - message: format!( - "Approval needed for tool '{}'. {}\nRequest ID: {}\nReply with: yes (or /approve), no (or /deny), or always (or /always).", - tool_name, description, request_id - ), - metadata_json, - }, + } => { + let reply_hint = if *allow_always { + "yes (or /approve), no (or /deny), or always (or /always)" + } else { + "yes (or /approve) or no (or /deny)" + }; + wit_channel::StatusUpdate { + status: wit_channel::StatusType::ApprovalNeeded, + message: format!( + "Approval needed for tool '{}'. {}\nRequest ID: {}\nReply with: {}.", + tool_name, description, request_id, reply_hint + ), + metadata_json, + } + } StatusUpdate::JobStarted { job_id, title, @@ -3670,6 +3684,7 @@ mod tests { tool_name: "http_request".into(), description: "Fetch weather".into(), parameters: serde_json::json!({"url": "https://wttr.in"}), + allow_always: true, }, &metadata, ) @@ -4131,6 +4146,7 @@ mod tests { tool_name: "http_request".to_string(), description: "Fetch weather data".to_string(), parameters: serde_json::json!({"url": "https://api.weather.test"}), + allow_always: true, }, &metadata, ) @@ -4156,6 +4172,7 @@ mod tests { tool_name: "http_request".to_string(), description: "Fetch weather data".to_string(), parameters: serde_json::json!({"url": "https://api.weather.test"}), + allow_always: true, }, &metadata, ) diff --git a/src/channels/web/mod.rs b/src/channels/web/mod.rs index a96f7c7b2d..bfefc5c4cd 100644 --- a/src/channels/web/mod.rs +++ b/src/channels/web/mod.rs @@ -374,6 +374,7 @@ impl Channel for GatewayChannel { tool_name, description, parameters, + allow_always, } => SseEvent::ApprovalNeeded { request_id, tool_name, @@ -381,6 +382,7 @@ impl Channel for GatewayChannel { parameters: serde_json::to_string_pretty(¶meters) .unwrap_or_else(|_| parameters.to_string()), thread_id, + allow_always, }, StatusUpdate::AuthRequired { extension_name, diff --git a/src/channels/web/server.rs b/src/channels/web/server.rs index 9a182c6cdf..ea3341c0d0 100644 --- a/src/channels/web/server.rs +++ b/src/channels/web/server.rs @@ -19,6 +19,7 @@ use axum::{ routing::{get, post}, }; use serde::Deserialize; +use sha2::{Digest, Sha256}; use tokio::sync::{mpsc, oneshot}; use tokio_stream::StreamExt; use tower_http::cors::{AllowHeaders, CorsLayer}; @@ -63,6 +64,16 @@ pub type PromptQueue = Arc< pub type RoutineEngineSlot = Arc>>>; +fn redact_oauth_state_for_logs(state: &str) -> String { + let digest = Sha256::digest(state.as_bytes()); + let mut short_hash = String::with_capacity(12); + for byte in &digest[..6] { + use std::fmt::Write as _; + let _ = write!(&mut short_hash, "{byte:02x}"); + } + format!("sha256:{short_hash}:len={}", state.len()) +} + /// Simple sliding-window rate limiter. /// /// Tracks the number of requests in the current window. Resets when the window expires. @@ -218,7 +229,8 @@ pub async fn start_server( .route( "/oauth/slack/callback", get(slack_relay_oauth_callback_handler), - ); + ) + .route("/relay/events", post(relay_events_handler)); // Protected routes (require auth) let auth_state = AuthState { token: auth_token }; @@ -565,22 +577,35 @@ async fn oauth_callback_handler( } }; - // Strip instance prefix from state for registry lookup. - // Platform nginx sends `state=instance:nonce` but flows are keyed by nonce only. - let lookup_key = oauth_defaults::strip_instance_prefix(&state_param); + let decoded_state = match oauth_defaults::decode_hosted_oauth_state(&state_param) { + Ok(decoded) => decoded, + Err(error) => { + let redacted_state = redact_oauth_state_for_logs(&state_param); + tracing::warn!( + state = %redacted_state, + error = %error, + "OAuth callback received with malformed state" + ); + clear_auth_mode(&state).await; + return oauth_error_page("IronClaw"); + } + }; + let lookup_key = decoded_state.flow_id.clone(); let flow = ext_mgr .pending_oauth_flows() .write() .await - .remove(lookup_key); + .remove(&lookup_key); let flow = match flow { Some(f) => f, None => { + let redacted_state = redact_oauth_state_for_logs(&state_param); + let redacted_lookup_key = redact_oauth_state_for_logs(&lookup_key); tracing::warn!( - state = %state_param, - lookup_key = %lookup_key, + state = %redacted_state, + lookup_key = %redacted_lookup_key, "OAuth callback received with unknown or expired state" ); clear_auth_mode(&state).await; @@ -607,33 +632,29 @@ async fn oauth_callback_handler( } // Exchange the authorization code for tokens. - // Use the platform exchange proxy when configured (keeps client_secret off container), - // otherwise call the provider's token URL directly. - let exchange_proxy_url = std::env::var("IRONCLAW_OAUTH_EXCHANGE_URL").ok(); + // Use the platform exchange proxy when configured, otherwise call the + // provider's token URL directly. + let exchange_proxy_url = oauth_defaults::exchange_proxy_url(); let result: Result<(), String> = async { - let token_response = if let (Some(proxy_url), None) = (&exchange_proxy_url, &flow.resource) - { - // Use the platform exchange proxy when configured and no resource - // parameter is needed. The proxy holds client_secret server-side so - // the container never sees it. MCP flows (resource.is_some()) bypass - // the proxy because it doesn't forward the RFC 8707 resource param. + let token_response = if let Some(proxy_url) = &exchange_proxy_url { let gateway_token = flow.gateway_token.as_deref().unwrap_or_default(); - oauth_defaults::exchange_via_proxy( + oauth_defaults::exchange_via_proxy(oauth_defaults::ProxyTokenExchangeRequest { proxy_url, gateway_token, - &code, - &flow.redirect_uri, - flow.code_verifier.as_deref(), - &flow.access_token_field, - ) + token_url: &flow.token_url, + client_id: &flow.client_id, + client_secret: flow.client_secret.as_deref(), + code: &code, + redirect_uri: &flow.redirect_uri, + code_verifier: flow.code_verifier.as_deref(), + access_token_field: &flow.access_token_field, + extra_token_params: &flow.token_exchange_extra_params, + }) .await .map_err(|e| e.to_string())? } else { - // Direct token exchange: uses exchange_oauth_code_with_resource so MCP - // flows can include the RFC 8707 `resource` parameter to scope the - // issued token to the specific MCP server. - oauth_defaults::exchange_oauth_code_with_resource( + oauth_defaults::exchange_oauth_code_with_params( &flow.token_url, &flow.client_id, flow.client_secret.as_deref(), @@ -641,7 +662,7 @@ async fn oauth_callback_handler( &flow.redirect_uri, flow.code_verifier.as_deref(), &flow.access_token_field, - flow.resource.as_deref(), + &flow.token_exchange_extra_params, ) .await .map_err(|e| e.to_string())? @@ -668,10 +689,8 @@ async fn oauth_callback_handler( .await .map_err(|e| e.to_string())?; - // For MCP OAuth flows (identified by resource field), persist the - // client_id so token refresh works without re-authentication. - // The CLI flow stores this in authorize_mcp_server(); the gateway - // callback must do the same. + // Persist the client_id for flows that need it after the session ends + // (for example DCR-based MCP refresh). if let Some(ref client_id_secret) = flow.client_id_secret_name { let params = crate::secrets::CreateSecretParams::new(client_id_secret, &flow.client_id) .with_provider(flow.provider.as_ref().cloned().unwrap_or_default()); @@ -752,11 +771,103 @@ async fn oauth_callback_handler( axum::response::Html(html).into_response() } +/// Webhook endpoint for receiving relay events from channel-relay. +/// +/// PUBLIC route — authenticated via HMAC signature (X-Relay-Signature header). +async fn relay_events_handler( + State(state): State>, + headers: axum::http::HeaderMap, + body: axum::body::Bytes, +) -> impl IntoResponse { + let ext_mgr = match state.extension_manager.as_ref() { + Some(mgr) => mgr, + None => { + return (StatusCode::SERVICE_UNAVAILABLE, "not ready").into_response(); + } + }; + + let signing_secret = match ext_mgr.relay_signing_secret() { + Some(s) => s, + None => { + return (StatusCode::SERVICE_UNAVAILABLE, "relay not configured").into_response(); + } + }; + + // Verify signature + let signature = match headers + .get("x-relay-signature") + .and_then(|v| v.to_str().ok()) + { + Some(s) => s.to_string(), + None => { + return (StatusCode::UNAUTHORIZED, "missing signature").into_response(); + } + }; + + let timestamp = match headers + .get("x-relay-timestamp") + .and_then(|v| v.to_str().ok()) + { + Some(t) => t.to_string(), + None => { + return (StatusCode::UNAUTHORIZED, "missing timestamp").into_response(); + } + }; + + // Check timestamp freshness (5 min window) + let ts: i64 = match timestamp.parse() { + Ok(t) => t, + Err(_) => { + return (StatusCode::BAD_REQUEST, "malformed timestamp").into_response(); + } + }; + let now = chrono::Utc::now().timestamp(); + if (now - ts).abs() > 300 { + return (StatusCode::UNAUTHORIZED, "stale timestamp").into_response(); + } + + // Verify HMAC: sha256(secret, timestamp + "." + body) + if !crate::channels::relay::webhook::verify_relay_signature( + &signing_secret, + ×tamp, + &body, + &signature, + ) { + return (StatusCode::UNAUTHORIZED, "invalid signature").into_response(); + } + + // Parse event + let event: crate::channels::relay::client::ChannelEvent = match serde_json::from_slice(&body) { + Ok(e) => e, + Err(e) => { + tracing::warn!(error = %e, "relay callback invalid JSON"); + return (StatusCode::BAD_REQUEST, "invalid JSON").into_response(); + } + }; + + // Push to relay channel + let event_tx_guard = ext_mgr.relay_event_tx(); + let event_tx = event_tx_guard.lock().await; + match event_tx.as_ref() { + Some(tx) => { + if let Err(e) = tx.try_send(event) { + tracing::warn!(error = %e, "relay event channel full or closed"); + return (StatusCode::SERVICE_UNAVAILABLE, "event queue full").into_response(); + } + } + None => { + return (StatusCode::SERVICE_UNAVAILABLE, "relay channel not active").into_response(); + } + } + + Json(serde_json::json!({"ok": true})).into_response() +} + /// OAuth callback for Slack via channel-relay. /// /// This is a PUBLIC route (no Bearer token required) because channel-relay /// redirects the user's browser here after Slack OAuth completes. -/// Query params: `stream_token`, `provider`, `team_id`. +/// Query params: `provider`, `team_id`. async fn slack_relay_oauth_callback_handler( State(state): State>, Query(params): Query>, @@ -773,27 +884,6 @@ async fn slack_relay_oauth_callback_handler( .into_response(); } - // Validate stream_token: required, non-empty, max 2048 bytes - let stream_token = match params.get("stream_token") { - Some(t) if !t.is_empty() && t.len() <= 2048 => t.clone(), - Some(t) if t.len() > 2048 => { - return axum::response::Html( - "\ -

Error

Invalid callback parameters.

" - .to_string(), - ) - .into_response(); - } - _ => { - return axum::response::Html( - "\ -

Error

Invalid callback parameters.

" - .to_string(), - ) - .into_response(); - } - }; - // Validate team_id format: empty or T followed by alphanumeric (max 20 chars) let team_id = params.get("team_id").cloned().unwrap_or_default(); if !team_id.is_empty() { @@ -879,30 +969,16 @@ async fn slack_relay_oauth_callback_handler( let _ = ext_mgr.secrets().delete(&state.user_id, &state_key).await; let result: Result<(), String> = async { - // Store the stream token as a secret - let token_key = format!("relay:{}:stream_token", DEFAULT_RELAY_NAME); - let _ = ext_mgr.secrets().delete(&state.user_id, &token_key).await; - ext_mgr - .secrets() - .create( - &state.user_id, - crate::secrets::CreateSecretParams { - name: token_key, - value: secrecy::SecretString::from(stream_token), - provider: Some(provider.clone()), - expires_at: None, - }, - ) - .await - .map_err(|e| format!("Failed to store stream token: {}", e))?; + let store = state.store.as_ref().ok_or_else(|| { + "Relay activation requires persistent settings storage; no-db mode is unsupported." + .to_string() + })?; // Store team_id in settings - if let Some(ref store) = state.store { - let team_id_key = format!("relay:{}:team_id", DEFAULT_RELAY_NAME); - let _ = store - .set_setting(&state.user_id, &team_id_key, &serde_json::json!(team_id)) - .await; - } + let team_id_key = format!("relay:{}:team_id", DEFAULT_RELAY_NAME); + let _ = store + .set_setting(&state.user_id, &team_id_key, &serde_json::json!(team_id)) + .await; // Activate the relay channel ext_mgr @@ -3253,7 +3329,7 @@ mod tests { secrets, sse_sender: None, gateway_token: None, - resource: None, + token_exchange_extra_params: std::collections::HashMap::new(), client_id_secret_name: None, created_at, }; @@ -3321,7 +3397,7 @@ mod tests { secrets, sse_sender: Some(sender), gateway_token: None, - resource: None, + token_exchange_extra_params: std::collections::HashMap::new(), client_id_secret_name: None, created_at, }; @@ -3424,7 +3500,7 @@ mod tests { secrets, sse_sender: None, gateway_token: None, - resource: None, + token_exchange_extra_params: std::collections::HashMap::new(), client_id_secret_name: None, // Expired — handler will reject after lookup (no network I/O) created_at, @@ -3476,6 +3552,85 @@ mod tests { ); } + #[tokio::test] + async fn test_oauth_callback_accepts_versioned_hosted_state() { + use axum::body::Body; + use tower::ServiceExt; + + let secrets: Arc = + Arc::new(crate::secrets::InMemorySecretsStore::new(Arc::new( + crate::secrets::SecretsCrypto::new(secrecy::SecretString::from( + TEST_GATEWAY_CRYPTO_KEY.to_string(), + )) + .expect("crypto"), + ))); + let (ext_mgr, _wasm_tools_dir, _wasm_channels_dir) = test_ext_mgr(secrets.clone()); + + let Some(created_at) = expired_flow_created_at() else { + eprintln!("Skipping versioned OAuth state test: monotonic uptime below expiry window"); + return; + }; + let flow = crate::cli::oauth_defaults::PendingOAuthFlow { + extension_name: "test_tool".to_string(), + display_name: "Test Tool".to_string(), + token_url: "https://example.com/token".to_string(), + client_id: "client123".to_string(), + client_secret: None, + redirect_uri: "https://example.com/oauth/callback".to_string(), + code_verifier: None, + access_token_field: "access_token".to_string(), + secret_name: "test_token".to_string(), + provider: None, + validation_endpoint: None, + scopes: vec![], + user_id: "test".to_string(), + secrets, + sse_sender: None, + gateway_token: None, + token_exchange_extra_params: std::collections::HashMap::new(), + client_id_secret_name: None, + created_at, + }; + + ext_mgr + .pending_oauth_flows() + .write() + .await + .insert("test_nonce".to_string(), flow); + + let state = test_gateway_state(Some(ext_mgr.clone())); + let app = test_oauth_router(state); + let versioned_state = + crate::cli::oauth_defaults::encode_hosted_oauth_state("test_nonce", Some("myinstance")); + + let req = axum::http::Request::builder() + .uri(format!( + "/oauth/callback?code=fake_code&state={}", + urlencoding::encode(&versioned_state) + )) + .body(Body::empty()) + .expect("request"); + + let resp = ServiceExt::>::oneshot(app, req) + .await + .expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), 1024 * 64) + .await + .expect("body"); + let html = String::from_utf8_lossy(&body); + assert!(html.contains("Authorization Failed")); + assert!( + ext_mgr + .pending_oauth_flows() + .read() + .await + .get("test_nonce") + .is_none() + ); + } + // --- Slack relay OAuth CSRF tests --- fn test_relay_oauth_router(state: Arc) -> Router { @@ -3533,7 +3688,7 @@ mod tests { // Callback without state param should be rejected let req = axum::http::Request::builder() - .uri("/oauth/slack/callback?stream_token=tok123&team_id=T123&provider=slack") + .uri("/oauth/slack/callback?team_id=T123&provider=slack") .body(Body::empty()) .expect("request"); @@ -3577,7 +3732,7 @@ mod tests { // Callback with wrong state param let req = axum::http::Request::builder() - .uri("/oauth/slack/callback?stream_token=tok123&team_id=T123&provider=slack&state=wrong-nonce") + .uri("/oauth/slack/callback?team_id=T123&provider=slack&state=wrong-nonce") .body(Body::empty()) .expect("request"); @@ -3625,7 +3780,7 @@ mod tests { // we just verify it doesn't return a CSRF error. let req = axum::http::Request::builder() .uri(format!( - "/oauth/slack/callback?stream_token=tok123&team_id=T123&provider=slack&state={}", + "/oauth/slack/callback?team_id=T123&provider=slack&state={}", nonce )) .body(Body::empty()) diff --git a/src/channels/web/static/app.js b/src/channels/web/static/app.js index 82b033b2d4..bc23d68c4d 100644 --- a/src/channels/web/static/app.js +++ b/src/channels/web/static/app.js @@ -1138,18 +1138,19 @@ function showApproval(data) { approveBtn.textContent = I18n.t('approval.approve'); approveBtn.addEventListener('click', () => sendApprovalAction(data.request_id, 'approve')); - const alwaysBtn = document.createElement('button'); - alwaysBtn.className = 'always'; - alwaysBtn.textContent = I18n.t('approval.always'); - alwaysBtn.addEventListener('click', () => sendApprovalAction(data.request_id, 'always')); - const denyBtn = document.createElement('button'); denyBtn.className = 'deny'; denyBtn.textContent = I18n.t('approval.deny'); denyBtn.addEventListener('click', () => sendApprovalAction(data.request_id, 'deny')); actions.appendChild(approveBtn); - actions.appendChild(alwaysBtn); + if (data.allow_always !== false) { + const alwaysBtn = document.createElement('button'); + alwaysBtn.className = 'always'; + alwaysBtn.textContent = I18n.t('approval.always'); + alwaysBtn.addEventListener('click', () => sendApprovalAction(data.request_id, 'always')); + actions.appendChild(alwaysBtn); + } actions.appendChild(denyBtn); card.appendChild(actions); diff --git a/src/channels/web/types.rs b/src/channels/web/types.rs index 3fad9f3525..861b5bd2d4 100644 --- a/src/channels/web/types.rs +++ b/src/channels/web/types.rs @@ -177,6 +177,8 @@ pub enum SseEvent { parameters: String, #[serde(skip_serializing_if = "Option::is_none")] thread_id: Option, + /// Whether the "always" auto-approve option should be shown. + allow_always: bool, }, #[serde(rename = "auth_required")] AuthRequired { @@ -230,6 +232,8 @@ pub enum SseEvent { status: String, #[serde(skip_serializing_if = "Option::is_none")] session_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + fallback_deliverable: Option, }, /// An image was generated by a tool. @@ -1080,6 +1084,7 @@ mod tests { description: "Run ls".to_string(), parameters: "{}".to_string(), thread_id: Some("t1".to_string()), + allow_always: true, }; let ws = WsServerMessage::from_sse_event(&sse); match ws { diff --git a/src/cli/memory.rs b/src/cli/memory.rs index a3df3625f5..2d0606a854 100644 --- a/src/cli/memory.rs +++ b/src/cli/memory.rs @@ -7,17 +7,18 @@ use std::sync::Arc; use clap::Subcommand; -use crate::workspace::{EmbeddingProvider, SearchConfig, Workspace}; +use crate::workspace::{EmbeddingCacheConfig, EmbeddingProvider, SearchConfig, Workspace}; /// Run a memory command using the Database trait (works with any backend). pub async fn run_memory_command_with_db( cmd: MemoryCommand, db: std::sync::Arc, embeddings: Option>, + cache_config: EmbeddingCacheConfig, ) -> anyhow::Result<()> { let mut workspace = Workspace::new_with_db("default", db); if let Some(emb) = embeddings { - workspace = workspace.with_embeddings(emb); + workspace = workspace.with_embeddings_cached(emb, cache_config); } match cmd { @@ -85,10 +86,11 @@ pub async fn run_memory_command( cmd: MemoryCommand, pool: deadpool_postgres::Pool, embeddings: Option>, + cache_config: EmbeddingCacheConfig, ) -> anyhow::Result<()> { let mut workspace = Workspace::new("default", pool); if let Some(emb) = embeddings { - workspace = workspace.with_embeddings(emb); + workspace = workspace.with_embeddings_cached(emb, cache_config); } match cmd { diff --git a/src/cli/mod.rs b/src/cli/mod.rs index cf3c793e81..54779ae19a 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -336,7 +336,10 @@ pub async fn run_memory_command(mem_cmd: &MemoryCommand) -> anyhow::Result<()> { .await .map_err(|e| anyhow::anyhow!("{}", e))?; - run_memory_command_with_db(mem_cmd.clone(), db, embeddings).await + let cache_config = crate::workspace::EmbeddingCacheConfig { + max_entries: config.embeddings.cache_size, + }; + run_memory_command_with_db(mem_cmd.clone(), db, embeddings, cache_config).await } #[cfg(test)] diff --git a/src/cli/oauth_defaults.rs b/src/cli/oauth_defaults.rs index a625f718f5..874cff987b 100644 --- a/src/cli/oauth_defaults.rs +++ b/src/cli/oauth_defaults.rs @@ -5,17 +5,10 @@ //! //! # Built-in Credentials //! -//! Many CLI tools (gcloud, rclone, gdrive) ship with default OAuth credentials -//! so users don't need to register their own OAuth app. Google explicitly -//! documents that client_secret for "Desktop App" / "Installed App" types -//! is NOT actually secret. -//! -//! Default credentials are hardcoded below. They can be overridden at: -//! -//! - **Compile time**: Set IRONCLAW_GOOGLE_CLIENT_ID / IRONCLAW_GOOGLE_CLIENT_SECRET -//! env vars before building to replace the hardcoded defaults. -//! - **Runtime**: Users can set GOOGLE_OAUTH_CLIENT_ID / GOOGLE_OAUTH_CLIENT_SECRET -//! env vars, which take priority over built-in defaults. +//! Some providers ship with built-in OAuth credentials so users don't need to +//! register their own OAuth app just to get started. Today this module only +//! includes built-in defaults for Google-family tools, and those defaults can +//! be overridden by provider-specific environment variables when needed. use std::collections::HashMap; use std::sync::Arc; @@ -23,6 +16,7 @@ use std::time::Duration; use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; use rand::RngCore; +use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use tokio::sync::RwLock; @@ -60,6 +54,14 @@ pub fn builtin_credentials(secret_name: &str) -> Option { } } +/// Returns the compile-time override env var name, if this provider supports one. +pub fn builtin_client_id_override_env(secret_name: &str) -> Option<&'static str> { + match secret_name { + "google_oauth_token" => Some("IRONCLAW_GOOGLE_CLIENT_ID"), + _ => None, + } +} + // ── Shared callback server ────────────────────────────────────────────── // Core OAuth callback infrastructure is defined in `crate::llm::oauth_helpers` @@ -173,9 +175,8 @@ pub async fn exchange_oauth_code( code_verifier: Option<&str>, access_token_field: &str, ) -> Result { - // Delegates to exchange_oauth_code_with_resource with resource=None. - // Non-MCP OAuth flows don't need the RFC 8707 resource parameter. - exchange_oauth_code_with_resource( + let extra_token_params = HashMap::new(); + exchange_oauth_code_with_params( token_url, client_id, client_secret, @@ -183,16 +184,14 @@ pub async fn exchange_oauth_code( redirect_uri, code_verifier, access_token_field, - None, + &extra_token_params, ) .await } -/// Exchange an OAuth authorization code for tokens, with optional RFC 8707 `resource` parameter. -/// -/// The `resource` parameter scopes the issued token to a specific server (used by MCP OAuth). +/// Exchange an OAuth authorization code for tokens with generic extra form parameters. #[allow(clippy::too_many_arguments)] -pub async fn exchange_oauth_code_with_resource( +pub async fn exchange_oauth_code_with_params( token_url: &str, client_id: &str, client_secret: Option<&str>, @@ -200,7 +199,7 @@ pub async fn exchange_oauth_code_with_resource( redirect_uri: &str, code_verifier: Option<&str>, access_token_field: &str, - resource: Option<&str>, + extra_token_params: &HashMap, ) -> Result { let client = reqwest::Client::new(); let mut token_params = vec![ @@ -213,10 +212,8 @@ pub async fn exchange_oauth_code_with_resource( token_params.push(("code_verifier", verifier.to_string())); } - // RFC 8707: include the `resource` parameter so the authorization server - // scopes the issued token to the specific MCP server (protected resource). - if let Some(resource) = resource { - token_params.push(("resource", resource.to_string())); + for (key, value) in extra_token_params { + token_params.push((key.as_str(), value.clone())); } let mut request = client.post(token_url); @@ -276,6 +273,37 @@ pub async fn exchange_oauth_code_with_resource( }) } +/// Exchange an OAuth authorization code for tokens, with optional RFC 8707 `resource` parameter. +/// +/// The `resource` parameter scopes the issued token to a specific server (used by MCP OAuth). +#[allow(clippy::too_many_arguments)] +pub async fn exchange_oauth_code_with_resource( + token_url: &str, + client_id: &str, + client_secret: Option<&str>, + code: &str, + redirect_uri: &str, + code_verifier: Option<&str>, + access_token_field: &str, + resource: Option<&str>, +) -> Result { + let mut extra_token_params = HashMap::new(); + if let Some(resource) = resource { + extra_token_params.insert("resource".to_string(), resource.to_string()); + } + exchange_oauth_code_with_params( + token_url, + client_id, + client_secret, + code, + redirect_uri, + code_verifier, + access_token_field, + &extra_token_params, + ) + .await +} + /// Store OAuth tokens (access + refresh) in the secrets store. /// /// Also stores the granted scopes as `{secret_name}_scopes` so that scope @@ -423,9 +451,9 @@ pub struct PendingOAuthFlow { pub sse_sender: Option>, /// Gateway auth token for authenticating with the platform token exchange proxy. pub gateway_token: Option, - /// RFC 8707 resource parameter (MCP OAuth only). - /// Sent during token exchange to scope the token to a specific MCP server. - pub resource: Option, + /// Additional form params for the token exchange request. + /// Used for provider-specific requirements such as RFC 8707 `resource`. + pub token_exchange_extra_params: HashMap, /// Secret name for persisting the client ID (MCP OAuth only). /// Needed so token refresh can find the client_id after the session ends. pub client_id_secret_name: Option, @@ -459,9 +487,7 @@ pub fn new_pending_oauth_registry() -> PendingOAuthRegistry { /// URL, meaning the user's browser will redirect to a hosted gateway rather than /// localhost. pub fn use_gateway_callback() -> bool { - std::env::var("IRONCLAW_OAUTH_CALLBACK_URL") - .ok() - .filter(|v| !v.is_empty()) + crate::config::helpers::env_or_override("IRONCLAW_OAUTH_CALLBACK_URL") .map(|raw| { url::Url::parse(&raw) .ok() @@ -472,6 +498,13 @@ pub fn use_gateway_callback() -> bool { .unwrap_or(false) } +/// Returns the configured OAuth token-exchange proxy URL, if any. +pub fn exchange_proxy_url() -> Option { + crate::config::helpers::env_or_override("IRONCLAW_OAUTH_EXCHANGE_URL") + .map(|url| url.trim().to_string()) + .filter(|url| !url.is_empty()) +} + /// Maximum age for pending OAuth flows (5 minutes, matching TCP listener timeout). pub const OAUTH_FLOW_EXPIRY: Duration = Duration::from_secs(300); @@ -486,23 +519,117 @@ pub async fn sweep_expired_flows(registry: &PendingOAuthRegistry) { // ── Platform routing helpers ──────────────────────────────────────── -/// Prepend instance name to CSRF state for platform routing. +const HOSTED_STATE_PREFIX: &str = "ic2"; +const HOSTED_STATE_CHECKSUM_BYTES: usize = 12; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DecodedHostedOAuthState { + pub flow_id: String, + pub instance_name: Option, + pub is_legacy: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct HostedOAuthStatePayload { + flow_id: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + instance_name: Option, + issued_at: u64, +} + +fn current_instance_name() -> Option { + crate::config::helpers::env_or_override("IRONCLAW_INSTANCE_NAME") + .or_else(|| crate::config::helpers::env_or_override("OPENCLAW_INSTANCE_NAME")) + .filter(|v| !v.is_empty()) +} + +fn hosted_state_checksum(payload_bytes: &[u8]) -> String { + let digest = Sha256::digest(payload_bytes); + URL_SAFE_NO_PAD.encode(&digest[..HOSTED_STATE_CHECKSUM_BYTES]) +} + +/// Build a versioned hosted OAuth state envelope. /// -/// The NEAR AI platform nginx proxy at `auth.DOMAIN` parses the instance name -/// from the `state` query parameter (format: `instance:nonce`) to route the -/// OAuth callback to the correct container. +/// The encoded value is opaque to providers and can be decoded by both +/// IronClaw and the external auth proxy for routing and callback lookup. +pub fn encode_hosted_oauth_state(flow_id: &str, instance_name: Option<&str>) -> String { + let payload = HostedOAuthStatePayload { + flow_id: flow_id.to_string(), + instance_name: instance_name + .map(str::trim) + .filter(|v| !v.is_empty()) + .map(str::to_string), + issued_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + let payload_json = match serde_json::to_vec(&payload) { + Ok(payload_json) => payload_json, + Err(error) => { + tracing::warn!(%error, flow_id, "Failed to serialize hosted OAuth state payload"); + return payload.flow_id; + } + }; + let payload = URL_SAFE_NO_PAD.encode(&payload_json); + let checksum = hosted_state_checksum(&payload_json); + format!("{HOSTED_STATE_PREFIX}.{payload}.{checksum}") +} + +/// Decode hosted OAuth state in either the new versioned format or the +/// legacy `instance:nonce`/`nonce` forms. +pub fn decode_hosted_oauth_state(state: &str) -> Result { + if let Some(rest) = state.strip_prefix(&format!("{HOSTED_STATE_PREFIX}.")) + && let Some((payload_b64, checksum)) = rest.rsplit_once('.') + && let Ok(payload_json) = URL_SAFE_NO_PAD.decode(payload_b64) + { + let expected_checksum = hosted_state_checksum(&payload_json); + if checksum != expected_checksum { + return Err("Hosted OAuth state checksum mismatch".to_string()); + } + if let Ok(payload) = serde_json::from_slice::(&payload_json) + && !payload.flow_id.trim().is_empty() + { + return Ok(DecodedHostedOAuthState { + flow_id: payload.flow_id, + instance_name: payload.instance_name.filter(|v| !v.is_empty()), + is_legacy: false, + }); + } + } + + if let Some((instance_name, flow_id)) = state.split_once(':') { + if flow_id.is_empty() { + return Err("Hosted OAuth legacy state is missing flow_id".to_string()); + } + return Ok(DecodedHostedOAuthState { + flow_id: flow_id.to_string(), + instance_name: if instance_name.is_empty() { + None + } else { + Some(instance_name.to_string()) + }, + is_legacy: true, + }); + } + + if state.is_empty() { + return Err("Hosted OAuth state is empty".to_string()); + } + + Ok(DecodedHostedOAuthState { + flow_id: state.to_string(), + instance_name: None, + is_legacy: true, + }) +} + +/// Build the hosted callback state used by the public OAuth callback endpoint. /// -/// Returns the nonce unchanged when `IRONCLAW_INSTANCE_NAME` is not set -/// (local/non-platform mode). +/// New flows emit a versioned opaque envelope, while callback decoding accepts +/// both the envelope and the legacy `instance:nonce` contract. pub fn build_platform_state(nonce: &str) -> String { - let instance = std::env::var("IRONCLAW_INSTANCE_NAME") - .or_else(|_| std::env::var("OPENCLAW_INSTANCE_NAME")) - .ok() - .filter(|v| !v.is_empty()); - match instance { - Some(name) => format!("{}:{}", name, nonce), - None => nonce.to_string(), - } + encode_hosted_oauth_state(nonce, current_instance_name().as_deref()) } /// Strip the instance prefix from a state parameter to recover the lookup nonce. @@ -517,43 +644,62 @@ pub fn strip_instance_prefix(state: &str) -> &str { .unwrap_or(state) } +pub struct ProxyTokenExchangeRequest<'a> { + pub proxy_url: &'a str, + pub gateway_token: &'a str, + pub token_url: &'a str, + pub client_id: &'a str, + pub client_secret: Option<&'a str>, + pub code: &'a str, + pub redirect_uri: &'a str, + pub code_verifier: Option<&'a str>, + pub access_token_field: &'a str, + pub extra_token_params: &'a HashMap, +} + /// Exchange an OAuth authorization code via the platform's token exchange proxy. /// -/// The proxy holds `client_secret` server-side so the container never sees it. -/// Authenticated via the gateway auth token (Bearer header). +/// Authenticated via the gateway auth token (Bearer header). The caller may +/// either rely on proxy-side secret lookup or forward a `client_secret` when +/// the provider requires it. /// -/// The proxy expects form params `{code, redirect_uri, code_verifier}` and -/// returns a standard Google token response `{access_token, refresh_token, expires_in}`. +/// The proxy expects standard OAuth form params plus optional provider-specific +/// token params and returns a standard token response such as +/// `{access_token, refresh_token, expires_in}`. pub async fn exchange_via_proxy( - proxy_url: &str, - gateway_token: &str, - code: &str, - redirect_uri: &str, - code_verifier: Option<&str>, - access_token_field: &str, + request: ProxyTokenExchangeRequest<'_>, ) -> Result { - if gateway_token.is_empty() { + if request.gateway_token.is_empty() { return Err(OAuthCallbackError::Io( "Gateway auth token is required for proxy token exchange".to_string(), )); } - let exchange_url = format!("{}/oauth/exchange", proxy_url.trim_end_matches('/')); + let exchange_url = format!("{}/oauth/exchange", request.proxy_url.trim_end_matches('/')); let client = reqwest::Client::builder() .timeout(Duration::from_secs(60)) .build() .map_err(|e| OAuthCallbackError::Io(format!("Failed to build HTTP client: {}", e)))?; let mut params = vec![ - ("code", code.to_string()), - ("redirect_uri", redirect_uri.to_string()), + ("code", request.code.to_string()), + ("redirect_uri", request.redirect_uri.to_string()), + ("token_url", request.token_url.to_string()), + ("client_id", request.client_id.to_string()), + ("access_token_field", request.access_token_field.to_string()), ]; - if let Some(verifier) = code_verifier { + if let Some(verifier) = request.code_verifier { params.push(("code_verifier", verifier.to_string())); } + if let Some(secret) = request.client_secret { + params.push(("client_secret", secret.to_string())); + } + for (key, value) in request.extra_token_params { + params.push((key.as_str(), value.clone())); + } let response = client .post(&exchange_url) - .bearer_auth(gateway_token) + .bearer_auth(request.gateway_token) .form(¶ms) .send() .await @@ -576,7 +722,7 @@ pub async fn exchange_via_proxy( .map_err(|e| OAuthCallbackError::Io(format!("Failed to parse proxy response: {}", e)))?; let access_token = token_data - .get(access_token_field) + .get(request.access_token_field) .and_then(|v| v.as_str()) .ok_or_else(|| { let fields: Vec<&str> = token_data @@ -585,7 +731,7 @@ pub async fn exchange_via_proxy( .unwrap_or_default(); OAuthCallbackError::Io(format!( "No '{}' field in proxy response (fields present: {:?})", - access_token_field, fields + request.access_token_field, fields )) })? .to_string(); @@ -605,14 +751,10 @@ pub async fn exchange_via_proxy( #[cfg(test)] mod tests { - use std::sync::Mutex; - use crate::cli::oauth_defaults::{ builtin_credentials, callback_host, callback_url, is_loopback_host, landing_html, }; - - /// Serializes env-mutating tests to prevent parallel races. - static ENV_MUTEX: Mutex<()> = Mutex::new(()); + use crate::config::helpers::ENV_MUTEX; #[test] fn test_is_loopback_host() { @@ -935,7 +1077,7 @@ mod tests { #[test] fn test_build_platform_state_with_instance() { - use crate::cli::oauth_defaults::build_platform_state; + use crate::cli::oauth_defaults::{build_platform_state, decode_hosted_oauth_state}; let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); let original = std::env::var("IRONCLAW_INSTANCE_NAME").ok(); @@ -943,7 +1085,11 @@ mod tests { unsafe { std::env::set_var("IRONCLAW_INSTANCE_NAME", "kind-deer"); } - assert_eq!(build_platform_state("abc123"), "kind-deer:abc123"); + let encoded = build_platform_state("abc123"); + let decoded = decode_hosted_oauth_state(&encoded).expect("decode hosted state"); + assert_eq!(decoded.flow_id, "abc123"); + assert_eq!(decoded.instance_name.as_deref(), Some("kind-deer")); + assert!(!decoded.is_legacy); unsafe { if let Some(val) = original { std::env::set_var("IRONCLAW_INSTANCE_NAME", val); @@ -955,7 +1101,7 @@ mod tests { #[test] fn test_build_platform_state_without_instance() { - use crate::cli::oauth_defaults::build_platform_state; + use crate::cli::oauth_defaults::{build_platform_state, decode_hosted_oauth_state}; let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); let original = std::env::var("IRONCLAW_INSTANCE_NAME").ok(); @@ -965,7 +1111,11 @@ mod tests { std::env::remove_var("IRONCLAW_INSTANCE_NAME"); std::env::remove_var("OPENCLAW_INSTANCE_NAME"); } - assert_eq!(build_platform_state("abc123"), "abc123"); + let encoded = build_platform_state("abc123"); + let decoded = decode_hosted_oauth_state(&encoded).expect("decode hosted state"); + assert_eq!(decoded.flow_id, "abc123"); + assert_eq!(decoded.instance_name, None); + assert!(!decoded.is_legacy); unsafe { if let Some(val) = original { std::env::set_var("IRONCLAW_INSTANCE_NAME", val); @@ -978,7 +1128,7 @@ mod tests { #[test] fn test_build_platform_state_with_openclaw_instance() { - use crate::cli::oauth_defaults::build_platform_state; + use crate::cli::oauth_defaults::{build_platform_state, decode_hosted_oauth_state}; let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); let original_ic = std::env::var("IRONCLAW_INSTANCE_NAME").ok(); @@ -988,7 +1138,11 @@ mod tests { std::env::remove_var("IRONCLAW_INSTANCE_NAME"); std::env::set_var("OPENCLAW_INSTANCE_NAME", "quiet-lion"); } - assert_eq!(build_platform_state("xyz789"), "quiet-lion:xyz789"); + let encoded = build_platform_state("xyz789"); + let decoded = decode_hosted_oauth_state(&encoded).expect("decode hosted state"); + assert_eq!(decoded.flow_id, "xyz789"); + assert_eq!(decoded.instance_name.as_deref(), Some("quiet-lion")); + assert!(!decoded.is_legacy); unsafe { if let Some(val) = original_ic { std::env::set_var("IRONCLAW_INSTANCE_NAME", val); @@ -1017,6 +1171,42 @@ mod tests { assert_eq!(strip_instance_prefix(""), ""); } + #[test] + fn test_decode_hosted_oauth_state_accepts_legacy_formats() { + use crate::cli::oauth_defaults::decode_hosted_oauth_state; + + let decoded = decode_hosted_oauth_state("kind-deer:abc123").expect("legacy prefixed"); + assert_eq!(decoded.flow_id, "abc123"); + assert_eq!(decoded.instance_name.as_deref(), Some("kind-deer")); + assert!(decoded.is_legacy); + + let decoded = decode_hosted_oauth_state("abc123").expect("legacy raw"); + assert_eq!(decoded.flow_id, "abc123"); + assert_eq!(decoded.instance_name, None); + assert!(decoded.is_legacy); + } + + #[test] + fn test_decode_hosted_oauth_state_falls_back_for_non_envelope_ic2_prefix() { + use crate::cli::oauth_defaults::decode_hosted_oauth_state; + + let decoded = + decode_hosted_oauth_state("ic2.provider-owned-state").expect("prefixed fallback"); + assert_eq!(decoded.flow_id, "ic2.provider-owned-state"); + assert_eq!(decoded.instance_name, None); + assert!(decoded.is_legacy); + } + + #[test] + fn test_decode_hosted_oauth_state_rejects_tampered_checksum() { + use crate::cli::oauth_defaults::{decode_hosted_oauth_state, encode_hosted_oauth_state}; + + let encoded = encode_hosted_oauth_state("abc123", Some("kind-deer")); + let tampered = format!("{encoded}broken"); + let err = decode_hosted_oauth_state(&tampered).expect_err("tampered state should fail"); + assert!(err.contains("checksum"), "unexpected error: {err}"); + } + /// Verify that `build_oauth_url` includes the RFC 8707 `resource` parameter /// when passed through `extra_params`, which is how MCP OAuth gateway mode /// scopes tokens to a specific MCP server. diff --git a/src/cli/tool.rs b/src/cli/tool.rs index ac5d1b374e..be6845807c 100644 --- a/src/cli/tool.rs +++ b/src/cli/tool.rs @@ -651,8 +651,8 @@ async fn auth_tool(name: String, dir: Option, user_id: String) -> anyho // Check for OAuth configuration if let Some(ref oauth) = auth.oauth { - // For providers with shared tokens (e.g., all Google tools share google_oauth_token), - // combine scopes from all installed tools so one auth covers everything. + // For providers with shared tokens, combine scopes from all installed + // tools so one auth covers everything. let combined = combine_provider_scopes(&tools_dir, &auth.secret_name, oauth).await; if combined.scopes.len() > oauth.scopes.len() { let extra = combined.scopes.len() - oauth.scopes.len(); @@ -670,8 +670,8 @@ async fn auth_tool(name: String, dir: Option, user_id: String) -> anyho } /// Scan the tools directory for all capabilities files sharing the same secret_name -/// and combine their OAuth scopes. This way, authing any Google tool requests scopes -/// for ALL installed Google tools, so one login covers everything. +/// and combine their OAuth scopes so one authorization covers the full shared +/// credential set. async fn combine_provider_scopes( tools_dir: &Path, secret_name: &str, @@ -736,11 +736,18 @@ async fn auth_tool_oauth( }) .or_else(|| builtin.as_ref().map(|c| c.client_id.to_string())) .ok_or_else(|| { - anyhow::anyhow!( + let mut message = format!( "OAuth client_id not configured.\n\ - Set {} env var, or build with IRONCLAW_GOOGLE_CLIENT_ID.", + Set {} env var", oauth.client_id_env.as_deref().unwrap_or("the client_id") - ) + ); + if let Some(override_env) = + oauth_defaults::builtin_client_id_override_env(&auth.secret_name) + { + message.push_str(&format!(", or build with {override_env}")); + } + message.push('.'); + anyhow::anyhow!(message) })?; // Get client_secret: capabilities file > runtime env var > built-in defaults diff --git a/src/config/embeddings.rs b/src/config/embeddings.rs index a1c3ecd7ed..43fea73a29 100644 --- a/src/config/embeddings.rs +++ b/src/config/embeddings.rs @@ -8,6 +8,9 @@ use crate::llm::SessionManager; use crate::settings::Settings; use crate::workspace::EmbeddingProvider; +/// Default maximum number of cached embeddings. +pub const DEFAULT_EMBEDDING_CACHE_SIZE: usize = 10_000; + /// Embeddings provider configuration. #[derive(Debug, Clone)] pub struct EmbeddingsConfig { @@ -26,6 +29,12 @@ pub struct EmbeddingsConfig { /// Custom base URL for OpenAI-compatible embedding providers. /// When set, overrides the default `https://api.openai.com`. pub openai_base_url: Option, + /// Maximum entries in the embedding LRU cache (default 10,000). + /// + /// Approximate raw embedding payload: `cache_size × dimension × 4 bytes`. + /// 10,000 × 1536 floats ≈ 58 MB (payload only; actual memory is higher + /// due to HashMap buckets, per-entry Vec/timestamp overhead). + pub cache_size: usize, } impl Default for EmbeddingsConfig { @@ -40,6 +49,7 @@ impl Default for EmbeddingsConfig { ollama_base_url: "http://localhost:11434".to_string(), dimension, openai_base_url: None, + cache_size: DEFAULT_EMBEDDING_CACHE_SIZE, } } } @@ -80,6 +90,15 @@ impl EmbeddingsConfig { let openai_base_url = optional_env("EMBEDDING_BASE_URL")?; + let cache_size = parse_optional_env("EMBEDDING_CACHE_SIZE", DEFAULT_EMBEDDING_CACHE_SIZE)?; + + if cache_size == 0 { + return Err(ConfigError::InvalidValue { + key: "EMBEDDING_CACHE_SIZE".to_string(), + message: "must be at least 1".to_string(), + }); + } + Ok(Self { enabled, provider, @@ -88,6 +107,7 @@ impl EmbeddingsConfig { ollama_base_url, dimension, openai_base_url, + cache_size, }) } @@ -183,13 +203,13 @@ mod tests { std::env::remove_var("EMBEDDING_MODEL"); std::env::remove_var("OPENAI_API_KEY"); std::env::remove_var("EMBEDDING_BASE_URL"); + std::env::remove_var("EMBEDDING_CACHE_SIZE"); } } #[test] fn embeddings_disabled_not_overridden_by_openai_key() { let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); - clear_embedding_env(); // SAFETY: Under ENV_MUTEX, no concurrent env access. unsafe { @@ -240,7 +260,6 @@ mod tests { #[test] fn embeddings_env_override_takes_precedence() { let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); - clear_embedding_env(); // SAFETY: Under ENV_MUTEX. unsafe { @@ -281,10 +300,8 @@ mod tests { let config = EmbeddingsConfig::resolve(&settings).expect("resolve should succeed"); assert_eq!( config.openai_base_url.as_deref(), - Some("https://custom.example.com"), - "EMBEDDING_BASE_URL env var should be parsed into openai_base_url" + Some("https://custom.example.com") ); - // SAFETY: Under ENV_MUTEX. unsafe { std::env::remove_var("EMBEDDING_BASE_URL"); @@ -303,4 +320,24 @@ mod tests { "openai_base_url should be None when EMBEDDING_BASE_URL is not set" ); } + + #[test] + fn cache_size_zero_rejected() { + let _guard = ENV_MUTEX.lock().expect("env mutex poisoned"); + clear_embedding_env(); + // SAFETY: Under ENV_MUTEX. + unsafe { + std::env::set_var("EMBEDDING_CACHE_SIZE", "0"); + } + + let settings = Settings::default(); + let result = EmbeddingsConfig::resolve(&settings); + assert!(result.is_err(), "cache_size=0 should be rejected"); + let err = result.unwrap_err().to_string(); + assert!(err.contains("at least 1"), "should mention minimum: {err}"); + // SAFETY: Under ENV_MUTEX. + unsafe { + std::env::remove_var("EMBEDDING_CACHE_SIZE"); + } + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 38c8088050..300fb08e71 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -38,7 +38,7 @@ pub use self::channels::{ ChannelsConfig, CliConfig, DEFAULT_GATEWAY_PORT, GatewayConfig, HttpConfig, SignalConfig, }; pub use self::database::{DatabaseBackend, DatabaseConfig, SslMode, default_libsql_path}; -pub use self::embeddings::EmbeddingsConfig; +pub use self::embeddings::{DEFAULT_EMBEDDING_CACHE_SIZE, EmbeddingsConfig}; pub use self::heartbeat::HeartbeatConfig; pub use self::hygiene::HygieneConfig; pub use self::llm::default_session_path; diff --git a/src/config/relay.rs b/src/config/relay.rs index d45de18820..e1ba82216a 100644 --- a/src/config/relay.rs +++ b/src/config/relay.rs @@ -7,7 +7,7 @@ use secrecy::SecretString; pub struct RelayConfig { /// Base URL of the channel-relay service (e.g., `http://localhost:3001`). pub url: String, - /// API key for authenticated channel-relay endpoints. + /// Bearer token for authenticated channel-relay endpoints (`sk-agent-*`). pub api_key: SecretString, /// Override for the OAuth callback URL (e.g., a tunnel URL). pub callback_url: Option, @@ -15,12 +15,8 @@ pub struct RelayConfig { pub instance_id: Option, /// HTTP request timeout in seconds (default: 30). pub request_timeout_secs: u64, - /// SSE stream long-poll timeout in seconds (default: 86400 = 24 h). - pub stream_timeout_secs: u64, - /// Initial exponential backoff in milliseconds (default: 1000). - pub backoff_initial_ms: u64, - /// Maximum exponential backoff in milliseconds (default: 60000). - pub backoff_max_ms: u64, + /// Path for the webhook callback endpoint (default: `/relay/events`). + pub webhook_path: String, } impl std::fmt::Debug for RelayConfig { @@ -31,9 +27,7 @@ impl std::fmt::Debug for RelayConfig { .field("callback_url", &self.callback_url) .field("instance_id", &self.instance_id) .field("request_timeout_secs", &self.request_timeout_secs) - .field("stream_timeout_secs", &self.stream_timeout_secs) - .field("backoff_initial_ms", &self.backoff_initial_ms) - .field("backoff_max_ms", &self.backoff_max_ms) + .field("webhook_path", &self.webhook_path) .finish() } } @@ -41,8 +35,10 @@ impl std::fmt::Debug for RelayConfig { impl RelayConfig { /// Load relay config from environment variables. /// - /// Returns `None` if either `CHANNEL_RELAY_URL` or `CHANNEL_RELAY_API_KEY` - /// is not set, making the relay integration opt-in. + /// Returns `None` if either of the required env vars (`CHANNEL_RELAY_URL`, + /// `CHANNEL_RELAY_API_KEY`) is not set, making the relay integration opt-in. + /// The signing secret is fetched from channel-relay at activation time via + /// the authenticated `/relay/signing-secret` endpoint — no env var required. pub fn from_env() -> Option { Self::from_env_reader(|key| std::env::var(key).ok()) } @@ -55,9 +51,7 @@ impl RelayConfig { callback_url: None, instance_id: None, request_timeout_secs: 30, - stream_timeout_secs: 86400, - backoff_initial_ms: 1000, - backoff_max_ms: 60000, + webhook_path: "/relay/events".into(), } } @@ -73,15 +67,7 @@ impl RelayConfig { request_timeout_secs: env("RELAY_REQUEST_TIMEOUT_SECS") .and_then(|v| v.parse().ok()) .unwrap_or(30), - stream_timeout_secs: env("RELAY_STREAM_TIMEOUT_SECS") - .and_then(|v| v.parse().ok()) - .unwrap_or(86400), - backoff_initial_ms: env("RELAY_BACKOFF_INITIAL_MS") - .and_then(|v| v.parse().ok()) - .unwrap_or(1000), - backoff_max_ms: env("RELAY_BACKOFF_MAX_MS") - .and_then(|v| v.parse().ok()) - .unwrap_or(60000), + webhook_path: env("RELAY_WEBHOOK_PATH").unwrap_or_else(|| "/relay/events".into()), }) } } @@ -97,7 +83,21 @@ mod tests { } #[test] - fn from_env_reader_loads_defaults() { + fn from_env_reader_requires_only_url_and_api_key() { + // Signing secret is fetched at activation time — only URL + API key needed. + let config = RelayConfig::from_env_reader(|key| match key { + "CHANNEL_RELAY_URL" => Some("http://localhost:3001".into()), + "CHANNEL_RELAY_API_KEY" => Some("test-key".into()), + _ => None, + }); + assert!( + config.is_some(), + "relay config should load with just URL + API key" + ); + } + + #[test] + fn from_env_reader_loads_all_required() { let config = RelayConfig::from_env_reader(|key| match key { "CHANNEL_RELAY_URL" => Some("http://localhost:3001".into()), "CHANNEL_RELAY_API_KEY" => Some("test-key".into()), @@ -107,9 +107,7 @@ mod tests { assert_eq!(config.url, "http://localhost:3001"); assert_eq!(config.request_timeout_secs, 30); - assert_eq!(config.stream_timeout_secs, 86400); - assert_eq!(config.backoff_initial_ms, 1000); - assert_eq!(config.backoff_max_ms, 60000); + assert_eq!(config.webhook_path, "/relay/events"); assert!(config.callback_url.is_none()); assert!(config.instance_id.is_none()); } @@ -122,9 +120,7 @@ mod tests { "IRONCLAW_OAUTH_CALLBACK_URL" => Some("https://tunnel.example.com".into()), "IRONCLAW_INSTANCE_ID" => Some("my-instance".into()), "RELAY_REQUEST_TIMEOUT_SECS" => Some("60".into()), - "RELAY_STREAM_TIMEOUT_SECS" => Some("43200".into()), - "RELAY_BACKOFF_INITIAL_MS" => Some("2000".into()), - "RELAY_BACKOFF_MAX_MS" => Some("120000".into()), + "RELAY_WEBHOOK_PATH" => Some("/custom/events".into()), _ => None, }) .expect("config should be Some"); @@ -135,9 +131,7 @@ mod tests { ); assert_eq!(config.instance_id.as_deref(), Some("my-instance")); assert_eq!(config.request_timeout_secs, 60); - assert_eq!(config.stream_timeout_secs, 43200); - assert_eq!(config.backoff_initial_ms, 2000); - assert_eq!(config.backoff_max_ms, 120000); + assert_eq!(config.webhook_path, "/custom/events"); } #[test] @@ -148,7 +142,7 @@ mod tests { } #[test] - fn debug_redacts_api_key() { + fn debug_redacts_secrets() { let config = RelayConfig::from_values("http://localhost:3001", "super-secret"); let debug = format!("{:?}", config); assert!(debug.contains("[REDACTED]")); diff --git a/src/context/fallback.rs b/src/context/fallback.rs new file mode 100644 index 0000000000..6e76557392 --- /dev/null +++ b/src/context/fallback.rs @@ -0,0 +1,319 @@ +//! Structured fallback deliverables for failed or stuck jobs. +//! +//! When a job fails or is detected as stuck, a [`FallbackDeliverable`] captures +//! what was accomplished before the failure: partial results, action statistics, +//! cost, and timing. This gives users visibility into terminal jobs instead of +//! just an error string. +//! +//! Fallback deliverables are stored in `JobContext.metadata["fallback_deliverable"]` +//! and surfaced through the `job_status` tool. + +use serde::{Deserialize, Serialize}; + +use crate::context::memory::Memory; +use crate::context::state::JobContext; + +/// Structured summary of a failed or stuck job. +/// +/// Stored in `JobContext.metadata["fallback_deliverable"]` when a job fails +/// or is marked stuck. Surfaced through the `job_status` tool. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FallbackDeliverable { + /// True if at least one action succeeded before failure. + pub partial: bool, + /// Why the job failed. + pub failure_reason: String, + /// Last action taken before failure. + pub last_action: Option, + /// Aggregate action statistics. + pub action_stats: ActionStats, + /// Total tokens consumed. + pub tokens_used: u64, + /// Total cost incurred (decimal as string for JSON safety). + pub cost: String, + /// Wall-clock elapsed time in seconds. + pub elapsed_secs: f64, + /// Number of self-repair attempts. + pub repair_attempts: u32, +} + +/// Summary of the last action taken before failure. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LastAction { + pub tool_name: String, + /// Truncated to 200 bytes (UTF-8 safe). + pub output_preview: String, + pub success: bool, +} + +/// Aggregate action counts. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ActionStats { + pub total: u32, + pub successful: u32, + pub failed: u32, +} + +impl FallbackDeliverable { + /// Build a fallback deliverable from a job context and its memory. + pub fn build(ctx: &JobContext, memory: &Memory, reason: &str) -> Self { + let successful = memory.successful_actions() as u32; + let failed = memory.failed_actions() as u32; + let total = memory.actions.len() as u32; + + let last_action = memory.last_action().map(|a| { + // Use sanitized output to avoid leaking secrets through the fallback API surface. + // For failed actions (no sanitized output), fall back to the error message. + // Borrow the string slice directly when possible to avoid cloning + // potentially large outputs just for truncation. + let owned_fallback; + let preview_str: &str = if let Some(v) = a.output_sanitized.as_ref() { + match v { + serde_json::Value::String(s) => s.as_str(), + other => { + owned_fallback = serde_json::to_string(other).unwrap_or_default(); + &owned_fallback + } + } + } else if let Some(ref err) = a.error { + err.as_str() + } else { + "" + }; + let preview = truncate_str(preview_str, 200); + LastAction { + tool_name: a.tool_name.clone(), + output_preview: preview.to_string(), + success: a.success, + } + }); + + let elapsed_secs = ctx.elapsed().map_or(0.0, |d| d.as_secs_f64()); + + Self { + partial: successful > 0, + failure_reason: truncate_str(reason, 1000).to_string(), + last_action, + action_stats: ActionStats { + total, + successful, + failed, + }, + tokens_used: ctx.total_tokens_used, + cost: ctx.actual_cost.to_string(), + elapsed_secs, + repair_attempts: ctx.repair_attempts, + } + } +} + +/// Truncate a string to at most `max_len` bytes on a char boundary. +fn truncate_str(s: &str, max_len: usize) -> &str { + &s[..crate::util::floor_char_boundary(s, max_len)] +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::context::memory::Memory; + use crate::context::state::JobContext; + use chrono::{Duration, Utc}; + use rust_decimal::Decimal; + use std::time::Duration as StdDuration; + + #[test] + fn test_fallback_zero_actions() { + let ctx = JobContext::new("Test", "Empty job"); + let memory = Memory::new(ctx.job_id); + + let fb = FallbackDeliverable::build(&ctx, &memory, "timed out"); + + assert!(!fb.partial); // safety: test + assert_eq!(fb.failure_reason, "timed out"); // safety: test + assert!(fb.last_action.is_none()); // safety: test + assert_eq!(fb.action_stats.total, 0); // safety: test + assert_eq!(fb.action_stats.successful, 0); // safety: test + assert_eq!(fb.action_stats.failed, 0); // safety: test + assert_eq!(fb.tokens_used, 0); // safety: test + assert_eq!(fb.cost, "0"); // safety: test + assert_eq!(fb.repair_attempts, 0); // safety: test + } + + #[test] + fn test_fallback_mixed_actions() { + let mut ctx = JobContext::new("Test", "Mixed job"); + ctx.total_tokens_used = 5000; + ctx.actual_cost = Decimal::new(42, 2); // 0.42 + ctx.repair_attempts = 1; + + let mut memory = Memory::new(ctx.job_id); + + // 3 successes + for _ in 0..3 { + let action = memory + .create_action("tool_a", serde_json::json!({})) + .succeed( + Some("output".to_string()), + serde_json::json!({}), + StdDuration::from_secs(1), + ); + memory.record_action(action); + } + // 2 failures + for _ in 0..2 { + let action = memory + .create_action("tool_b", serde_json::json!({})) + .fail("broke", StdDuration::from_secs(1)); + memory.record_action(action); + } + + let fb = FallbackDeliverable::build(&ctx, &memory, "max iterations"); + + assert!(fb.partial); // safety: test + assert_eq!(fb.action_stats.total, 5); // safety: test + assert_eq!(fb.action_stats.successful, 3); // safety: test + assert_eq!(fb.action_stats.failed, 2); // safety: test + assert_eq!(fb.tokens_used, 5000); // safety: test + assert_eq!(fb.cost, "0.42"); // safety: test + assert_eq!(fb.repair_attempts, 1); // safety: test + assert!(fb.last_action.is_some()); // safety: test + let la = fb.last_action.unwrap(); // safety: test + assert_eq!(la.tool_name, "tool_b"); // safety: test + assert!(!la.success); // safety: test + // Failed actions should surface the error message as the output preview + assert_eq!(la.output_preview, "broke"); // safety: test + } + + #[test] + fn test_fallback_failed_action_shows_error() { + let ctx = JobContext::new("Test", "Error preview"); + let mut memory = Memory::new(ctx.job_id); + + let action = memory + .create_action("broken_tool", serde_json::json!({})) + .fail("connection timed out after 30s", StdDuration::from_secs(30)); + memory.record_action(action); + + let fb = FallbackDeliverable::build(&ctx, &memory, "tool failure"); + let la = fb.last_action.unwrap(); // safety: test + assert!(!la.success); // safety: test + assert_eq!(la.output_preview, "connection timed out after 30s"); // safety: test + } + + #[test] + fn test_fallback_last_action_truncation() { + let ctx = JobContext::new("Test", "Truncation"); + let mut memory = Memory::new(ctx.job_id); + + let long_output = "x".repeat(500); + let action = memory + .create_action("tool_c", serde_json::json!({})) + .succeed( + Some(long_output.clone()), + serde_json::Value::String(long_output), + StdDuration::from_secs(1), + ); + memory.record_action(action); + + let fb = FallbackDeliverable::build(&ctx, &memory, "failed"); + let la = fb.last_action.unwrap(); // safety: test + assert!(la.output_preview.len() <= 200); // safety: test + assert!(!la.output_preview.is_empty()); // safety: test + } + + #[test] + fn test_fallback_uses_sanitized_output() { + let ctx = JobContext::new("Test", "Sanitized"); + let mut memory = Memory::new(ctx.job_id); + + let action = memory + .create_action("tool_d", serde_json::json!({})) + .succeed( + Some("[REDACTED]".to_string()), + serde_json::json!({"api_key": "sk-secret-key-12345"}), + StdDuration::from_secs(1), + ); + memory.record_action(action); + + let fb = FallbackDeliverable::build(&ctx, &memory, "failed"); + let la = fb.last_action.unwrap(); // safety: test + // Must use sanitized output, not raw + assert!(!la.output_preview.contains("sk-secret")); // safety: test + assert!(la.output_preview.contains("REDACTED")); // safety: test + } + + #[test] + fn test_fallback_elapsed_time() { + let mut ctx = JobContext::new("Test", "Timing"); + let now = Utc::now(); + ctx.started_at = Some(now - Duration::seconds(10)); + ctx.completed_at = Some(now); + + let memory = Memory::new(ctx.job_id); + let fb = FallbackDeliverable::build(&ctx, &memory, "failed"); + + // Should be approximately 10 seconds + assert!((fb.elapsed_secs - 10.0).abs() < 0.1); // safety: test + } + + #[test] + fn test_fallback_no_started_at() { + let ctx = JobContext::new("Test", "Never started"); + let memory = Memory::new(ctx.job_id); + + let fb = FallbackDeliverable::build(&ctx, &memory, "failed"); + assert!((fb.elapsed_secs - 0.0).abs() < 0.001); // safety: test + } + + #[test] + fn test_fallback_elapsed_time_no_completed_at() { + let mut ctx = JobContext::new("Test", "Still running"); + ctx.started_at = Some(Utc::now() - Duration::seconds(5)); + // completed_at is None — should use Utc::now() as fallback + + let memory = Memory::new(ctx.job_id); + let fb = FallbackDeliverable::build(&ctx, &memory, "stuck"); + + // Should be approximately 5 seconds (using now as end time) + assert!(fb.elapsed_secs >= 4.0 && fb.elapsed_secs <= 7.0); // safety: test + } + + #[test] + fn test_fallback_failure_reason_truncation() { + let ctx = JobContext::new("Test", "Long reason"); + let memory = Memory::new(ctx.job_id); + + let long_reason = "x".repeat(5000); + let fb = FallbackDeliverable::build(&ctx, &memory, &long_reason); + + assert!(fb.failure_reason.len() <= 1000); // safety: test + assert!(!fb.failure_reason.is_empty()); // safety: test + } + + #[test] + fn test_truncate_str_ascii() { + assert_eq!(truncate_str("hello", 10), "hello"); // safety: test + assert_eq!(truncate_str("hello world", 5), "hello"); // safety: test + } + + #[test] + fn test_truncate_str_unicode() { + // "é" is 2 bytes in UTF-8 + let s = "café"; + assert_eq!(truncate_str(s, 10), "café"); // safety: test + // Truncating at 4 would split "é", should back up to 3 + assert_eq!(truncate_str(s, 4), "caf"); // safety: test + } + + #[test] + fn test_fallback_serialization() { + let ctx = JobContext::new("Test", "Serialize"); + let memory = Memory::new(ctx.job_id); + let fb = FallbackDeliverable::build(&ctx, &memory, "test error"); + + // Should serialize to JSON and back without error + let json = serde_json::to_value(&fb).unwrap(); // safety: test + let deserialized: FallbackDeliverable = serde_json::from_value(json).unwrap(); // safety: test + assert_eq!(deserialized.failure_reason, "test error"); // safety: test + } +} diff --git a/src/context/memory.rs b/src/context/memory.rs index 9452c64964..05313e6787 100644 --- a/src/context/memory.rs +++ b/src/context/memory.rs @@ -58,15 +58,19 @@ impl ActionRecord { } /// Mark the action as successful. + /// + /// `output_sanitized` is the tool output after safety processing (string). + /// `output_raw` is the original tool result (JSON value, stored as a + /// pretty-printed JSON string in `ActionRecord.output_raw`). pub fn succeed( mut self, - output_raw: Option, - output_sanitized: serde_json::Value, + output_sanitized: Option, + output_raw: serde_json::Value, duration: Duration, ) -> Self { self.success = true; - self.output_raw = output_raw; - self.output_sanitized = Some(output_sanitized); + self.output_raw = Some(serde_json::to_string_pretty(&output_raw).unwrap_or_default()); + self.output_sanitized = output_sanitized.map(serde_json::Value::String); self.duration = duration; self } @@ -248,15 +252,15 @@ mod tests { #[test] fn test_action_record() { let action = ActionRecord::new(0, "test", serde_json::json!({"key": "value"})); - assert_eq!(action.sequence, 0); - assert!(!action.success); + assert_eq!(action.sequence, 0); // safety: test + assert!(!action.success); // safety: test let action = action.succeed( Some("raw".to_string()), serde_json::json!({"result": "ok"}), Duration::from_millis(100), ); - assert!(action.success); + assert!(action.success); // safety: test } #[test] @@ -267,7 +271,7 @@ mod tests { memory.add(ChatMessage::user("How are you?")); memory.add(ChatMessage::assistant("Good!")); - assert_eq!(memory.len(), 3); // Oldest removed + assert_eq!(memory.len(), 3); // Oldest removed // safety: test } #[test] @@ -286,9 +290,9 @@ mod tests { .with_cost(Decimal::new(20, 1)); memory.record_action(action2); - assert_eq!(memory.total_cost(), Decimal::new(30, 1)); - assert_eq!(memory.total_duration(), Duration::from_secs(3)); - assert_eq!(memory.successful_actions(), 2); + assert_eq!(memory.total_cost(), Decimal::new(30, 1)); // safety: test + assert_eq!(memory.total_duration(), Duration::from_secs(3)); // safety: test + assert_eq!(memory.successful_actions(), 2); // safety: test } #[test] @@ -296,11 +300,11 @@ mod tests { let action = ActionRecord::new(1, "broken_tool", serde_json::json!({"x": 1})); let action = action.fail("something went wrong", Duration::from_millis(50)); - assert!(!action.success); - assert_eq!(action.error.as_deref(), Some("something went wrong")); - assert_eq!(action.duration, Duration::from_millis(50)); - assert!(action.output_raw.is_none()); - assert!(action.output_sanitized.is_none()); + assert!(!action.success); // safety: test + assert_eq!(action.error.as_deref(), Some("something went wrong")); // safety: test + assert_eq!(action.duration, Duration::from_millis(50)); // safety: test + assert!(action.output_raw.is_none()); // safety: test + assert!(action.output_sanitized.is_none()); // safety: test } #[test] @@ -308,9 +312,9 @@ mod tests { let action = ActionRecord::new(0, "risky_tool", serde_json::json!({})); let action = action.with_warnings(vec!["suspicious pattern".into(), "possible xss".into()]); - assert_eq!(action.sanitization_warnings.len(), 2); - assert_eq!(action.sanitization_warnings[0], "suspicious pattern"); - assert_eq!(action.sanitization_warnings[1], "possible xss"); + assert_eq!(action.sanitization_warnings.len(), 2); // safety: test + assert_eq!(action.sanitization_warnings[0], "suspicious pattern"); // safety: test + assert_eq!(action.sanitization_warnings[1], "possible xss"); // safety: test } #[test] @@ -319,41 +323,46 @@ mod tests { let cost = Decimal::new(42, 2); // 0.42 let action = action.with_cost(cost); - assert_eq!(action.cost, Some(Decimal::new(42, 2))); + assert_eq!(action.cost, Some(Decimal::new(42, 2))); // safety: test } #[test] fn test_action_record_new_defaults() { let action = ActionRecord::new(5, "my_tool", serde_json::json!({"key": "val"})); - assert_eq!(action.sequence, 5); - assert_eq!(action.tool_name, "my_tool"); - assert_eq!(action.input, serde_json::json!({"key": "val"})); - assert!(!action.success); - assert!(action.output_raw.is_none()); - assert!(action.output_sanitized.is_none()); - assert!(action.sanitization_warnings.is_empty()); - assert!(action.cost.is_none()); - assert_eq!(action.duration, Duration::ZERO); - assert!(action.error.is_none()); + assert_eq!(action.sequence, 5); // safety: test + assert_eq!(action.tool_name, "my_tool"); // safety: test + assert_eq!(action.input, serde_json::json!({"key": "val"})); // safety: test + assert!(!action.success); // safety: test + assert!(action.output_raw.is_none()); // safety: test + assert!(action.output_sanitized.is_none()); // safety: test + assert!(action.sanitization_warnings.is_empty()); // safety: test + assert!(action.cost.is_none()); // safety: test + assert_eq!(action.duration, Duration::ZERO); // safety: test + assert!(action.error.is_none()); // safety: test } #[test] fn test_action_record_succeed_sets_fields() { let action = ActionRecord::new(0, "tool", serde_json::json!({})); let action = action.succeed( - Some("raw output here".into()), + Some("sanitized output".into()), serde_json::json!({"clean": true}), Duration::from_secs(7), ); - assert!(action.success); - assert_eq!(action.output_raw.as_deref(), Some("raw output here")); + assert!(action.success); // safety: test + // output_raw is the JSON value pretty-printed + let expected_raw = + serde_json::to_string_pretty(&serde_json::json!({"clean": true})).unwrap(); // safety: test + assert_eq!(action.output_raw.as_deref(), Some(expected_raw.as_str())); // safety: test + // output_sanitized wraps the string in a JSON string value assert_eq!( + /* safety: test */ action.output_sanitized, - Some(serde_json::json!({"clean": true})) + Some(serde_json::json!("sanitized output")) ); - assert_eq!(action.duration, Duration::from_secs(7)); + assert_eq!(action.duration, Duration::from_secs(7)); // safety: test } #[test] @@ -361,13 +370,13 @@ mod tests { let mut mem = ConversationMemory::new(10); mem.add(ChatMessage::user("hello")); mem.add(ChatMessage::assistant("hi")); - assert_eq!(mem.len(), 2); - assert!(!mem.is_empty()); + assert_eq!(mem.len(), 2); // safety: test + assert!(!mem.is_empty()); // safety: test mem.clear(); - assert_eq!(mem.len(), 0); - assert!(mem.is_empty()); - assert!(mem.messages().is_empty()); + assert_eq!(mem.len(), 0); // safety: test + assert!(mem.is_empty()); // safety: test + assert!(mem.messages().is_empty()); // safety: test } #[test] @@ -379,20 +388,20 @@ mod tests { mem.add(ChatMessage::assistant("four")); let last_2 = mem.last_n(2); - assert_eq!(last_2.len(), 2); - assert_eq!(last_2[0].content, "three"); - assert_eq!(last_2[1].content, "four"); + assert_eq!(last_2.len(), 2); // safety: test + assert_eq!(last_2[0].content, "three"); // safety: test + assert_eq!(last_2[1].content, "four"); // safety: test // Requesting more than available returns all let last_100 = mem.last_n(100); - assert_eq!(last_100.len(), 4); + assert_eq!(last_100.len(), 4); // safety: test } #[test] fn test_conversation_memory_last_n_empty() { let mem = ConversationMemory::new(10); let result = mem.last_n(5); - assert!(result.is_empty()); + assert!(result.is_empty()); // safety: test } #[test] @@ -405,13 +414,13 @@ mod tests { // At capacity (3). Adding one more should trim, but keep system. mem.add(ChatMessage::user("msg3")); - assert_eq!(mem.len(), 3); + assert_eq!(mem.len(), 3); // safety: test // System message must survive - assert_eq!(mem.messages()[0].role, crate::llm::Role::System); - assert_eq!(mem.messages()[0].content, "You are helpful"); + assert_eq!(mem.messages()[0].role, crate::llm::Role::System); // safety: test + assert_eq!(mem.messages()[0].content, "You are helpful"); // safety: test // Oldest non-system message (msg1) should be gone - assert_eq!(mem.messages()[1].content, "msg2"); - assert_eq!(mem.messages()[2].content, "msg3"); + assert_eq!(mem.messages()[1].content, "msg2"); // safety: test + assert_eq!(mem.messages()[2].content, "msg3"); // safety: test } #[test] @@ -422,9 +431,9 @@ mod tests { // Now at capacity. Add another. mem.add(ChatMessage::user("b")); - assert_eq!(mem.len(), 2); - assert_eq!(mem.messages()[0].role, crate::llm::Role::System); - assert_eq!(mem.messages()[1].content, "b"); + assert_eq!(mem.len(), 2); // safety: test + assert_eq!(mem.messages()[0].role, crate::llm::Role::System); // safety: test + assert_eq!(mem.messages()[1].content, "b"); // safety: test } #[test] @@ -440,7 +449,7 @@ mod tests { mem.add(ChatMessage::user("hello")); // Should have broken out rather than looping forever. // The system message is protected, so len may exceed max. - assert!(mem.len() <= 2); + assert!(mem.len() <= 2); // safety: test } #[test] @@ -459,14 +468,14 @@ mod tests { .fail("oops", Duration::from_millis(2)); memory.record_action(err); - assert_eq!(memory.successful_actions(), 1); - assert_eq!(memory.failed_actions(), 1); + assert_eq!(memory.successful_actions(), 1); // safety: test + assert_eq!(memory.failed_actions(), 1); // safety: test } #[test] fn test_memory_last_action() { let mut memory = Memory::new(Uuid::new_v4()); - assert!(memory.last_action().is_none()); + assert!(memory.last_action().is_none()); // safety: test let a1 = memory .create_action("first", serde_json::json!({})) @@ -478,8 +487,8 @@ mod tests { .fail("nope", Duration::ZERO); memory.record_action(a2); - let last = memory.last_action().unwrap(); - assert_eq!(last.tool_name, "second"); + let last = memory.last_action().unwrap(); // safety: test + assert_eq!(last.tool_name, "second"); // safety: test } #[test] @@ -499,9 +508,9 @@ mod tests { ); memory.record_action(a); - assert_eq!(memory.actions_by_tool("shell").len(), 3); - assert_eq!(memory.actions_by_tool("http").len(), 1); - assert_eq!(memory.actions_by_tool("nonexistent").len(), 0); + assert_eq!(memory.actions_by_tool("shell").len(), 3); // safety: test + assert_eq!(memory.actions_by_tool("http").len(), 1); // safety: test + assert_eq!(memory.actions_by_tool("nonexistent").len(), 0); // safety: test } #[test] @@ -509,25 +518,25 @@ mod tests { let mut memory = Memory::new(Uuid::new_v4()); let a0 = memory.create_action("t", serde_json::json!({})); - assert_eq!(a0.sequence, 0); + assert_eq!(a0.sequence, 0); // safety: test let a1 = memory.create_action("t", serde_json::json!({})); - assert_eq!(a1.sequence, 1); + assert_eq!(a1.sequence, 1); // safety: test let a2 = memory.create_action("t", serde_json::json!({})); - assert_eq!(a2.sequence, 2); + assert_eq!(a2.sequence, 2); // safety: test } #[test] fn test_memory_add_message_delegates_to_conversation() { let mut memory = Memory::new(Uuid::new_v4()); - assert!(memory.conversation.is_empty()); + assert!(memory.conversation.is_empty()); // safety: test memory.add_message(ChatMessage::user("hello")); memory.add_message(ChatMessage::assistant("hi")); - assert_eq!(memory.conversation.len(), 2); - assert_eq!(memory.conversation.messages()[0].content, "hello"); + assert_eq!(memory.conversation.len(), 2); // safety: test + assert_eq!(memory.conversation.messages()[0].content, "hello"); // safety: test } #[test] @@ -540,7 +549,7 @@ mod tests { .succeed(None, serde_json::json!({}), Duration::ZERO); memory.record_action(a); - assert_eq!(memory.total_cost(), Decimal::ZERO); + assert_eq!(memory.total_cost(), Decimal::ZERO); // safety: test } #[test] @@ -560,6 +569,6 @@ mod tests { memory.record_action(a2); // Both successful and failed actions contribute to total duration - assert_eq!(memory.total_duration(), Duration::from_millis(300)); + assert_eq!(memory.total_duration(), Duration::from_millis(300)); // safety: test } } diff --git a/src/context/mod.rs b/src/context/mod.rs index a7dd61ded4..4b48203842 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -6,10 +6,12 @@ //! - State machine //! - Resource tracking +pub mod fallback; mod manager; mod memory; mod state; +pub use fallback::FallbackDeliverable; pub use manager::ContextManager; pub use memory::{ActionRecord, ConversationMemory, Memory}; pub use state::{JobContext, JobState, StateTransition, TokenBudgetExceeded}; diff --git a/src/extensions/manager.rs b/src/extensions/manager.rs index 00d787a5a3..0762f3ed3f 100644 --- a/src/extensions/manager.rs +++ b/src/extensions/manager.rs @@ -45,6 +45,56 @@ struct PendingAuth { task_handle: Option>, } +struct HostedOAuthFlowStart { + name: String, + kind: ExtensionKind, + auth_url: String, + expected_state: String, + flow: crate::cli::oauth_defaults::PendingOAuthFlow, +} + +fn hosted_proxy_client_secret( + client_secret: &Option, + builtin: Option<&crate::cli::oauth_defaults::OAuthCredentials>, + exchange_proxy_configured: bool, +) -> Option { + if !exchange_proxy_configured { + return client_secret.clone(); + } + + let builtin_secret = builtin.map(|credentials| credentials.client_secret); + match (client_secret, builtin_secret) { + (Some(resolved), Some(baked_in)) if resolved == baked_in => None, + _ => client_secret.clone(), + } +} + +fn normalize_oauth_callback_path(path: &str) -> String { + let trimmed_path = path.trim_end_matches('/'); + if trimmed_path.is_empty() { + "/oauth/callback".to_string() + } else if trimmed_path.ends_with("/oauth/callback") { + trimmed_path.to_string() + } else { + format!("{trimmed_path}/oauth/callback") + } +} + +fn normalize_hosted_callback_url(callback_url: &str) -> String { + if let Ok(mut parsed) = url::Url::parse(callback_url) { + let normalized_path = normalize_oauth_callback_path(parsed.path()); + parsed.set_path(&normalized_path); + return parsed.to_string(); + } + + let normalized_callback_url = callback_url.trim_end_matches('/'); + if normalized_callback_url.ends_with("/oauth/callback") { + normalized_callback_url.to_string() + } else { + format!("{normalized_callback_url}/oauth/callback") + } +} + /// Runtime infrastructure needed for hot-activating WASM channels. /// /// Set after construction via [`ExtensionManager::set_channel_runtime`] once the @@ -361,6 +411,18 @@ pub struct ExtensionManager { /// Relay config captured at startup. Used by `auth_channel_relay` and /// `activate_channel_relay` instead of re-reading env vars. relay_config: Option, + /// Shared event sender for the relay webhook endpoint. + /// Populated by `activate_channel_relay`, consumed by the web gateway's + /// `/relay/events` handler. + relay_event_tx: Arc< + tokio::sync::Mutex< + Option>, + >, + >, + /// Per-instance callback signing secret fetched from channel-relay at activation. + /// Stored here so the web gateway can verify incoming callbacks without + /// any env var or shared secret. + relay_signing_secret_cache: Arc>>>, /// When `true`, OAuth flows always return an auth URL to the caller /// instead of opening a browser on the server via `open::that()`. /// Set by the web gateway at startup via `enable_gateway_mode()`. @@ -446,6 +508,8 @@ impl ExtensionManager { pending_oauth_flows: crate::cli::oauth_defaults::new_pending_oauth_registry(), gateway_token: std::env::var("GATEWAY_AUTH_TOKEN").ok(), relay_config: crate::config::RelayConfig::from_env(), + relay_event_tx: Arc::new(tokio::sync::Mutex::new(None)), + relay_signing_secret_cache: Arc::new(std::sync::Mutex::new(None)), gateway_mode: std::sync::atomic::AtomicBool::new(false), gateway_base_url: RwLock::new(None), pending_telegram_verification: RwLock::new(HashMap::new()), @@ -533,7 +597,9 @@ impl ExtensionManager { async fn gateway_callback_redirect_uri(&self) -> Option { use crate::cli::oauth_defaults; if oauth_defaults::use_gateway_callback() { - return Some(format!("{}/oauth/callback", oauth_defaults::callback_url())); + return Some(normalize_hosted_callback_url( + &oauth_defaults::callback_url(), + )); } // Use gateway_base_url from enable_gateway_mode() if let Some(ref base) = *self.gateway_base_url.read().await { @@ -564,6 +630,33 @@ impl ExtensionManager { }) } + /// Get the shared relay event sender for the webhook endpoint. + pub fn relay_event_tx( + &self, + ) -> Arc< + tokio::sync::Mutex< + Option>, + >, + > { + Arc::clone(&self.relay_event_tx) + } + + /// Get the per-instance callback signing secret for webhook signature verification. + /// + /// Returns the secret that was fetched from channel-relay's + /// `/relay/signing-secret` endpoint during `activate_channel_relay`. + /// Returns `None` if the relay channel has not been activated yet. + pub fn relay_signing_secret(&self) -> Option> { + self.relay_signing_secret_cache.lock().ok()?.clone() + } + + async fn clear_relay_webhook_state(&self) { + *self.relay_event_tx.lock().await = None; + if let Ok(mut cache) = self.relay_signing_secret_cache.lock() { + *cache = None; + } + } + /// Inject a registry entry for testing. The entry is added to the discovery /// cache so it appears in search results alongside built-in entries. pub async fn inject_registry_entry(&self, entry: crate::extensions::RegistryEntry) { @@ -753,12 +846,25 @@ impl ExtensionManager { *self.relay_channel_manager.write().await = Some(channel_manager); } - /// Check if a channel name corresponds to a relay extension (has stored stream token). + /// Check if a channel name corresponds to a relay extension (has stored team_id + /// or is tracked in the installed relay extensions set). pub async fn is_relay_channel(&self, name: &str) -> bool { - self.secrets - .exists(&self.user_id, &format!("relay:{}:stream_token", name)) - .await - .unwrap_or(false) + // Check in-memory installed set first (supports no-store mode) + if self.installed_relay_extensions.read().await.contains(name) { + return true; + } + // Then check persistent settings + if let Some(ref store) = self.store { + let team_id_key = format!("relay:{}:team_id", name); + store + .get_setting(&self.user_id, &team_id_key) + .await + .ok() + .flatten() + .is_some() + } else { + false + } } /// Restore persisted relay channels after startup. @@ -870,6 +976,98 @@ impl ExtensionManager { &self.pending_oauth_flows } + async fn clear_pending_extension_auth(&self, name: &str) { + { + let mut pending = self.pending_auth.write().await; + if let Some(old) = pending.remove(name) + && let Some(handle) = old.task_handle + { + handle.abort(); + } + } + + let mut flows = self.pending_oauth_flows.write().await; + flows.retain(|_, flow| flow.extension_name != name); + } + + fn rewrite_oauth_state_param( + auth_url: String, + expected_state: &str, + hosted_state: &str, + ) -> String { + if hosted_state == expected_state { + return auth_url; + } + + let Ok(mut parsed) = url::Url::parse(&auth_url) else { + return auth_url.replace( + &format!("state={}", urlencoding::encode(expected_state)), + &format!("state={}", urlencoding::encode(hosted_state)), + ); + }; + + let mut replaced = false; + let pairs: Vec<(String, String)> = parsed + .query_pairs() + .map(|(key, value)| { + if key == "state" { + replaced = true; + (key.into_owned(), hosted_state.to_string()) + } else { + (key.into_owned(), value.into_owned()) + } + }) + .collect(); + + { + let mut query_pairs = parsed.query_pairs_mut(); + query_pairs.clear(); + for (key, value) in pairs { + query_pairs.append_pair(&key, &value); + } + if !replaced { + query_pairs.append_pair("state", hosted_state); + } + } + + parsed.to_string() + } + + async fn start_gateway_oauth_flow(&self, request: HostedOAuthFlowStart) -> AuthResult { + use crate::cli::oauth_defaults; + + oauth_defaults::sweep_expired_flows(&self.pending_oauth_flows).await; + + let hosted_state = oauth_defaults::build_platform_state(&request.expected_state); + let auth_url = Self::rewrite_oauth_state_param( + request.auth_url, + &request.expected_state, + &hosted_state, + ); + + self.pending_oauth_flows + .write() + .await + .insert(request.expected_state, request.flow); + + self.pending_auth.write().await.insert( + request.name.clone(), + PendingAuth { + _name: request.name.clone(), + _kind: request.kind, + created_at: std::time::Instant::now(), + task_handle: None, + }, + ); + + AuthResult::awaiting_authorization( + request.name, + request.kind, + auth_url, + "gateway".to_string(), + ) + } + /// Broadcast an extension status change to the web UI via SSE. async fn broadcast_extension_status(&self, name: &str, status: &str, message: Option<&str>) { if let Some(ref sender) = *self.sse_sender.read().await { @@ -1167,11 +1365,7 @@ impl ExtensionManager { let active_names = self.active_channel_names.read().await; for name in installed.iter() { let active = active_names.contains(name); - let has_token = self - .secrets - .exists(&self.user_id, &format!("relay:{}:stream_token", name)) - .await - .unwrap_or(false); + let has_token = self.is_relay_channel(name).await; let registry_entry = self .registry .get_with_kind(name, Some(ExtensionKind::ChannelRelay)) @@ -1365,19 +1559,26 @@ impl ExtensionManager { // Remove from active channels self.active_channel_names.write().await.remove(name); self.persist_active_channels().await; + self.activation_errors.write().await.remove(name); - // Remove stored stream token - let _ = self - .secrets - .delete(&self.user_id, &format!("relay:{}:stream_token", name)) - .await; + // Remove stored team_id + if let Some(ref store) = self.store { + let _ = store + .delete_setting(&self.user_id, &format!("relay:{}:team_id", name)) + .await; + } + + // Stop webhook traffic before removing the channel from the managers. + self.clear_relay_webhook_state().await; - // Shut down the channel (check both runtime paths for WASM+relay and relay-only modes) + // Shut down and remove the channel (check both runtime paths for + // WASM+relay and relay-only modes). let mut shut_down = false; if let Some(ref rt) = *self.channel_runtime.read().await && let Some(channel) = rt.channel_manager.get_channel(name).await { let _ = channel.shutdown().await; + rt.channel_manager.remove(name).await; shut_down = true; } if !shut_down @@ -1385,6 +1586,7 @@ impl ExtensionManager { && let Some(channel) = cm.get_channel(name).await { let _ = channel.shutdown().await; + cm.remove(name).await; } Ok(format!("Removed channel relay '{}'", name)) @@ -2325,6 +2527,7 @@ impl ExtensionManager { use crate::cli::oauth_defaults; let is_gateway = self.should_use_gateway_mode(); + self.clear_pending_extension_auth(name).await; // Build redirect URI: gateway uses the public callback URL, // local mode binds a random port. @@ -2382,19 +2585,8 @@ impl ExtensionManager { let code_verifier = oauth_result.code_verifier; if is_gateway { - // Gateway mode: store pending flow for the /oauth/callback handler. - oauth_defaults::sweep_expired_flows(&self.pending_oauth_flows).await; - - // Platform routing: prepend instance name to state - let platform_state = oauth_defaults::build_platform_state(&expected_state); - let auth_url = if platform_state != expected_state { - oauth_result.url.replace( - &format!("state={}", urlencoding::encode(&expected_state)), - &format!("state={}", urlencoding::encode(&platform_state)), - ) - } else { - oauth_result.url - }; + let mut token_exchange_extra_params = HashMap::new(); + token_exchange_extra_params.insert("resource".to_string(), resource.clone()); let flow = oauth_defaults::PendingOAuthFlow { extension_name: name.to_string(), @@ -2413,7 +2605,7 @@ impl ExtensionManager { secrets: Arc::clone(&self.secrets), sse_sender: self.sse_sender.read().await.clone(), gateway_token: self.gateway_token.clone(), - resource: Some(resource), + token_exchange_extra_params, client_id_secret_name: if server.oauth.is_none() { Some(server.client_id_secret_name()) } else { @@ -2422,27 +2614,15 @@ impl ExtensionManager { created_at: std::time::Instant::now(), }; - self.pending_oauth_flows - .write() - .await - .insert(expected_state, flow); - - self.pending_auth.write().await.insert( - name.to_string(), - PendingAuth { - _name: name.to_string(), - _kind: ExtensionKind::McpServer, - created_at: std::time::Instant::now(), - task_handle: None, - }, - ); - - Ok(AuthResult::awaiting_authorization( - name, - ExtensionKind::McpServer, - auth_url, - "gateway".to_string(), - )) + Ok(self + .start_gateway_oauth_flow(HostedOAuthFlowStart { + name: name.to_string(), + kind: ExtensionKind::McpServer, + auth_url: oauth_result.url, + expected_state, + flow, + }) + .await) } else { // Local mode: return URL for manual opening self.pending_auth.write().await.insert( @@ -2843,9 +3023,10 @@ impl ExtensionManager { Enter it in the Setup tab or set {} env var", name, env_name ); - // Only mention the Google-specific build flag for Google providers - if auth.secret_name.to_lowercase().contains("google") { - msg.push_str(", or build with IRONCLAW_GOOGLE_CLIENT_ID"); + if let Some(override_env) = + crate::cli::oauth_defaults::builtin_client_id_override_env(&auth.secret_name) + { + msg.push_str(&format!(", or build with {override_env}")); } msg.push('.'); msg @@ -2861,20 +3042,7 @@ impl ExtensionManager { ) .await; - // Cancel any existing pending auth for this tool (frees port 9876 in TCP mode) - { - let mut pending = self.pending_auth.write().await; - if let Some(old) = pending.remove(name) - && let Some(handle) = old.task_handle - { - handle.abort(); - } - } - // Also clean up any gateway-mode pending flows for this tool - { - let mut flows = self.pending_oauth_flows.write().await; - flows.retain(|_, flow| flow.extension_name != name); - } + self.clear_pending_extension_auth(name).await; let redirect_uri = self .gateway_callback_redirect_uri() @@ -2905,30 +3073,24 @@ impl ExtensionManager { .unwrap_or_else(|| name.to_string()); if self.should_use_gateway_mode() { - // Gateway mode: store pending flow state for the web gateway's - // `/oauth/callback` handler to complete the exchange. No TCP listener - // needed — the OAuth provider redirects to the gateway URL. - oauth_defaults::sweep_expired_flows(&self.pending_oauth_flows).await; - - // Wrap the CSRF nonce with instance name for platform routing. - // Nginx at auth.DOMAIN parses `instance:nonce` to route the callback - // to the correct container. The flow is keyed by the raw nonce. - let platform_state = oauth_defaults::build_platform_state(&expected_state); - let auth_url = if platform_state != expected_state { - auth_url.replace( - &format!("state={}", urlencoding::encode(&expected_state)), - &format!("state={}", urlencoding::encode(&platform_state)), - ) - } else { - auth_url - }; + // When an exchange proxy is configured, omit the client_secret if it + // was resolved from built-in defaults (desktop app credentials). The + // proxy holds the correct web-app secret for platform-registered OAuth + // apps. Sending the desktop secret would cause a client_id/secret + // mismatch because the container's GOOGLE_OAUTH_CLIENT_ID is the web + // app, not the desktop app. + let proxy_client_secret = hosted_proxy_client_secret( + &client_secret, + builtin.as_ref(), + oauth_defaults::exchange_proxy_url().is_some(), + ); let flow = oauth_defaults::PendingOAuthFlow { extension_name: name.to_string(), display_name: display_name.clone(), token_url: oauth.token_url.clone(), client_id: client_id.clone(), - client_secret: client_secret.clone(), + client_secret: proxy_client_secret, redirect_uri: redirect_uri.clone(), code_verifier, access_token_field: oauth.access_token_field.clone(), @@ -2940,35 +3102,20 @@ impl ExtensionManager { secrets: Arc::clone(&self.secrets), sse_sender: self.sse_sender.read().await.clone(), gateway_token: self.gateway_token.clone(), - resource: None, + token_exchange_extra_params: std::collections::HashMap::new(), client_id_secret_name: None, created_at: std::time::Instant::now(), }; - // Key by raw nonce (without instance prefix) — the callback handler - // strips the prefix before lookup. - self.pending_oauth_flows - .write() - .await - .insert(expected_state, flow); - - // Register pending auth without a task handle (gateway handles completion) - self.pending_auth.write().await.insert( - name.to_string(), - PendingAuth { - _name: name.to_string(), - _kind: ExtensionKind::WasmTool, - created_at: std::time::Instant::now(), - task_handle: None, - }, - ); - - Ok(AuthResult::awaiting_authorization( - name, - ExtensionKind::WasmTool, - auth_url, - "gateway".to_string(), - )) + Ok(self + .start_gateway_oauth_flow(HostedOAuthFlowStart { + name: name.to_string(), + kind: ExtensionKind::WasmTool, + auth_url, + expected_state, + flow, + }) + .await) } else { // TCP listener mode: bind port 9876 and spawn a background task // to wait for the callback. This is the original flow for local/desktop use. @@ -3880,25 +4027,14 @@ impl ExtensionManager { /// For Telegram: accepts a bot token, registers it with channel-relay, /// and stores the returned stream token. async fn auth_channel_relay(&self, name: &str) -> Result { - // Check if already authenticated (stream token exists) - let token_key = format!("relay:{}:stream_token", name); - if self - .secrets - .exists(&self.user_id, &token_key) - .await - .unwrap_or(false) - { + // Check if already authenticated (has stored team_id) + if self.is_relay_channel(name).await { return Ok(AuthResult::authenticated(name, ExtensionKind::ChannelRelay)); } // Use relay config captured at startup let relay_config = self.relay_config()?; - let instance_id = self.relay_instance_id(relay_config); - let user_id_uuid = std::env::var("IRONCLAW_USER_ID").unwrap_or_else(|_| { - uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_DNS, self.user_id.as_bytes()).to_string() - }); - let client = crate::channels::relay::RelayClient::new( relay_config.url.clone(), relay_config.api_key.clone(), @@ -3906,22 +4042,11 @@ impl ExtensionManager { ) .map_err(|e| ExtensionError::Config(e.to_string()))?; - // OAuth redirect flow - let callback_base = self - .tunnel_url - .clone() - .or_else(|| relay_config.callback_url.clone()) - .unwrap_or_else(|| { - let host = std::env::var("GATEWAY_HOST").unwrap_or_else(|_| "127.0.0.1".into()); - let port = std::env::var("GATEWAY_PORT") - .unwrap_or_else(|_| crate::config::DEFAULT_GATEWAY_PORT.to_string()); - format!("http://{}:{}", host, port) - }); - - // Generate CSRF nonce for OAuth state parameter + // Generate CSRF nonce — IronClaw validates this on the callback to ensure + // the OAuth completion is legitimate. Channel-relay embeds it in the signed + // state and appends it to the post-OAuth redirect URL. let state_nonce = uuid::Uuid::new_v4().to_string(); let state_key = format!("relay:{}:oauth_state", name); - // Delete any stale nonce before storing the new one let _ = self.secrets.delete(&self.user_id, &state_key).await; self.secrets .create( @@ -3931,15 +4056,9 @@ impl ExtensionManager { .await .map_err(|e| ExtensionError::AuthFailed(format!("Failed to store OAuth state: {e}")))?; - let callback_url = format!( - "{}/oauth/slack/callback?state={}", - callback_base, state_nonce - ); - - match client - .initiate_oauth(&instance_id, &user_id_uuid, &callback_url) - .await - { + // Channel-relay derives all URLs from trusted instance_url in chat-api. + // We only pass the nonce for CSRF validation on the callback. + match client.initiate_oauth(Some(&state_nonce)).await { Ok(auth_url) => Ok(AuthResult::awaiting_authorization( name, ExtensionKind::ChannelRelay, @@ -3952,29 +4071,17 @@ impl ExtensionManager { /// Activate a channel-relay extension. async fn activate_channel_relay(&self, name: &str) -> Result { - let token_key = format!("relay:{}:stream_token", name); let team_id_key = format!("relay:{}:team_id", name); - // Check if we have a stream token - let stream_token = match self.secrets.get_decrypted(&self.user_id, &token_key).await { - Ok(secret) => secret.expose().to_string(), - Err(_) => { - return Err(ExtensionError::AuthRequired); - } - }; - - // Get team_id from settings - let team_id = if let Some(ref store) = self.store { - store - .get_setting(&self.user_id, &team_id_key) - .await - .ok() - .flatten() - .and_then(|v| v.as_str().map(|s| s.to_string())) - .unwrap_or_default() - } else { - String::new() - }; + let store = self.store.as_ref().ok_or(ExtensionError::AuthRequired)?; + let team_id = store + .get_setting(&self.user_id, &team_id_key) + .await + .ok() + .flatten() + .and_then(|v| v.as_str().map(|s| s.to_string())) + .filter(|s| !s.is_empty()) + .ok_or(ExtensionError::AuthRequired)?; // Use relay config captured at startup let relay_config = self.relay_config()?; @@ -3988,18 +4095,29 @@ impl ExtensionManager { ) .map_err(|e| ExtensionError::ActivationFailed(e.to_string()))?; + // Fetch the per-instance signing secret from channel-relay. + // This must succeed — there is no fallback. + let signing_secret = client.get_signing_secret(&team_id).await.map_err(|e| { + ExtensionError::Config(format!("Failed to fetch relay signing secret: {e}")) + })?; + + // Create the event channel for webhook callbacks + let (event_tx, event_rx) = tokio::sync::mpsc::channel(64); + let channel = crate::channels::relay::RelayChannel::new_with_provider( - client, + client.clone(), crate::channels::relay::channel::RelayProvider::Slack, - stream_token, - team_id, - instance_id, - self.user_id.clone(), - ) - .with_timeouts( - relay_config.stream_timeout_secs, - relay_config.backoff_initial_ms, - relay_config.backoff_max_ms, + team_id.clone(), + instance_id.clone(), + event_tx.clone(), + event_rx, + ); + + // Callback URL is now set during OAuth flow, not via PUT /callbacks. + // The relay webhook endpoint path is still needed for the web gateway. + tracing::info!( + webhook_path = %relay_config.webhook_path, + "Relay channel activated (callback URL set during OAuth)" ); // Hot-add to channel manager @@ -4013,6 +4131,13 @@ impl ExtensionManager { .await .map_err(|e| ExtensionError::ActivationFailed(e.to_string()))?; + if let Ok(mut cache) = self.relay_signing_secret_cache.lock() { + *cache = Some(signing_secret); + } + + // Store the event sender so the web gateway's relay webhook endpoint can push events + *self.relay_event_tx.lock().await = Some(event_tx); + // Mark as active self.active_channel_names .write() @@ -4035,11 +4160,11 @@ impl ExtensionManager { /// Activate a channel-relay extension from stored credentials (for startup reconnect). pub async fn activate_stored_relay(&self, name: &str) -> Result<(), ExtensionError> { + self.activate_channel_relay(name).await?; self.installed_relay_extensions .write() .await .insert(name.to_string()); - self.activate_channel_relay(name).await?; Ok(()) } @@ -4070,13 +4195,8 @@ impl ExtensionManager { if self.installed_relay_extensions.read().await.contains(name) { return Ok(ExtensionKind::ChannelRelay); } - // Also check if there's a stored stream token (persisted across restarts) - if self - .secrets - .exists(&self.user_id, &format!("relay:{}:stream_token", name)) - .await - .unwrap_or(false) - { + // Also check if there's a stored team_id (persisted across restarts) + if self.is_relay_channel(name).await { return Ok(ExtensionKind::ChannelRelay); } @@ -5210,7 +5330,8 @@ mod tests { use crate::extensions::manager::{ ChannelRuntimeState, FallbackDecision, TelegramBindingData, TelegramBindingResult, TelegramOwnerBindingState, build_wasm_channel_runtime_config_updates, - combine_install_errors, fallback_decision, infer_kind_from_url, send_telegram_text_message, + combine_install_errors, fallback_decision, hosted_proxy_client_secret, infer_kind_from_url, + normalize_hosted_callback_url, send_telegram_text_message, telegram_message_matches_verification_code, }; use crate::extensions::{ @@ -6351,24 +6472,24 @@ mod tests { } #[tokio::test] - async fn test_is_relay_channel_detects_stored_token() { + async fn test_is_relay_channel_returns_false_without_store() { let dir = tempfile::tempdir().expect("temp dir"); let mgr = make_test_manager(None, dir.path().to_path_buf()); - // No token stored → not a relay channel + // With no DB store, is_relay_channel always returns false assert!(!mgr.is_relay_channel("slack-relay").await); + } - // Store a stream token - mgr.secrets - .create( - "test", - crate::secrets::CreateSecretParams::new("relay:slack-relay:stream_token", "tok123"), - ) - .await - .expect("store token"); + #[tokio::test] + async fn test_activate_channel_relay_without_store_returns_auth_required() { + let dir = tempfile::tempdir().expect("temp dir"); + let mgr = make_test_manager(None, dir.path().to_path_buf()); - // Now it's detected as a relay channel - assert!(mgr.is_relay_channel("slack-relay").await); + let err = mgr.activate_channel_relay("slack-relay").await.unwrap_err(); + assert!( + matches!(err, ExtensionError::AuthRequired), + "expected AuthRequired, got: {err:?}" + ); } #[tokio::test] @@ -6384,18 +6505,25 @@ mod tests { cm.add(Box::new(stub)).await; mgr.set_relay_channel_manager(Arc::clone(&cm)).await; - // Mark as installed + store a token so determine_installed_kind finds it + // Mark as installed + store team_id so determine_installed_kind finds it mgr.installed_relay_extensions .write() .await .insert("slack-relay".to_string()); - mgr.secrets - .create( - "test", - crate::secrets::CreateSecretParams::new("relay:slack-relay:stream_token", "tok123"), - ) - .await - .expect("store token"); + *mgr.relay_event_tx.lock().await = Some(tokio::sync::mpsc::channel(1).0); + if let Ok(mut cache) = mgr.relay_signing_secret_cache.lock() { + *cache = Some(vec![9u8; 32]); + } + if let Some(ref store) = mgr.store { + store + .set_setting( + "test", + "relay:slack-relay:team_id", + &serde_json::json!("T123"), + ) + .await + .expect("store team_id"); + } // Verify channel exists before removal assert!(cm.get_channel("slack-relay").await.is_some()); @@ -6412,6 +6540,18 @@ mod tests { .contains("slack-relay"), "Should be removed from installed set" ); + assert!( + mgr.relay_event_tx.lock().await.is_none(), + "relay event sender should be cleared on remove" + ); + assert!( + mgr.relay_signing_secret().is_none(), + "relay signing secret cache should be cleared on remove" + ); + assert!( + cm.get_channel("slack-relay").await.is_none(), + "relay channel should be removed from the channel manager" + ); } #[tokio::test] @@ -6460,7 +6600,7 @@ mod tests { secrets: Arc::clone(&secrets), sse_sender: None, gateway_token: None, - resource: None, + token_exchange_extra_params: std::collections::HashMap::new(), client_id_secret_name: None, created_at: std::time::Instant::now(), }, @@ -6484,7 +6624,7 @@ mod tests { secrets, sse_sender: None, gateway_token: None, - resource: None, + token_exchange_extra_params: std::collections::HashMap::new(), client_id_secret_name: None, created_at: std::time::Instant::now(), }, @@ -6651,9 +6791,6 @@ mod tests { // The root cause was that `should_use_gateway_mode()` only checked the // `IRONCLAW_OAUTH_CALLBACK_URL` env var, ignoring `self.tunnel_url`. - /// Serializes env-mutating tests to prevent parallel races. - static GATEWAY_ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); - /// Build a minimal ExtensionManager with a custom tunnel_url. fn make_manager_with_tunnel(tunnel_url: Option) -> ExtensionManager { use crate::secrets::{InMemorySecretsStore, SecretsCrypto}; @@ -6686,9 +6823,11 @@ mod tests { #[test] fn should_use_gateway_mode_true_for_tunnel_url() { - let _guard = GATEWAY_ENV_MUTEX.lock().expect("env mutex poisoned"); + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok(); - // SAFETY: Under GATEWAY_ENV_MUTEX, no concurrent env access. + // SAFETY: Under ENV_MUTEX, no concurrent env access. unsafe { std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL"); } @@ -6708,7 +6847,9 @@ mod tests { #[test] fn should_use_gateway_mode_false_without_tunnel() { - let _guard = GATEWAY_ENV_MUTEX.lock().expect("env mutex poisoned"); + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok(); unsafe { std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL"); @@ -6729,7 +6870,9 @@ mod tests { #[test] fn should_use_gateway_mode_false_for_loopback_tunnel() { - let _guard = GATEWAY_ENV_MUTEX.lock().expect("env mutex poisoned"); + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok(); unsafe { std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL"); @@ -6757,9 +6900,11 @@ mod tests { impl EnvGuard { fn new() -> Self { - let guard = GATEWAY_ENV_MUTEX.lock().expect("env mutex poisoned"); + let guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok(); - // SAFETY: Under GATEWAY_ENV_MUTEX, no concurrent env access. + // SAFETY: Under ENV_MUTEX, no concurrent env access. unsafe { std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL"); } @@ -6772,7 +6917,7 @@ mod tests { impl Drop for EnvGuard { fn drop(&mut self) { - // SAFETY: Under GATEWAY_ENV_MUTEX (still held by _mutex), no concurrent env access. + // SAFETY: Under ENV_MUTEX (still held by _mutex), no concurrent env access. unsafe { if let Some(ref val) = self.original { std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val); @@ -6813,6 +6958,90 @@ mod tests { ); } + #[test] + fn gateway_callback_redirect_uri_does_not_duplicate_callback_path_from_env() { + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); + let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok(); + unsafe { + std::env::set_var( + "IRONCLAW_OAUTH_CALLBACK_URL", + "https://oauth.test.example/oauth/callback", + ); + } + + let mgr = make_manager_with_tunnel(None); + assert_eq!( + tokio_test::block_on(mgr.gateway_callback_redirect_uri()), + Some("https://oauth.test.example/oauth/callback".to_string()), + ); + + unsafe { + if let Some(val) = original { + std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val); + } else { + std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL"); + } + } + } + + #[test] + fn gateway_callback_redirect_uri_trims_trailing_slash_from_env_callback() { + let _guard = crate::config::helpers::ENV_MUTEX + .lock() + .expect("env mutex poisoned"); + let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok(); + unsafe { + std::env::set_var( + "IRONCLAW_OAUTH_CALLBACK_URL", + "https://oauth.test.example/oauth/callback/", + ); + } + + let mgr = make_manager_with_tunnel(None); + assert_eq!( + tokio_test::block_on(mgr.gateway_callback_redirect_uri()), + Some("https://oauth.test.example/oauth/callback".to_string()), + ); + + unsafe { + if let Some(val) = original { + std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val); + } else { + std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL"); + } + } + } + + #[test] + fn normalize_hosted_callback_url_preserves_query_params() { + assert_eq!( + normalize_hosted_callback_url("https://oauth.test.example?source=hosted"), + "https://oauth.test.example/oauth/callback?source=hosted" + ); + assert_eq!( + normalize_hosted_callback_url( + "https://oauth.test.example/oauth/callback?source=hosted" + ), + "https://oauth.test.example/oauth/callback?source=hosted" + ); + } + + #[test] + fn rewrite_oauth_state_param_updates_only_state_query_param() { + let auth_url = + "https://auth.example.com/authorize?client_id=abc&state=old-state&hint=state%3Dkeep"; + assert_eq!( + ExtensionManager::rewrite_oauth_state_param( + auth_url.to_string(), + "old-state", + "new-hosted-state", + ), + "https://auth.example.com/authorize?client_id=abc&state=new-hosted-state&hint=state%3Dkeep" + ); + } + #[tokio::test] async fn gateway_mode_enabled_explicitly() { let _env = EnvGuard::new(); @@ -7167,4 +7396,71 @@ mod tests { panic!("URL missing token: {url}"); // safety: test assertion } } + + // ── proxy_client_secret suppression ───────────────────────────── + + #[test] + fn test_proxy_client_secret_suppressed_when_builtin_matches_with_exchange_proxy() { + let builtin = crate::cli::oauth_defaults::builtin_credentials("google_oauth_token"); + let builtin_ref = builtin.as_ref(); + let secret = Some(builtin_ref.unwrap().client_secret.to_string()); + + let result = hosted_proxy_client_secret(&secret, builtin_ref, true); + assert_eq!( + result, None, + "built-in desktop secret must be suppressed when the exchange proxy is configured" + ); + } + + #[test] + fn test_proxy_client_secret_kept_when_not_builtin_with_exchange_proxy() { + let builtin = crate::cli::oauth_defaults::builtin_credentials("google_oauth_token"); + let secret = Some("user-entered-custom-secret".to_string()); + + let result = hosted_proxy_client_secret(&secret, builtin.as_ref(), true); + assert_eq!( + result, + Some("user-entered-custom-secret".to_string()), + "non-builtin secret must be kept even when the exchange proxy is configured" + ); + } + + #[test] + fn test_proxy_client_secret_kept_without_exchange_proxy_even_for_builtin_secret() { + let builtin = crate::cli::oauth_defaults::builtin_credentials("google_oauth_token"); + let builtin_ref = builtin.as_ref(); + let secret = Some(builtin_ref.unwrap().client_secret.to_string()); + + let result = hosted_proxy_client_secret(&secret, builtin_ref, false); + assert_eq!( + result, secret, + "built-in secret must be kept when the callback will exchange directly" + ); + } + + #[test] + fn test_proxy_client_secret_none_stays_none() { + let builtin = crate::cli::oauth_defaults::builtin_credentials("google_oauth_token"); + + let result = hosted_proxy_client_secret(&None, builtin.as_ref(), true); + assert_eq!( + result, None, + "None secret stays None even when the exchange proxy is configured" + ); + } + + #[test] + fn test_proxy_client_secret_no_builtin_provider() { + // MCP/non-Google providers have no builtin credentials + let builtin = crate::cli::oauth_defaults::builtin_credentials("mcp_notion_access_token"); + assert!(builtin.is_none()); + + let secret = Some("dcr-secret".to_string()); + let result = hosted_proxy_client_secret(&secret, builtin.as_ref(), true); + assert_eq!( + result, + Some("dcr-secret".to_string()), + "non-builtin provider secret must be kept" + ); + } } diff --git a/src/llm/oauth_helpers.rs b/src/llm/oauth_helpers.rs index 551fc04b6a..b63457fdfd 100644 --- a/src/llm/oauth_helpers.rs +++ b/src/llm/oauth_helpers.rs @@ -39,9 +39,7 @@ pub enum OAuthCallbackError { /// deployments where `127.0.0.1` is unreachable from the user's browser), /// then falls back to `http://{callback_host()}:{OAUTH_CALLBACK_PORT}`. pub fn callback_url() -> String { - std::env::var("IRONCLAW_OAUTH_CALLBACK_URL") - .ok() - .filter(|v| !v.is_empty()) + crate::config::helpers::env_or_override("IRONCLAW_OAUTH_CALLBACK_URL") .unwrap_or_else(|| format!("http://{}:{}", callback_host(), OAUTH_CALLBACK_PORT)) } @@ -57,7 +55,8 @@ pub fn callback_url() -> String { /// Note: this transmits the session token over plain HTTP — prefer SSH port /// forwarding (`ssh -L 9876:127.0.0.1:9876 user@host`) when possible. pub fn callback_host() -> String { - std::env::var("OAUTH_CALLBACK_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()) + crate::config::helpers::env_or_override("OAUTH_CALLBACK_HOST") + .unwrap_or_else(|| "127.0.0.1".to_string()) } /// Returns `true` if `host` is a loopback address that only accepts local connections. diff --git a/src/orchestrator/api.rs b/src/orchestrator/api.rs index b46aa8c68b..8d77c581ec 100644 --- a/src/orchestrator/api.rs +++ b/src/orchestrator/api.rs @@ -333,6 +333,12 @@ async fn job_event_handler( .get("session_id") .and_then(|v| v.as_str()) .map(|s| s.to_string()), + // NOTE: `fallback_deliverable` is currently always None in SSE events. + // In-memory jobs store fallback data in JobContext.metadata (accessed via job_status tool). + // Sandbox containers don't yet emit fallback data in their event payloads. + // This field is forward-compatible infrastructure for when container workers + // gain context/memory tracking capabilities. + fallback_deliverable: payload.data.get("fallback_deliverable").cloned(), }, _ => SseEvent::JobStatus { job_id: job_id_str, diff --git a/src/tools/builtin/http.rs b/src/tools/builtin/http.rs index 9d7af888da..0bd8eb37cb 100644 --- a/src/tools/builtin/http.rs +++ b/src/tools/builtin/http.rs @@ -837,7 +837,7 @@ impl Tool for HttpTool { })); if has_credentials { - return ApprovalRequirement::Always; + return ApprovalRequirement::UnlessAutoApproved; } // GET requests (or missing method, since GET is the default) are low-risk @@ -1093,25 +1093,31 @@ mod tests { } #[test] - fn test_auth_header_object_format_returns_always() { + fn test_auth_header_object_format_returns_unless_auto_approved() { let tool = HttpTool::new(); let params = serde_json::json!({ "method": "GET", "url": "https://api.example.com/data", "headers": {"Authorization": "Bearer token123"} }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] - fn test_auth_header_array_format_returns_always() { + fn test_auth_header_array_format_returns_unless_auto_approved() { let tool = HttpTool::new(); let params = serde_json::json!({ "method": "GET", "url": "https://api.example.com/data", "headers": [{"name": "Authorization", "value": "Bearer token123"}] }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] @@ -1124,7 +1130,10 @@ mod tests { "url": "https://example.com", "headers": {"AUTHORIZATION": "Bearer x"} }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); // Array format with mixed case let params = serde_json::json!({ @@ -1132,7 +1141,10 @@ mod tests { "url": "https://example.com", "headers": [{"name": "X-Api-Key", "value": "key123"}] }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] @@ -1161,8 +1173,8 @@ mod tests { }); assert_eq!( tool.requires_approval(¶ms), - ApprovalRequirement::Always, - "Header '{}' should trigger Always approval", + ApprovalRequirement::UnlessAutoApproved, + "Header '{}' should trigger UnlessAutoApproved approval", header_name ); } @@ -1203,7 +1215,7 @@ mod tests { // ── Credential registry approval tests ───────────────────────────── #[test] - fn test_host_with_credential_mapping_returns_always() { + fn test_host_with_credential_mapping_returns_unless_auto_approved() { use crate::secrets::CredentialMapping; use crate::tools::wasm::SharedCredentialRegistry; @@ -1223,7 +1235,10 @@ mod tests { "method": "GET", "url": "https://api.openai.com/v1/models" }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] @@ -1243,24 +1258,55 @@ mod tests { } #[test] - fn test_url_query_param_credential_returns_always() { + fn test_url_query_param_credential_returns_unless_auto_approved() { let tool = HttpTool::new(); let params = serde_json::json!({ "method": "GET", "url": "https://api.example.com/data?api_key=secret123" }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); } #[test] - fn test_bearer_value_in_custom_header_returns_always() { + fn test_bearer_value_in_custom_header_returns_unless_auto_approved() { let tool = HttpTool::new(); let params = serde_json::json!({ "method": "GET", "url": "https://example.com", "headers": {"X-Custom": format!("Bearer {TEST_OPENAI_API_KEY}")} }); - assert_eq!(tool.requires_approval(¶ms), ApprovalRequirement::Always); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved + ); + } + + /// Regression test: credentialed HTTP requests must return + /// `UnlessAutoApproved` (not `Always`) so that the session auto-approve + /// set is respected when the user says "always". + #[test] + fn test_credentialed_requests_respect_auto_approve() { + let tool = HttpTool::new(); + + // Manual credentials (Authorization header) + let params = serde_json::json!({ + "method": "GET", + "url": "https://api.github.com/orgs/Casa", + "headers": {"Authorization": "Bearer ghp_abc123"} + }); + // Must NOT be Always — Always ignores the session auto-approve set + assert_ne!( + tool.requires_approval(¶ms), + ApprovalRequirement::Always, + "Credentialed HTTP requests must not return Always; use UnlessAutoApproved" + ); + assert_eq!( + tool.requires_approval(¶ms), + ApprovalRequirement::UnlessAutoApproved, + ); } #[test] diff --git a/src/tools/builtin/job.rs b/src/tools/builtin/job.rs index 9346d14ab1..ea7e53054d 100644 --- a/src/tools/builtin/job.rs +++ b/src/tools/builtin/job.rs @@ -1005,7 +1005,8 @@ impl Tool for JobStatusTool { "created_at": job_ctx.created_at.to_rfc3339(), "started_at": job_ctx.started_at.map(|t| t.to_rfc3339()), "completed_at": job_ctx.completed_at.map(|t| t.to_rfc3339()), - "actual_cost": job_ctx.actual_cost.to_string() + "actual_cost": job_ctx.actual_cost.to_string(), + "fallback_deliverable": job_ctx.metadata.get("fallback_deliverable"), }); Ok(ToolOutput::success(result, start.elapsed())) } @@ -1384,7 +1385,7 @@ mod tests { let tool = CreateJobTool::new(manager.clone()); // Without sandbox deps, it should use the local path - assert!(!tool.sandbox_enabled()); + assert!(!tool.sandbox_enabled()); // safety: test let params = serde_json::json!({ "title": "Test Job", @@ -1392,12 +1393,13 @@ mod tests { }); let ctx = JobContext::default(); - let result = tool.execute(params, &ctx).await.unwrap(); + let result = tool.execute(params, &ctx).await.unwrap(); // safety: test - let job_id = result.result.get("job_id").unwrap().as_str().unwrap(); - assert!(!job_id.is_empty()); + let job_id = result.result.get("job_id").unwrap().as_str().unwrap(); // safety: test + assert!(!job_id.is_empty()); // safety: test assert_eq!( - result.result.get("status").unwrap().as_str().unwrap(), + /* safety: test */ + result.result.get("status").unwrap().as_str().unwrap(), // safety: test "pending" ); } @@ -1409,11 +1411,11 @@ mod tests { // Without sandbox let tool = CreateJobTool::new(Arc::clone(&manager)); let schema = tool.parameters_schema(); - let props = schema.get("properties").unwrap().as_object().unwrap(); - assert!(props.contains_key("title")); - assert!(props.contains_key("description")); - assert!(!props.contains_key("wait")); - assert!(!props.contains_key("mode")); + let props = schema.get("properties").unwrap().as_object().unwrap(); // safety: test + assert!(props.contains_key("title")); // safety: test + assert!(props.contains_key("description")); // safety: test + assert!(!props.contains_key("wait")); // safety: test + assert!(!props.contains_key("mode")); // safety: test } #[test] @@ -1422,7 +1424,7 @@ mod tests { // Without sandbox: default timeout let tool = CreateJobTool::new(Arc::clone(&manager)); - assert_eq!(tool.execution_timeout(), Duration::from_secs(30)); + assert_eq!(tool.execution_timeout(), Duration::from_secs(30)); // safety: test } #[tokio::test] @@ -1455,23 +1457,23 @@ mod tests { let manager = Arc::new(ContextManager::new(5)); // Create some jobs - manager.create_job("Job 1", "Desc 1").await.unwrap(); - manager.create_job("Job 2", "Desc 2").await.unwrap(); + manager.create_job("Job 1", "Desc 1").await.unwrap(); // safety: test + manager.create_job("Job 2", "Desc 2").await.unwrap(); // safety: test let tool = ListJobsTool::new(manager); let params = serde_json::json!({}); let ctx = JobContext::default(); - let result = tool.execute(params, &ctx).await.unwrap(); + let result = tool.execute(params, &ctx).await.unwrap(); // safety: test - let jobs = result.result.get("jobs").unwrap().as_array().unwrap(); - assert_eq!(jobs.len(), 2); + let jobs = result.result.get("jobs").unwrap().as_array().unwrap(); // safety: test + assert_eq!(jobs.len(), 2); // safety: test } #[tokio::test] async fn test_job_status_tool() { let manager = Arc::new(ContextManager::new(5)); - let job_id = manager.create_job("Test Job", "Description").await.unwrap(); + let job_id = manager.create_job("Test Job", "Description").await.unwrap(); // safety: test let tool = JobStatusTool::new(manager); @@ -1479,10 +1481,11 @@ mod tests { "job_id": job_id.to_string() }); let ctx = JobContext::default(); - let result = tool.execute(params, &ctx).await.unwrap(); + let result = tool.execute(params, &ctx).await.unwrap(); // safety: test assert_eq!( - result.result.get("title").unwrap().as_str().unwrap(), + /* safety: test */ + result.result.get("title").unwrap().as_str().unwrap(), // safety: test "Test Job" ); } @@ -1496,8 +1499,9 @@ mod tests { let missing_title = tool .execute(serde_json::json!({ "description": "A test job" }), &ctx) .await; - assert!(missing_title.is_err()); + assert!(missing_title.is_err()); // safety: test assert!( + /* safety: test */ missing_title .unwrap_err() .to_string() @@ -1507,8 +1511,9 @@ mod tests { let missing_description = tool .execute(serde_json::json!({ "title": "Test Job" }), &ctx) .await; - assert!(missing_description.is_err()); + assert!(missing_description.is_err()); // safety: test assert!( + /* safety: test */ missing_description .unwrap_err() .to_string() @@ -1522,19 +1527,19 @@ mod tests { let pending_id = manager .create_job_for_user("default", "Pending Job", "Todo") .await - .unwrap(); + .unwrap(); // safety: test let completed_id = manager .create_job_for_user("default", "Completed Job", "Done") .await - .unwrap(); + .unwrap(); // safety: test let failed_id = manager .create_job_for_user("default", "Failed Job", "Oops") .await - .unwrap(); + .unwrap(); // safety: test manager .create_job_for_user("other-user", "Other User Job", "Ignore") .await - .unwrap(); + .unwrap(); // safety: test manager .update_context(completed_id, |ctx| { @@ -1542,41 +1547,44 @@ mod tests { ctx.transition_to(JobState::Completed, Some("done".to_string())) }) .await - .unwrap() - .unwrap(); + .unwrap() // safety: test + .unwrap(); // safety: test manager .update_context(failed_id, |ctx| { ctx.transition_to(JobState::InProgress, None)?; ctx.transition_to(JobState::Failed, Some("boom".to_string())) }) .await - .unwrap() - .unwrap(); + .unwrap() // safety: test + .unwrap(); // safety: test let tool = ListJobsTool::new(Arc::clone(&manager)); let ctx = JobContext::default(); - let result = tool.execute(serde_json::json!({}), &ctx).await.unwrap(); + let result = tool.execute(serde_json::json!({}), &ctx).await.unwrap(); // safety: test - let jobs = result.result.get("jobs").unwrap().as_array().unwrap(); - assert_eq!(jobs.len(), 3); + let jobs = result.result.get("jobs").unwrap().as_array().unwrap(); // safety: test + assert_eq!(jobs.len(), 3); // safety: test assert!(jobs.iter().any(|job| { + // safety: test job.get("job_id").and_then(|v| v.as_str()) == Some(&pending_id.to_string()) && job.get("status").and_then(|v| v.as_str()) == Some("Pending") })); assert!(jobs.iter().any(|job| { + // safety: test job.get("job_id").and_then(|v| v.as_str()) == Some(&completed_id.to_string()) && job.get("status").and_then(|v| v.as_str()) == Some("Completed") })); assert!(jobs.iter().any(|job| { + // safety: test job.get("job_id").and_then(|v| v.as_str()) == Some(&failed_id.to_string()) && job.get("status").and_then(|v| v.as_str()) == Some("Failed") })); - let summary = result.result.get("summary").unwrap(); - assert_eq!(summary.get("total").and_then(|v| v.as_u64()), Some(3)); - assert_eq!(summary.get("pending").and_then(|v| v.as_u64()), Some(1)); - assert_eq!(summary.get("completed").and_then(|v| v.as_u64()), Some(1)); - assert_eq!(summary.get("failed").and_then(|v| v.as_u64()), Some(1)); + let summary = result.result.get("summary").unwrap(); // safety: test + assert_eq!(summary.get("total").and_then(|v| v.as_u64()), Some(3)); // safety: test + assert_eq!(summary.get("pending").and_then(|v| v.as_u64()), Some(1)); // safety: test + assert_eq!(summary.get("completed").and_then(|v| v.as_u64()), Some(1)); // safety: test + assert_eq!(summary.get("failed").and_then(|v| v.as_u64()), Some(1)); // safety: test } #[tokio::test] @@ -1585,29 +1593,30 @@ mod tests { let job_id = manager .create_job_for_user("default", "Transition Job", "Track me") .await - .unwrap(); + .unwrap(); // safety: test manager .update_context(job_id, |ctx| { ctx.transition_to(JobState::InProgress, Some("started".to_string()))?; ctx.transition_to(JobState::Completed, Some("finished".to_string())) }) .await - .unwrap() - .unwrap(); + .unwrap() // safety: test + .unwrap(); // safety: test let tool = JobStatusTool::new(Arc::clone(&manager)); let ctx = JobContext::default(); let result = tool .execute(serde_json::json!({ "job_id": job_id.to_string() }), &ctx) .await - .unwrap(); + .unwrap(); // safety: test assert_eq!( + /* safety: test */ result.result.get("status").and_then(|v| v.as_str()), Some("Completed") ); - assert!(result.result.get("started_at").unwrap().is_string()); - assert!(result.result.get("completed_at").unwrap().is_string()); + assert!(result.result.get("started_at").unwrap().is_string()); // safety: test + assert!(result.result.get("completed_at").unwrap().is_string()); // safety: test } #[tokio::test] @@ -1616,26 +1625,27 @@ mod tests { let job_id = manager .create_job_for_user("default", "Running Job", "In progress") .await - .unwrap(); + .unwrap(); // safety: test manager .update_context(job_id, |ctx| ctx.transition_to(JobState::InProgress, None)) .await - .unwrap() - .unwrap(); + .unwrap() // safety: test + .unwrap(); // safety: test let tool = CancelJobTool::new(Arc::clone(&manager)); let ctx = JobContext::default(); let result = tool .execute(serde_json::json!({ "job_id": job_id.to_string() }), &ctx) .await - .unwrap(); + .unwrap(); // safety: test assert_eq!( + /* safety: test */ result.result.get("status").and_then(|v| v.as_str()), Some("cancelled") ); - let updated = manager.get_context(job_id).await.unwrap(); - assert_eq!(updated.state, JobState::Cancelled); + let updated = manager.get_context(job_id).await.unwrap(); // safety: test + assert_eq!(updated.state, JobState::Cancelled); // safety: test } #[tokio::test] @@ -1644,39 +1654,81 @@ mod tests { let job_id = manager .create_job_for_user("default", "Completed Job", "Already done") .await - .unwrap(); + .unwrap(); // safety: test manager .update_context(job_id, |ctx| { ctx.transition_to(JobState::InProgress, None)?; ctx.transition_to(JobState::Completed, Some("done".to_string())) }) .await - .unwrap() - .unwrap(); + .unwrap() // safety: test + .unwrap(); // safety: test let tool = CancelJobTool::new(Arc::clone(&manager)); let ctx = JobContext::default(); let result = tool .execute(serde_json::json!({ "job_id": job_id.to_string() }), &ctx) .await - .unwrap(); + .unwrap(); // safety: test + + let error = result.result.get("error").and_then(|v| v.as_str()).unwrap(); // safety: test + assert!(error.contains("Cannot cancel job")); // safety: test + assert!(error.contains("completed")); // safety: test + } + + #[tokio::test] + async fn test_job_status_includes_fallback_deliverable() { + let manager = Arc::new(ContextManager::new(5)); + let job_id = manager + .create_job_for_user("default", "Failing Job", "Will fail") + .await + .unwrap(); // safety: test + + // Inject a real FallbackDeliverable into the job metadata. + let fallback = serde_json::json!({ + "partial": true, + "failure_reason": "max iterations", + "last_action": null, + "action_stats": { "total": 5, "successful": 3, "failed": 2 }, + "tokens_used": 1000, + "cost": "0.05", + "elapsed_secs": 12.5, + "repair_attempts": 1, + }); + manager + .update_context(job_id, |ctx| { + ctx.metadata = serde_json::json!({ "fallback_deliverable": fallback.clone() }); + Ok::<(), String>(()) + }) + .await + .unwrap() // safety: test + .unwrap(); // safety: test + + let tool = JobStatusTool::new(manager); + let params = serde_json::json!({ "job_id": job_id.to_string() }); + let ctx = JobContext::default(); + let result = tool.execute(params, &ctx).await.unwrap(); // safety: test - let error = result.result.get("error").and_then(|v| v.as_str()).unwrap(); - assert!(error.contains("Cannot cancel job")); - assert!(error.contains("completed")); + let fb = result.result.get("fallback_deliverable").unwrap(); // safety: test + assert_eq!(fb.get("partial").unwrap(), true); // safety: test + assert_eq!(fb.get("failure_reason").unwrap(), "max iterations"); // safety: test + let stats = fb.get("action_stats").unwrap(); // safety: test + assert_eq!(stats.get("total").unwrap(), 5); // safety: test + assert_eq!(stats.get("successful").unwrap(), 3); // safety: test + assert_eq!(stats.get("failed").unwrap(), 2); // safety: test } #[test] fn test_resolve_project_dir_auto() { let project_id = Uuid::new_v4(); - let (dir, browse_id) = resolve_project_dir(None, project_id).unwrap(); - assert!(dir.exists()); - assert!(dir.ends_with(project_id.to_string())); - assert_eq!(browse_id, project_id.to_string()); + let (dir, browse_id) = resolve_project_dir(None, project_id).unwrap(); // safety: test + assert!(dir.exists()); // safety: test + assert!(dir.ends_with(project_id.to_string())); // safety: test + assert_eq!(browse_id, project_id.to_string()); // safety: test // Must be under the projects base - let base = projects_base().canonicalize().unwrap(); - assert!(dir.starts_with(&base)); + let base = projects_base().canonicalize().unwrap(); // safety: test + assert!(dir.starts_with(&base)); // safety: test let _ = std::fs::remove_dir_all(&dir); } @@ -1684,33 +1736,34 @@ mod tests { #[test] fn test_resolve_project_dir_explicit_under_base() { let base = projects_base(); - std::fs::create_dir_all(&base).unwrap(); + std::fs::create_dir_all(&base).unwrap(); // safety: test let explicit = base.join("test_explicit_project"); // Explicit paths must already exist (no auto-create). - std::fs::create_dir_all(&explicit).unwrap(); + std::fs::create_dir_all(&explicit).unwrap(); // safety: test let project_id = Uuid::new_v4(); - let (dir, browse_id) = resolve_project_dir(Some(explicit.clone()), project_id).unwrap(); - assert!(dir.exists()); - assert_eq!(browse_id, "test_explicit_project"); + let (dir, browse_id) = resolve_project_dir(Some(explicit.clone()), project_id).unwrap(); // safety: test + assert!(dir.exists()); // safety: test + assert_eq!(browse_id, "test_explicit_project"); // safety: test - let canonical_base = base.canonicalize().unwrap(); - assert!(dir.starts_with(&canonical_base)); + let canonical_base = base.canonicalize().unwrap(); // safety: test + assert!(dir.starts_with(&canonical_base)); // safety: test let _ = std::fs::remove_dir_all(&explicit); } #[test] fn test_resolve_project_dir_rejects_outside_base() { - let tmp = tempfile::tempdir().unwrap(); + let tmp = tempfile::tempdir().unwrap(); // safety: test let escape_attempt = tmp.path().join("evil_project"); // Don't create it: explicit paths that don't exist are rejected // before the prefix check even runs. let result = resolve_project_dir(Some(escape_attempt), Uuid::new_v4()); - assert!(result.is_err()); + assert!(result.is_err()); // safety: test let err = result.unwrap_err().to_string(); assert!( + /* safety: test */ err.contains("does not exist"), "expected 'does not exist' error, got: {}", err @@ -1720,13 +1773,14 @@ mod tests { #[test] fn test_resolve_project_dir_rejects_outside_base_existing() { // A directory that exists but is outside the projects base. - let tmp = tempfile::tempdir().unwrap(); + let tmp = tempfile::tempdir().unwrap(); // safety: test let outside = tmp.path().to_path_buf(); let result = resolve_project_dir(Some(outside), Uuid::new_v4()); - assert!(result.is_err()); + assert!(result.is_err()); // safety: test let err = result.unwrap_err().to_string(); assert!( + /* safety: test */ err.contains("must be under"), "expected 'must be under' error, got: {}", err @@ -1740,7 +1794,7 @@ mod tests { let traversal = base.join("legit").join("..").join("..").join(".ssh"); let result = resolve_project_dir(Some(traversal), Uuid::new_v4()); - assert!(result.is_err(), "traversal path should be rejected"); + assert!(result.is_err(), "traversal path should be rejected"); // safety: test // Traversal path that actually resolves gets the prefix check. // `base/../` resolves to the parent of projects base, which is outside. @@ -1748,7 +1802,7 @@ mod tests { std::fs::create_dir_all(&base_parent).ok(); if base_parent.exists() { let result = resolve_project_dir(Some(base_parent.clone()), Uuid::new_v4()); - assert!(result.is_err(), "path outside base should be rejected"); + assert!(result.is_err(), "path outside base should be rejected"); // safety: test let _ = std::fs::remove_dir_all(&base_parent); } } @@ -1762,8 +1816,9 @@ mod tests { )); let tool = CreateJobTool::new(manager).with_sandbox(jm, None); let schema = tool.parameters_schema(); - let props = schema.get("properties").unwrap().as_object().unwrap(); + let props = schema.get("properties").unwrap().as_object().unwrap(); // safety: test assert!( + /* safety: test */ props.contains_key("project_dir"), "sandbox schema must expose project_dir" ); @@ -1778,8 +1833,9 @@ mod tests { )); let tool = CreateJobTool::new(manager).with_sandbox(jm, None); let schema = tool.parameters_schema(); - let props = schema.get("properties").unwrap().as_object().unwrap(); + let props = schema.get("properties").unwrap().as_object().unwrap(); // safety: test assert!( + /* safety: test */ props.contains_key("credentials"), "sandbox schema must expose credentials" ); @@ -1792,13 +1848,13 @@ mod tests { // No credentials parameter let params = serde_json::json!({"title": "t", "description": "d"}); - let grants = tool.parse_credentials(¶ms, "user1").await.unwrap(); - assert!(grants.is_empty()); + let grants = tool.parse_credentials(¶ms, "user1").await.unwrap(); // safety: test + assert!(grants.is_empty()); // safety: test // Empty credentials object let params = serde_json::json!({"credentials": {}}); - let grants = tool.parse_credentials(¶ms, "user1").await.unwrap(); - assert!(grants.is_empty()); + let grants = tool.parse_credentials(¶ms, "user1").await.unwrap(); // safety: test + assert!(grants.is_empty()); // safety: test } #[tokio::test] @@ -1808,9 +1864,10 @@ mod tests { let params = serde_json::json!({"credentials": {"my_secret": "MY_SECRET"}}); let result = tool.parse_credentials(¶ms, "user1").await; - assert!(result.is_err()); + assert!(result.is_err()); // safety: test let err = result.unwrap_err().to_string(); assert!( + /* safety: test */ err.contains("no secrets store"), "expected 'no secrets store' error, got: {}", err @@ -1828,9 +1885,10 @@ mod tests { let params = serde_json::json!({"credentials": {"nonexistent_secret": "SOME_VAR"}}); let result = tool.parse_credentials(¶ms, "user1").await; - assert!(result.is_err()); + assert!(result.is_err()); // safety: test let err = result.unwrap_err().to_string(); assert!( + /* safety: test */ err.contains("not found"), "expected 'not found' error, got: {}", err @@ -1852,17 +1910,17 @@ mod tests { CreateSecretParams::new("github_token", TEST_GITHUB_TOKEN), ) .await - .unwrap(); + .unwrap(); // safety: test let tool = CreateJobTool::new(manager).with_secrets(Arc::clone(&secrets)); let params = serde_json::json!({ "credentials": {"github_token": "GITHUB_TOKEN"} }); - let grants = tool.parse_credentials(¶ms, "user1").await.unwrap(); - assert_eq!(grants.len(), 1); - assert_eq!(grants[0].secret_name, "github_token"); - assert_eq!(grants[0].env_var, "GITHUB_TOKEN"); + let grants = tool.parse_credentials(¶ms, "user1").await.unwrap(); // safety: test + assert_eq!(grants.len(), 1); // safety: test + assert_eq!(grants[0].secret_name, "github_token"); // safety: test + assert_eq!(grants[0].env_var, "GITHUB_TOKEN"); // safety: test } fn test_prompt_tool(queue: PromptQueue) -> JobPromptTool { @@ -1876,7 +1934,7 @@ mod tests { let job_id = cm .create_job_for_user("default", "Test Job", "desc") .await - .unwrap(); + .unwrap(); // safety: test let queue: PromptQueue = Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())); @@ -1889,18 +1947,19 @@ mod tests { }); let ctx = JobContext::default(); - let result = tool.execute(params, &ctx).await.unwrap(); + let result = tool.execute(params, &ctx).await.unwrap(); // safety: test assert_eq!( - result.result.get("status").unwrap().as_str().unwrap(), + /* safety: test */ + result.result.get("status").unwrap().as_str().unwrap(), // safety: test "queued" ); let q = queue.lock().await; - let prompts = q.get(&job_id).unwrap(); - assert_eq!(prompts.len(), 1); - assert_eq!(prompts[0].content, "What's the status?"); - assert!(!prompts[0].done); + let prompts = q.get(&job_id).unwrap(); // safety: test + assert_eq!(prompts.len(), 1); // safety: test + assert_eq!(prompts[0].content, "What's the status?"); // safety: test + assert!(!prompts[0].done); // safety: test } #[tokio::test] @@ -1910,6 +1969,7 @@ mod tests { Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())); let tool = test_prompt_tool(queue); assert_eq!( + /* safety: test */ tool.requires_approval(&serde_json::json!({})), ApprovalRequirement::UnlessAutoApproved ); @@ -1928,7 +1988,7 @@ mod tests { let ctx = JobContext::default(); let result = tool.execute(params, &ctx).await; - assert!(result.is_err()); + assert!(result.is_err()); // safety: test } #[tokio::test] @@ -1943,7 +2003,7 @@ mod tests { let ctx = JobContext::default(); let result = tool.execute(params, &ctx).await; - assert!(result.is_err()); + assert!(result.is_err()); // safety: test } #[tokio::test] @@ -1958,7 +2018,7 @@ mod tests { let job_id = cm .create_job_for_user("owner-user", "Secret Job", "classified") .await - .unwrap(); + .unwrap(); // safety: test // We need a Store to construct the tool, but creating one requires // a database URL. Instead, test the ownership logic directly: @@ -1968,9 +2028,9 @@ mod tests { ..Default::default() }; - let job_ctx = cm.get_context(job_id).await.unwrap(); - assert_ne!(job_ctx.user_id, attacker_ctx.user_id); - assert_eq!(job_ctx.user_id, "owner-user"); + let job_ctx = cm.get_context(job_id).await.unwrap(); // safety: test + assert_ne!(job_ctx.user_id, attacker_ctx.user_id); // safety: test + assert_eq!(job_ctx.user_id, "owner-user"); // safety: test } #[test] @@ -1991,12 +2051,12 @@ mod tests { "required": ["job_id"] }); - let props = schema.get("properties").unwrap().as_object().unwrap(); - assert!(props.contains_key("job_id")); - assert!(props.contains_key("limit")); - let required = schema.get("required").unwrap().as_array().unwrap(); - assert_eq!(required.len(), 1); - assert_eq!(required[0].as_str().unwrap(), "job_id"); + let props = schema.get("properties").unwrap().as_object().unwrap(); // safety: test + assert!(props.contains_key("job_id")); // safety: test + assert!(props.contains_key("limit")); // safety: test + let required = schema.get("required").unwrap().as_array().unwrap(); // safety: test + assert_eq!(required.len(), 1); // safety: test + assert_eq!(required[0].as_str().unwrap(), "job_id"); // safety: test } #[tokio::test] @@ -2005,7 +2065,7 @@ mod tests { let job_id = cm .create_job_for_user("owner-user", "Test Job", "desc") .await - .unwrap(); + .unwrap(); // safety: test let queue: PromptQueue = Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())); @@ -2023,9 +2083,10 @@ mod tests { }; let result = tool.execute(params, &ctx).await; - assert!(result.is_err()); + assert!(result.is_err()); // safety: test let err = result.unwrap_err().to_string(); assert!( + /* safety: test */ err.contains("does not belong to current user"), "expected ownership error, got: {}", err @@ -2035,33 +2096,34 @@ mod tests { #[tokio::test] async fn test_resolve_job_id_full_uuid() { let cm = ContextManager::new(5); - let job_id = cm.create_job("Test", "Desc").await.unwrap(); + let job_id = cm.create_job("Test", "Desc").await.unwrap(); // safety: test - let resolved = resolve_job_id(&job_id.to_string(), &cm).await.unwrap(); - assert_eq!(resolved, job_id); + let resolved = resolve_job_id(&job_id.to_string(), &cm).await.unwrap(); // safety: test + assert_eq!(resolved, job_id); // safety: test } #[tokio::test] async fn test_resolve_job_id_short_prefix() { let cm = ContextManager::new(5); - let job_id = cm.create_job("Test", "Desc").await.unwrap(); + let job_id = cm.create_job("Test", "Desc").await.unwrap(); // safety: test // Use first 8 hex chars (without dashes) let hex = job_id.to_string().replace('-', ""); let prefix = &hex[..8]; - let resolved = resolve_job_id(prefix, &cm).await.unwrap(); - assert_eq!(resolved, job_id); + let resolved = resolve_job_id(prefix, &cm).await.unwrap(); // safety: test + assert_eq!(resolved, job_id); // safety: test } #[tokio::test] async fn test_resolve_job_id_no_match() { let cm = ContextManager::new(5); - cm.create_job("Test", "Desc").await.unwrap(); + cm.create_job("Test", "Desc").await.unwrap(); // safety: test let result = resolve_job_id("00000000", &cm).await; - assert!(result.is_err()); + assert!(result.is_err()); // safety: test let err = result.unwrap_err().to_string(); assert!( + /* safety: test */ err.contains("no job found"), "expected 'no job found', got: {}", err @@ -2072,6 +2134,6 @@ mod tests { async fn test_resolve_job_id_invalid_input() { let cm = ContextManager::new(5); let result = resolve_job_id("not-hex-at-all!", &cm).await; - assert!(result.is_err()); + assert!(result.is_err()); // safety: test } } diff --git a/src/worker/job.rs b/src/worker/job.rs index 0f0e969ee7..87b9cfeb9f 100644 --- a/src/worker/job.rs +++ b/src/worker/job.rs @@ -196,6 +196,7 @@ impl Worker { .get("session_id") .and_then(|v| v.as_str()) .map(|s| s.to_string()), + fallback_deliverable: data.get("fallback_deliverable").cloned(), }), _ => None, }; @@ -960,9 +961,14 @@ Report when the job is complete or if you encounter issues you cannot resolve."# } async fn mark_failed(&self, reason: &str) -> Result<(), Error> { + // Build fallback deliverable from memory before transitioning. + let fallback = self.build_fallback(reason).await; + self.context_manager() .update_context(self.job_id, |ctx| { - ctx.transition_to(JobState::Failed, Some(reason.to_string())) + ctx.transition_to(JobState::Failed, Some(reason.to_string()))?; + store_fallback_in_metadata(ctx, fallback.as_ref()); + Ok(()) }) .await? .map_err(|s| crate::error::JobError::ContextError { @@ -983,8 +989,15 @@ Report when the job is complete or if you encounter issues you cannot resolve."# } async fn mark_stuck(&self, reason: &str) -> Result<(), Error> { + // Build fallback deliverable from memory before transitioning. + let fallback = self.build_fallback(reason).await; + self.context_manager() - .update_context(self.job_id, |ctx| ctx.mark_stuck(reason)) + .update_context(self.job_id, |ctx| { + ctx.mark_stuck(reason)?; + store_fallback_in_metadata(ctx, fallback.as_ref()); + Ok(()) + }) .await? .map_err(|s| crate::error::JobError::ContextError { id: self.job_id, @@ -1002,6 +1015,57 @@ Report when the job is complete or if you encounter issues you cannot resolve."# self.persist_status(JobState::Stuck, Some(reason.to_string())); Ok(()) } + + /// Build a [`FallbackDeliverable`] from the current job context and memory. + async fn build_fallback(&self, reason: &str) -> Option { + let memory = match self.context_manager().get_memory(self.job_id).await { + Ok(memory) => memory, + Err(e) => { + tracing::warn!( + job_id = %self.job_id, + "Failed to load memory while building fallback deliverable: {e}" + ); + return None; + } + }; + let ctx = match self.context_manager().get_context(self.job_id).await { + Ok(ctx) => ctx, + Err(e) => { + tracing::warn!( + job_id = %self.job_id, + "Failed to load context while building fallback deliverable: {e}" + ); + return None; + } + }; + Some(crate::context::FallbackDeliverable::build( + &ctx, &memory, reason, + )) + } +} + +/// Store a fallback deliverable in the job context's metadata. +fn store_fallback_in_metadata( + ctx: &mut crate::context::JobContext, + fallback: Option<&crate::context::FallbackDeliverable>, +) { + let Some(fb) = fallback else { + return; + }; + match serde_json::to_value(fb) { + Ok(val) => { + if !ctx.metadata.is_object() { + ctx.metadata = serde_json::json!({}); + } + ctx.metadata["fallback_deliverable"] = val; + } + Err(e) => { + tracing::warn!( + "Failed to serialize fallback deliverable for job {}: {e}", + ctx.job_id + ); + } + } } /// Job delegate: implements `LoopDelegate` for the background job context. @@ -1440,7 +1504,7 @@ mod tests { } let cm = Arc::new(crate::context::ContextManager::new(5)); - let job_id = cm.create_job("test", "test job").await.unwrap(); + let job_id = cm.create_job("test", "test job").await.unwrap(); // safety: test let deps = WorkerDeps { context_manager: cm, @@ -1472,8 +1536,9 @@ mod tests { tool_call_id: "call_abc123".to_string(), }; - assert_eq!(selection.tool_call_id, "call_abc123"); + assert_eq!(selection.tool_call_id, "call_abc123"); // safety: test assert_ne!( + /* safety: test */ selection.tool_call_id, "tool_call_id", "tool_call_id must not be the hardcoded placeholder string" ); @@ -1509,11 +1574,12 @@ mod tests { let results = worker.execute_tools_parallel(&selections).await; let elapsed = start.elapsed(); - assert_eq!(results.len(), 3); + assert_eq!(results.len(), 3); // safety: test for r in &results { - assert!(r.result.is_ok(), "Tool should succeed"); + assert!(r.result.is_ok(), "Tool should succeed"); // safety: test } assert!( + /* safety: test */ elapsed < Duration::from_millis(800), "Parallel execution took {:?}, expected < 800ms (sequential would be ~600ms)", elapsed @@ -1565,9 +1631,9 @@ mod tests { let results = worker.execute_tools_parallel(&selections).await; - assert!(results[0].result.as_ref().unwrap().contains("done_tool_a")); - assert!(results[1].result.as_ref().unwrap().contains("done_tool_b")); - assert!(results[2].result.as_ref().unwrap().contains("done_tool_c")); + assert!(results[0].result.as_ref().unwrap().contains("done_tool_a")); // safety: test + assert!(results[1].result.as_ref().unwrap().contains("done_tool_b")); // safety: test + assert!(results[2].result.as_ref().unwrap().contains("done_tool_c")); // safety: test } #[tokio::test] @@ -1583,8 +1649,9 @@ mod tests { }]; let results = worker.execute_tools_parallel(&selections).await; - assert_eq!(results.len(), 1); + assert_eq!(results.len(), 1); // safety: test assert!( + /* safety: test */ results[0].result.is_err(), "Missing tool should produce an error, not a panic" ); @@ -1600,23 +1667,24 @@ mod tests { ctx.transition_to(JobState::InProgress, None) }) .await - .unwrap() - .unwrap(); + .unwrap() // safety: test + .unwrap(); // safety: test - worker.mark_completed().await.unwrap(); + worker.mark_completed().await.unwrap(); // safety: test let ctx = worker .context_manager() .get_context(worker.job_id) .await - .unwrap(); - assert_eq!(ctx.state, JobState::Completed); + .unwrap(); // safety: test + assert_eq!(ctx.state, JobState::Completed); // safety: test // Second mark_completed should succeed (idempotent) rather than // erroring, matching the fix for the execution_loop / worker wrapper // race condition. let result = worker.mark_completed().await; assert!( + /* safety: test */ result.is_ok(), "Completed -> Completed transition should be idempotent" ); @@ -1641,7 +1709,7 @@ mod tests { } let cm = Arc::new(crate::context::ContextManager::new(5)); - let job_id = cm.create_job("test", "test job").await.unwrap(); + let job_id = cm.create_job("test", "test job").await.unwrap(); // safety: test let deps = WorkerDeps { context_manager: cm, @@ -1740,6 +1808,7 @@ mod tests { .execute_tool("needs_approval", &serde_json::json!({})) .await; assert!( + /* safety: test */ result.is_err(), "Should be blocked without approval context" ); @@ -1752,7 +1821,7 @@ mod tests { let result = worker_allowed .execute_tool("needs_approval", &serde_json::json!({})) .await; - assert!(result.is_ok(), "Should be allowed with autonomous context"); + assert!(result.is_ok(), "Should be allowed with autonomous context"); // safety: test } #[tokio::test] @@ -1766,6 +1835,7 @@ mod tests { .execute_tool("always_approval", &serde_json::json!({})) .await; assert!( + /* safety: test */ result.is_err(), "Always tool should be blocked without permission" ); @@ -1781,6 +1851,7 @@ mod tests { .execute_tool("always_approval", &serde_json::json!({})) .await; assert!( + /* safety: test */ result.is_ok(), "Always tool should be allowed with permission" ); @@ -1797,8 +1868,8 @@ mod tests { ctx.transition_to(JobState::InProgress, None) }) .await - .unwrap() - .unwrap(); + .unwrap() // safety: test + .unwrap(); // safety: test // Set a token budget worker @@ -1807,16 +1878,17 @@ mod tests { ctx.max_tokens = 100; }) .await - .unwrap(); + .unwrap(); // safety: test // Simulate adding tokens that exceed the budget let budget_result = worker .context_manager() .update_context(worker.job_id, |ctx| ctx.add_tokens(200)) .await - .unwrap(); + .unwrap(); // safety: test assert!( + /* safety: test */ budget_result.is_err(), "Should return error when token budget exceeded" ); @@ -1825,13 +1897,13 @@ mod tests { worker .mark_failed(&budget_result.unwrap_err().to_string()) .await - .unwrap(); + .unwrap(); // safety: test let ctx = worker .context_manager() .get_context(worker.job_id) .await - .unwrap(); - assert_eq!(ctx.state, JobState::Failed); + .unwrap(); // safety: test + assert_eq!(ctx.state, JobState::Failed); // safety: test } #[tokio::test] @@ -1845,21 +1917,22 @@ mod tests { ctx.transition_to(JobState::InProgress, None) }) .await - .unwrap() - .unwrap(); + .unwrap() // safety: test + .unwrap(); // safety: test // Simulate what the execution loop does when max_iterations is exceeded worker .mark_failed("Maximum iterations exceeded: job hit the iteration cap") .await - .unwrap(); + .unwrap(); // safety: test let ctx = worker .context_manager() .get_context(worker.job_id) .await - .unwrap(); + .unwrap(); // safety: test assert_eq!( + /* safety: test */ ctx.state, JobState::Failed, "Iteration cap should transition to Failed, not Stuck" @@ -1989,4 +2062,52 @@ mod tests { "Should skip empty first reasoning and return the first non-empty one" ); } + + #[test] + fn test_store_fallback_in_metadata_roundtrip() { + use crate::context::FallbackDeliverable; + + let mut ctx = JobContext::new("Test", "fallback roundtrip"); + let memory = crate::context::Memory::new(ctx.job_id); + let fb = FallbackDeliverable::build(&ctx, &memory, "test failure"); + + // Store into metadata + store_fallback_in_metadata(&mut ctx, Some(&fb)); + + // Verify it's stored and can be deserialized back + let stored = ctx.metadata.get("fallback_deliverable"); + assert!(stored.is_some(), "fallback missing from metadata"); // safety: test + + let recovered: FallbackDeliverable = + serde_json::from_value(stored.unwrap().clone()).expect("deserialize fallback"); // safety: test + assert_eq!(recovered.failure_reason, "test failure"); // safety: test + assert!(!recovered.partial); // safety: test + } + + #[test] + fn test_store_fallback_handles_non_object_metadata() { + use crate::context::FallbackDeliverable; + + let mut ctx = JobContext::new("Test", "non-object metadata"); + ctx.metadata = serde_json::json!("not an object"); + + let memory = crate::context::Memory::new(ctx.job_id); + let fb = FallbackDeliverable::build(&ctx, &memory, "failed"); + + store_fallback_in_metadata(&mut ctx, Some(&fb)); + + // Must normalize to object and store + assert!(ctx.metadata.is_object()); // safety: test + assert!(ctx.metadata.get("fallback_deliverable").is_some()); // safety: test + } + + #[test] + fn test_store_fallback_none_is_noop() { + let mut ctx = JobContext::new("Test", "noop"); + let original = ctx.metadata.clone(); + + store_fallback_in_metadata(&mut ctx, None); + + assert_eq!(ctx.metadata, original); // safety: test + } } diff --git a/src/workspace/README.md b/src/workspace/README.md index 2b3ee5b48b..db65294d42 100644 --- a/src/workspace/README.md +++ b/src/workspace/README.md @@ -38,12 +38,17 @@ workspace/ ## Using the Workspace ```rust +use std::sync::Arc; use crate::workspace::{Workspace, OpenAiEmbeddings, paths}; -// Create workspace for a user +// Create workspace for a user (wraps embeddings in a default LRU cache) let workspace = Workspace::new("user_123", pool) .with_embeddings(Arc::new(OpenAiEmbeddings::new(api_key))); +// For tests: skip the cache layer (avoids unnecessary overhead with mocks) +// let workspace = Workspace::new("user_123", pool) +// .with_embeddings_uncached(Arc::new(MockEmbeddings::new(1536))); + // Read/write any path let doc = workspace.read("projects/alpha/notes.md").await?; workspace.write("context/priorities.md", "# Priorities\n\n1. Feature X").await?; diff --git a/src/workspace/embedding_cache.rs b/src/workspace/embedding_cache.rs new file mode 100644 index 0000000000..848bd2e501 --- /dev/null +++ b/src/workspace/embedding_cache.rs @@ -0,0 +1,613 @@ +//! LRU embedding cache wrapping any [`EmbeddingProvider`]. +//! +//! Avoids redundant HTTP calls for identical texts by caching embeddings +//! in memory keyed by `SHA-256(model_name + "\0" + text)`. +//! +//! Follows the same cache pattern as `llm::response_cache::CachedProvider`: +//! `HashMap` + `last_accessed` tracking + manual LRU eviction. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +use async_trait::async_trait; +use sha2::{Digest, Sha256}; + +use crate::workspace::embeddings::{EmbeddingError, EmbeddingProvider}; + +/// Configuration for the embedding cache. +#[derive(Debug, Clone)] +pub struct EmbeddingCacheConfig { + /// Maximum number of cached embeddings (default 10,000). + /// + /// Approximate raw embedding payload: `max_entries × dimension × 4 bytes`. + /// At 10,000 entries × 1536 floats ≈ 58 MB (payload only; actual memory + /// is higher due to HashMap buckets, `[u8; 32]` hash keys, `Vec`/`Instant` + /// per-entry overhead). + pub max_entries: usize, +} + +impl Default for EmbeddingCacheConfig { + fn default() -> Self { + Self { + max_entries: crate::config::DEFAULT_EMBEDDING_CACHE_SIZE, + } + } +} + +struct CacheEntry { + embedding: Vec, + last_accessed: Instant, +} + +/// Embedding provider wrapper that caches results in memory. +/// +/// Thread-safe via `std::sync::Mutex`. The lock is **never held** +/// across `.await` points (all critical sections are scoped blocks), +/// so a synchronous mutex is cheaper than `tokio::sync::Mutex`. +pub struct CachedEmbeddingProvider { + inner: Arc, + cache: Mutex>, + config: EmbeddingCacheConfig, +} + +impl CachedEmbeddingProvider { + /// Wrap a provider with LRU caching. + /// + /// `config.max_entries` is clamped to at least 1. + pub fn new(inner: Arc, config: EmbeddingCacheConfig) -> Self { + let config = EmbeddingCacheConfig { + max_entries: config.max_entries.max(1), + }; + if config.max_entries > 100_000 { + tracing::warn!( + max_entries = config.max_entries, + "Embedding cache size exceeds 100,000 entries; memory usage may be significant" + ); + } + Self { + inner, + cache: Mutex::new(HashMap::with_capacity(config.max_entries.min(1024))), + config, + } + } + + /// Number of entries currently in the cache. + pub fn len(&self) -> usize { + self.cache.lock().unwrap_or_else(|e| e.into_inner()).len() + } + + /// Whether the cache is empty. + pub fn is_empty(&self) -> bool { + self.cache + .lock() + .unwrap_or_else(|e| e.into_inner()) + .is_empty() + } + + /// Clear all cached entries. + pub fn clear(&self) { + self.cache.lock().unwrap_or_else(|e| e.into_inner()).clear(); + } + + /// Build a deterministic cache key: `SHA-256(model_name + "\0" + text)`. + /// + /// Returns raw 32-byte hash to avoid a 64-char hex String allocation per lookup. + fn cache_key(&self, text: &str) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(self.inner.model_name().as_bytes()); + hasher.update(b"\0"); + hasher.update(text.as_bytes()); + hasher.finalize().into() + } + + /// Evict the least-recently-used entry if at capacity (single-entry path). + // TODO: O(n) scan per eviction. If max_entries grows large, switch to + // an ordered data structure (e.g. `IndexMap` with swap_remove, or a + // linked-list LRU like the `lru` crate). + fn evict_lru(cache: &mut HashMap<[u8; 32], CacheEntry>, max_entries: usize) { + while cache.len() >= max_entries { + let oldest_key = cache + .iter() + .min_by_key(|(_, entry)| entry.last_accessed) + .map(|(k, _)| *k); + + if let Some(k) = oldest_key { + cache.remove(&k); + } else { + break; + } + } + } + + /// Evict the `k` oldest entries in O(n) average time via partial selection. + /// + /// Used by `embed_batch` to avoid the O(n×m) cost of calling + /// `evict_lru` per insert. + fn evict_k_oldest(cache: &mut HashMap<[u8; 32], CacheEntry>, k: usize) { + if k == 0 || cache.is_empty() { + return; + } + if k >= cache.len() { + cache.clear(); + return; + } + // Partial selection: find the k oldest in O(n) average via + // select_nth_unstable_by_key, then remove the first k entries. + let mut entries: Vec<([u8; 32], Instant)> = cache + .iter() + .map(|(key, entry)| (*key, entry.last_accessed)) + .collect(); + entries.select_nth_unstable_by_key(k - 1, |(_, t)| *t); + for (key, _) in entries.into_iter().take(k) { + cache.remove(&key); + } + } +} + +#[async_trait] +impl EmbeddingProvider for CachedEmbeddingProvider { + fn dimension(&self) -> usize { + self.inner.dimension() + } + + fn model_name(&self) -> &str { + self.inner.model_name() + } + + fn max_input_length(&self) -> usize { + self.inner.max_input_length() + } + + async fn embed(&self, text: &str) -> Result, EmbeddingError> { + let key = self.cache_key(text); + + // Check cache (short critical section) + { + let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner()); + if let Some(entry) = guard.get_mut(&key) { + entry.last_accessed = Instant::now(); + tracing::trace!("embedding cache hit"); + return Ok(entry.embedding.clone()); + } + } + // Lock released before HTTP call. + // NOTE: Thundering herd — multiple concurrent callers with the same + // uncached key will each call the inner provider. This is acceptable: + // embeddings are idempotent and the last writer wins in the HashMap. + + let embedding = self.inner.embed(text).await?; + + // Store result. Re-check under lock: another concurrent caller may + // have inserted this key while the lock was released for the HTTP call. + { + let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner()); + if let Some(entry) = guard.get_mut(&key) { + // Key already present (thundering herd) — just update, no eviction needed. + entry.embedding = embedding.clone(); + entry.last_accessed = Instant::now(); + } else { + Self::evict_lru(&mut guard, self.config.max_entries); + guard.insert( + key, + CacheEntry { + embedding: embedding.clone(), + last_accessed: Instant::now(), + }, + ); + } + } + + tracing::trace!("embedding cache miss"); + Ok(embedding) + } + + async fn embed_batch(&self, texts: &[String]) -> Result>, EmbeddingError> { + if texts.is_empty() { + return Ok(Vec::new()); + } + + // Partition into hits and misses + let keys: Vec<[u8; 32]> = texts.iter().map(|t| self.cache_key(t)).collect(); + let mut results: Vec>> = vec![None; texts.len()]; + let mut miss_indices: Vec = Vec::new(); + + { + let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner()); + let now = Instant::now(); + for (i, key) in keys.iter().enumerate() { + if let Some(entry) = guard.get_mut(key) { + entry.last_accessed = now; + results[i] = Some(entry.embedding.clone()); + } else { + miss_indices.push(i); + } + } + } + // Lock released before HTTP call + + if miss_indices.is_empty() { + tracing::trace!(count = texts.len(), "embedding batch: all cache hits"); + // All slots populated from cache hits + return results + .into_iter() + .enumerate() + .map(|(i, slot)| { + slot.ok_or_else(|| { + EmbeddingError::InvalidResponse(format!( + "embedding slot {i} was not populated" + )) + }) + }) + .collect::, _>>(); + } + + // Fetch missing embeddings + let miss_texts: Vec = miss_indices.iter().map(|&i| texts[i].clone()).collect(); + let new_embeddings = self.inner.embed_batch(&miss_texts).await?; + + if new_embeddings.len() != miss_indices.len() { + return Err(EmbeddingError::InvalidResponse(format!( + "embed_batch returned {} embeddings, expected {}", + new_embeddings.len(), + miss_indices.len() + ))); + } + + tracing::trace!( + hits = texts.len() - miss_indices.len(), + misses = miss_indices.len(), + "embedding batch: partial cache" + ); + + // Assemble results first (all misses, regardless of cache capacity). + for (orig_idx, emb) in miss_indices.iter().copied().zip(&new_embeddings) { + results[orig_idx] = Some(emb.clone()); + } + + // Cache the new embeddings, respecting max_entries. + { + let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner()); + // When misses exceed capacity, clear and only cache the tail. + let cacheable = miss_indices.len().min(self.config.max_entries); + let skip = miss_indices.len() - cacheable; + let need_to_evict = (guard.len() + cacheable).saturating_sub(self.config.max_entries); + if need_to_evict > 0 { + Self::evict_k_oldest(&mut guard, need_to_evict); + } + let now = Instant::now(); + for (&orig_idx, emb) in miss_indices[skip..].iter().zip(&new_embeddings[skip..]) { + guard.insert( + keys[orig_idx], + CacheEntry { + embedding: emb.clone(), + last_accessed: now, + }, + ); + } + } + + results + .into_iter() + .enumerate() + .map(|(i, slot)| { + slot.ok_or_else(|| { + EmbeddingError::InvalidResponse(format!("embedding slot {i} was not populated")) + }) + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicU32, Ordering}; + + /// Mock embedding provider that counts calls. + struct CountingMock { + dimension: usize, + model: String, + embed_calls: AtomicU32, + batch_calls: AtomicU32, + } + + impl CountingMock { + fn new(dimension: usize, model: &str) -> Self { + Self { + dimension, + model: model.to_string(), + embed_calls: AtomicU32::new(0), + batch_calls: AtomicU32::new(0), + } + } + + fn embed_calls(&self) -> u32 { + self.embed_calls.load(Ordering::SeqCst) + } + + fn batch_calls(&self) -> u32 { + self.batch_calls.load(Ordering::SeqCst) + } + } + + #[async_trait] + impl EmbeddingProvider for CountingMock { + fn dimension(&self) -> usize { + self.dimension + } + fn model_name(&self) -> &str { + &self.model + } + fn max_input_length(&self) -> usize { + 10_000 + } + async fn embed(&self, text: &str) -> Result, EmbeddingError> { + self.embed_calls.fetch_add(1, Ordering::SeqCst); + // Simple deterministic embedding: val = text.len() / 100.0 + let val = text.len() as f32 / 100.0; + Ok(vec![val; self.dimension]) + } + async fn embed_batch(&self, texts: &[String]) -> Result>, EmbeddingError> { + self.batch_calls.fetch_add(1, Ordering::SeqCst); + texts + .iter() + .map(|t| { + let val = t.len() as f32 / 100.0; + Ok(vec![val; self.dimension]) + }) + .collect() + } + } + + #[tokio::test] + async fn cache_hit_avoids_inner_call() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 }); + + let r1 = cached.embed("hello").await.unwrap(); + assert_eq!(inner.embed_calls(), 1); + + let r2 = cached.embed("hello").await.unwrap(); + assert_eq!(inner.embed_calls(), 1); // still 1 -- cache hit + assert_eq!(r1, r2); + + assert_eq!(cached.len(), 1); + } + + #[tokio::test] + async fn cache_miss_calls_inner() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 }); + + cached.embed("hello").await.unwrap(); + cached.embed("world").await.unwrap(); + assert_eq!(inner.embed_calls(), 2); + assert_eq!(cached.len(), 2); + } + + #[tokio::test] + async fn cache_key_includes_model() { + let inner_a = Arc::new(CountingMock::new(4, "model-a")); + let inner_b = Arc::new(CountingMock::new(4, "model-b")); + + let cached_a = CachedEmbeddingProvider::new( + inner_a.clone(), + EmbeddingCacheConfig { max_entries: 100 }, + ); + let cached_b = CachedEmbeddingProvider::new( + inner_b.clone(), + EmbeddingCacheConfig { max_entries: 100 }, + ); + + // Same text, different models -> different cache keys + let key_a = cached_a.cache_key("hello"); + let key_b = cached_b.cache_key("hello"); + assert_ne!(key_a, key_b); + } + + #[tokio::test] + async fn lru_eviction() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 2 }); + + cached.embed("first").await.unwrap(); + cached.embed("second").await.unwrap(); + assert_eq!(cached.len(), 2); + + // Third entry should evict the oldest ("first") + cached.embed("third").await.unwrap(); + assert_eq!(cached.len(), 2); + assert_eq!(inner.embed_calls(), 3); + + // "first" should be a cache miss now + cached.embed("first").await.unwrap(); + assert_eq!(inner.embed_calls(), 4); + } + + #[tokio::test] + async fn embed_batch_partial_hits() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 }); + + // Pre-cache one text + cached.embed("cached").await.unwrap(); + assert_eq!(inner.embed_calls(), 1); + + // Batch with 1 cached + 2 new + let texts = vec![ + "cached".to_string(), + "new_one".to_string(), + "new_two".to_string(), + ]; + let results = cached.embed_batch(&texts).await.unwrap(); + + // Should have called embed_batch on inner for 2 misses + assert_eq!(inner.batch_calls(), 1); + assert_eq!(results.len(), 3); + assert_eq!(cached.len(), 3); + } + + #[tokio::test] + async fn batch_preserves_order() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 }); + + // Pre-cache "bb" (len 2) + cached.embed("bb").await.unwrap(); + + // Batch: "a" (miss, len 1), "bb" (hit, len 2), "ccc" (miss, len 3) + let texts = vec!["a".to_string(), "bb".to_string(), "ccc".to_string()]; + let results = cached.embed_batch(&texts).await.unwrap(); + + assert_eq!(results.len(), 3); + let expected_a = vec![1.0_f32 / 100.0; 4]; + let expected_bb = vec![2.0_f32 / 100.0; 4]; + let expected_ccc = vec![3.0_f32 / 100.0; 4]; + assert_eq!(results[0], expected_a); + assert_eq!(results[1], expected_bb); + assert_eq!(results[2], expected_ccc); + } + + #[tokio::test] + async fn batch_exceeding_capacity_respects_max_entries() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 3 }); + + // Batch with 5 misses but cache capacity is 3 + let texts: Vec = (0..5).map(|i| format!("text_{i}")).collect(); + let results = cached.embed_batch(&texts).await.unwrap(); + + assert_eq!(results.len(), 5); + let len = cached.len(); + assert!(len <= 3, "cache len {len} exceeds max 3"); + } + + /// Mock embedding provider that fails the first N calls, then succeeds. + struct FailThenSucceedMock { + dimension: usize, + model: String, + remaining_failures: AtomicU32, + } + + impl FailThenSucceedMock { + fn new(dimension: usize, fail_count: u32) -> Self { + Self { + dimension, + model: "fail-mock".to_string(), + remaining_failures: AtomicU32::new(fail_count), + } + } + } + + #[async_trait] + impl EmbeddingProvider for FailThenSucceedMock { + fn dimension(&self) -> usize { + self.dimension + } + fn model_name(&self) -> &str { + &self.model + } + fn max_input_length(&self) -> usize { + 10_000 + } + async fn embed(&self, text: &str) -> Result, EmbeddingError> { + let prev = + self.remaining_failures + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| { + if v > 0 { Some(v - 1) } else { None } + }); + if prev.is_ok() { + return Err(EmbeddingError::HttpError("simulated failure".to_string())); + } + let val = text.len() as f32 / 100.0; + Ok(vec![val; self.dimension]) + } + async fn embed_batch(&self, texts: &[String]) -> Result>, EmbeddingError> { + let prev = + self.remaining_failures + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| { + if v > 0 { Some(v - 1) } else { None } + }); + if prev.is_ok() { + return Err(EmbeddingError::HttpError("simulated failure".to_string())); + } + texts + .iter() + .map(|t| { + let val = t.len() as f32 / 100.0; + Ok(vec![val; self.dimension]) + }) + .collect() + } + } + + #[tokio::test] + async fn error_does_not_pollute_cache() { + let inner = Arc::new(FailThenSucceedMock::new(4, 1)); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 }); + + // First call fails + let err = cached.embed("hello").await; + assert!(err.is_err()); + assert!(cached.is_empty(), "cache should be empty after error"); + + // Second call succeeds and should call the inner provider (not serve stale error) + let result = cached.embed("hello").await; + assert!(result.is_ok()); + assert_eq!(cached.len(), 1); + } + + #[tokio::test] + async fn embed_batch_empty_input() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 }); + + let results = cached.embed_batch(&[]).await.unwrap(); + assert!(results.is_empty()); + assert_eq!(inner.batch_calls(), 0); + } + + #[tokio::test] + async fn embed_batch_all_misses() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 }); + + // Nothing cached — every text is a miss + let texts: Vec = vec!["alpha".into(), "beta".into(), "gamma".into()]; + let results = cached.embed_batch(&texts).await.unwrap(); + assert_eq!(results.len(), 3); + assert_eq!(inner.batch_calls(), 1, "inner called once for misses"); + assert_eq!(cached.len(), 3, "all results should be cached"); + + // Second call should be all hits — no new inner calls + let results2 = cached.embed_batch(&texts).await.unwrap(); + assert_eq!(results2.len(), 3); + assert_eq!(inner.batch_calls(), 1, "no new inner calls"); + } + + #[tokio::test] + async fn zero_max_entries_clamped_to_one() { + let inner = Arc::new(CountingMock::new(4, "test-model")); + let cached = + CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 0 }); + + // Should behave as max_entries=1 (clamped in constructor) + cached.embed("hello").await.unwrap(); + assert_eq!(cached.len(), 1); + + // Second entry evicts the first + cached.embed("world").await.unwrap(); + assert_eq!(cached.len(), 1); + assert_eq!(inner.embed_calls(), 2); + } +} diff --git a/src/workspace/mod.rs b/src/workspace/mod.rs index ad233caf77..f2a59809d2 100644 --- a/src/workspace/mod.rs +++ b/src/workspace/mod.rs @@ -42,6 +42,7 @@ mod chunker; mod document; +mod embedding_cache; mod embeddings; pub mod hygiene; #[cfg(feature = "postgres")] @@ -50,6 +51,7 @@ mod search; pub use chunker::{ChunkConfig, chunk_document}; pub use document::{MemoryChunk, MemoryDocument, WorkspaceEntry, paths}; +pub use embedding_cache::{CachedEmbeddingProvider, EmbeddingCacheConfig}; pub use embeddings::{ EmbeddingProvider, MockEmbeddings, NearAiEmbeddings, OllamaEmbeddings, OpenAiEmbeddings, }; @@ -371,7 +373,33 @@ impl Workspace { } /// Set the embedding provider for semantic search. + /// + /// The provider is automatically wrapped in a [`CachedEmbeddingProvider`] + /// with the default cache size (10,000 entries; payload ~58 MB for 1536-dim, + /// actual memory higher due to per-entry overhead). pub fn with_embeddings(mut self, provider: Arc) -> Self { + self.embeddings = Some(Arc::new(CachedEmbeddingProvider::new( + provider, + EmbeddingCacheConfig::default(), + ))); + self + } + + /// Set the embedding provider with a custom cache configuration. + pub fn with_embeddings_cached( + mut self, + provider: Arc, + cache_config: EmbeddingCacheConfig, + ) -> Self { + self.embeddings = Some(Arc::new(CachedEmbeddingProvider::new( + provider, + cache_config, + ))); + self + } + + /// Set the embedding provider **without** caching (for tests). + pub fn with_embeddings_uncached(mut self, provider: Arc) -> Self { self.embeddings = Some(provider); self } diff --git a/tests/e2e/mock_llm.py b/tests/e2e/mock_llm.py index c27f276265..359c22d58f 100644 --- a/tests/e2e/mock_llm.py +++ b/tests/e2e/mock_llm.py @@ -267,14 +267,24 @@ async def _stream_tool_call(request: web.Request, cid: str, tc: dict) -> web.Str async def oauth_exchange(request: web.Request) -> web.Response: """Mock OAuth token exchange proxy for E2E tests. - Accepts form params (code, redirect_uri, code_verifier) and returns - a fake token response. Called by ironclaw's exchange_via_proxy() when - IRONCLAW_OAUTH_EXCHANGE_URL is set. + Accepts the generic hosted OAuth proxy contract used by IronClaw and + returns a fake token response. MCP callback tests assert that provider- + specific token params such as RFC 8707 `resource` are forwarded here. """ data = await request.post() code = data.get("code", "") + access_token_field = data.get("access_token_field", "access_token") + + if code == "mock_mcp_code": + if not data.get("token_url", "").endswith("/oauth/token"): + return web.json_response({"error": "missing_token_url"}, status=400) + if not data.get("client_id"): + return web.json_response({"error": "missing_client_id"}, status=400) + if not data.get("resource"): + return web.json_response({"error": "missing_resource"}, status=400) + return web.json_response({ - "access_token": f"mock-token-{code}", + access_token_field: f"mock-token-{code}", "refresh_token": "mock-refresh-token", "expires_in": 3600, }) diff --git a/tests/e2e/scenarios/test_mcp_auth_flow.py b/tests/e2e/scenarios/test_mcp_auth_flow.py index 7de2bbe689..cc36aa2edd 100644 --- a/tests/e2e/scenarios/test_mcp_auth_flow.py +++ b/tests/e2e/scenarios/test_mcp_auth_flow.py @@ -99,6 +99,10 @@ async def test_mcp_activate_triggers_auth(ironclaw_server): assert auth_url is not None or awaiting_token, ( f"Activate should require auth, got: {data}" ) + if auth_url is not None: + assert _extract_state(auth_url).startswith("ic2."), ( + f"Hosted MCP OAuth should emit versioned state, got: {auth_url}" + ) # ── Section C: OAuth Round-Trip ────────────────────────────────────────── diff --git a/tests/relay_integration.rs b/tests/relay_integration.rs index 8479cd6730..0a053885d9 100644 --- a/tests/relay_integration.rs +++ b/tests/relay_integration.rs @@ -2,18 +2,12 @@ //! //! Uses real HTTP servers on random ports (no mock framework). -use std::convert::Infallible; -use std::sync::atomic::{AtomicUsize, Ordering}; - use axum::{ Json, Router, extract::Query, - http::StatusCode, - response::sse::{Event, KeepAlive, Sse}, routing::{get, post}, }; -use futures::stream; -use ironclaw::channels::relay::client::{RelayClient, RelayError}; +use ironclaw::channels::relay::client::{ChannelEvent, RelayClient}; use secrecy::SecretString; use serde::Deserialize; use tokio::net::TcpListener; @@ -37,109 +31,79 @@ fn test_client(base_url: &str) -> RelayClient { .expect("client build") } -// ── SSE stream mock ───────────────────────────────────────────────────── +// ── Signing secret fetch ───────────────────────────────────────────────── #[tokio::test] -async fn test_sse_stream_receives_events() { +async fn test_get_signing_secret_returns_decoded_bytes() { + let secret_hex = hex::encode([1u8; 32]); + let secret_hex_clone = secret_hex.clone(); let app = Router::new().route( - "/stream", - get( - |Query(params): Query>| async move { - // Verify token is passed - assert!(params.contains_key("token")); - - let events = vec![ - Ok::<_, Infallible>( - Event::default().event("message").data( - serde_json::json!({ - "event_type": "message", - "provider": "slack", - "provider_scope": "T123", - "channel_id": "C456", - "sender_id": "U789", - "content": "hello world" - }) - .to_string(), - ), - ), - Ok(Event::default().event("message").data( - serde_json::json!({ - "event_type": "direct_message", - "provider": "slack", - "provider_scope": "T123", - "channel_id": "D001", - "sender_id": "U789", - "content": "dm text" - }) - .to_string(), - )), - ]; - - Sse::new(stream::iter(events)).keep_alive(KeepAlive::default()) - }, - ), + "/relay/signing-secret", + get(move || { + let s = secret_hex_clone.clone(); + async move { Json(serde_json::json!({"signing_secret": s})) } + }), ); let base_url = start_server(app).await; let client = test_client(&base_url); - let (mut event_stream, handle) = client.connect_stream("test-token", 30).await.unwrap(); + let secret = client.get_signing_secret("T123").await.unwrap(); + assert_eq!(secret, vec![1u8; 32]); +} - use futures::StreamExt; - let first = event_stream.next().await.expect("first event"); - assert_eq!(first.event_type, "message"); - assert_eq!(first.text(), "hello world"); - assert_eq!(first.team_id(), "T123"); +#[tokio::test] +async fn test_get_signing_secret_404_returns_error() { + let app = Router::new().route( + "/relay/signing-secret", + get(|| async { (axum::http::StatusCode::NOT_FOUND, "not found") }), + ); - let second = event_stream.next().await.expect("second event"); - assert_eq!(second.event_type, "direct_message"); - assert_eq!(second.text(), "dm text"); + let base_url = start_server(app).await; + let client = test_client(&base_url); - handle.abort(); + let result = client.get_signing_secret("T123").await; + assert!(result.is_err()); } -// ── Token renewal flow ────────────────────────────────────────────────── - #[tokio::test] -async fn test_token_expired_returns_error() { - let app = Router::new().route("/stream", get(|| async { StatusCode::UNAUTHORIZED })); +async fn test_get_signing_secret_invalid_hex_returns_protocol_error() { + let app = Router::new().route( + "/relay/signing-secret", + get(|| async { Json(serde_json::json!({"signing_secret": "not-hex"})) }), + ); let base_url = start_server(app).await; let client = test_client(&base_url); - match client.connect_stream("expired-token", 30).await { - Err(RelayError::TokenExpired) => {} // expected - Err(other) => panic!("expected TokenExpired, got: {other}"), - Ok(_) => panic!("expected error, got Ok"), - } + let err = client + .get_signing_secret("T123") + .await + .unwrap_err() + .to_string(); + assert!(err.contains("invalid signing_secret hex"), "got: {err}"); } #[tokio::test] -async fn test_token_renewal() { - let call_count = std::sync::Arc::new(AtomicUsize::new(0)); - let call_count_clone = call_count.clone(); - +async fn test_get_signing_secret_wrong_length_returns_protocol_error() { + let short_secret_hex = hex::encode([7u8; 31]); let app = Router::new().route( - "/stream/renew", - post(move |Json(body): Json| { - let count = call_count_clone.clone(); - async move { - count.fetch_add(1, Ordering::SeqCst); - assert!(body.get("instance_id").is_some()); - assert!(body.get("user_id").is_some()); - Json(serde_json::json!({ - "stream_token": "renewed-token-123" - })) - } + "/relay/signing-secret", + get(move || { + let s = short_secret_hex.clone(); + async move { Json(serde_json::json!({"signing_secret": s})) } }), ); let base_url = start_server(app).await; let client = test_client(&base_url); - let new_token = client.renew_token("inst-1", "user-1").await.unwrap(); - assert_eq!(new_token, "renewed-token-123"); - assert_eq!(call_count.load(Ordering::SeqCst), 1); + let err = client + .get_signing_secret("T123") + .await + .unwrap_err() + .to_string(); + assert!(err.contains("expected 32 bytes"), "got: {err}"); } // ── Proxy call ────────────────────────────────────────────────────────── @@ -171,7 +135,7 @@ async fn test_proxy_provider_sends_correct_payload() { "text": "Hello from test", }); let resp = client - .proxy_provider("slack", "T123", "chat.postMessage", body, None) + .proxy_provider("slack", "T123", "chat.postMessage", body) .await .unwrap(); assert_eq!(resp["ok"], true); @@ -200,18 +164,18 @@ async fn test_list_connections() { assert!(!conns[1].connected); } -// ── API key header ────────────────────────────────────────────────────── +// ── Bearer token auth ──────────────────────────────────────────────────── #[tokio::test] -async fn test_api_key_sent_in_header() { +async fn test_bearer_token_sent_in_header() { let app = Router::new().route( "/connections", get(|headers: axum::http::HeaderMap| async move { - let key = headers - .get("X-API-Key") + let auth = headers + .get("authorization") .and_then(|v| v.to_str().ok()) .unwrap_or(""); - assert_eq!(key, "test-api-key"); + assert_eq!(auth, "Bearer test-api-key"); Json(serde_json::json!([])) }), ); @@ -233,82 +197,10 @@ fn test_relay_client_new_succeeds() { assert!(client.is_ok()); } -// ── SSE UTF-8 chunk boundary ──────────────────────────────────────────── - -/// Verify that multi-byte UTF-8 characters split across SSE chunks are -/// not corrupted (no U+FFFD replacement characters). -#[tokio::test] -async fn test_sse_stream_preserves_multibyte_utf8_across_chunks() { - use std::sync::atomic::{AtomicBool, Ordering}; - - let sent = std::sync::Arc::new(AtomicBool::new(false)); - let sent_clone = sent.clone(); - - let app = Router::new().route( - "/stream", - get(move |_: Query>| { - let sent = sent_clone.clone(); - async move { - // Build SSE payload with emoji that will be split mid-character - let event_data = serde_json::json!({ - "event_type": "message", - "provider": "slack", - "provider_scope": "T1", - "channel_id": "C1", - "sender_id": "U1", - "content": "hello 🦀 world" - }); - let payload = format!("event: message\ndata: {}\n\n", event_data); - let bytes = payload.into_bytes(); - - // Split in the middle of the 4-byte crab emoji - let crab_pos = bytes - .windows(4) - .position(|w| w == [0xF0, 0x9F, 0xA6, 0x80]) - .unwrap(); - let split_at = crab_pos + 2; - - let chunk1 = bytes[..split_at].to_vec(); - let chunk2 = bytes[split_at..].to_vec(); - - sent.store(true, Ordering::SeqCst); - - let events = vec![ - Ok::<_, Infallible>(axum::body::Bytes::from(chunk1)), - Ok(axum::body::Bytes::from(chunk2)), - ]; - - axum::response::Response::builder() - .header("content-type", "text/event-stream") - .body(axum::body::Body::from_stream(stream::iter(events))) - .unwrap() - } - }), - ); - - let base_url = start_server(app).await; - let client = test_client(&base_url); - - let (mut event_stream, handle) = client.connect_stream("tok", 30).await.unwrap(); - - use futures::StreamExt; - let event = event_stream.next().await.expect("should get event"); - assert_eq!( - event.text(), - "hello 🦀 world", - "emoji should not be corrupted" - ); - assert!(sent.load(Ordering::SeqCst)); - - handle.abort(); -} - // ── Channel event field validation ────────────────────────────────────── #[test] fn test_channel_event_missing_fields_detected() { - use ironclaw::channels::relay::client::ChannelEvent; - // Event with empty sender_id should be detectable let json = r#"{"event_type": "message", "provider_scope": "T1", "channel_id": "C1", "sender_id": "", "content": "test"}"#; let event: ChannelEvent = serde_json::from_str(json).unwrap(); diff --git a/tests/workspace_integration.rs b/tests/workspace_integration.rs index dddd95e94e..2182fc38a1 100644 --- a/tests/workspace_integration.rs +++ b/tests/workspace_integration.rs @@ -308,7 +308,7 @@ async fn test_workspace_hybrid_search_with_mock_embeddings() { // Create workspace with mock embeddings (1536 dimensions to match OpenAI) let embeddings = Arc::new(MockEmbeddings::new(1536)); - let workspace = Workspace::new(user_id, pool.clone()).with_embeddings(embeddings); + let workspace = Workspace::new(user_id, pool.clone()).with_embeddings_uncached(embeddings); // Write documents workspace