Skip to content
Merged
27 changes: 0 additions & 27 deletions codex-rs/core/tests/common/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,33 +701,6 @@ pub async fn start_mock_server() -> MockServer {
server
}

// todo(aibrahim): remove this and use our search matching patterns directly
/// Get all POST requests to `/responses` endpoints from the mock server.
/// Filters out GET requests (e.g., `/models`) .
pub async fn get_responses_requests(server: &MockServer) -> Vec<wiremock::Request> {
server
.received_requests()
.await
.expect("mock server should not fail")
.into_iter()
.filter(|req| req.method == "POST" && req.url.path().ends_with("/responses"))
.collect()
}

// todo(aibrahim): remove this and use our search matching patterns directly
/// Get request bodies as JSON values from POST requests to `/responses` endpoints.
/// Filters out GET requests (e.g., `/models`) .
pub async fn get_responses_request_bodies(server: &MockServer) -> Vec<Value> {
get_responses_requests(server)
.await
.into_iter()
.map(|req| {
req.body_json::<Value>()
.expect("request body to be valid JSON")
})
.collect()
}

#[derive(Clone)]
pub struct FunctionCallResponseMocks {
pub function_call: ResponseMock,
Expand Down
16 changes: 14 additions & 2 deletions codex-rs/core/tests/common/test_codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ use tempfile::TempDir;
use wiremock::MockServer;

use crate::load_default_config_for_test;
use crate::responses::get_responses_request_bodies;
use crate::responses::start_mock_server;
use crate::streaming_sse::StreamingSseServer;
use crate::wait_for_event;
use wiremock::Match;
use wiremock::matchers::path_regex;

type ConfigMutator = dyn FnOnce(&mut Config) + Send;
type PreBuildHook = dyn FnOnce(&Path) + Send + 'static;
Expand Down Expand Up @@ -322,7 +323,18 @@ impl TestCodexHarness {
}

pub async fn request_bodies(&self) -> Vec<Value> {
get_responses_request_bodies(&self.server).await
let path_matcher = path_regex(".*/responses$");
self.server
.received_requests()
.await
.expect("mock server should not fail")
.into_iter()
.filter(|req| path_matcher.matches(req))
.map(|req| {
req.body_json::<Value>()
.expect("request body to be valid JSON")
})
.collect()
}

pub async fn function_call_output_value(&self, call_id: &str) -> Value {
Expand Down
133 changes: 43 additions & 90 deletions codex-rs/core/tests/suite/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ use codex_protocol::user_input::UserInput;
use core_test_support::load_default_config_for_test;
use core_test_support::load_sse_fixture_with_id;
use core_test_support::responses::ev_completed_with_tokens;
use core_test_support::responses::get_responses_requests;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_once_match;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::sse_failed;
use core_test_support::skip_if_no_network;
Expand Down Expand Up @@ -324,17 +324,7 @@ async fn includes_conversation_id_and_model_headers_in_request() {
// Mock server
let server = MockServer::start().await;

// First request – must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");

Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let resp_mock = mount_sse_once(&server, sse_completed("resp1")).await;

let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
Expand Down Expand Up @@ -373,24 +363,19 @@ async fn includes_conversation_id_and_model_headers_in_request() {

wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;

// get request from the server
let requests = get_responses_requests(&server).await;
let request = requests
.first()
.expect("expected POST request to /responses");
let request_conversation_id = request.headers.get("conversation_id").unwrap();
let request_authorization = request.headers.get("authorization").unwrap();
let request_originator = request.headers.get("originator").unwrap();

assert_eq!(
request_conversation_id.to_str().unwrap(),
conversation_id.to_string()
);
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
assert_eq!(
request_authorization.to_str().unwrap(),
"Bearer Test API Key"
);
let request = resp_mock.single_request();
assert_eq!(request.path(), "/v1/responses");
let request_conversation_id = request
.header("conversation_id")
.expect("conversation_id header");
let request_authorization = request
.header("authorization")
.expect("authorization header");
let request_originator = request.header("originator").expect("originator header");

assert_eq!(request_conversation_id, conversation_id.to_string());
assert_eq!(request_originator, "codex_cli_rs");
assert_eq!(request_authorization, "Bearer Test API Key");
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand Down Expand Up @@ -451,17 +436,7 @@ async fn chatgpt_auth_sends_correct_request() {
// Mock server
let server = MockServer::start().await;

// First request – must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");

Mock::given(method("POST"))
.and(path("/api/codex/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let resp_mock = mount_sse_once(&server, sse_completed("resp1")).await;

let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/api/codex", server.uri())),
Expand Down Expand Up @@ -499,27 +474,24 @@ async fn chatgpt_auth_sends_correct_request() {

wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;

// get request from the server
let requests = get_responses_requests(&server).await;
let request = requests
.first()
.expect("expected POST request to /responses");
let request_conversation_id = request.headers.get("conversation_id").unwrap();
let request_authorization = request.headers.get("authorization").unwrap();
let request_originator = request.headers.get("originator").unwrap();
let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap();
let request_body = request.body_json::<serde_json::Value>().unwrap();
let request = resp_mock.single_request();
assert_eq!(request.path(), "/api/codex/responses");
let request_conversation_id = request
.header("conversation_id")
.expect("conversation_id header");
let request_authorization = request
.header("authorization")
.expect("authorization header");
let request_originator = request.header("originator").expect("originator header");
let request_chatgpt_account_id = request
.header("chatgpt-account-id")
.expect("chatgpt-account-id header");
let request_body = request.body_json();

assert_eq!(
request_conversation_id.to_str().unwrap(),
conversation_id.to_string()
);
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
assert_eq!(
request_authorization.to_str().unwrap(),
"Bearer Access Token"
);
assert_eq!(request_chatgpt_account_id.to_str().unwrap(), "account_id");
assert_eq!(request_conversation_id, conversation_id.to_string());
assert_eq!(request_originator, "codex_cli_rs");
assert_eq!(request_authorization, "Bearer Access Token");
assert_eq!(request_chatgpt_account_id, "account_id");
assert!(request_body["stream"].as_bool().unwrap());
assert_eq!(
request_body["include"][0].as_str().unwrap(),
Expand Down Expand Up @@ -1107,17 +1079,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
"data: {\"type\":\"response.created\",\"response\":{}}\n\n",
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\"}}\n\n",
);

let template = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_body, "text/event-stream");

Mock::given(method("POST"))
.and(path("/openai/responses"))
.respond_with(template)
.expect(1)
.mount(&server)
.await;
let resp_mock = mount_sse_once(&server, sse_body.to_string()).await;

let provider = ModelProviderInfo {
name: "azure".into(),
Expand Down Expand Up @@ -1233,11 +1195,9 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
}
}

let requests = get_responses_requests(&server).await;
assert_eq!(requests.len(), 1, "expected a single POST request");
let body: serde_json::Value = requests[0]
.body_json()
.expect("request body to be valid JSON");
let request = resp_mock.single_request();
assert_eq!(request.path(), "/openai/responses");
let body = request.body_json();

assert_eq!(body["store"], serde_json::Value::Bool(true));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
Expand Down Expand Up @@ -1784,16 +1744,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
]"##;
let sse1 = core_test_support::load_sse_fixture_with_id_from_str(sse_raw, "resp1");

Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse1.clone(), "text/event-stream"),
)
.expect(3) // respond identically to the three sequential turns
.mount(&server)
.await;
let request_log = mount_sse_sequence(&server, vec![sse1.clone(), sse1.clone(), sse1]).await;

// Configure provider to point to mock server (Responses API) and use API key auth.
let model_provider = ModelProviderInfo {
Expand Down Expand Up @@ -1847,8 +1798,11 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;

// Inspect the three captured requests.
let requests = get_responses_requests(&server).await;
let requests = request_log.requests();
assert_eq!(requests.len(), 3, "expected 3 requests (one per turn)");
for request in &requests {
assert_eq!(request.path(), "/v1/responses");
}

// Replace full-array compare with tail-only raw JSON compare using a single hard-coded value.
let r3_tail_expected = json!([
Expand Down Expand Up @@ -1880,8 +1834,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() {
]);

let r3_input_array = requests[2]
.body_json::<serde_json::Value>()
.unwrap()
.body_json()
.get("input")
.and_then(|v| v.as_array())
.cloned()
Expand Down
Loading
Loading