diff --git a/README.md b/README.md index 34070a9..31d9ad1 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ This tool implements an ACP adapter around the Codex CLI, supporting: - /review-commit - /init - /compact + - /fast - /logout - Custom Prompts - Client MCP servers diff --git a/src/thread.rs b/src/thread.rs index 143206a..5d43dcc 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -38,7 +38,7 @@ use codex_protocol::{ ElicitationRequest, ElicitationRequestEvent, GuardianAssessmentAction, GuardianCommandSource, }, - config_types::TrustLevel, + config_types::{ServiceTier, TrustLevel}, dynamic_tools::{DynamicToolCallOutputContentItem, DynamicToolCallRequest}, error::CodexErr, mcp::CallToolResult, @@ -2920,6 +2920,13 @@ impl ThreadActor { "compact", "summarize conversation to prevent hitting the context limit", ), + AvailableCommand::new( + "fast", + "toggle Fast mode to enable fastest inference with increased plan usage", + ) + .input(AvailableCommandInput::Unstructured( + UnstructuredCommandInput::new("on|off|status"), + )), AvailableCommand::new("logout", "logout of Codex"), ] } @@ -3251,6 +3258,10 @@ impl ThreadActor { let op; if let Some((name, rest)) = extract_slash_command(&items) { match name { + "fast" => { + self.handle_fast_command(rest, response_tx).await?; + return Ok(response_rx); + } "compact" => op = Op::Compact, "init" => { op = Op::UserInput { @@ -3346,6 +3357,82 @@ impl ThreadActor { Ok(response_rx) } + async fn handle_fast_command( + &mut self, + rest: &str, + response_tx: oneshot::Sender>, + ) -> Result<(), Error> { + let arg = rest.trim(); + let tier = match arg { + "" => { + if self.fast_mode_enabled() { + None + } else { + Some(ServiceTier::Fast) + } + } + "on" => Some(ServiceTier::Fast), + "off" => None, + "status" => { + self.client.send_agent_text(format!( + "Fast mode is {}\n", + if self.fast_mode_enabled() { + "on" + } else { + "off" + } + )); + drop(response_tx.send(Ok(StopReason::EndTurn))); + return Ok(()); + } + _ => { + self.client + .send_agent_text("Usage: /fast [on|off|status]\n"); + drop(response_tx.send(Ok(StopReason::EndTurn))); + return Ok(()); + } + }; + + self.thread + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: None, + personality: None, + windows_sandbox_level: None, + service_tier: Some(tier.map(|tier| tier.request_value().to_string())), + approvals_reviewer: None, + permission_profile: None, + }) + .await + .map_err(|e| Error::from(anyhow::anyhow!(e)))?; + + self.config.service_tier = tier.map(|tier| tier.request_value().to_string()); + self.client.send_agent_text(format!( + "Fast mode set to {}\n", + if self.fast_mode_enabled() { + "on" + } else { + "off" + } + )); + drop(response_tx.send(Ok(StopReason::EndTurn))); + + Ok(()) + } + + fn fast_mode_enabled(&self) -> bool { + self.config + .service_tier + .as_deref() + .and_then(ServiceTier::from_request_value) + .is_some_and(|tier| tier == ServiceTier::Fast) + } + async fn handle_set_mode(&mut self, mode: SessionModeId) -> Result<(), Error> { let preset = APPROVAL_PRESETS .iter() @@ -4393,6 +4480,105 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_fast_on() -> anyhow::Result<()> { + let (session_id, client, thread, message_tx, _handle) = setup().await?; + let (prompt_response_tx, prompt_response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id.clone(), vec!["/fast on".into()]), + response_tx: prompt_response_tx, + })?; + + let stop_reason = prompt_response_rx.await??.await??; + assert_eq!(stop_reason, StopReason::EndTurn); + drop(message_tx); + + let notifications = client.notifications.lock().unwrap(); + assert_eq!(notifications.len(), 1); + assert!(matches!( + ¬ifications[0].update, + SessionUpdate::AgentMessageChunk(ContentChunk { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) if text == "Fast mode set to on\n" + )); + let ops = thread.ops.lock().unwrap(); + assert!(matches!( + ops.as_slice(), + [Op::OverrideTurnContext { + service_tier: Some(Some(service_tier)), + .. + }] if service_tier == ServiceTier::Fast.request_value() + )); + + Ok(()) + } + + #[tokio::test] + async fn test_fast_off() -> anyhow::Result<()> { + let (session_id, client, thread, message_tx, _handle) = setup().await?; + let (prompt_response_tx, prompt_response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id.clone(), vec!["/fast off".into()]), + response_tx: prompt_response_tx, + })?; + + let stop_reason = prompt_response_rx.await??.await??; + assert_eq!(stop_reason, StopReason::EndTurn); + drop(message_tx); + + let notifications = client.notifications.lock().unwrap(); + assert_eq!(notifications.len(), 1); + assert!(matches!( + ¬ifications[0].update, + SessionUpdate::AgentMessageChunk(ContentChunk { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) if text == "Fast mode set to off\n" + )); + let ops = thread.ops.lock().unwrap(); + assert!(matches!( + ops.as_slice(), + [Op::OverrideTurnContext { + service_tier: Some(None), + .. + }] + )); + + Ok(()) + } + + #[tokio::test] + async fn test_fast_status() -> anyhow::Result<()> { + let (session_id, client, thread, message_tx, _handle) = setup().await?; + let (prompt_response_tx, prompt_response_rx) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id.clone(), vec!["/fast status".into()]), + response_tx: prompt_response_tx, + })?; + + let stop_reason = prompt_response_rx.await??.await??; + assert_eq!(stop_reason, StopReason::EndTurn); + drop(message_tx); + + let notifications = client.notifications.lock().unwrap(); + assert_eq!(notifications.len(), 1); + assert!(matches!( + ¬ifications[0].update, + SessionUpdate::AgentMessageChunk(ContentChunk { + content: ContentBlock::Text(TextContent { text, .. }), + .. + }) if text == "Fast mode is off\n" + )); + let ops = thread.ops.lock().unwrap(); + assert!(ops.is_empty()); + + Ok(()) + } + #[test] fn test_guardian_execve_summary_uses_argv_without_duplication() -> anyhow::Result<()> { let action = GuardianAssessmentAction::Execve { @@ -4761,11 +4947,12 @@ mod tests { SessionClient::with_client(session_id.clone(), client.clone(), Arc::default()); let conversation = Arc::new(StubCodexThread::new()); let models_manager = Arc::new(StubModelsManager); - let config = Config::load_with_cli_overrides_and_harness_overrides( + let mut config = Config::load_with_cli_overrides_and_harness_overrides( vec![], ConfigOverrides::default(), ) .await?; + config.service_tier = None; let (message_tx, message_rx) = tokio::sync::mpsc::unbounded_channel(); let (resolution_tx, resolution_rx) = tokio::sync::mpsc::unbounded_channel(); @@ -5130,6 +5317,7 @@ mod tests { .unwrap(); } Op::ExecApproval { .. } + | Op::OverrideTurnContext { .. } | Op::ResolveElicitation { .. } | Op::RequestPermissionsResponse { .. } | Op::PatchApproval { .. }