diff --git a/src/codex_agent.rs b/src/codex_agent.rs index 33891c0..dd03b31 100644 --- a/src/codex_agent.rs +++ b/src/codex_agent.rs @@ -534,9 +534,7 @@ impl Agent for CodexAgent { // Get the session state let thread = self.get_thread(&request.session_id)?; - let stop_reason = thread.prompt(request).await?; - - Ok(PromptResponse::new(stop_reason)) + thread.prompt_response(request).await } async fn cancel(&self, args: CancelNotification) -> Result<(), Error> { diff --git a/src/thread.rs b/src/thread.rs index 4b074cd..ca9c11b 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -12,11 +12,13 @@ use agent_client_protocol::{ ConfigOptionUpdate, Content, ContentBlock, ContentChunk, Diff, EmbeddedResource, EmbeddedResourceResource, Error, LoadSessionResponse, Meta, ModelId, ModelInfo, PermissionOption, PermissionOptionKind, Plan, PlanEntry, PlanEntryPriority, PlanEntryStatus, - PromptRequest, RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse, + PromptRequest, PromptResponse, RequestPermissionOutcome, RequestPermissionRequest, + RequestPermissionResponse, ResourceLink, SelectedPermissionOutcome, SessionConfigId, SessionConfigOption, SessionConfigOptionCategory, SessionConfigOptionValue, SessionConfigSelectOption, SessionConfigValueId, SessionId, SessionInfoUpdate, SessionMode, SessionModeId, SessionModeState, SessionModelState, SessionNotification, SessionUpdate, StopReason, Terminal, + Usage, TextResourceContents, ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields, ToolKind, UnstructuredCommandInput, UsageUpdate, }; @@ -51,8 +53,9 @@ use codex_protocol::{ Op, PatchApplyBeginEvent, PatchApplyEndEvent, PatchApplyStatus, ReasoningContentDeltaEvent, ReasoningRawContentDeltaEvent, ReviewDecision, ReviewOutputEvent, ReviewRequest, ReviewTarget, RolloutItem, SandboxPolicy, StreamErrorEvent, TerminalInteractionEvent, - TokenCountEvent, TurnAbortedEvent, TurnCompleteEvent, TurnStartedEvent, UserMessageEvent, - ViewImageToolCallEvent, WarningEvent, WebSearchBeginEvent, WebSearchEndEvent, + TokenCountEvent, TokenUsage, TurnAbortedEvent, TurnCompleteEvent, TurnStartedEvent, + UserMessageEvent, ViewImageToolCallEvent, WarningEvent, WebSearchBeginEvent, + WebSearchEndEvent, }, request_permissions::{ PermissionGrantScope, RequestPermissionsEvent, RequestPermissionsResponse, @@ -133,7 +136,8 @@ enum ThreadMessage { }, Prompt { request: PromptRequest, - response_tx: oneshot::Sender>, Error>>, + response_tx: + oneshot::Sender>, Error>>, }, SetMode { mode: SessionModeId, @@ -228,6 +232,12 @@ impl Thread { } pub async fn prompt(&self, request: PromptRequest) -> Result { + self.prompt_response(request) + .await + .map(|response| response.stop_reason) + } + + pub async fn prompt_response(&self, request: PromptRequest) -> Result { let (response_tx, response_rx) = oneshot::channel(); let message = ThreadMessage::Prompt { @@ -464,9 +474,11 @@ struct PromptState { resolution_tx: mpsc::UnboundedSender, pending_permission_interactions: HashMap, event_count: usize, - response_tx: Option>>, + response_tx: Option>>, seen_message_deltas: bool, seen_reasoning_deltas: bool, + token_usage_at_turn_start: Option, + latest_total_token_usage: Option, } impl PromptState { @@ -474,7 +486,8 @@ impl PromptState { submission_id: String, thread: Arc, resolution_tx: mpsc::UnboundedSender, - response_tx: oneshot::Sender>, + response_tx: oneshot::Sender>, + token_usage_at_turn_start: Option, ) -> Self { Self { submission_id, @@ -487,6 +500,8 @@ impl PromptState { response_tx: Some(response_tx), seen_message_deltas: false, seen_reasoning_deltas: false, + latest_total_token_usage: token_usage_at_turn_start.clone(), + token_usage_at_turn_start, } } @@ -680,9 +695,10 @@ impl PromptState { info!("Task started with context window of {turn_id} {model_context_window:?} {collaboration_mode_kind:?}"); } EventMsg::TokenCount(TokenCountEvent { info, .. }) => { - if let Some(info) = info - && let Some(size) = info.model_context_window { - let used = info.last_token_usage.tokens_in_context_window().max(0) as u64; + if let Some(info) = info { + self.latest_total_token_usage = Some(info.total_token_usage.clone()); + if let Some(size) = info.model_context_window { + let used = info.total_token_usage.tokens_in_context_window().max(0) as u64; client .send_notification(SessionUpdate::UsageUpdate(UsageUpdate::new( used, @@ -690,6 +706,7 @@ impl PromptState { ))) .await; } + } } EventMsg::ItemStarted(ItemStartedEvent { thread_id, turn_id, item }) => { info!("Item started with thread_id: {thread_id}, turn_id: {turn_id}, item: {item:?}"); @@ -893,7 +910,14 @@ impl PromptState { ); self.abort_pending_interactions(); if let Some(response_tx) = self.response_tx.take() { - response_tx.send(Ok(StopReason::EndTurn)).ok(); + let response = PromptResponse::new(StopReason::EndTurn).usage( + turn_usage( + self.token_usage_at_turn_start.as_ref(), + self.latest_total_token_usage.as_ref(), + ) + .map(usage_from_token_usage), + ); + response_tx.send(Ok(response)).ok(); } } EventMsg::UndoStarted(event) => { @@ -940,14 +964,18 @@ impl PromptState { info!("Turn {turn_id:?} aborted: {reason:?}"); self.abort_pending_interactions(); if let Some(response_tx) = self.response_tx.take() { - response_tx.send(Ok(StopReason::Cancelled)).ok(); + response_tx + .send(Ok(PromptResponse::new(StopReason::Cancelled))) + .ok(); } } EventMsg::ShutdownComplete => { info!("Agent shutting down"); self.abort_pending_interactions(); if let Some(response_tx) = self.response_tx.take() { - response_tx.send(Ok(StopReason::Cancelled)).ok(); + response_tx + .send(Ok(PromptResponse::new(StopReason::Cancelled))) + .ok(); } } EventMsg::ViewImageToolCall(ViewImageToolCallEvent { call_id, path }) => { @@ -2205,6 +2233,8 @@ struct ThreadActor { resolution_rx: mpsc::UnboundedReceiver, /// Last config options state we emitted to the client, used for deduping updates. last_sent_config_options: Option>, + /// Latest cumulative token usage observed for this session. + last_total_token_usage: Option, } impl ThreadActor { @@ -2231,6 +2261,7 @@ impl ThreadActor { message_rx, resolution_rx, last_sent_config_options: None, + last_total_token_usage: None, } } @@ -2319,7 +2350,9 @@ impl ThreadActor { request, response_tx, } => { - let result = self.handle_prompt(request).await; + let result = self + .handle_prompt(request, self.last_total_token_usage.clone()) + .await; drop(response_tx.send(result)); } ThreadMessage::SetMode { mode, response_tx } => { @@ -2768,7 +2801,8 @@ impl ThreadActor { async fn handle_prompt( &mut self, request: PromptRequest, - ) -> Result>, Error> { + token_usage_at_turn_start: Option, + ) -> Result>, Error> { let (response_tx, response_rx) = oneshot::channel(); let items = build_prompt_items(request.prompt); @@ -2871,6 +2905,7 @@ impl ThreadActor { self.thread.clone(), self.resolution_tx.clone(), response_tx, + token_usage_at_turn_start, )); self.submissions.insert(submission_id, state); @@ -3315,6 +3350,9 @@ impl ThreadActor { } async fn handle_event(&mut self, Event { id, msg }: Event) { + if let EventMsg::TokenCount(TokenCountEvent { info: Some(info), .. }) = &msg { + self.last_total_token_usage = Some(info.total_token_usage.clone()); + } if let Some(submission) = self.submissions.get_mut(&id) { submission.handle_event(&self.client, msg).await; } else { @@ -3323,6 +3361,34 @@ impl ThreadActor { } } +fn usage_from_token_usage(usage: TokenUsage) -> Usage { + Usage::new( + usage.total_tokens.max(0) as u64, + usage.input_tokens.max(0) as u64, + usage.output_tokens.max(0) as u64, + ) + .thought_tokens( + (usage.reasoning_output_tokens > 0).then_some(usage.reasoning_output_tokens as u64), + ) + .cached_read_tokens((usage.cached_input_tokens > 0).then_some(usage.cached_input_tokens as u64)) +} + +fn turn_usage(start: Option<&TokenUsage>, end: Option<&TokenUsage>) -> Option { + let end = end.cloned()?; + Some(match start { + Some(start) => TokenUsage { + input_tokens: (end.input_tokens - start.input_tokens).max(0), + cached_input_tokens: (end.cached_input_tokens - start.cached_input_tokens).max(0), + output_tokens: (end.output_tokens - start.output_tokens).max(0), + reasoning_output_tokens: (end.reasoning_output_tokens + - start.reasoning_output_tokens) + .max(0), + total_tokens: (end.total_tokens - start.total_tokens).max(0), + }, + None => end, + }) +} + fn build_prompt_items(prompt: Vec) -> Vec { prompt .into_iter() @@ -3488,8 +3554,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3512,6 +3578,66 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_prompt_usage_reports_turn_delta_and_cumulative_context() -> anyhow::Result<()> { + let (session_id, client, _, message_tx, local_set) = setup(vec![]).await?; + let (prompt_response_tx1, prompt_response_rx1) = tokio::sync::oneshot::channel(); + let (prompt_response_tx2, prompt_response_rx2) = tokio::sync::oneshot::channel(); + + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id.clone(), vec!["usage-1".into()]), + response_tx: prompt_response_tx1, + })?; + message_tx.send(ThreadMessage::Prompt { + request: PromptRequest::new(session_id.clone(), vec!["usage-2".into()]), + response_tx: prompt_response_tx2, + })?; + + tokio::try_join!( + async { + let response1 = prompt_response_rx1.await??.await??; + assert_eq!(response1.stop_reason, StopReason::EndTurn); + let usage1 = response1.usage.expect("first response should include usage"); + assert_eq!(usage1.input_tokens, 100); + assert_eq!(usage1.output_tokens, 50); + assert_eq!(usage1.total_tokens, 175); + assert_eq!(usage1.thought_tokens, Some(5)); + assert_eq!(usage1.cached_read_tokens, Some(20)); + + let response2 = prompt_response_rx2.await??.await??; + assert_eq!(response2.stop_reason, StopReason::EndTurn); + let usage2 = response2.usage.expect("second response should include usage"); + assert_eq!(usage2.input_tokens, 30); + assert_eq!(usage2.output_tokens, 20); + assert_eq!(usage2.total_tokens, 33); + assert_eq!(usage2.thought_tokens, Some(3)); + assert_eq!(usage2.cached_read_tokens, Some(5)); + + drop(message_tx); + anyhow::Ok(()) + }, + async { + local_set.await; + anyhow::Ok(()) + } + )?; + + let notifications = client.notifications.lock().unwrap(); + let usage_updates: Vec<_> = notifications + .iter() + .filter_map(|notification| match ¬ification.update { + SessionUpdate::UsageUpdate(update) => Some(update.clone()), + _ => None, + }) + .collect(); + assert_eq!(usage_updates.len(), 2); + assert_eq!(usage_updates[0].used, 175); + assert_eq!(usage_updates[1].used, 208); + assert_eq!(usage_updates[1].size, 256_000); + + Ok(()) + } + #[tokio::test] async fn test_compact() -> anyhow::Result<()> { let (session_id, client, thread, message_tx, local_set) = setup(vec![]).await?; @@ -3524,8 +3650,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3562,8 +3688,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3612,8 +3738,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3662,8 +3788,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3717,8 +3843,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3772,8 +3898,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3829,8 +3955,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3891,8 +4017,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -3943,8 +4069,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -4148,6 +4274,64 @@ mod tests { last_agent_message: None, turn_id, })); + } else if prompt == "usage-1" || prompt == "usage-2" { + let (total_usage, last_usage) = if prompt == "usage-1" { + ( + TokenUsage { + input_tokens: 100, + cached_input_tokens: 20, + output_tokens: 50, + reasoning_output_tokens: 5, + total_tokens: 175, + }, + TokenUsage { + input_tokens: 100, + cached_input_tokens: 20, + output_tokens: 50, + reasoning_output_tokens: 5, + total_tokens: 175, + }, + ) + } else { + ( + TokenUsage { + input_tokens: 130, + cached_input_tokens: 25, + output_tokens: 70, + reasoning_output_tokens: 8, + total_tokens: 208, + }, + TokenUsage { + input_tokens: 30, + cached_input_tokens: 5, + output_tokens: 20, + reasoning_output_tokens: 3, + total_tokens: 33, + }, + ) + }; + self.op_tx + .send(Event { + id: id.to_string(), + msg: EventMsg::TokenCount(TokenCountEvent { + info: Some(codex_protocol::protocol::TokenUsageInfo { + total_token_usage: total_usage, + last_token_usage: last_usage, + model_context_window: Some(256_000), + }), + rate_limits: None, + }), + }) + .unwrap(); + self.op_tx + .send(Event { + id: id.to_string(), + msg: EventMsg::TurnComplete(TurnCompleteEvent { + last_agent_message: None, + turn_id: id.to_string(), + }), + }) + .unwrap(); } else if prompt == "approval-block" { self.op_tx .send(Event { @@ -4414,8 +4598,8 @@ mod tests { tokio::try_join!( async { - let stop_reason = prompt_response_rx.await??.await??; - assert_eq!(stop_reason, StopReason::EndTurn); + let response = prompt_response_rx.await??.await??; + assert_eq!(response.stop_reason, StopReason::EndTurn); drop(message_tx); anyhow::Ok(()) }, @@ -4502,6 +4686,7 @@ mod tests { thread.clone(), message_tx, response_tx, + None, ); prompt_state @@ -4589,6 +4774,7 @@ mod tests { thread.clone(), message_tx, response_tx, + None, ); prompt_state @@ -4651,8 +4837,13 @@ mod tests { let thread = Arc::new(StubCodexThread::new()); let (response_tx, _response_rx) = tokio::sync::oneshot::channel(); let (message_tx, _message_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut prompt_state = - PromptState::new("submission-id".to_string(), thread, message_tx, response_tx); + let mut prompt_state = PromptState::new( + "submission-id".to_string(), + thread, + message_tx, + response_tx, + None, + ); prompt_state .handle_event( @@ -4770,9 +4961,9 @@ mod tests { .await?; tokio::time::timeout(Duration::from_millis(100), thread.shutdown()).await??; - let stop_reason = + let response = tokio::time::timeout(Duration::from_millis(100), stop_reason_rx).await??; - assert_eq!(stop_reason?, StopReason::Cancelled); + assert_eq!(response?.stop_reason, StopReason::Cancelled); anyhow::Ok(()) })