diff --git a/src/channels/manager.rs b/src/channels/manager.rs index b026ff850d..0c9a3da77a 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -239,6 +239,11 @@ impl ChannelManager { pub async fn get_channel(&self, name: &str) -> Option> { self.channels.read().await.get(name).cloned() } + + /// Remove a channel from the manager. + pub async fn remove(&self, name: &str) -> Option> { + self.channels.write().await.remove(name) + } } impl Default for ChannelManager { diff --git a/src/channels/relay/channel.rs b/src/channels/relay/channel.rs index 52aea478ee..ab61c14714 100644 --- a/src/channels/relay/channel.rs +++ b/src/channels/relay/channel.rs @@ -1,16 +1,16 @@ -//! Channel trait implementation for channel-relay SSE streams. +//! Channel trait implementation for channel-relay webhook callbacks. //! -//! `RelayChannel` connects to a channel-relay service via SSE, converts -//! incoming events to `IncomingMessage`s, and sends responses via the -//! relay's provider-specific proxy API (Slack). +//! `RelayChannel` receives events from channel-relay via HTTP POST callbacks +//! (pushed through an mpsc channel by the webhook handler), converts them +//! to `IncomingMessage`s, and sends responses via the relay's provider-specific +//! proxy API (Slack). use std::collections::HashMap; -use std::sync::Arc; use async_trait::async_trait; -use tokio::sync::{RwLock, mpsc}; +use tokio::sync::mpsc; -use crate::channels::relay::client::{RelayClient, RelayError}; +use crate::channels::relay::client::{ChannelEvent, RelayClient}; use crate::channels::{Channel, IncomingMessage, MessageStream, OutgoingResponse, StatusUpdate}; use crate::error::ChannelError; @@ -39,44 +39,34 @@ impl RelayProvider { } } -/// Channel implementation that connects to a channel-relay SSE stream. +/// Channel implementation that receives events from channel-relay via webhook callbacks. pub struct RelayChannel { client: RelayClient, provider: RelayProvider, - stream_token: Arc>, team_id: String, instance_id: String, - user_id: String, - /// SSE stream long-poll timeout in seconds. - stream_timeout_secs: u64, - /// Initial exponential backoff in milliseconds. - backoff_initial_ms: u64, - /// Maximum exponential backoff in milliseconds. - backoff_max_ms: u64, - /// Handle to the reconnect task for clean shutdown. - reconnect_handle: RwLock>>, - /// Handle to the SSE parser task for clean shutdown. - parser_handle: Arc>>>, - /// Maximum consecutive reconnect failures before giving up. - max_consecutive_failures: u64, + /// Sender side of the event channel — shared with the webhook handler. + event_tx: mpsc::Sender, + /// Receiver side — taken once by `start()`. + event_rx: tokio::sync::Mutex>>, } impl RelayChannel { /// Create a new relay channel for Slack (default provider). pub fn new( client: RelayClient, - stream_token: String, team_id: String, instance_id: String, - user_id: String, + event_tx: mpsc::Sender, + event_rx: mpsc::Receiver, ) -> Self { Self::new_with_provider( client, RelayProvider::Slack, - stream_token, team_id, instance_id, - user_id, + event_tx, + event_rx, ) } @@ -84,44 +74,24 @@ impl RelayChannel { pub fn new_with_provider( client: RelayClient, provider: RelayProvider, - stream_token: String, team_id: String, instance_id: String, - user_id: String, + event_tx: mpsc::Sender, + event_rx: mpsc::Receiver, ) -> Self { Self { client, provider, - stream_token: Arc::new(RwLock::new(stream_token)), team_id, instance_id, - user_id, - stream_timeout_secs: 86400, - backoff_initial_ms: 1000, - backoff_max_ms: 60000, - reconnect_handle: RwLock::new(None), - parser_handle: Arc::new(RwLock::new(None)), - max_consecutive_failures: 50, + event_tx, + event_rx: tokio::sync::Mutex::new(Some(event_rx)), } } - /// Set backoff/timeout parameters from relay config values. - pub fn with_timeouts( - mut self, - stream_timeout_secs: u64, - backoff_initial_ms: u64, - backoff_max_ms: u64, - ) -> Self { - self.stream_timeout_secs = stream_timeout_secs; - self.backoff_initial_ms = backoff_initial_ms; - self.backoff_max_ms = backoff_max_ms; - self - } - - /// Set the maximum number of consecutive reconnect failures before giving up. - pub fn with_max_failures(mut self, max: u64) -> Self { - self.max_consecutive_failures = max; - self + /// Get a clone of the event sender for wiring into the webhook endpoint. + pub fn event_sender(&self) -> mpsc::Sender { + self.event_tx.clone() } /// Build a provider-appropriate proxy body for sending a message. @@ -151,15 +121,9 @@ impl RelayChannel { team_id: &str, method: &str, body: serde_json::Value, - ) -> Result { + ) -> Result { self.client - .proxy_provider( - self.provider.as_str(), - team_id, - method, - body, - Some(&self.instance_id), - ) + .proxy_provider(self.provider.as_str(), team_id, method, body) .await } } @@ -172,204 +136,82 @@ impl Channel for RelayChannel { async fn start(&self) -> Result { let channel_name = self.name().to_string(); - let token = self.stream_token.read().await.clone(); - let (stream, initial_parser_handle) = self - .client - .connect_stream(&token, self.stream_timeout_secs) - .await - .map_err(|e| ChannelError::StartupFailed { - name: channel_name.clone(), - reason: e.to_string(), - })?; - *self.parser_handle.write().await = Some(initial_parser_handle); + // Take the receiver (can only start once) + let mut event_rx = + self.event_rx + .lock() + .await + .take() + .ok_or_else(|| ChannelError::StartupFailed { + name: channel_name.clone(), + reason: "RelayChannel already started".to_string(), + })?; let (tx, rx) = mpsc::channel(64); - - // Spawn the stream reader + reconnect task - let client = self.client.clone(); - let stream_token = Arc::clone(&self.stream_token); - let instance_id = self.instance_id.clone(); - let user_id = self.user_id.clone(); - let team_id = self.team_id.clone(); - let stream_timeout_secs = self.stream_timeout_secs; - let backoff_initial_ms = self.backoff_initial_ms; - let backoff_max_ms = self.backoff_max_ms; - let max_consecutive_failures = self.max_consecutive_failures; - let parser_handle = Arc::clone(&self.parser_handle); let provider_str = self.provider.as_str().to_string(); let relay_name = channel_name.clone(); - let handle = tokio::spawn(async move { - use futures::StreamExt; - - let mut current_stream = stream; - let mut backoff_ms = backoff_initial_ms; - let mut consecutive_failures: u64 = 0; - - loop { - // Read events from the current stream - while let Some(event) = current_stream.next().await { - // Reset backoff and failure count on successful event - backoff_ms = backoff_initial_ms; - consecutive_failures = 0; - - // Validate required fields - if event.sender_id.is_empty() - || event.channel_id.is_empty() - || event.provider_scope.is_empty() - { - tracing::debug!( - event_type = %event.event_type, - sender_id = %event.sender_id, - channel_id = %event.channel_id, - "Relay: skipping event with missing required fields" - ); - continue; - } - - // Skip non-message events - if !event.is_message() { - tracing::debug!( - event_type = %event.event_type, - "Relay: skipping non-message event" - ); - continue; - } - - tracing::info!( + // Spawn a task that reads events from the webhook handler and converts to IncomingMessage + tokio::spawn(async move { + while let Some(event) = event_rx.recv().await { + // Validate required fields + if event.sender_id.is_empty() + || event.channel_id.is_empty() + || event.provider_scope.is_empty() + { + tracing::debug!( event_type = %event.event_type, - sender = %event.sender_id, - channel = %event.channel_id, - provider = %provider_str, - "Relay: received message from {}", provider_str + sender_id = %event.sender_id, + channel_id = %event.channel_id, + "Relay: skipping event with missing required fields" ); - - let msg = IncomingMessage::new(&relay_name, &event.sender_id, event.text()) - .with_user_name(event.display_name()) - .with_metadata(serde_json::json!({ - "team_id": event.team_id(), - "channel_id": event.channel_id, - "sender_id": event.sender_id, - "sender_name": event.display_name(), - "event_type": event.event_type, - "thread_id": event.thread_id, - "provider": event.provider, - })); - - let msg = if let Some(ref thread_id) = event.thread_id { - msg.with_thread(thread_id) - } else { - msg.with_thread(&event.channel_id) - }; - - if tx.send(msg).await.is_err() { - tracing::info!("Relay channel receiver dropped, stopping"); - return; - } + continue; } - // Stream ended, attempt reconnect with backoff - consecutive_failures += 1; - if consecutive_failures >= max_consecutive_failures { - tracing::error!( - channel = %relay_name, - failures = consecutive_failures, - "Relay channel giving up after {} consecutive failures", - consecutive_failures + // Skip non-message events + if !event.is_message() { + tracing::debug!( + event_type = %event.event_type, + "Relay: skipping non-message event" ); - break; + continue; } - tracing::warn!( - backoff_ms = backoff_ms, - failures = consecutive_failures, - "Relay SSE stream ended, reconnecting..." + tracing::info!( + event_type = %event.event_type, + sender = %event.sender_id, + channel = %event.channel_id, + provider = %provider_str, + "Relay: received message from {}", provider_str ); - tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms * 2).min(backoff_max_ms); - - // Try to reconnect - let token = stream_token.read().await.clone(); - match client.connect_stream(&token, stream_timeout_secs).await { - Ok((new_stream, new_parser)) => { - tracing::info!("Relay SSE stream reconnected"); - consecutive_failures = 0; - backoff_ms = backoff_initial_ms; - current_stream = new_stream; - // Abort old parser before replacing - if let Some(old) = parser_handle.write().await.take() { - old.abort(); - } - *parser_handle.write().await = Some(new_parser); - } - Err(RelayError::TokenExpired) => { - // Attempt token renewal - tracing::info!("Relay stream token expired, renewing..."); - match client.renew_token(&instance_id, &user_id).await { - Ok(new_token) => { - *stream_token.write().await = new_token.clone(); - match client.connect_stream(&new_token, stream_timeout_secs).await { - Ok((new_stream, new_parser)) => { - tracing::info!( - "Relay SSE stream reconnected with new token" - ); - consecutive_failures = 0; - backoff_ms = backoff_initial_ms; - current_stream = new_stream; - if let Some(old) = parser_handle.write().await.take() { - old.abort(); - } - *parser_handle.write().await = Some(new_parser); - } - Err(e) => { - tracing::error!( - error = %e, - "Failed to reconnect after token renewal" - ); - } - } - } - Err(e) => { - tracing::error!( - error = %e, - "Failed to renew relay stream token" - ); - } - } - } - Err(e) => { - tracing::error!(error = %e, "Failed to reconnect relay SSE stream"); - } - } - // Check if the team is still valid (skip when team_id is unknown, - // e.g. when no DB store was available at activation time) - if !team_id.is_empty() { - match client.list_connections(&instance_id).await { - Ok(conns) => { - let has_team = - conns.iter().any(|c| c.team_id == team_id && c.connected); - if !has_team { - tracing::warn!( - team_id = %team_id, - "Team no longer connected, stopping relay channel" - ); - return; - } - } - Err(e) => { - tracing::warn!( - error = %e, - "Could not verify team connection, will retry next iteration" - ); - } - } + let msg = IncomingMessage::new(&relay_name, &event.sender_id, event.text()) + .with_user_name(event.display_name()) + .with_metadata(serde_json::json!({ + "team_id": event.team_id(), + "channel_id": event.channel_id, + "sender_id": event.sender_id, + "sender_name": event.display_name(), + "event_type": event.event_type, + "thread_id": event.thread_id, + "provider": event.provider, + })); + + let msg = if let Some(ref thread_id) = event.thread_id { + msg.with_thread(thread_id) + } else { + msg.with_thread(&event.channel_id) + }; + + if tx.send(msg).await.is_err() { + tracing::info!("Relay channel receiver dropped, stopping"); + return; } } - }); - *self.reconnect_handle.write().await = Some(handle); + tracing::info!("Relay event channel closed"); + }); let stream = tokio_stream::wrappers::ReceiverStream::new(rx); Ok(Box::pin(stream)) @@ -450,28 +292,24 @@ impl Channel for RelayChannel { name: self.name().to_string(), reason: "Missing channel_id for approval buttons".into(), })?; - let sender_id = metadata - .get("sender_id") - .and_then(|v| v.as_str()) - .ok_or_else(|| ChannelError::SendFailed { - name: self.name().to_string(), - reason: "Missing sender_id for approval buttons".into(), - })?; let thread_id = metadata.get("thread_id").and_then(|v| v.as_str()); let team_id = metadata .get("team_id") .and_then(|v| v.as_str()) .unwrap_or(&self.team_id); - // Button value payload (Slack limits button values to 2000 chars; - // safe with typical UUIDs but documented here as a constraint) + // Register server-side approval record and get opaque token. + // The button value contains ONLY the token — no routing fields. + let approval_token = self + .client + .create_approval(team_id, channel_id, thread_id, &request_id) + .await + .map_err(|e| ChannelError::SendFailed { + name: self.name().to_string(), + reason: format!("Failed to register approval: {e}"), + })?; let value_payload = serde_json::json!({ - "instance_id": self.instance_id, - "team_id": team_id, - "channel_id": channel_id, - "thread_ts": thread_id, - "request_id": request_id, - "sender_id": sender_id, + "approval_token": approval_token, }); let value_str = value_payload.to_string(); @@ -582,12 +420,8 @@ impl Channel for RelayChannel { } async fn shutdown(&self) -> Result<(), ChannelError> { - if let Some(handle) = self.reconnect_handle.write().await.take() { - handle.abort(); - } - if let Some(handle) = self.parser_handle.write().await.take() { - handle.abort(); - } + // Relay cleanup is driven by the extension manager dropping the shared + // sender and removing the channel from the channel manager. Ok(()) } } @@ -605,27 +439,20 @@ mod tests { .expect("client") } + fn make_channel() -> RelayChannel { + let (tx, rx) = mpsc::channel(64); + RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx, rx) + } + #[test] fn relay_channel_name() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); assert_eq!(channel.name(), DEFAULT_RELAY_NAME); } #[test] fn conversation_context_extracts_metadata() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let metadata = serde_json::json!({ "sender_name": "bob", @@ -640,8 +467,6 @@ mod tests { #[test] fn metadata_shape_includes_event_type_and_sender_name() { - // Regression: metadata JSON must include event_type and sender_name - // for downstream routing (DM vs channel) and conversation_context(). let metadata = serde_json::json!({ "team_id": "T123", "channel_id": "C456", @@ -651,43 +476,19 @@ mod tests { "thread_id": null, "provider": "slack", }); - // event_type must be present for DM-vs-channel routing assert_eq!( metadata.get("event_type").and_then(|v| v.as_str()), Some("direct_message") ); - // sender_name must be present for conversation_context assert_eq!( metadata.get("sender_name").and_then(|v| v.as_str()), Some("alice") ); } - #[test] - fn with_timeouts_sets_values() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ) - .with_timeouts(43200, 2000, 120000); - - assert_eq!(channel.stream_timeout_secs, 43200); - assert_eq!(channel.backoff_initial_ms, 2000); - assert_eq!(channel.backoff_max_ms, 120000); - } - #[test] fn build_send_body_slack() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let (method, body) = channel.build_send_body("C456", "hello", Some("1234567.890")); assert_eq!(method, "chat.postMessage"); assert_eq!(body["channel"], "C456"); @@ -695,72 +496,95 @@ mod tests { assert_eq!(body["thread_ts"], "1234567.890"); } - #[test] - fn parser_handle_is_shared_arc() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); - // parser_handle should be an Arc — cloning should give a second reference - let handle_clone = Arc::clone(&channel.parser_handle); - // Both point to the same allocation - assert!(Arc::ptr_eq(&channel.parser_handle, &handle_clone)); - } - - #[test] - fn with_max_failures_sets_value() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ) - .with_max_failures(10); + #[tokio::test] + async fn start_processes_events() { + let (tx, rx) = mpsc::channel(64); + let channel = + RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx.clone(), rx); + + let mut stream = channel.start().await.unwrap(); + + // Send an event + tx.send(ChannelEvent { + id: "1".into(), + event_type: "message".into(), + provider: "slack".into(), + provider_scope: "T123".into(), + channel_id: "C456".into(), + sender_id: "U789".into(), + sender_name: Some("alice".into()), + content: Some("hello".into()), + thread_id: None, + raw: serde_json::Value::Null, + timestamp: None, + }) + .await + .unwrap(); + + use futures::StreamExt; + let msg = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .unwrap() + .unwrap(); - assert_eq!(channel.max_consecutive_failures, 10); + assert_eq!(msg.content, "hello"); + assert_eq!(msg.user_id, "U789"); } - #[test] - fn default_max_failures_is_50() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); - assert_eq!(channel.max_consecutive_failures, 50); - } + #[tokio::test] + async fn start_skips_non_message_events() { + let (tx, rx) = mpsc::channel(64); + let channel = + RelayChannel::new(test_client(), "T123".into(), "inst1".into(), tx.clone(), rx); + + let mut stream = channel.start().await.unwrap(); + + // Send a non-message event (should be skipped) + tx.send(ChannelEvent { + id: "1".into(), + event_type: "reaction".into(), + provider: "slack".into(), + provider_scope: "T123".into(), + channel_id: "C456".into(), + sender_id: "U789".into(), + sender_name: None, + content: None, + thread_id: None, + raw: serde_json::Value::Null, + timestamp: None, + }) + .await + .unwrap(); + + // Send a real message + tx.send(ChannelEvent { + id: "2".into(), + event_type: "message".into(), + provider: "slack".into(), + provider_scope: "T123".into(), + channel_id: "C456".into(), + sender_id: "U789".into(), + sender_name: None, + content: Some("real message".into()), + thread_id: None, + raw: serde_json::Value::Null, + timestamp: None, + }) + .await + .unwrap(); + + use futures::StreamExt; + let msg = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()) + .await + .unwrap() + .unwrap(); - #[test] - fn empty_team_id_accepted_at_construction() { - // Regression: empty team_id (when no DB store is available) must not - // prevent channel construction or cause immediate shutdown. - let channel = RelayChannel::new( - test_client(), - "token".into(), - String::new(), // empty team_id - "inst1".into(), - "user1".into(), - ); - assert_eq!(channel.team_id, ""); - // The reconnect loop now skips team validation when team_id is empty, - // so the channel remains alive. + assert_eq!(msg.content, "real message"); } #[tokio::test] async fn test_send_status_non_approval_is_noop() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let metadata = serde_json::json!({}); let result = channel .send_status( @@ -775,13 +599,7 @@ mod tests { #[tokio::test] async fn test_send_status_approval_non_dm_skips() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let metadata = serde_json::json!({ "event_type": "message", "channel_id": "C456", @@ -804,13 +622,7 @@ mod tests { #[tokio::test] async fn test_send_status_approval_dm_missing_channel_id_errors() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + let channel = make_channel(); let metadata = serde_json::json!({ "event_type": "direct_message", "sender_id": "U789", @@ -835,14 +647,8 @@ mod tests { } #[tokio::test] - async fn test_send_status_approval_dm_missing_sender_id_errors() { - let channel = RelayChannel::new( - test_client(), - "token".into(), - "T123".into(), - "inst1".into(), - "user1".into(), - ); + async fn test_send_status_approval_dm_without_sender_id_is_ok() { + let channel = make_channel(); let metadata = serde_json::json!({ "event_type": "direct_message", "channel_id": "C456", @@ -861,8 +667,8 @@ mod tests { assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!( - err.contains("sender_id"), - "expected sender_id error, got: {err}" + !err.contains("sender_id"), + "sender_id should not be required anymore, got: {err}" ); } } diff --git a/src/channels/relay/client.rs b/src/channels/relay/client.rs index d1c03a5190..81fbb56c93 100644 --- a/src/channels/relay/client.rs +++ b/src/channels/relay/client.rs @@ -1,15 +1,10 @@ //! HTTP client for the channel-relay service. //! //! Wraps reqwest for all channel-relay API calls: OAuth initiation, -//! SSE streaming, token renewal, and Slack API proxy. +//! approvals, signing-secret fetch, and Slack API proxy. -use std::pin::Pin; -use std::task::{Context, Poll}; - -use futures::Stream; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc; /// Known relay event types. pub mod event_types { @@ -18,7 +13,7 @@ pub mod event_types { pub const MENTION: &str = "mention"; } -/// A parsed SSE event from the channel-relay stream. +/// A parsed event from the channel-relay webhook callback. /// /// Field names match the channel-relay `ChannelEvent` struct exactly. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -123,21 +118,19 @@ impl RelayClient { /// /// Calls `GET /oauth/slack/auth` with `redirect(Policy::none())` and /// returns the `Location` header (Slack OAuth URL) without following it. - pub async fn initiate_oauth( - &self, - instance_id: &str, - user_id: &str, - callback_url: &str, - ) -> Result { + /// Initiate Slack OAuth. Channel-relay derives all URLs from the trusted + /// instance_url in chat-api. IronClaw only passes an optional CSRF nonce + /// for validating the callback — no URLs. + pub async fn initiate_oauth(&self, state_nonce: Option<&str>) -> Result { + let mut query: Vec<(&str, &str)> = vec![]; + if let Some(nonce) = state_nonce { + query.push(("state_nonce", nonce)); + } let resp = self .http .get(format!("{}/oauth/slack/auth", self.base_url)) - .header("X-API-Key", self.api_key.expose_secret()) - .query(&[ - ("instance_id", instance_id), - ("user_id", user_id), - ("callback", callback_url), - ]) + .bearer_auth(self.api_key.expose_secret()) + .query(&query) .send() .await .map_err(|e| RelayError::Network(e.to_string()))?; @@ -173,104 +166,69 @@ impl RelayClient { } } - /// Connect to the SSE event stream. + /// Register a pending approval and return the opaque approval token. /// - /// Returns a stream of parsed `ChannelEvent`s and the `JoinHandle` of the - /// background SSE parser task. The caller is responsible for reconnection - /// logic on stream end/error and for aborting the handle on shutdown. - pub async fn connect_stream( + /// Calls `POST /approvals` with the target team/channel/request identifiers. + /// The returned token is embedded in Slack button values instead of routing fields. + /// The relay derives the authorized approver from the connection's authed_user_id. + pub async fn create_approval( &self, - stream_token: &str, - stream_timeout_secs: u64, - ) -> Result<(ChannelEventStream, tokio::task::JoinHandle<()>), RelayError> { - let resp = self - .http - .get(format!("{}/stream", self.base_url)) - .query(&[("token", stream_token)]) - .timeout(std::time::Duration::from_secs(stream_timeout_secs)) - .send() - .await - .map_err(|e| RelayError::Network(e.to_string()))?; - - let status = resp.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - return Err(RelayError::TokenExpired); - } - if !status.is_success() { - let body = resp.text().await.unwrap_or_default(); - return Err(RelayError::Api { - status: status.as_u16(), - message: body, - }); + team_id: &str, + channel_id: &str, + thread_ts: Option<&str>, + request_id: &str, + ) -> Result { + let mut body = serde_json::json!({ + "team_id": team_id, + "channel_id": channel_id, + "request_id": request_id, + }); + if let Some(ts) = thread_ts { + body["thread_ts"] = serde_json::Value::String(ts.to_string()); } - // Spawn a background task that reads the SSE stream and sends parsed events - let (tx, rx) = mpsc::channel(64); - let byte_stream = resp.bytes_stream(); - let handle = tokio::spawn(parse_sse_stream(byte_stream, tx)); - - Ok((ChannelEventStream { rx }, handle)) - } - - /// Renew an expired stream token. - /// - /// Calls `POST /stream/renew` with API key auth, returns a new stream token. - pub async fn renew_token( - &self, - instance_id: &str, - user_id: &str, - ) -> Result { let resp = self .http - .post(format!("{}/stream/renew", self.base_url)) - .header("X-API-Key", self.api_key.expose_secret()) - .json(&serde_json::json!({ - "instance_id": instance_id, - "user_id": user_id, - })) + .post(format!("{}/approvals", self.base_url)) + .bearer_auth(self.api_key.expose_secret()) + .json(&body) .send() .await .map_err(|e| RelayError::Network(e.to_string()))?; - let status = resp.status(); - if !status.is_success() { + if !resp.status().is_success() { + let status = resp.status().as_u16(); let body = resp.text().await.unwrap_or_default(); return Err(RelayError::Api { - status: status.as_u16(), + status, message: body, }); } - let body: serde_json::Value = resp + let result: serde_json::Value = resp .json() .await .map_err(|e| RelayError::Protocol(e.to_string()))?; - body.get("stream_token") - .or_else(|| body.get("token")) + + result + .get("approval_token") .and_then(|v| v.as_str()) .map(|s| s.to_string()) - .ok_or_else(|| RelayError::Protocol("Response missing stream_token field".to_string())) + .ok_or_else(|| RelayError::Protocol("missing approval_token in response".to_string())) } - /// Proxy an API call through channel-relay for any provider. - /// - /// Calls `POST /proxy/{provider}/{method}?team_id=X&instance_id=Y` with the given JSON body. pub async fn proxy_provider( &self, provider: &str, team_id: &str, method: &str, body: serde_json::Value, - instance_id: Option<&str>, ) -> Result { - let mut query: Vec<(&str, &str)> = vec![("team_id", team_id)]; - if let Some(iid) = instance_id { - query.push(("instance_id", iid)); - } + let query: Vec<(&str, &str)> = vec![("team_id", team_id)]; let resp = self .http .post(format!("{}/proxy/{}/{}", self.base_url, provider, method)) - .header("X-API-Key", self.api_key.expose_secret()) + .bearer_auth(self.api_key.expose_secret()) .query(&query) .json(&body) .send() @@ -291,12 +249,58 @@ impl RelayClient { .map_err(|e| RelayError::Protocol(e.to_string())) } + /// Fetch the per-instance callback signing secret from channel-relay. + /// + /// Calls `GET /relay/signing-secret` (authenticated) and returns the decoded + /// 32-byte secret. Called once at activation time; the result is cached in the + /// extension manager so subsequent calls to `relay_signing_secret()` use it. + pub async fn get_signing_secret(&self, team_id: &str) -> Result, RelayError> { + let resp = self + .http + .get(format!("{}/relay/signing-secret", self.base_url)) + .bearer_auth(self.api_key.expose_secret()) + .query(&[("team_id", team_id)]) + .send() + .await + .map_err(|e| RelayError::Network(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status().as_u16(); + let body = resp.text().await.unwrap_or_default(); + return Err(RelayError::Api { + status, + message: body, + }); + } + + let body: serde_json::Value = resp + .json() + .await + .map_err(|e| RelayError::Protocol(e.to_string()))?; + + body.get("signing_secret") + .and_then(|v| v.as_str()) + .ok_or_else(|| RelayError::Protocol("missing signing_secret in response".to_string())) + .and_then(|raw| { + let decoded = hex::decode(raw).map_err(|e| { + RelayError::Protocol(format!("invalid signing_secret hex: {e}")) + })?; + if decoded.len() != 32 { + return Err(RelayError::Protocol(format!( + "invalid signing_secret length: expected 32 bytes, got {}", + decoded.len() + ))); + } + Ok(decoded) + }) + } + /// List active connections for an instance. pub async fn list_connections(&self, instance_id: &str) -> Result, RelayError> { let resp = self .http .get(format!("{}/connections", self.base_url)) - .header("X-API-Key", self.api_key.expose_secret()) + .bearer_auth(self.api_key.expose_secret()) .query(&[("instance_id", instance_id)]) .send() .await @@ -317,91 +321,6 @@ impl RelayClient { } } -/// Async stream of parsed channel events from SSE. -pub struct ChannelEventStream { - rx: mpsc::Receiver, -} - -impl Stream for ChannelEventStream { - type Item = ChannelEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx.poll_recv(cx) - } -} - -/// Parse SSE format from a reqwest bytes stream. -/// -/// SSE format: -/// ```text -/// event: message -/// data: {"key": "value"} -/// -/// ``` -/// Blank line terminates an event. -async fn parse_sse_stream( - byte_stream: impl futures::Stream> + Send + 'static, - tx: mpsc::Sender, -) { - use futures::StreamExt; - - let mut buffer = Vec::::new(); - let mut event_type = String::new(); - let mut data_lines = Vec::new(); - - let mut byte_stream = std::pin::pin!(byte_stream); - while let Some(chunk_result) = byte_stream.next().await { - let chunk = match chunk_result { - Ok(c) => c, - Err(e) => { - tracing::debug!(error = %e, "SSE stream chunk error"); - break; - } - }; - - buffer.extend_from_slice(&chunk); - - // Process complete lines (decode UTF-8 only on full lines to avoid - // corruption when multi-byte characters span chunk boundaries) - while let Some(newline_pos) = buffer.iter().position(|&b| b == b'\n') { - let line = String::from_utf8_lossy(&buffer[..newline_pos]) - .trim_end_matches('\r') - .to_string(); - buffer.drain(..=newline_pos); - - if line.is_empty() { - // Blank line = end of event - if !data_lines.is_empty() { - let data = data_lines.join("\n"); - if let Ok(mut event) = serde_json::from_str::(&data) { - if event.event_type.is_empty() && !event_type.is_empty() { - event.event_type = event_type.clone(); - } - if tx.send(event).await.is_err() { - return; // receiver dropped - } - } else { - tracing::debug!( - event_type = %event_type, - data_len = data.len(), - "Failed to parse SSE event data as ChannelEvent" - ); - } - } - event_type.clear(); - data_lines.clear(); - } else if let Some(value) = line.strip_prefix("event:") { - event_type = value.trim().to_string(); - } else if let Some(value) = line.strip_prefix("data:") { - data_lines.push(value.trim().to_string()); - } - // Ignore other fields (id:, retry:, comments) - } - } - - tracing::debug!("SSE stream ended"); -} - /// Errors from relay client operations. #[derive(Debug, thiserror::Error)] pub enum RelayError { @@ -413,9 +332,6 @@ pub enum RelayError { #[error("Protocol error: {0}")] Protocol(String), - - #[error("Stream token expired")] - TokenExpired, } #[cfg(test)] @@ -494,9 +410,6 @@ mod tests { message: "unauthorized".into(), }; assert_eq!(err.to_string(), "API error (HTTP 401): unauthorized"); - - let err = RelayError::TokenExpired; - assert_eq!(err.to_string(), "Stream token expired"); } #[test] @@ -518,32 +431,4 @@ mod tests { assert!(make(event_types::DIRECT_MESSAGE).is_message()); assert!(make(event_types::MENTION).is_message()); } - - #[tokio::test] - async fn parse_sse_handles_multibyte_utf8_across_chunks() { - // The crab emoji (🦀) is 4 bytes: [0xF0, 0x9F, 0xA6, 0x80]. - // Split it across two chunks to verify no U+FFFD corruption. - let event_json = r#"{"event_type":"message","content":"hello 🦀 world","provider_scope":"T1","channel_id":"C1","sender_id":"U1"}"#; - let full = format!("event: message\ndata: {}\n\n", event_json); - let bytes = full.as_bytes(); - - // Find the crab emoji and split mid-character - let crab_pos = bytes - .windows(4) - .position(|w| w == [0xF0, 0x9F, 0xA6, 0x80]) - .expect("crab emoji not found"); - let split_at = crab_pos + 2; // split in the middle of the 4-byte emoji - - let chunk1 = bytes::Bytes::copy_from_slice(&bytes[..split_at]); - let chunk2 = bytes::Bytes::copy_from_slice(&bytes[split_at..]); - - let chunks: Vec> = vec![Ok(chunk1), Ok(chunk2)]; - let stream = futures::stream::iter(chunks); - - let (tx, mut rx) = mpsc::channel(8); - parse_sse_stream(stream, tx).await; - - let event = rx.recv().await.expect("should receive event"); - assert_eq!(event.text(), "hello 🦀 world"); - } } diff --git a/src/channels/relay/mod.rs b/src/channels/relay/mod.rs index 1582319fa6..05f5870c40 100644 --- a/src/channels/relay/mod.rs +++ b/src/channels/relay/mod.rs @@ -1,12 +1,13 @@ //! Channel-relay integration for connecting to external messaging platforms //! (Slack) via the channel-relay service. //! -//! The relay service handles OAuth, credential storage, webhook ingestion, -//! and SSE event streaming. IronClaw consumes the SSE stream and sends -//! messages via the relay's proxy API. +//! The relay service handles OAuth, credential storage, and webhook ingestion. +//! IronClaw receives events via webhook callbacks and sends messages via the +//! relay's proxy API. pub mod channel; pub mod client; +pub mod webhook; pub use channel::{DEFAULT_RELAY_NAME, RelayChannel}; pub use client::RelayClient; diff --git a/src/channels/relay/webhook.rs b/src/channels/relay/webhook.rs new file mode 100644 index 0000000000..c5a9f82a65 --- /dev/null +++ b/src/channels/relay/webhook.rs @@ -0,0 +1,66 @@ +//! Shared relay webhook signature verification helpers. + +use hmac::{Hmac, Mac}; +use sha2::Sha256; + +type HmacSha256 = Hmac; + +/// Verify a relay callback HMAC signature. +pub fn verify_relay_signature( + secret: &[u8], + timestamp: &str, + body: &[u8], + signature: &str, +) -> bool { + verify_signature(secret, timestamp, body, signature) +} + +fn verify_signature(secret: &[u8], timestamp: &str, body: &[u8], signature: &str) -> bool { + let mut mac = match HmacSha256::new_from_slice(secret) { + Ok(m) => m, + Err(_) => return false, + }; + mac.update(timestamp.as_bytes()); + mac.update(b"."); + mac.update(body); + let expected = format!("sha256={}", hex::encode(mac.finalize().into_bytes())); + subtle::ConstantTimeEq::ct_eq(expected.as_bytes(), signature.as_bytes()).into() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_signature(secret: &[u8], timestamp: &str, body: &[u8]) -> String { + let mut mac = HmacSha256::new_from_slice(secret).unwrap(); + mac.update(timestamp.as_bytes()); + mac.update(b"."); + mac.update(body); + format!("sha256={}", hex::encode(mac.finalize().into_bytes())) + } + + #[test] + fn verify_valid_signature() { + let secret = b"test-secret"; + let body = b"hello"; + let ts = "1234567890"; + let sig = make_signature(secret, ts, body); + assert!(verify_signature(secret, ts, body, &sig)); + } + + #[test] + fn verify_wrong_secret_fails() { + let body = b"hello"; + let ts = "1234567890"; + let sig = make_signature(b"correct", ts, body); + assert!(!verify_signature(b"wrong", ts, body, &sig)); + } + + #[test] + fn verify_tampered_body_fails() { + let secret = b"secret"; + let ts = "1234567890"; + let sig = make_signature(secret, ts, b"original"); + assert!(!verify_signature(secret, ts, b"tampered", &sig)); + } +} diff --git a/src/channels/web/server.rs b/src/channels/web/server.rs index 27ef7cdce9..2648dea286 100644 --- a/src/channels/web/server.rs +++ b/src/channels/web/server.rs @@ -208,7 +208,8 @@ pub async fn start_server( .route( "/oauth/slack/callback", get(slack_relay_oauth_callback_handler), - ); + ) + .route("/relay/events", post(relay_events_handler)); // Protected routes (require auth) let auth_state = AuthState { token: auth_token }; @@ -742,11 +743,103 @@ async fn oauth_callback_handler( axum::response::Html(html).into_response() } +/// Webhook endpoint for receiving relay events from channel-relay. +/// +/// PUBLIC route — authenticated via HMAC signature (X-Relay-Signature header). +async fn relay_events_handler( + State(state): State>, + headers: axum::http::HeaderMap, + body: axum::body::Bytes, +) -> impl IntoResponse { + let ext_mgr = match state.extension_manager.as_ref() { + Some(mgr) => mgr, + None => { + return (StatusCode::SERVICE_UNAVAILABLE, "not ready").into_response(); + } + }; + + let signing_secret = match ext_mgr.relay_signing_secret() { + Some(s) => s, + None => { + return (StatusCode::SERVICE_UNAVAILABLE, "relay not configured").into_response(); + } + }; + + // Verify signature + let signature = match headers + .get("x-relay-signature") + .and_then(|v| v.to_str().ok()) + { + Some(s) => s.to_string(), + None => { + return (StatusCode::UNAUTHORIZED, "missing signature").into_response(); + } + }; + + let timestamp = match headers + .get("x-relay-timestamp") + .and_then(|v| v.to_str().ok()) + { + Some(t) => t.to_string(), + None => { + return (StatusCode::UNAUTHORIZED, "missing timestamp").into_response(); + } + }; + + // Check timestamp freshness (5 min window) + let ts: i64 = match timestamp.parse() { + Ok(t) => t, + Err(_) => { + return (StatusCode::BAD_REQUEST, "malformed timestamp").into_response(); + } + }; + let now = chrono::Utc::now().timestamp(); + if (now - ts).abs() > 300 { + return (StatusCode::UNAUTHORIZED, "stale timestamp").into_response(); + } + + // Verify HMAC: sha256(secret, timestamp + "." + body) + if !crate::channels::relay::webhook::verify_relay_signature( + &signing_secret, + ×tamp, + &body, + &signature, + ) { + return (StatusCode::UNAUTHORIZED, "invalid signature").into_response(); + } + + // Parse event + let event: crate::channels::relay::client::ChannelEvent = match serde_json::from_slice(&body) { + Ok(e) => e, + Err(e) => { + tracing::warn!(error = %e, "relay callback invalid JSON"); + return (StatusCode::BAD_REQUEST, "invalid JSON").into_response(); + } + }; + + // Push to relay channel + let event_tx_guard = ext_mgr.relay_event_tx(); + let event_tx = event_tx_guard.lock().await; + match event_tx.as_ref() { + Some(tx) => { + if let Err(e) = tx.try_send(event) { + tracing::warn!(error = %e, "relay event channel full or closed"); + return (StatusCode::SERVICE_UNAVAILABLE, "event queue full").into_response(); + } + } + None => { + return (StatusCode::SERVICE_UNAVAILABLE, "relay channel not active").into_response(); + } + } + + Json(serde_json::json!({"ok": true})).into_response() +} + /// OAuth callback for Slack via channel-relay. /// /// This is a PUBLIC route (no Bearer token required) because channel-relay /// redirects the user's browser here after Slack OAuth completes. -/// Query params: `stream_token`, `provider`, `team_id`. +/// Query params: `provider`, `team_id`. async fn slack_relay_oauth_callback_handler( State(state): State>, Query(params): Query>, @@ -763,27 +856,6 @@ async fn slack_relay_oauth_callback_handler( .into_response(); } - // Validate stream_token: required, non-empty, max 2048 bytes - let stream_token = match params.get("stream_token") { - Some(t) if !t.is_empty() && t.len() <= 2048 => t.clone(), - Some(t) if t.len() > 2048 => { - return axum::response::Html( - "\ -

Error

Invalid callback parameters.

" - .to_string(), - ) - .into_response(); - } - _ => { - return axum::response::Html( - "\ -

Error

Invalid callback parameters.

" - .to_string(), - ) - .into_response(); - } - }; - // Validate team_id format: empty or T followed by alphanumeric (max 20 chars) let team_id = params.get("team_id").cloned().unwrap_or_default(); if !team_id.is_empty() { @@ -869,30 +941,16 @@ async fn slack_relay_oauth_callback_handler( let _ = ext_mgr.secrets().delete(&state.user_id, &state_key).await; let result: Result<(), String> = async { - // Store the stream token as a secret - let token_key = format!("relay:{}:stream_token", DEFAULT_RELAY_NAME); - let _ = ext_mgr.secrets().delete(&state.user_id, &token_key).await; - ext_mgr - .secrets() - .create( - &state.user_id, - crate::secrets::CreateSecretParams { - name: token_key, - value: secrecy::SecretString::from(stream_token), - provider: Some(provider.clone()), - expires_at: None, - }, - ) - .await - .map_err(|e| format!("Failed to store stream token: {}", e))?; + let store = state.store.as_ref().ok_or_else(|| { + "Relay activation requires persistent settings storage; no-db mode is unsupported." + .to_string() + })?; // Store team_id in settings - if let Some(ref store) = state.store { - let team_id_key = format!("relay:{}:team_id", DEFAULT_RELAY_NAME); - let _ = store - .set_setting(&state.user_id, &team_id_key, &serde_json::json!(team_id)) - .await; - } + let team_id_key = format!("relay:{}:team_id", DEFAULT_RELAY_NAME); + let _ = store + .set_setting(&state.user_id, &team_id_key, &serde_json::json!(team_id)) + .await; // Activate the relay channel ext_mgr @@ -3516,7 +3574,7 @@ mod tests { // Callback without state param should be rejected let req = axum::http::Request::builder() - .uri("/oauth/slack/callback?stream_token=tok123&team_id=T123&provider=slack") + .uri("/oauth/slack/callback?team_id=T123&provider=slack") .body(Body::empty()) .expect("request"); @@ -3560,7 +3618,7 @@ mod tests { // Callback with wrong state param let req = axum::http::Request::builder() - .uri("/oauth/slack/callback?stream_token=tok123&team_id=T123&provider=slack&state=wrong-nonce") + .uri("/oauth/slack/callback?team_id=T123&provider=slack&state=wrong-nonce") .body(Body::empty()) .expect("request"); @@ -3608,7 +3666,7 @@ mod tests { // we just verify it doesn't return a CSRF error. let req = axum::http::Request::builder() .uri(format!( - "/oauth/slack/callback?stream_token=tok123&team_id=T123&provider=slack&state={}", + "/oauth/slack/callback?team_id=T123&provider=slack&state={}", nonce )) .body(Body::empty()) diff --git a/src/config/relay.rs b/src/config/relay.rs index d45de18820..e1ba82216a 100644 --- a/src/config/relay.rs +++ b/src/config/relay.rs @@ -7,7 +7,7 @@ use secrecy::SecretString; pub struct RelayConfig { /// Base URL of the channel-relay service (e.g., `http://localhost:3001`). pub url: String, - /// API key for authenticated channel-relay endpoints. + /// Bearer token for authenticated channel-relay endpoints (`sk-agent-*`). pub api_key: SecretString, /// Override for the OAuth callback URL (e.g., a tunnel URL). pub callback_url: Option, @@ -15,12 +15,8 @@ pub struct RelayConfig { pub instance_id: Option, /// HTTP request timeout in seconds (default: 30). pub request_timeout_secs: u64, - /// SSE stream long-poll timeout in seconds (default: 86400 = 24 h). - pub stream_timeout_secs: u64, - /// Initial exponential backoff in milliseconds (default: 1000). - pub backoff_initial_ms: u64, - /// Maximum exponential backoff in milliseconds (default: 60000). - pub backoff_max_ms: u64, + /// Path for the webhook callback endpoint (default: `/relay/events`). + pub webhook_path: String, } impl std::fmt::Debug for RelayConfig { @@ -31,9 +27,7 @@ impl std::fmt::Debug for RelayConfig { .field("callback_url", &self.callback_url) .field("instance_id", &self.instance_id) .field("request_timeout_secs", &self.request_timeout_secs) - .field("stream_timeout_secs", &self.stream_timeout_secs) - .field("backoff_initial_ms", &self.backoff_initial_ms) - .field("backoff_max_ms", &self.backoff_max_ms) + .field("webhook_path", &self.webhook_path) .finish() } } @@ -41,8 +35,10 @@ impl std::fmt::Debug for RelayConfig { impl RelayConfig { /// Load relay config from environment variables. /// - /// Returns `None` if either `CHANNEL_RELAY_URL` or `CHANNEL_RELAY_API_KEY` - /// is not set, making the relay integration opt-in. + /// Returns `None` if either of the required env vars (`CHANNEL_RELAY_URL`, + /// `CHANNEL_RELAY_API_KEY`) is not set, making the relay integration opt-in. + /// The signing secret is fetched from channel-relay at activation time via + /// the authenticated `/relay/signing-secret` endpoint — no env var required. pub fn from_env() -> Option { Self::from_env_reader(|key| std::env::var(key).ok()) } @@ -55,9 +51,7 @@ impl RelayConfig { callback_url: None, instance_id: None, request_timeout_secs: 30, - stream_timeout_secs: 86400, - backoff_initial_ms: 1000, - backoff_max_ms: 60000, + webhook_path: "/relay/events".into(), } } @@ -73,15 +67,7 @@ impl RelayConfig { request_timeout_secs: env("RELAY_REQUEST_TIMEOUT_SECS") .and_then(|v| v.parse().ok()) .unwrap_or(30), - stream_timeout_secs: env("RELAY_STREAM_TIMEOUT_SECS") - .and_then(|v| v.parse().ok()) - .unwrap_or(86400), - backoff_initial_ms: env("RELAY_BACKOFF_INITIAL_MS") - .and_then(|v| v.parse().ok()) - .unwrap_or(1000), - backoff_max_ms: env("RELAY_BACKOFF_MAX_MS") - .and_then(|v| v.parse().ok()) - .unwrap_or(60000), + webhook_path: env("RELAY_WEBHOOK_PATH").unwrap_or_else(|| "/relay/events".into()), }) } } @@ -97,7 +83,21 @@ mod tests { } #[test] - fn from_env_reader_loads_defaults() { + fn from_env_reader_requires_only_url_and_api_key() { + // Signing secret is fetched at activation time — only URL + API key needed. + let config = RelayConfig::from_env_reader(|key| match key { + "CHANNEL_RELAY_URL" => Some("http://localhost:3001".into()), + "CHANNEL_RELAY_API_KEY" => Some("test-key".into()), + _ => None, + }); + assert!( + config.is_some(), + "relay config should load with just URL + API key" + ); + } + + #[test] + fn from_env_reader_loads_all_required() { let config = RelayConfig::from_env_reader(|key| match key { "CHANNEL_RELAY_URL" => Some("http://localhost:3001".into()), "CHANNEL_RELAY_API_KEY" => Some("test-key".into()), @@ -107,9 +107,7 @@ mod tests { assert_eq!(config.url, "http://localhost:3001"); assert_eq!(config.request_timeout_secs, 30); - assert_eq!(config.stream_timeout_secs, 86400); - assert_eq!(config.backoff_initial_ms, 1000); - assert_eq!(config.backoff_max_ms, 60000); + assert_eq!(config.webhook_path, "/relay/events"); assert!(config.callback_url.is_none()); assert!(config.instance_id.is_none()); } @@ -122,9 +120,7 @@ mod tests { "IRONCLAW_OAUTH_CALLBACK_URL" => Some("https://tunnel.example.com".into()), "IRONCLAW_INSTANCE_ID" => Some("my-instance".into()), "RELAY_REQUEST_TIMEOUT_SECS" => Some("60".into()), - "RELAY_STREAM_TIMEOUT_SECS" => Some("43200".into()), - "RELAY_BACKOFF_INITIAL_MS" => Some("2000".into()), - "RELAY_BACKOFF_MAX_MS" => Some("120000".into()), + "RELAY_WEBHOOK_PATH" => Some("/custom/events".into()), _ => None, }) .expect("config should be Some"); @@ -135,9 +131,7 @@ mod tests { ); assert_eq!(config.instance_id.as_deref(), Some("my-instance")); assert_eq!(config.request_timeout_secs, 60); - assert_eq!(config.stream_timeout_secs, 43200); - assert_eq!(config.backoff_initial_ms, 2000); - assert_eq!(config.backoff_max_ms, 120000); + assert_eq!(config.webhook_path, "/custom/events"); } #[test] @@ -148,7 +142,7 @@ mod tests { } #[test] - fn debug_redacts_api_key() { + fn debug_redacts_secrets() { let config = RelayConfig::from_values("http://localhost:3001", "super-secret"); let debug = format!("{:?}", config); assert!(debug.contains("[REDACTED]")); diff --git a/src/extensions/manager.rs b/src/extensions/manager.rs index 00d787a5a3..fbc06d5d93 100644 --- a/src/extensions/manager.rs +++ b/src/extensions/manager.rs @@ -361,6 +361,18 @@ pub struct ExtensionManager { /// Relay config captured at startup. Used by `auth_channel_relay` and /// `activate_channel_relay` instead of re-reading env vars. relay_config: Option, + /// Shared event sender for the relay webhook endpoint. + /// Populated by `activate_channel_relay`, consumed by the web gateway's + /// `/relay/events` handler. + relay_event_tx: Arc< + tokio::sync::Mutex< + Option>, + >, + >, + /// Per-instance callback signing secret fetched from channel-relay at activation. + /// Stored here so the web gateway can verify incoming callbacks without + /// any env var or shared secret. + relay_signing_secret_cache: Arc>>>, /// When `true`, OAuth flows always return an auth URL to the caller /// instead of opening a browser on the server via `open::that()`. /// Set by the web gateway at startup via `enable_gateway_mode()`. @@ -446,6 +458,8 @@ impl ExtensionManager { pending_oauth_flows: crate::cli::oauth_defaults::new_pending_oauth_registry(), gateway_token: std::env::var("GATEWAY_AUTH_TOKEN").ok(), relay_config: crate::config::RelayConfig::from_env(), + relay_event_tx: Arc::new(tokio::sync::Mutex::new(None)), + relay_signing_secret_cache: Arc::new(std::sync::Mutex::new(None)), gateway_mode: std::sync::atomic::AtomicBool::new(false), gateway_base_url: RwLock::new(None), pending_telegram_verification: RwLock::new(HashMap::new()), @@ -564,6 +578,33 @@ impl ExtensionManager { }) } + /// Get the shared relay event sender for the webhook endpoint. + pub fn relay_event_tx( + &self, + ) -> Arc< + tokio::sync::Mutex< + Option>, + >, + > { + Arc::clone(&self.relay_event_tx) + } + + /// Get the per-instance callback signing secret for webhook signature verification. + /// + /// Returns the secret that was fetched from channel-relay's + /// `/relay/signing-secret` endpoint during `activate_channel_relay`. + /// Returns `None` if the relay channel has not been activated yet. + pub fn relay_signing_secret(&self) -> Option> { + self.relay_signing_secret_cache.lock().ok()?.clone() + } + + async fn clear_relay_webhook_state(&self) { + *self.relay_event_tx.lock().await = None; + if let Ok(mut cache) = self.relay_signing_secret_cache.lock() { + *cache = None; + } + } + /// Inject a registry entry for testing. The entry is added to the discovery /// cache so it appears in search results alongside built-in entries. pub async fn inject_registry_entry(&self, entry: crate::extensions::RegistryEntry) { @@ -753,12 +794,25 @@ impl ExtensionManager { *self.relay_channel_manager.write().await = Some(channel_manager); } - /// Check if a channel name corresponds to a relay extension (has stored stream token). + /// Check if a channel name corresponds to a relay extension (has stored team_id + /// or is tracked in the installed relay extensions set). pub async fn is_relay_channel(&self, name: &str) -> bool { - self.secrets - .exists(&self.user_id, &format!("relay:{}:stream_token", name)) - .await - .unwrap_or(false) + // Check in-memory installed set first (supports no-store mode) + if self.installed_relay_extensions.read().await.contains(name) { + return true; + } + // Then check persistent settings + if let Some(ref store) = self.store { + let team_id_key = format!("relay:{}:team_id", name); + store + .get_setting(&self.user_id, &team_id_key) + .await + .ok() + .flatten() + .is_some() + } else { + false + } } /// Restore persisted relay channels after startup. @@ -1167,11 +1221,7 @@ impl ExtensionManager { let active_names = self.active_channel_names.read().await; for name in installed.iter() { let active = active_names.contains(name); - let has_token = self - .secrets - .exists(&self.user_id, &format!("relay:{}:stream_token", name)) - .await - .unwrap_or(false); + let has_token = self.is_relay_channel(name).await; let registry_entry = self .registry .get_with_kind(name, Some(ExtensionKind::ChannelRelay)) @@ -1365,19 +1415,26 @@ impl ExtensionManager { // Remove from active channels self.active_channel_names.write().await.remove(name); self.persist_active_channels().await; + self.activation_errors.write().await.remove(name); - // Remove stored stream token - let _ = self - .secrets - .delete(&self.user_id, &format!("relay:{}:stream_token", name)) - .await; + // Remove stored team_id + if let Some(ref store) = self.store { + let _ = store + .delete_setting(&self.user_id, &format!("relay:{}:team_id", name)) + .await; + } - // Shut down the channel (check both runtime paths for WASM+relay and relay-only modes) + // Stop webhook traffic before removing the channel from the managers. + self.clear_relay_webhook_state().await; + + // Shut down and remove the channel (check both runtime paths for + // WASM+relay and relay-only modes). let mut shut_down = false; if let Some(ref rt) = *self.channel_runtime.read().await && let Some(channel) = rt.channel_manager.get_channel(name).await { let _ = channel.shutdown().await; + rt.channel_manager.remove(name).await; shut_down = true; } if !shut_down @@ -1385,6 +1442,7 @@ impl ExtensionManager { && let Some(channel) = cm.get_channel(name).await { let _ = channel.shutdown().await; + cm.remove(name).await; } Ok(format!("Removed channel relay '{}'", name)) @@ -3880,25 +3938,14 @@ impl ExtensionManager { /// For Telegram: accepts a bot token, registers it with channel-relay, /// and stores the returned stream token. async fn auth_channel_relay(&self, name: &str) -> Result { - // Check if already authenticated (stream token exists) - let token_key = format!("relay:{}:stream_token", name); - if self - .secrets - .exists(&self.user_id, &token_key) - .await - .unwrap_or(false) - { + // Check if already authenticated (has stored team_id) + if self.is_relay_channel(name).await { return Ok(AuthResult::authenticated(name, ExtensionKind::ChannelRelay)); } // Use relay config captured at startup let relay_config = self.relay_config()?; - let instance_id = self.relay_instance_id(relay_config); - let user_id_uuid = std::env::var("IRONCLAW_USER_ID").unwrap_or_else(|_| { - uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_DNS, self.user_id.as_bytes()).to_string() - }); - let client = crate::channels::relay::RelayClient::new( relay_config.url.clone(), relay_config.api_key.clone(), @@ -3906,22 +3953,11 @@ impl ExtensionManager { ) .map_err(|e| ExtensionError::Config(e.to_string()))?; - // OAuth redirect flow - let callback_base = self - .tunnel_url - .clone() - .or_else(|| relay_config.callback_url.clone()) - .unwrap_or_else(|| { - let host = std::env::var("GATEWAY_HOST").unwrap_or_else(|_| "127.0.0.1".into()); - let port = std::env::var("GATEWAY_PORT") - .unwrap_or_else(|_| crate::config::DEFAULT_GATEWAY_PORT.to_string()); - format!("http://{}:{}", host, port) - }); - - // Generate CSRF nonce for OAuth state parameter + // Generate CSRF nonce — IronClaw validates this on the callback to ensure + // the OAuth completion is legitimate. Channel-relay embeds it in the signed + // state and appends it to the post-OAuth redirect URL. let state_nonce = uuid::Uuid::new_v4().to_string(); let state_key = format!("relay:{}:oauth_state", name); - // Delete any stale nonce before storing the new one let _ = self.secrets.delete(&self.user_id, &state_key).await; self.secrets .create( @@ -3931,15 +3967,9 @@ impl ExtensionManager { .await .map_err(|e| ExtensionError::AuthFailed(format!("Failed to store OAuth state: {e}")))?; - let callback_url = format!( - "{}/oauth/slack/callback?state={}", - callback_base, state_nonce - ); - - match client - .initiate_oauth(&instance_id, &user_id_uuid, &callback_url) - .await - { + // Channel-relay derives all URLs from trusted instance_url in chat-api. + // We only pass the nonce for CSRF validation on the callback. + match client.initiate_oauth(Some(&state_nonce)).await { Ok(auth_url) => Ok(AuthResult::awaiting_authorization( name, ExtensionKind::ChannelRelay, @@ -3952,29 +3982,17 @@ impl ExtensionManager { /// Activate a channel-relay extension. async fn activate_channel_relay(&self, name: &str) -> Result { - let token_key = format!("relay:{}:stream_token", name); let team_id_key = format!("relay:{}:team_id", name); - // Check if we have a stream token - let stream_token = match self.secrets.get_decrypted(&self.user_id, &token_key).await { - Ok(secret) => secret.expose().to_string(), - Err(_) => { - return Err(ExtensionError::AuthRequired); - } - }; - - // Get team_id from settings - let team_id = if let Some(ref store) = self.store { - store - .get_setting(&self.user_id, &team_id_key) - .await - .ok() - .flatten() - .and_then(|v| v.as_str().map(|s| s.to_string())) - .unwrap_or_default() - } else { - String::new() - }; + let store = self.store.as_ref().ok_or(ExtensionError::AuthRequired)?; + let team_id = store + .get_setting(&self.user_id, &team_id_key) + .await + .ok() + .flatten() + .and_then(|v| v.as_str().map(|s| s.to_string())) + .filter(|s| !s.is_empty()) + .ok_or(ExtensionError::AuthRequired)?; // Use relay config captured at startup let relay_config = self.relay_config()?; @@ -3988,18 +4006,29 @@ impl ExtensionManager { ) .map_err(|e| ExtensionError::ActivationFailed(e.to_string()))?; + // Fetch the per-instance signing secret from channel-relay. + // This must succeed — there is no fallback. + let signing_secret = client.get_signing_secret(&team_id).await.map_err(|e| { + ExtensionError::Config(format!("Failed to fetch relay signing secret: {e}")) + })?; + + // Create the event channel for webhook callbacks + let (event_tx, event_rx) = tokio::sync::mpsc::channel(64); + let channel = crate::channels::relay::RelayChannel::new_with_provider( - client, + client.clone(), crate::channels::relay::channel::RelayProvider::Slack, - stream_token, - team_id, - instance_id, - self.user_id.clone(), - ) - .with_timeouts( - relay_config.stream_timeout_secs, - relay_config.backoff_initial_ms, - relay_config.backoff_max_ms, + team_id.clone(), + instance_id.clone(), + event_tx.clone(), + event_rx, + ); + + // Callback URL is now set during OAuth flow, not via PUT /callbacks. + // The relay webhook endpoint path is still needed for the web gateway. + tracing::info!( + webhook_path = %relay_config.webhook_path, + "Relay channel activated (callback URL set during OAuth)" ); // Hot-add to channel manager @@ -4013,6 +4042,13 @@ impl ExtensionManager { .await .map_err(|e| ExtensionError::ActivationFailed(e.to_string()))?; + if let Ok(mut cache) = self.relay_signing_secret_cache.lock() { + *cache = Some(signing_secret); + } + + // Store the event sender so the web gateway's relay webhook endpoint can push events + *self.relay_event_tx.lock().await = Some(event_tx); + // Mark as active self.active_channel_names .write() @@ -4035,11 +4071,11 @@ impl ExtensionManager { /// Activate a channel-relay extension from stored credentials (for startup reconnect). pub async fn activate_stored_relay(&self, name: &str) -> Result<(), ExtensionError> { + self.activate_channel_relay(name).await?; self.installed_relay_extensions .write() .await .insert(name.to_string()); - self.activate_channel_relay(name).await?; Ok(()) } @@ -4070,13 +4106,8 @@ impl ExtensionManager { if self.installed_relay_extensions.read().await.contains(name) { return Ok(ExtensionKind::ChannelRelay); } - // Also check if there's a stored stream token (persisted across restarts) - if self - .secrets - .exists(&self.user_id, &format!("relay:{}:stream_token", name)) - .await - .unwrap_or(false) - { + // Also check if there's a stored team_id (persisted across restarts) + if self.is_relay_channel(name).await { return Ok(ExtensionKind::ChannelRelay); } @@ -6351,24 +6382,24 @@ mod tests { } #[tokio::test] - async fn test_is_relay_channel_detects_stored_token() { + async fn test_is_relay_channel_returns_false_without_store() { let dir = tempfile::tempdir().expect("temp dir"); let mgr = make_test_manager(None, dir.path().to_path_buf()); - // No token stored → not a relay channel + // With no DB store, is_relay_channel always returns false assert!(!mgr.is_relay_channel("slack-relay").await); + } - // Store a stream token - mgr.secrets - .create( - "test", - crate::secrets::CreateSecretParams::new("relay:slack-relay:stream_token", "tok123"), - ) - .await - .expect("store token"); + #[tokio::test] + async fn test_activate_channel_relay_without_store_returns_auth_required() { + let dir = tempfile::tempdir().expect("temp dir"); + let mgr = make_test_manager(None, dir.path().to_path_buf()); - // Now it's detected as a relay channel - assert!(mgr.is_relay_channel("slack-relay").await); + let err = mgr.activate_channel_relay("slack-relay").await.unwrap_err(); + assert!( + matches!(err, ExtensionError::AuthRequired), + "expected AuthRequired, got: {err:?}" + ); } #[tokio::test] @@ -6384,18 +6415,25 @@ mod tests { cm.add(Box::new(stub)).await; mgr.set_relay_channel_manager(Arc::clone(&cm)).await; - // Mark as installed + store a token so determine_installed_kind finds it + // Mark as installed + store team_id so determine_installed_kind finds it mgr.installed_relay_extensions .write() .await .insert("slack-relay".to_string()); - mgr.secrets - .create( - "test", - crate::secrets::CreateSecretParams::new("relay:slack-relay:stream_token", "tok123"), - ) - .await - .expect("store token"); + *mgr.relay_event_tx.lock().await = Some(tokio::sync::mpsc::channel(1).0); + if let Ok(mut cache) = mgr.relay_signing_secret_cache.lock() { + *cache = Some(vec![9u8; 32]); + } + if let Some(ref store) = mgr.store { + store + .set_setting( + "test", + "relay:slack-relay:team_id", + &serde_json::json!("T123"), + ) + .await + .expect("store team_id"); + } // Verify channel exists before removal assert!(cm.get_channel("slack-relay").await.is_some()); @@ -6412,6 +6450,18 @@ mod tests { .contains("slack-relay"), "Should be removed from installed set" ); + assert!( + mgr.relay_event_tx.lock().await.is_none(), + "relay event sender should be cleared on remove" + ); + assert!( + mgr.relay_signing_secret().is_none(), + "relay signing secret cache should be cleared on remove" + ); + assert!( + cm.get_channel("slack-relay").await.is_none(), + "relay channel should be removed from the channel manager" + ); } #[tokio::test] diff --git a/tests/relay_integration.rs b/tests/relay_integration.rs index 8479cd6730..0a053885d9 100644 --- a/tests/relay_integration.rs +++ b/tests/relay_integration.rs @@ -2,18 +2,12 @@ //! //! Uses real HTTP servers on random ports (no mock framework). -use std::convert::Infallible; -use std::sync::atomic::{AtomicUsize, Ordering}; - use axum::{ Json, Router, extract::Query, - http::StatusCode, - response::sse::{Event, KeepAlive, Sse}, routing::{get, post}, }; -use futures::stream; -use ironclaw::channels::relay::client::{RelayClient, RelayError}; +use ironclaw::channels::relay::client::{ChannelEvent, RelayClient}; use secrecy::SecretString; use serde::Deserialize; use tokio::net::TcpListener; @@ -37,109 +31,79 @@ fn test_client(base_url: &str) -> RelayClient { .expect("client build") } -// ── SSE stream mock ───────────────────────────────────────────────────── +// ── Signing secret fetch ───────────────────────────────────────────────── #[tokio::test] -async fn test_sse_stream_receives_events() { +async fn test_get_signing_secret_returns_decoded_bytes() { + let secret_hex = hex::encode([1u8; 32]); + let secret_hex_clone = secret_hex.clone(); let app = Router::new().route( - "/stream", - get( - |Query(params): Query>| async move { - // Verify token is passed - assert!(params.contains_key("token")); - - let events = vec![ - Ok::<_, Infallible>( - Event::default().event("message").data( - serde_json::json!({ - "event_type": "message", - "provider": "slack", - "provider_scope": "T123", - "channel_id": "C456", - "sender_id": "U789", - "content": "hello world" - }) - .to_string(), - ), - ), - Ok(Event::default().event("message").data( - serde_json::json!({ - "event_type": "direct_message", - "provider": "slack", - "provider_scope": "T123", - "channel_id": "D001", - "sender_id": "U789", - "content": "dm text" - }) - .to_string(), - )), - ]; - - Sse::new(stream::iter(events)).keep_alive(KeepAlive::default()) - }, - ), + "/relay/signing-secret", + get(move || { + let s = secret_hex_clone.clone(); + async move { Json(serde_json::json!({"signing_secret": s})) } + }), ); let base_url = start_server(app).await; let client = test_client(&base_url); - let (mut event_stream, handle) = client.connect_stream("test-token", 30).await.unwrap(); + let secret = client.get_signing_secret("T123").await.unwrap(); + assert_eq!(secret, vec![1u8; 32]); +} - use futures::StreamExt; - let first = event_stream.next().await.expect("first event"); - assert_eq!(first.event_type, "message"); - assert_eq!(first.text(), "hello world"); - assert_eq!(first.team_id(), "T123"); +#[tokio::test] +async fn test_get_signing_secret_404_returns_error() { + let app = Router::new().route( + "/relay/signing-secret", + get(|| async { (axum::http::StatusCode::NOT_FOUND, "not found") }), + ); - let second = event_stream.next().await.expect("second event"); - assert_eq!(second.event_type, "direct_message"); - assert_eq!(second.text(), "dm text"); + let base_url = start_server(app).await; + let client = test_client(&base_url); - handle.abort(); + let result = client.get_signing_secret("T123").await; + assert!(result.is_err()); } -// ── Token renewal flow ────────────────────────────────────────────────── - #[tokio::test] -async fn test_token_expired_returns_error() { - let app = Router::new().route("/stream", get(|| async { StatusCode::UNAUTHORIZED })); +async fn test_get_signing_secret_invalid_hex_returns_protocol_error() { + let app = Router::new().route( + "/relay/signing-secret", + get(|| async { Json(serde_json::json!({"signing_secret": "not-hex"})) }), + ); let base_url = start_server(app).await; let client = test_client(&base_url); - match client.connect_stream("expired-token", 30).await { - Err(RelayError::TokenExpired) => {} // expected - Err(other) => panic!("expected TokenExpired, got: {other}"), - Ok(_) => panic!("expected error, got Ok"), - } + let err = client + .get_signing_secret("T123") + .await + .unwrap_err() + .to_string(); + assert!(err.contains("invalid signing_secret hex"), "got: {err}"); } #[tokio::test] -async fn test_token_renewal() { - let call_count = std::sync::Arc::new(AtomicUsize::new(0)); - let call_count_clone = call_count.clone(); - +async fn test_get_signing_secret_wrong_length_returns_protocol_error() { + let short_secret_hex = hex::encode([7u8; 31]); let app = Router::new().route( - "/stream/renew", - post(move |Json(body): Json| { - let count = call_count_clone.clone(); - async move { - count.fetch_add(1, Ordering::SeqCst); - assert!(body.get("instance_id").is_some()); - assert!(body.get("user_id").is_some()); - Json(serde_json::json!({ - "stream_token": "renewed-token-123" - })) - } + "/relay/signing-secret", + get(move || { + let s = short_secret_hex.clone(); + async move { Json(serde_json::json!({"signing_secret": s})) } }), ); let base_url = start_server(app).await; let client = test_client(&base_url); - let new_token = client.renew_token("inst-1", "user-1").await.unwrap(); - assert_eq!(new_token, "renewed-token-123"); - assert_eq!(call_count.load(Ordering::SeqCst), 1); + let err = client + .get_signing_secret("T123") + .await + .unwrap_err() + .to_string(); + assert!(err.contains("expected 32 bytes"), "got: {err}"); } // ── Proxy call ────────────────────────────────────────────────────────── @@ -171,7 +135,7 @@ async fn test_proxy_provider_sends_correct_payload() { "text": "Hello from test", }); let resp = client - .proxy_provider("slack", "T123", "chat.postMessage", body, None) + .proxy_provider("slack", "T123", "chat.postMessage", body) .await .unwrap(); assert_eq!(resp["ok"], true); @@ -200,18 +164,18 @@ async fn test_list_connections() { assert!(!conns[1].connected); } -// ── API key header ────────────────────────────────────────────────────── +// ── Bearer token auth ──────────────────────────────────────────────────── #[tokio::test] -async fn test_api_key_sent_in_header() { +async fn test_bearer_token_sent_in_header() { let app = Router::new().route( "/connections", get(|headers: axum::http::HeaderMap| async move { - let key = headers - .get("X-API-Key") + let auth = headers + .get("authorization") .and_then(|v| v.to_str().ok()) .unwrap_or(""); - assert_eq!(key, "test-api-key"); + assert_eq!(auth, "Bearer test-api-key"); Json(serde_json::json!([])) }), ); @@ -233,82 +197,10 @@ fn test_relay_client_new_succeeds() { assert!(client.is_ok()); } -// ── SSE UTF-8 chunk boundary ──────────────────────────────────────────── - -/// Verify that multi-byte UTF-8 characters split across SSE chunks are -/// not corrupted (no U+FFFD replacement characters). -#[tokio::test] -async fn test_sse_stream_preserves_multibyte_utf8_across_chunks() { - use std::sync::atomic::{AtomicBool, Ordering}; - - let sent = std::sync::Arc::new(AtomicBool::new(false)); - let sent_clone = sent.clone(); - - let app = Router::new().route( - "/stream", - get(move |_: Query>| { - let sent = sent_clone.clone(); - async move { - // Build SSE payload with emoji that will be split mid-character - let event_data = serde_json::json!({ - "event_type": "message", - "provider": "slack", - "provider_scope": "T1", - "channel_id": "C1", - "sender_id": "U1", - "content": "hello 🦀 world" - }); - let payload = format!("event: message\ndata: {}\n\n", event_data); - let bytes = payload.into_bytes(); - - // Split in the middle of the 4-byte crab emoji - let crab_pos = bytes - .windows(4) - .position(|w| w == [0xF0, 0x9F, 0xA6, 0x80]) - .unwrap(); - let split_at = crab_pos + 2; - - let chunk1 = bytes[..split_at].to_vec(); - let chunk2 = bytes[split_at..].to_vec(); - - sent.store(true, Ordering::SeqCst); - - let events = vec![ - Ok::<_, Infallible>(axum::body::Bytes::from(chunk1)), - Ok(axum::body::Bytes::from(chunk2)), - ]; - - axum::response::Response::builder() - .header("content-type", "text/event-stream") - .body(axum::body::Body::from_stream(stream::iter(events))) - .unwrap() - } - }), - ); - - let base_url = start_server(app).await; - let client = test_client(&base_url); - - let (mut event_stream, handle) = client.connect_stream("tok", 30).await.unwrap(); - - use futures::StreamExt; - let event = event_stream.next().await.expect("should get event"); - assert_eq!( - event.text(), - "hello 🦀 world", - "emoji should not be corrupted" - ); - assert!(sent.load(Ordering::SeqCst)); - - handle.abort(); -} - // ── Channel event field validation ────────────────────────────────────── #[test] fn test_channel_event_missing_fields_detected() { - use ironclaw::channels::relay::client::ChannelEvent; - // Event with empty sender_id should be detectable let json = r#"{"event_type": "message", "provider_scope": "T1", "channel_id": "C1", "sender_id": "", "content": "test"}"#; let event: ChannelEvent = serde_json::from_str(json).unwrap();