diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ab7df33fa41..02635c1ba33 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1924,8 +1924,10 @@ dependencies = [ "anyhow", "filedescriptor", "lazy_static", + "libc", "log", "portable-pty", + "pretty_assertions", "shared_library", "tokio", "winapi", diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index 52a28d57533..2a918732383 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -32,6 +32,7 @@ use crate::sandboxing::SandboxPermissions; use crate::spawn::StdioPolicy; use crate::spawn::spawn_child_async; use crate::text_encoding::bytes_to_string_smart; +use codex_utils_pty::process_group::kill_child_process_group; pub const DEFAULT_EXEC_COMMAND_TIMEOUT_MS: u64 = 10_000; @@ -750,38 +751,6 @@ fn synthetic_exit_status(code: i32) -> ExitStatus { std::process::ExitStatus::from_raw(code as u32) } -#[cfg(unix)] -fn kill_child_process_group(child: &mut Child) -> io::Result<()> { - use std::io::ErrorKind; - - if let Some(pid) = child.id() { - let pid = pid as libc::pid_t; - let pgid = unsafe { libc::getpgid(pid) }; - if pgid == -1 { - let err = std::io::Error::last_os_error(); - if err.kind() != ErrorKind::NotFound { - return Err(err); - } - return Ok(()); - } - - let result = unsafe { libc::killpg(pgid, libc::SIGKILL) }; - if result == -1 { - let err = std::io::Error::last_os_error(); - if err.kind() != ErrorKind::NotFound { - return Err(err); - } - } - } - - Ok(()) -} - -#[cfg(not(unix))] -fn kill_child_process_group(_: &mut Child) -> io::Result<()> { - Ok(()) -} - #[cfg(test)] mod tests { use super::*; diff --git a/codex-rs/core/src/spawn.rs b/codex-rs/core/src/spawn.rs index 48c57c2cd6e..c1a8d4457a1 100644 --- a/codex-rs/core/src/spawn.rs +++ b/codex-rs/core/src/spawn.rs @@ -70,8 +70,8 @@ pub(crate) async fn spawn_child_async( #[cfg(target_os = "linux")] let parent_pid = libc::getpid(); cmd.pre_exec(move || { - if set_process_group && libc::setpgid(0, 0) == -1 { - return Err(std::io::Error::last_os_error()); + if set_process_group { + codex_utils_pty::process_group::set_process_group()?; } // This relies on prctl(2), so it only works on Linux. @@ -79,18 +79,7 @@ pub(crate) async fn spawn_child_async( { // This prctl call effectively requests, "deliver SIGTERM when my // current parent dies." - if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) == -1 { - return Err(std::io::Error::last_os_error()); - } - - // Though if there was a race condition and this pre_exec() block is - // run _after_ the parent (i.e., the Codex process) has already - // exited, then parent will be the closest configured "subreaper" - // ancestor process, or PID 1 (init). If the Codex process has exited - // already, so should the child process. - if libc::getppid() != parent_pid { - libc::raise(libc::SIGTERM); - } + codex_utils_pty::process_group::set_parent_death_signal(parent_pid)?; } Ok(()) }); diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index 7769f262a9e..d90ed3b644f 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -33,6 +33,8 @@ struct ExecCommandArgs { shell: Option, #[serde(default = "default_login")] login: bool, + #[serde(default = "default_tty")] + tty: bool, #[serde(default = "default_exec_yield_time_ms")] yield_time_ms: u64, #[serde(default)] @@ -67,6 +69,10 @@ fn default_login() -> bool { true } +fn default_tty() -> bool { + false +} + #[async_trait] impl ToolHandler for UnifiedExecHandler { fn kind(&self) -> ToolKind { @@ -124,6 +130,7 @@ impl ToolHandler for UnifiedExecHandler { let ExecCommandArgs { workdir, + tty, yield_time_ms, max_output_tokens, sandbox_permissions, @@ -173,6 +180,7 @@ impl ToolHandler for UnifiedExecHandler { yield_time_ms, max_output_tokens, workdir, + tty, sandbox_permissions, justification, }, diff --git a/codex-rs/core/src/tools/runtimes/unified_exec.rs b/codex-rs/core/src/tools/runtimes/unified_exec.rs index 58ca66fcbca..2505b10ed20 100644 --- a/codex-rs/core/src/tools/runtimes/unified_exec.rs +++ b/codex-rs/core/src/tools/runtimes/unified_exec.rs @@ -37,6 +37,7 @@ pub struct UnifiedExecRequest { pub command: Vec, pub cwd: PathBuf, pub env: HashMap, + pub tty: bool, pub sandbox_permissions: SandboxPermissions, pub justification: Option, pub exec_approval_requirement: ExecApprovalRequirement, @@ -46,6 +47,7 @@ pub struct UnifiedExecRequest { pub struct UnifiedExecApprovalKey { pub command: Vec, pub cwd: PathBuf, + pub tty: bool, pub sandbox_permissions: SandboxPermissions, } @@ -58,6 +60,7 @@ impl UnifiedExecRequest { command: Vec, cwd: PathBuf, env: HashMap, + tty: bool, sandbox_permissions: SandboxPermissions, justification: Option, exec_approval_requirement: ExecApprovalRequirement, @@ -66,6 +69,7 @@ impl UnifiedExecRequest { command, cwd, env, + tty, sandbox_permissions, justification, exec_approval_requirement, @@ -96,6 +100,7 @@ impl Approvable for UnifiedExecRuntime<'_> { vec![UnifiedExecApprovalKey { command: req.command.clone(), cwd: req.cwd.clone(), + tty: req.tty, sandbox_permissions: req.sandbox_permissions, }] } @@ -189,7 +194,7 @@ impl<'a> ToolRuntime for UnifiedExecRunt .env_for(spec) .map_err(|err| ToolError::Codex(err.into()))?; self.manager - .open_session_with_exec_env(&exec_env) + .open_session_with_exec_env(&exec_env, req.tty) .await .map_err(|err| match err { UnifiedExecError::SandboxDenied { output, .. } => { diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index 0a66b414039..54369c7e4b9 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -168,6 +168,15 @@ fn create_exec_command_tool() -> ToolSpec { ), }, ), + ( + "tty".to_string(), + JsonSchema::Boolean { + description: Some( + "Whether to allocate a TTY for the command. Defaults to false (plain pipes); set to true to open a PTY and access TTY process." + .to_string(), + ), + } + ), ( "yield_time_ms".to_string(), JsonSchema::Number { diff --git a/codex-rs/core/src/unified_exec/mod.rs b/codex-rs/core/src/unified_exec/mod.rs index ae10054079f..d274932e8e1 100644 --- a/codex-rs/core/src/unified_exec/mod.rs +++ b/codex-rs/core/src/unified_exec/mod.rs @@ -77,6 +77,7 @@ pub(crate) struct ExecCommandRequest { pub yield_time_ms: u64, pub max_output_tokens: Option, pub workdir: Option, + pub tty: bool, pub sandbox_permissions: SandboxPermissions, pub justification: Option, } @@ -200,6 +201,7 @@ mod tests { yield_time_ms, max_output_tokens: None, workdir: None, + tty: true, sandbox_permissions: SandboxPermissions::UseDefault, justification: None, }, diff --git a/codex-rs/core/src/unified_exec/process_manager.rs b/codex-rs/core/src/unified_exec/process_manager.rs index 2e80156114f..f1c06388e15 100644 --- a/codex-rs/core/src/unified_exec/process_manager.rs +++ b/codex-rs/core/src/unified_exec/process_manager.rs @@ -131,6 +131,7 @@ impl UnifiedExecProcessManager { cwd.clone(), request.sandbox_permissions, request.justification, + request.tty, context, ) .await; @@ -471,21 +472,34 @@ impl UnifiedExecProcessManager { pub(crate) async fn open_session_with_exec_env( &self, env: &ExecEnv, + tty: bool, ) -> Result { let (program, args) = env .command .split_first() .ok_or(UnifiedExecError::MissingCommandLine)?; - let spawned = codex_utils_pty::spawn_pty_process( - program, - args, - env.cwd.as_path(), - &env.env, - &env.arg0, - ) - .await - .map_err(|err| UnifiedExecError::create_process(err.to_string()))?; + let spawn_result = if tty { + codex_utils_pty::pty::spawn_process( + program, + args, + env.cwd.as_path(), + &env.env, + &env.arg0, + ) + .await + } else { + codex_utils_pty::pipe::spawn_process( + program, + args, + env.cwd.as_path(), + &env.env, + &env.arg0, + ) + .await + }; + let spawned = + spawn_result.map_err(|err| UnifiedExecError::create_process(err.to_string()))?; UnifiedExecProcess::from_spawned(spawned, env.sandbox).await } @@ -495,6 +509,7 @@ impl UnifiedExecProcessManager { cwd: PathBuf, sandbox_permissions: SandboxPermissions, justification: Option, + tty: bool, context: &UnifiedExecContext, ) -> Result { let env = apply_unified_exec_env(create_env(&context.turn.shell_environment_policy)); @@ -517,6 +532,7 @@ impl UnifiedExecProcessManager { command.to_vec(), cwd, env, + tty, sandbox_permissions, justification, exec_approval_requirement, diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index dc6d47fe710..c8720025cc4 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -657,7 +657,13 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { .await; let expected_env_value = "propagated-env-http"; - let rmcp_http_server_bin = cargo_bin("test_streamable_http_server")?; + let rmcp_http_server_bin = match cargo_bin("test_streamable_http_server") { + Ok(path) => path, + Err(err) => { + eprintln!("test_streamable_http_server binary not available, skipping test: {err}"); + return Ok(()); + } + }; let listener = TcpListener::bind("127.0.0.1:0")?; let port = listener.local_addr()?.port(); @@ -819,7 +825,13 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> { let expected_token = "initial-access-token"; let client_id = "test-client-id"; let refresh_token = "initial-refresh-token"; - let rmcp_http_server_bin = cargo_bin("test_streamable_http_server")?; + let rmcp_http_server_bin = match cargo_bin("test_streamable_http_server") { + Ok(path) => path, + Err(err) => { + eprintln!("test_streamable_http_server binary not available, skipping test: {err}"); + return Ok(()); + } + }; let listener = TcpListener::bind("127.0.0.1:0")?; let port = listener.local_addr()?.port(); diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index 2e2fd34e67e..7ad97416878 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -37,6 +37,7 @@ use regex_lite::Regex; use serde_json::Value; use serde_json::json; use tokio::time::Duration; +use which::which; fn extract_output_text(item: &Value) -> Option<&str> { item.get("output").and_then(|value| match value { @@ -1287,6 +1288,180 @@ async fn exec_command_reports_chunk_and_exit_metadata() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_defaults_to_pipe() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + skip_if_windows!(Ok(())); + + let python = match which("python").or_else(|_| which("python3")) { + Ok(path) => path, + Err(_) => { + eprintln!("python not found in PATH, skipping tty default test."); + return Ok(()); + } + }; + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::UnifiedExec); + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "uexec-default-pipe"; + let args = serde_json::json!({ + "cmd": format!("{} -c \"import sys; print(sys.stdin.isatty())\"", python.display()), + "yield_time_ms": 1500, + }); + + let responses = vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "exec_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ]; + let request_log = mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "check default pipe mode".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await; + + let requests = request_log.requests(); + assert!(!requests.is_empty(), "expected at least one POST request"); + let bodies = requests + .into_iter() + .map(|request| request.body_json()) + .collect::>(); + + let outputs = collect_tool_outputs(&bodies)?; + let output = outputs + .get(call_id) + .expect("missing default pipe unified exec output"); + let normalized = output.output.replace("\r\n", "\n"); + + assert!( + normalized.contains("False"), + "stdin should not be a tty by default: {normalized:?}" + ); + assert_eq!(output.exit_code, Some(0)); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_can_enable_tty() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + skip_if_windows!(Ok(())); + + let python = match which("python").or_else(|_| which("python3")) { + Ok(path) => path, + Err(_) => { + eprintln!("python not found in PATH, skipping tty enable test."); + return Ok(()); + } + }; + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::UnifiedExec); + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "uexec-tty-enabled"; + let args = serde_json::json!({ + "cmd": format!("{} -c \"import sys; print(sys.stdin.isatty())\"", python.display()), + "yield_time_ms": 1500, + "tty": true, + }); + + let responses = vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "exec_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ]; + let request_log = mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "check tty enabled".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await; + + let requests = request_log.requests(); + assert!(!requests.is_empty(), "expected at least one POST request"); + let bodies = requests + .into_iter() + .map(|request| request.body_json()) + .collect::>(); + + let outputs = collect_tool_outputs(&bodies)?; + let output = outputs + .get(call_id) + .expect("missing tty-enabled unified exec output"); + let normalized = output.output.replace("\r\n", "\n"); + + assert!( + normalized.contains("True"), + "stdin should be a tty when tty=true: {normalized:?}" + ); + assert_eq!(output.exit_code, Some(0)); + assert!(output.process_id.is_none(), "process should have exited"); + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn unified_exec_respects_early_exit_notifications() -> Result<()> { skip_if_no_network!(Ok(())); @@ -1404,6 +1579,7 @@ async fn write_stdin_returns_exit_metadata_and_clears_session() -> Result<()> { let start_args = serde_json::json!({ "cmd": "/bin/cat", "yield_time_ms": 500, + "tty": true, }); let send_args = serde_json::json!({ "chars": "hello unified exec\n", @@ -1563,6 +1739,7 @@ async fn unified_exec_emits_end_event_when_session_dies_via_stdin() -> Result<() let start_args = serde_json::json!({ "cmd": "/bin/cat", "yield_time_ms": 200, + "tty": true, }); let echo_call_id = "uexec-end-on-exit-echo"; @@ -1903,6 +2080,7 @@ PY let first_args = serde_json::json!({ "cmd": script, "yield_time_ms": 25, + "tty": true, }); let second_call_id = "uexec-lag-poll"; @@ -2285,6 +2463,7 @@ async fn unified_exec_python_prompt_under_seatbelt() -> Result<()> { let startup_args = serde_json::json!({ "cmd": format!("{} -i", python.display()), "yield_time_ms": 1_500, + "tty": true, }); let exit_call_id = "uexec-python-exit"; diff --git a/codex-rs/utils/pty/Cargo.toml b/codex-rs/utils/pty/Cargo.toml index 1a460ea3deb..ec221bfc94a 100644 --- a/codex-rs/utils/pty/Cargo.toml +++ b/codex-rs/utils/pty/Cargo.toml @@ -10,7 +10,10 @@ workspace = true [dependencies] anyhow = { workspace = true } portable-pty = { workspace = true } -tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync", "time"] } +tokio = { workspace = true, features = ["io-util", "macros", "process", "rt-multi-thread", "sync", "time"] } + +[dev-dependencies] +pretty_assertions = { workspace = true } [target.'cfg(windows)'.dependencies] filedescriptor = "0.8.3" @@ -27,3 +30,5 @@ winapi = { version = "0.3.9", features = [ "winerror", "winnt", ] } +[target.'cfg(unix)'.dependencies] +libc = { workspace = true } diff --git a/codex-rs/utils/pty/README.md b/codex-rs/utils/pty/README.md new file mode 100644 index 00000000000..22f2e3a89b6 --- /dev/null +++ b/codex-rs/utils/pty/README.md @@ -0,0 +1,56 @@ +# codex-utils-pty + +Lightweight helpers for spawning interactive processes either under a PTY (pseudo terminal) or regular pipes. The public API is minimal and mirrors both backends so callers can switch based on their needs (e.g., enabling or disabling TTY). + +## API surface + +- `spawn_pty_process(program, args, cwd, env, arg0)` → `SpawnedProcess` +- `spawn_pipe_process(program, args, cwd, env, arg0)` → `SpawnedProcess` +- `conpty_supported()` → `bool` (Windows only; always true elsewhere) +- `ProcessHandle` exposes: + - `writer_sender()` → `mpsc::Sender>` (stdin) + - `output_receiver()` → `broadcast::Receiver>` (stdout/stderr merged) + - `has_exited()`, `exit_code()`, `terminate()` +- `SpawnedProcess` bundles `handle`, `output_rx`, and `exit_rx` (oneshot exit code). + +## Usage examples + +```rust +use std::collections::HashMap; +use std::path::Path; +use codex_utils_pty::spawn_pty_process; + +# tokio_test::block_on(async { +let env_map: HashMap = std::env::vars().collect(); +let spawned = spawn_pty_process( + "bash", + &["-lc".into(), "echo hello".into()], + Path::new("."), + &env_map, + &None, +).await?; + +let writer = spawned.session.writer_sender(); +writer.send(b"exit\n".to_vec()).await?; + +// Collect output until the process exits. +let mut output_rx = spawned.output_rx; +let mut collected = Vec::new(); +while let Ok(chunk) = output_rx.try_recv() { + collected.extend_from_slice(&chunk); +} +let exit_code = spawned.exit_rx.await.unwrap_or(-1); +# let _ = (collected, exit_code); +# anyhow::Ok(()) +# }); +``` + +Swap in `spawn_pipe_process` for a non-TTY subprocess; the rest of the API stays the same. + +## Tests + +Unit tests live in `src/lib.rs` and cover both backends (PTY Python REPL and pipe-based stdin roundtrip). Run with: + +``` +cargo test -p codex-utils-pty -- --nocapture +``` diff --git a/codex-rs/utils/pty/src/lib.rs b/codex-rs/utils/pty/src/lib.rs index 9e05416ed81..037b0d761ca 100644 --- a/codex-rs/utils/pty/src/lib.rs +++ b/codex-rs/utils/pty/src/lib.rs @@ -1,271 +1,23 @@ -use core::fmt; -use std::collections::HashMap; -use std::io::ErrorKind; -use std::path::Path; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; -use std::sync::Mutex as StdMutex; -use std::time::Duration; - +pub mod pipe; +mod process; +pub mod process_group; +pub mod pty; +#[cfg(test)] +mod tests; #[cfg(windows)] mod win; -use anyhow::Result; -#[cfg(not(windows))] -use portable_pty::native_pty_system; -use portable_pty::CommandBuilder; -use portable_pty::MasterPty; -use portable_pty::PtySize; -use portable_pty::SlavePty; -use tokio::sync::broadcast; -use tokio::sync::mpsc; -use tokio::sync::oneshot; -use tokio::sync::Mutex as TokioMutex; -use tokio::task::JoinHandle; - -pub struct PtyPairWrapper { - pub _slave: Option>, - pub _master: Box, -} - -#[derive(Debug)] -pub struct ExecCommandSession { - writer_tx: mpsc::Sender>, - output_tx: broadcast::Sender>, - killer: StdMutex>>, - reader_handle: StdMutex>>, - writer_handle: StdMutex>>, - wait_handle: StdMutex>>, - exit_status: Arc, - exit_code: Arc>>, - // PtyPair must be preserved because the process will receive Control+C if the - // slave is closed - _pair: StdMutex, -} - -impl fmt::Debug for PtyPairWrapper { - fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result { - Ok(()) - } -} - -impl ExecCommandSession { - #[allow(clippy::too_many_arguments)] - pub fn new( - writer_tx: mpsc::Sender>, - output_tx: broadcast::Sender>, - initial_output_rx: broadcast::Receiver>, - killer: Box, - reader_handle: JoinHandle<()>, - writer_handle: JoinHandle<()>, - wait_handle: JoinHandle<()>, - exit_status: Arc, - exit_code: Arc>>, - pair: PtyPairWrapper, - ) -> (Self, broadcast::Receiver>) { - ( - Self { - writer_tx, - output_tx, - killer: StdMutex::new(Some(killer)), - reader_handle: StdMutex::new(Some(reader_handle)), - writer_handle: StdMutex::new(Some(writer_handle)), - wait_handle: StdMutex::new(Some(wait_handle)), - exit_status, - exit_code, - _pair: StdMutex::new(pair), - }, - initial_output_rx, - ) - } - - pub fn writer_sender(&self) -> mpsc::Sender> { - self.writer_tx.clone() - } - - pub fn output_receiver(&self) -> broadcast::Receiver> { - self.output_tx.subscribe() - } - - pub fn has_exited(&self) -> bool { - self.exit_status.load(std::sync::atomic::Ordering::SeqCst) - } - - pub fn exit_code(&self) -> Option { - self.exit_code.lock().ok().and_then(|guard| *guard) - } - - pub fn terminate(&self) { - if let Ok(mut killer_opt) = self.killer.lock() { - if let Some(mut killer) = killer_opt.take() { - let _ = killer.kill(); - } - } - - if let Ok(mut h) = self.reader_handle.lock() { - if let Some(handle) = h.take() { - handle.abort(); - } - } - if let Ok(mut h) = self.writer_handle.lock() { - if let Some(handle) = h.take() { - handle.abort(); - } - } - if let Ok(mut h) = self.wait_handle.lock() { - if let Some(handle) = h.take() { - handle.abort(); - } - } - } -} - -impl Drop for ExecCommandSession { - fn drop(&mut self) { - self.terminate(); - } -} - -#[derive(Debug)] -pub struct SpawnedPty { - pub session: ExecCommandSession, - pub output_rx: broadcast::Receiver>, - pub exit_rx: oneshot::Receiver, -} - -#[allow(unreachable_code)] -pub fn conpty_supported() -> bool { - // Annotation required because `win` can't be compiled on other OS. - #[cfg(windows)] - return win::conpty_supported(); - - true -} - -#[cfg(windows)] -fn platform_native_pty_system() -> Box { - Box::new(win::ConPtySystem::default()) -} - -#[cfg(not(windows))] -fn platform_native_pty_system() -> Box { - native_pty_system() -} - -pub async fn spawn_pty_process( - program: &str, - args: &[String], - cwd: &Path, - env: &HashMap, - arg0: &Option, -) -> Result { - if program.is_empty() { - anyhow::bail!("missing program for PTY spawn"); - } - - let pty_system = platform_native_pty_system(); - let pair = pty_system.openpty(PtySize { - rows: 24, - cols: 80, - pixel_width: 0, - pixel_height: 0, - })?; - - let mut command_builder = CommandBuilder::new(arg0.as_ref().unwrap_or(&program.to_string())); - command_builder.cwd(cwd); - command_builder.env_clear(); - for arg in args { - command_builder.arg(arg); - } - for (key, value) in env { - command_builder.env(key, value); - } - - let mut child = pair.slave.spawn_command(command_builder)?; - let killer = child.clone_killer(); - - let (writer_tx, mut writer_rx) = mpsc::channel::>(128); - let (output_tx, _) = broadcast::channel::>(256); - // Subscribe before starting the reader thread. - let initial_output_rx = output_tx.subscribe(); - - let mut reader = pair.master.try_clone_reader()?; - let output_tx_clone = output_tx.clone(); - let reader_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || { - let mut buf = [0u8; 8_192]; - loop { - match reader.read(&mut buf) { - Ok(0) => break, - Ok(n) => { - let _ = output_tx_clone.send(buf[..n].to_vec()); - } - Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, - Err(ref e) if e.kind() == ErrorKind::WouldBlock => { - std::thread::sleep(Duration::from_millis(5)); - continue; - } - Err(_) => break, - } - } - }); - - let writer = pair.master.take_writer()?; - let writer = Arc::new(TokioMutex::new(writer)); - let writer_handle: JoinHandle<()> = tokio::spawn({ - let writer = Arc::clone(&writer); - async move { - while let Some(bytes) = writer_rx.recv().await { - let mut guard = writer.lock().await; - use std::io::Write; - let _ = guard.write_all(&bytes); - let _ = guard.flush(); - } - } - }); - - let (exit_tx, exit_rx) = oneshot::channel::(); - let exit_status = Arc::new(AtomicBool::new(false)); - let wait_exit_status = Arc::clone(&exit_status); - let exit_code = Arc::new(StdMutex::new(None)); - let wait_exit_code = Arc::clone(&exit_code); - let wait_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || { - let code = match child.wait() { - Ok(status) => status.exit_code() as i32, - Err(_) => -1, - }; - wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst); - if let Ok(mut guard) = wait_exit_code.lock() { - *guard = Some(code); - } - let _ = exit_tx.send(code); - }); - - let pair = PtyPairWrapper { - _slave: if cfg!(windows) { - // Keep the slave handle alive on Windows to prevent the process from receiving Control+C - Some(pair.slave) - } else { - None - }, - _master: pair.master, - }; - - let (session, output_rx) = ExecCommandSession::new( - writer_tx, - output_tx, - initial_output_rx, - killer, - reader_handle, - writer_handle, - wait_handle, - exit_status, - exit_code, - pair, - ); - - Ok(SpawnedPty { - session, - output_rx, - exit_rx, - }) -} +/// Spawn a non-interactive process using regular pipes for stdin/stdout/stderr. +pub use pipe::spawn_process as spawn_pipe_process; +/// Handle for interacting with a spawned process (PTY or pipe). +pub use process::ProcessHandle; +/// Bundle of process handles plus output and exit receivers returned by spawn helpers. +pub use process::SpawnedProcess; +/// Backwards-compatible alias for ProcessHandle. +pub type ExecCommandSession = ProcessHandle; +/// Backwards-compatible alias for SpawnedProcess. +pub type SpawnedPty = SpawnedProcess; +/// Report whether ConPTY is available on this platform (Windows only). +pub use pty::conpty_supported; +/// Spawn a process attached to a PTY for interactive use. +pub use pty::spawn_process as spawn_pty_process; diff --git a/codex-rs/utils/pty/src/pipe.rs b/codex-rs/utils/pty/src/pipe.rs new file mode 100644 index 00000000000..c3dcd4ddcbe --- /dev/null +++ b/codex-rs/utils/pty/src/pipe.rs @@ -0,0 +1,232 @@ +use std::collections::HashMap; +use std::io; +use std::io::ErrorKind; +use std::path::Path; +use std::process::Stdio; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; + +use anyhow::Result; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::process::Command; +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; + +use crate::process::ChildTerminator; +use crate::process::ProcessHandle; +use crate::process::SpawnedProcess; + +#[cfg(target_os = "linux")] +use libc; + +struct PipeChildTerminator { + #[cfg(windows)] + pid: u32, + #[cfg(unix)] + process_group_id: u32, +} + +impl ChildTerminator for PipeChildTerminator { + fn kill(&mut self) -> io::Result<()> { + #[cfg(unix)] + { + crate::process_group::kill_process_group(self.process_group_id) + } + + #[cfg(windows)] + { + kill_process(self.pid) + } + + #[cfg(not(any(unix, windows)))] + { + Ok(()) + } + } +} + +#[cfg(windows)] +fn kill_process(pid: u32) -> io::Result<()> { + unsafe { + let handle = winapi::um::processthreadsapi::OpenProcess( + winapi::um::winnt::PROCESS_TERMINATE, + 0, + pid, + ); + if handle.is_null() { + return Err(io::Error::last_os_error()); + } + let success = winapi::um::processthreadsapi::TerminateProcess(handle, 1); + let err = io::Error::last_os_error(); + winapi::um::handleapi::CloseHandle(handle); + if success == 0 { + Err(err) + } else { + Ok(()) + } + } +} + +async fn read_output_stream(mut reader: R, output_tx: broadcast::Sender>) +where + R: AsyncRead + Unpin, +{ + let mut buf = vec![0u8; 8_192]; + loop { + match reader.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + let _ = output_tx.send(buf[..n].to_vec()); + } + Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, + Err(_) => break, + } + } +} + +/// Spawn a process using regular pipes (no PTY), returning handles for stdin, output, and exit. +pub async fn spawn_process( + program: &str, + args: &[String], + cwd: &Path, + env: &HashMap, + arg0: &Option, +) -> Result { + if program.is_empty() { + anyhow::bail!("missing program for pipe spawn"); + } + + let mut command = Command::new(program); + #[cfg(unix)] + if let Some(arg0) = arg0 { + command.arg0(arg0); + } + #[cfg(target_os = "linux")] + let parent_pid = unsafe { libc::getpid() }; + #[cfg(unix)] + unsafe { + command.pre_exec(move || { + crate::process_group::set_process_group()?; + #[cfg(target_os = "linux")] + crate::process_group::set_parent_death_signal(parent_pid)?; + Ok(()) + }); + } + #[cfg(not(unix))] + let _ = arg0; + command.current_dir(cwd); + command.env_clear(); + for (key, value) in env { + command.env(key, value); + } + for arg in args { + command.arg(arg); + } + command.stdin(Stdio::piped()); + command.stdout(Stdio::piped()); + command.stderr(Stdio::piped()); + + let mut child = command.spawn()?; + let pid = child + .id() + .ok_or_else(|| io::Error::other("missing child pid"))?; + #[cfg(unix)] + let process_group_id = pid; + + let stdin = child.stdin.take(); + let stdout = child.stdout.take(); + let stderr = child.stderr.take(); + + let (writer_tx, mut writer_rx) = mpsc::channel::>(128); + let (output_tx, _) = broadcast::channel::>(256); + let initial_output_rx = output_tx.subscribe(); + + let writer_handle = tokio::spawn({ + let writer = stdin.map(|w| Arc::new(tokio::sync::Mutex::new(w))); + async move { + while let Some(bytes) = writer_rx.recv().await { + if let Some(writer) = &writer { + let mut guard = writer.lock().await; + let _ = guard.write_all(&bytes).await; + let _ = guard.flush().await; + } + } + } + }); + + let stdout_handle = stdout.map(|stdout| { + let output_tx = output_tx.clone(); + tokio::spawn(async move { + read_output_stream(BufReader::new(stdout), output_tx).await; + }) + }); + let stderr_handle = stderr.map(|stderr| { + let output_tx = output_tx.clone(); + tokio::spawn(async move { + read_output_stream(BufReader::new(stderr), output_tx).await; + }) + }); + let mut reader_abort_handles = Vec::new(); + if let Some(handle) = stdout_handle.as_ref() { + reader_abort_handles.push(handle.abort_handle()); + } + if let Some(handle) = stderr_handle.as_ref() { + reader_abort_handles.push(handle.abort_handle()); + } + let reader_handle = tokio::spawn(async move { + if let Some(handle) = stdout_handle { + let _ = handle.await; + } + if let Some(handle) = stderr_handle { + let _ = handle.await; + } + }); + + let (exit_tx, exit_rx) = oneshot::channel::(); + let exit_status = Arc::new(AtomicBool::new(false)); + let wait_exit_status = Arc::clone(&exit_status); + let exit_code = Arc::new(StdMutex::new(None)); + let wait_exit_code = Arc::clone(&exit_code); + let wait_handle: JoinHandle<()> = tokio::spawn(async move { + let code = match child.wait().await { + Ok(status) => status.code().unwrap_or(-1), + Err(_) => -1, + }; + wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst); + if let Ok(mut guard) = wait_exit_code.lock() { + *guard = Some(code); + } + let _ = exit_tx.send(code); + }); + + let (handle, output_rx) = ProcessHandle::new( + writer_tx, + output_tx, + initial_output_rx, + Box::new(PipeChildTerminator { + #[cfg(windows)] + pid, + #[cfg(unix)] + process_group_id, + }), + reader_handle, + reader_abort_handles, + writer_handle, + wait_handle, + exit_status, + exit_code, + None, + ); + + Ok(SpawnedProcess { + session: handle, + output_rx, + exit_rx, + }) +} diff --git a/codex-rs/utils/pty/src/process.rs b/codex-rs/utils/pty/src/process.rs new file mode 100644 index 00000000000..5c487fd3866 --- /dev/null +++ b/codex-rs/utils/pty/src/process.rs @@ -0,0 +1,147 @@ +use core::fmt; +use std::io; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; + +use portable_pty::MasterPty; +use portable_pty::SlavePty; +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::AbortHandle; +use tokio::task::JoinHandle; + +pub(crate) trait ChildTerminator: Send + Sync { + fn kill(&mut self) -> io::Result<()>; +} + +pub struct PtyHandles { + pub _slave: Option>, + pub _master: Box, +} + +impl fmt::Debug for PtyHandles { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PtyHandles").finish() + } +} + +/// Handle for driving an interactive process (PTY or pipe). +pub struct ProcessHandle { + writer_tx: mpsc::Sender>, + output_tx: broadcast::Sender>, + killer: StdMutex>>, + reader_handle: StdMutex>>, + reader_abort_handles: StdMutex>, + writer_handle: StdMutex>>, + wait_handle: StdMutex>>, + exit_status: Arc, + exit_code: Arc>>, + // PtyHandles must be preserved because the process will receive Control+C if the + // slave is closed + _pty_handles: StdMutex>, +} + +impl fmt::Debug for ProcessHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ProcessHandle").finish() + } +} + +impl ProcessHandle { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + writer_tx: mpsc::Sender>, + output_tx: broadcast::Sender>, + initial_output_rx: broadcast::Receiver>, + killer: Box, + reader_handle: JoinHandle<()>, + reader_abort_handles: Vec, + writer_handle: JoinHandle<()>, + wait_handle: JoinHandle<()>, + exit_status: Arc, + exit_code: Arc>>, + pty_handles: Option, + ) -> (Self, broadcast::Receiver>) { + ( + Self { + writer_tx, + output_tx, + killer: StdMutex::new(Some(killer)), + reader_handle: StdMutex::new(Some(reader_handle)), + reader_abort_handles: StdMutex::new(reader_abort_handles), + writer_handle: StdMutex::new(Some(writer_handle)), + wait_handle: StdMutex::new(Some(wait_handle)), + exit_status, + exit_code, + _pty_handles: StdMutex::new(pty_handles), + }, + initial_output_rx, + ) + } + + /// Returns a channel sender for writing raw bytes to the child stdin. + pub fn writer_sender(&self) -> mpsc::Sender> { + self.writer_tx.clone() + } + + /// Returns a broadcast receiver that yields stdout/stderr chunks. + pub fn output_receiver(&self) -> broadcast::Receiver> { + self.output_tx.subscribe() + } + + /// True if the child process has exited. + pub fn has_exited(&self) -> bool { + self.exit_status.load(std::sync::atomic::Ordering::SeqCst) + } + + /// Returns the exit code if known. + pub fn exit_code(&self) -> Option { + self.exit_code.lock().ok().and_then(|guard| *guard) + } + + /// Attempts to kill the child and abort helper tasks. + pub fn terminate(&self) { + if let Ok(mut killer_opt) = self.killer.lock() { + if let Some(mut killer) = killer_opt.take() { + let _ = killer.kill(); + } + } + + if let Ok(mut h) = self.reader_handle.lock() { + if let Some(handle) = h.take() { + handle.abort(); + } + } + if let Ok(mut handles) = self.reader_abort_handles.lock() { + for handle in handles.drain(..) { + handle.abort(); + } + } + if let Ok(mut h) = self.writer_handle.lock() { + if let Some(handle) = h.take() { + handle.abort(); + } + } + if let Ok(mut h) = self.wait_handle.lock() { + if let Some(handle) = h.take() { + handle.abort(); + } + } + } +} + +impl Drop for ProcessHandle { + fn drop(&mut self) { + self.terminate(); + } +} + +/// Return value from spawn helpers (PTY or pipe). +#[derive(Debug)] +pub struct SpawnedProcess { + pub session: ProcessHandle, + pub output_rx: broadcast::Receiver>, + pub exit_rx: oneshot::Receiver, +} diff --git a/codex-rs/utils/pty/src/process_group.rs b/codex-rs/utils/pty/src/process_group.rs new file mode 100644 index 00000000000..ae77a36be08 --- /dev/null +++ b/codex-rs/utils/pty/src/process_group.rs @@ -0,0 +1,135 @@ +//! Process-group helpers shared by pipe/pty and shell command execution. +//! +//! This module centralizes the OS-specific pieces that ensure a spawned +//! command can be cleaned up reliably: +//! - `set_process_group` is called in `pre_exec` so the child starts its own +//! process group. +//! - `kill_process_group_by_pid` targets the whole group (children/grandchildren) +//! - `kill_process_group` targets a known process group ID directly +//! instead of a single PID. +//! - `set_parent_death_signal` (Linux only) arranges for the child to receive a +//! `SIGTERM` when the parent exits, and re-checks the parent PID to avoid +//! races during fork/exec. +//! +//! On non-Unix platforms these helpers are no-ops. + +use std::io; + +use tokio::process::Child; + +#[cfg(target_os = "linux")] +/// Ensure the child receives SIGTERM when the original parent dies. +/// +/// This should run in `pre_exec` and uses `parent_pid` captured before spawn to +/// avoid a race where the parent exits between fork and exec. +pub fn set_parent_death_signal(parent_pid: libc::pid_t) -> io::Result<()> { + if unsafe { libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) } == -1 { + return Err(io::Error::last_os_error()); + } + + if unsafe { libc::getppid() } != parent_pid { + unsafe { + libc::raise(libc::SIGTERM); + } + } + + Ok(()) +} + +#[cfg(not(target_os = "linux"))] +/// No-op on non-Linux platforms. +pub fn set_parent_death_signal(_parent_pid: i32) -> io::Result<()> { + Ok(()) +} + +#[cfg(unix)] +/// Put the calling process into its own process group. +/// +/// Intended for use in `pre_exec` so the child becomes the group leader. +pub fn set_process_group() -> io::Result<()> { + let result = unsafe { libc::setpgid(0, 0) }; + if result == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } +} + +#[cfg(not(unix))] +/// No-op on non-Unix platforms. +pub fn set_process_group() -> io::Result<()> { + Ok(()) +} + +#[cfg(unix)] +/// Kill the process group for the given PID (best-effort). +/// +/// This resolves the PGID for `pid` and sends SIGKILL to the whole group. +pub fn kill_process_group_by_pid(pid: u32) -> io::Result<()> { + use std::io::ErrorKind; + + let pid = pid as libc::pid_t; + let pgid = unsafe { libc::getpgid(pid) }; + if pgid == -1 { + let err = io::Error::last_os_error(); + if err.kind() != ErrorKind::NotFound { + return Err(err); + } + return Ok(()); + } + + let result = unsafe { libc::killpg(pgid, libc::SIGKILL) }; + if result == -1 { + let err = io::Error::last_os_error(); + if err.kind() != ErrorKind::NotFound { + return Err(err); + } + } + + Ok(()) +} + +#[cfg(not(unix))] +/// No-op on non-Unix platforms. +pub fn kill_process_group_by_pid(_pid: u32) -> io::Result<()> { + Ok(()) +} + +#[cfg(unix)] +/// Kill a specific process group ID (best-effort). +pub fn kill_process_group(process_group_id: u32) -> io::Result<()> { + use std::io::ErrorKind; + + let pgid = process_group_id as libc::pid_t; + let result = unsafe { libc::killpg(pgid, libc::SIGKILL) }; + if result == -1 { + let err = io::Error::last_os_error(); + if err.kind() != ErrorKind::NotFound { + return Err(err); + } + } + + Ok(()) +} + +#[cfg(not(unix))] +/// No-op on non-Unix platforms. +pub fn kill_process_group(_process_group_id: u32) -> io::Result<()> { + Ok(()) +} + +#[cfg(unix)] +/// Kill the process group for a tokio child (best-effort). +pub fn kill_child_process_group(child: &mut Child) -> io::Result<()> { + if let Some(pid) = child.id() { + return kill_process_group_by_pid(pid); + } + + Ok(()) +} + +#[cfg(not(unix))] +/// No-op on non-Unix platforms. +pub fn kill_child_process_group(_child: &mut Child) -> io::Result<()> { + Ok(()) +} diff --git a/codex-rs/utils/pty/src/pty.rs b/codex-rs/utils/pty/src/pty.rs new file mode 100644 index 00000000000..0f1ec2c8235 --- /dev/null +++ b/codex-rs/utils/pty/src/pty.rs @@ -0,0 +1,174 @@ +use std::collections::HashMap; +use std::io::ErrorKind; +use std::path::Path; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::time::Duration; + +use anyhow::Result; +#[cfg(not(windows))] +use portable_pty::native_pty_system; +use portable_pty::CommandBuilder; +use portable_pty::PtySize; +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; + +use crate::process::ChildTerminator; +use crate::process::ProcessHandle; +use crate::process::PtyHandles; +use crate::process::SpawnedProcess; + +/// Returns true when ConPTY support is available (Windows only). +#[cfg(windows)] +pub fn conpty_supported() -> bool { + crate::win::conpty_supported() +} + +/// Returns true when ConPTY support is available (non-Windows always true). +#[cfg(not(windows))] +pub fn conpty_supported() -> bool { + true +} + +struct PtyChildTerminator { + killer: Box, +} + +impl ChildTerminator for PtyChildTerminator { + fn kill(&mut self) -> std::io::Result<()> { + self.killer.kill() + } +} + +fn platform_native_pty_system() -> Box { + #[cfg(windows)] + { + Box::new(crate::win::ConPtySystem::default()) + } + + #[cfg(not(windows))] + { + native_pty_system() + } +} + +/// Spawn a process attached to a PTY, returning handles for stdin, output, and exit. +pub async fn spawn_process( + program: &str, + args: &[String], + cwd: &Path, + env: &HashMap, + arg0: &Option, +) -> Result { + if program.is_empty() { + anyhow::bail!("missing program for PTY spawn"); + } + + let pty_system = platform_native_pty_system(); + let pair = pty_system.openpty(PtySize { + rows: 24, + cols: 80, + pixel_width: 0, + pixel_height: 0, + })?; + + let mut command_builder = CommandBuilder::new(arg0.as_ref().unwrap_or(&program.to_string())); + command_builder.cwd(cwd); + command_builder.env_clear(); + for arg in args { + command_builder.arg(arg); + } + for (key, value) in env { + command_builder.env(key, value); + } + + let mut child = pair.slave.spawn_command(command_builder)?; + let killer = child.clone_killer(); + + let (writer_tx, mut writer_rx) = mpsc::channel::>(128); + let (output_tx, _) = broadcast::channel::>(256); + let initial_output_rx = output_tx.subscribe(); + + let mut reader = pair.master.try_clone_reader()?; + let output_tx_clone = output_tx.clone(); + let reader_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || { + let mut buf = [0u8; 8_192]; + loop { + match reader.read(&mut buf) { + Ok(0) => break, + Ok(n) => { + let _ = output_tx_clone.send(buf[..n].to_vec()); + } + Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + std::thread::sleep(Duration::from_millis(5)); + continue; + } + Err(_) => break, + } + } + }); + + let writer = pair.master.take_writer()?; + let writer = Arc::new(tokio::sync::Mutex::new(writer)); + let writer_handle: JoinHandle<()> = tokio::spawn({ + let writer = Arc::clone(&writer); + async move { + while let Some(bytes) = writer_rx.recv().await { + let mut guard = writer.lock().await; + use std::io::Write; + let _ = guard.write_all(&bytes); + let _ = guard.flush(); + } + } + }); + + let (exit_tx, exit_rx) = oneshot::channel::(); + let exit_status = Arc::new(AtomicBool::new(false)); + let wait_exit_status = Arc::clone(&exit_status); + let exit_code = Arc::new(StdMutex::new(None)); + let wait_exit_code = Arc::clone(&exit_code); + let wait_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || { + let code = match child.wait() { + Ok(status) => status.exit_code() as i32, + Err(_) => -1, + }; + wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst); + if let Ok(mut guard) = wait_exit_code.lock() { + *guard = Some(code); + } + let _ = exit_tx.send(code); + }); + + let handles = PtyHandles { + _slave: if cfg!(windows) { + Some(pair.slave) + } else { + None + }, + _master: pair.master, + }; + + let (handle, output_rx) = ProcessHandle::new( + writer_tx, + output_tx, + initial_output_rx, + Box::new(PtyChildTerminator { killer }), + reader_handle, + Vec::new(), + writer_handle, + wait_handle, + exit_status, + exit_code, + Some(handles), + ); + + Ok(SpawnedProcess { + session: handle, + output_rx, + exit_rx, + }) +} diff --git a/codex-rs/utils/pty/src/tests.rs b/codex-rs/utils/pty/src/tests.rs new file mode 100644 index 00000000000..cdf9f0824a1 --- /dev/null +++ b/codex-rs/utils/pty/src/tests.rs @@ -0,0 +1,242 @@ +use std::collections::HashMap; +use std::path::Path; + +use pretty_assertions::assert_eq; + +use crate::spawn_pipe_process; +use crate::spawn_pty_process; + +fn find_python() -> Option { + for candidate in ["python3", "python"] { + if let Ok(output) = std::process::Command::new(candidate) + .arg("--version") + .output() + { + if output.status.success() { + return Some(candidate.to_string()); + } + } + } + None +} + +fn setsid_available() -> bool { + if cfg!(windows) { + return false; + } + std::process::Command::new("setsid") + .arg("true") + .status() + .map(|status| status.success()) + .unwrap_or(false) +} + +fn shell_command(program: &str) -> (String, Vec) { + if cfg!(windows) { + let cmd = std::env::var("COMSPEC").unwrap_or_else(|_| "cmd.exe".to_string()); + (cmd, vec!["/C".to_string(), program.to_string()]) + } else { + ( + "/bin/sh".to_string(), + vec!["-c".to_string(), program.to_string()], + ) + } +} + +fn echo_sleep_command(marker: &str) -> String { + if cfg!(windows) { + format!("echo {marker} & ping -n 2 127.0.0.1 > NUL") + } else { + format!("echo {marker}; sleep 0.05") + } +} + +async fn collect_output_until_exit( + mut output_rx: tokio::sync::broadcast::Receiver>, + exit_rx: tokio::sync::oneshot::Receiver, + timeout_ms: u64, +) -> (Vec, i32) { + let mut collected = Vec::new(); + let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms); + tokio::pin!(exit_rx); + + loop { + tokio::select! { + res = output_rx.recv() => { + if let Ok(chunk) = res { + collected.extend_from_slice(&chunk); + } + } + res = &mut exit_rx => { + let code = res.unwrap_or(-1); + // On Windows (ConPTY in particular), it's possible to observe the exit notification + // before the final bytes are drained from the PTY reader thread. Drain for a brief + // "quiet" window to make output assertions deterministic. + let (quiet_ms, max_ms) = if cfg!(windows) { (200, 2_000) } else { (50, 500) }; + let quiet = tokio::time::Duration::from_millis(quiet_ms); + let max_deadline = + tokio::time::Instant::now() + tokio::time::Duration::from_millis(max_ms); + while tokio::time::Instant::now() < max_deadline { + match tokio::time::timeout(quiet, output_rx.recv()).await { + Ok(Ok(chunk)) => collected.extend_from_slice(&chunk), + Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => continue, + Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break, + Err(_) => break, + } + } + return (collected, code); + } + _ = tokio::time::sleep_until(deadline) => { + return (collected, -1); + } + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn pty_python_repl_emits_output_and_exits() -> anyhow::Result<()> { + let Some(python) = find_python() else { + eprintln!("python not found; skipping pty_python_repl_emits_output_and_exits"); + return Ok(()); + }; + + let env_map: HashMap = std::env::vars().collect(); + let spawned = spawn_pty_process(&python, &[], Path::new("."), &env_map, &None).await?; + let writer = spawned.session.writer_sender(); + let newline = if cfg!(windows) { "\r\n" } else { "\n" }; + writer + .send(format!("print('hello from pty'){newline}").into_bytes()) + .await?; + writer.send(format!("exit(){newline}").into_bytes()).await?; + + let timeout_ms = if cfg!(windows) { 10_000 } else { 5_000 }; + let (output, code) = + collect_output_until_exit(spawned.output_rx, spawned.exit_rx, timeout_ms).await; + let text = String::from_utf8_lossy(&output); + + assert!( + text.contains("hello from pty"), + "expected python output in PTY: {text:?}" + ); + assert_eq!(code, 0, "expected python to exit cleanly"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn pipe_process_round_trips_stdin() -> anyhow::Result<()> { + let Some(python) = find_python() else { + eprintln!("python not found; skipping pipe_process_round_trips_stdin"); + return Ok(()); + }; + + let args = vec![ + "-u".to_string(), + "-c".to_string(), + "import sys; print(sys.stdin.readline().strip());".to_string(), + ]; + let env_map: HashMap = std::env::vars().collect(); + let spawned = spawn_pipe_process(&python, &args, Path::new("."), &env_map, &None).await?; + let writer = spawned.session.writer_sender(); + writer.send(b"roundtrip\n".to_vec()).await?; + + let (output, code) = collect_output_until_exit(spawned.output_rx, spawned.exit_rx, 5_000).await; + let text = String::from_utf8_lossy(&output); + + assert!( + text.contains("roundtrip"), + "expected pipe process to echo stdin: {text:?}" + ); + assert_eq!(code, 0, "expected python -c to exit cleanly"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn pipe_and_pty_share_interface() -> anyhow::Result<()> { + let env_map: HashMap = std::env::vars().collect(); + + let (pipe_program, pipe_args) = shell_command(&echo_sleep_command("pipe_ok")); + let (pty_program, pty_args) = shell_command(&echo_sleep_command("pty_ok")); + + let pipe = + spawn_pipe_process(&pipe_program, &pipe_args, Path::new("."), &env_map, &None).await?; + let pty = spawn_pty_process(&pty_program, &pty_args, Path::new("."), &env_map, &None).await?; + + let (pipe_out, pipe_code) = + collect_output_until_exit(pipe.output_rx, pipe.exit_rx, 3_000).await; + let (pty_out, pty_code) = collect_output_until_exit(pty.output_rx, pty.exit_rx, 3_000).await; + + assert_eq!(pipe_code, 0); + assert_eq!(pty_code, 0); + assert!( + String::from_utf8_lossy(&pipe_out).contains("pipe_ok"), + "pipe output mismatch: {pipe_out:?}" + ); + assert!( + String::from_utf8_lossy(&pty_out).contains("pty_ok"), + "pty output mismatch: {pty_out:?}" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn pipe_drains_stderr_without_stdout_activity() -> anyhow::Result<()> { + let Some(python) = find_python() else { + eprintln!("python not found; skipping pipe_drains_stderr_without_stdout_activity"); + return Ok(()); + }; + + let script = "import sys\nchunk = 'E' * 65536\nfor _ in range(64):\n sys.stderr.write(chunk)\n sys.stderr.flush()\n"; + let args = vec!["-c".to_string(), script.to_string()]; + let env_map: HashMap = std::env::vars().collect(); + let spawned = spawn_pipe_process(&python, &args, Path::new("."), &env_map, &None).await?; + + let (output, code) = + collect_output_until_exit(spawned.output_rx, spawned.exit_rx, 10_000).await; + + assert_eq!(code, 0, "expected python to exit cleanly"); + assert!(!output.is_empty(), "expected stderr output to be drained"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn pipe_terminate_aborts_detached_readers() -> anyhow::Result<()> { + if !setsid_available() { + eprintln!("setsid not available; skipping pipe_terminate_aborts_detached_readers"); + return Ok(()); + } + + let env_map: HashMap = std::env::vars().collect(); + let script = + "setsid sh -c 'i=0; while [ $i -lt 200 ]; do echo tick; sleep 0.01; i=$((i+1)); done' &"; + let (program, args) = shell_command(script); + let mut spawned = spawn_pipe_process(&program, &args, Path::new("."), &env_map, &None).await?; + + let _ = tokio::time::timeout( + tokio::time::Duration::from_millis(500), + spawned.output_rx.recv(), + ) + .await + .map_err(|_| anyhow::anyhow!("expected detached output before terminate"))??; + + spawned.session.terminate(); + let mut post_rx = spawned.session.output_receiver(); + + let post_terminate = + tokio::time::timeout(tokio::time::Duration::from_millis(200), post_rx.recv()).await; + + match post_terminate { + Err(_) => Ok(()), + Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => Ok(()), + Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => { + anyhow::bail!("unexpected output after terminate (lagged)") + } + Ok(Ok(chunk)) => anyhow::bail!( + "unexpected output after terminate: {:?}", + String::from_utf8_lossy(&chunk) + ), + } +}