diff --git a/src/thread.rs b/src/thread.rs index d1ea773..ba94a99 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -2570,6 +2570,8 @@ struct ThreadActor { custom_prompts: Rc>>, /// The models available for this thread. models_manager: Arc, + /// Raw model ids selected or configured before this actor's model catalog knew about them. + remembered_raw_models: Vec, /// Internal message sender used to route spawned interaction results back to the actor. resolution_tx: mpsc::UnboundedSender, /// A sender for each interested `Op` submission that needs events routed. @@ -2594,6 +2596,8 @@ impl ThreadActor { resolution_tx: mpsc::UnboundedSender, resolution_rx: mpsc::UnboundedReceiver, ) -> Self { + let remembered_raw_models = config.model.iter().cloned().collect(); + Self { auth, client, @@ -2601,6 +2605,7 @@ impl ThreadActor { config, custom_prompts: Rc::default(), models_manager, + remembered_raw_models, resolution_tx, submissions: HashMap::new(), message_rx, @@ -2842,9 +2847,11 @@ impl ThreadActor { )) } - async fn find_current_model(&self) -> Option { - let model_presets = self.models_manager.list_models().await; - let config_model = self.get_current_model().await; + fn find_current_model_in_presets( + &self, + model_presets: &[ModelPreset], + config_model: &str, + ) -> Option { let preset = model_presets .iter() .find(|preset| preset.model == config_model)?; @@ -2873,6 +2880,49 @@ impl ThreadActor { Some((model.to_owned(), reasoning)) } + fn has_preset_for_model(presets: &[ModelPreset], model: &str) -> bool { + presets + .iter() + .any(|preset| preset.model == model || preset.id == model) + } + + fn remember_raw_model(&mut self, model: impl Into) { + let model = model.into(); + if model.is_empty() || self.remembered_raw_models.iter().any(|m| m == &model) { + return; + } + + self.remembered_raw_models.push(model); + } + + fn push_raw_config_model_option( + model_select_options: &mut Vec, + seen_model_ids: &mut HashSet, + model: &str, + ) { + if model.is_empty() || !seen_model_ids.insert(model.to_owned()) { + return; + } + + model_select_options.push(SessionConfigSelectOption::new( + model.to_owned(), + model.to_owned(), + )); + } + + fn push_raw_model_info( + available_models: &mut Vec, + seen_model_ids: &mut HashSet, + model: &str, + ) { + if model.is_empty() || !seen_model_ids.insert(model.to_owned()) { + return; + } + + let model_id = ModelId::new(model.to_owned()); + available_models.push(ModelInfo::new(model_id, model.to_owned())); + } + async fn config_options(&self) -> Result, Error> { let mut options = Vec::new(); @@ -2901,15 +2951,26 @@ impl ThreadActor { let current_preset = presets.iter().find(|p| p.model == current_model).cloned(); let mut model_select_options = Vec::new(); + let mut raw_model_ids = HashSet::new(); if current_preset.is_none() { - // If no preset found, return the current model string as-is - model_select_options.push(SessionConfigSelectOption::new( - current_model.clone(), - current_model.clone(), - )); + Self::push_raw_config_model_option( + &mut model_select_options, + &mut raw_model_ids, + ¤t_model, + ); }; + for model in &self.remembered_raw_models { + if !Self::has_preset_for_model(&presets, model) { + Self::push_raw_config_model_option( + &mut model_select_options, + &mut raw_model_ids, + model, + ); + } + } + model_select_options.extend( presets .into_iter() @@ -3052,6 +3113,10 @@ impl ThreadActor { .await .map_err(|e| Error::from(anyhow::anyhow!(e)))?; + if preset.is_none() { + self.remember_raw_model(model_to_use.clone()); + } + self.config.model = Some(model_to_use); self.config.model_reasoning_effort = effort_to_use; @@ -3107,20 +3172,26 @@ impl ThreadActor { async fn models(&self) -> Result { let mut available_models = Vec::new(); let config_model = self.get_current_model().await; + let presets = self.models_manager.list_models().await; + let mut raw_model_ids = HashSet::new(); - let current_model_id = if let Some(model_id) = self.find_current_model().await { - model_id - } else { - // If no preset found, return the current model string as-is - let model_id = ModelId::new(self.get_current_model().await); - available_models.push(ModelInfo::new(model_id.clone(), model_id.to_string())); - model_id - }; + let current_model_id = + if let Some(model_id) = self.find_current_model_in_presets(&presets, &config_model) { + model_id + } else { + let model_id = ModelId::new(config_model.clone()); + Self::push_raw_model_info(&mut available_models, &mut raw_model_ids, &config_model); + model_id + }; + + for model in &self.remembered_raw_models { + if !Self::has_preset_for_model(&presets, model) { + Self::push_raw_model_info(&mut available_models, &mut raw_model_ids, model); + } + } available_models.extend( - self.models_manager - .list_models() - .await + presets .iter() .filter(|model| model.show_in_picker || model.model == config_model) .flat_map(|preset| { @@ -3314,7 +3385,9 @@ impl ThreadActor { async fn handle_set_model(&mut self, model: ModelId) -> Result<(), Error> { // Try parsing as preset format, otherwise use as-is, fallback to config - let (model_to_use, effort_to_use) = if let Some((m, e)) = Self::parse_model_id(&model) { + let parsed_model = Self::parse_model_id(&model); + let is_raw_model = parsed_model.is_none(); + let (model_to_use, effort_to_use) = if let Some((m, e)) = parsed_model { (m, Some(e)) } else { let model_str = model.0.to_string(); @@ -3347,6 +3420,10 @@ impl ThreadActor { .await .map_err(|e| Error::from(anyhow::anyhow!(e)))?; + if is_raw_model { + self.remember_raw_model(model_to_use.clone()); + } + self.config.model = Some(model_to_use); self.config.model_reasoning_effort = effort_to_use; @@ -4051,9 +4128,12 @@ mod tests { use std::sync::atomic::AtomicUsize; use std::time::Duration; - use agent_client_protocol::{RequestPermissionResponse, TextContent}; + use agent_client_protocol::{ + RequestPermissionResponse, SessionConfigKind, SessionConfigSelectOptions, TextContent, + }; use codex_core::{config::ConfigOverrides, test_support::all_model_presets}; use codex_protocol::config_types::ModeKind; + use codex_protocol::openai_models::{ReasoningEffortPreset, default_input_modalities}; use tokio::{ sync::{Mutex, Notify, mpsc::UnboundedSender}, task::LocalSet, @@ -4557,6 +4637,92 @@ mod tests { Ok(()) } + #[tokio::test] + async fn raw_configured_model_stays_in_config_options_after_switching_models() + -> anyhow::Result<()> { + let older_model = test_model_preset("test-older-model", true); + let mut actor = setup_actor_with_models( + ConfigOverrides { + model: Some("test-new-model".to_owned()), + ..ConfigOverrides::default() + }, + vec![older_model.clone()], + ) + .await?; + + let model_options = model_config_option_values(&actor.config_options().await?); + assert!(model_options.contains(&"test-new-model".to_owned())); + + actor + .handle_set_config_model(SessionConfigValueId::new(older_model.id.clone())) + .await?; + + let model_options = model_config_option_values(&actor.config_options().await?); + assert!(model_options.contains(&"test-new-model".to_owned())); + assert!(model_options.contains(&older_model.id)); + + Ok(()) + } + + #[tokio::test] + async fn hidden_preset_model_does_not_stay_in_config_options_after_switching_models() + -> anyhow::Result<()> { + let hidden_model = test_model_preset("test-hidden-model", false); + let visible_model = test_model_preset("test-visible-model", true); + let mut actor = setup_actor_with_models( + ConfigOverrides { + model: Some(hidden_model.model.clone()), + ..ConfigOverrides::default() + }, + vec![hidden_model.clone(), visible_model.clone()], + ) + .await?; + + let model_options = model_config_option_values(&actor.config_options().await?); + assert!(model_options.contains(&hidden_model.id)); + + actor + .handle_set_config_model(SessionConfigValueId::new(visible_model.id.clone())) + .await?; + + let model_options = model_config_option_values(&actor.config_options().await?); + assert!(!model_options.contains(&hidden_model.id)); + assert!(model_options.contains(&visible_model.id)); + + Ok(()) + } + + #[tokio::test] + async fn raw_configured_model_stays_in_legacy_models_after_switching_models() + -> anyhow::Result<()> { + let older_model = test_model_preset("test-legacy-older-model", true); + let mut actor = setup_actor_with_models( + ConfigOverrides { + model: Some("test-legacy-new-model".to_owned()), + ..ConfigOverrides::default() + }, + vec![older_model.clone()], + ) + .await?; + + let model_ids = legacy_model_ids(&actor.models().await?); + assert!(model_ids.contains(&"test-legacy-new-model".to_owned())); + + actor + .handle_set_model(ModelId::new(format!( + "{}/{}", + older_model.id.as_str(), + ReasoningEffort::Medium + ))) + .await?; + + let model_ids = legacy_model_ids(&actor.models().await?); + assert!(model_ids.contains(&"test-legacy-new-model".to_owned())); + assert!(model_ids.contains(&format!("{}/{}", older_model.id, ReasoningEffort::Medium))); + + Ok(()) + } + async fn setup( custom_prompts: Vec, ) -> anyhow::Result<( @@ -4571,7 +4737,7 @@ mod tests { let session_client = SessionClient::with_client(session_id.clone(), client.clone(), Arc::default()); let conversation = Arc::new(StubCodexThread::new()); - let models_manager = Arc::new(StubModelsManager); + let models_manager = Arc::new(StubModelsManager::default()); let config = Config::load_with_cli_overrides_and_harness_overrides( vec![], ConfigOverrides::default(), @@ -4597,6 +4763,79 @@ mod tests { Ok((session_id, client, conversation, message_tx, local_set)) } + async fn setup_actor_with_models( + config_overrides: ConfigOverrides, + models: Vec, + ) -> anyhow::Result> { + let session_id = SessionId::new("test"); + let client = Arc::new(StubClient::new()); + let session_client = SessionClient::with_client(session_id, client.clone(), Arc::default()); + let conversation = Arc::new(StubCodexThread::new()); + let models_manager = Arc::new(StubModelsManager::new(models)); + let config = + Config::load_with_cli_overrides_and_harness_overrides(vec![], config_overrides).await?; + let (_message_tx, message_rx) = tokio::sync::mpsc::unbounded_channel(); + let (resolution_tx, resolution_rx) = tokio::sync::mpsc::unbounded_channel(); + + Ok(ThreadActor::new( + StubAuth, + session_client, + conversation, + models_manager, + config, + message_rx, + resolution_tx, + resolution_rx, + )) + } + + fn test_model_preset(model: &str, show_in_picker: bool) -> ModelPreset { + ModelPreset { + id: model.to_owned(), + model: model.to_owned(), + display_name: model.to_owned(), + description: format!("{model} description"), + default_reasoning_effort: ReasoningEffort::Medium, + supported_reasoning_efforts: vec![ReasoningEffortPreset { + effort: ReasoningEffort::Medium, + description: "medium".to_owned(), + }], + supports_personality: false, + is_default: false, + upgrade: None, + show_in_picker, + availability_nux: None, + supported_in_api: true, + input_modalities: default_input_modalities(), + } + } + + fn model_config_option_values(config_options: &[SessionConfigOption]) -> Vec { + let model_option = config_options + .iter() + .find(|option| option.id.0.as_ref() == "model") + .expect("model config option"); + let SessionConfigKind::Select(select) = &model_option.kind else { + panic!("model config option should be a select"); + }; + let SessionConfigSelectOptions::Ungrouped(options) = &select.options else { + panic!("model config option should be ungrouped"); + }; + + options + .iter() + .map(|option| option.value.0.to_string()) + .collect() + } + + fn legacy_model_ids(models: &SessionModelState) -> Vec { + models + .available_models + .iter() + .map(|model| model.model_id.0.to_string()) + .collect() + } + struct StubAuth; impl Auth for StubAuth { @@ -4605,16 +4844,35 @@ mod tests { } } - struct StubModelsManager; + struct StubModelsManager { + models: Vec, + } + + impl StubModelsManager { + fn new(models: Vec) -> Self { + Self { models } + } + } + + impl Default for StubModelsManager { + fn default() -> Self { + Self::new(all_model_presets().to_owned()) + } + } #[async_trait::async_trait] impl ModelsManagerImpl for StubModelsManager { - async fn get_model(&self, _model_id: &Option) -> String { - all_model_presets()[0].to_owned().id + async fn get_model(&self, model_id: &Option) -> String { + model_id.clone().unwrap_or_else(|| { + self.models + .first() + .map(|model| model.model.clone()) + .unwrap_or_default() + }) } async fn list_models(&self) -> Vec { - all_model_presets().to_owned() + self.models.clone() } } @@ -4895,6 +5153,7 @@ mod tests { | Op::ResolveElicitation { .. } | Op::RequestPermissionsResponse { .. } | Op::PatchApproval { .. } + | Op::OverrideTurnContext { .. } | Op::Interrupt => {} Op::Shutdown => { if let Some(active_prompt_id) = self.active_prompt_id.lock().unwrap().take() { @@ -5430,7 +5689,7 @@ mod tests { let session_client = SessionClient::with_client(session_id.clone(), client.clone(), Arc::default()); let conversation = Arc::new(StubCodexThread::new()); - let models_manager = Arc::new(StubModelsManager); + let models_manager = Arc::new(StubModelsManager::default()); let config = Config::load_with_cli_overrides_and_harness_overrides( vec![], ConfigOverrides::default(),