diff --git a/src/thread.rs b/src/thread.rs index 4b074cd..e4037af 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -35,7 +35,7 @@ use codex_protocol::{ custom_prompts::CustomPrompt, dynamic_tools::{DynamicToolCallOutputContentItem, DynamicToolCallRequest}, mcp::CallToolResult, - models::{MacOsSeatbeltProfileExtensions, PermissionProfile, ResponseItem, WebSearchAction}, + models::{FunctionCallOutputPayload, MacOsSeatbeltProfileExtensions, PermissionProfile, ResponseItem, WebSearchAction}, openai_models::{ModelPreset, ReasoningEffort}, parse_command::ParsedCommand, plan_tool::{PlanItemArg, StepStatus, UpdatePlanArgs}, @@ -467,6 +467,7 @@ struct PromptState { response_tx: Option>>, seen_message_deltas: bool, seen_reasoning_deltas: bool, + retain_without_response: bool, } impl PromptState { @@ -487,14 +488,33 @@ impl PromptState { response_tx: Some(response_tx), seen_message_deltas: false, seen_reasoning_deltas: false, + retain_without_response: false, + } + } + + fn new_replay(thread: Arc) -> Self { + let (resolution_tx, _) = mpsc::unbounded_channel(); + Self { + submission_id: String::new(), + active_commands: HashMap::new(), + active_web_search: None, + thread, + resolution_tx, + pending_permission_interactions: HashMap::new(), + event_count: 0, + response_tx: None, + seen_message_deltas: false, + seen_reasoning_deltas: false, + retain_without_response: true, } } fn is_active(&self) -> bool { - let Some(response_tx) = &self.response_tx else { - return false; - }; - !response_tx.is_closed() + self.retain_without_response + || self + .response_tx + .as_ref() + .is_some_and(|response_tx| !response_tx.is_closed()) } fn abort_pending_interactions(&mut self) { @@ -891,6 +911,7 @@ impl PromptState { "Task {turn_id} completed successfully after {} events. Last agent message: {last_agent_message:?}", self.event_count ); + self.retain_without_response = false; self.abort_pending_interactions(); if let Some(response_tx) = self.response_tx.take() { response_tx.send(Ok(StopReason::EndTurn)).ok(); @@ -927,6 +948,7 @@ impl PromptState { codex_error_info, }) => { error!("Unhandled error during turn: {message} {codex_error_info:?}"); + self.retain_without_response = false; self.abort_pending_interactions(); if let Some(response_tx) = self.response_tx.take() { response_tx @@ -938,6 +960,7 @@ impl PromptState { } EventMsg::TurnAborted(TurnAbortedEvent { reason, turn_id }) => { info!("Turn {turn_id:?} aborted: {reason:?}"); + self.retain_without_response = false; self.abort_pending_interactions(); if let Some(response_tx) = self.response_tx.take() { response_tx.send(Ok(StopReason::Cancelled)).ok(); @@ -945,6 +968,7 @@ impl PromptState { } EventMsg::ShutdownComplete => { info!("Agent shutting down"); + self.retain_without_response = false; self.abort_pending_interactions(); if let Some(response_tx) = self.response_tx.take() { response_tx.send(Ok(StopReason::Cancelled)).ok(); @@ -1500,48 +1524,7 @@ impl PromptState { parsed_cmd, process_id: _, } = event; - // Create a new tool call for the command execution - let tool_call_id = ToolCallId::new(call_id.clone()); - let ParseCommandToolCall { - title, - file_extension, - locations, - terminal_output, - kind, - } = parse_command_tool_call(parsed_cmd, &cwd); - - let active_command = ActiveCommand { - tool_call_id: tool_call_id.clone(), - output: String::new(), - file_extension, - terminal_output, - }; - let (content, meta) = if client.supports_terminal_output(&active_command) { - let content = vec![ToolCallContent::Terminal(Terminal::new(call_id.clone()))]; - let meta = Some(Meta::from_iter([( - "terminal_info".to_owned(), - serde_json::json!({ - "terminal_id": call_id, - "cwd": cwd - }), - )])); - (content, meta) - } else { - (vec![], None) - }; - - self.active_commands.insert(call_id.clone(), active_command); - - client - .send_tool_call( - ToolCall::new(tool_call_id, title) - .kind(kind) - .status(ToolCallStatus::InProgress) - .locations(locations) - .raw_input(raw_input) - .content(content) - .meta(meta), - ) + self.begin_active_command(client, call_id, cwd, parsed_cmd, Some(raw_input)) .await; } @@ -1613,38 +1596,14 @@ impl PromptState { process_id: _, status, } = event; - if let Some(active_command) = self.active_commands.remove(&call_id) { - let is_success = exit_code == 0; - - let status = match status { - ExecCommandStatus::Completed => ToolCallStatus::Completed, - _ if is_success => ToolCallStatus::Completed, - ExecCommandStatus::Failed | ExecCommandStatus::Declined => ToolCallStatus::Failed, - }; - - client - .send_tool_call_update( - ToolCallUpdate::new( - active_command.tool_call_id.clone(), - ToolCallUpdateFields::new() - .status(status) - .raw_output(raw_output), - ) - .meta( - client.supports_terminal_output(&active_command).then(|| { - Meta::from_iter([( - "terminal_exit".into(), - serde_json::json!({ - "terminal_id": call_id, - "exit_code": exit_code, - "signal": null - }), - )]) - }), - ), - ) - .await; - } + let is_success = exit_code == 0; + let status = match status { + ExecCommandStatus::Completed => ToolCallStatus::Completed, + _ if is_success => ToolCallStatus::Completed, + ExecCommandStatus::Failed | ExecCommandStatus::Declined => ToolCallStatus::Failed, + }; + self.finish_active_command(client, call_id, exit_code, status, raw_output) + .await; } async fn terminal_interaction( @@ -1696,6 +1655,94 @@ impl PromptState { } } + async fn begin_active_command( + &mut self, + client: &SessionClient, + call_id: String, + cwd: PathBuf, + parsed_cmd: Vec, + raw_input: Option, + ) { + let tool_call_id = ToolCallId::new(call_id.clone()); + let ParseCommandToolCall { + title, + file_extension, + locations, + terminal_output, + kind, + } = parse_command_tool_call(parsed_cmd, &cwd); + + let active_command = ActiveCommand { + tool_call_id: tool_call_id.clone(), + output: String::new(), + file_extension, + terminal_output, + }; + let (content, meta) = if client.supports_terminal_output(&active_command) { + let content = vec![ToolCallContent::Terminal(Terminal::new(call_id.clone()))]; + let meta = Some(Meta::from_iter([( + "terminal_info".to_owned(), + serde_json::json!({ + "terminal_id": call_id, + "cwd": cwd + }), + )])); + (content, meta) + } else { + (vec![], None) + }; + + self.active_commands.insert(call_id.clone(), active_command); + + let mut tool_call = ToolCall::new(tool_call_id, title) + .kind(kind) + .status(ToolCallStatus::InProgress) + .locations(locations) + .content(content); + if let Some(raw_input) = raw_input { + tool_call = tool_call.raw_input(raw_input); + } + if let Some(meta) = meta { + tool_call = tool_call.meta(meta); + } + + client.send_tool_call(tool_call).await; + } + + async fn finish_active_command( + &mut self, + client: &SessionClient, + call_id: String, + exit_code: i32, + status: ToolCallStatus, + raw_output: serde_json::Value, + ) { + if let Some(active_command) = self.active_commands.remove(&call_id) { + client + .send_tool_call_update( + ToolCallUpdate::new( + active_command.tool_call_id.clone(), + ToolCallUpdateFields::new() + .status(status) + .raw_output(raw_output), + ) + .meta( + client.supports_terminal_output(&active_command).then(|| { + Meta::from_iter([( + "terminal_exit".into(), + serde_json::json!({ + "terminal_id": call_id, + "exit_code": exit_code, + "signal": null + }), + )]) + }), + ), + ) + .await; + } + } + async fn start_web_search(&mut self, client: &SessionClient, call_id: String) { self.active_web_search = Some(call_id.clone()); client @@ -1965,6 +2012,29 @@ fn build_exec_permission_options( .collect() } +#[derive(Clone)] +struct ReplayExecCall { + turn_id: String, + original_call_id: String, +} + +enum ReplayFunctionCall { + UnifiedExec(ReplayExecCall), +} + +#[derive(Default)] +struct ReplayState { + current_turn_id: Option, + pending_function_calls: HashMap, + processes: HashMap, +} + +struct ReplayUnifiedExecOutput { + process_id: Option, + exit_code: Option, + output: String, +} + struct ParseCommandToolCall { title: String, file_extension: Option, @@ -3000,18 +3070,23 @@ impl ThreadActor { /// This is called when loading a session to stream all prior messages. /// /// We process both `EventMsg` and `ResponseItem`: - /// - `EventMsg` for user/agent messages and reasoning (like the TUI does) - /// - `ResponseItem` for tool calls only (not persisted as EventMsg) + /// - `EventMsg` for user/agent messages and turn lifecycle + /// - `ResponseItem` for tool calls, including unified exec tool output async fn handle_replay_history(&mut self, history: Vec) -> Result<(), Error> { + let mut replay_state = ReplayState::default(); for item in history { match item { + RolloutItem::TurnContext(turn_context) => { + replay_state.current_turn_id = turn_context.turn_id; + } RolloutItem::EventMsg(event_msg) => { - self.replay_event_msg(&event_msg).await; + self.replay_event_msg(&mut replay_state, &event_msg).await; } RolloutItem::ResponseItem(response_item) => { - self.replay_response_item(&response_item).await; + self.replay_response_item(&mut replay_state, &response_item) + .await; } - // Skip SessionMeta, TurnContext, Compacted + // Skip SessionMeta and Compacted _ => {} } } @@ -3019,8 +3094,8 @@ impl ThreadActor { } /// Convert and send an EventMsg as ACP notification(s) during replay. - /// Handles messages and reasoning - mirrors the live event handling in PromptState. - async fn replay_event_msg(&self, msg: &EventMsg) { + /// Handles messages/reasoning and clears replay-retained prompt state on turn completion. + async fn replay_event_msg(&mut self, replay_state: &mut ReplayState, msg: &EventMsg) { match msg { EventMsg::UserMessage(UserMessageEvent { message, .. }) => { self.client.send_user_message(message.clone()).await; @@ -3034,10 +3109,45 @@ impl ThreadActor { EventMsg::AgentReasoningRawContent(AgentReasoningRawContentEvent { text }) => { self.client.send_agent_thought(text.clone()).await; } - // Skip other event types during replay - they either: - // - Are transient (deltas, turn lifecycle) - // - Don't have direct ACP equivalents - // - Are handled via ResponseItem instead + EventMsg::TurnStarted(TurnStartedEvent { turn_id, .. }) => { + replay_state.current_turn_id = Some(turn_id.clone()); + } + EventMsg::TurnComplete(TurnCompleteEvent { turn_id, .. }) => { + replay_state.current_turn_id = Some(turn_id.clone()); + self.replay_existing_prompt_event(turn_id, msg.clone()) + .await; + replay_state + .pending_function_calls + .retain(|_, call| !matches!(call, ReplayFunctionCall::UnifiedExec(exec) if exec.turn_id == turn_id.as_str())); + replay_state + .processes + .retain(|_, exec| exec.turn_id != turn_id.as_str()); + } + EventMsg::TurnAborted(TurnAbortedEvent { turn_id, .. }) => { + replay_state.current_turn_id = turn_id.clone(); + if let Some(turn_id) = turn_id.as_deref() { + self.replay_existing_prompt_event(turn_id, msg.clone()) + .await; + replay_state + .pending_function_calls + .retain(|_, call| !matches!(call, ReplayFunctionCall::UnifiedExec(exec) if exec.turn_id == turn_id)); + replay_state + .processes + .retain(|_, exec| exec.turn_id != turn_id); + } + } + EventMsg::ShutdownComplete => { + if let Some(turn_id) = replay_state.current_turn_id.clone() { + self.replay_existing_prompt_event(&turn_id, msg.clone()) + .await; + replay_state + .pending_function_calls + .retain(|_, call| !matches!(call, ReplayFunctionCall::UnifiedExec(exec) if exec.turn_id == turn_id.as_str())); + replay_state + .processes + .retain(|_, exec| exec.turn_id != turn_id.as_str()); + } + } _ => {} } } @@ -3170,9 +3280,183 @@ impl ThreadActor { Some((title, kind, locations)) } + fn parse_exec_command_function_call( + &self, + arguments: &str, + ) -> Option<(PathBuf, Vec)> { + #[derive(serde::Deserialize)] + struct ExecCommandArgs { + cmd: String, + #[serde(default, alias = "cwd")] + workdir: Option, + } + + let args: ExecCommandArgs = serde_json::from_str(arguments).ok()?; + let command_vec = shlex::split(&args.cmd) + .filter(|parts| !parts.is_empty()) + .unwrap_or_else(|| vec![args.cmd.clone()]); + let cwd = args.workdir.unwrap_or_else(|| self.config.cwd.clone()); + Some((cwd, parse_command(&command_vec))) + } + + fn parse_write_stdin_function_call(arguments: &str) -> Option<(String, Option)> { + #[derive(serde::Deserialize)] + struct WriteStdinArgs { + session_id: serde_json::Value, + #[serde(default)] + chars: Option, + } + + let args: WriteStdinArgs = serde_json::from_str(arguments).ok()?; + let session_id = match args.session_id { + serde_json::Value::String(session_id) => session_id, + serde_json::Value::Number(session_id) => session_id.to_string(), + _ => return None, + }; + Some((session_id, args.chars)) + } + + fn parse_unified_exec_output( + output: &FunctionCallOutputPayload, + ) -> Option { + let text = output.body.to_text()?; + let (header, body) = text + .split_once("\nOutput:\n") + .unwrap_or((text.as_str(), "")); + + let process_id = header.lines().find_map(|line| { + line.strip_prefix("Process running with session ID ") + .map(str::trim) + .filter(|session_id| !session_id.is_empty()) + .map(ToOwned::to_owned) + }); + let exit_code = header.lines().find_map(|line| { + line.strip_prefix("Process exited with code ") + .and_then(|code| code.trim().parse::().ok()) + }); + + if process_id.is_none() && exit_code.is_none() { + return None; + } + + Some(ReplayUnifiedExecOutput { + process_id, + exit_code, + output: body.to_string(), + }) + } + + fn take_or_create_replay_prompt(&mut self, turn_id: &str) -> PromptState { + match self.submissions.remove(turn_id) { + Some(SubmissionState::Prompt(prompt)) => prompt, + Some(other) => { + warn!("Encountered non-prompt submission while replaying turn {turn_id}"); + self.submissions.insert(turn_id.to_string(), other); + PromptState::new_replay(self.thread.clone()) + } + None => PromptState::new_replay(self.thread.clone()), + } + } + + fn store_replay_prompt(&mut self, turn_id: String, prompt: PromptState) { + if prompt.is_active() { + self.submissions + .insert(turn_id, SubmissionState::Prompt(prompt)); + } + } + + async fn replay_existing_prompt_event(&mut self, turn_id: &str, event: EventMsg) { + let Some(submission) = self.submissions.remove(turn_id) else { + return; + }; + let SubmissionState::Prompt(mut prompt) = submission else { + self.submissions.insert(turn_id.to_string(), submission); + return; + }; + prompt.handle_event(&self.client, event).await; + self.store_replay_prompt(turn_id.to_string(), prompt); + } + + async fn replay_unified_exec_begin( + &mut self, + turn_id: String, + call_id: String, + arguments: &str, + ) -> bool { + let Some((cwd, parsed_cmd)) = self.parse_exec_command_function_call(arguments) else { + return false; + }; + let raw_input = serde_json::from_str(arguments).ok(); + let mut prompt = self.take_or_create_replay_prompt(&turn_id); + prompt + .begin_active_command(&self.client, call_id.clone(), cwd, parsed_cmd, raw_input) + .await; + self.store_replay_prompt(turn_id, prompt); + true + } + + async fn replay_unified_exec_input(&mut self, exec: &ReplayExecCall, chars: &str) { + if chars.is_empty() { + return; + } + let mut prompt = self.take_or_create_replay_prompt(&exec.turn_id); + prompt + .terminal_interaction( + &self.client, + TerminalInteractionEvent { + call_id: exec.original_call_id.clone(), + process_id: String::new(), + stdin: chars.to_string(), + }, + ) + .await; + self.store_replay_prompt(exec.turn_id.clone(), prompt); + } + + async fn replay_unified_exec_output( + &mut self, + exec: ReplayExecCall, + output: &FunctionCallOutputPayload, + ) -> bool { + let Some(parsed_output) = Self::parse_unified_exec_output(output) else { + return false; + }; + + let mut prompt = self.take_or_create_replay_prompt(&exec.turn_id); + if !parsed_output.output.is_empty() { + prompt + .exec_command_output_delta( + &self.client, + ExecCommandOutputDeltaEvent { + call_id: exec.original_call_id.clone(), + chunk: parsed_output.output.as_bytes().to_vec(), + stream: codex_protocol::protocol::ExecOutputStream::Stdout, + }, + ) + .await; + } + if let Some(exit_code) = parsed_output.exit_code { + prompt + .finish_active_command( + &self.client, + exec.original_call_id.clone(), + exit_code, + if exit_code == 0 { + ToolCallStatus::Completed + } else { + ToolCallStatus::Failed + }, + serde_json::to_value(output).unwrap_or(serde_json::Value::Null), + ) + .await; + } + self.store_replay_prompt(exec.turn_id, prompt); + true + } + /// Convert and send a single ResponseItem as ACP notification(s) during replay. /// Only handles tool calls - messages/reasoning are handled via EventMsg. - async fn replay_response_item(&self, item: &ResponseItem) { + async fn replay_response_item(&mut self, replay_state: &mut ReplayState, item: &ResponseItem) { match item { // Skip Message and Reasoning - these are handled via EventMsg ResponseItem::Message { .. } | ResponseItem::Reasoning { .. } => {} @@ -3182,6 +3466,36 @@ impl ThreadActor { call_id, .. } => { + if name == "exec_command" + && let Some(turn_id) = replay_state.current_turn_id.clone() + && self + .replay_unified_exec_begin(turn_id.clone(), call_id.clone(), arguments) + .await + { + replay_state.pending_function_calls.insert( + call_id.clone(), + ReplayFunctionCall::UnifiedExec(ReplayExecCall { + turn_id, + original_call_id: call_id.clone(), + }), + ); + return; + } + + if name == "write_stdin" + && let Some((session_id, chars)) = + Self::parse_write_stdin_function_call(arguments) + && let Some(exec) = replay_state.processes.get(&session_id).cloned() + { + if let Some(chars) = chars.as_deref() { + self.replay_unified_exec_input(&exec, chars).await; + } + replay_state + .pending_function_calls + .insert(call_id.clone(), ReplayFunctionCall::UnifiedExec(exec)); + return; + } + // Check if this is a shell command - parse it like we do for LocalShellCall if matches!(name.as_str(), "shell" | "container.exec" | "shell_command") && let Some((title, kind, locations)) = @@ -3212,6 +3526,22 @@ impl ThreadActor { .await; } ResponseItem::FunctionCallOutput { call_id, output } => { + if let Some(ReplayFunctionCall::UnifiedExec(exec)) = + replay_state.pending_function_calls.remove(call_id) + && let Some(parsed_output) = Self::parse_unified_exec_output(output) + { + if let Some(process_id) = parsed_output.process_id.clone() { + replay_state.processes.insert(process_id, exec.clone()); + } else { + replay_state + .processes + .retain(|_, active| active.original_call_id != exec.original_call_id); + } + + self.replay_unified_exec_output(exec, output).await; + return; + } + self.client .send_tool_call_completed(call_id.clone(), serde_json::to_value(output).ok()) .await; @@ -3468,7 +3798,7 @@ mod tests { use agent_client_protocol::{RequestPermissionResponse, TextContent}; use codex_core::{config::ConfigOverrides, test_support::all_model_presets}; - use codex_protocol::config_types::ModeKind; + use codex_protocol::{config_types::ModeKind, protocol::ExecCommandSource}; use tokio::{ sync::{Mutex, Notify, mpsc::UnboundedSender}, task::LocalSet, @@ -4012,6 +4342,36 @@ mod tests { Ok((session_id, client, conversation, message_tx, local_set)) } + async fn setup_actor( + custom_prompts: Vec, + ) -> anyhow::Result<(Arc, Arc, ThreadActor)> { + let session_id = SessionId::new("test"); + let client = Arc::new(StubClient::new()); + let session_client = SessionClient::with_client(session_id, client.clone(), Arc::default()); + let conversation = Arc::new(StubCodexThread::new()); + let models_manager = Arc::new(StubModelsManager); + let config = Config::load_with_cli_overrides_and_harness_overrides( + vec![], + ConfigOverrides::default(), + ) + .await?; + let (_message_tx, message_rx) = tokio::sync::mpsc::unbounded_channel(); + + let (resolution_tx, resolution_rx) = tokio::sync::mpsc::unbounded_channel(); + let mut actor = ThreadActor::new( + StubAuth, + session_client, + conversation.clone(), + models_manager, + config, + message_rx, + resolution_tx, + resolution_rx, + ); + actor.custom_prompts = Rc::new(RefCell::new(custom_prompts)); + Ok((client, conversation, actor)) + } + struct StubAuth; impl Auth for StubAuth { @@ -4402,6 +4762,90 @@ mod tests { } } + fn replay_turn_context(turn_id: &str) -> RolloutItem { + RolloutItem::TurnContext(codex_protocol::protocol::TurnContextItem { + turn_id: Some(turn_id.to_string()), + cwd: std::env::current_dir().unwrap(), + current_date: None, + timezone: None, + approval_policy: codex_protocol::protocol::AskForApproval::OnRequest, + sandbox_policy: codex_protocol::protocol::SandboxPolicy::DangerFullAccess, + network: None, + model: "gpt-5".to_string(), + personality: None, + collaboration_mode: None, + effort: None, + summary: Default::default(), + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: None, + realtime_active: None, + trace_id: None, + }) + } + + fn replay_function_call( + name: &str, + call_id: &str, + arguments: serde_json::Value, + ) -> RolloutItem { + RolloutItem::ResponseItem(ResponseItem::FunctionCall { + id: None, + name: name.to_string(), + arguments: arguments.to_string(), + call_id: call_id.to_string(), + }) + } + + fn replay_function_call_output(call_id: &str, output: impl Into) -> RolloutItem { + RolloutItem::ResponseItem(ResponseItem::FunctionCallOutput { + call_id: call_id.to_string(), + output: FunctionCallOutputPayload::from_text(output.into()), + }) + } + + fn unified_exec_output_text( + output: &str, + process_id: Option<&str>, + exit_code: Option, + ) -> String { + let mut lines = vec![ + "Chunk ID: replay".to_string(), + "Wall time: 0.0100 seconds".to_string(), + ]; + if let Some(exit_code) = exit_code { + lines.push(format!("Process exited with code {exit_code}")); + } + if let Some(process_id) = process_id { + lines.push(format!("Process running with session ID {process_id}")); + } + lines.push("Original token count: 0".to_string()); + lines.push("Output:".to_string()); + lines.push(output.to_string()); + lines.join("\n") + } + + fn tool_calls(notifications: &[SessionNotification]) -> Vec { + notifications + .iter() + .filter_map(|notification| match ¬ification.update { + SessionUpdate::ToolCall(tool_call) => Some(tool_call.clone()), + _ => None, + }) + .collect() + } + + fn tool_call_updates(notifications: &[SessionNotification]) -> Vec { + notifications + .iter() + .filter_map(|notification| match ¬ification.update { + SessionUpdate::ToolCallUpdate(update) => Some(update.clone()), + _ => None, + }) + .collect() + } + #[tokio::test] async fn test_parallel_exec_commands() -> anyhow::Result<()> { let (session_id, client, _, message_tx, local_set) = setup(vec![]).await?; @@ -4482,6 +4926,94 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_replay_unified_exec_routes_write_stdin_to_original_command() -> anyhow::Result<()> + { + let (client, _, mut actor) = setup_actor(vec![]).await?; + let turn_id = "turn-replay"; + let cwd = std::env::current_dir().unwrap(); + actor + .handle_replay_history(vec![ + replay_turn_context(turn_id), + replay_function_call( + "exec_command", + "call-exec", + serde_json::json!({ + "cmd": "make build", + "workdir": cwd, + }), + ), + replay_function_call_output( + "call-exec", + unified_exec_output_text("Compiling...\n", Some("74418"), None), + ), + replay_function_call( + "write_stdin", + "call-poll-1", + serde_json::json!({ + "session_id": 74418, + "chars": "status\n", + }), + ), + replay_function_call_output( + "call-poll-1", + unified_exec_output_text("Still running...\n", Some("74418"), None), + ), + replay_function_call( + "write_stdin", + "call-poll-2", + serde_json::json!({ + "session_id": 74418, + "chars": "", + }), + ), + replay_function_call_output( + "call-poll-2", + unified_exec_output_text("Done\n", None, Some(0)), + ), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + last_agent_message: None, + turn_id: turn_id.to_string(), + })), + ]) + .await?; + + let notifications = client.notifications.lock().unwrap(); + let tool_calls = tool_calls(¬ifications); + let updates = tool_call_updates(¬ifications); + + assert_eq!( + tool_calls.len(), + 1, + "expected only the original exec_command tool call, got {tool_calls:?}" + ); + assert_eq!(tool_calls[0].tool_call_id, ToolCallId::new("call-exec")); + + let completed_updates: Vec<_> = updates + .iter() + .filter(|update| { + update.tool_call_id == ToolCallId::new("call-exec") + && update.fields.status == Some(ToolCallStatus::Completed) + }) + .collect(); + assert_eq!( + completed_updates.len(), + 1, + "expected one completion for the original exec_command, got {updates:?}" + ); + assert!( + !format!("{tool_calls:?}").contains("call-poll"), + "write_stdin polling should not create separate tool calls: {tool_calls:?}" + ); + assert!( + format!("{updates:?}").contains("Still running...") + && format!("{updates:?}").contains("status"), + "expected replayed polling output and stdin to attach to the original command: {updates:?}" + ); + + Ok(()) + } + #[tokio::test] async fn test_exec_approval_uses_available_decisions() -> anyhow::Result<()> { LocalSet::new() @@ -4569,6 +5101,102 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_replay_unified_exec_preserves_active_command_for_live_terminal_end() + -> anyhow::Result<()> { + let (client, _, mut actor) = setup_actor(vec![]).await?; + let turn_id = "turn-live"; + let cwd = std::env::current_dir().unwrap(); + actor + .handle_replay_history(vec![ + replay_turn_context(turn_id), + replay_function_call( + "exec_command", + "call-live", + serde_json::json!({ + "cmd": "make build", + "workdir": cwd.clone(), + }), + ), + replay_function_call_output( + "call-live", + unified_exec_output_text("Compiling...\n", Some("74418"), None), + ), + ]) + .await?; + + actor + .handle_event(Event { + id: turn_id.to_string(), + msg: EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent { + call_id: "call-live".into(), + stream: codex_protocol::protocol::ExecOutputStream::Stdout, + chunk: b"Done\n".to_vec(), + }), + }) + .await; + actor + .handle_event(Event { + id: turn_id.to_string(), + msg: EventMsg::ExecCommandEnd(ExecCommandEndEvent { + call_id: "call-live".into(), + process_id: Some("74418".into()), + turn_id: turn_id.to_string(), + command: vec!["make".into(), "build".into()], + cwd, + parsed_cmd: vec![], + source: ExecCommandSource::UnifiedExecStartup, + interaction_input: None, + stdout: "Compiling...\nDone\n".into(), + stderr: String::new(), + aggregated_output: "Compiling...\nDone\n".into(), + exit_code: 0, + duration: std::time::Duration::from_millis(10), + formatted_output: "Compiling...\nDone\n".into(), + status: ExecCommandStatus::Completed, + }), + }) + .await; + actor + .handle_event(Event { + id: turn_id.to_string(), + msg: EventMsg::TurnComplete(TurnCompleteEvent { + last_agent_message: None, + turn_id: turn_id.to_string(), + }), + }) + .await; + + let notifications = client.notifications.lock().unwrap(); + let tool_calls = tool_calls(¬ifications); + let updates = tool_call_updates(¬ifications); + + assert_eq!( + tool_calls.len(), + 1, + "expected one replayed exec tool call, got {tool_calls:?}" + ); + let completed_updates: Vec<_> = updates + .iter() + .filter(|update| { + update.tool_call_id == ToolCallId::new("call-live") + && update.fields.status == Some(ToolCallStatus::Completed) + }) + .collect(); + assert_eq!( + completed_updates.len(), + 1, + "expected the live terminal end to complete the replayed exec command, got {updates:?}" + ); + assert!( + format!("{updates:?}").contains("Compiling...") + && format!("{updates:?}").contains("Done"), + "expected replayed output and live output to land on the same tool call: {updates:?}" + ); + + Ok(()) + } + #[tokio::test] async fn test_mcp_elicitation_declines_unsupported_form_requests() -> anyhow::Result<()> { LocalSet::new()