diff --git a/src/thread.rs b/src/thread.rs index 4b074cd..40c7cb1 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -31,9 +31,10 @@ use codex_core::{ }; use codex_protocol::{ approvals::{ElicitationRequest, ElicitationRequestEvent}, - config_types::TrustLevel, + config_types::{CollaborationMode, CollaborationModeMask, ModeKind, Settings, TrustLevel}, custom_prompts::CustomPrompt, dynamic_tools::{DynamicToolCallOutputContentItem, DynamicToolCallRequest}, + items::TurnItem, mcp::CallToolResult, models::{MacOsSeatbeltProfileExtensions, PermissionProfile, ResponseItem, WebSearchAction}, openai_models::{ModelPreset, ReasoningEffort}, @@ -57,6 +58,10 @@ use codex_protocol::{ request_permissions::{ PermissionGrantScope, RequestPermissionsEvent, RequestPermissionsResponse, }, + request_user_input::{ + RequestUserInputAnswer, RequestUserInputEvent, RequestUserInputQuestion, + RequestUserInputResponse, + }, user_input::UserInput, }; use codex_shell_command::parse_command::parse_command; @@ -98,6 +103,7 @@ impl CodexThreadImpl for CodexThread { pub trait ModelsManagerImpl { async fn get_model(&self, model_id: &Option) -> String; async fn list_models(&self) -> Vec; + async fn list_collaboration_modes(&self) -> Vec; } #[async_trait::async_trait] @@ -110,6 +116,10 @@ impl ModelsManagerImpl for ModelsManager { async fn list_models(&self) -> Vec { self.list_models(RefreshStrategy::OnlineIfUncached).await } + + async fn list_collaboration_modes(&self) -> Vec { + self.list_collaboration_modes() + } } pub trait Auth { @@ -343,6 +353,14 @@ enum PendingPermissionRequest { call_id: String, permissions: PermissionProfile, }, + UserInput { + turn_id: String, + call_id: String, + questions: Vec, + question_index: usize, + answers: HashMap, + option_map: HashMap, + }, } struct PendingPermissionInteraction { @@ -362,6 +380,19 @@ fn permissions_request_key(call_id: &str) -> String { format!("permissions:{call_id}") } +fn user_input_request_key(call_id: &str, question_index: usize) -> String { + format!("user-input:{call_id}:{question_index}") +} + +fn mode_kind_as_id(mode: ModeKind) -> &'static str { + match mode { + ModeKind::Plan => "plan", + ModeKind::Default => "default", + ModeKind::PairProgramming => "pair_programming", + ModeKind::Execute => "execute", + } +} + enum SubmissionState { /// Loading custom prompts from the project CustomPrompts(CustomPromptsState), @@ -537,9 +568,131 @@ impl PromptState { } } + async fn submit_user_input_answers( + &self, + turn_id: String, + answers: HashMap, + ) -> Result<(), Error> { + self.thread + .submit(Op::UserInputAnswer { + id: turn_id, + response: RequestUserInputResponse { answers }, + }) + .await + .map_err(|e| Error::from(anyhow::anyhow!(e)))?; + Ok(()) + } + + async fn complete_user_input_tool_call(&self, client: &SessionClient, call_id: String) { + client + .send_tool_call_update(ToolCallUpdate::new( + ToolCallId::new(call_id), + ToolCallUpdateFields::new().status(ToolCallStatus::Completed), + )) + .await; + } + + async fn finalize_user_input_answers( + &self, + client: &SessionClient, + call_id: String, + turn_id: String, + answers: HashMap, + ) -> Result<(), Error> { + self.submit_user_input_answers(turn_id, answers).await?; + self.complete_user_input_tool_call(client, call_id).await; + Ok(()) + } + + fn spawn_user_input_question_request( + &mut self, + client: &SessionClient, + turn_id: String, + call_id: String, + questions: Vec, + question_index: usize, + answers: HashMap, + ) { + let Some(question) = questions.get(question_index).cloned() else { + let thread = self.thread.clone(); + let client = client.clone(); + tokio::task::spawn_local(async move { + if let Err(err) = thread + .submit(Op::UserInputAnswer { + id: turn_id, + response: RequestUserInputResponse { answers }, + }) + .await + { + warn!("Failed to submit UserInputAnswer fallback: {err}"); + return; + } + + client + .send_tool_call_update(ToolCallUpdate::new( + ToolCallId::new(call_id), + ToolCallUpdateFields::new().status(ToolCallStatus::Completed), + )) + .await; + }); + return; + }; + + let (options, option_map) = build_user_input_permission_options(&question); + + let mut content_lines = vec![question.question.clone()]; + if let Some(question_options) = question.options.as_ref() { + content_lines.extend( + question_options + .iter() + .map(|option| format!("- {}: {}", option.label, option.description)), + ); + } + if question.is_other { + content_lines.push( + "- Other: custom answer will be available when the client UI supports structured input" + .to_string(), + ); + } + + let title = if question.header.is_empty() { + "Need user input".to_string() + } else { + format!("Need user input: {}", question.header) + }; + + let request_key = user_input_request_key(&call_id, question_index); + self.spawn_permission_request( + client, + request_key, + PendingPermissionRequest::UserInput { + turn_id, + call_id: call_id.clone(), + questions, + question_index, + answers, + option_map, + }, + ToolCallUpdate::new( + ToolCallId::new(call_id), + ToolCallUpdateFields::new() + .kind(ToolKind::Think) + .status(ToolCallStatus::Pending) + .title(title) + .raw_input(serde_json::json!({ + "request_type": "request_user_input", + "question": question, + "fallback": "session/request_permission", + })) + .content(vec![content_lines.join("\n").into()]), + ), + options, + ); + } + async fn handle_permission_request_resolved( &mut self, - _client: &SessionClient, + client: &SessionClient, request_key: String, response: Result, ) -> Result<(), Error> { @@ -635,6 +788,57 @@ impl PromptState { .await .map_err(|e| Error::from(anyhow::anyhow!(e)))?; } + PendingPermissionRequest::UserInput { + turn_id, + call_id, + questions, + question_index, + mut answers, + option_map, + } => { + let Some(question) = questions.get(question_index) else { + self.finalize_user_input_answers(client, call_id, turn_id, answers) + .await?; + return Ok(()); + }; + + let selected_answer = match response.outcome { + RequestPermissionOutcome::Selected(SelectedPermissionOutcome { + option_id, + .. + }) => option_map.get(option_id.0.as_ref()).cloned(), + RequestPermissionOutcome::Cancelled | _ => None, + }; + + let Some(answer) = selected_answer else { + self.finalize_user_input_answers(client, call_id, turn_id, answers) + .await?; + return Ok(()); + }; + + answers.insert( + question.id.clone(), + RequestUserInputAnswer { + answers: vec![answer], + }, + ); + + let next_question_index = question_index + 1; + if next_question_index >= questions.len() { + self.finalize_user_input_answers(client, call_id, turn_id, answers) + .await?; + return Ok(()); + } + + self.spawn_user_input_question_request( + client, + turn_id, + call_id, + questions, + next_question_index, + answers, + ); + } } Ok(()) @@ -665,6 +869,7 @@ impl PromptState { | EventMsg::TurnAborted(..) | EventMsg::EnteredReviewMode(..) | EventMsg::ExitedReviewMode(..) + | EventMsg::RequestUserInput(..) | EventMsg::ShutdownComplete => { self.complete_web_search(client).await; } @@ -885,6 +1090,10 @@ impl PromptState { item, }) => { info!("Item completed: thread_id={}, turn_id={}, item={:?}", thread_id, turn_id, item); + if let TurnItem::Plan(plan_item) = item { + // Fallback for ACP clients that do not render plan items natively. + client.send_agent_text(plan_item.text).await; + } } EventMsg::TurnComplete(TurnCompleteEvent { last_agent_message, turn_id }) => { info!( @@ -1017,6 +1226,19 @@ impl PromptState { drop(response_tx.send(Err(err))); } } + EventMsg::RequestUserInput(event) => { + info!( + "Request user input: call_id={}, turn_id={}, questions={}", + event.call_id, + event.turn_id, + event.questions.len() + ); + if let Err(err) = self.request_user_input(client, event).await + && let Some(response_tx) = self.response_tx.take() + { + drop(response_tx.send(Err(err))); + } + } // Ignore these events EventMsg::ImageGenerationBegin(..) @@ -1058,7 +1280,6 @@ impl PromptState { // Used for returning a single history entry | EventMsg::GetHistoryEntryResponse(..) | EventMsg::DeprecationNotice(..) - | EventMsg::RequestUserInput(..) | EventMsg::ListRemoteSkillsResponse(..) | EventMsg::RemoteSkillDownloaded(..)) => { warn!("Unexpected event: {:?}", e); @@ -1754,6 +1975,35 @@ impl PromptState { } } + async fn request_user_input( + &mut self, + client: &SessionClient, + event: RequestUserInputEvent, + ) -> Result<(), Error> { + let RequestUserInputEvent { + call_id, + turn_id, + questions, + } = event; + + if questions.is_empty() { + self.finalize_user_input_answers(client, call_id, turn_id, HashMap::new()) + .await?; + return Ok(()); + } + + self.spawn_user_input_question_request( + client, + turn_id, + call_id, + questions, + 0, + HashMap::new(), + ); + + Ok(()) + } + async fn request_permissions( &mut self, client: &SessionClient, @@ -1965,6 +2215,53 @@ fn build_exec_permission_options( .collect() } +fn build_user_input_permission_options( + question: &RequestUserInputQuestion, +) -> (Vec, HashMap) { + let mut option_map = HashMap::new(); + let mut options = Vec::new(); + + if let Some(question_options) = question.options.as_ref() { + for (index, option) in question_options.iter().enumerate() { + let option_id = format!("answer-{index}"); + option_map.insert(option_id.clone(), option.label.clone()); + options.push(PermissionOption::new( + option_id, + option.label.clone(), + PermissionOptionKind::AllowOnce, + )); + } + } + + if question.is_other { + let option_id = "answer-other".to_string(); + option_map.insert(option_id.clone(), "other".to_string()); + options.push(PermissionOption::new( + option_id, + "Other", + PermissionOptionKind::AllowOnce, + )); + } + + if options.is_empty() { + let option_id = "answer-continue".to_string(); + option_map.insert(option_id.clone(), "continue".to_string()); + options.push(PermissionOption::new( + option_id, + "Continue", + PermissionOptionKind::AllowOnce, + )); + } + + options.push(PermissionOption::new( + "cancel", + "Cancel", + PermissionOptionKind::RejectOnce, + )); + + (options, option_map) +} + struct ParseCommandToolCall { title: String, file_extension: Option, @@ -2205,6 +2502,8 @@ struct ThreadActor { resolution_rx: mpsc::UnboundedReceiver, /// Last config options state we emitted to the client, used for deduping updates. last_sent_config_options: Option>, + /// Current collaboration mode kind for this session. + current_collaboration_mode_kind: ModeKind, } impl ThreadActor { @@ -2231,6 +2530,7 @@ impl ThreadActor { message_rx, resolution_rx, last_sent_config_options: None, + current_collaboration_mode_kind: ModeKind::Default, } } @@ -2518,6 +2818,37 @@ impl ThreadActor { ); } + let collaboration_modes = self.models_manager.list_collaboration_modes().await; + let mut collaboration_mode_options = Vec::new(); + for mask in collaboration_modes { + let Some(mode) = mask.mode else { + continue; + }; + if !mode.is_tui_visible() { + continue; + } + let mode_id = mode_kind_as_id(mode); + if collaboration_mode_options + .iter() + .any(|opt: &SessionConfigSelectOption| opt.value.0.as_ref() == mode_id) + { + continue; + } + collaboration_mode_options.push(SessionConfigSelectOption::new(mode_id, mask.name)); + } + if !collaboration_mode_options.is_empty() { + options.push( + SessionConfigOption::select( + "collaboration_mode", + "Collaboration Mode", + mode_kind_as_id(self.current_collaboration_mode_kind), + collaboration_mode_options, + ) + .category(SessionConfigOptionCategory::Mode) + .description("Choose collaboration behavior (Default or Plan mode)"), + ); + } + let presets = self.models_manager.list_models().await; let current_model = self.get_current_model().await; @@ -2621,12 +2952,62 @@ impl ThreadActor { }; match config_id.0.as_ref() { "mode" => self.handle_set_mode(SessionModeId::new(value.0)).await, + "collaboration_mode" => self.handle_set_collaboration_mode(value).await, "model" => self.handle_set_config_model(value).await, "reasoning_effort" => self.handle_set_config_reasoning_effort(value).await, _ => Err(Error::invalid_params().data("Unsupported config option")), } } + async fn handle_set_collaboration_mode( + &mut self, + value: SessionConfigValueId, + ) -> Result<(), Error> { + let mode: ModeKind = serde_json::from_value(value.0.as_ref().into()) + .map_err(|_| Error::invalid_params().data("Unsupported collaboration mode"))?; + if !mode.is_tui_visible() { + return Err(Error::invalid_params().data("Unsupported collaboration mode")); + } + + let masks = self.models_manager.list_collaboration_modes().await; + let Some(mask) = masks.iter().find(|mask| mask.mode == Some(mode)) else { + return Err(Error::invalid_params().data("Collaboration mode is unavailable")); + }; + + let current_mode = CollaborationMode { + mode: self.current_collaboration_mode_kind, + settings: Settings { + model: self.get_current_model().await, + reasoning_effort: self.config.model_reasoning_effort, + developer_instructions: self.config.developer_instructions.clone(), + }, + }; + let next_mode = current_mode.apply_mask(mask); + + self.thread + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(next_mode.clone()), + personality: None, + windows_sandbox_level: None, + service_tier: None, + }) + .await + .map_err(|e| Error::from(anyhow::anyhow!(e)))?; + + self.current_collaboration_mode_kind = next_mode.mode; + self.config.model = Some(next_mode.settings.model.clone()); + self.config.model_reasoning_effort = next_mode.settings.reasoning_effort; + self.config.developer_instructions = next_mode.settings.developer_instructions; + + Ok(()) + } + async fn handle_set_config_model(&mut self, value: SessionConfigValueId) -> Result<(), Error> { let model_id = value.0; @@ -3315,6 +3696,14 @@ impl ThreadActor { } async fn handle_event(&mut self, Event { id, msg }: Event) { + if let EventMsg::TurnStarted(TurnStartedEvent { + collaboration_mode_kind, + .. + }) = &msg + { + self.current_collaboration_mode_kind = *collaboration_mode_kind; + } + if let Some(submission) = self.submissions.get_mut(&id) { submission.handle_event(&self.client, msg).await; } else { @@ -4031,6 +4420,25 @@ mod tests { async fn list_models(&self) -> Vec { all_model_presets().to_owned() } + + async fn list_collaboration_modes(&self) -> Vec { + vec![ + CollaborationModeMask { + name: "Default".to_string(), + mode: Some(ModeKind::Default), + model: None, + reasoning_effort: None, + developer_instructions: None, + }, + CollaborationModeMask { + name: "Plan".to_string(), + mode: Some(ModeKind::Plan), + model: None, + reasoning_effort: None, + developer_instructions: None, + }, + ] + } } struct StubCodexThread { @@ -4569,6 +4977,86 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_request_user_input_submits_user_input_answer() -> anyhow::Result<()> { + LocalSet::new() + .run_until(async { + let session_id = SessionId::new("test"); + let client = Arc::new(StubClient::with_permission_responses(vec![ + RequestPermissionResponse::new(RequestPermissionOutcome::Selected( + SelectedPermissionOutcome::new("answer-0"), + )), + ])); + let session_client = + SessionClient::with_client(session_id, client.clone(), Arc::default()); + let thread = Arc::new(StubCodexThread::new()); + let (response_tx, _response_rx) = tokio::sync::oneshot::channel(); + let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel(); + let mut prompt_state = PromptState::new( + "submission-id".to_string(), + thread.clone(), + message_tx, + response_tx, + ); + + prompt_state + .request_user_input( + &session_client, + RequestUserInputEvent { + call_id: "call-id".to_string(), + turn_id: "turn-id".to_string(), + questions: vec![RequestUserInputQuestion { + id: "confirm_path".to_string(), + header: "Confirm".to_string(), + question: "Continue?".to_string(), + is_other: false, + is_secret: false, + options: Some(vec![ + codex_protocol::request_user_input::RequestUserInputQuestionOption { + label: "yes".to_string(), + description: "Continue".to_string(), + }, + codex_protocol::request_user_input::RequestUserInputQuestionOption { + label: "no".to_string(), + description: "Stop".to_string(), + }, + ]), + }], + }, + ) + .await?; + + let ThreadMessage::PermissionRequestResolved { + submission_id, + request_key, + response, + } = message_rx.recv().await.unwrap() + else { + panic!("expected permission resolution message"); + }; + assert_eq!(submission_id, "submission-id"); + prompt_state + .handle_permission_request_resolved(&session_client, request_key, response) + .await?; + + let ops = thread.ops.lock().unwrap(); + assert!(matches!( + ops.last(), + Some(Op::UserInputAnswer { id, response }) + if id == "turn-id" + && response + .answers + .get("confirm_path") + .is_some_and(|answer| answer.answers == vec!["yes".to_string()]) + )); + + anyhow::Ok(()) + }) + .await?; + + Ok(()) + } + #[tokio::test] async fn test_mcp_elicitation_declines_unsupported_form_requests() -> anyhow::Result<()> { LocalSet::new()