diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index 83fa53b9973..e9e71140ad0 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -113,6 +113,10 @@ client_request_definitions! { params: v2::ThreadArchiveParams, response: v2::ThreadArchiveResponse, }, + ThreadRollback => "thread/rollback" { + params: v2::ThreadRollbackParams, + response: v2::ThreadRollbackResponse, + }, ThreadList => "thread/list" { params: v2::ThreadListParams, response: v2::ThreadListResponse, diff --git a/codex-rs/app-server-protocol/src/protocol/thread_history.rs b/codex-rs/app-server-protocol/src/protocol/thread_history.rs index ba1e6261cc6..6fa6dfabbd4 100644 --- a/codex-rs/app-server-protocol/src/protocol/thread_history.rs +++ b/codex-rs/app-server-protocol/src/protocol/thread_history.rs @@ -6,6 +6,7 @@ use crate::protocol::v2::UserInput; use codex_protocol::protocol::AgentReasoningEvent; use codex_protocol::protocol::AgentReasoningRawContentEvent; use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::ThreadRolledBackEvent; use codex_protocol::protocol::TurnAbortedEvent; use codex_protocol::protocol::UserMessageEvent; @@ -57,6 +58,7 @@ impl ThreadHistoryBuilder { EventMsg::TokenCount(_) => {} EventMsg::EnteredReviewMode(_) => {} EventMsg::ExitedReviewMode(_) => {} + EventMsg::ThreadRolledBack(payload) => self.handle_thread_rollback(payload), EventMsg::UndoCompleted(_) => {} EventMsg::TurnAborted(payload) => self.handle_turn_aborted(payload), _ => {} @@ -130,6 +132,23 @@ impl ThreadHistoryBuilder { turn.status = TurnStatus::Interrupted; } + fn handle_thread_rollback(&mut self, payload: &ThreadRolledBackEvent) { + self.finish_current_turn(); + + let n = usize::try_from(payload.num_turns).unwrap_or(usize::MAX); + if n >= self.turns.len() { + self.turns.clear(); + } else { + self.turns.truncate(self.turns.len().saturating_sub(n)); + } + + // Re-number subsequent synthetic ids so the pruned history is consistent. + self.next_turn_index = + i64::try_from(self.turns.len().saturating_add(1)).unwrap_or(i64::MAX); + let item_count: usize = self.turns.iter().map(|t| t.items.len()).sum(); + self.next_item_index = i64::try_from(item_count.saturating_add(1)).unwrap_or(i64::MAX); + } + fn finish_current_turn(&mut self) { if let Some(turn) = self.current_turn.take() { if turn.items.is_empty() { @@ -213,6 +232,7 @@ mod tests { use codex_protocol::protocol::AgentMessageEvent; use codex_protocol::protocol::AgentReasoningEvent; use codex_protocol::protocol::AgentReasoningRawContentEvent; + use codex_protocol::protocol::ThreadRolledBackEvent; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnAbortedEvent; use codex_protocol::protocol::UserMessageEvent; @@ -410,4 +430,95 @@ mod tests { } ); } + + #[test] + fn drops_last_turns_on_thread_rollback() { + let events = vec![ + EventMsg::UserMessage(UserMessageEvent { + message: "First".into(), + images: None, + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "A1".into(), + }), + EventMsg::UserMessage(UserMessageEvent { + message: "Second".into(), + images: None, + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "A2".into(), + }), + EventMsg::ThreadRolledBack(ThreadRolledBackEvent { num_turns: 1 }), + EventMsg::UserMessage(UserMessageEvent { + message: "Third".into(), + images: None, + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "A3".into(), + }), + ]; + + let turns = build_turns_from_event_msgs(&events); + let expected = vec![ + Turn { + id: "turn-1".into(), + status: TurnStatus::Completed, + error: None, + items: vec![ + ThreadItem::UserMessage { + id: "item-1".into(), + content: vec![UserInput::Text { + text: "First".into(), + }], + }, + ThreadItem::AgentMessage { + id: "item-2".into(), + text: "A1".into(), + }, + ], + }, + Turn { + id: "turn-2".into(), + status: TurnStatus::Completed, + error: None, + items: vec![ + ThreadItem::UserMessage { + id: "item-3".into(), + content: vec![UserInput::Text { + text: "Third".into(), + }], + }, + ThreadItem::AgentMessage { + id: "item-4".into(), + text: "A3".into(), + }, + ], + }, + ]; + assert_eq!(turns, expected); + } + + #[test] + fn thread_rollback_clears_all_turns_when_num_turns_exceeds_history() { + let events = vec![ + EventMsg::UserMessage(UserMessageEvent { + message: "One".into(), + images: None, + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "A1".into(), + }), + EventMsg::UserMessage(UserMessageEvent { + message: "Two".into(), + images: None, + }), + EventMsg::AgentMessage(AgentMessageEvent { + message: "A2".into(), + }), + EventMsg::ThreadRolledBack(ThreadRolledBackEvent { num_turns: 99 }), + ]; + + let turns = build_turns_from_event_msgs(&events); + assert_eq!(turns, Vec::::new()); + } } diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index 4311baba5fa..dedf51ea1bb 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -89,6 +89,7 @@ pub enum CodexErrorInfo { InternalServerError, Unauthorized, BadRequest, + ThreadRollbackFailed, SandboxError, /// The response SSE stream disconnected in the middle of a turn before completion. ResponseStreamDisconnected { @@ -119,6 +120,7 @@ impl From for CodexErrorInfo { CoreCodexErrorInfo::InternalServerError => CodexErrorInfo::InternalServerError, CoreCodexErrorInfo::Unauthorized => CodexErrorInfo::Unauthorized, CoreCodexErrorInfo::BadRequest => CodexErrorInfo::BadRequest, + CoreCodexErrorInfo::ThreadRollbackFailed => CodexErrorInfo::ThreadRollbackFailed, CoreCodexErrorInfo::SandboxError => CodexErrorInfo::SandboxError, CoreCodexErrorInfo::ResponseStreamDisconnected { http_status_code } => { CodexErrorInfo::ResponseStreamDisconnected { http_status_code } @@ -1055,6 +1057,30 @@ pub struct ThreadArchiveParams { #[ts(export_to = "v2/")] pub struct ThreadArchiveResponse {} +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct ThreadRollbackParams { + pub thread_id: String, + /// The number of turns to drop from the end of the thread. Must be >= 1. + /// + /// This only modifies the thread's history and does not revert local file changes + /// that have been made by the agent. Clients are responsible for reverting these changes. + pub num_turns: u32, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct ThreadRollbackResponse { + /// The updated thread after applying the rollback, with `turns` populated. + /// + /// The ThreadItems stored in each Turn are lossy since we explicitly do not + /// persist all agent interactions, such as command executions. This is the same + /// behavior as `thread/resume`. + pub thread: Thread, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -1193,7 +1219,7 @@ pub struct Thread { pub source: SessionSource, /// Optional Git metadata captured when the thread was created. pub git_info: Option, - /// Only populated on a `thread/resume` response. + /// Only populated on `thread/resume` and `thread/rollback` responses. /// For all other responses and notifications returning a Thread, /// the turns field will be an empty list. pub turns: Vec, diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 267d7577e53..b8f7de90ab1 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -72,6 +72,7 @@ Example (from OpenAI's official VSCode extension): - `thread/resume` — reopen an existing thread by id so subsequent `turn/start` calls append to it. - `thread/list` — page through stored rollouts; supports cursor-based pagination and optional `modelProviders` filtering. - `thread/archive` — move a thread’s rollout file into the archived directory; returns `{}` on success. +- `thread/rollback` — drop the last N turns from the agent’s in-memory context and persist a rollback marker in the rollout so future resumes see the pruned history; returns the updated `thread` (with `turns` populated) on success. - `turn/start` — add user input to a thread and begin Codex generation; responds with the initial `turn` object and streams `turn/started`, `item/*`, and `turn/completed` notifications. - `turn/interrupt` — request cancellation of an in-flight turn by `(thread_id, turn_id)`; success is an empty `{}` response and the turn finishes with `status: "interrupted"`. - `review/start` — kick off Codex’s automated reviewer for a thread; responds like `turn/start` and emits `item/started`/`item/completed` notifications with `enteredReviewMode` and `exitedReviewMode` items, plus a final assistant `agentMessage` containing the review. diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index ad0455a0587..c9b78fe8c88 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -1,7 +1,13 @@ use crate::codex_message_processor::ApiVersion; use crate::codex_message_processor::PendingInterrupts; +use crate::codex_message_processor::PendingRollbacks; use crate::codex_message_processor::TurnSummary; use crate::codex_message_processor::TurnSummaryStore; +use crate::codex_message_processor::read_event_msgs_from_rollout; +use crate::codex_message_processor::read_summary_from_rollout; +use crate::codex_message_processor::summary_to_thread; +use crate::error_code::INTERNAL_ERROR_CODE; +use crate::error_code::INVALID_REQUEST_ERROR_CODE; use crate::outgoing_message::OutgoingMessageSender; use codex_app_server_protocol::AccountRateLimitsUpdatedNotification; use codex_app_server_protocol::AgentMessageDeltaNotification; @@ -27,6 +33,7 @@ use codex_app_server_protocol::FileUpdateChange; use codex_app_server_protocol::InterruptConversationResponse; use codex_app_server_protocol::ItemCompletedNotification; use codex_app_server_protocol::ItemStartedNotification; +use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::McpToolCallError; use codex_app_server_protocol::McpToolCallResult; use codex_app_server_protocol::McpToolCallStatus; @@ -40,6 +47,7 @@ use codex_app_server_protocol::ServerNotification; use codex_app_server_protocol::ServerRequestPayload; use codex_app_server_protocol::TerminalInteractionNotification; use codex_app_server_protocol::ThreadItem; +use codex_app_server_protocol::ThreadRollbackResponse; use codex_app_server_protocol::ThreadTokenUsage; use codex_app_server_protocol::ThreadTokenUsageUpdatedNotification; use codex_app_server_protocol::Turn; @@ -50,9 +58,11 @@ use codex_app_server_protocol::TurnInterruptResponse; use codex_app_server_protocol::TurnPlanStep; use codex_app_server_protocol::TurnPlanUpdatedNotification; use codex_app_server_protocol::TurnStatus; +use codex_app_server_protocol::build_turns_from_event_msgs; use codex_core::CodexConversation; use codex_core::parse_command::shlex_join; use codex_core::protocol::ApplyPatchApprovalRequestEvent; +use codex_core::protocol::CodexErrorInfo as CoreCodexErrorInfo; use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::ExecApprovalRequestEvent; @@ -78,14 +88,17 @@ use tracing::error; type JsonValue = serde_json::Value; +#[allow(clippy::too_many_arguments)] pub(crate) async fn apply_bespoke_event_handling( event: Event, conversation_id: ConversationId, conversation: Arc, outgoing: Arc, pending_interrupts: PendingInterrupts, + pending_rollbacks: PendingRollbacks, turn_summary_store: TurnSummaryStore, api_version: ApiVersion, + fallback_model_provider: String, ) { let Event { id: event_turn_id, @@ -337,6 +350,26 @@ pub(crate) async fn apply_bespoke_event_handling( .await; } EventMsg::Error(ev) => { + let message = ev.message.clone(); + let codex_error_info = ev.codex_error_info.clone(); + + // If this error belongs to an in-flight `thread/rollback` request, fail that request + // (and clear pending state) so subsequent rollbacks are unblocked. + // + // Don't send a notification for this error. + if matches!( + codex_error_info, + Some(CoreCodexErrorInfo::ThreadRollbackFailed) + ) { + return handle_thread_rollback_failed( + conversation_id, + message, + &pending_rollbacks, + &outgoing, + ) + .await; + }; + let turn_error = TurnError { message: ev.message, codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from), @@ -345,7 +378,7 @@ pub(crate) async fn apply_bespoke_event_handling( handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await; outgoing .send_server_notification(ServerNotification::Error(ErrorNotification { - error: turn_error, + error: turn_error.clone(), will_retry: false, thread_id: conversation_id.to_string(), turn_id: event_turn_id.clone(), @@ -690,6 +723,58 @@ pub(crate) async fn apply_bespoke_event_handling( ) .await; } + EventMsg::ThreadRolledBack(_rollback_event) => { + let pending = { + let mut map = pending_rollbacks.lock().await; + map.remove(&conversation_id) + }; + + if let Some(request_id) = pending { + let rollout_path = conversation.rollout_path(); + let response = match read_summary_from_rollout( + rollout_path.as_path(), + fallback_model_provider.as_str(), + ) + .await + { + Ok(summary) => { + let mut thread = summary_to_thread(summary); + match read_event_msgs_from_rollout(rollout_path.as_path()).await { + Ok(events) => { + thread.turns = build_turns_from_event_msgs(&events); + ThreadRollbackResponse { thread } + } + Err(err) => { + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!( + "failed to load rollout `{}`: {err}", + rollout_path.display() + ), + data: None, + }; + outgoing.send_error(request_id, error).await; + return; + } + } + } + Err(err) => { + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!( + "failed to load rollout `{}`: {err}", + rollout_path.display() + ), + data: None, + }; + outgoing.send_error(request_id, error).await; + return; + } + }; + + outgoing.send_response(request_id, response).await; + } + } EventMsg::TurnDiff(turn_diff_event) => { handle_turn_diff( conversation_id, @@ -906,6 +991,31 @@ async fn handle_turn_interrupted( .await; } +async fn handle_thread_rollback_failed( + conversation_id: ConversationId, + message: String, + pending_rollbacks: &PendingRollbacks, + outgoing: &OutgoingMessageSender, +) { + let pending_rollback = { + let mut map = pending_rollbacks.lock().await; + map.remove(&conversation_id) + }; + + if let Some(request_id) = pending_rollback { + outgoing + .send_error( + request_id, + JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: message.clone(), + data: None, + }, + ) + .await; + } +} + async fn handle_token_count_event( conversation_id: ConversationId, turn_id: String, diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index c1166d95088..70569c9b4ec 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -91,6 +91,7 @@ use codex_app_server_protocol::ThreadListParams; use codex_app_server_protocol::ThreadListResponse; use codex_app_server_protocol::ThreadResumeParams; use codex_app_server_protocol::ThreadResumeResponse; +use codex_app_server_protocol::ThreadRollbackParams; use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::ThreadStartResponse; use codex_app_server_protocol::ThreadStartedNotification; @@ -178,6 +179,8 @@ use uuid::Uuid; type PendingInterruptQueue = Vec<(RequestId, ApiVersion)>; pub(crate) type PendingInterrupts = Arc>>; +pub(crate) type PendingRollbacks = Arc>>; + /// Per-conversation accumulation of the latest states e.g. error message while a turn runs. #[derive(Default, Clone)] pub(crate) struct TurnSummary { @@ -220,6 +223,8 @@ pub(crate) struct CodexMessageProcessor { active_login: Arc>>, // Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives. pending_interrupts: PendingInterrupts, + // Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives. + pending_rollbacks: PendingRollbacks, turn_summary_store: TurnSummaryStore, pending_fuzzy_searches: Arc>>>, feedback: CodexFeedback, @@ -275,6 +280,7 @@ impl CodexMessageProcessor { conversation_listeners: HashMap::new(), active_login: Arc::new(Mutex::new(None)), pending_interrupts: Arc::new(Mutex::new(HashMap::new())), + pending_rollbacks: Arc::new(Mutex::new(HashMap::new())), turn_summary_store: Arc::new(Mutex::new(HashMap::new())), pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())), feedback, @@ -365,6 +371,9 @@ impl CodexMessageProcessor { ClientRequest::ThreadArchive { request_id, params } => { self.thread_archive(request_id, params).await; } + ClientRequest::ThreadRollback { request_id, params } => { + self.thread_rollback(request_id, params).await; + } ClientRequest::ThreadList { request_id, params } => { self.thread_list(request_id, params).await; } @@ -1506,6 +1515,52 @@ impl CodexMessageProcessor { } } + async fn thread_rollback(&mut self, request_id: RequestId, params: ThreadRollbackParams) { + let ThreadRollbackParams { + thread_id, + num_turns, + } = params; + + if num_turns == 0 { + self.send_invalid_request_error(request_id, "numTurns must be >= 1".to_string()) + .await; + return; + } + + let (conversation_id, conversation) = + match self.conversation_from_thread_id(&thread_id).await { + Ok(v) => v, + Err(error) => { + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + { + let mut map = self.pending_rollbacks.lock().await; + if map.contains_key(&conversation_id) { + self.send_invalid_request_error( + request_id, + "rollback already in progress for this thread".to_string(), + ) + .await; + return; + } + + map.insert(conversation_id, request_id.clone()); + } + + if let Err(err) = conversation.submit(Op::ThreadRollback { num_turns }).await { + // No ThreadRollback event will arrive if an error occurs. + // Clean up and reply immediately. + let mut map = self.pending_rollbacks.lock().await; + map.remove(&conversation_id); + + self.send_internal_error(request_id, format!("failed to start rollback: {err}")) + .await; + } + } + async fn thread_list(&self, request_id: RequestId, params: ThreadListParams) { let ThreadListParams { cursor, @@ -3095,8 +3150,10 @@ impl CodexMessageProcessor { let outgoing_for_task = self.outgoing.clone(); let pending_interrupts = self.pending_interrupts.clone(); + let pending_rollbacks = self.pending_rollbacks.clone(); let turn_summary_store = self.turn_summary_store.clone(); let api_version_for_task = api_version; + let fallback_model_provider = self.config.model_provider_id.clone(); tokio::spawn(async move { loop { tokio::select! { @@ -3152,8 +3209,10 @@ impl CodexMessageProcessor { conversation.clone(), outgoing_for_task.clone(), pending_interrupts.clone(), + pending_rollbacks.clone(), turn_summary_store.clone(), api_version_for_task, + fallback_model_provider.clone(), ) .await; } @@ -3354,7 +3413,7 @@ async fn derive_config_from_params( Config::load_with_cli_overrides_and_harness_overrides(cli_overrides, overrides).await } -async fn read_summary_from_rollout( +pub(crate) async fn read_summary_from_rollout( path: &Path, fallback_provider: &str, ) -> std::io::Result { @@ -3413,6 +3472,24 @@ async fn read_summary_from_rollout( }) } +pub(crate) async fn read_event_msgs_from_rollout( + path: &Path, +) -> std::io::Result> { + let items = match RolloutRecorder::get_rollout_history(path).await? { + InitialHistory::New => Vec::new(), + InitialHistory::Forked(items) => items, + InitialHistory::Resumed(resumed) => resumed.history, + }; + + Ok(items + .into_iter() + .filter_map(|item| match item { + RolloutItem::EventMsg(event) => Some(event), + _ => None, + }) + .collect()) +} + fn extract_conversation_summary( path: PathBuf, head: &[serde_json::Value], @@ -3474,7 +3551,7 @@ fn parse_datetime(timestamp: Option<&str>) -> Option> { }) } -fn summary_to_thread(summary: ConversationSummary) -> Thread { +pub(crate) fn summary_to_thread(summary: ConversationSummary) -> Thread { let ConversationSummary { conversation_id, path, diff --git a/codex-rs/app-server/tests/common/mcp_process.rs b/codex-rs/app-server/tests/common/mcp_process.rs index 98b2cabaaa0..026b01ebbad 100644 --- a/codex-rs/app-server/tests/common/mcp_process.rs +++ b/codex-rs/app-server/tests/common/mcp_process.rs @@ -45,6 +45,7 @@ use codex_app_server_protocol::SetDefaultModelParams; use codex_app_server_protocol::ThreadArchiveParams; use codex_app_server_protocol::ThreadListParams; use codex_app_server_protocol::ThreadResumeParams; +use codex_app_server_protocol::ThreadRollbackParams; use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::TurnInterruptParams; use codex_app_server_protocol::TurnStartParams; @@ -316,6 +317,15 @@ impl McpProcess { self.send_request("thread/archive", params).await } + /// Send a `thread/rollback` JSON-RPC request. + pub async fn send_thread_rollback_request( + &mut self, + params: ThreadRollbackParams, + ) -> anyhow::Result { + let params = Some(serde_json::to_value(params)?); + self.send_request("thread/rollback", params).await + } + /// Send a `thread/list` JSON-RPC request. pub async fn send_thread_list_request( &mut self, diff --git a/codex-rs/app-server/tests/suite/v2/mod.rs b/codex-rs/app-server/tests/suite/v2/mod.rs index f23f792d399..1ef00c6939d 100644 --- a/codex-rs/app-server/tests/suite/v2/mod.rs +++ b/codex-rs/app-server/tests/suite/v2/mod.rs @@ -7,6 +7,7 @@ mod review; mod thread_archive; mod thread_list; mod thread_resume; +mod thread_rollback; mod thread_start; mod turn_interrupt; mod turn_start; diff --git a/codex-rs/app-server/tests/suite/v2/thread_rollback.rs b/codex-rs/app-server/tests/suite/v2/thread_rollback.rs new file mode 100644 index 00000000000..f3313c759ef --- /dev/null +++ b/codex-rs/app-server/tests/suite/v2/thread_rollback.rs @@ -0,0 +1,177 @@ +use anyhow::Result; +use app_test_support::McpProcess; +use app_test_support::create_final_assistant_message_sse_response; +use app_test_support::create_mock_chat_completions_server_unchecked; +use app_test_support::to_response; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ThreadItem; +use codex_app_server_protocol::ThreadResumeParams; +use codex_app_server_protocol::ThreadResumeResponse; +use codex_app_server_protocol::ThreadRollbackParams; +use codex_app_server_protocol::ThreadRollbackResponse; +use codex_app_server_protocol::ThreadStartParams; +use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnStartParams; +use codex_app_server_protocol::UserInput as V2UserInput; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +#[tokio::test] +async fn thread_rollback_drops_last_turns_and_persists_to_rollout() -> Result<()> { + // Three Codex turns hit the mock model (session start + two turn/start calls). + let responses = vec![ + create_final_assistant_message_sse_response("Done")?, + create_final_assistant_message_sse_response("Done")?, + create_final_assistant_message_sse_response("Done")?, + ]; + let server = create_mock_chat_completions_server_unchecked(responses).await; + + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + // Start a thread. + let start_id = mcp + .send_thread_start_request(ThreadStartParams { + model: Some("mock-model".to_string()), + ..Default::default() + }) + .await?; + let start_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(start_id)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(start_resp)?; + + // Two turns. + let first_text = "First"; + let turn1_id = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: first_text.to_string(), + }], + ..Default::default() + }) + .await?; + let _turn1_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn1_id)), + ) + .await??; + let _completed1 = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + + let turn2_id = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "Second".to_string(), + }], + ..Default::default() + }) + .await?; + let _turn2_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn2_id)), + ) + .await??; + let _completed2 = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + + // Roll back the last turn. + let rollback_id = mcp + .send_thread_rollback_request(ThreadRollbackParams { + thread_id: thread.id.clone(), + num_turns: 1, + }) + .await?; + let rollback_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(rollback_id)), + ) + .await??; + let ThreadRollbackResponse { + thread: rolled_back_thread, + } = to_response::(rollback_resp)?; + + assert_eq!(rolled_back_thread.turns.len(), 1); + assert_eq!(rolled_back_thread.turns[0].items.len(), 2); + match &rolled_back_thread.turns[0].items[0] { + ThreadItem::UserMessage { content, .. } => { + assert_eq!( + content, + &vec![V2UserInput::Text { + text: first_text.to_string() + }] + ); + } + other => panic!("expected user message item, got {other:?}"), + } + + // Resume and confirm the history is pruned. + let resume_id = mcp + .send_thread_resume_request(ThreadResumeParams { + thread_id: thread.id, + ..Default::default() + }) + .await?; + let resume_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(resume_id)), + ) + .await??; + let ThreadResumeResponse { thread, .. } = to_response::(resume_resp)?; + + assert_eq!(thread.turns.len(), 1); + assert_eq!(thread.turns[0].items.len(), 2); + match &thread.turns[0].items[0] { + ThreadItem::UserMessage { content, .. } => { + assert_eq!( + content, + &vec![V2UserInput::Text { + text: first_text.to_string() + }] + ); + } + other => panic!("expected user message item, got {other:?}"), + } + + Ok(()) +} + +fn create_config_toml(codex_home: &std::path::Path, server_uri: &str) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index e0729b57acb..23feace8144 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1028,6 +1028,25 @@ impl Session { } } + /// Persist the event to the rollout file, flush it, and only then deliver it to clients. + /// + /// Most events can be delivered immediately after queueing the rollout write, but some + /// clients (e.g. app-server thread/rollback) re-read the rollout file synchronously on + /// receipt of the event and depend on the marker already being visible on disk. + pub(crate) async fn send_event_raw_flushed(&self, event: Event) { + // Record the last known agent status. + if let Some(status) = agent_status_from_event(&event.msg) { + let mut guard = self.agent_status.write().await; + *guard = status; + } + self.persist_rollout_items(&[RolloutItem::EventMsg(event.msg.clone())]) + .await; + self.flush_rollout().await; + if let Err(e) = self.tx_event.send(event).await { + error!("failed to send tool call event: {e}"); + } + } + pub(crate) async fn emit_turn_item_started(&self, turn_context: &TurnContext, item: &TurnItem) { self.send_event( turn_context, @@ -1245,6 +1264,9 @@ impl Session { history.replace(rebuilt); } } + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => { + history.drop_last_n_user_turns(rollback.num_turns); + } _ => {} } } @@ -1684,6 +1706,9 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv Op::Compact => { handlers::compact(&sess, sub.id.clone()).await; } + Op::ThreadRollback { num_turns } => { + handlers::thread_rollback(&sess, sub.id.clone(), num_turns).await; + } Op::RunUserShellCommand { command } => { handlers::run_user_shell_command( &sess, @@ -1741,6 +1766,7 @@ mod handlers { use codex_protocol::protocol::ReviewDecision; use codex_protocol::protocol::ReviewRequest; use codex_protocol::protocol::SkillsListEntry; + use codex_protocol::protocol::ThreadRolledBackEvent; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::WarningEvent; @@ -2060,6 +2086,46 @@ mod handlers { .await; } + pub async fn thread_rollback(sess: &Arc, sub_id: String, num_turns: u32) { + if num_turns == 0 { + sess.send_event_raw(Event { + id: sub_id, + msg: EventMsg::Error(ErrorEvent { + message: "num_turns must be >= 1".to_string(), + codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed), + }), + }) + .await; + return; + } + + let has_active_turn = { sess.active_turn.lock().await.is_some() }; + if has_active_turn { + sess.send_event_raw(Event { + id: sub_id, + msg: EventMsg::Error(ErrorEvent { + message: "Cannot rollback while a turn is in progress.".to_string(), + codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed), + }), + }) + .await; + return; + } + + let turn_context = sess.new_default_turn_with_sub_id(sub_id).await; + + let mut history = sess.clone_history().await; + history.drop_last_n_user_turns(num_turns); + sess.replace_history(history.get_history()).await; + sess.recompute_token_usage(turn_context.as_ref()).await; + + sess.send_event_raw_flushed(Event { + id: turn_context.sub_id.clone(), + msg: EventMsg::ThreadRolledBack(ThreadRolledBackEvent { num_turns }), + }) + .await; + } + pub async fn shutdown(sess: &Arc, sub_id: String) -> bool { sess.abort_all_tasks(TurnAbortReason::Interrupted).await; sess.services @@ -2973,6 +3039,131 @@ mod tests { assert_eq!(expected, actual); } + #[tokio::test] + async fn thread_rollback_drops_last_turn_from_history() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()); + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + let turn_1 = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "turn 1 user".to_string(), + }], + }, + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "turn 1 assistant".to_string(), + }], + }, + ]; + sess.record_into_history(&turn_1, tc.as_ref()).await; + + let turn_2 = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "turn 2 user".to_string(), + }], + }, + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "turn 2 assistant".to_string(), + }], + }, + ]; + sess.record_into_history(&turn_2, tc.as_ref()).await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 1); + + let mut expected = Vec::new(); + expected.extend(initial_context); + expected.extend(turn_1); + + let actual = sess.clone_history().await.get_history(); + assert_eq!(expected, actual); + } + + #[tokio::test] + async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()); + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + let turn_1 = vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "turn 1 user".to_string(), + }], + }]; + sess.record_into_history(&turn_1, tc.as_ref()).await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 99).await; + + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 99); + + let actual = sess.clone_history().await.get_history(); + assert_eq!(initial_context, actual); + } + + #[tokio::test] + async fn thread_rollback_fails_when_turn_in_progress() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()); + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + *sess.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + + let error_event = wait_for_thread_rollback_failed(&rx).await; + assert_eq!( + error_event.codex_error_info, + Some(CodexErrorInfo::ThreadRollbackFailed) + ); + + let actual = sess.clone_history().await.get_history(); + assert_eq!(initial_context, actual); + } + + #[tokio::test] + async fn thread_rollback_fails_when_num_turns_is_zero() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()); + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 0).await; + + let error_event = wait_for_thread_rollback_failed(&rx).await; + assert_eq!(error_event.message, "num_turns must be >= 1"); + assert_eq!( + error_event.codex_error_info, + Some(CodexErrorInfo::ThreadRollbackFailed) + ); + + let actual = sess.clone_history().await.get_history(); + assert_eq!(initial_context, actual); + } + #[tokio::test] async fn set_rate_limits_retains_previous_credits() { let codex_home = tempfile::tempdir().expect("create temp dir"); @@ -3206,6 +3397,44 @@ mod tests { assert_eq!(expected, got); } + async fn wait_for_thread_rolled_back( + rx: &async_channel::Receiver, + ) -> crate::protocol::ThreadRolledBackEvent { + let deadline = StdDuration::from_secs(2); + let start = std::time::Instant::now(); + loop { + let remaining = deadline.saturating_sub(start.elapsed()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::ThreadRolledBack(payload) => return payload, + _ => continue, + } + } + } + + async fn wait_for_thread_rollback_failed(rx: &async_channel::Receiver) -> ErrorEvent { + let deadline = StdDuration::from_secs(2); + let start = std::time::Instant::now(); + loop { + let remaining = deadline.saturating_sub(start.elapsed()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::Error(payload) + if payload.codex_error_info == Some(CodexErrorInfo::ThreadRollbackFailed) => + { + return payload; + } + _ => continue, + } + } + } + fn text_block(s: &str) -> ContentBlock { ContentBlock::TextContent(TextContent { annotations: None, diff --git a/codex-rs/core/src/context_manager/history.rs b/codex-rs/core/src/context_manager/history.rs index c18ad7df8ec..e52561c9fed 100644 --- a/codex-rs/core/src/context_manager/history.rs +++ b/codex-rs/core/src/context_manager/history.rs @@ -5,6 +5,9 @@ use crate::truncate::approx_token_count; use crate::truncate::approx_tokens_from_byte_count; use crate::truncate::truncate_function_output_items_with_policy; use crate::truncate::truncate_text; +use crate::user_instructions::SkillInstructions; +use crate::user_instructions::UserInstructions; +use crate::user_shell_command::is_user_shell_command_text; use codex_protocol::models::ContentItem; use codex_protocol::models::FunctionCallOutputContentItem; use codex_protocol::models::FunctionCallOutputPayload; @@ -152,6 +155,39 @@ impl ContextManager { } } + /// Drop the last `num_turns` user turns from this history. + /// + /// "User turns" are identified as `ResponseItem::Message` entries whose role is `"user"`. + /// + /// This mirrors thread-rollback semantics: + /// - `num_turns == 0` is a no-op + /// - if there are no user turns, this is a no-op + /// - if `num_turns` exceeds the number of user turns, all user turns are dropped while + /// preserving any items that occurred before the first user message. + pub(crate) fn drop_last_n_user_turns(&mut self, num_turns: u32) { + if num_turns == 0 { + return; + } + + // Keep behavior consistent with call sites that previously operated on `get_history()`: + // normalize first (call/output invariants), then truncate based on the normalized view. + let snapshot = self.get_history(); + let user_positions = user_message_positions(&snapshot); + let Some(&first_user_idx) = user_positions.first() else { + self.replace(snapshot); + return; + }; + + let n_from_end = usize::try_from(num_turns).unwrap_or(usize::MAX); + let cut_idx = if n_from_end >= user_positions.len() { + first_user_idx + } else { + user_positions[user_positions.len() - n_from_end] + }; + + self.replace(snapshot[..cut_idx].to_vec()); + } + pub(crate) fn update_token_info( &mut self, usage: &TokenUsage, @@ -291,6 +327,56 @@ fn estimate_reasoning_length(encoded_len: usize) -> usize { .saturating_sub(650) } +fn is_session_prefix(text: &str) -> bool { + let trimmed = text.trim_start(); + let lowered = trimmed.to_ascii_lowercase(); + lowered.starts_with("") +} + +fn is_user_turn_boundary(item: &ResponseItem) -> bool { + let ResponseItem::Message { role, content, .. } = item else { + return false; + }; + + if role != "user" { + return false; + } + + if UserInstructions::is_user_instructions(content) + || SkillInstructions::is_skill_instructions(content) + { + return false; + } + + for content_item in content { + match content_item { + ContentItem::InputText { text } => { + if is_session_prefix(text) || is_user_shell_command_text(text) { + return false; + } + } + ContentItem::OutputText { text } => { + if is_session_prefix(text) { + return false; + } + } + ContentItem::InputImage { .. } => {} + } + } + + true +} + +fn user_message_positions(items: &[ResponseItem]) -> Vec { + let mut positions = Vec::new(); + for (idx, item) in items.iter().enumerate() { + if is_user_turn_boundary(item) { + positions.push(idx); + } + } + positions +} + #[cfg(test)] #[path = "history_tests.rs"] mod tests; diff --git a/codex-rs/core/src/context_manager/history_tests.rs b/codex-rs/core/src/context_manager/history_tests.rs index d121b7dc634..1fb47e1e66b 100644 --- a/codex-rs/core/src/context_manager/history_tests.rs +++ b/codex-rs/core/src/context_manager/history_tests.rs @@ -43,6 +43,16 @@ fn user_msg(text: &str) -> ResponseItem { } } +fn user_input_text_msg(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + } +} + fn reasoning_msg(text: &str) -> ResponseItem { ResponseItem::Reasoning { id: String::new(), @@ -227,6 +237,127 @@ fn remove_first_item_handles_local_shell_pair() { assert_eq!(h.contents(), vec![]); } +#[test] +fn drop_last_n_user_turns_preserves_prefix() { + let items = vec![ + assistant_msg("session prefix item"), + user_msg("u1"), + assistant_msg("a1"), + user_msg("u2"), + assistant_msg("a2"), + ]; + + let mut history = create_history_with_items(items); + history.drop_last_n_user_turns(1); + assert_eq!( + history.get_history(), + vec![ + assistant_msg("session prefix item"), + user_msg("u1"), + assistant_msg("a1"), + ] + ); + + let mut history = create_history_with_items(vec![ + assistant_msg("session prefix item"), + user_msg("u1"), + assistant_msg("a1"), + user_msg("u2"), + assistant_msg("a2"), + ]); + history.drop_last_n_user_turns(99); + assert_eq!( + history.get_history(), + vec![assistant_msg("session prefix item")] + ); +} + +#[test] +fn drop_last_n_user_turns_ignores_session_prefix_user_messages() { + let items = vec![ + user_input_text_msg("ctx"), + user_input_text_msg("do the thing"), + user_input_text_msg( + "# AGENTS.md instructions for test_directory\n\n\ntest_text\n", + ), + user_input_text_msg( + "\ndemo\nskills/demo/SKILL.md\nbody\n", + ), + user_input_text_msg("echo 42"), + user_input_text_msg("turn 1 user"), + assistant_msg("turn 1 assistant"), + user_input_text_msg("turn 2 user"), + assistant_msg("turn 2 assistant"), + ]; + + let mut history = create_history_with_items(items); + history.drop_last_n_user_turns(1); + + let expected_prefix_and_first_turn = vec![ + user_input_text_msg("ctx"), + user_input_text_msg("do the thing"), + user_input_text_msg( + "# AGENTS.md instructions for test_directory\n\n\ntest_text\n", + ), + user_input_text_msg( + "\ndemo\nskills/demo/SKILL.md\nbody\n", + ), + user_input_text_msg("echo 42"), + user_input_text_msg("turn 1 user"), + assistant_msg("turn 1 assistant"), + ]; + + assert_eq!(history.get_history(), expected_prefix_and_first_turn); + + let expected_prefix_only = vec![ + user_input_text_msg("ctx"), + user_input_text_msg("do the thing"), + user_input_text_msg( + "# AGENTS.md instructions for test_directory\n\n\ntest_text\n", + ), + user_input_text_msg( + "\ndemo\nskills/demo/SKILL.md\nbody\n", + ), + user_input_text_msg("echo 42"), + ]; + + let mut history = create_history_with_items(vec![ + user_input_text_msg("ctx"), + user_input_text_msg("do the thing"), + user_input_text_msg( + "# AGENTS.md instructions for test_directory\n\n\ntest_text\n", + ), + user_input_text_msg( + "\ndemo\nskills/demo/SKILL.md\nbody\n", + ), + user_input_text_msg("echo 42"), + user_input_text_msg("turn 1 user"), + assistant_msg("turn 1 assistant"), + user_input_text_msg("turn 2 user"), + assistant_msg("turn 2 assistant"), + ]); + history.drop_last_n_user_turns(2); + assert_eq!(history.get_history(), expected_prefix_only); + + let mut history = create_history_with_items(vec![ + user_input_text_msg("ctx"), + user_input_text_msg("do the thing"), + user_input_text_msg( + "# AGENTS.md instructions for test_directory\n\n\ntest_text\n", + ), + user_input_text_msg( + "\ndemo\nskills/demo/SKILL.md\nbody\n", + ), + user_input_text_msg("echo 42"), + user_input_text_msg("turn 1 user"), + assistant_msg("turn 1 assistant"), + user_input_text_msg("turn 2 user"), + assistant_msg("turn 2 assistant"), + ]); + history.drop_last_n_user_turns(3); + assert_eq!(history.get_history(), expected_prefix_only); +} + #[test] fn remove_first_item_handles_custom_tool_pair() { let items = vec![ diff --git a/codex-rs/core/src/conversation_manager.rs b/codex-rs/core/src/conversation_manager.rs index c5bb9a17586..cf0f0106dbe 100644 --- a/codex-rs/core/src/conversation_manager.rs +++ b/codex-rs/core/src/conversation_manager.rs @@ -16,10 +16,9 @@ use crate::protocol::Event; use crate::protocol::EventMsg; use crate::protocol::SessionConfiguredEvent; use crate::rollout::RolloutRecorder; +use crate::rollout::truncation; use crate::skills::SkillsManager; use codex_protocol::ConversationId; -use codex_protocol::items::TurnItem; -use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ModelPreset; use codex_protocol::protocol::InitialHistory; use codex_protocol::protocol::Op; @@ -307,30 +306,8 @@ impl ConversationManagerState { /// Return a prefix of `items` obtained by cutting strictly before the nth user message /// (0-based) and all items that follow it. fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> InitialHistory { - // Work directly on rollout items, and cut the vector at the nth user message input. let items: Vec = history.get_rollout_items(); - - // Find indices of user message inputs in rollout order. - let mut user_positions: Vec = Vec::new(); - for (idx, item) in items.iter().enumerate() { - if let RolloutItem::ResponseItem(item @ ResponseItem::Message { .. }) = item - && matches!( - crate::event_mapping::parse_turn_item(item), - Some(TurnItem::UserMessage(_)) - ) - { - user_positions.push(idx); - } - } - - // If fewer than or equal to n user messages exist, treat as empty (out of range). - if user_positions.len() <= n { - return InitialHistory::New; - } - - // Cut strictly before the nth user message (do not keep the nth itself). - let cut_idx = user_positions[n]; - let rolled: Vec = items.into_iter().take(cut_idx).collect(); + let rolled = truncation::truncate_rollout_before_nth_user_message_from_start(&items, n); if rolled.is_empty() { InitialHistory::New diff --git a/codex-rs/core/src/rollout/mod.rs b/codex-rs/core/src/rollout/mod.rs index 540d204be3e..d7e24602fd1 100644 --- a/codex-rs/core/src/rollout/mod.rs +++ b/codex-rs/core/src/rollout/mod.rs @@ -11,6 +11,7 @@ pub(crate) mod error; pub mod list; pub(crate) mod policy; pub mod recorder; +pub(crate) mod truncation; pub use codex_protocol::protocol::SessionMeta; pub(crate) use error::map_session_init_error; diff --git a/codex-rs/core/src/rollout/policy.rs b/codex-rs/core/src/rollout/policy.rs index 07c8af1144b..6c02ad09425 100644 --- a/codex-rs/core/src/rollout/policy.rs +++ b/codex-rs/core/src/rollout/policy.rs @@ -45,6 +45,7 @@ pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool { | EventMsg::ContextCompacted(_) | EventMsg::EnteredReviewMode(_) | EventMsg::ExitedReviewMode(_) + | EventMsg::ThreadRolledBack(_) | EventMsg::UndoCompleted(_) | EventMsg::TurnAborted(_) => true, EventMsg::Error(_) diff --git a/codex-rs/core/src/rollout/truncation.rs b/codex-rs/core/src/rollout/truncation.rs new file mode 100644 index 00000000000..b8127f0345b --- /dev/null +++ b/codex-rs/core/src/rollout/truncation.rs @@ -0,0 +1,195 @@ +//! Helpers for truncating rollouts based on "user turn" boundaries. +//! +//! In core, "user turns" are detected by scanning `ResponseItem::Message` items and +//! interpreting them via `event_mapping::parse_turn_item(...)`. + +use crate::event_mapping; +use codex_protocol::items::TurnItem; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::RolloutItem; + +/// Return the indices of user message boundaries in a rollout. +/// +/// A user message boundary is a `RolloutItem::ResponseItem(ResponseItem::Message { .. })` +/// whose parsed turn item is `TurnItem::UserMessage`. +/// +/// Rollouts can contain `ThreadRolledBack` markers. Those markers indicate that the +/// last N user turns were removed from the effective thread history; we apply them here so +/// indexing uses the post-rollback history rather than the raw stream. +pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec { + let mut user_positions = Vec::new(); + for (idx, item) in items.iter().enumerate() { + match item { + RolloutItem::ResponseItem(item @ ResponseItem::Message { .. }) + if matches!( + event_mapping::parse_turn_item(item), + Some(TurnItem::UserMessage(_)) + ) => + { + user_positions.push(idx); + } + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => { + let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX); + let new_len = user_positions.len().saturating_sub(num_turns); + user_positions.truncate(new_len); + } + _ => {} + } + } + user_positions +} + +/// Return a prefix of `items` obtained by cutting strictly before the nth user message. +/// +/// The boundary index is 0-based from the start of `items` (so `n_from_start = 0` returns +/// a prefix that excludes the first user message and everything after it). +/// +/// If fewer than or equal to `n_from_start` user messages exist, this returns an empty +/// vector (out of range). +pub(crate) fn truncate_rollout_before_nth_user_message_from_start( + items: &[RolloutItem], + n_from_start: usize, +) -> Vec { + let user_positions = user_message_positions_in_rollout(items); + + // If fewer than or equal to n user messages exist, treat as empty (out of range). + if user_positions.len() <= n_from_start { + return Vec::new(); + } + + // Cut strictly before the nth user message (do not keep the nth itself). + let cut_idx = user_positions[n_from_start]; + items[..cut_idx].to_vec() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::codex::make_session_and_context; + use assert_matches::assert_matches; + use codex_protocol::models::ContentItem; + use codex_protocol::models::ReasoningItemReasoningSummary; + use codex_protocol::protocol::ThreadRolledBackEvent; + use pretty_assertions::assert_eq; + + fn user_msg(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + } + } + + fn assistant_msg(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + } + } + + #[test] + fn truncates_rollout_from_start_before_nth_user_only() { + let items = [ + user_msg("u1"), + assistant_msg("a1"), + assistant_msg("a2"), + user_msg("u2"), + assistant_msg("a3"), + ResponseItem::Reasoning { + id: "r1".to_string(), + summary: vec![ReasoningItemReasoningSummary::SummaryText { + text: "s".to_string(), + }], + content: None, + encrypted_content: None, + }, + ResponseItem::FunctionCall { + id: None, + name: "tool".to_string(), + arguments: "{}".to_string(), + call_id: "c1".to_string(), + }, + assistant_msg("a4"), + ]; + + let rollout: Vec = items + .iter() + .cloned() + .map(RolloutItem::ResponseItem) + .collect(); + + let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, 1); + let expected = vec![ + RolloutItem::ResponseItem(items[0].clone()), + RolloutItem::ResponseItem(items[1].clone()), + RolloutItem::ResponseItem(items[2].clone()), + ]; + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&expected).unwrap() + ); + + let truncated2 = truncate_rollout_before_nth_user_message_from_start(&rollout, 2); + assert_matches!(truncated2.as_slice(), []); + } + + #[test] + fn truncates_rollout_from_start_applies_thread_rollback_markers() { + let rollout_items = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(assistant_msg("a1")), + RolloutItem::ResponseItem(user_msg("u2")), + RolloutItem::ResponseItem(assistant_msg("a2")), + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent { + num_turns: 1, + })), + RolloutItem::ResponseItem(user_msg("u3")), + RolloutItem::ResponseItem(assistant_msg("a3")), + RolloutItem::ResponseItem(user_msg("u4")), + RolloutItem::ResponseItem(assistant_msg("a4")), + ]; + + // Effective user history after applying rollback(1) is: u1, u3, u4. + // So n_from_start=2 should cut before u4 (not u3). + let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 2); + let expected = rollout_items[..7].to_vec(); + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&expected).unwrap() + ); + } + + #[tokio::test] + async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() { + let (session, turn_context) = make_session_and_context().await; + let mut items = session.build_initial_context(&turn_context); + items.push(user_msg("feature request")); + items.push(assistant_msg("ack")); + items.push(user_msg("second question")); + items.push(assistant_msg("answer")); + + let rollout_items: Vec = items + .iter() + .cloned() + .map(RolloutItem::ResponseItem) + .collect(); + + let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1); + let expected: Vec = vec![ + RolloutItem::ResponseItem(items[0].clone()), + RolloutItem::ResponseItem(items[1].clone()), + RolloutItem::ResponseItem(items[2].clone()), + ]; + + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&expected).unwrap() + ); + } +} diff --git a/codex-rs/exec/src/event_processor_with_human_output.rs b/codex-rs/exec/src/event_processor_with_human_output.rs index 40afab7c9c6..ba6d99c5a82 100644 --- a/codex-rs/exec/src/event_processor_with_human_output.rs +++ b/codex-rs/exec/src/event_processor_with_human_output.rs @@ -595,7 +595,8 @@ impl EventProcessor for EventProcessorWithHumanOutput { | EventMsg::ReasoningRawContentDelta(_) | EventMsg::SkillsUpdateAvailable | EventMsg::UndoCompleted(_) - | EventMsg::UndoStarted(_) => {} + | EventMsg::UndoStarted(_) + | EventMsg::ThreadRolledBack(_) => {} } CodexStatus::Running } diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index b50d0c7640a..629f4cc493b 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -311,6 +311,7 @@ async fn run_codex_tool_session_inner( | EventMsg::UndoCompleted(_) | EventMsg::ExitedReviewMode(_) | EventMsg::ContextCompacted(_) + | EventMsg::ThreadRolledBack(_) | EventMsg::DeprecationNotice(_) => { // For now, we do not do anything extra for these // events. Note that diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index 7fb3d62d250..4dbd7902c53 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -210,6 +210,12 @@ pub enum Op { /// Request Codex to undo a turn (turn are stacked so it is the same effect as CMD + Z). Undo, + /// Request Codex to drop the last N user turns from in-memory context. + /// + /// This does not attempt to revert local filesystem changes. Clients are + /// responsible for undoing any edits on disk. + ThreadRollback { num_turns: u32 }, + /// Request a code review from the agent. Review { review_request: ReviewRequest }, @@ -541,6 +547,9 @@ pub enum EventMsg { /// Conversation history was compacted (either automatically or manually). ContextCompacted(ContextCompactedEvent), + /// Conversation history was rolled back by dropping the last N user turns. + ThreadRolledBack(ThreadRolledBackEvent), + /// Agent has started a task TaskStarted(TaskStartedEvent), @@ -718,6 +727,7 @@ pub enum CodexErrorInfo { ResponseTooManyFailedAttempts { http_status_code: Option, }, + ThreadRollbackFailed, Other, } @@ -1618,6 +1628,12 @@ pub struct UndoCompletedEvent { pub message: Option, } +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] +pub struct ThreadRolledBackEvent { + /// Number of user turns that were removed from context. + pub num_turns: u32, +} + #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] pub struct StreamErrorEvent { pub message: String, diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index c83629976cf..152724261ac 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -2114,6 +2114,7 @@ impl ChatWidget { } EventMsg::ExitedReviewMode(review) => self.on_exited_review_mode(review), EventMsg::ContextCompacted(_) => self.on_agent_message("Context compacted".to_owned()), + EventMsg::ThreadRolledBack(_) => {} EventMsg::RawResponseItem(_) | EventMsg::ItemStarted(_) | EventMsg::ItemCompleted(_) diff --git a/codex-rs/tui2/src/chatwidget.rs b/codex-rs/tui2/src/chatwidget.rs index ab594678719..7700c534249 100644 --- a/codex-rs/tui2/src/chatwidget.rs +++ b/codex-rs/tui2/src/chatwidget.rs @@ -1921,6 +1921,7 @@ impl ChatWidget { EventMsg::ExitedReviewMode(review) => self.on_exited_review_mode(review), EventMsg::ContextCompacted(_) => self.on_agent_message("Context compacted".to_owned()), EventMsg::RawResponseItem(_) + | EventMsg::ThreadRolledBack(_) | EventMsg::ItemStarted(_) | EventMsg::ItemCompleted(_) | EventMsg::AgentMessageContentDelta(_)