diff --git a/Cargo.lock b/Cargo.lock index 5f5ee7b..4457448 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1534,6 +1534,7 @@ dependencies = [ "codex-models-manager", "codex-protocol", "codex-shell-command", + "codex-thread-store", "codex-utils-approval-presets", "codex-utils-cli", "diffy", diff --git a/Cargo.toml b/Cargo.toml index e68ef0c..b697447 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ codex-models-manager = { git = "https://github.com/openai/codex", tag = "rust-v0 codex-protocol = { git = "https://github.com/openai/codex", tag = "rust-v0.129.0" } codex-login = { git = "https://github.com/openai/codex", tag = "rust-v0.129.0" } codex-shell-command = { git = "https://github.com/openai/codex", tag = "rust-v0.129.0" } +codex-thread-store = { git = "https://github.com/openai/codex", tag = "rust-v0.129.0" } codex-utils-approval-presets = { git = "https://github.com/openai/codex", tag = "rust-v0.129.0" } codex-utils-cli = { git = "https://github.com/openai/codex", tag = "rust-v0.129.0" } diffy = { version = "0.5.0", features = ["std"] } diff --git a/src/codex_agent.rs b/src/codex_agent.rs index bae7687..12cdd85 100644 --- a/src/codex_agent.rs +++ b/src/codex_agent.rs @@ -15,8 +15,8 @@ use agent_client_protocol as acp; use codex_config::{McpServerConfig, McpServerTransportConfig}; use codex_core::{ NewThread, RolloutRecorder, SortDirection, StateDbHandle, ThreadManager, ThreadSortKey, - config::Config, find_thread_path_by_id_str, init_state_db, parse_cursor, - resolve_installation_id, thread_store_from_config, + config::Config, find_thread_names_by_ids, find_thread_path_by_id_str, init_state_db, + parse_cursor, resolve_installation_id, thread_store_from_config, }; use codex_exec_server::{EnvironmentManager, EnvironmentManagerArgs, ExecServerRuntimePaths}; use codex_login::{ @@ -28,11 +28,11 @@ use codex_protocol::{ protocol::{InitialHistory, SessionSource}, }; use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, path::{Path, PathBuf}, sync::{Arc, Mutex}, }; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; use unicode_segmentation::UnicodeSegmentation; use crate::thread::Thread; @@ -566,11 +566,13 @@ impl CodexAgent { .lock() .unwrap() .insert(session_id.clone(), config.cwd.to_path_buf()); + let thread_store = thread_store_from_config(&config, self.state_db.clone()); let thread = Arc::new(Thread::new( session_id.clone(), thread, self.auth_manager.clone(), Arc::new(self.thread_manager.get_models_manager()), + thread_store, self.client_capabilities.clone(), config.clone(), cx, @@ -640,11 +642,13 @@ impl CodexAgent { .await .map_err(|e| Error::internal_error().data(e.to_string()))?; + let thread_store = thread_store_from_config(&config, self.state_db.clone()); let thread = Arc::new(Thread::new( session_id.clone(), thread, self.auth_manager.clone(), Arc::new(self.thread_manager.get_models_manager()), + thread_store, self.client_capabilities.clone(), config.clone(), cx, @@ -695,7 +699,7 @@ impl CodexAgent { .await .map_err(|err| Error::internal_error().data(format!("failed to list sessions: {err}")))?; - let sessions = page + let session_items = page .items .into_iter() .filter_map(|item| { @@ -708,17 +712,36 @@ impl CodexAgent { return None; } - let title = item - .first_user_message - .as_deref() - .and_then(format_session_title); let updated_at = item.updated_at.or(item.created_at); - Some( - SessionInfo::new(SessionId::new(thread_id.to_string()), item_cwd) - .title(title) - .updated_at(updated_at), - ) + Some((thread_id, item_cwd, item.first_user_message, updated_at)) + }) + .collect::>(); + + let thread_ids = session_items + .iter() + .map(|(thread_id, _, _, _)| *thread_id) + .collect::>(); + let thread_names = + match find_thread_names_by_ids(&self.config.codex_home, &thread_ids).await { + Ok(thread_names) => thread_names, + Err(err) => { + warn!("failed to read Codex thread names: {err}"); + HashMap::new() + } + }; + + let sessions = session_items + .into_iter() + .map(|(thread_id, item_cwd, first_user_message, updated_at)| { + let title = thread_names + .get(&thread_id) + .cloned() + .or_else(|| first_user_message.as_deref().and_then(format_session_title)); + + SessionInfo::new(SessionId::new(thread_id.to_string()), item_cwd) + .title(title) + .updated_at(updated_at) }) .collect::>(); diff --git a/src/thread.rs b/src/thread.rs index 143206a..262127e 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -17,11 +17,11 @@ use agent_client_protocol::{ PlanEntryStatus, PromptRequest, RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse, ResourceLink, SelectedPermissionOutcome, SessionConfigId, SessionConfigOption, SessionConfigOptionCategory, SessionConfigOptionValue, - SessionConfigSelectOption, SessionConfigValueId, SessionId, SessionMode, SessionModeId, - SessionModeState, SessionModelState, SessionNotification, SessionUpdate, StopReason, - Terminal, TextContent, TextResourceContents, ToolCall, ToolCallContent, ToolCallId, - ToolCallLocation, ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields, ToolKind, - UnstructuredCommandInput, UsageUpdate, + SessionConfigSelectOption, SessionConfigValueId, SessionId, SessionInfoUpdate, SessionMode, + SessionModeId, SessionModeState, SessionModelState, SessionNotification, SessionUpdate, + StopReason, Terminal, TextContent, TextResourceContents, ToolCall, ToolCallContent, + ToolCallId, ToolCallLocation, ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields, + ToolKind, UnstructuredCommandInput, UsageUpdate, }, }; use codex_apply_patch::parse_patch; @@ -30,10 +30,12 @@ use codex_core::{ config::{Config, set_project_trust_level}, review_format::format_review_findings_block, review_prompts::user_facing_hint, + util::normalize_thread_name, }; use codex_login::auth::AuthManager; use codex_models_manager::manager::{ModelsManager, RefreshStrategy}; use codex_protocol::{ + ThreadId, approvals::{ ElicitationRequest, ElicitationRequestEvent, GuardianAssessmentAction, GuardianCommandSource, @@ -76,6 +78,7 @@ use codex_protocol::{ user_input::UserInput, }; use codex_shell_command::parse_command::parse_command; +use codex_thread_store::{ThreadMetadataPatch, ThreadStore, UpdateThreadMetadataParams}; use codex_utils_approval_presets::{ApprovalPreset, builtin_approval_presets}; use heck::ToTitleCase; use itertools::Itertools; @@ -255,6 +258,47 @@ impl ModelsManagerImpl for Arc { } } +trait ThreadNameStore: Send + Sync { + fn update_thread_name( + &self, + thread_id: ThreadId, + name: String, + ) -> Pin> + Send + '_>>; +} + +struct ThreadStoreNameUpdater { + thread_store: Arc, +} + +impl ThreadStoreNameUpdater { + fn new(thread_store: Arc) -> Self { + Self { thread_store } + } +} + +impl ThreadNameStore for ThreadStoreNameUpdater { + fn update_thread_name( + &self, + thread_id: ThreadId, + name: String, + ) -> Pin> + Send + '_>> { + Box::pin(async move { + self.thread_store + .update_thread_metadata(UpdateThreadMetadataParams { + thread_id, + patch: ThreadMetadataPatch { + name: Some(name), + ..Default::default() + }, + include_archived: false, + }) + .await + .map_err(|err| Error::internal_error().data(err.to_string()))?; + Ok(()) + }) + } +} + pub trait Auth { fn logout(&self) -> impl Future> + Send; } @@ -324,6 +368,7 @@ impl Thread { thread: Arc, auth: Arc, models_manager: Arc, + thread_store: Arc, client_capabilities: Arc>, config: Config, cx: ConnectionTo, @@ -336,6 +381,7 @@ impl Thread { SessionClient::new(session_id, cx, client_capabilities), thread.clone(), models_manager, + Arc::new(ThreadStoreNameUpdater::new(thread_store)), config, message_rx, resolution_tx, @@ -2741,6 +2787,8 @@ struct ThreadActor { config: Config, /// The models available for this thread. models_manager: Arc, + /// Storage surface used for thread metadata updates. + thread_name_store: Arc, /// Internal message sender used to route spawned interaction results back to the actor. resolution_tx: mpsc::UnboundedSender, /// A sender for each interested `Op` submission that needs events routed. @@ -2760,6 +2808,7 @@ impl ThreadActor { client: SessionClient, thread: Arc, models_manager: Arc, + thread_name_store: Arc, config: Config, message_rx: mpsc::UnboundedReceiver, resolution_tx: mpsc::UnboundedSender, @@ -2771,6 +2820,7 @@ impl ThreadActor { thread, config, models_manager, + thread_name_store, resolution_tx, submissions: HashMap::new(), message_rx, @@ -2920,6 +2970,9 @@ impl ThreadActor { "compact", "summarize conversation to prevent hitting the context limit", ), + AvailableCommand::new("rename", "rename the current thread").input( + AvailableCommandInput::Unstructured(UnstructuredCommandInput::new("new name")), + ), AvailableCommand::new("logout", "logout of Codex"), ] } @@ -3307,6 +3360,12 @@ impl ThreadActor { self.auth.logout().await?; return Err(Error::auth_required()); } + "rename" if !rest.is_empty() => { + let name = normalize_thread_name(rest).ok_or_else(Error::invalid_params)?; + self.handle_rename(name).await?; + drop(response_tx.send(Ok(StopReason::EndTurn))); + return Ok(response_rx); + } _ => { op = Op::UserInput { items, @@ -3346,6 +3405,18 @@ impl ThreadActor { Ok(response_rx) } + async fn handle_rename(&mut self, name: String) -> Result<(), Error> { + let thread_id = ThreadId::from_string(self.client.session_id.0.as_ref()) + .map_err(|err| Error::internal_error().data(err.to_string()))?; + self.thread_name_store + .update_thread_name(thread_id, name.clone()) + .await?; + send_session_title_update(&self.client, Some(name.clone())); + self.client + .send_agent_text(format!("Thread renamed to: {name}\n")); + Ok(()) + } + async fn handle_set_mode(&mut self, mode: SessionModeId) -> Result<(), Error> { let preset = APPROVAL_PRESETS .iter() @@ -3506,6 +3577,9 @@ impl ThreadActor { EventMsg::AgentReasoningRawContent(AgentReasoningRawContentEvent { text }) => { self.client.send_agent_thought(text.clone()); } + EventMsg::SessionConfigured(event) => { + send_session_title_update(&self.client, event.thread_name.clone()); + } EventMsg::ThreadGoalUpdated(event) => { self.client .send_agent_text(format_thread_goal_update(event)); @@ -3803,12 +3877,22 @@ impl ThreadActor { async fn handle_event(&mut self, Event { id, msg }: Event) { if let Some(submission) = self.submissions.get_mut(&id) { submission.handle_event(&self.client, msg).await; + } else if let EventMsg::SessionConfigured(event) = msg { + send_session_title_update(&self.client, event.thread_name); } else { warn!("Received event for unknown submission ID: {id} {msg:?}"); } } } +fn send_session_title_update(client: &SessionClient, title: Option) { + if let Some(title) = title { + client.send_notification(SessionUpdate::SessionInfoUpdate( + SessionInfoUpdate::new().title(title), + )); + } +} + fn build_prompt_items(prompt: Vec) -> Vec { prompt .into_iter() @@ -4259,6 +4343,42 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_rename_sends_title_update() -> anyhow::Result<()> { + let (session_id, client, thread, message_tx, _handle) = setup().await?; + let (prompt_response_tx, prompt_response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id.clone(), vec!["/rename My New Thread".into()]), + response_tx: prompt_response_tx, + })?; + + let stop_reason = prompt_response_rx.await??.await??; + assert_eq!(stop_reason, StopReason::EndTurn); + drop(message_tx); + + assert!(thread.ops.lock().unwrap().is_empty()); + let notifications = client.notifications.lock().unwrap(); + assert!(notifications.iter().any(|notification| { + matches!( + ¬ification.update, + SessionUpdate::SessionInfoUpdate(update) + if update.title.value() == Some(&"My New Thread".to_string()) + ) + })); + assert!(notifications.iter().any(|notification| { + matches!( + ¬ification.update, + SessionUpdate::AgentMessageChunk(ContentChunk { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) if text == "Thread renamed to: My New Thread\n" + ) + })); + + Ok(()) + } + #[tokio::test] async fn test_thread_goal_updated_is_sent_as_agent_message() -> anyhow::Result<()> { let (session_id, client, _, message_tx, _handle) = setup().await?; @@ -4755,7 +4875,7 @@ mod tests { UnboundedSender, tokio::task::JoinHandle<()>, )> { - let session_id = SessionId::new("test"); + let session_id = SessionId::new(ThreadId::new().to_string()); let client = Arc::new(StubClient::new()); let session_client = SessionClient::with_client(session_id.clone(), client.clone(), Arc::default()); @@ -4774,6 +4894,7 @@ mod tests { session_client, conversation.clone(), models_manager, + Arc::new(StubThreadNameStore::default()), config, message_rx, resolution_tx, @@ -4792,6 +4913,24 @@ mod tests { } } + #[derive(Default)] + struct StubThreadNameStore { + updates: std::sync::Mutex>, + } + + impl ThreadNameStore for StubThreadNameStore { + fn update_thread_name( + &self, + thread_id: ThreadId, + name: String, + ) -> Pin> + Send + '_>> { + Box::pin(async move { + self.updates.lock().unwrap().push((thread_id, name)); + Ok(()) + }) + } + } + struct StubModelsManager; impl ModelsManagerImpl for StubModelsManager { @@ -5646,6 +5785,7 @@ mod tests { session_client, conversation.clone(), models_manager, + Arc::new(StubThreadNameStore::default()), config, message_rx, resolution_tx,