Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tracing_subscriber::EnvFilter;

mod codex_agent;
mod local_spawner;
mod pricing;
mod prompt_args;
mod thread;

Expand Down
144 changes: 144 additions & 0 deletions src/pricing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/// Estimated per-token pricing in USD for OpenAI models.
///
/// These rates are used to derive an approximate `UsageUpdate.cost` from token
/// counts when the upstream Codex event stream does not provide an authoritative
/// cost signal. Prices will need periodic updates as OpenAI changes its
/// pricing. Once <https://github.com/openai/codex/issues/16258> is resolved,
/// this module can be replaced with the upstream cost.

/// Per-million-token prices in USD.
pub(crate) struct ModelPricing {
/// Regular (non-cached) input tokens.
input: f64,
/// Cached input tokens.
cached_input: f64,
/// Output tokens (includes reasoning tokens).
output: f64,
}

impl ModelPricing {
const fn new(input: f64, cached_input: f64, output: f64) -> Self {
Self {
input,
cached_input,
output,
}
}

/// Estimate cost in USD from per-turn token counts.
///
/// `input_tokens` is the total input count reported by the API (which
/// *includes* cached tokens). We subtract `cached_input_tokens` to get the
/// non-cached portion that is billed at the regular input rate.
pub(crate) fn estimate_cost(
&self,
input_tokens: u64,
cached_input_tokens: u64,
output_tokens: u64,
) -> f64 {
let uncached = input_tokens.saturating_sub(cached_input_tokens) as f64;
let cached = cached_input_tokens as f64;
let output = output_tokens as f64;

(uncached * self.input + cached * self.cached_input + output * self.output) / 1_000_000.0
}
}

// Prices per 1 M tokens (USD) — last updated 2025-04.
// Sorted longest-prefix-first so the first match wins.
static PRICING_TABLE: &[(&str, ModelPricing)] = &[
("gpt-4.1-mini", ModelPricing::new(0.40, 0.10, 1.60)),
("gpt-4.1-nano", ModelPricing::new(0.10, 0.025, 0.40)),
("gpt-4.1", ModelPricing::new(2.00, 0.50, 8.00)),
("gpt-4o-mini", ModelPricing::new(0.15, 0.075, 0.60)),
("gpt-4o", ModelPricing::new(2.50, 1.25, 10.00)),
("o4-mini", ModelPricing::new(1.10, 0.275, 4.40)),
("o3-mini", ModelPricing::new(1.10, 0.55, 4.40)),
("o3", ModelPricing::new(2.00, 0.50, 8.00)),
];

/// Look up pricing for `model`. Tries the full slug first, then strips a
/// trailing date suffix (e.g. `-2025-04-14`) and retries, matching against
/// known prefixes.
pub(crate) fn lookup_pricing(model: &str) -> Option<&'static ModelPricing> {
// Try prefix match against the table (longest prefix listed first).
if let Some(pricing) = prefix_match(model) {
return Some(pricing);
}

// Strip a trailing `-YYYY-MM-DD` date suffix and retry.
if let Some(base) = model.rsplit_once('-').and_then(|(left, _)| {
// Only strip if the part before the last dash also ends with digits
// (i.e. the suffix looks like a date: YYYY-MM-DD).
left.rsplit_once('-')
.and_then(|(ll, _)| ll.rsplit_once('-').map(|(lll, _)| lll))
}) {
return prefix_match(base);
}

None
}

fn prefix_match(slug: &str) -> Option<&'static ModelPricing> {
PRICING_TABLE
.iter()
.find(|(prefix, _)| slug.starts_with(prefix))
.map(|(_, p)| p)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn exact_match() {
let p = lookup_pricing("gpt-4.1").unwrap();
assert!((p.input - 2.0).abs() < f64::EPSILON);
}

#[test]
fn date_suffix() {
let p = lookup_pricing("gpt-4.1-2025-04-14").unwrap();
assert!((p.input - 2.0).abs() < f64::EPSILON);
}

#[test]
fn mini_before_base() {
let p = lookup_pricing("gpt-4.1-mini").unwrap();
assert!((p.input - 0.40).abs() < f64::EPSILON);
}

#[test]
fn nano_match() {
let p = lookup_pricing("gpt-4.1-nano").unwrap();
assert!((p.input - 0.10).abs() < f64::EPSILON);
}

#[test]
fn o3_match() {
let p = lookup_pricing("o3").unwrap();
assert!((p.input - 2.0).abs() < f64::EPSILON);
}

#[test]
fn o4_mini_match() {
let p = lookup_pricing("o4-mini").unwrap();
assert!((p.input - 1.10).abs() < f64::EPSILON);
}

#[test]
fn unknown_model() {
assert!(lookup_pricing("unknown-model-xyz").is_none());
}

#[test]
fn cost_calculation() {
let p = lookup_pricing("gpt-4.1").unwrap();
// 1000 uncached input (2000 total - 1000 cached), 1000 cached, 500 output
// = (1000 * 2.0 + 1000 * 0.5 + 500 * 8.0) / 1_000_000
// = (2000 + 500 + 4000) / 1_000_000
// = 0.0065
let cost = p.estimate_cost(2000, 1000, 500);
assert!((cost - 0.0065).abs() < 1e-10);
}
}
44 changes: 37 additions & 7 deletions src/thread.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
cell::RefCell,
cell::{Cell, RefCell},
collections::HashMap,
ops::DerefMut,
path::{Path, PathBuf},
Expand All @@ -18,7 +18,7 @@ use agent_client_protocol::{
SessionConfigValueId, SessionId, SessionInfoUpdate, SessionMode, SessionModeId,
SessionModeState, SessionModelState, SessionNotification, SessionUpdate, StopReason, Terminal,
TextResourceContents, ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus,
ToolCallUpdate, ToolCallUpdateFields, ToolKind, UnstructuredCommandInput, UsageUpdate,
ToolCallUpdate, ToolCallUpdateFields, ToolKind, UnstructuredCommandInput, Cost, UsageUpdate,
};
use codex_apply_patch::parse_patch;
use codex_core::{
Expand Down Expand Up @@ -467,6 +467,10 @@ struct PromptState {
response_tx: Option<oneshot::Sender<Result<StopReason, Error>>>,
seen_message_deltas: bool,
seen_reasoning_deltas: bool,
/// Shared cumulative cost accumulator (persists across prompts in the session).
cumulative_cost: Rc<Cell<f64>>,
/// Current model slug, updated on `ModelReroute` events.
current_model: String,
}

impl PromptState {
Expand All @@ -475,6 +479,8 @@ impl PromptState {
thread: Arc<dyn CodexThreadImpl>,
resolution_tx: mpsc::UnboundedSender<ThreadMessage>,
response_tx: oneshot::Sender<Result<StopReason, Error>>,
cumulative_cost: Rc<Cell<f64>>,
current_model: String,
) -> Self {
Self {
submission_id,
Expand All @@ -487,6 +493,8 @@ impl PromptState {
response_tx: Some(response_tx),
seen_message_deltas: false,
seen_reasoning_deltas: false,
cumulative_cost,
current_model,
}
}

Expand Down Expand Up @@ -683,11 +691,22 @@ impl PromptState {
if let Some(info) = info
&& let Some(size) = info.model_context_window {
let used = info.last_token_usage.tokens_in_context_window().max(0) as u64;

let mut update = UsageUpdate::new(used, size as u64);

// Estimate incremental cost from per-turn token counts.
let last = &info.last_token_usage;
if let Some(pricing) = crate::pricing::lookup_pricing(&self.current_model) {
let input = last.input_tokens.max(0) as u64;
let cached = last.cached_input_tokens.max(0) as u64;
let output = last.output_tokens.max(0) as u64;
let increment = pricing.estimate_cost(input, cached, output);
self.cumulative_cost.set(self.cumulative_cost.get() + increment);
update = update.cost(Cost::new(self.cumulative_cost.get(), "USD"));
}

client
.send_notification(SessionUpdate::UsageUpdate(UsageUpdate::new(
used,
size as u64,
)))
.send_notification(SessionUpdate::UsageUpdate(update))
.await;
}
}
Expand Down Expand Up @@ -1003,6 +1022,7 @@ impl PromptState {
}
EventMsg::ModelReroute(ModelRerouteEvent { from_model, to_model, reason }) => {
info!("Model reroute: from={from_model}, to={to_model}, reason={reason:?}");
self.current_model = to_model;
}

EventMsg::ContextCompacted(..) => {
Expand Down Expand Up @@ -2205,6 +2225,8 @@ struct ThreadActor<A> {
resolution_rx: mpsc::UnboundedReceiver<ThreadMessage>,
/// Last config options state we emitted to the client, used for deduping updates.
last_sent_config_options: Option<Vec<SessionConfigOption>>,
/// Cumulative estimated session cost in USD, shared with active `PromptState`.
cumulative_cost: Rc<Cell<f64>>,
}

impl<A: Auth> ThreadActor<A> {
Expand All @@ -2231,6 +2253,7 @@ impl<A: Auth> ThreadActor<A> {
message_rx,
resolution_rx,
last_sent_config_options: None,
cumulative_cost: Rc::new(Cell::new(0.0)),
}
}

Expand Down Expand Up @@ -2866,11 +2889,14 @@ impl<A: Auth> ThreadActor<A> {
info!("Submitted prompt with submission_id: {submission_id}");
info!("Starting to wait for conversation events for submission_id: {submission_id}");

let current_model = self.get_current_model().await;
let state = SubmissionState::Prompt(PromptState::new(
submission_id.clone(),
self.thread.clone(),
self.resolution_tx.clone(),
response_tx,
self.cumulative_cost.clone(),
current_model,
));

self.submissions.insert(submission_id, state);
Expand Down Expand Up @@ -4502,6 +4528,8 @@ mod tests {
thread.clone(),
message_tx,
response_tx,
Rc::new(Cell::new(0.0)),
String::new(),
);

prompt_state
Expand Down Expand Up @@ -4589,6 +4617,8 @@ mod tests {
thread.clone(),
message_tx,
response_tx,
Rc::new(Cell::new(0.0)),
String::new(),
);

prompt_state
Expand Down Expand Up @@ -4652,7 +4682,7 @@ mod tests {
let (response_tx, _response_rx) = tokio::sync::oneshot::channel();
let (message_tx, _message_rx) = tokio::sync::mpsc::unbounded_channel();
let mut prompt_state =
PromptState::new("submission-id".to_string(), thread, message_tx, response_tx);
PromptState::new("submission-id".to_string(), thread, message_tx, response_tx, Rc::new(Cell::new(0.0)), String::new());

prompt_state
.handle_event(
Expand Down