diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 96e90c5cfa6..ebbb0078863 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -485,7 +485,7 @@ pub(crate) struct SessionSettingsUpdate { impl Session { /// Don't expand the number of mutated arguments on config. We are in the process of getting rid of it. - fn build_per_turn_config(session_configuration: &SessionConfiguration) -> Config { + pub(crate) fn build_per_turn_config(session_configuration: &SessionConfiguration) -> Config { // todo(aibrahim): store this state somewhere else so we don't need to mut config let config = session_configuration.original_config_do_not_use.clone(); let mut per_turn_config = (*config).clone(); diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index cfa5a0acc61..b268bf6d782 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -92,6 +92,8 @@ pub enum Feature { PowershellUtf8, /// Compress request bodies (zstd) when sending streaming requests to codex-backend. EnableRequestCompression, + /// Enable collab tools. + Collab, } impl Feature { @@ -398,6 +400,12 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::Experimental, default_enabled: false, }, + FeatureSpec { + id: Feature::Collab, + key: "collab", + stage: Stage::Experimental, + default_enabled: false, + }, FeatureSpec { id: Feature::Tui2, key: "tui2", diff --git a/codex-rs/core/src/tools/handlers/collab.rs b/codex-rs/core/src/tools/handlers/collab.rs new file mode 100644 index 00000000000..e59e15cbc06 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/collab.rs @@ -0,0 +1,194 @@ +use crate::codex::TurnContext; +use crate::config::Config; +use crate::error::CodexErr; +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use async_trait::async_trait; +use codex_protocol::ThreadId; +use serde::Deserialize; + +pub struct CollabHandler; + +pub(crate) const DEFAULT_WAIT_TIMEOUT_MS: i64 = 30_000; +pub(crate) const MAX_WAIT_TIMEOUT_MS: i64 = 300_000; + +#[derive(Debug, Deserialize)] +struct SpawnAgentArgs { + message: String, +} + +#[derive(Debug, Deserialize)] +struct SendInputArgs { + id: String, + message: String, +} + +#[derive(Debug, Deserialize)] +struct WaitArgs { + id: String, + timeout_ms: Option, +} + +#[derive(Debug, Deserialize)] +struct CloseAgentArgs { + id: String, +} + +#[async_trait] +impl ToolHandler for CollabHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tool_name, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "collab handler received unsupported payload".to_string(), + )); + } + }; + + match tool_name.as_str() { + "spawn_agent" => handle_spawn_agent(session, turn, arguments).await, + "send_input" => handle_send_input(session, arguments).await, + "wait" => handle_wait(arguments).await, + "close_agent" => handle_close_agent(arguments).await, + other => Err(FunctionCallError::RespondToModel(format!( + "unsupported collab tool {other}" + ))), + } + } +} + +async fn handle_spawn_agent( + session: std::sync::Arc, + turn: std::sync::Arc, + arguments: String, +) -> Result { + let args: SpawnAgentArgs = parse_arguments(&arguments)?; + if args.message.trim().is_empty() { + return Err(FunctionCallError::RespondToModel( + "Empty message can't be send to an agent".to_string(), + )); + } + let config = build_agent_spawn_config(turn.as_ref())?; + let result = session + .services + .agent_control + .spawn_agent(config, args.message, true) + .await + .map_err(|err| FunctionCallError::Fatal(err.to_string()))?; + + Ok(ToolOutput::Function { + content: format!("agent_id: {result}"), + success: Some(true), + content_items: None, + }) +} + +async fn handle_send_input( + session: std::sync::Arc, + arguments: String, +) -> Result { + let args: SendInputArgs = parse_arguments(&arguments)?; + let agent_id = agent_id(&args.id)?; + if args.message.trim().is_empty() { + return Err(FunctionCallError::RespondToModel( + "Empty message can't be send to an agent".to_string(), + )); + } + let content = session + .services + .agent_control + .send_prompt(agent_id, args.message) + .await + .map_err(|err| match err { + CodexErr::ThreadNotFound(id) => { + FunctionCallError::RespondToModel(format!("agent with id {id} not found")) + } + err => FunctionCallError::Fatal(err.to_string()), + })?; + + Ok(ToolOutput::Function { + content, + success: Some(true), + content_items: None, + }) +} + +async fn handle_wait(arguments: String) -> Result { + let args: WaitArgs = parse_arguments(&arguments)?; + let _agent_id = agent_id(&args.id)?; + + let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS); + if timeout_ms <= 0 { + return Err(FunctionCallError::RespondToModel( + "timeout_ms must be greater than zero".to_string(), + )); + } + let _timeout_ms = timeout_ms.min(MAX_WAIT_TIMEOUT_MS); + // TODO(jif): implement agent wait once lifecycle tracking is wired up. + Err(FunctionCallError::Fatal("wait not implemented".to_string())) +} + +async fn handle_close_agent(arguments: String) -> Result { + let args: CloseAgentArgs = parse_arguments(&arguments)?; + let _agent_id = agent_id(&args.id)?; + // TODO(jif): implement agent shutdown and return the final status. + Err(FunctionCallError::Fatal( + "close_agent not implemented".to_string(), + )) +} + +fn agent_id(id: &str) -> Result { + ThreadId::from_string(id) + .map_err(|e| FunctionCallError::RespondToModel(format!("invalid agent id {id}: {e:?}"))) +} + +fn build_agent_spawn_config(turn: &TurnContext) -> Result { + let base_config = turn.client.config(); + let mut config = (*base_config).clone(); + config.model = Some(turn.client.get_model()); + config.model_provider = turn.client.get_provider(); + config.model_reasoning_effort = turn.client.get_reasoning_effort(); + config.model_reasoning_summary = turn.client.get_reasoning_summary(); + config.developer_instructions = turn.developer_instructions.clone(); + config.base_instructions = turn.base_instructions.clone(); + config.compact_prompt = turn.compact_prompt.clone(); + config.user_instructions = turn.user_instructions.clone(); + config.shell_environment_policy = turn.shell_environment_policy.clone(); + config.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone(); + config.cwd = turn.cwd.clone(); + config + .approval_policy + .set(turn.approval_policy) + .map_err(|err| { + FunctionCallError::RespondToModel(format!("approval_policy is invalid: {err}")) + })?; + config + .sandbox_policy + .set(turn.sandbox_policy.clone()) + .map_err(|err| { + FunctionCallError::RespondToModel(format!("sandbox_policy is invalid: {err}")) + })?; + Ok(config) +} diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index d9f6859c641..ab8123df1e1 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -1,4 +1,5 @@ pub mod apply_patch; +pub(crate) mod collab; mod grep_files; mod list_dir; mod mcp; @@ -15,6 +16,7 @@ use serde::Deserialize; use crate::function_tool::FunctionCallError; pub use apply_patch::ApplyPatchHandler; +pub use collab::CollabHandler; pub use grep_files::GrepFilesHandler; pub use list_dir::ListDirHandler; pub use mcp::McpHandler; diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index 846025d58e2..48c4da71fd9 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -5,6 +5,8 @@ use crate::features::Features; use crate::tools::handlers::PLAN_TOOL; use crate::tools::handlers::apply_patch::create_apply_patch_freeform_tool; use crate::tools::handlers::apply_patch::create_apply_patch_json_tool; +use crate::tools::handlers::collab::DEFAULT_WAIT_TIMEOUT_MS; +use crate::tools::handlers::collab::MAX_WAIT_TIMEOUT_MS; use crate::tools::registry::ToolRegistryBuilder; use codex_protocol::openai_models::ApplyPatchToolType; use codex_protocol::openai_models::ConfigShellToolType; @@ -22,6 +24,7 @@ pub(crate) struct ToolsConfig { pub apply_patch_tool_type: Option, pub web_search_request: bool, pub web_search_cached: bool, + pub collab_tools: bool, pub experimental_supported_tools: Vec, } @@ -39,6 +42,7 @@ impl ToolsConfig { let include_apply_patch_tool = features.enabled(Feature::ApplyPatchFreeform); let include_web_search_request = features.enabled(Feature::WebSearchRequest); let include_web_search_cached = features.enabled(Feature::WebSearchCached); + let include_collab_tools = features.enabled(Feature::Collab); let shell_type = if !features.enabled(Feature::ShellTool) { ConfigShellToolType::Disabled @@ -70,6 +74,7 @@ impl ToolsConfig { apply_patch_tool_type, web_search_request: include_web_search_request, web_search_cached: include_web_search_cached, + collab_tools: include_collab_tools, experimental_supported_tools: model_info.experimental_supported_tools.clone(), } } @@ -416,6 +421,104 @@ fn create_view_image_tool() -> ToolSpec { }) } +fn create_spawn_agent_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "message".to_string(), + JsonSchema::String { + description: Some("Initial message to send to the new agent.".to_string()), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "spawn_agent".to_string(), + description: "Spawn a new agent and return its id.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["message".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + +fn create_send_input_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "id".to_string(), + JsonSchema::String { + description: Some("Identifier of the agent to message.".to_string()), + }, + ); + properties.insert( + "message".to_string(), + JsonSchema::String { + description: Some("Message to send to the agent.".to_string()), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "send_input".to_string(), + description: "Send a message to an existing agent.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["id".to_string(), "message".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + +fn create_wait_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "id".to_string(), + JsonSchema::String { + description: Some("Identifier of the agent to wait on.".to_string()), + }, + ); + properties.insert( + "timeout_ms".to_string(), + JsonSchema::Number { + description: Some(format!( + "Optional timeout in milliseconds. Defaults to {DEFAULT_WAIT_TIMEOUT_MS} and max {MAX_WAIT_TIMEOUT_MS}." + )), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "wait".to_string(), + description: "Wait for an agent and return its status.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["id".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + +fn create_close_agent_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "id".to_string(), + JsonSchema::String { + description: Some("Identifier of the agent to close.".to_string()), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "close_agent".to_string(), + description: "Close an agent and return its last known status.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["id".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + fn create_test_sync_tool() -> ToolSpec { let mut properties = BTreeMap::new(); properties.insert( @@ -981,6 +1084,7 @@ pub(crate) fn build_specs( mcp_tools: Option>, ) -> ToolRegistryBuilder { use crate::tools::handlers::ApplyPatchHandler; + use crate::tools::handlers::CollabHandler; use crate::tools::handlers::GrepFilesHandler; use crate::tools::handlers::ListDirHandler; use crate::tools::handlers::McpHandler; @@ -1107,6 +1211,18 @@ pub(crate) fn build_specs( builder.push_spec_with_parallel_support(create_view_image_tool(), true); builder.register_handler("view_image", view_image_handler); + if config.collab_tools { + let collab_handler = Arc::new(CollabHandler); + builder.push_spec(create_spawn_agent_tool()); + builder.push_spec(create_send_input_tool()); + builder.push_spec(create_wait_tool()); + builder.push_spec(create_close_agent_tool()); + builder.register_handler("spawn_agent", collab_handler.clone()); + builder.register_handler("send_input", collab_handler.clone()); + builder.register_handler("wait", collab_handler.clone()); + builder.register_handler("close_agent", collab_handler); + } + if let Some(mcp_tools) = mcp_tools { let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect(); entries.sort_by(|a, b| a.0.cmp(&b.0)); @@ -1286,6 +1402,23 @@ mod tests { } } + #[test] + fn test_build_specs_collab_tools_enabled() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Collab); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + features: &features, + }); + let (tools, _) = build_specs(&tools_config, None).build(); + assert_contains_tool_names( + &tools, + &["spawn_agent", "send_input", "wait", "close_agent"], + ); + } + fn assert_model_tools(model_slug: &str, features: &Features, expected_tools: &[&str]) { let config = test_config(); let model_info = ModelsManager::construct_model_info_offline(model_slug, &config);