diff --git a/codex-rs/core/src/agent/control.rs b/codex-rs/core/src/agent/control.rs index 22b6ae343b6..616e34d10d2 100644 --- a/codex-rs/core/src/agent/control.rs +++ b/codex-rs/core/src/agent/control.rs @@ -27,7 +27,6 @@ impl AgentControl { Self { manager } } - #[allow(dead_code)] // Used by upcoming multi-agent tooling. /// Spawn a new agent thread and submit the initial prompt. /// /// If `headless` is true, a background drain task is spawned to prevent unbounded event growth @@ -50,7 +49,6 @@ impl AgentControl { Ok(new_thread.thread_id) } - #[allow(dead_code)] // Used by upcoming multi-agent tooling. /// Send a `user` prompt to an existing agent thread. pub(crate) async fn send_prompt( &self, @@ -69,7 +67,7 @@ impl AgentControl { .await } - #[allow(dead_code)] // Used by upcoming multi-agent tooling. + #[allow(dead_code)] // Will be used for collab tools. /// Fetch the last known status for `agent_id`, returning `NotFound` when unavailable. pub(crate) async fn get_status(&self, agent_id: ThreadId) -> AgentStatus { let Ok(state) = self.upgrade() else { @@ -114,13 +112,63 @@ fn spawn_headless_drain(thread: Arc) { #[cfg(test)] mod tests { use super::*; + use crate::CodexAuth; + use crate::ThreadManager; use crate::agent::agent_status_from_event; + use crate::config::Config; + use crate::config::ConfigBuilder; + use assert_matches::assert_matches; use codex_protocol::protocol::ErrorEvent; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnAbortedEvent; use codex_protocol::protocol::TurnCompleteEvent; use codex_protocol::protocol::TurnStartedEvent; use pretty_assertions::assert_eq; + use tempfile::TempDir; + + async fn test_config() -> (TempDir, Config) { + let home = TempDir::new().expect("create temp dir"); + let config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .build() + .await + .expect("load default test config"); + (home, config) + } + + struct AgentControlHarness { + _home: TempDir, + config: Config, + manager: ThreadManager, + control: AgentControl, + } + + impl AgentControlHarness { + async fn new() -> Self { + let (home, config) = test_config().await; + let manager = ThreadManager::with_models_provider_and_home( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let control = manager.agent_control(); + Self { + _home: home, + config, + manager, + control, + } + } + + async fn start_thread(&self) -> (ThreadId, Arc) { + let new_thread = self + .manager + .start_thread(self.config.clone()) + .await + .expect("start thread"); + (new_thread.thread_id, new_thread.thread) + } + } #[tokio::test] async fn send_prompt_errors_when_manager_dropped() { @@ -185,4 +233,103 @@ mod tests { let status = agent_status_from_event(&EventMsg::ShutdownComplete); assert_eq!(status, Some(AgentStatus::Shutdown)); } + + #[tokio::test] + async fn spawn_agent_errors_when_manager_dropped() { + let control = AgentControl::default(); + let (_home, config) = test_config().await; + let err = control + .spawn_agent(config, "hello".to_string(), false) + .await + .expect_err("spawn_agent should fail without a manager"); + assert_eq!( + err.to_string(), + "unsupported operation: thread manager dropped" + ); + } + + #[tokio::test] + async fn send_prompt_errors_when_thread_missing() { + let harness = AgentControlHarness::new().await; + let thread_id = ThreadId::new(); + let err = harness + .control + .send_prompt(thread_id, "hello".to_string()) + .await + .expect_err("send_prompt should fail for missing thread"); + assert_matches!(err, CodexErr::ThreadNotFound(id) if id == thread_id); + } + + #[tokio::test] + async fn get_status_returns_not_found_for_missing_thread() { + let harness = AgentControlHarness::new().await; + let status = harness.control.get_status(ThreadId::new()).await; + assert_eq!(status, AgentStatus::NotFound); + } + + #[tokio::test] + async fn get_status_returns_pending_init_for_new_thread() { + let harness = AgentControlHarness::new().await; + let (thread_id, _) = harness.start_thread().await; + let status = harness.control.get_status(thread_id).await; + assert_eq!(status, AgentStatus::PendingInit); + } + + #[tokio::test] + async fn send_prompt_submits_user_message() { + let harness = AgentControlHarness::new().await; + let (thread_id, _thread) = harness.start_thread().await; + + let submission_id = harness + .control + .send_prompt(thread_id, "hello from tests".to_string()) + .await + .expect("send_prompt should succeed"); + assert!(!submission_id.is_empty()); + let expected = ( + thread_id, + Op::UserInput { + items: vec![UserInput::Text { + text: "hello from tests".to_string(), + }], + final_output_json_schema: None, + }, + ); + let captured = harness + .manager + .captured_ops() + .into_iter() + .find(|entry| *entry == expected); + assert_eq!(captured, Some(expected)); + } + + #[tokio::test] + async fn spawn_agent_creates_thread_and_sends_prompt() { + let harness = AgentControlHarness::new().await; + let thread_id = harness + .control + .spawn_agent(harness.config.clone(), "spawned".to_string(), false) + .await + .expect("spawn_agent should succeed"); + let _thread = harness + .manager + .get_thread(thread_id) + .await + .expect("thread should be registered"); + let expected = ( + thread_id, + Op::UserInput { + items: vec![UserInput::Text { + text: "spawned".to_string(), + }], + final_output_json_schema: None, + }, + ); + let captured = harness + .manager + .captured_ops() + .into_iter() + .find(|entry| *entry == expected); + assert_eq!(captured, Some(expected)); + } } diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index a4e8f9c34cf..6124f678745 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -56,6 +56,10 @@ pub(crate) struct ThreadManagerState { models_manager: Arc, skills_manager: Arc, session_source: SessionSource, + #[cfg(any(test, feature = "test-support"))] + #[allow(dead_code)] + // Captures submitted ops for testing purpose. + ops_log: Arc>>, } impl ThreadManager { @@ -74,6 +78,8 @@ impl ThreadManager { skills_manager: Arc::new(SkillsManager::new(codex_home)), auth_manager, session_source, + #[cfg(any(test, feature = "test-support"))] + ops_log: Arc::new(std::sync::Mutex::new(Vec::new())), }), #[cfg(any(test, feature = "test-support"))] _test_codex_home_guard: None, @@ -111,6 +117,8 @@ impl ThreadManager { skills_manager: Arc::new(SkillsManager::new(codex_home)), auth_manager, session_source: SessionSource::Exec, + #[cfg(any(test, feature = "test-support"))] + ops_log: Arc::new(std::sync::Mutex::new(Vec::new())), }), _test_codex_home_guard: None, } @@ -202,9 +210,19 @@ impl ThreadManager { .await } - fn agent_control(&self) -> AgentControl { + pub(crate) fn agent_control(&self) -> AgentControl { AgentControl::new(Arc::downgrade(&self.state)) } + + #[cfg(any(test, feature = "test-support"))] + #[allow(dead_code)] + pub(crate) fn captured_ops(&self) -> Vec<(ThreadId, Op)> { + self.state + .ops_log + .lock() + .map(|log| log.clone()) + .unwrap_or_default() + } } impl ThreadManagerState { @@ -217,7 +235,14 @@ impl ThreadManagerState { } pub(crate) async fn send_op(&self, thread_id: ThreadId, op: Op) -> CodexResult { - self.get_thread(thread_id).await?.submit(op).await + let thread = self.get_thread(thread_id).await?; + #[cfg(any(test, feature = "test-support"))] + { + if let Ok(mut log) = self.ops_log.lock() { + log.push((thread_id, op.clone())); + } + } + thread.submit(op).await } #[allow(dead_code)] // Used by upcoming multi-agent tooling. diff --git a/codex-rs/core/src/tools/handlers/collab.rs b/codex-rs/core/src/tools/handlers/collab.rs index e59e15cbc06..a2f69a759bb 100644 --- a/codex-rs/core/src/tools/handlers/collab.rs +++ b/codex-rs/core/src/tools/handlers/collab.rs @@ -1,3 +1,4 @@ +use crate::agent::AgentStatus; use crate::codex::TurnContext; use crate::config::Config; use crate::error::CodexErr; @@ -11,29 +12,13 @@ use crate::tools::registry::ToolKind; use async_trait::async_trait; use codex_protocol::ThreadId; use serde::Deserialize; +use serde::Serialize; pub struct CollabHandler; pub(crate) const DEFAULT_WAIT_TIMEOUT_MS: i64 = 30_000; pub(crate) const MAX_WAIT_TIMEOUT_MS: i64 = 300_000; -#[derive(Debug, Deserialize)] -struct SpawnAgentArgs { - message: String, -} - -#[derive(Debug, Deserialize)] -struct SendInputArgs { - id: String, - message: String, -} - -#[derive(Debug, Deserialize)] -struct WaitArgs { - id: String, - timeout_ms: Option, -} - #[derive(Debug, Deserialize)] struct CloseAgentArgs { id: String, @@ -68,10 +53,10 @@ impl ToolHandler for CollabHandler { }; match tool_name.as_str() { - "spawn_agent" => handle_spawn_agent(session, turn, arguments).await, - "send_input" => handle_send_input(session, arguments).await, - "wait" => handle_wait(arguments).await, - "close_agent" => handle_close_agent(arguments).await, + "spawn_agent" => spawn::handle(session, turn, arguments).await, + "send_input" => send_input::handle(session, arguments).await, + "wait" => wait::handle(session, arguments).await, + "close_agent" => close_agent::handle(arguments).await, other => Err(FunctionCallError::RespondToModel(format!( "unsupported collab tool {other}" ))), @@ -79,84 +64,160 @@ impl ToolHandler for CollabHandler { } } -async fn handle_spawn_agent( - session: std::sync::Arc, - turn: std::sync::Arc, - arguments: String, -) -> Result { - let args: SpawnAgentArgs = parse_arguments(&arguments)?; - if args.message.trim().is_empty() { - return Err(FunctionCallError::RespondToModel( - "Empty message can't be send to an agent".to_string(), - )); - } - let config = build_agent_spawn_config(turn.as_ref())?; - let result = session - .services - .agent_control - .spawn_agent(config, args.message, true) - .await - .map_err(|err| FunctionCallError::Fatal(err.to_string()))?; - - Ok(ToolOutput::Function { - content: format!("agent_id: {result}"), - success: Some(true), - content_items: None, - }) +mod spawn { + use super::*; + use crate::codex::Session; + use std::sync::Arc; + + #[derive(Debug, Deserialize)] + struct SpawnAgentArgs { + message: String, + } + + pub async fn handle( + session: Arc, + turn: Arc, + arguments: String, + ) -> Result { + let args: SpawnAgentArgs = parse_arguments(&arguments)?; + if args.message.trim().is_empty() { + return Err(FunctionCallError::RespondToModel( + "Empty message can't be send to an agent".to_string(), + )); + } + let config = build_agent_spawn_config(turn.as_ref())?; + let result = session + .services + .agent_control + .spawn_agent(config, args.message, true) + .await + .map_err(|err| FunctionCallError::Fatal(err.to_string()))?; + + Ok(ToolOutput::Function { + content: format!("agent_id: {result}"), + success: Some(true), + content_items: None, + }) + } } -async fn handle_send_input( - session: std::sync::Arc, - arguments: String, -) -> Result { - let args: SendInputArgs = parse_arguments(&arguments)?; - let agent_id = agent_id(&args.id)?; - if args.message.trim().is_empty() { - return Err(FunctionCallError::RespondToModel( - "Empty message can't be send to an agent".to_string(), - )); - } - let content = session - .services - .agent_control - .send_prompt(agent_id, args.message) - .await - .map_err(|err| match err { - CodexErr::ThreadNotFound(id) => { - FunctionCallError::RespondToModel(format!("agent with id {id} not found")) - } - err => FunctionCallError::Fatal(err.to_string()), - })?; +mod send_input { + use super::*; + use crate::codex::Session; + use std::sync::Arc; + + #[derive(Debug, Deserialize)] + struct SendInputArgs { + id: String, + message: String, + } + + pub async fn handle( + session: Arc, + arguments: String, + ) -> Result { + let args: SendInputArgs = parse_arguments(&arguments)?; + let agent_id = agent_id(&args.id)?; + if args.message.trim().is_empty() { + return Err(FunctionCallError::RespondToModel( + "Empty message can't be send to an agent".to_string(), + )); + } + let content = session + .services + .agent_control + .send_prompt(agent_id, args.message) + .await + .map_err(|err| match err { + CodexErr::ThreadNotFound(id) => { + FunctionCallError::RespondToModel(format!("agent with id {id} not found")) + } + err => FunctionCallError::Fatal(err.to_string()), + })?; - Ok(ToolOutput::Function { - content, - success: Some(true), - content_items: None, - }) + Ok(ToolOutput::Function { + content, + success: Some(true), + content_items: None, + }) + } } -async fn handle_wait(arguments: String) -> Result { - let args: WaitArgs = parse_arguments(&arguments)?; - let _agent_id = agent_id(&args.id)?; +#[allow(unused_variables)] +mod wait { + use super::*; + use crate::codex::Session; + use std::sync::Arc; + + #[derive(Debug, Deserialize)] + struct WaitArgs { + id: String, + timeout_ms: Option, + } + + #[derive(Debug, Serialize)] + struct WaitResult { + status: AgentStatus, + timed_out: bool, + } + + pub async fn handle( + session: Arc, + arguments: String, + ) -> Result { + let args: WaitArgs = parse_arguments(&arguments)?; + let agent_id = agent_id(&args.id)?; + + let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS); + if timeout_ms <= 0 { + return Err(FunctionCallError::RespondToModel( + "timeout_ms must be greater than zero".to_string(), + )); + } + let timeout_ms = timeout_ms.min(MAX_WAIT_TIMEOUT_MS); + // TODO(jif) actual implementation + let outcome = WaitResult { + status: Default::default(), + timed_out: false, + }; + + if matches!(outcome.status, AgentStatus::NotFound) { + return Err(FunctionCallError::RespondToModel(format!( + "agent with id {agent_id} not found" + ))); + } - let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS); - if timeout_ms <= 0 { - return Err(FunctionCallError::RespondToModel( - "timeout_ms must be greater than zero".to_string(), - )); + let message = outcome.timed_out.then(|| { + format!( + "Timed out after {timeout_ms}ms waiting for agent {agent_id}. The agent may still be running." + ) + }); + let result = WaitResult { + status: outcome.status, + timed_out: outcome.timed_out, + }; + let content = serde_json::to_string(&result).map_err(|err| { + FunctionCallError::Fatal(format!("failed to serialize wait result: {err}")) + })?; + Ok(ToolOutput::Function { + content, + success: Some(!outcome.timed_out), + content_items: None, + }) } - let _timeout_ms = timeout_ms.min(MAX_WAIT_TIMEOUT_MS); - // TODO(jif): implement agent wait once lifecycle tracking is wired up. - Err(FunctionCallError::Fatal("wait not implemented".to_string())) } -async fn handle_close_agent(arguments: String) -> Result { - let args: CloseAgentArgs = parse_arguments(&arguments)?; - let _agent_id = agent_id(&args.id)?; - // TODO(jif): implement agent shutdown and return the final status. - Err(FunctionCallError::Fatal( - "close_agent not implemented".to_string(), - )) +pub mod close_agent { + use super::*; + + pub async fn handle(arguments: String) -> Result { + let args: CloseAgentArgs = parse_arguments(&arguments)?; + let _agent_id = agent_id(&args.id)?; + // TODO(jif): implement agent shutdown and return the final status. + Err(FunctionCallError::Fatal( + "close_agent not implemented".to_string(), + )) + } } fn agent_id(id: &str) -> Result { @@ -192,3 +253,284 @@ fn build_agent_spawn_config(turn: &TurnContext) -> Result, + turn: Arc, + tool_name: &str, + payload: ToolPayload, + ) -> ToolInvocation { + ToolInvocation { + session, + turn, + tracker: Arc::new(Mutex::new(TurnDiffTracker::default())), + call_id: "call-1".to_string(), + tool_name: tool_name.to_string(), + payload, + } + } + + fn function_payload(args: serde_json::Value) -> ToolPayload { + ToolPayload::Function { + arguments: args.to_string(), + } + } + + fn thread_manager() -> ThreadManager { + ThreadManager::with_models_provider( + CodexAuth::from_api_key("dummy"), + built_in_model_providers()["openai"].clone(), + ) + } + + #[tokio::test] + async fn handler_rejects_non_function_payloads() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + ToolPayload::Custom { + input: "hello".to_string(), + }, + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("payload should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel( + "collab handler received unsupported payload".to_string() + ) + ); + } + + #[tokio::test] + async fn handler_rejects_unknown_tool() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "unknown_tool", + function_payload(json!({})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("tool should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("unsupported collab tool unknown_tool".to_string()) + ); + } + + #[tokio::test] + async fn spawn_agent_rejects_empty_message() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({"message": " "})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("empty message should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel( + "Empty message can't be send to an agent".to_string() + ) + ); + } + + #[tokio::test] + async fn spawn_agent_errors_when_manager_dropped() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({"message": "hello"})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("spawn should fail without a manager"); + }; + assert_eq!( + err, + FunctionCallError::Fatal("unsupported operation: thread manager dropped".to_string()) + ); + } + + #[tokio::test] + async fn send_input_rejects_empty_message() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({"id": ThreadId::new().to_string(), "message": ""})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("empty message should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel( + "Empty message can't be send to an agent".to_string() + ) + ); + } + + #[tokio::test] + async fn send_input_rejects_invalid_id() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({"id": "not-a-uuid", "message": "hi"})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("invalid id should be rejected"); + }; + let FunctionCallError::RespondToModel(msg) = err else { + panic!("expected respond-to-model error"); + }; + assert!(msg.starts_with("invalid agent id not-a-uuid:")); + } + + #[tokio::test] + async fn send_input_reports_missing_agent() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let agent_id = ThreadId::new(); + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({"id": agent_id.to_string(), "message": "hi"})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("missing agent should be reported"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel(format!("agent with id {agent_id} not found")) + ); + } + + #[tokio::test] + async fn wait_rejects_non_positive_timeout() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({"id": ThreadId::new().to_string(), "timeout_ms": 0})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("non-positive timeout should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("timeout_ms must be greater than zero".to_string()) + ); + } + + #[tokio::test] + async fn wait_rejects_invalid_id() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({"id": "invalid"})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("invalid id should be rejected"); + }; + let FunctionCallError::RespondToModel(msg) = err else { + panic!("expected respond-to-model error"); + }; + assert!(msg.starts_with("invalid agent id invalid:")); + } + + #[tokio::test] + async fn close_agent_reports_not_implemented() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "close_agent", + function_payload(json!({"id": ThreadId::new().to_string()})), + ); + let Err(err) = CollabHandler.handle(invocation).await else { + panic!("close_agent should fail"); + }; + assert_eq!( + err, + FunctionCallError::Fatal("close_agent not implemented".to_string()) + ); + } + + #[tokio::test] + async fn build_agent_spawn_config_uses_turn_context_values() { + let (_session, mut turn) = make_session_and_context().await; + turn.developer_instructions = Some("dev".to_string()); + turn.base_instructions = Some("base".to_string()); + turn.compact_prompt = Some("compact".to_string()); + turn.user_instructions = Some("user".to_string()); + turn.shell_environment_policy = ShellEnvironmentPolicy { + use_profile: true, + ..ShellEnvironmentPolicy::default() + }; + let temp_dir = tempfile::tempdir().expect("temp dir"); + turn.cwd = temp_dir.path().to_path_buf(); + turn.codex_linux_sandbox_exe = Some(PathBuf::from("/bin/echo")); + turn.approval_policy = AskForApproval::Never; + turn.sandbox_policy = SandboxPolicy::DangerFullAccess; + + let config = build_agent_spawn_config(&turn).expect("spawn config"); + let mut expected = (*turn.client.config()).clone(); + expected.model = Some(turn.client.get_model()); + expected.model_provider = turn.client.get_provider(); + expected.model_reasoning_effort = turn.client.get_reasoning_effort(); + expected.model_reasoning_summary = turn.client.get_reasoning_summary(); + expected.developer_instructions = turn.developer_instructions.clone(); + expected.base_instructions = turn.base_instructions.clone(); + expected.compact_prompt = turn.compact_prompt.clone(); + expected.user_instructions = turn.user_instructions.clone(); + expected.shell_environment_policy = turn.shell_environment_policy.clone(); + expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone(); + expected.cwd = turn.cwd.clone(); + expected + .approval_policy + .set(turn.approval_policy) + .expect("approval policy set"); + expected + .sandbox_policy + .set(turn.sandbox_policy) + .expect("sandbox policy set"); + assert_eq!(config, expected); + } +}