From 0598172eaf9367b908505b2e2bf398000570eb8e Mon Sep 17 00:00:00 2001 From: ReactorScram Date: Mon, 29 Jan 2024 11:17:36 -0600 Subject: [PATCH] inital commit --- .github/workflows/_rust.yml | 37 ++++ Cargo.toml | 31 +++ README.md | 13 +- src/client.rs | 89 +++++++++ src/lib.rs | 247 ++++++++++++++++++++++++ src/main.rs | 6 + src/multi_process_tests.rs | 292 +++++++++++++++++++++++++++++ src/server.rs | 363 ++++++++++++++++++++++++++++++++++++ 8 files changed, 1076 insertions(+), 2 deletions(-) create mode 100755 .github/workflows/_rust.yml create mode 100755 Cargo.toml mode change 100644 => 100755 README.md create mode 100755 src/client.rs create mode 100755 src/lib.rs create mode 100755 src/main.rs create mode 100755 src/multi_process_tests.rs create mode 100755 src/server.rs diff --git a/.github/workflows/_rust.yml b/.github/workflows/_rust.yml new file mode 100755 index 0000000..6b99aac --- /dev/null +++ b/.github/workflows/_rust.yml @@ -0,0 +1,37 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + strategy: + fail-fast: false + matrix: + runs-on: + - windows-2019 + - windows-2022 + runs-on: ${{ matrix.runs-on }} + + steps: + - uses: actions/checkout@v3 + - name: cargo doc + env: RUSTDOCFLAGS: "-D warnings" + run: cargo doc --all-features --no-deps --document-private-items + - name: cargo fmt + run: cargo fmt -- --check + - name: cargo clippy + run: cargo clippy --all-targets --all-features -- -D warnings + - name: Single-process test + env: RUST_LOG=debug + run: cargo test + - name: Multi-process test + env: RUST_LOG=debug + run: cargo run + diff --git a/Cargo.toml b/Cargo.toml new file mode 100755 index 0000000..2cd127d --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "subzone" +version = "0.1.0" +keywords = ["command", "process", "subprocess", "worker"] +description = "Worker subprocesses with async IPC for Windows" +edition = "2021" + +[dependencies] +anyhow = { version = "1.0" } +clap = { version = "4.4", features = ["derive", "env"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = { version = "1.0", default-features = false } +tokio = { version = "1.33.0", features = ["io-util", "net", "process", "rt-multi-thread", "sync", "time"] } +tracing = "0.1.40" +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +uuid = { version = "1.7.0", features = ["v4"] } + +[target.'cfg(windows)'.dependencies.windows] +version = "0.52.0" +features = [ + # Needed for `CreateJobObjectA` + "Win32_Foundation", + # Needed for `CreateJobObjectA` + "Win32_Security", + # Needed for Windows to automatically kill child processes if the main process crashes + "Win32_System_JobObjects", + # Needed to check process ID of named pipe clients + "Win32_System_Pipes", + "Win32_System_Threading", +] diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 0e9505f..9b07133 --- a/README.md +++ b/README.md @@ -1,2 +1,11 @@ -# subzone -IPC for Windows +# Testing + +```bash +cargo test && cargo run && echo good +``` + +# Git history + +This repo split off from at 634a5439b54e459d4d0109b2b1f64c983abac139 + +See comment diff --git a/src/client.rs b/src/client.rs new file mode 100755 index 0000000..ceb32b1 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,89 @@ +use anyhow::Result; +use serde::{de::DeserializeOwned, Serialize}; +use std::marker::PhantomData; +use tokio::{ + io::AsyncWriteExt, + net::windows::named_pipe::{self, NamedPipeClient}, + sync::mpsc, +}; + +use crate::{read_deserialize, write_serialize, Error, ManagerMsgInternal, WorkerMsgInternal}; + +/// A client that's connected to a server +/// +/// Manual testing shows that if the corresponding Server's process crashes, Windows will +/// be nice and return errors for anything trying to read from the Client +pub struct Client { + pipe_writer: tokio::io::WriteHalf, + /// Needed to make `next` cancel-safe + read_rx: mpsc::Receiver>, + /// Needed to make `next` cancel-safe + reader_task: tokio::task::JoinHandle>, + _manager_msg: PhantomData, + _worker_msg: PhantomData, +} + +impl Client { + /// Creates a `Client` and echoes the security cookie back to the `Server` + /// + /// Doesn't block, fails instantly if the server isn't up. + pub async fn new(server_id: &str) -> Result { + let mut client = Client::new_unsecured(server_id)?; + let mut cookie = String::new(); + std::io::stdin().read_line(&mut cookie)?; + let cookie = WorkerMsgInternal::Cookie(cookie.trim().to_string()); + client.send_internal(&cookie).await?; + Ok(client) + } + + /// Creates a `Client`. Requires a Tokio context + /// + /// Doesn't block, will fail instantly if the server isn't ready + #[tracing::instrument(skip_all)] + pub(crate) fn new_unsecured(server_id: &str) -> Result { + let pipe = named_pipe::ClientOptions::new().open(server_id)?; + let (mut pipe_reader, pipe_writer) = tokio::io::split(pipe); + let (read_tx, read_rx) = mpsc::channel(1); + let reader_task = tokio::spawn(async move { + loop { + let msg = read_deserialize(&mut pipe_reader).await?; + read_tx.send(msg).await?; + } + }); + + Ok(Self { + pipe_writer, + read_rx, + reader_task, + _manager_msg: Default::default(), + _worker_msg: Default::default(), + }) + } + + pub async fn close(mut self) -> Result<()> { + self.pipe_writer.shutdown().await?; + self.reader_task.abort(); + tracing::debug!("Client closing gracefully"); + Ok(()) + } + + /// Receives a message from the server + /// + /// # Cancel safety + /// + /// This method is cancel-safe, internally it calls `tokio::sync::mpsc::Receiver::recv` + pub async fn next(&mut self) -> Result, Error> { + let buf = self.read_rx.recv().await.ok_or_else(|| Error::Eof)?; + let buf = std::str::from_utf8(&buf)?; + let msg = serde_json::from_str(buf)?; + Ok(msg) + } + + pub async fn send(&mut self, msg: W) -> Result<(), Error> { + self.send_internal(&WorkerMsgInternal::User(msg)).await + } + + async fn send_internal(&mut self, msg: &WorkerMsgInternal) -> Result<(), Error> { + write_serialize(&mut self.pipe_writer, msg).await + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100755 index 0000000..d14b0d7 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,247 @@ +//! Worker subprocesses with async IPC for Windows +//! +//! To run the unit tests and multi-process tests, use +//! ```bash +//! cargo test && cargo run && echo good +//! ``` +//! +//! # Security +//! +//! The IPC module uses Windows' named pipes primitive. +//! +//! These seem *relatively* secure. Chromium also uses them. +//! Privileged applications with admin powers or kernel +//! modules, like Wireshark, can snoop on named pipes, because they're running as root. +//! +//! Non-privileged processes can enumerate the names of named pipes. To prevent +//! a process that isn't our child from connecting to our named pipe, I check the +//! process ID before communicating, and then require the first message to be a cookie +//! echoed to the child's stdin and back through the pipe, similar to a CSRF token. +//! +//! Also by default, non-elevated processes cannot connect to named pipe servers +//! inside elevated processes. +//! +//! # Design +//! +//! subzone has these features: +//! +//! - Kill unresponsive worker if needed +//! - Graceful shutdown +//! - Automatically kill workers even if the manager process crashes +//! - Bails out if some other process tries to intercept IPC between the two processes + +use anyhow::Result; +use clap::Parser; +use serde::{Deserialize, Serialize}; +use std::{fmt::Debug, marker::Unpin}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +mod client; +mod server; +// Always enabled, since the integration tests can't run in `cargo test` yet +pub(crate) mod multi_process_tests; + +pub use client::Client; +pub use server::{LeakGuard, Server, SubcommandChild, SubcommandExit, Subprocess}; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Used to detected graceful named pipe closes + #[error("EOF")] + Eof, + /// Any IO error except EOF + #[error(transparent)] + Io(std::io::Error), + #[error(transparent)] + Json(#[from] serde_json::Error), + #[error("Something went wrong while converting message length to u32 or usize")] + MessageLength, + #[error("Protocol error, got Cookie or Shutdown at an incorrect time")] + Protocol, + #[error(transparent)] + Utf8(#[from] std::str::Utf8Error), +} + +#[derive(Deserialize, Serialize)] +pub enum ManagerMsgInternal { + Shutdown, + User(T), +} + +#[derive(Deserialize, Serialize)] +pub enum WorkerMsgInternal { + Cookie(String), + User(T), +} + +impl From for Error { + fn from(e: std::io::Error) -> Self { + if e.kind() == std::io::ErrorKind::UnexpectedEof { + Self::Eof + } else { + Self::Io(e) + } + } +} + +#[derive(Parser)] +struct Cli { + #[command(subcommand)] + cmd: Option, +} + +/// Don't use. This is just for internal tests that are difficult to do with `cargo test` +pub fn run_multi_process_tests() -> Result<()> { + let cli = Cli::parse(); + multi_process_tests::run(cli.cmd) +} + +/// Returns a random valid named pipe ID based on a UUIDv4 +/// +/// e.g. "\\.\pipe\dev.firezone.client\9508e87c-1c92-4630-bb20-839325d169bd" +/// +/// Normally you don't need to call this directly. Tests may need it to inject +/// a known pipe ID into a process controlled by the test. +pub(crate) fn random_pipe_id() -> String { + named_pipe_path(&uuid::Uuid::new_v4().to_string()) +} + +/// Returns a valid named pipe ID +/// +/// e.g. "\\.\pipe\dev.firezone.client\{path}" +pub(crate) fn named_pipe_path(path: &str) -> String { + format!(r"\\.\pipe\subzone\{path}") +} + +/// Reads a message from an async reader, with a 32-bit little-endian length prefix +async fn read_deserialize(reader: &mut R) -> Result, Error> { + let mut len_buf = [0u8; 4]; + reader.read_exact(&mut len_buf).await?; + let len = u32::from_le_bytes(len_buf); + tracing::trace!(?len, "reading message"); + let len = usize::try_from(len).map_err(|_| Error::MessageLength)?; + let mut buf = vec![0u8; len]; + reader.read_exact(&mut buf).await?; + Ok(buf) +} + +/// Writes a message to an async writer, with a 32-bit little-endian length prefix +async fn write_serialize( + writer: &mut W, + msg: &T, +) -> Result<(), Error> { + // Using JSON because `bincode` couldn't decode `ResourceDescription` + let buf = serde_json::to_string(msg)?; + let len = u32::try_from(buf.len()) + .map_err(|_| Error::MessageLength)? + .to_le_bytes(); + tracing::trace!(len = buf.len(), "writing message"); + writer.write_all(&len).await?; + writer.write_all(buf.as_bytes()).await?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::server::UnconnectedServer; + use super::*; + use crate::multi_process_tests::{Callback, ManagerMsg, WorkerMsg}; + use anyhow::Context; + use std::time::{Duration, Instant}; + use tokio::runtime::Runtime; + + /// Because it turns out `bincode` can't deserialize `ResourceDescription` or something. + #[test] + fn round_trip_serde() -> Result<()> { + let cb: WorkerMsg = WorkerMsg::Callback(Callback::OnUpdateResources(sample_resources())); + + let v = serde_json::to_string(&cb)?; + let roundtripped: WorkerMsg = serde_json::from_str(&v)?; + + assert_eq!(roundtripped, cb); + + Ok(()) + } + + /// Test just the happy path + /// It's hard to simulate a process crash because: + /// - If I Drop anything, Tokio will clean it up + /// - If I `std::mem::forget` anything, the test process is still running, so Windows will not clean it up + #[test] + #[tracing::instrument(skip_all)] + fn happy_path() -> Result<()> { + tracing_subscriber::fmt::try_init().ok(); + + let rt = Runtime::new()?; + rt.block_on(async move { + // Pretend we're in the main process + let (server, server_id) = UnconnectedServer::new()?; + + let worker_task = tokio::spawn(async move { + // Pretend we're in a worker process + let mut client: Client = Client::new_unsecured(&server_id)?; + + client + .send(WorkerMsg::Callback(Callback::OnUpdateResources( + sample_resources(), + ))) + .await?; + + // Handle requests from the main process + loop { + let Ok(ManagerMsgInternal::User(req)) = client.next().await else { + tracing::debug!("shutting down worker_task"); + break; + }; + tracing::debug!(?req, "worker_task got request"); + let resp = WorkerMsg::Response(req.clone()); + client.send(resp).await?; + } + client.close().await?; + Ok::<_, anyhow::Error>(()) + }); + + let mut server: Server = server.accept().await?; + + let start_time = Instant::now(); + + let cb = server + .next() + .await + .context("should have gotten a OnUpdateResources callback")?; + assert_eq!( + cb, + WorkerMsg::Callback(Callback::OnUpdateResources(sample_resources())) + ); + + server.send(ManagerMsg::Connect).await?; + assert_eq!( + server.next().await.unwrap(), + WorkerMsg::Response(ManagerMsg::Connect) + ); + server.send(ManagerMsg::Connect).await?; + assert_eq!( + server.next().await.unwrap(), + WorkerMsg::Response(ManagerMsg::Connect) + ); + + let elapsed = start_time.elapsed(); + assert!(elapsed < Duration::from_millis(20), "{:?}", elapsed); + + server.close().await?; + + // Make sure the worker 'process' exited + worker_task.await??; + + Ok::<_, anyhow::Error>(()) + })?; + Ok(()) + } + + fn sample_resources() -> Vec { + vec![ + "2efe9c25-bd92-49a0-99d7-8b92da014dd5".into(), + "613eaf56-6efa-45e5-88aa-ea4ad64d8c18".into(), + ] + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100755 index 0000000..f9b228b --- /dev/null +++ b/src/main.rs @@ -0,0 +1,6 @@ +//! Test driver for subzone's multi-process test, which are difficult to run +//! inside Cargo's test harness. + +fn main() -> anyhow::Result<()> { + subzone::run_multi_process_tests() +} diff --git a/src/multi_process_tests.rs b/src/multi_process_tests.rs new file mode 100755 index 0000000..221bfd5 --- /dev/null +++ b/src/multi_process_tests.rs @@ -0,0 +1,292 @@ +//! Integration and unit tests for IPC security, leak guard, etc. + +// TODO: Try making these into no-harness integration tests, if the IPC module +// ends up living long enough. See + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::time::{Duration, Instant}; +use tokio::time::timeout; + +use crate::{ + server::UnconnectedServer, Client, LeakGuard, ManagerMsgInternal, Server, SubcommandChild, + SubcommandExit, Subprocess, +}; + +#[derive(clap::Subcommand)] +pub(crate) enum Subcommand { + LeakManager { + #[arg(long, action = clap::ArgAction::Set)] + enable_protection: bool, + pipe_id: String, + }, + LeakWorker { + pipe_id: String, + }, + + ApiWorker { + pipe_id: String, + }, +} + +pub(crate) fn run(cmd: Option) -> Result<()> { + tracing_subscriber::fmt::init(); + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(async move { + match cmd { + None => { + test_api().await.context("test_api failed")?; + tracing::info!("test_api passed"); + test_leak(false).await.context("test_leak(false) failed")?; + test_leak(true).await.context("test_leak(true) failed")?; + tracing::info!("test_leak passed"); + tracing::info!("all tests passed"); + Ok(()) + } + Some(Subcommand::LeakManager { + enable_protection, + pipe_id, + }) => leak_manager(pipe_id, enable_protection), + Some(Subcommand::LeakWorker { pipe_id }) => leak_worker(pipe_id).await, + Some(Subcommand::ApiWorker { pipe_id }) => test_api_worker(pipe_id).await, + } + })?; + Ok(()) +} + +/// A message from the manager process +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +pub(crate) enum ManagerMsg { + Connect, +} + +/// A message from the worker process +#[derive(Debug, Deserialize, PartialEq, Serialize)] +pub(crate) enum WorkerMsg { + Callback(Callback), + Response(ManagerMsg), // For debugging, just say what manager request we're responding to +} + +#[derive(Debug, Deserialize, PartialEq, Serialize)] +pub(crate) enum Callback { + /// Cookie for named pipe security + Cookie(String), + DisconnectedTokenExpired, + /// Connlib disconnected and we should gracefully join the worker process + OnDisconnect, + OnUpdateResources(Vec), + TunnelReady, +} + +#[tracing::instrument(skip_all)] +async fn test_api() -> Result<()> { + let start_time = Instant::now(); + + let mut leak_guard = LeakGuard::new()?; + let args = ["api-worker"]; + let Subprocess { + mut server, + mut worker, + } = timeout( + Duration::from_secs(10), + Subprocess::new(&mut leak_guard, &args), + ) + .await??; + tracing::debug!("Manager got connection from worker"); + + let msg = server + .next() + .await + .context("should have gotten a TunnelReady callback")?; + assert_eq!(msg, WorkerMsg::Callback(Callback::TunnelReady)); + + let msg = server + .next() + .await + .context("should have gotten a OnUpdateResources callback")?; + assert_eq!( + msg, + WorkerMsg::Callback(Callback::OnUpdateResources(sample_resources())) + ); + + server.send(ManagerMsg::Connect).await?; + let msg: WorkerMsg = server + .next() + .await + .context("should have gotten a response to Connect")?; + anyhow::ensure!(msg == WorkerMsg::Response(ManagerMsg::Connect)); + + let elapsed = start_time.elapsed(); + anyhow::ensure!( + elapsed < Duration::from_millis(100), + "IPC took too long: {elapsed:?}" + ); + + let timer = Instant::now(); + server.close().await?; + let elapsed = timer.elapsed(); + anyhow::ensure!( + elapsed < Duration::from_millis(20), + "Server took too long to close: {elapsed:?}" + ); + + assert_eq!( + worker.wait_then_kill(Duration::from_secs(5)).await?, + SubcommandExit::Success + ); + + Ok(()) +} + +#[tracing::instrument(skip_all)] +async fn test_api_worker(pipe_id: String) -> Result<()> { + let mut client = Client::new(&pipe_id).await?; + + client + .send(WorkerMsg::Callback(Callback::TunnelReady)) + .await?; + + client + .send(WorkerMsg::Callback(Callback::OnUpdateResources( + sample_resources(), + ))) + .await?; + + tracing::trace!("Worker connected to named pipe"); + loop { + let ManagerMsgInternal::User(req) = client.next().await? else { + break; + }; + client.send(WorkerMsg::Response(req)).await?; + } + + let timer = Instant::now(); + client.close().await?; + let elapsed = timer.elapsed(); + anyhow::ensure!( + elapsed < Duration::from_millis(5), + "Client took too long to close: {elapsed:?}" + ); + Ok(()) +} + +/// Top-level function to test whether the process leak protection works. +/// +/// 1. Open a named pipe server +/// 2. Spawn a manager process, passing the pipe name to it +/// 3. The manager process spawns a worker process, passing the pipe name to it +/// 4. The manager process sets up leak protection on the worker process +/// 5. The worker process connects to our pipe server to confirm that it's up +/// 6. We SIGKILL the manager process +/// 7. Reading from the named pipe server should return an EOF since the worker process was killed by leak protection. +/// +/// # Research +/// - [Stack Overflow example](https://stackoverflow.com/questions/53208/how-do-i-automatically-destroy-child-processes-in-windows) +/// - [Chromium example](https://source.chromium.org/chromium/chromium/src/+/main:base/process/launch_win.cc;l=421;drc=b7d560c40ceb5283dba3e3d305abd9e2e7e926cd) +/// - [MSDN docs](https://learn.microsoft.com/en-us/windows/win32/api/jobapi2/nf-jobapi2-assignprocesstojobobject) +/// - [windows-rs docs](https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/System/JobObjects/fn.AssignProcessToJobObject.html) +#[tracing::instrument] +async fn test_leak(enable_protection: bool) -> Result<()> { + let (server, pipe_id) = UnconnectedServer::new()?; + let args = [ + "leak-manager", + "--enable-protection", + &enable_protection.to_string(), + &pipe_id, + ]; + let mut manager = SubcommandChild::new(&args)?; + let mut server: Server = + timeout(Duration::from_secs(5), server.accept()).await??; + + tracing::debug!("Actual pipe client PID = {}", server.client_pid()); + tracing::debug!("Harness accepted connection from Worker"); + + // Send a few requests to make sure the worker is connected and good + for _ in 0..3 { + server.send(ManagerMsg::Connect).await?; + server + .next() + .await + .expect("should have gotten a response to Connect"); + } + + timeout(Duration::from_secs(5), manager.process.kill()).await??; + tracing::debug!("Harness killed manager"); + + // I can't think of a good way to synchronize with the worker process stopping, + // so just give it 10 seconds for Windows to stop it. + for _ in 0..5 { + if server.send(ManagerMsg::Connect).await.is_err() { + tracing::info!("confirmed worker stopped responding"); + break; + } + if server.next().await.is_err() { + tracing::info!("confirmed worker stopped responding"); + break; + } + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } + + if enable_protection { + assert!( + server.send(ManagerMsg::Connect).await.is_err(), + "worker shouldn't be able to respond here, it should have stopped when the manager stopped" + ); + assert!( + server.next().await.is_err(), + "worker shouldn't be able to respond here, it should have stopped when the manager stopped" + ); + tracing::info!("enabling leak protection worked"); + } else { + assert!( + server.send(ManagerMsg::Connect).await.is_ok(), + "worker should still respond here, this failure means the test is invalid" + ); + assert!( + server.next().await.is_ok(), + "worker should still respond here, this failure means the test is invalid" + ); + tracing::info!("not enabling leak protection worked"); + } + Ok(()) +} + +#[tracing::instrument] +fn leak_manager(pipe_id: String, enable_protection: bool) -> Result<()> { + let mut leak_guard = LeakGuard::new()?; + + let worker = SubcommandChild::new(&["leak-worker", &pipe_id])?; + tracing::debug!("Expected worker PID = {}", worker.process.id().unwrap()); + + if enable_protection { + leak_guard.add_process(&worker.process)?; + } + + tracing::debug!("Manager set up leak protection, waiting for SIGKILL"); + loop { + std::thread::park(); + } +} + +#[tracing::instrument(skip_all)] +async fn leak_worker(pipe_id: String) -> Result<()> { + let mut client = Client::new_unsecured(&pipe_id)?; + tracing::debug!("Worker connected to named pipe"); + loop { + let ManagerMsgInternal::User(req) = client.next().await? else { + break; + }; + let resp = WorkerMsg::Response(req); + client.send(resp).await?; + } + client.close().await?; + Ok(()) +} + +// Duplicated because I want this to be private in both test modules +fn sample_resources() -> Vec { + vec![ + "2efe9c25-bd92-49a0-99d7-8b92da014dd5".into(), + "613eaf56-6efa-45e5-88aa-ea4ad64d8c18".into(), + ] +} diff --git a/src/server.rs b/src/server.rs new file mode 100755 index 0000000..1ac03c9 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,363 @@ +use anyhow::{bail, Context, Result}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + ffi::c_void, + marker::PhantomData, + os::windows::io::{AsHandle, AsRawHandle}, + process::Stdio, + time::Duration, +}; +use tokio::{ + io::{AsyncWriteExt, WriteHalf}, + net::windows::named_pipe::{self, NamedPipeServer}, + process::{self, Child}, + sync::mpsc, + time::timeout, +}; +use windows::Win32::{ + Foundation::HANDLE, + System::JobObjects::{ + AssignProcessToJobObject, CreateJobObjectA, JobObjectExtendedLimitInformation, + SetInformationJobObject, JOBOBJECT_EXTENDED_LIMIT_INFORMATION, + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, + System::Pipes::GetNamedPipeClientProcessId, +}; + +use crate::{read_deserialize, write_serialize, Error, ManagerMsgInternal, WorkerMsgInternal}; + +/// A named pipe server linked to a worker subprocess +pub struct Subprocess { + pub server: Server, + pub worker: SubcommandChild, +} + +impl Subprocess { + /// Returns a linked named pipe server and worker subprocess + /// + /// The process ID and cookie have already been checked for security + /// when this function returns. + pub async fn new(leak_guard: &mut LeakGuard, args: &[&str]) -> Result { + let (mut server, pipe_id) = + UnconnectedServer::new().context("couldn't create UnconnectedServer")?; + let mut process = process::Command::new( + std::env::current_exe().context("couldn't get current exe name")?, + ); + // Make the child's stdin piped so we can send it a security cookie. + process.stdin(Stdio::piped()); + for arg in args { + process.arg(arg); + } + process.arg(&pipe_id); + let mut process = process.spawn().context("couldn't spawn subprocess")?; + if let Err(error) = leak_guard.add_process(&process) { + tracing::error!("couldn't add subprocess to leak guard, attempting to kill subprocess"); + process.kill().await.ok(); + return Err(error.context("couldn't add subprocess to leak guard")); + } + let child_pid = process + .id() + .ok_or_else(|| anyhow::anyhow!("child process should have an ID"))?; + let mut worker = SubcommandChild { process }; + + // Accept the connection + server + .pipe + .connect() + .await + .context("expected a client connection")?; + let client_pid = server.client_pid()?; + + // Make sure our child process connected to our pipe, and not some 3rd-party process + if child_pid != client_pid { + bail!("PID of child process and pipe client should match"); + } + + // Make sure the process on the other end of the pipe knows the cookie we went + // to our child process' stdin + let mut child_stdin = worker + .process + .stdin + .take() + .ok_or_else(|| anyhow::anyhow!("couldn't get stdin of subprocess"))?; + let cookie = uuid::Uuid::new_v4().to_string(); + let line = format!("{}\n", cookie); + tracing::trace!(?cookie, "Sending cookie"); + child_stdin + .write_all(line.as_bytes()) + .await + .context("couldn't write cookie to subprocess stdin")?; + + let buf = read_deserialize(&mut server.pipe).await?; + let buf = std::str::from_utf8(&buf)?; + let WorkerMsgInternal::::Cookie(echoed_cookie) = serde_json::from_str(buf)? else { + bail!("didn't receive cookie from pipe client"); + }; + tracing::trace!(?echoed_cookie, "Got cookie back"); + if echoed_cookie != cookie { + bail!("cookie received from pipe client should match the cookie we sent to our child process"); + } + + let server = Server::new(server.pipe)?; + + Ok(Self { server, worker }) + } +} + +/// A server that accepts only one client +pub(crate) struct UnconnectedServer { + pub(crate) pipe: named_pipe::NamedPipeServer, +} + +impl UnconnectedServer { + /// Requires a Tokio context + pub(crate) fn new() -> Result<(Self, String)> { + let id = super::random_pipe_id(); + let this = Self::new_with_id(&id)?; + Ok((this, id)) + } + + fn client_pid(&self) -> Result { + get_client_pid(&self.pipe) + } + + fn new_with_id(id: &str) -> Result { + let pipe = named_pipe::ServerOptions::new() + .first_pipe_instance(true) + .create(id)?; + + Ok(Self { pipe }) + } + + /// Accept an incoming connection + /// + /// This will wait forever if the client never shows up. + /// Try pairing it with `tokio::time:timeout` + pub(crate) async fn accept(self) -> Result> { + self.pipe.connect().await?; + Server::new(self.pipe) + } +} + +/// A server that's connected to a client +/// +/// Manual testing shows that if the corresponding Client's process crashes, Windows will +/// be nice and return errors for anything trying to read from the Server +pub struct Server { + client_pid: u32, + pipe_writer: WriteHalf, + /// Needed to make `next` cancel-safe + read_rx: mpsc::Receiver>, + /// Needed to make `next` cancel-safe + _reader_task: tokio::task::JoinHandle>, + _manager_msg: PhantomData, + _worker_msg: PhantomData, +} + +impl Server { + #[tracing::instrument(skip_all)] + fn new(pipe: named_pipe::NamedPipeServer) -> Result { + let client_pid = get_client_pid(&pipe)?; + let (mut pipe_reader, pipe_writer) = tokio::io::split(pipe); + let (read_tx, read_rx) = mpsc::channel(1); + let _reader_task = tokio::spawn(async move { + loop { + let msg = read_deserialize(&mut pipe_reader).await?; + read_tx.send(msg).await?; + } + }); + + Ok(Self { + client_pid, + pipe_writer, + read_rx, + _reader_task, + _manager_msg: Default::default(), + _worker_msg: Default::default(), + }) + } + + /// Tells the pipe client to shutdown. + /// + /// Should be wrapped in a Tokio timeout in case the pipe client isn't responding. + pub async fn close(mut self) -> Result<()> { + write_serialize(&mut self.pipe_writer, &ManagerMsgInternal::::Shutdown).await?; + loop { + // Pump out the read half until it errors + match self.next().await { + Ok(_) => {} + Err(Error::Eof) => break, + Err(error) => { + tracing::error!(?error, "Error while shutting down the named pipe"); + break; + } + } + } + self.pipe_writer.shutdown().await?; + Ok(()) + } + + pub fn client_pid(&self) -> u32 { + self.client_pid + } + + /// Receives a message from the client + /// + /// # Cancel safety + /// + /// This method is cancel-safe, internally it calls `tokio::sync::mpsc::Receiver::recv` + pub async fn next(&mut self) -> Result { + let buf = self.read_rx.recv().await.ok_or_else(|| Error::Eof)?; + let buf = std::str::from_utf8(&buf)?; + let msg = serde_json::from_str(buf)?; + let WorkerMsgInternal::User(msg) = msg else { + return Err(Error::Protocol); + }; + Ok(msg) + } + + pub async fn send(&mut self, msg: M) -> Result<(), Error> { + write_serialize(&mut self.pipe_writer, &ManagerMsgInternal::User(msg)).await + } +} + +pub(crate) fn get_client_pid(pipe: &named_pipe::NamedPipeServer) -> Result { + let handle = pipe.as_handle(); + // SAFETY: TODO + let handle = HANDLE(unsafe { handle.as_raw_handle().offset_from(std::ptr::null()) }); + let mut pid = 0; + // SAFETY: Not sure if this can be called from two threads at once? + // But the pointer is valid at least. + unsafe { GetNamedPipeClientProcessId(handle, &mut pid) }?; + Ok(pid) +} + +/// `std::process::Child` but for a subcommand running from the same exe as +/// the current process. +/// +/// Unlike `std::process::Child`, `Drop` tries to join the process, and kills it +/// if it can't. +pub struct SubcommandChild { + pub(crate) process: Child, +} + +/// +#[derive(Debug, PartialEq)] +pub enum SubcommandExit { + /// The process exited gracefully + Success, + /// The process didn't crash, but it returned a non-success exit code + Failure, + /// The process had to be killed + Killed, +} + +impl SubcommandChild { + /// Launches the current exe as a subprocess with new arguments + /// + /// # Parameters + /// + /// * `args` - e.g. `["debug", "test", "ipc-worker"]` + pub fn new(args: &[&str]) -> Result { + // Need this binding to avoid a "temporary freed while still in use" error + let mut process = process::Command::new(std::env::current_exe()?); + process + // Make stdin a pipe so we can send the child a security cookie + .stdin(Stdio::piped()) + // Best-effort attempt to kill the child when this handle drops + // The Tokio docs say this is hard and we should just try to clean up + // before dropping + .kill_on_drop(true); + for arg in args { + process.arg(arg); + } + let process = process.spawn()?; + Ok(SubcommandChild { process }) + } + + /// Joins the subprocess without blocking, returning an error if the process doesn't stop + #[tracing::instrument(skip(self))] + pub(crate) fn wait_or_kill(&mut self) -> Result { + if let Ok(Some(status)) = self.process.try_wait() { + if status.success() { + Ok(SubcommandExit::Success) + } else { + Ok(SubcommandExit::Failure) + } + } else { + self.process.start_kill()?; + Ok(SubcommandExit::Killed) + } + } + + /// Waits `dur` for process to exit gracefully, and then `dur` to kill process if needed + pub async fn wait_then_kill(&mut self, dur: Duration) -> Result { + if let Ok(status) = timeout(dur, self.process.wait()).await { + return if status?.success() { + Ok(SubcommandExit::Success) + } else { + Ok(SubcommandExit::Failure) + }; + } + + timeout(dur, self.process.kill()).await??; + Ok(SubcommandExit::Killed) + } +} + +impl Drop for SubcommandChild { + fn drop(&mut self) { + match self.wait_or_kill() { + Ok(SubcommandExit::Killed) => tracing::error!("SubcommandChild was killed inside Drop"), + // Don't care - might have already been handled before Drop + Ok(_) => {} + Err(error) => tracing::error!(?error, "SubcommandChild could not be joined or killed"), + } + } +} + +/// Uses a Windows job object to kill child processes when the parent exits +/// +/// This contains a Windows handle that always leaks. Try to create one LeakGuard +/// and use it throughout your whole main process. +pub struct LeakGuard { + // Technically this job object handle does leak + job_object: HANDLE, +} + +impl LeakGuard { + pub fn new() -> Result { + // SAFETY: TODO + let job_object = unsafe { CreateJobObjectA(None, None) }?; + + let mut jeli = JOBOBJECT_EXTENDED_LIMIT_INFORMATION::default(); + jeli.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + // SAFETY: Windows shouldn't hang on to `jeli`. I'm not sure why this is unsafe. + unsafe { + SetInformationJobObject( + job_object, + JobObjectExtendedLimitInformation, + &jeli as *const JOBOBJECT_EXTENDED_LIMIT_INFORMATION as *const c_void, + u32::try_from(std::mem::size_of_val(&jeli))?, + ) + }?; + + Ok(Self { job_object }) + } + + /// Registers a child process with the LeakGuard so that Windows will kill the child if the manager exits or crashes + pub fn add_process(&mut self, process: &Child) -> Result<()> { + // Process IDs are not the same as handles, so get our handle to the process. + let process_handle = process + .raw_handle() + .ok_or_else(|| anyhow::anyhow!("Child should have a handle"))?; + // SAFETY: The docs say this is UB since the null pointer doesn't belong to the same allocated object as the handle. + // I couldn't get `OpenProcess` to work, and I don't have any other way to convert the process ID to a handle safely. + // Since the handles aren't pointers per se, maybe it'll work? + let process_handle = HANDLE(unsafe { process_handle.offset_from(std::ptr::null()) }); + // SAFETY: TODO + unsafe { AssignProcessToJobObject(self.job_object, process_handle) } + .context("AssignProcessToJobObject")?; + Ok(()) + } +}