diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 0b38db1eb49..5e0656fbe27 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -887,6 +887,7 @@ dependencies = [ "codex-file-search", "codex-login", "codex-protocol", + "codex-rmcp-client", "codex-utils-json-to-toml", "core_test_support", "mcp-types", diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index 28583667393..c62acc88324 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -139,6 +139,11 @@ client_request_definitions! { response: v2::ModelListResponse, }, + McpServerOauthLogin => "mcpServer/oauth/login" { + params: v2::McpServerOauthLoginParams, + response: v2::McpServerOauthLoginResponse, + }, + McpServersList => "mcpServers/list" { params: v2::ListMcpServersParams, response: v2::ListMcpServersResponse, @@ -524,6 +529,7 @@ server_notification_definitions! { CommandExecutionOutputDelta => "item/commandExecution/outputDelta" (v2::CommandExecutionOutputDeltaNotification), FileChangeOutputDelta => "item/fileChange/outputDelta" (v2::FileChangeOutputDeltaNotification), McpToolCallProgress => "item/mcpToolCall/progress" (v2::McpToolCallProgressNotification), + McpServerOauthLoginCompleted => "mcpServer/oauthLogin/completed" (v2::McpServerOauthLoginCompletedNotification), AccountUpdated => "account/updated" (v2::AccountUpdatedNotification), AccountRateLimitsUpdated => "account/rateLimits/updated" (v2::AccountRateLimitsUpdatedNotification), ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification), diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index ea70b805b0a..dbef55ed159 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -688,6 +688,26 @@ pub struct ListMcpServersResponse { pub next_cursor: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct McpServerOauthLoginParams { + pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub scopes: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub timeout_secs: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct McpServerOauthLoginResponse { + pub authorization_url: String, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -1467,6 +1487,17 @@ pub struct McpToolCallProgressNotification { pub message: String, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct McpServerOauthLoginCompletedNotification { + pub name: String, + pub success: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub error: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index 99d5a7a1410..e4a326a2c3e 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -26,6 +26,7 @@ codex-login = { workspace = true } codex-protocol = { workspace = true } codex-app-server-protocol = { workspace = true } codex-feedback = { workspace = true } +codex-rmcp-client = { workspace = true } codex-utils-json-to-toml = { workspace = true } chrono = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 65721a698ef..79d11e16548 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -55,6 +55,9 @@ use codex_app_server_protocol::LoginChatGptResponse; use codex_app_server_protocol::LogoutAccountResponse; use codex_app_server_protocol::LogoutChatGptResponse; use codex_app_server_protocol::McpServer; +use codex_app_server_protocol::McpServerOauthLoginCompletedNotification; +use codex_app_server_protocol::McpServerOauthLoginParams; +use codex_app_server_protocol::McpServerOauthLoginResponse; use codex_app_server_protocol::ModelListParams; use codex_app_server_protocol::ModelListResponse; use codex_app_server_protocol::NewConversationParams; @@ -115,6 +118,7 @@ use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config::ConfigToml; use codex_core::config::edit::ConfigEditsBuilder; +use codex_core::config::types::McpServerTransportConfig; use codex_core::config_loader::load_config_as_toml; use codex_core::default_client::get_codex_user_agent; use codex_core::exec::ExecParams; @@ -147,6 +151,7 @@ use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::SessionMetaLine; use codex_protocol::protocol::USER_MESSAGE_BEGIN; use codex_protocol::user_input::UserInput as CoreInputItem; +use codex_rmcp_client::perform_oauth_login_return_url; use codex_utils_json_to_toml::json_to_toml; use std::collections::HashMap; use std::collections::HashSet; @@ -161,6 +166,7 @@ use std::time::Duration; use tokio::select; use tokio::sync::Mutex; use tokio::sync::oneshot; +use toml::Value as TomlValue; use tracing::error; use tracing::info; use tracing::warn; @@ -198,6 +204,7 @@ pub(crate) struct CodexMessageProcessor { outgoing: Arc, codex_linux_sandbox_exe: Option, config: Arc, + cli_overrides: Vec<(String, TomlValue)>, conversation_listeners: HashMap>, active_login: Arc>>, // Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives. @@ -244,6 +251,7 @@ impl CodexMessageProcessor { outgoing: Arc, codex_linux_sandbox_exe: Option, config: Arc, + cli_overrides: Vec<(String, TomlValue)>, feedback: CodexFeedback, ) -> Self { Self { @@ -252,6 +260,7 @@ impl CodexMessageProcessor { outgoing, codex_linux_sandbox_exe, config, + cli_overrides, conversation_listeners: HashMap::new(), active_login: Arc::new(Mutex::new(None)), pending_interrupts: Arc::new(Mutex::new(HashMap::new())), @@ -261,6 +270,16 @@ impl CodexMessageProcessor { } } + async fn load_latest_config(&self) -> Result { + Config::load_with_cli_overrides(self.cli_overrides.clone(), ConfigOverrides::default()) + .await + .map_err(|err| JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("failed to reload config: {err}"), + data: None, + }) + } + fn review_request_from_target( target: ApiReviewTarget, ) -> Result<(ReviewRequest, String), JSONRPCErrorError> { @@ -369,6 +388,9 @@ impl CodexMessageProcessor { ClientRequest::ModelList { request_id, params } => { self.list_models(request_id, params).await; } + ClientRequest::McpServerOauthLogin { request_id, params } => { + self.mcp_server_oauth_login(request_id, params).await; + } ClientRequest::McpServersList { request_id, params } => { self.list_mcp_servers(request_id, params).await; } @@ -1916,6 +1938,115 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } + async fn mcp_server_oauth_login( + &self, + request_id: RequestId, + params: McpServerOauthLoginParams, + ) { + let config = match self.load_latest_config().await { + Ok(config) => config, + Err(error) => { + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + if !config.features.enabled(Feature::RmcpClient) { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: "OAuth login is only supported when [features].rmcp_client is true in config.toml".to_string(), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + + let McpServerOauthLoginParams { + name, + scopes, + timeout_secs, + } = params; + + let Some(server) = config.mcp_servers.get(&name) else { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: format!("No MCP server named '{name}' found."), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + }; + + let (url, http_headers, env_http_headers) = match &server.transport { + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + .. + } => (url.clone(), http_headers.clone(), env_http_headers.clone()), + _ => { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: "OAuth login is only supported for streamable HTTP servers." + .to_string(), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + match perform_oauth_login_return_url( + &name, + &url, + config.mcp_oauth_credentials_store_mode, + http_headers, + env_http_headers, + scopes.as_deref().unwrap_or_default(), + timeout_secs, + ) + .await + { + Ok(handle) => { + let authorization_url = handle.authorization_url().to_string(); + let notification_name = name.clone(); + let outgoing = Arc::clone(&self.outgoing); + let conversation_manager = Arc::clone(&self.conversation_manager); + + tokio::spawn(async move { + let (success, error) = match handle.wait().await { + Ok(()) => (true, None), + Err(err) => (false, Some(err.to_string())), + }; + + if success { + conversation_manager.mark_mcp_oauth_success(Utc::now().timestamp()); + } + + let notification = ServerNotification::McpServerOauthLoginCompleted( + McpServerOauthLoginCompletedNotification { + name: notification_name, + success, + error, + }, + ); + outgoing.send_server_notification(notification).await; + }); + + let response = McpServerOauthLoginResponse { authorization_url }; + self.outgoing.send_response(request_id, response).await; + } + Err(err) => { + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("failed to login to MCP server '{name}': {err}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + } + } + } + async fn list_mcp_servers(&self, request_id: RequestId, params: ListMcpServersParams) { let snapshot = collect_mcp_snapshot(self.config.as_ref()).await; diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 90560e9b3c5..6a6cf5edb25 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -59,6 +59,7 @@ impl MessageProcessor { outgoing.clone(), codex_linux_sandbox_exe, Arc::clone(&config), + cli_overrides.clone(), feedback, ); let config_api = ConfigApi::new(config.codex_home.clone(), cli_overrides); diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 042ae1a37a5..279e3db0c7c 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -55,6 +55,8 @@ use mcp_types::ReadResourceResult; use mcp_types::RequestId; use serde_json; use serde_json::Value; +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; use tokio::sync::Mutex; use tokio::sync::RwLock; use tokio::sync::oneshot; @@ -170,6 +172,7 @@ impl Codex { models_manager: Arc, conversation_history: InitialHistory, session_source: SessionSource, + mcp_oauth_refresh_clock: Arc, ) -> CodexResult { let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY); let (tx_event, rx_event) = async_channel::unbounded(); @@ -210,6 +213,7 @@ impl Codex { tx_event.clone(), conversation_history, session_source_clone, + mcp_oauth_refresh_clock.clone(), ) .await .map_err(|e| { @@ -466,6 +470,7 @@ impl Session { } } + #[allow(clippy::too_many_arguments)] async fn new( session_configuration: SessionConfiguration, config: Arc, @@ -474,6 +479,7 @@ impl Session { tx_event: Sender, initial_history: InitialHistory, session_source: SessionSource, + mcp_oauth_refresh_clock: Arc, ) -> anyhow::Result> { debug!( "Configuring session: model={}; provider={:?}", @@ -583,8 +589,11 @@ impl Session { let state = SessionState::new(session_configuration.clone()); let services = SessionServices { - mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())), + mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new( + mcp_oauth_refresh_clock.clone(), + ))), mcp_startup_cancellation_token: CancellationToken::new(), + mcp_oauth_refresh_clock, unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(config.notify.clone()), rollout: Mutex::new(Some(rollout_recorder)), @@ -1386,6 +1395,7 @@ impl Session { server: &str, params: Option, ) -> anyhow::Result { + self.refresh_mcp_clients_if_needed().await?; self.services .mcp_connection_manager .read() @@ -1399,6 +1409,7 @@ impl Session { server: &str, params: Option, ) -> anyhow::Result { + self.refresh_mcp_clients_if_needed().await?; self.services .mcp_connection_manager .read() @@ -1412,6 +1423,7 @@ impl Session { server: &str, params: ReadResourceRequestParams, ) -> anyhow::Result { + self.refresh_mcp_clients_if_needed().await?; self.services .mcp_connection_manager .read() @@ -1426,6 +1438,7 @@ impl Session { tool: &str, arguments: Option, ) -> anyhow::Result { + self.refresh_mcp_clients_if_needed().await?; self.services .mcp_connection_manager .read() @@ -1435,6 +1448,7 @@ impl Session { } pub(crate) async fn parse_mcp_tool_name(&self, tool_name: &str) -> Option<(String, String)> { + self.refresh_mcp_clients_if_needed().await.ok()?; self.services .mcp_connection_manager .read() @@ -1443,6 +1457,42 @@ impl Session { .await } + async fn refresh_mcp_clients_if_needed(&self) -> anyhow::Result<()> { + let current_clock = self.services.mcp_oauth_refresh_clock.load(Ordering::SeqCst); + let last_seen = { + let manager = self.services.mcp_connection_manager.read().await; + manager.last_refresh_seen() + }; + if current_clock <= last_seen { + return Ok(()); + } + + let config = { + let state = self.state.lock().await; + state + .session_configuration + .original_config_do_not_use + .clone() + }; + let store_mode = config.mcp_oauth_credentials_store_mode; + let auth_statuses = compute_auth_statuses(config.mcp_servers.iter(), store_mode).await; + + { + let mut manager = self.services.mcp_connection_manager.write().await; + manager + .refresh_if_needed( + &config.mcp_servers, + store_mode, + auth_statuses, + self.tx_event.clone(), + self.services.mcp_startup_cancellation_token.clone(), + ) + .await; + } + + Ok(()) + } + pub async fn interrupt_task(self: &Arc) { info!("interrupt received: abort current task, if any"); let has_active_turn = { self.active_turn.lock().await.is_some() }; @@ -2882,9 +2932,13 @@ mod tests { let state = SessionState::new(session_configuration.clone()); + let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0)); let services = SessionServices { - mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())), + mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new( + mcp_oauth_refresh_clock.clone(), + ))), mcp_startup_cancellation_token: CancellationToken::new(), + mcp_oauth_refresh_clock, unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(None), rollout: Mutex::new(None), @@ -2964,9 +3018,13 @@ mod tests { let state = SessionState::new(session_configuration.clone()); + let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0)); let services = SessionServices { - mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())), + mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new( + mcp_oauth_refresh_clock.clone(), + ))), mcp_startup_cancellation_token: CancellationToken::new(), + mcp_oauth_refresh_clock, unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(None), rollout: Mutex::new(None), diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 670225ead06..efadc90029c 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -51,6 +51,7 @@ pub(crate) async fn run_codex_conversation_interactive( models_manager, initial_history.unwrap_or(InitialHistory::New), SessionSource::SubAgent(SubAgentSource::Review), + parent_session.services.mcp_oauth_refresh_clock.clone(), ) .await?; let codex = Arc::new(codex); diff --git a/codex-rs/core/src/conversation_manager.rs b/codex-rs/core/src/conversation_manager.rs index b1818849eb4..570020e6226 100644 --- a/codex-rs/core/src/conversation_manager.rs +++ b/codex-rs/core/src/conversation_manager.rs @@ -22,6 +22,8 @@ use codex_protocol::protocol::SessionSource; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; use tokio::sync::RwLock; /// Represents a newly created Codex conversation, including the first event @@ -39,6 +41,7 @@ pub struct ConversationManager { auth_manager: Arc, models_manager: Arc, session_source: SessionSource, + mcp_oauth_refresh_clock: Arc, } impl ConversationManager { @@ -48,6 +51,7 @@ impl ConversationManager { auth_manager: auth_manager.clone(), session_source, models_manager: Arc::new(ModelsManager::new(auth_manager)), + mcp_oauth_refresh_clock: Arc::new(AtomicI64::new(0)), } } @@ -65,6 +69,15 @@ impl ConversationManager { self.session_source.clone() } + pub fn mcp_oauth_refresh_clock(&self) -> Arc { + self.mcp_oauth_refresh_clock.clone() + } + + pub fn mark_mcp_oauth_success(&self, timestamp_secs: i64) { + self.mcp_oauth_refresh_clock + .store(timestamp_secs, Ordering::SeqCst); + } + pub async fn new_conversation(&self, config: Config) -> CodexResult { self.spawn_conversation( config, @@ -89,6 +102,7 @@ impl ConversationManager { models_manager, InitialHistory::New, self.session_source.clone(), + self.mcp_oauth_refresh_clock.clone(), ) .await?; self.finalize_spawn(codex, conversation_id).await @@ -166,6 +180,7 @@ impl ConversationManager { self.models_manager.clone(), initial_history, self.session_source.clone(), + self.mcp_oauth_refresh_clock.clone(), ) .await?; self.finalize_spawn(codex, conversation_id).await @@ -207,6 +222,7 @@ impl ConversationManager { self.models_manager.clone(), history, self.session_source.clone(), + self.mcp_oauth_refresh_clock.clone(), ) .await?; diff --git a/codex-rs/core/src/mcp/mod.rs b/codex-rs/core/src/mcp/mod.rs index ed5f2ea69f6..8c538f7153d 100644 --- a/codex-rs/core/src/mcp/mod.rs +++ b/codex-rs/core/src/mcp/mod.rs @@ -1,5 +1,7 @@ pub mod auth; use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicI64; use async_channel::unbounded; use codex_protocol::protocol::McpListToolsResponseEvent; @@ -29,7 +31,8 @@ pub async fn collect_mcp_snapshot(config: &Config) -> McpListToolsResponseEvent ) .await; - let mut mcp_connection_manager = McpConnectionManager::default(); + let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0)); + let mut mcp_connection_manager = McpConnectionManager::new(mcp_oauth_refresh_clock); let (tx_event, rx_event) = unbounded(); drop(rx_event); let cancel_token = CancellationToken::new(); diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 11a90f77a81..b081787517b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -12,6 +12,8 @@ use std::env; use std::ffi::OsString; use std::path::PathBuf; use std::sync::Arc; +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; use std::time::Duration; use crate::mcp::auth::McpAuthStatusEntry; @@ -260,13 +262,70 @@ pub struct SandboxState { } /// A thin wrapper around a set of running [`RmcpClient`] instances. -#[derive(Default)] pub(crate) struct McpConnectionManager { clients: HashMap, elicitation_requests: ElicitationRequestManager, + mcp_oauth_refresh_clock: Arc, + last_refresh_seen: AtomicI64, + config_snapshot: HashMap, + store_mode_snapshot: Option, + auth_entries_snapshot: HashMap, } impl McpConnectionManager { + pub(crate) fn new(mcp_oauth_refresh_clock: Arc) -> Self { + Self { + clients: HashMap::new(), + elicitation_requests: ElicitationRequestManager::default(), + mcp_oauth_refresh_clock, + last_refresh_seen: AtomicI64::new(0), + config_snapshot: HashMap::new(), + store_mode_snapshot: None, + auth_entries_snapshot: HashMap::new(), + } + } + + fn update_snapshots( + &mut self, + mcp_servers: &HashMap, + store_mode: OAuthCredentialsStoreMode, + auth_entries: &HashMap, + ) { + self.config_snapshot = mcp_servers.clone(); + self.store_mode_snapshot = Some(store_mode); + self.auth_entries_snapshot = auth_entries.clone(); + let now = self.mcp_oauth_refresh_clock.load(Ordering::SeqCst); + self.last_refresh_seen.store(now, Ordering::SeqCst); + } + + pub(crate) fn last_refresh_seen(&self) -> i64 { + self.last_refresh_seen.load(Ordering::SeqCst) + } + + pub(crate) async fn refresh_if_needed( + &mut self, + config: &HashMap, + store_mode: OAuthCredentialsStoreMode, + auth_entries: HashMap, + tx_event: Sender, + cancel_token: CancellationToken, + ) { + let current = self.mcp_oauth_refresh_clock.load(Ordering::SeqCst); + if current <= self.last_refresh_seen() { + return; + } + + self.initialize( + config.clone(), + store_mode, + auth_entries, + tx_event, + cancel_token, + ) + .await; + self.last_refresh_seen.store(current, Ordering::SeqCst); + } + pub async fn initialize( &mut self, mcp_servers: HashMap, @@ -281,7 +340,9 @@ impl McpConnectionManager { let mut clients = HashMap::new(); let mut join_set = JoinSet::new(); let elicitation_requests = ElicitationRequestManager::default(); - for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) { + for (server_name, cfg) in mcp_servers.iter().filter(|(_, cfg)| cfg.enabled) { + let server_name = server_name.to_string(); + let cfg = cfg.clone(); let cancel_token = cancel_token.child_token(); let _ = emit_update( &tx_event, @@ -333,6 +394,7 @@ impl McpConnectionManager { } self.clients = clients; self.elicitation_requests = elicitation_requests.clone(); + self.update_snapshots(&mcp_servers, store_mode, &auth_entries); tokio::spawn(async move { let outcomes = join_set.join_all().await; let mut summary = McpStartupCompleteEvent::default(); diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 7387bcedae0..4410304277c 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -8,6 +8,7 @@ use crate::tools::sandboxing::ApprovalStore; use crate::unified_exec::UnifiedExecSessionManager; use crate::user_notification::UserNotifier; use codex_otel::otel_event_manager::OtelEventManager; +use std::sync::atomic::AtomicI64; use tokio::sync::Mutex; use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; @@ -15,6 +16,7 @@ use tokio_util::sync::CancellationToken; pub(crate) struct SessionServices { pub(crate) mcp_connection_manager: Arc>, pub(crate) mcp_startup_cancellation_token: CancellationToken, + pub(crate) mcp_oauth_refresh_clock: Arc, pub(crate) unified_exec_manager: UnifiedExecSessionManager, pub(crate) notifier: UserNotifier, pub(crate) rollout: Mutex>, diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index ac617f3d29c..954898cea49 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -16,7 +16,9 @@ pub use oauth::WrappedOAuthTokenResponse; pub use oauth::delete_oauth_tokens; pub(crate) use oauth::load_oauth_tokens; pub use oauth::save_oauth_tokens; +pub use perform_oauth_login::OauthLoginHandle; pub use perform_oauth_login::perform_oauth_login; +pub use perform_oauth_login::perform_oauth_login_return_url; pub use rmcp::model::ElicitationAction; pub use rmcp_client::Elicitation; pub use rmcp_client::ElicitationResponse; diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index d8ffdd3949a..9815a3a22d6 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -22,6 +22,11 @@ use crate::save_oauth_tokens; use crate::utils::apply_default_headers; use crate::utils::build_default_headers; +struct OauthHeaders { + http_headers: Option>, + env_http_headers: Option>, +} + struct CallbackServerGuard { server: Arc, } @@ -40,70 +45,52 @@ pub async fn perform_oauth_login( env_http_headers: Option>, scopes: &[String], ) -> Result<()> { - let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?); - let guard = CallbackServerGuard { - server: Arc::clone(&server), + let headers = OauthHeaders { + http_headers, + env_http_headers, }; + OauthLoginFlow::new( + server_name, + server_url, + store_mode, + headers, + scopes, + true, + None, + ) + .await? + .finish() + .await +} - let redirect_uri = match server.server_addr() { - tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => { - format!("http://{}:{}/callback", addr.ip(), addr.port()) - } - tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => { - format!("http://[{}]:{}/callback", addr.ip(), addr.port()) - } - #[cfg(not(target_os = "windows"))] - _ => return Err(anyhow!("unable to determine callback address")), +pub async fn perform_oauth_login_return_url( + server_name: &str, + server_url: &str, + store_mode: OAuthCredentialsStoreMode, + http_headers: Option>, + env_http_headers: Option>, + scopes: &[String], + timeout_secs: Option, +) -> Result { + let headers = OauthHeaders { + http_headers, + env_http_headers, }; + let flow = OauthLoginFlow::new( + server_name, + server_url, + store_mode, + headers, + scopes, + false, + timeout_secs, + ) + .await?; - let (tx, rx) = oneshot::channel(); - spawn_callback_server(server, tx); - - let default_headers = build_default_headers(http_headers, env_http_headers)?; - let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?; - - let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?; - let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect(); - oauth_state - .start_authorization(&scope_refs, &redirect_uri, Some("Codex")) - .await?; - let auth_url = oauth_state.get_authorization_url().await?; - - println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n"); - - if webbrowser::open(&auth_url).is_err() { - println!("(Browser launch failed; please copy the URL above manually.)"); - } - - let (code, csrf_state) = timeout(Duration::from_secs(300), rx) - .await - .context("timed out waiting for OAuth callback")? - .context("OAuth callback was cancelled")?; - - oauth_state - .handle_callback(&code, &csrf_state) - .await - .context("failed to handle OAuth callback")?; - - let (client_id, credentials_opt) = oauth_state - .get_credentials() - .await - .context("failed to retrieve OAuth credentials")?; - let credentials = - credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?; - - let expires_at = compute_expires_at_millis(&credentials); - let stored = StoredOAuthTokens { - server_name: server_name.to_string(), - url: server_url.to_string(), - client_id, - token_response: WrappedOAuthTokenResponse(credentials), - expires_at, - }; - save_oauth_tokens(server_name, &stored, store_mode)?; + let authorization_url = flow.authorization_url(); + let completion = flow.spawn(); - drop(guard); - Ok(()) + Ok(OauthLoginHandle::new(authorization_url, completion)) } fn spawn_callback_server(server: Arc, tx: oneshot::Sender<(String, String)>) { @@ -160,3 +147,181 @@ fn parse_oauth_callback(path: &str) -> Option { state: state?, }) } + +pub struct OauthLoginHandle { + authorization_url: String, + completion: oneshot::Receiver>, +} + +impl OauthLoginHandle { + fn new(authorization_url: String, completion: oneshot::Receiver>) -> Self { + Self { + authorization_url, + completion, + } + } + + pub fn authorization_url(&self) -> &str { + &self.authorization_url + } + + pub fn into_parts(self) -> (String, oneshot::Receiver>) { + (self.authorization_url, self.completion) + } + + pub async fn wait(self) -> Result<()> { + self.completion + .await + .map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))? + } +} + +struct OauthLoginFlow { + auth_url: String, + oauth_state: OAuthState, + rx: oneshot::Receiver<(String, String)>, + guard: CallbackServerGuard, + server_name: String, + server_url: String, + store_mode: OAuthCredentialsStoreMode, + launch_browser: bool, + timeout: Duration, +} + +impl OauthLoginFlow { + async fn new( + server_name: &str, + server_url: &str, + store_mode: OAuthCredentialsStoreMode, + headers: OauthHeaders, + scopes: &[String], + launch_browser: bool, + timeout_secs: Option, + ) -> Result { + const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300; + + let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?); + let guard = CallbackServerGuard { + server: Arc::clone(&server), + }; + + let redirect_uri = match server.server_addr() { + tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => { + let ip = addr.ip(); + let port = addr.port(); + format!("http://{ip}:{port}/callback") + } + tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => { + let ip = addr.ip(); + let port = addr.port(); + format!("http://[{ip}]:{port}/callback") + } + #[cfg(not(target_os = "windows"))] + _ => return Err(anyhow!("unable to determine callback address")), + }; + + let (tx, rx) = oneshot::channel(); + spawn_callback_server(server, tx); + + let OauthHeaders { + http_headers, + env_http_headers, + } = headers; + let default_headers = build_default_headers(http_headers, env_http_headers)?; + let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?; + + let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?; + let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect(); + oauth_state + .start_authorization(&scope_refs, &redirect_uri, Some("Codex")) + .await?; + let auth_url = oauth_state.get_authorization_url().await?; + let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1); + let timeout = Duration::from_secs(timeout_secs as u64); + + Ok(Self { + auth_url, + oauth_state, + rx, + guard, + server_name: server_name.to_string(), + server_url: server_url.to_string(), + store_mode, + launch_browser, + timeout, + }) + } + + fn authorization_url(&self) -> String { + self.auth_url.clone() + } + + async fn finish(mut self) -> Result<()> { + if self.launch_browser { + let server_name = &self.server_name; + let auth_url = &self.auth_url; + println!( + "Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n" + ); + + if webbrowser::open(auth_url).is_err() { + println!("(Browser launch failed; please copy the URL above manually.)"); + } + } + + let result = async { + let (code, csrf_state) = timeout(self.timeout, &mut self.rx) + .await + .context("timed out waiting for OAuth callback")? + .context("OAuth callback was cancelled")?; + + self.oauth_state + .handle_callback(&code, &csrf_state) + .await + .context("failed to handle OAuth callback")?; + + let (client_id, credentials_opt) = self + .oauth_state + .get_credentials() + .await + .context("failed to retrieve OAuth credentials")?; + let credentials = credentials_opt + .ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?; + + let expires_at = compute_expires_at_millis(&credentials); + let stored = StoredOAuthTokens { + server_name: self.server_name.clone(), + url: self.server_url.clone(), + client_id, + token_response: WrappedOAuthTokenResponse(credentials), + expires_at, + }; + save_oauth_tokens(&self.server_name, &stored, self.store_mode)?; + + Ok(()) + } + .await; + + drop(self.guard); + result + } + + fn spawn(self) -> oneshot::Receiver> { + let server_name_for_logging = self.server_name.clone(); + let (tx, rx) = oneshot::channel(); + + tokio::spawn(async move { + let result = self.finish().await; + + if let Err(err) = &result { + eprintln!( + "Failed to complete OAuth login for '{server_name_for_logging}': {err:#}" + ); + } + + let _ = tx.send(result); + }); + + rx + } +}