Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions codex-rs/core/tests/common/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ impl ResponsesRequest {
self.0.body_json().unwrap()
}

pub fn body_bytes(&self) -> Vec<u8> {
self.0.body.clone()
}

/// Returns all `input_text` spans from `message` inputs for the provided role.
pub fn message_input_texts(&self, role: &str) -> Vec<String> {
self.inputs_of_type("message")
Expand Down Expand Up @@ -701,33 +705,6 @@ pub async fn start_mock_server() -> MockServer {
server
}

// todo(aibrahim): remove this and use our search matching patterns directly
/// Get all POST requests to `/responses` endpoints from the mock server.
/// Filters out GET requests (e.g., `/models`) .
pub async fn get_responses_requests(server: &MockServer) -> Vec<wiremock::Request> {
server
.received_requests()
.await
.expect("mock server should not fail")
.into_iter()
.filter(|req| req.method == "POST" && req.url.path().ends_with("/responses"))
.collect()
}

// todo(aibrahim): remove this and use our search matching patterns directly
/// Get request bodies as JSON values from POST requests to `/responses` endpoints.
/// Filters out GET requests (e.g., `/models`) .
pub async fn get_responses_request_bodies(server: &MockServer) -> Vec<Value> {
get_responses_requests(server)
.await
.into_iter()
.map(|req| {
req.body_json::<Value>()
.expect("request body to be valid JSON")
})
.collect()
}

#[derive(Clone)]
pub struct FunctionCallResponseMocks {
pub function_call: ResponseMock,
Expand Down
16 changes: 14 additions & 2 deletions codex-rs/core/tests/common/test_codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ use tempfile::TempDir;
use wiremock::MockServer;

use crate::load_default_config_for_test;
use crate::responses::get_responses_request_bodies;
use crate::responses::start_mock_server;
use crate::streaming_sse::StreamingSseServer;
use crate::wait_for_event;
use wiremock::Match;
use wiremock::matchers::path_regex;

type ConfigMutator = dyn FnOnce(&mut Config) + Send;
type PreBuildHook = dyn FnOnce(&Path) + Send + 'static;
Expand Down Expand Up @@ -322,7 +323,18 @@ impl TestCodexHarness {
}

pub async fn request_bodies(&self) -> Vec<Value> {
get_responses_request_bodies(&self.server).await
let path_matcher = path_regex(".*/responses$");
self.server
.received_requests()
.await
.expect("mock server should not fail")
.into_iter()
.filter(|req| path_matcher.matches(req))
.map(|req| {
req.body_json::<Value>()
.expect("request body to be valid JSON")
})
.collect()
}

pub async fn function_call_output_value(&self, call_id: &str) -> Value {
Expand Down
159 changes: 66 additions & 93 deletions codex-rs/core/tests/suite/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use codex_otel::otel_manager::OtelManager;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::config_types::Verbosity;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ReasoningItemReasoningSummary;
use codex_protocol::models::WebSearchAction;
Expand All @@ -31,9 +32,9 @@ use codex_protocol::user_input::UserInput;
use core_test_support::load_default_config_for_test;
use core_test_support::load_sse_fixture_with_id;
use core_test_support::responses::ev_completed_with_tokens;
use core_test_support::responses::get_responses_requests;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_once_match;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::sse_failed;
use core_test_support::skip_if_no_network;
Expand Down Expand Up @@ -324,17 +325,7 @@ async fn includes_conversation_id_and_model_headers_in_request() {
// Mock server
let server = MockServer::start().await;

// First request – must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");

Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let resp_mock = mount_sse_once(&server, sse_completed("resp1")).await;

let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
Expand Down Expand Up @@ -373,24 +364,19 @@ async fn includes_conversation_id_and_model_headers_in_request() {

wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;

// get request from the server
let requests = get_responses_requests(&server).await;
let request = requests
.first()
.expect("expected POST request to /responses");
let request_conversation_id = request.headers.get("conversation_id").unwrap();
let request_authorization = request.headers.get("authorization").unwrap();
let request_originator = request.headers.get("originator").unwrap();

assert_eq!(
request_conversation_id.to_str().unwrap(),
conversation_id.to_string()
);
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
assert_eq!(
request_authorization.to_str().unwrap(),
"Bearer Test API Key"
);
let request = resp_mock.single_request();
assert_eq!(request.path(), "/v1/responses");
let request_conversation_id = request
.header("conversation_id")
.expect("conversation_id header");
let request_authorization = request
.header("authorization")
.expect("authorization header");
let request_originator = request.header("originator").expect("originator header");

assert_eq!(request_conversation_id, conversation_id.to_string());
assert_eq!(request_originator, "codex_cli_rs");
assert_eq!(request_authorization, "Bearer Test API Key");
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand Down Expand Up @@ -451,17 +437,7 @@ async fn chatgpt_auth_sends_correct_request() {
// Mock server
let server = MockServer::start().await;

// First request – must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");

Mock::given(method("POST"))
.and(path("/api/codex/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let resp_mock = mount_sse_once(&server, sse_completed("resp1")).await;

let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/api/codex", server.uri())),
Expand Down Expand Up @@ -499,27 +475,24 @@ async fn chatgpt_auth_sends_correct_request() {

wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;

// get request from the server
let requests = get_responses_requests(&server).await;
let request = requests
.first()
.expect("expected POST request to /responses");
let request_conversation_id = request.headers.get("conversation_id").unwrap();
let request_authorization = request.headers.get("authorization").unwrap();
let request_originator = request.headers.get("originator").unwrap();
let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap();
let request_body = request.body_json::<serde_json::Value>().unwrap();
let request = resp_mock.single_request();
assert_eq!(request.path(), "/api/codex/responses");
let request_conversation_id = request
.header("conversation_id")
.expect("conversation_id header");
let request_authorization = request
.header("authorization")
.expect("authorization header");
let request_originator = request.header("originator").expect("originator header");
let request_chatgpt_account_id = request
.header("chatgpt-account-id")
.expect("chatgpt-account-id header");
let request_body = request.body_json();

assert_eq!(
request_conversation_id.to_str().unwrap(),
conversation_id.to_string()
);
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
assert_eq!(
request_authorization.to_str().unwrap(),
"Bearer Access Token"
);
assert_eq!(request_chatgpt_account_id.to_str().unwrap(), "account_id");
assert_eq!(request_conversation_id, conversation_id.to_string());
assert_eq!(request_originator, "codex_cli_rs");
assert_eq!(request_authorization, "Bearer Access Token");
assert_eq!(request_chatgpt_account_id, "account_id");
assert!(request_body["stream"].as_bool().unwrap());
assert_eq!(
request_body["include"][0].as_str().unwrap(),
Expand Down Expand Up @@ -1107,17 +1080,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
"data: {\"type\":\"response.created\",\"response\":{}}\n\n",
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\"}}\n\n",
);

let template = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_body, "text/event-stream");

Mock::given(method("POST"))
.and(path("/openai/responses"))
.respond_with(template)
.expect(1)
.mount(&server)
.await;
let resp_mock = mount_sse_once(&server, sse_body.to_string()).await;

let provider = ModelProviderInfo {
name: "azure".into(),
Expand Down Expand Up @@ -1202,6 +1165,13 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
arguments: "{}".into(),
call_id: "function-call-id".into(),
});
prompt.input.push(ResponseItem::FunctionCallOutput {
call_id: "function-call-id".into(),
output: FunctionCallOutputPayload {
content: "ok".into(),
..Default::default()
},
});
prompt.input.push(ResponseItem::LocalShellCall {
id: Some("local-shell-id".into()),
call_id: Some("local-shell-call-id".into()),
Expand All @@ -1221,6 +1191,10 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
name: "custom_tool".into(),
input: "{}".into(),
});
prompt.input.push(ResponseItem::CustomToolCallOutput {
call_id: "custom-tool-call-id".into(),
output: "ok".into(),
});

let mut stream = client
.stream(&prompt)
Expand All @@ -1233,21 +1207,27 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
}
}

let requests = get_responses_requests(&server).await;
assert_eq!(requests.len(), 1, "expected a single POST request");
let body: serde_json::Value = requests[0]
.body_json()
.expect("request body to be valid JSON");
let request = resp_mock.single_request();
assert_eq!(request.path(), "/openai/responses");
let body = request.body_json();

assert_eq!(body["store"], serde_json::Value::Bool(true));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(6));
assert_eq!(body["input"].as_array().map(Vec::len), Some(8));
assert_eq!(body["input"][0]["id"].as_str(), Some("reasoning-id"));
assert_eq!(body["input"][1]["id"].as_str(), Some("message-id"));
assert_eq!(body["input"][2]["id"].as_str(), Some("web-search-id"));
assert_eq!(body["input"][3]["id"].as_str(), Some("function-id"));
assert_eq!(body["input"][4]["id"].as_str(), Some("local-shell-id"));
assert_eq!(body["input"][5]["id"].as_str(), Some("custom-tool-id"));
assert_eq!(
body["input"][4]["call_id"].as_str(),
Some("function-call-id")
);
assert_eq!(body["input"][5]["id"].as_str(), Some("local-shell-id"));
assert_eq!(body["input"][6]["id"].as_str(), Some("custom-tool-id"));
assert_eq!(
body["input"][7]["call_id"].as_str(),
Some("custom-tool-call-id")
);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand Down Expand Up @@ -1784,16 +1764,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
]"##;
let sse1 = core_test_support::load_sse_fixture_with_id_from_str(sse_raw, "resp1");

Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse1.clone(), "text/event-stream"),
)
.expect(3) // respond identically to the three sequential turns
.mount(&server)
.await;
let request_log = mount_sse_sequence(&server, vec![sse1.clone(), sse1.clone(), sse1]).await;

// Configure provider to point to mock server (Responses API) and use API key auth.
let model_provider = ModelProviderInfo {
Expand Down Expand Up @@ -1847,8 +1818,11 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;

// Inspect the three captured requests.
let requests = get_responses_requests(&server).await;
let requests = request_log.requests();
assert_eq!(requests.len(), 3, "expected 3 requests (one per turn)");
for request in &requests {
assert_eq!(request.path(), "/v1/responses");
}

// Replace full-array compare with tail-only raw JSON compare using a single hard-coded value.
let r3_tail_expected = json!([
Expand Down Expand Up @@ -1880,8 +1854,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
]);

let r3_input_array = requests[2]
.body_json::<serde_json::Value>()
.unwrap()
.body_json()
.get("input")
.and_then(|v| v.as_array())
.cloned()
Expand Down
Loading
Loading