Skip to content

Commit 43f33a5

Browse files
committed
make it more robust
1 parent 5561375 commit 43f33a5

File tree

5 files changed

+173
-7
lines changed

5 files changed

+173
-7
lines changed

codex-rs/utils/pty/src/pipe.rs

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,29 @@ use crate::process::SpawnedProcess;
2626
use libc;
2727

2828
struct PipeChildTerminator {
29+
#[cfg(windows)]
2930
pid: u32,
31+
#[cfg(unix)]
32+
process_group_id: u32,
3033
}
3134

3235
impl ChildTerminator for PipeChildTerminator {
3336
fn kill(&mut self) -> io::Result<()> {
34-
kill_process(self.pid)
35-
}
36-
}
37+
#[cfg(unix)]
38+
{
39+
return crate::process_group::kill_process_group(self.process_group_id);
40+
}
3741

38-
#[cfg(unix)]
39-
fn kill_process(pid: u32) -> io::Result<()> {
40-
crate::process_group::kill_process_group_by_pid(pid)
42+
#[cfg(windows)]
43+
{
44+
return kill_process(self.pid);
45+
}
46+
47+
#[cfg(not(any(unix, windows)))]
48+
{
49+
Ok(())
50+
}
51+
}
4152
}
4253

4354
#[cfg(windows)]
@@ -125,6 +136,8 @@ pub async fn spawn_process(
125136
let pid = child
126137
.id()
127138
.ok_or_else(|| io::Error::other("missing child pid"))?;
139+
#[cfg(unix)]
140+
let process_group_id = pid;
128141

129142
let stdin = child.stdin.take();
130143
let stdout = child.stdout.take();
@@ -159,6 +172,13 @@ pub async fn spawn_process(
159172
read_output_stream(BufReader::new(stderr), output_tx).await;
160173
})
161174
});
175+
let mut reader_abort_handles = Vec::new();
176+
if let Some(handle) = stdout_handle.as_ref() {
177+
reader_abort_handles.push(handle.abort_handle());
178+
}
179+
if let Some(handle) = stderr_handle.as_ref() {
180+
reader_abort_handles.push(handle.abort_handle());
181+
}
162182
let reader_handle = tokio::spawn(async move {
163183
if let Some(handle) = stdout_handle {
164184
let _ = handle.await;
@@ -189,8 +209,14 @@ pub async fn spawn_process(
189209
writer_tx,
190210
output_tx,
191211
initial_output_rx,
192-
Box::new(PipeChildTerminator { pid }),
212+
Box::new(PipeChildTerminator {
213+
#[cfg(windows)]
214+
pid,
215+
#[cfg(unix)]
216+
process_group_id,
217+
}),
193218
reader_handle,
219+
reader_abort_handles,
194220
writer_handle,
195221
wait_handle,
196222
exit_status,

codex-rs/utils/pty/src/process.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use portable_pty::SlavePty;
99
use tokio::sync::broadcast;
1010
use tokio::sync::mpsc;
1111
use tokio::sync::oneshot;
12+
use tokio::task::AbortHandle;
1213
use tokio::task::JoinHandle;
1314

1415
pub(crate) trait ChildTerminator: Send + Sync {
@@ -32,6 +33,7 @@ pub struct ProcessHandle {
3233
output_tx: broadcast::Sender<Vec<u8>>,
3334
killer: StdMutex<Option<Box<dyn ChildTerminator>>>,
3435
reader_handle: StdMutex<Option<JoinHandle<()>>>,
36+
reader_abort_handles: StdMutex<Vec<AbortHandle>>,
3537
writer_handle: StdMutex<Option<JoinHandle<()>>>,
3638
wait_handle: StdMutex<Option<JoinHandle<()>>>,
3739
exit_status: Arc<AtomicBool>,
@@ -55,6 +57,7 @@ impl ProcessHandle {
5557
initial_output_rx: broadcast::Receiver<Vec<u8>>,
5658
killer: Box<dyn ChildTerminator>,
5759
reader_handle: JoinHandle<()>,
60+
reader_abort_handles: Vec<AbortHandle>,
5861
writer_handle: JoinHandle<()>,
5962
wait_handle: JoinHandle<()>,
6063
exit_status: Arc<AtomicBool>,
@@ -67,6 +70,7 @@ impl ProcessHandle {
6770
output_tx,
6871
killer: StdMutex::new(Some(killer)),
6972
reader_handle: StdMutex::new(Some(reader_handle)),
73+
reader_abort_handles: StdMutex::new(reader_abort_handles),
7074
writer_handle: StdMutex::new(Some(writer_handle)),
7175
wait_handle: StdMutex::new(Some(wait_handle)),
7276
exit_status,
@@ -110,6 +114,11 @@ impl ProcessHandle {
110114
handle.abort();
111115
}
112116
}
117+
if let Ok(mut handles) = self.reader_abort_handles.lock() {
118+
for handle in handles.drain(..) {
119+
handle.abort();
120+
}
121+
}
113122
if let Ok(mut h) = self.writer_handle.lock() {
114123
if let Some(handle) = h.take() {
115124
handle.abort();

codex-rs/utils/pty/src/process_group.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//! - `set_process_group` is called in `pre_exec` so the child starts its own
66
//! process group.
77
//! - `kill_process_group_by_pid` targets the whole group (children/grandchildren)
8+
//! - `kill_process_group` targets a known process group ID directly
89
//! instead of a single PID.
910
//! - `set_parent_death_signal` (Linux only) arranges for the child to receive a
1011
//! `SIGTERM` when the parent exits, and re-checks the parent PID to avoid
@@ -94,6 +95,29 @@ pub fn kill_process_group_by_pid(_pid: u32) -> io::Result<()> {
9495
Ok(())
9596
}
9697

98+
#[cfg(unix)]
99+
/// Kill a specific process group ID (best-effort).
100+
pub fn kill_process_group(process_group_id: u32) -> io::Result<()> {
101+
use std::io::ErrorKind;
102+
103+
let pgid = process_group_id as libc::pid_t;
104+
let result = unsafe { libc::killpg(pgid, libc::SIGKILL) };
105+
if result == -1 {
106+
let err = io::Error::last_os_error();
107+
if err.kind() != ErrorKind::NotFound {
108+
return Err(err);
109+
}
110+
}
111+
112+
Ok(())
113+
}
114+
115+
#[cfg(not(unix))]
116+
/// No-op on non-Unix platforms.
117+
pub fn kill_process_group(_process_group_id: u32) -> io::Result<()> {
118+
Ok(())
119+
}
120+
97121
#[cfg(unix)]
98122
/// Kill the process group for a tokio child (best-effort).
99123
pub fn kill_child_process_group(child: &mut Child) -> io::Result<()> {

codex-rs/utils/pty/src/pty.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ pub async fn spawn_process(
158158
initial_output_rx,
159159
Box::new(PtyChildTerminator { killer }),
160160
reader_handle,
161+
Vec::new(),
161162
writer_handle,
162163
wait_handle,
163164
exit_status,

codex-rs/utils/pty/src/tests.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
use std::collections::HashMap;
22
use std::path::Path;
33

4+
#[cfg(unix)]
5+
use std::io::ErrorKind;
6+
47
use pretty_assertions::assert_eq;
58

69
use crate::spawn_pipe_process;
710
use crate::spawn_pty_process;
11+
#[cfg(unix)]
12+
use crate::SpawnedProcess;
813

914
fn find_python() -> Option<String> {
1015
for candidate in ["python3", "python"] {
@@ -20,6 +25,27 @@ fn find_python() -> Option<String> {
2025
None
2126
}
2227

28+
fn setsid_available() -> bool {
29+
if cfg!(windows) {
30+
return false;
31+
}
32+
std::process::Command::new("setsid")
33+
.arg("true")
34+
.status()
35+
.map(|status| status.success())
36+
.unwrap_or(false)
37+
}
38+
39+
#[cfg(unix)]
40+
fn process_exists(pid: i32) -> bool {
41+
let result = unsafe { libc::kill(pid, 0) };
42+
if result == 0 {
43+
return true;
44+
}
45+
let err = std::io::Error::last_os_error();
46+
err.kind() != ErrorKind::NotFound
47+
}
48+
2349
fn shell_command(program: &str) -> (String, Vec<String>) {
2450
if cfg!(windows) {
2551
let cmd = std::env::var("COMSPEC").unwrap_or_else(|_| "cmd.exe".to_string());
@@ -190,3 +216,83 @@ async fn pipe_drains_stderr_without_stdout_activity() -> anyhow::Result<()> {
190216

191217
Ok(())
192218
}
219+
220+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
221+
async fn pipe_terminate_aborts_detached_readers() -> anyhow::Result<()> {
222+
if !setsid_available() {
223+
eprintln!("setsid not available; skipping pipe_terminate_aborts_detached_readers");
224+
return Ok(());
225+
}
226+
227+
let env_map: HashMap<String, String> = std::env::vars().collect();
228+
let script =
229+
"setsid sh -c 'i=0; while [ $i -lt 200 ]; do echo tick; sleep 0.01; i=$((i+1)); done' &";
230+
let (program, args) = shell_command(script);
231+
let mut spawned = spawn_pipe_process(&program, &args, Path::new("."), &env_map, &None).await?;
232+
233+
let _ = tokio::time::timeout(
234+
tokio::time::Duration::from_millis(500),
235+
spawned.output_rx.recv(),
236+
)
237+
.await
238+
.map_err(|_| anyhow::anyhow!("expected detached output before terminate"))??;
239+
240+
spawned.session.terminate();
241+
let mut post_rx = spawned.session.output_receiver();
242+
243+
let post_terminate =
244+
tokio::time::timeout(tokio::time::Duration::from_millis(200), post_rx.recv()).await;
245+
246+
match post_terminate {
247+
Err(_) => Ok(()),
248+
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => Ok(()),
249+
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
250+
anyhow::bail!("unexpected output after terminate (lagged)")
251+
}
252+
Ok(Ok(chunk)) => anyhow::bail!(
253+
"unexpected output after terminate: {:?}",
254+
String::from_utf8_lossy(&chunk)
255+
),
256+
}
257+
}
258+
259+
#[cfg(unix)]
260+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
261+
async fn pipe_terminate_kills_process_group_after_exit() -> anyhow::Result<()> {
262+
let env_map: HashMap<String, String> = std::env::vars().collect();
263+
let (program, args) = shell_command("sleep 9999 & echo $!");
264+
let SpawnedProcess {
265+
session,
266+
output_rx,
267+
exit_rx,
268+
} = spawn_pipe_process(&program, &args, Path::new("."), &env_map, &None).await?;
269+
270+
let (output, code) = collect_output_until_exit(output_rx, exit_rx, 2_000).await;
271+
assert_eq!(code, 0, "expected shell to exit cleanly");
272+
273+
let pid_line = String::from_utf8_lossy(&output)
274+
.lines()
275+
.find(|line| !line.trim().is_empty())
276+
.unwrap_or("")
277+
.trim()
278+
.to_string();
279+
let pid: i32 = pid_line
280+
.parse()
281+
.map_err(|_| anyhow::anyhow!("failed to parse background pid from {pid_line:?}"))?;
282+
283+
session.terminate();
284+
285+
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(500);
286+
while tokio::time::Instant::now() < deadline && process_exists(pid) {
287+
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
288+
}
289+
290+
if process_exists(pid) {
291+
unsafe {
292+
libc::kill(pid, libc::SIGKILL);
293+
}
294+
anyhow::bail!("background pid still alive after terminate: {pid}");
295+
}
296+
297+
Ok(())
298+
}

0 commit comments

Comments
 (0)