diff --git a/src/lib.rs b/src/lib.rs index 7da85ed..d9309b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use tracing_subscriber::EnvFilter; mod codex_agent; mod local_spawner; +mod pricing; mod prompt_args; mod thread; diff --git a/src/pricing.rs b/src/pricing.rs new file mode 100644 index 0000000..9d804a0 --- /dev/null +++ b/src/pricing.rs @@ -0,0 +1,144 @@ +/// Estimated per-token pricing in USD for OpenAI models. +/// +/// These rates are used to derive an approximate `UsageUpdate.cost` from token +/// counts when the upstream Codex event stream does not provide an authoritative +/// cost signal. Prices will need periodic updates as OpenAI changes its +/// pricing. Once is resolved, +/// this module can be replaced with the upstream cost. + +/// Per-million-token prices in USD. +pub(crate) struct ModelPricing { + /// Regular (non-cached) input tokens. + input: f64, + /// Cached input tokens. + cached_input: f64, + /// Output tokens (includes reasoning tokens). + output: f64, +} + +impl ModelPricing { + const fn new(input: f64, cached_input: f64, output: f64) -> Self { + Self { + input, + cached_input, + output, + } + } + + /// Estimate cost in USD from per-turn token counts. + /// + /// `input_tokens` is the total input count reported by the API (which + /// *includes* cached tokens). We subtract `cached_input_tokens` to get the + /// non-cached portion that is billed at the regular input rate. + pub(crate) fn estimate_cost( + &self, + input_tokens: u64, + cached_input_tokens: u64, + output_tokens: u64, + ) -> f64 { + let uncached = input_tokens.saturating_sub(cached_input_tokens) as f64; + let cached = cached_input_tokens as f64; + let output = output_tokens as f64; + + (uncached * self.input + cached * self.cached_input + output * self.output) / 1_000_000.0 + } +} + +// Prices per 1 M tokens (USD) — last updated 2025-04. +// Sorted longest-prefix-first so the first match wins. +static PRICING_TABLE: &[(&str, ModelPricing)] = &[ + ("gpt-4.1-mini", ModelPricing::new(0.40, 0.10, 1.60)), + ("gpt-4.1-nano", ModelPricing::new(0.10, 0.025, 0.40)), + ("gpt-4.1", ModelPricing::new(2.00, 0.50, 8.00)), + ("gpt-4o-mini", ModelPricing::new(0.15, 0.075, 0.60)), + ("gpt-4o", ModelPricing::new(2.50, 1.25, 10.00)), + ("o4-mini", ModelPricing::new(1.10, 0.275, 4.40)), + ("o3-mini", ModelPricing::new(1.10, 0.55, 4.40)), + ("o3", ModelPricing::new(2.00, 0.50, 8.00)), +]; + +/// Look up pricing for `model`. Tries the full slug first, then strips a +/// trailing date suffix (e.g. `-2025-04-14`) and retries, matching against +/// known prefixes. +pub(crate) fn lookup_pricing(model: &str) -> Option<&'static ModelPricing> { + // Try prefix match against the table (longest prefix listed first). + if let Some(pricing) = prefix_match(model) { + return Some(pricing); + } + + // Strip a trailing `-YYYY-MM-DD` date suffix and retry. + if let Some(base) = model.rsplit_once('-').and_then(|(left, _)| { + // Only strip if the part before the last dash also ends with digits + // (i.e. the suffix looks like a date: YYYY-MM-DD). + left.rsplit_once('-') + .and_then(|(ll, _)| ll.rsplit_once('-').map(|(lll, _)| lll)) + }) { + return prefix_match(base); + } + + None +} + +fn prefix_match(slug: &str) -> Option<&'static ModelPricing> { + PRICING_TABLE + .iter() + .find(|(prefix, _)| slug.starts_with(prefix)) + .map(|(_, p)| p) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn exact_match() { + let p = lookup_pricing("gpt-4.1").unwrap(); + assert!((p.input - 2.0).abs() < f64::EPSILON); + } + + #[test] + fn date_suffix() { + let p = lookup_pricing("gpt-4.1-2025-04-14").unwrap(); + assert!((p.input - 2.0).abs() < f64::EPSILON); + } + + #[test] + fn mini_before_base() { + let p = lookup_pricing("gpt-4.1-mini").unwrap(); + assert!((p.input - 0.40).abs() < f64::EPSILON); + } + + #[test] + fn nano_match() { + let p = lookup_pricing("gpt-4.1-nano").unwrap(); + assert!((p.input - 0.10).abs() < f64::EPSILON); + } + + #[test] + fn o3_match() { + let p = lookup_pricing("o3").unwrap(); + assert!((p.input - 2.0).abs() < f64::EPSILON); + } + + #[test] + fn o4_mini_match() { + let p = lookup_pricing("o4-mini").unwrap(); + assert!((p.input - 1.10).abs() < f64::EPSILON); + } + + #[test] + fn unknown_model() { + assert!(lookup_pricing("unknown-model-xyz").is_none()); + } + + #[test] + fn cost_calculation() { + let p = lookup_pricing("gpt-4.1").unwrap(); + // 1000 uncached input (2000 total - 1000 cached), 1000 cached, 500 output + // = (1000 * 2.0 + 1000 * 0.5 + 500 * 8.0) / 1_000_000 + // = (2000 + 500 + 4000) / 1_000_000 + // = 0.0065 + let cost = p.estimate_cost(2000, 1000, 500); + assert!((cost - 0.0065).abs() < 1e-10); + } +} diff --git a/src/thread.rs b/src/thread.rs index 4b074cd..510bb25 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -1,5 +1,5 @@ use std::{ - cell::RefCell, + cell::{Cell, RefCell}, collections::HashMap, ops::DerefMut, path::{Path, PathBuf}, @@ -18,7 +18,7 @@ use agent_client_protocol::{ SessionConfigValueId, SessionId, SessionInfoUpdate, SessionMode, SessionModeId, SessionModeState, SessionModelState, SessionNotification, SessionUpdate, StopReason, Terminal, TextResourceContents, ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, - ToolCallUpdate, ToolCallUpdateFields, ToolKind, UnstructuredCommandInput, UsageUpdate, + ToolCallUpdate, ToolCallUpdateFields, ToolKind, UnstructuredCommandInput, Cost, UsageUpdate, }; use codex_apply_patch::parse_patch; use codex_core::{ @@ -467,6 +467,10 @@ struct PromptState { response_tx: Option>>, seen_message_deltas: bool, seen_reasoning_deltas: bool, + /// Shared cumulative cost accumulator (persists across prompts in the session). + cumulative_cost: Rc>, + /// Current model slug, updated on `ModelReroute` events. + current_model: String, } impl PromptState { @@ -475,6 +479,8 @@ impl PromptState { thread: Arc, resolution_tx: mpsc::UnboundedSender, response_tx: oneshot::Sender>, + cumulative_cost: Rc>, + current_model: String, ) -> Self { Self { submission_id, @@ -487,6 +493,8 @@ impl PromptState { response_tx: Some(response_tx), seen_message_deltas: false, seen_reasoning_deltas: false, + cumulative_cost, + current_model, } } @@ -683,11 +691,22 @@ impl PromptState { if let Some(info) = info && let Some(size) = info.model_context_window { let used = info.last_token_usage.tokens_in_context_window().max(0) as u64; + + let mut update = UsageUpdate::new(used, size as u64); + + // Estimate incremental cost from per-turn token counts. + let last = &info.last_token_usage; + if let Some(pricing) = crate::pricing::lookup_pricing(&self.current_model) { + let input = last.input_tokens.max(0) as u64; + let cached = last.cached_input_tokens.max(0) as u64; + let output = last.output_tokens.max(0) as u64; + let increment = pricing.estimate_cost(input, cached, output); + self.cumulative_cost.set(self.cumulative_cost.get() + increment); + update = update.cost(Cost::new(self.cumulative_cost.get(), "USD")); + } + client - .send_notification(SessionUpdate::UsageUpdate(UsageUpdate::new( - used, - size as u64, - ))) + .send_notification(SessionUpdate::UsageUpdate(update)) .await; } } @@ -1003,6 +1022,7 @@ impl PromptState { } EventMsg::ModelReroute(ModelRerouteEvent { from_model, to_model, reason }) => { info!("Model reroute: from={from_model}, to={to_model}, reason={reason:?}"); + self.current_model = to_model; } EventMsg::ContextCompacted(..) => { @@ -2205,6 +2225,8 @@ struct ThreadActor { resolution_rx: mpsc::UnboundedReceiver, /// Last config options state we emitted to the client, used for deduping updates. last_sent_config_options: Option>, + /// Cumulative estimated session cost in USD, shared with active `PromptState`. + cumulative_cost: Rc>, } impl ThreadActor { @@ -2231,6 +2253,7 @@ impl ThreadActor { message_rx, resolution_rx, last_sent_config_options: None, + cumulative_cost: Rc::new(Cell::new(0.0)), } } @@ -2866,11 +2889,14 @@ impl ThreadActor { info!("Submitted prompt with submission_id: {submission_id}"); info!("Starting to wait for conversation events for submission_id: {submission_id}"); + let current_model = self.get_current_model().await; let state = SubmissionState::Prompt(PromptState::new( submission_id.clone(), self.thread.clone(), self.resolution_tx.clone(), response_tx, + self.cumulative_cost.clone(), + current_model, )); self.submissions.insert(submission_id, state); @@ -4502,6 +4528,8 @@ mod tests { thread.clone(), message_tx, response_tx, + Rc::new(Cell::new(0.0)), + String::new(), ); prompt_state @@ -4589,6 +4617,8 @@ mod tests { thread.clone(), message_tx, response_tx, + Rc::new(Cell::new(0.0)), + String::new(), ); prompt_state @@ -4652,7 +4682,7 @@ mod tests { let (response_tx, _response_rx) = tokio::sync::oneshot::channel(); let (message_tx, _message_rx) = tokio::sync::mpsc::unbounded_channel(); let mut prompt_state = - PromptState::new("submission-id".to_string(), thread, message_tx, response_tx); + PromptState::new("submission-id".to_string(), thread, message_tx, response_tx, Rc::new(Cell::new(0.0)), String::new()); prompt_state .handle_event(