diff --git a/src/killport.rs b/src/killport.rs index b630be2..97f0857 100644 --- a/src/killport.rs +++ b/src/killport.rs @@ -11,7 +11,9 @@ use std::{fmt::Display, io::Error}; /// Interface for killable targets such as native process and docker container. pub trait Killable { fn kill(&self, signal: KillportSignal) -> Result; + fn get_type(&self) -> KillableType; + fn get_name(&self) -> String; } @@ -102,6 +104,7 @@ impl KillportOperations for Killport { if docker_present && process.get_name().to_lowercase().contains("docker") { continue; } + target_killables.push(Box::new(process)); } } diff --git a/src/windows.rs b/src/windows.rs index 05f6745..19070b6 100644 --- a/src/windows.rs +++ b/src/windows.rs @@ -2,7 +2,7 @@ use crate::killport::{Killable, KillableType}; use log::info; use std::{ alloc::{alloc, dealloc, Layout}, - collections::HashSet, + collections::{HashMap, HashSet}, ffi::c_void, io::{Error, ErrorKind, Result}, ptr::addr_of, @@ -33,12 +33,17 @@ use windows_sys::Win32::{ #[derive(Debug)] pub struct WindowsProcess { pid: u32, - name: Option, + name: String, + parent: Option>, } impl WindowsProcess { - pub fn new(pid: u32, name: Option) -> Self { - Self { pid, name } + pub fn new(pid: u32, name: String) -> Self { + Self { + pid, + name, + parent: None, + } } } @@ -50,6 +55,7 @@ impl WindowsProcess { /// /// * `port` - Target port number pub fn find_target_processes(port: u16) -> Result> { + let lookup_table: ProcessLookupTable = ProcessLookupTable::create()?; let mut pids: HashSet = HashSet::new(); let processes = unsafe { @@ -65,16 +71,22 @@ pub fn find_target_processes(port: u16) -> Result> { // Find processes in the UDP IPv6 table use_extended_table::(port, &mut pids)?; - // Collect parents of the PIDs - collect_parents(&mut pids)?; + let mut processes: Vec = Vec::with_capacity(pids.len()); + + for pid in pids { + let process_name = lookup_table + .process_names + .get(&pid) + .cloned() + .unwrap_or_else(|| "Unknown".to_string()); - // Collect the processes - let mut processes: Vec = pids - .into_iter() - .map(|pid| WindowsProcess::new(pid, None)) - .collect(); + let mut process = WindowsProcess::new(pid, process_name); - lookup_proccess_names(&mut processes)?; + // Resolve the process parents + lookup_process_parents(&lookup_table, &mut process)?; + + processes.push(process); + } processes }; @@ -84,22 +96,18 @@ pub fn find_target_processes(port: u16) -> Result> { impl Killable for WindowsProcess { fn kill(&self, _signal: crate::signal::KillportSignal) -> Result { - let mut pids: HashSet = HashSet::new(); - pids.insert(self.pid); + let mut killed = false; + let mut next = Some(self); + while let Some(current) = next { + unsafe { + kill_process(current)?; + } - if pids.is_empty() { - return Ok(false); + killed = true; + next = current.parent.as_ref().map(|value| value.as_ref()); } - unsafe { - collect_parents(&mut pids)?; - - for pid in pids { - kill_process(pid)?; - } - }; - - Ok(true) + Ok(killed) } fn get_type(&self) -> KillableType { @@ -107,133 +115,222 @@ impl Killable for WindowsProcess { } fn get_name(&self) -> String { - match self.name.as_ref() { - Some(value) => value.to_string(), - None => "Unknown".to_string(), - } + self.name.to_string() } } -/// Collects the names for the processes in the provided collection of -/// processes. If name resolving fails that process is just "Unknown" +/// Checks if there is a running process with the provided pid /// /// # Arguments /// -/// * `processes` - The set of processes to resolve the names of -unsafe fn lookup_proccess_names(processes: &mut [WindowsProcess]) -> Result<()> { - // Request a snapshot handle - let handle: HANDLE = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); +/// * `pid` - The process ID to search for +fn is_process_running(pid: u32) -> Result { + let mut snapshot = WindowsProcessesSnapshot::create()?; + let is_running = snapshot.any(|entry| entry.th32ProcessID == pid); + Ok(is_running) +} - // Ensure we got a valid handle - if handle == INVALID_HANDLE_VALUE { - let error: WIN32_ERROR = GetLastError(); - return Err(Error::new( - ErrorKind::Other, - format!("Failed to get handle to processes: {:#x}", error), - )); - } +/// Lookup table for finding the names and parents for +/// a process using its pid +pub struct ProcessLookupTable { + /// Mapping from pid to name + process_names: HashMap, + /// Mapping from pid to parent pid + process_parents: HashMap, +} - // Allocate the memory to use for the entries - let mut entry: PROCESSENTRY32 = std::mem::zeroed(); - entry.dwSize = std::mem::size_of::() as u32; - - // Process the first item - if Process32First(handle, &mut entry) != FALSE { - loop { - let target_process = processes - .iter_mut() - .find(|proc| proc.pid == entry.th32ProcessID); - if let Some(target_process) = target_process { - let name_chars = entry - .szExeFile - .iter() - .copied() - .take_while(|value| *value != 0) - .collect(); - - let name = String::from_utf8(name_chars); - if let Ok(name) = name { - target_process.name = Some(name) - } - } +impl ProcessLookupTable { + pub fn create() -> Result { + let mut process_names: HashMap = HashMap::new(); + let mut process_parents: HashMap = HashMap::new(); - // Process the next entry - if Process32Next(handle, &mut entry) == FALSE { - break; - } - } + WindowsProcessesSnapshot::create()?.for_each(|entry| { + process_names.insert(entry.th32ProcessID, get_process_entry_name(&entry)); + process_parents.insert(entry.th32ProcessID, entry.th32ParentProcessID); + }); + + Ok(Self { + process_names, + process_parents, + }) } +} - // Close the handle now that its no longer needed - CloseHandle(handle); +/// Finds any parent processes of the provided process, adding +/// the process to the list of parents +/// +/// WARNING - This worked in the previous versions because the implementation +/// was flawwed and didn't properly look up the tree of parents, trying to kill +/// all of the parents causes problems since you'll end up killing explorer.exe +/// or some other windows sys process. So I've limited the depth to a single process deep +/// +/// # Arguments +/// +/// * `process` - The process to collect parents for +fn lookup_process_parents( + lookup_table: &ProcessLookupTable, + process: &mut WindowsProcess, +) -> Result<()> { + const MAX_PARENT_DEPTH: u8 = 1; + + let mut current_procces = process; + let mut depth = 0; + + while let Some(&parent_pid) = lookup_table.process_parents.get(¤t_procces.pid) { + if depth == MAX_PARENT_DEPTH { + break; + } + + let process_name = lookup_table + .process_names + .get(&parent_pid) + .cloned() + .unwrap_or_else(|| "Unknown".to_string()); + + // Add the new parent process + let parent = current_procces + .parent + .insert(Box::new(WindowsProcess::new(parent_pid, process_name))); + + current_procces = parent; + depth += 1 + } Ok(()) } -/// Collects all the parent processes for the PIDs in -/// the provided set +/// Parses the name from a process entry, falls back to "Unknown" +/// for invalid names /// /// # Arguments /// -/// * `pids` - The set to match PIDs from and insert PIDs into -unsafe fn collect_parents(pids: &mut HashSet) -> Result<()> { - // Request a snapshot handle - let handle: HANDLE = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); +/// * `entry` - The process entry +fn get_process_entry_name(entry: &PROCESSENTRY32) -> String { + let name_chars = entry + .szExeFile + .iter() + .copied() + .take_while(|value| *value != 0) + .collect(); + + let name = String::from_utf8(name_chars); + name.unwrap_or_else(|_| "Unknown".to_string()) +} - // Ensure we got a valid handle - if handle == INVALID_HANDLE_VALUE { - let error: WIN32_ERROR = GetLastError(); - return Err(Error::new( - ErrorKind::Other, - format!("Failed to get handle to processes: {:#x}", error), - )); +/// Snapshot of the running windows processes that can be iterated to find +/// information about various processes such as parent processes and +/// process names +/// +/// This is a safe abstraction +pub struct WindowsProcessesSnapshot { + /// Handle to the snapshot + handle: HANDLE, + /// The memory for reading process entries + entry: PROCESSENTRY32, + /// State of reading + state: SnapshotState, +} + +/// State for the snapshot iterator +pub enum SnapshotState { + /// Can read the first entry + First, + /// Can read the next entry + Next, + /// Reached the end, cannot iterate further always give [None] + End, +} + +impl WindowsProcessesSnapshot { + /// Creates a new process snapshot to iterate + pub fn create() -> Result { + // Request a snapshot handle + let handle: HANDLE = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) }; + + // Ensure we got a valid handle + if handle == INVALID_HANDLE_VALUE { + let error: WIN32_ERROR = unsafe { GetLastError() }; + return Err(Error::new( + ErrorKind::Other, + format!("Failed to get handle to processes: {:#x}", error), + )); + } + + // Allocate the memory to use for the entries + let mut entry: PROCESSENTRY32 = unsafe { std::mem::zeroed() }; + entry.dwSize = std::mem::size_of::() as u32; + + Ok(Self { + handle, + entry, + state: SnapshotState::First, + }) } +} - // Allocate the memory to use for the entries - let mut entry: PROCESSENTRY32 = std::mem::zeroed(); - entry.dwSize = std::mem::size_of::() as u32; +impl Iterator for WindowsProcessesSnapshot { + type Item = PROCESSENTRY32; - // Process the first item - if Process32First(handle, &mut entry) != FALSE { - let mut count = 0; + fn next(&mut self) -> Option { + match self.state { + SnapshotState::First => { + // Process the first entry + if unsafe { Process32First(self.handle, &mut self.entry) } == FALSE { + self.state = SnapshotState::End; + return None; + } + self.state = SnapshotState::Next; - loop { - // Add matching processes to the output - if pids.contains(&entry.th32ProcessID) { - pids.insert(entry.th32ParentProcessID); - count += 1; + Some(self.entry) } + SnapshotState::Next => { + // Process the next entry + if unsafe { Process32Next(self.handle, &mut self.entry) } == FALSE { + self.state = SnapshotState::End; + return None; + } - // Process the next entry - if Process32Next(handle, &mut entry) == FALSE { - break; + Some(self.entry) } + SnapshotState::End => None, } - - info!("Collected {} parent processes", count); } +} - // Close the handle now that its no longer needed - CloseHandle(handle); - - Ok(()) +impl Drop for WindowsProcessesSnapshot { + fn drop(&mut self) { + unsafe { + // Close the handle now that its no longer needed + CloseHandle(self.handle); + } + } } /// Kills a process with the provided process ID /// /// # Arguments /// -/// * `pid` - The process ID -unsafe fn kill_process(pid: u32) -> Result<()> { - info!("Killing process with PID {}", pid); +/// * `process` - The process +unsafe fn kill_process(process: &WindowsProcess) -> Result<()> { + info!("Killing process {}:{}", process.get_name(), process.pid); // Open the process handle with intent to terminate - let handle: HANDLE = OpenProcess(PROCESS_TERMINATE, FALSE, pid); + let handle: HANDLE = OpenProcess(PROCESS_TERMINATE, FALSE, process.pid); if handle == 0 { + // If the process just isn't running we can ignore the error + if !is_process_running(process.pid)? { + return Ok(()); + } + let error: WIN32_ERROR = GetLastError(); return Err(Error::new( ErrorKind::Other, - format!("Failed to obtain handle to process {}: {:#x}", pid, error), + format!( + "Failed to obtain handle to process {}:{}: {:#x}", + process.get_name(), + process.pid, + error + ), )); } @@ -247,7 +344,12 @@ unsafe fn kill_process(pid: u32) -> Result<()> { let error: WIN32_ERROR = GetLastError(); return Err(Error::new( ErrorKind::Other, - format!("Failed to terminate process {}: {:#x}", pid, error), + format!( + "Failed to terminate process {}:{}: {:#x}", + process.get_name(), + process.pid, + error + ), )); } diff --git a/tests/killport_windows_tests.rs b/tests/killport_windows_tests.rs index 6d19668..dd36b56 100644 --- a/tests/killport_windows_tests.rs +++ b/tests/killport_windows_tests.rs @@ -126,7 +126,7 @@ fn kill_service_by_port_dry_run() { #[test] fn check_process_type_and_name() { - let process = WindowsProcess::new(1234, Some("unique_process".to_string())); + let process = WindowsProcess::new(1234, "unique_process".to_string()); assert_eq!(process.get_type(), KillableType::Process); assert_eq!(process.get_name(), "unique_process");