From 2071e47026111bed4f579238021658e2c6b2b963 Mon Sep 17 00:00:00 2001 From: Priyash Patil <38959321+priyashpatil@users.noreply.github.com> Date: Sat, 25 Apr 2026 00:15:58 +0530 Subject: [PATCH] feat: add plan mode support --- README.md | 1 + src/thread.rs | 1429 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 1403 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 34070a9..9694a44 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ This tool implements an ACP adapter around the Codex CLI, supporting: - Tool calls (with permission requests) - Following - Edit review +- Plan mode - TODO lists - Slash commands: - /review (with optional instructions) diff --git a/src/thread.rs b/src/thread.rs index c45189a..462c014 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -38,9 +38,10 @@ use codex_protocol::{ ElicitationRequest, ElicitationRequestEvent, GuardianAssessmentAction, GuardianCommandSource, }, - config_types::TrustLevel, + config_types::{CollaborationMode, CollaborationModeMask, ModeKind, Settings, TrustLevel}, dynamic_tools::{DynamicToolCallOutputContentItem, DynamicToolCallRequest}, error::CodexErr, + items::TurnItem, mcp::CallToolResult, models::{PermissionProfile, ResponseItem, WebSearchAction}, openai_models::{ModelPreset, ReasoningEffort}, @@ -69,6 +70,10 @@ use codex_protocol::{ PermissionGrantScope, RequestPermissionProfile, RequestPermissionsEvent, RequestPermissionsResponse, }, + request_user_input::{ + RequestUserInputAnswer, RequestUserInputEvent, RequestUserInputQuestion, + RequestUserInputResponse, + }, user_input::UserInput, }; use codex_shell_command::parse_command::parse_command; @@ -135,6 +140,9 @@ pub trait ModelsManagerImpl: Send + Sync { model_id: &Option, ) -> Pin + Send + '_>>; fn list_models(&self) -> Pin> + Send + '_>>; + fn list_collaboration_modes( + &self, + ) -> Pin> + Send + '_>>; } impl ModelsManagerImpl for ModelsManager { @@ -152,6 +160,12 @@ impl ModelsManagerImpl for ModelsManager { fn list_models(&self) -> Pin> + Send + '_>> { Box::pin(self.list_models(RefreshStrategy::OnlineIfUncached)) } + + fn list_collaboration_modes( + &self, + ) -> Pin> + Send + '_>> { + Box::pin(async move { self.list_collaboration_modes() }) + } } pub trait Auth { @@ -200,6 +214,9 @@ enum ThreadMessage { history: Vec, response_tx: oneshot::Sender>, }, + SubmitPlanImplementation { + approval_preset_id: String, + }, PermissionRequestResolved { submission_id: String, request_key: String, @@ -391,6 +408,15 @@ enum PendingPermissionRequest { request_id: codex_protocol::mcp::RequestId, option_map: HashMap, }, + PlanImplementation, + UserInput { + turn_id: String, + call_id: String, + questions: Vec, + question_index: usize, + answers: HashMap, + option_map: HashMap, + }, } struct PendingPermissionInteraction { @@ -443,6 +469,30 @@ fn permissions_request_key(call_id: &str) -> String { format!("permissions:{call_id}") } +fn user_input_request_key(call_id: &str, question_index: usize) -> String { + format!("user-input:{call_id}:{question_index}") +} + +fn plan_implementation_request_key(submission_id: &str) -> String { + format!("plan-implementation:{submission_id}") +} + +fn mode_kind_as_id(mode: ModeKind) -> &'static str { + match mode { + ModeKind::Plan => "plan", + ModeKind::Default => "default", + ModeKind::PairProgramming => "pair_programming", + ModeKind::Execute => "execute", + } +} + +fn collaboration_mode_description(mode: ModeKind) -> Option<&'static str> { + match mode { + ModeKind::Plan => Some(PLAN_MODE_DESCRIPTION), + ModeKind::Default | ModeKind::PairProgramming | ModeKind::Execute => None, + } +} + fn mcp_elicitation_request_key( server_name: &str, request_id: &codex_protocol::mcp::RequestId, @@ -466,6 +516,12 @@ const MCP_TOOL_APPROVAL_ALLOW_OPTION_ID: &str = "approved"; const MCP_TOOL_APPROVAL_ALLOW_SESSION_OPTION_ID: &str = "approved-for-session"; const MCP_TOOL_APPROVAL_ALLOW_ALWAYS_OPTION_ID: &str = "approved-always"; const MCP_TOOL_APPROVAL_CANCEL_OPTION_ID: &str = "cancel"; +const PLAN_MODE_DESCRIPTION: &str = + "Codex can help create and refine a plan before implementation."; +const PLAN_IMPLEMENTATION_ACCEPT_DEFAULT_OPTION_ID: &str = "accept-plan-default"; +const PLAN_IMPLEMENTATION_ACCEPT_FULL_ACCESS_OPTION_ID: &str = "accept-plan-full-access"; +const PLAN_IMPLEMENTATION_STAY_OPTION_ID: &str = "stay-in-plan"; +const PLAN_IMPLEMENTATION_CODING_MESSAGE: &str = "Implement the plan."; struct SupportedMcpElicitationPermissionRequest { request_key: String, @@ -744,6 +800,11 @@ struct PromptState { response_tx: Option>>, seen_message_deltas: bool, seen_reasoning_deltas: bool, + turn_complete: bool, + turn_collaboration_mode_kind: ModeKind, + saw_plan_output: bool, + plan_output_text: Option, + prompted_for_plan_implementation: bool, } impl PromptState { @@ -765,10 +826,18 @@ impl PromptState { response_tx: Some(response_tx), seen_message_deltas: false, seen_reasoning_deltas: false, + turn_complete: false, + turn_collaboration_mode_kind: ModeKind::Default, + saw_plan_output: false, + plan_output_text: None, + prompted_for_plan_implementation: false, } } fn is_active(&self) -> bool { + if !self.turn_complete || !self.pending_permission_interactions.is_empty() { + return true; + } let Some(response_tx) = &self.response_tx else { return false; }; @@ -815,9 +884,202 @@ impl PromptState { } } + fn spawn_plan_implementation_request(&mut self, client: &SessionClient) { + self.prompted_for_plan_implementation = true; + let request_key = plan_implementation_request_key(&self.submission_id); + let call_id = request_key.clone(); + let prompt_content = self.plan_output_text.clone().unwrap_or_else(|| { + "The plan is ready. Choose whether to switch to Default mode and start implementation, or stay in Plan mode to keep discussing it.".to_string() + }); + self.spawn_permission_request( + client, + request_key, + PendingPermissionRequest::PlanImplementation, + ToolCallUpdate::new( + ToolCallId::new(call_id), + ToolCallUpdateFields::new() + .kind(ToolKind::Think) + .status(ToolCallStatus::Pending) + .title("Implement this plan?") + .raw_input(serde_json::json!({ + "request_type": "plan_implementation", + })) + .content(vec![ToolCallContent::Content(Content::new( + ContentBlock::Text(TextContent::new(prompt_content)), + ))]), + ), + vec![ + PermissionOption::new( + PLAN_IMPLEMENTATION_ACCEPT_DEFAULT_OPTION_ID, + "Accept and continue with Default profile", + PermissionOptionKind::AllowOnce, + ), + PermissionOption::new( + PLAN_IMPLEMENTATION_ACCEPT_FULL_ACCESS_OPTION_ID, + "Accept and continue with Full Access profile", + PermissionOptionKind::AllowOnce, + ), + PermissionOption::new( + PLAN_IMPLEMENTATION_STAY_OPTION_ID, + "Reject and continue planning", + PermissionOptionKind::RejectOnce, + ), + ], + ); + } + + async fn submit_user_input_answers( + &self, + turn_id: String, + answers: HashMap, + ) -> Result<(), Error> { + self.thread + .submit(Op::UserInputAnswer { + id: turn_id, + response: RequestUserInputResponse { answers }, + }) + .await + .map_err(|e| Error::from(anyhow::anyhow!(e)))?; + Ok(()) + } + + fn complete_user_input_tool_call( + &self, + client: &SessionClient, + call_id: impl Into, + ) { + client.send_tool_call_update(ToolCallUpdate::new( + call_id, + ToolCallUpdateFields::new().status(ToolCallStatus::Completed), + )); + } + + fn mark_plan_implementation_decision( + &self, + client: &SessionClient, + call_id: impl Into, + status: ToolCallStatus, + title: &'static str, + decision: &'static str, + ) { + let content = self.plan_output_text.as_ref().map(|text| { + vec![ToolCallContent::Content(Content::new(ContentBlock::Text( + TextContent::new(text.clone()), + )))] + }); + client.send_tool_call_update(ToolCallUpdate::new( + call_id, + ToolCallUpdateFields::new() + .status(status) + .title(title) + .content(content) + .raw_output(serde_json::json!({ + "decision": decision, + })), + )); + } + + async fn finalize_user_input_answers( + &self, + client: &SessionClient, + call_id: String, + turn_id: String, + answers: HashMap, + ) -> Result<(), Error> { + self.submit_user_input_answers(turn_id, answers).await?; + self.complete_user_input_tool_call(client, call_id); + Ok(()) + } + + fn spawn_user_input_question_request( + &mut self, + client: &SessionClient, + turn_id: String, + call_id: String, + questions: Vec, + question_index: usize, + answers: HashMap, + ) { + let Some(question) = questions.get(question_index).cloned() else { + let thread = self.thread.clone(); + let client = client.clone(); + tokio::spawn(async move { + if let Err(err) = thread + .submit(Op::UserInputAnswer { + id: turn_id, + response: RequestUserInputResponse { answers }, + }) + .await + { + warn!("Failed to submit UserInputAnswer fallback: {err}"); + return; + } + + client.send_tool_call_update(ToolCallUpdate::new( + ToolCallId::new(call_id), + ToolCallUpdateFields::new().status(ToolCallStatus::Completed), + )); + }); + return; + }; + + let (options, option_map) = build_user_input_permission_options(&question); + let mut content_lines = vec![question.question.clone()]; + + if let Some(question_options) = question.options.as_ref() { + content_lines.extend( + question_options + .iter() + .map(|option| format!("- {}: {}", option.label, option.description)), + ); + } + if question.is_other { + content_lines.push( + "- Other: custom answer will be available when the client UI supports structured input" + .to_string(), + ); + } + + let title = if question.header.is_empty() { + "Need user input".to_string() + } else { + format!("Need user input: {}", question.header) + }; + + let request_key = user_input_request_key(&call_id, question_index); + self.spawn_permission_request( + client, + request_key, + PendingPermissionRequest::UserInput { + turn_id, + call_id: call_id.clone(), + questions, + question_index, + answers, + option_map, + }, + ToolCallUpdate::new( + ToolCallId::new(call_id), + ToolCallUpdateFields::new() + .kind(ToolKind::Think) + .status(ToolCallStatus::Pending) + .title(title) + .raw_input(serde_json::json!({ + "request_type": "request_user_input", + "question": question, + "fallback": "session/request_permission", + })) + .content(vec![ToolCallContent::Content(Content::new( + ContentBlock::Text(TextContent::new(content_lines.join("\n"))), + ))]), + ), + options, + ); + } + async fn handle_permission_request_resolved( &mut self, - _client: &SessionClient, + client: &SessionClient, request_key: String, response: Result, ) -> Result<(), Error> { @@ -944,6 +1206,96 @@ impl PromptState { .await .map_err(|e| Error::from(anyhow::anyhow!(e)))?; } + PendingPermissionRequest::PlanImplementation => { + let selected_option_id = match response.outcome { + RequestPermissionOutcome::Selected(SelectedPermissionOutcome { + option_id, + .. + }) => Some(option_id.0.to_string()), + RequestPermissionOutcome::Cancelled | _ => None, + }; + + let approval_preset_id = match selected_option_id.as_deref() { + Some(PLAN_IMPLEMENTATION_ACCEPT_DEFAULT_OPTION_ID) => Some("auto"), + Some(PLAN_IMPLEMENTATION_ACCEPT_FULL_ACCESS_OPTION_ID) => Some("full-access"), + _ => None, + }; + + if let Some(approval_preset_id) = approval_preset_id { + self.mark_plan_implementation_decision( + client, + request_key, + ToolCallStatus::Completed, + "User accepted plan", + "accept_plan", + ); + drop( + self.resolution_tx + .send(ThreadMessage::SubmitPlanImplementation { + approval_preset_id: approval_preset_id.to_string(), + }), + ); + } else { + self.mark_plan_implementation_decision( + client, + request_key, + ToolCallStatus::Failed, + "Plan rejected", + "stay_in_plan", + ); + } + } + PendingPermissionRequest::UserInput { + turn_id, + call_id, + questions, + question_index, + mut answers, + option_map, + } => { + let Some(question) = questions.get(question_index) else { + self.finalize_user_input_answers(client, call_id, turn_id, answers) + .await?; + return Ok(()); + }; + + let selected_answer = match response.outcome { + RequestPermissionOutcome::Selected(SelectedPermissionOutcome { + option_id, + .. + }) => option_map.get(option_id.0.as_ref()).cloned(), + RequestPermissionOutcome::Cancelled | _ => None, + }; + + let Some(answer) = selected_answer else { + self.finalize_user_input_answers(client, call_id, turn_id, answers) + .await?; + return Ok(()); + }; + + answers.insert( + question.id.clone(), + RequestUserInputAnswer { + answers: vec![answer], + }, + ); + + let next_question_index = question_index + 1; + if next_question_index >= questions.len() { + self.finalize_user_input_answers(client, call_id, turn_id, answers) + .await?; + return Ok(()); + } + + self.spawn_user_input_question_request( + client, + turn_id, + call_id, + questions, + next_question_index, + answers, + ); + } } Ok(()) @@ -974,6 +1326,7 @@ impl PromptState { | EventMsg::TurnAborted(..) | EventMsg::EnteredReviewMode(..) | EventMsg::ExitedReviewMode(..) + | EventMsg::RequestUserInput(..) | EventMsg::ShutdownComplete => { self.complete_web_search(client); } @@ -988,6 +1341,11 @@ impl PromptState { started_at: _, }) => { info!("Task started with context window of {turn_id} {model_context_window:?} {collaboration_mode_kind:?}"); + self.turn_complete = false; + self.turn_collaboration_mode_kind = collaboration_mode_kind; + self.saw_plan_output = false; + self.plan_output_text = None; + self.prompted_for_plan_implementation = false; } EventMsg::TokenCount(TokenCountEvent { info, .. }) => { if let Some(info) = info @@ -1072,6 +1430,9 @@ impl PromptState { EventMsg::PlanUpdate(UpdatePlanArgs { explanation, plan }) => { // Send this to the client via session/update notification info!("Agent plan updated. Explanation: {:?}", explanation); + if !plan.is_empty() { + self.saw_plan_output = true; + } client.update_plan(plan); } EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }) => { @@ -1200,6 +1561,10 @@ impl PromptState { item, }) => { info!("Item completed: thread_id={}, turn_id={}, item={:?}", thread_id, turn_id, item); + if let TurnItem::Plan(plan_item) = item { + self.saw_plan_output = true; + self.plan_output_text = Some(plan_item.text); + } } EventMsg::TurnComplete(TurnCompleteEvent { last_agent_message, turn_id, completed_at: _, duration_ms: _, time_to_first_token_ms: _, }) => { info!( @@ -1207,6 +1572,13 @@ impl PromptState { self.event_count ); self.abort_pending_interactions(); + self.turn_complete = true; + if self.turn_collaboration_mode_kind == ModeKind::Plan + && self.saw_plan_output + && !self.prompted_for_plan_implementation + { + self.spawn_plan_implementation_request(client); + } if let Some(response_tx) = self.response_tx.take() { response_tx.send(Ok(StopReason::EndTurn)).ok(); } @@ -1234,6 +1606,7 @@ impl PromptState { error!( "Handled error during turn: {message} {codex_error_info:?} {additional_details:?}" ); + self.turn_complete = true; } EventMsg::Error(ErrorEvent { message, @@ -1241,6 +1614,7 @@ impl PromptState { }) => { error!("Unhandled error during turn: {message} {codex_error_info:?}"); self.abort_pending_interactions(); + self.turn_complete = true; if let Some(response_tx) = self.response_tx.take() { response_tx .send(Err(Error::internal_error().data( @@ -1252,6 +1626,7 @@ impl PromptState { EventMsg::TurnAborted(TurnAbortedEvent { reason, turn_id, completed_at: _, duration_ms: _ }) => { info!("Turn {turn_id:?} aborted: {reason:?}"); self.abort_pending_interactions(); + self.turn_complete = true; if let Some(response_tx) = self.response_tx.take() { response_tx.send(Ok(StopReason::Cancelled)).ok(); } @@ -1259,6 +1634,7 @@ impl PromptState { EventMsg::ShutdownComplete => { info!("Agent shutting down"); self.abort_pending_interactions(); + self.turn_complete = true; if let Some(response_tx) = self.response_tx.take() { response_tx.send(Ok(StopReason::Cancelled)).ok(); } @@ -1332,6 +1708,19 @@ impl PromptState { drop(response_tx.send(Err(err))); } } + EventMsg::RequestUserInput(event) => { + info!( + "Request user input: call_id={}, turn_id={}, questions={}", + event.call_id, + event.turn_id, + event.questions.len() + ); + if let Err(err) = self.request_user_input(client, event).await + && let Some(response_tx) = self.response_tx.take() + { + drop(response_tx.send(Err(err))); + } + } EventMsg::GuardianAssessment(event) => { info!( "Guardian assessment: id={}, status={:?}, turn_id={}", @@ -1379,13 +1768,41 @@ impl PromptState { | EventMsg::RealtimeConversationListVoicesResponse(..) // Used for returning a single history entry | EventMsg::GetHistoryEntryResponse(..) - | EventMsg::DeprecationNotice(..) - | EventMsg::RequestUserInput(..)) => { + | EventMsg::DeprecationNotice(..)) => { warn!("Unexpected event: {:?}", e); } } } + async fn request_user_input( + &mut self, + client: &SessionClient, + event: RequestUserInputEvent, + ) -> Result<(), Error> { + let RequestUserInputEvent { + call_id, + turn_id, + questions, + } = event; + + if questions.is_empty() { + self.finalize_user_input_answers(client, call_id, turn_id, HashMap::new()) + .await?; + return Ok(()); + } + + self.spawn_user_input_question_request( + client, + turn_id, + call_id, + questions, + 0, + HashMap::new(), + ); + + Ok(()) + } + async fn mcp_elicitation( &mut self, client: &SessionClient, @@ -2358,6 +2775,53 @@ fn build_exec_permission_options( .collect() } +fn build_user_input_permission_options( + question: &RequestUserInputQuestion, +) -> (Vec, HashMap) { + let mut option_map = HashMap::new(); + let mut options = Vec::new(); + + if let Some(question_options) = question.options.as_ref() { + for (index, option) in question_options.iter().enumerate() { + let option_id = format!("answer-{index}"); + option_map.insert(option_id.clone(), option.label.clone()); + options.push(PermissionOption::new( + option_id, + option.label.clone(), + PermissionOptionKind::AllowOnce, + )); + } + } + + if question.is_other { + let option_id = "answer-other".to_string(); + option_map.insert(option_id.clone(), "other".to_string()); + options.push(PermissionOption::new( + option_id, + "Other", + PermissionOptionKind::AllowOnce, + )); + } + + if options.is_empty() { + let option_id = "answer-continue".to_string(); + option_map.insert(option_id.clone(), "continue".to_string()); + options.push(PermissionOption::new( + option_id, + "Continue", + PermissionOptionKind::AllowOnce, + )); + } + + options.push(PermissionOption::new( + "cancel", + "Cancel", + PermissionOptionKind::RejectOnce, + )); + + (options, option_map) +} + struct ParseCommandToolCall { title: String, file_extension: Option, @@ -2592,6 +3056,8 @@ struct ThreadActor { resolution_rx: mpsc::UnboundedReceiver, /// Last config options state we emitted to the client, used for deduping updates. last_sent_config_options: Option>, + /// Current collaboration mode kind for this session. + current_collaboration_mode_kind: ModeKind, } impl ThreadActor { @@ -2617,6 +3083,7 @@ impl ThreadActor { message_rx, resolution_rx, last_sent_config_options: None, + current_collaboration_mode_kind: ModeKind::Default, } } @@ -2709,6 +3176,14 @@ impl ThreadActor { let result = self.handle_replay_history(history); drop(response_tx.send(result)); } + ThreadMessage::SubmitPlanImplementation { approval_preset_id } => { + if let Err(err) = self.submit_plan_implementation(&approval_preset_id).await { + error!("Failed to submit accepted plan implementation: {err:?}"); + self.client + .send_agent_text(format!("Failed to start implementation: {err}")); + } + self.maybe_emit_config_options_update().await; + } ThreadMessage::PermissionRequestResolved { submission_id, request_key, @@ -2766,7 +3241,42 @@ impl ThreadActor { ] } - fn modes(&self) -> Option { + async fn modes(&self) -> Option { + let approval_modes = self.approval_modes()?; + let collaboration_modes = self.models_manager.list_collaboration_modes().await; + let mut available_modes = approval_modes.available_modes; + + for mask in collaboration_modes { + let Some(mode) = mask.mode else { + continue; + }; + if !mode.is_tui_visible() || mode == ModeKind::Default { + continue; + } + let mode_id = mode_kind_as_id(mode); + if available_modes + .iter() + .any(|available_mode: &SessionMode| available_mode.id.0.as_ref() == mode_id) + { + continue; + } + let mut session_mode = SessionMode::new(mode_id, mask.name); + if let Some(description) = collaboration_mode_description(mode) { + session_mode = session_mode.description(description); + } + available_modes.push(session_mode); + } + + let current_mode_id = if self.current_collaboration_mode_kind == ModeKind::Default { + approval_modes.current_mode_id + } else { + SessionModeId::new(mode_kind_as_id(self.current_collaboration_mode_kind)) + }; + + Some(SessionModeState::new(current_mode_id, available_modes)) + } + + fn approval_modes(&self) -> Option { let current_mode_id = APPROVAL_PRESETS .iter() .find(|preset| { @@ -2837,7 +3347,7 @@ impl ThreadActor { async fn config_options(&self) -> Result, Error> { let mut options = Vec::new(); - if let Some(modes) = self.modes() { + if let Some(modes) = self.modes().await { let select_options = modes .available_modes .into_iter() @@ -2847,12 +3357,14 @@ impl ThreadActor { options.push( SessionConfigOption::select( "mode", - "Approval Preset", + "Mode", modes.current_mode_id.0, select_options, ) - .category(SessionConfigOptionCategory::Mode) - .description("Choose an approval and sandboxing preset for your session"), + .category(SessionConfigOptionCategory::Other( + "_approval_preset".to_string(), + )) + .description("Choose a session mode or approval preset"), ); } @@ -2958,24 +3470,85 @@ impl ThreadActor { }; match config_id.0.as_ref() { "mode" => self.handle_set_mode(SessionModeId::new(value.0)).await, + "collaboration_mode" => self.handle_set_collaboration_mode(value).await, "model" => self.handle_set_config_model(value).await, "reasoning_effort" => self.handle_set_config_reasoning_effort(value).await, _ => Err(Error::invalid_params().data("Unsupported config option")), } } - async fn handle_set_config_model(&mut self, value: SessionConfigValueId) -> Result<(), Error> { - let model_id = value.0; + async fn handle_set_collaboration_mode( + &mut self, + value: SessionConfigValueId, + ) -> Result<(), Error> { + let mode: ModeKind = serde_json::from_value(value.0.as_ref().into()) + .map_err(|_| Error::invalid_params().data("Unsupported collaboration mode"))?; + let next_mode = self.collaboration_mode_for(mode).await?; - let presets = self.models_manager.list_models().await; - let preset = presets.iter().find(|p| p.id.as_str() == &*model_id); + self.thread + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(next_mode.clone()), + personality: None, + windows_sandbox_level: None, + service_tier: None, + approvals_reviewer: None, + permission_profile: None, + }) + .await + .map_err(|e| Error::from(anyhow::anyhow!(e)))?; - let model_to_use = preset - .map(|p| p.model.clone()) - .unwrap_or_else(|| model_id.to_string()); + self.apply_collaboration_mode(next_mode); - if model_to_use.is_empty() { - return Err(Error::invalid_params().data("No model selected")); + Ok(()) + } + + async fn collaboration_mode_for(&self, mode: ModeKind) -> Result { + if !mode.is_tui_visible() { + return Err(Error::invalid_params().data("Unsupported collaboration mode")); + } + + let masks = self.models_manager.list_collaboration_modes().await; + let Some(mask) = masks.iter().find(|mask| mask.mode == Some(mode)) else { + return Err(Error::invalid_params().data("Collaboration mode is unavailable")); + }; + + let current_mode = CollaborationMode { + mode: self.current_collaboration_mode_kind, + settings: Settings { + model: self.get_current_model().await, + reasoning_effort: self.config.model_reasoning_effort, + developer_instructions: self.config.developer_instructions.clone(), + }, + }; + + Ok(current_mode.apply_mask(mask)) + } + + fn apply_collaboration_mode(&mut self, next_mode: CollaborationMode) { + self.current_collaboration_mode_kind = next_mode.mode; + self.config.model = Some(next_mode.settings.model); + self.config.model_reasoning_effort = next_mode.settings.reasoning_effort; + self.config.developer_instructions = next_mode.settings.developer_instructions; + } + + async fn handle_set_config_model(&mut self, value: SessionConfigValueId) -> Result<(), Error> { + let model_id = value.0; + + let presets = self.models_manager.list_models().await; + let preset = presets.iter().find(|p| p.id.as_str() == &*model_id); + + let model_to_use = preset + .map(|p| p.model.clone()) + .unwrap_or_else(|| model_id.to_string()); + + if model_to_use.is_empty() { + return Err(Error::invalid_params().data("No model selected")); } let effort_to_use = if let Some(preset) = preset { @@ -3102,7 +3675,7 @@ impl ThreadActor { async fn handle_load(&mut self) -> Result { Ok(LoadSessionResponse::new() .models(self.models().await?) - .modes(self.modes()) + .modes(self.modes().await) .config_options(self.config_options().await?)) } @@ -3212,11 +3785,67 @@ impl ThreadActor { Ok(response_rx) } + async fn submit_plan_implementation(&mut self, approval_preset_id: &str) -> Result<(), Error> { + let preset = Self::approval_preset(approval_preset_id)?; + let collaboration_mode = self.collaboration_mode_for(ModeKind::Default).await?; + let op = Op::UserInputWithTurnContext { + items: vec![UserInput::Text { + text: PLAN_IMPLEMENTATION_CODING_MESSAGE.to_string(), + text_elements: vec![], + }], + environments: None, + final_output_json_schema: None, + responsesapi_client_metadata: None, + cwd: None, + approval_policy: Some(preset.approval), + approvals_reviewer: None, + sandbox_policy: Some(preset.sandbox.clone()), + permission_profile: None, + windows_sandbox_level: None, + model: None, + effort: None, + summary: None, + service_tier: None, + collaboration_mode: Some(collaboration_mode.clone()), + personality: None, + }; + + let submission_id = self + .thread + .submit(op) + .await + .map_err(|e| Error::internal_error().data(e.to_string()))?; + + let (response_tx, _response_rx) = oneshot::channel(); + let state = SubmissionState::Prompt(PromptState::new( + submission_id.clone(), + self.thread.clone(), + self.resolution_tx.clone(), + response_tx, + )); + + self.submissions.insert(submission_id, state); + self.apply_collaboration_mode(collaboration_mode); + self.apply_approval_preset(preset)?; + + Ok(()) + } + async fn handle_set_mode(&mut self, mode: SessionModeId) -> Result<(), Error> { - let preset = APPROVAL_PRESETS + if APPROVAL_PRESETS .iter() - .find(|preset| mode.0.as_ref() == preset.id) - .ok_or_else(Error::invalid_params)?; + .any(|preset| mode.0.as_ref() == preset.id) + { + return self.handle_set_approval_preset(mode).await; + } + + self.handle_set_collaboration_mode(SessionConfigValueId::new(mode.0)) + .await + } + + async fn handle_set_approval_preset(&mut self, mode: SessionModeId) -> Result<(), Error> { + let preset = Self::approval_preset(mode.0.as_ref())?; + let collaboration_mode = self.collaboration_mode_for(ModeKind::Default).await?; self.thread .submit(Op::OverrideTurnContext { @@ -3226,7 +3855,7 @@ impl ThreadActor { model: None, effort: None, summary: None, - collaboration_mode: None, + collaboration_mode: Some(collaboration_mode.clone()), personality: None, windows_sandbox_level: None, service_tier: None, @@ -3236,6 +3865,20 @@ impl ThreadActor { .await .map_err(|e| Error::from(anyhow::anyhow!(e)))?; + self.apply_collaboration_mode(collaboration_mode); + self.apply_approval_preset(preset)?; + + Ok(()) + } + + fn approval_preset(mode: &str) -> Result<&'static ApprovalPreset, Error> { + APPROVAL_PRESETS + .iter() + .find(|preset| mode == preset.id) + .ok_or_else(Error::invalid_params) + } + + fn apply_approval_preset(&mut self, preset: &ApprovalPreset) -> Result<(), Error> { self.config .permissions .approval_policy @@ -3247,7 +3890,7 @@ impl ThreadActor { .set(preset.sandbox.clone()) .map_err(|e| Error::from(anyhow::anyhow!(e)))?; - match preset.sandbox { + match &preset.sandbox { // Treat this user action as a trusted dir SandboxPolicy::DangerFullAccess | SandboxPolicy::WorkspaceWrite { .. } @@ -3645,6 +4288,14 @@ impl ThreadActor { } async fn handle_event(&mut self, Event { id, msg }: Event) { + if let EventMsg::TurnStarted(TurnStartedEvent { + collaboration_mode_kind, + .. + }) = &msg + { + self.current_collaboration_mode_kind = *collaboration_mode_kind; + } + if let Some(submission) = self.submissions.get_mut(&id) { submission.handle_event(&self.client, msg).await; } else { @@ -4029,9 +4680,12 @@ mod tests { use std::sync::atomic::AtomicUsize; use std::time::Duration; - use agent_client_protocol::schema::{RequestPermissionResponse, TextContent}; + use agent_client_protocol::schema::{ + RequestPermissionResponse, SessionConfigKind, SessionConfigSelectOptions, TextContent, + }; use codex_core::{config::ConfigOverrides, test_support::all_model_presets}; use codex_protocol::config_types::ModeKind; + use codex_protocol::request_user_input::RequestUserInputQuestionOption; use tokio::sync::{Mutex, Notify, mpsc::UnboundedSender}; use super::*; @@ -4411,15 +5065,468 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_set_collaboration_mode_submits_override() -> anyhow::Result<()> { + let (_session_id, _client, thread, message_tx, _handle) = setup().await?; + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::SetConfigOption { + config_id: SessionConfigId::new("collaboration_mode"), + value: SessionConfigOptionValue::ValueId { + value: SessionConfigValueId::new("plan"), + }, + response_tx, + })?; + + response_rx.await??; + drop(message_tx); + + let ops = thread.ops.lock().unwrap(); + assert!(matches!( + ops.last(), + Some(Op::OverrideTurnContext { + collaboration_mode: Some(mode), + .. + }) if mode.mode == ModeKind::Plan + )); + + Ok(()) + } + + #[tokio::test] + async fn test_set_session_mode_submits_collaboration_override() -> anyhow::Result<()> { + let (_session_id, _client, thread, message_tx, _handle) = setup().await?; + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::SetMode { + mode: SessionModeId::new("plan"), + response_tx, + })?; + + response_rx.await??; + drop(message_tx); + + let ops = thread.ops.lock().unwrap(); + assert!(matches!( + ops.last(), + Some(Op::OverrideTurnContext { + collaboration_mode: Some(mode), + .. + }) if mode.mode == ModeKind::Plan + )); + + Ok(()) + } + + #[tokio::test] + async fn test_set_approval_session_mode_resets_collaboration_mode() -> anyhow::Result<()> { + let (_session_id, _client, thread, message_tx, _handle) = setup().await?; + let (plan_response_tx, plan_response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::SetMode { + mode: SessionModeId::new("plan"), + response_tx: plan_response_tx, + })?; + plan_response_rx.await??; + + let (approval_response_tx, approval_response_rx) = tokio::sync::oneshot::channel(); + message_tx.send(ThreadMessage::SetMode { + mode: SessionModeId::new("full-access"), + response_tx: approval_response_tx, + })?; + approval_response_rx.await??; + drop(message_tx); + + let ops = thread.ops.lock().unwrap(); + assert!(matches!( + ops.last(), + Some(Op::OverrideTurnContext { + collaboration_mode: Some(mode), + approval_policy: Some(_), + sandbox_policy: Some(_), + .. + }) if mode.mode == ModeKind::Default + )); + + Ok(()) + } + + #[tokio::test] + async fn test_load_exposes_combined_modes_in_mode_selector() -> anyhow::Result<()> { + let (_session_id, _client, _thread, message_tx, _handle) = setup().await?; + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::Load { response_tx })?; + + let response = response_rx.await??; + drop(message_tx); + + let modes = response.modes.expect("mode selector should be present"); + assert_eq!(modes.current_mode_id.0.as_ref(), "full-access"); + assert_eq!( + modes + .available_modes + .iter() + .map(|mode| mode.id.0.as_ref()) + .collect::>(), + vec!["read-only", "auto", "full-access", "plan"] + ); + assert_eq!( + modes + .available_modes + .iter() + .map(|mode| mode.name.as_str()) + .collect::>(), + vec!["Read Only", "Default", "Full Access", "Plan"] + ); + assert_eq!( + modes + .available_modes + .iter() + .find(|mode| mode.id.0.as_ref() == "plan") + .and_then(|mode| mode.description.as_deref()), + Some(PLAN_MODE_DESCRIPTION) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_combined_modes_are_exposed_as_config_option() -> anyhow::Result<()> { + let (_session_id, _client, _thread, message_tx, _handle) = setup().await?; + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::GetConfigOptions { response_tx })?; + + let options = response_rx.await??; + drop(message_tx); + + let Some(primary_option) = options.first() else { + anyhow::bail!("expected config options"); + }; + assert_eq!(primary_option.id.0.as_ref(), "mode"); + assert!(matches!( + primary_option.category.as_ref(), + Some(SessionConfigOptionCategory::Other(category)) + if category == "_approval_preset" + )); + assert!( + options + .iter() + .all(|option| option.id.0.as_ref() != "collaboration_mode") + ); + let SessionConfigKind::Select(select) = &primary_option.kind else { + anyhow::bail!("expected mode config option to be a select"); + }; + assert_eq!(select.current_value.0.as_ref(), "full-access"); + let SessionConfigSelectOptions::Ungrouped(select_options) = &select.options else { + anyhow::bail!("expected ungrouped mode options"); + }; + assert_eq!( + select_options + .iter() + .map(|option| ( + option.value.0.as_ref(), + option.name.as_str(), + option.description.as_deref() + )) + .collect::>(), + vec![ + ( + "read-only", + "Read Only", + Some( + "Codex can read files in the current workspace. Approval is required to edit files or access the internet." + ), + ), + ( + "auto", + "Default", + Some( + "Codex can read and edit files in the current workspace, and run commands. Approval is required to access the internet or edit other files. (Identical to Agent mode)" + ), + ), + ( + "full-access", + "Full Access", + Some( + "Codex can edit files outside this workspace and access the internet without asking for approval. Exercise caution when using." + ), + ), + ("plan", "Plan", Some(PLAN_MODE_DESCRIPTION)), + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_plan_completion_prompts_and_accept_submits_default_mode_implementation() + -> anyhow::Result<()> { + let client = Arc::new(StubClient::with_permission_responses(vec![ + RequestPermissionResponse::new(RequestPermissionOutcome::Selected( + SelectedPermissionOutcome::new(PLAN_IMPLEMENTATION_ACCEPT_DEFAULT_OPTION_ID), + )), + ])); + let (session_id, client, thread, message_tx, _handle) = setup_with_client(client).await?; + + let (mode_response_tx, mode_response_rx) = tokio::sync::oneshot::channel(); + message_tx.send(ThreadMessage::SetMode { + mode: SessionModeId::new("plan"), + response_tx: mode_response_tx, + })?; + mode_response_rx.await??; + + let (prompt_response_tx, prompt_response_rx) = tokio::sync::oneshot::channel(); + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id, vec!["plan-turn".into()]), + response_tx: prompt_response_tx, + })?; + let stop_reason_rx = prompt_response_rx.await??; + assert_eq!(stop_reason_rx.await??, StopReason::EndTurn); + + tokio::time::timeout(Duration::from_millis(500), async { + loop { + let has_implementation = thread.ops.lock().unwrap().iter().any(|op| { + matches!( + op, + Op::UserInputWithTurnContext { + items, + collaboration_mode: Some(mode), + approval_policy: Some(_), + sandbox_policy: Some(_), + .. + } if mode.mode == ModeKind::Default + && matches!( + items.as_slice(), + [UserInput::Text { text, .. }] + if text == PLAN_IMPLEMENTATION_CODING_MESSAGE + ) + ) + }); + if has_implementation { + break; + } + tokio::task::yield_now().await; + } + }) + .await?; + + let requests = client.permission_requests.lock().unwrap(); + let request = requests + .iter() + .find(|request| { + request + .tool_call + .fields + .title + .as_ref() + .is_some_and(|title| title == "Implement this plan?") + }) + .expect("expected plan implementation permission request"); + assert!(matches!( + request.tool_call.fields.content.as_deref(), + Some([ + ToolCallContent::Content(Content { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) + ]) if text == "- Step 1\n- Step 2\n" + )); + assert_eq!( + request + .options + .iter() + .map(|option| ( + option.option_id.0.to_string(), + option.name.as_str(), + option.kind + )) + .collect::>(), + vec![ + ( + PLAN_IMPLEMENTATION_ACCEPT_DEFAULT_OPTION_ID.to_string(), + "Accept and continue with Default profile", + PermissionOptionKind::AllowOnce, + ), + ( + PLAN_IMPLEMENTATION_ACCEPT_FULL_ACCESS_OPTION_ID.to_string(), + "Accept and continue with Full Access profile", + PermissionOptionKind::AllowOnce, + ), + ( + PLAN_IMPLEMENTATION_STAY_OPTION_ID.to_string(), + "Reject and continue planning", + PermissionOptionKind::RejectOnce, + ), + ] + ); + let notifications = client.notifications.lock().unwrap(); + assert!(notifications.iter().any(|notification| matches!( + ¬ification.update, + SessionUpdate::ToolCallUpdate(ToolCallUpdate { + fields, + .. + }) if fields.status == Some(ToolCallStatus::Completed) + && fields.title.as_deref() == Some("User accepted plan") + && matches!( + fields.content.as_deref(), + Some([ + ToolCallContent::Content(Content { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) + ]) if text == "- Step 1\n- Step 2\n" + ) + ))); + assert!( + notifications.iter().all(|notification| !matches!( + ¬ification.update, + SessionUpdate::AgentMessageChunk(ContentChunk { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) if text == "- Step 1\n- Step 2\n" + )), + "plan text should be shown in the implementation prompt, not pasted into the chat" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_plan_completion_accept_can_continue_with_full_access_profile() + -> anyhow::Result<()> { + let client = Arc::new(StubClient::with_permission_responses(vec![ + RequestPermissionResponse::new(RequestPermissionOutcome::Selected( + SelectedPermissionOutcome::new(PLAN_IMPLEMENTATION_ACCEPT_FULL_ACCESS_OPTION_ID), + )), + ])); + let (session_id, _client, thread, message_tx, _handle) = setup_with_client(client).await?; + + let (mode_response_tx, mode_response_rx) = tokio::sync::oneshot::channel(); + message_tx.send(ThreadMessage::SetMode { + mode: SessionModeId::new("plan"), + response_tx: mode_response_tx, + })?; + mode_response_rx.await??; + + let (prompt_response_tx, prompt_response_rx) = tokio::sync::oneshot::channel(); + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id, vec!["plan-turn".into()]), + response_tx: prompt_response_tx, + })?; + let stop_reason_rx = prompt_response_rx.await??; + assert_eq!(stop_reason_rx.await??, StopReason::EndTurn); + + tokio::time::timeout(Duration::from_millis(500), async { + loop { + let has_full_access_implementation = thread.ops.lock().unwrap().iter().any(|op| { + matches!( + op, + Op::UserInputWithTurnContext { + collaboration_mode: Some(mode), + sandbox_policy: Some(SandboxPolicy::DangerFullAccess), + .. + } if mode.mode == ModeKind::Default + ) + }); + if has_full_access_implementation { + break; + } + tokio::task::yield_now().await; + } + }) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_plan_completion_stay_in_plan_does_not_submit_implementation() -> anyhow::Result<()> + { + let client = Arc::new(StubClient::with_permission_responses(vec![ + RequestPermissionResponse::new(RequestPermissionOutcome::Selected( + SelectedPermissionOutcome::new(PLAN_IMPLEMENTATION_STAY_OPTION_ID), + )), + ])); + let (session_id, client, thread, message_tx, _handle) = setup_with_client(client).await?; + + let (mode_response_tx, mode_response_rx) = tokio::sync::oneshot::channel(); + message_tx.send(ThreadMessage::SetMode { + mode: SessionModeId::new("plan"), + response_tx: mode_response_tx, + })?; + mode_response_rx.await??; + + let (prompt_response_tx, prompt_response_rx) = tokio::sync::oneshot::channel(); + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id, vec!["plan-turn".into()]), + response_tx: prompt_response_tx, + })?; + let stop_reason_rx = prompt_response_rx.await??; + assert_eq!(stop_reason_rx.await??, StopReason::EndTurn); + + tokio::time::sleep(Duration::from_millis(50)).await; + let ops = thread.ops.lock().unwrap(); + assert!( + !ops.iter() + .any(|op| matches!(op, Op::UserInputWithTurnContext { .. })), + "stay in Plan mode should not submit implementation, got {ops:?}" + ); + drop(ops); + + let notifications = client.notifications.lock().unwrap(); + assert!(notifications.iter().any(|notification| matches!( + ¬ification.update, + SessionUpdate::ToolCallUpdate(ToolCallUpdate { + fields, + .. + }) if fields.status == Some(ToolCallStatus::Failed) + && fields.title.as_deref() == Some("Plan rejected") + && matches!( + fields.content.as_deref(), + Some([ + ToolCallContent::Content(Content { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) + ]) if text == "- Step 1\n- Step 2\n" + ) + ))); + assert!( + notifications.iter().all(|notification| !matches!( + ¬ification.update, + SessionUpdate::AgentMessageChunk(ContentChunk { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) if text == "- Step 1\n- Step 2\n" + )), + "rejected plan should stay in the tool-call UI, not be pasted into chat" + ); + + Ok(()) + } + async fn setup() -> anyhow::Result<( SessionId, Arc, Arc, UnboundedSender, tokio::task::JoinHandle<()>, + )> { + setup_with_client(Arc::new(StubClient::new())).await + } + + async fn setup_with_client( + client: Arc, + ) -> anyhow::Result<( + SessionId, + Arc, + Arc, + UnboundedSender, + tokio::task::JoinHandle<()>, )> { let session_id = SessionId::new("test"); - let client = Arc::new(StubClient::new()); let session_client = SessionClient::with_client(session_id.clone(), client.clone(), Arc::default()); let conversation = Arc::new(StubCodexThread::new()); @@ -4468,6 +5575,29 @@ mod tests { fn list_models(&self) -> Pin> + Send + '_>> { Box::pin(async { all_model_presets().to_owned() }) } + + fn list_collaboration_modes( + &self, + ) -> Pin> + Send + '_>> { + Box::pin(async { + vec![ + CollaborationModeMask { + name: "Default".to_string(), + mode: Some(ModeKind::Default), + model: None, + reasoning_effort: None, + developer_instructions: None, + }, + CollaborationModeMask { + name: "Plan".to_string(), + mode: Some(ModeKind::Plan), + model: None, + reasoning_effort: None, + developer_instructions: None, + }, + ] + }) + } } struct StubCodexThread { @@ -4504,7 +5634,7 @@ mod tests { self.ops.lock().unwrap().push(op.clone()); match op { - Op::UserInput { items, .. } => { + Op::UserInput { items, .. } | Op::UserInputWithTurnContext { items, .. } => { *self.active_prompt_id.lock().unwrap() = Some(id.to_string()); let prompt = items .into_iter() @@ -4616,6 +5746,44 @@ mod tests { }), }) .unwrap(); + } else if prompt == "plan-turn" { + let turn_id = id.to_string(); + self.op_tx + .send(Event { + id: id.to_string(), + msg: EventMsg::TurnStarted(TurnStartedEvent { + model_context_window: None, + collaboration_mode_kind: ModeKind::Plan, + turn_id: turn_id.clone(), + started_at: None, + }), + }) + .unwrap(); + self.op_tx + .send(Event { + id: id.to_string(), + msg: EventMsg::ItemCompleted(ItemCompletedEvent { + thread_id: codex_protocol::ThreadId::new(), + turn_id: turn_id.clone(), + item: TurnItem::Plan(codex_protocol::items::PlanItem { + id: "plan-item".to_string(), + text: "- Step 1\n- Step 2\n".to_string(), + }), + }), + }) + .unwrap(); + self.op_tx + .send(Event { + id: id.to_string(), + msg: EventMsg::TurnComplete(TurnCompleteEvent { + last_agent_message: None, + turn_id, + completed_at: None, + duration_ms: None, + time_to_first_token_ms: None, + }), + }) + .unwrap(); } else { self.op_tx .send(Event { @@ -4765,6 +5933,8 @@ mod tests { | Op::ResolveElicitation { .. } | Op::RequestPermissionsResponse { .. } | Op::PatchApproval { .. } + | Op::UserInputAnswer { .. } + | Op::OverrideTurnContext { .. } | Op::Interrupt => {} Op::Shutdown => { if let Some(active_prompt_id) = self.active_prompt_id.lock().unwrap().take() @@ -5013,6 +6183,211 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_request_user_input_submits_user_input_answer() -> anyhow::Result<()> { + let session_id = SessionId::new("test"); + let client = Arc::new(StubClient::with_permission_responses(vec![ + RequestPermissionResponse::new(RequestPermissionOutcome::Selected( + SelectedPermissionOutcome::new("answer-0"), + )), + ])); + let session_client = SessionClient::with_client(session_id, client.clone(), Arc::default()); + let thread = Arc::new(StubCodexThread::new()); + let (response_tx, _response_rx) = tokio::sync::oneshot::channel(); + let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel(); + let mut prompt_state = PromptState::new( + "submission-id".to_string(), + thread.clone(), + message_tx, + response_tx, + ); + + prompt_state + .request_user_input( + &session_client, + RequestUserInputEvent { + call_id: "call-id".to_string(), + turn_id: "turn-id".to_string(), + questions: vec![RequestUserInputQuestion { + id: "confirm_path".to_string(), + header: "Confirm".to_string(), + question: "Continue?".to_string(), + is_other: false, + is_secret: false, + options: Some(vec![ + codex_protocol::request_user_input::RequestUserInputQuestionOption { + label: "yes".to_string(), + description: "Continue".to_string(), + }, + codex_protocol::request_user_input::RequestUserInputQuestionOption { + label: "no".to_string(), + description: "Stop".to_string(), + }, + ]), + }], + }, + ) + .await?; + + let ThreadMessage::PermissionRequestResolved { + submission_id, + request_key, + response, + } = message_rx.recv().await.unwrap() + else { + panic!("expected permission resolution message"); + }; + assert_eq!(submission_id, "submission-id"); + + prompt_state + .handle_permission_request_resolved(&session_client, request_key, response) + .await?; + + let ops = thread.ops.lock().unwrap(); + assert!(matches!( + ops.last(), + Some(Op::UserInputAnswer { id, response }) + if id == "turn-id" + && response + .answers + .get("confirm_path") + .is_some_and(|answer| answer.answers == vec!["yes".to_string()]) + )); + + let notifications = client.notifications.lock().unwrap(); + assert!(notifications.iter().any(|notification| { + matches!( + ¬ification.update, + SessionUpdate::ToolCallUpdate(update) + if update.tool_call_id.0.as_ref() == "call-id" + && matches!( + update.fields.status, + Some(ToolCallStatus::Completed) + ) + ) + })); + + Ok(()) + } + + #[tokio::test] + async fn test_request_user_input_chains_questions_before_submitting_answers() + -> anyhow::Result<()> { + let session_id = SessionId::new("test"); + let client = Arc::new(StubClient::with_permission_responses(vec![ + RequestPermissionResponse::new(RequestPermissionOutcome::Selected( + SelectedPermissionOutcome::new("answer-0"), + )), + RequestPermissionResponse::new(RequestPermissionOutcome::Selected( + SelectedPermissionOutcome::new("answer-1"), + )), + ])); + let session_client = SessionClient::with_client(session_id, client.clone(), Arc::default()); + let thread = Arc::new(StubCodexThread::new()); + let (response_tx, _response_rx) = tokio::sync::oneshot::channel(); + let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel(); + let mut prompt_state = PromptState::new( + "submission-id".to_string(), + thread.clone(), + message_tx, + response_tx, + ); + + prompt_state + .request_user_input( + &session_client, + RequestUserInputEvent { + call_id: "call-id".to_string(), + turn_id: "turn-id".to_string(), + questions: vec![ + RequestUserInputQuestion { + id: "target".to_string(), + header: "Target".to_string(), + question: "Which file?".to_string(), + is_other: false, + is_secret: false, + options: Some(vec![ + RequestUserInputQuestionOption { + label: "root".to_string(), + description: "Use the root file".to_string(), + }, + RequestUserInputQuestionOption { + label: "npm".to_string(), + description: "Use the npm package file".to_string(), + }, + ]), + }, + RequestUserInputQuestion { + id: "depth".to_string(), + header: "Depth".to_string(), + question: "How detailed?".to_string(), + is_other: false, + is_secret: false, + options: Some(vec![ + RequestUserInputQuestionOption { + label: "brief".to_string(), + description: "Keep it short".to_string(), + }, + RequestUserInputQuestionOption { + label: "detailed".to_string(), + description: "Include more detail".to_string(), + }, + ]), + }, + ], + }, + ) + .await?; + + for expected_request_key in ["user-input:call-id:0", "user-input:call-id:1"] { + let ThreadMessage::PermissionRequestResolved { + submission_id, + request_key, + response, + } = message_rx.recv().await.unwrap() + else { + panic!("expected permission resolution message"); + }; + assert_eq!(submission_id, "submission-id"); + assert_eq!(request_key, expected_request_key); + + prompt_state + .handle_permission_request_resolved(&session_client, request_key, response) + .await?; + } + + let ops = thread.ops.lock().unwrap(); + assert_eq!(ops.len(), 1); + assert!(matches!( + ops.last(), + Some(Op::UserInputAnswer { id, response }) + if id == "turn-id" + && response + .answers + .get("target") + .is_some_and(|answer| answer.answers == vec!["root".to_string()]) + && response + .answers + .get("depth") + .is_some_and(|answer| answer.answers == vec!["detailed".to_string()]) + )); + drop(ops); + + let requests = client.permission_requests.lock().unwrap(); + assert_eq!( + requests + .iter() + .map(|request| request.tool_call.fields.title.as_deref()) + .collect::>(), + vec![ + Some("Need user input: Target"), + Some("Need user input: Depth"), + ] + ); + + Ok(()) + } + #[tokio::test] async fn test_mcp_tool_approval_elicitation_routes_to_permission_request() -> anyhow::Result<()> {