diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 4fc3665..1d4a7c6 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -19,6 +19,7 @@ jobs: - s390x-unknown-linux-gnu - aarch64-apple-darwin - x86_64-apple-darwin + - x86_64-pc-windows-gnu runs-on: ${{ (matrix.target == 'aarch64-apple-darwin' || matrix.target == 'x86_64-apple-darwin') && 'macos-latest' || 'ubuntu-latest' }} steps: - name: Checkout repository diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3862952..9a8932f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -64,7 +64,7 @@ jobs: - s390x-unknown-linux-gnu - aarch64-apple-darwin - x86_64-apple-darwin - # - x86_64-pc-windows-gnu + - x86_64-pc-windows-gnu runs-on: ${{ (matrix.target == 'aarch64-apple-darwin' || matrix.target == 'x86_64-apple-darwin') && 'macos-latest' || 'ubuntu-latest' }} steps: - name: Checkout repository diff --git a/src/cli.rs b/src/cli.rs index 7e01424..c9eb880 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,8 +1,8 @@ use clap::{Parser, ValueEnum}; use clap_verbosity_flag::{Verbosity, WarnLevel}; use core::fmt; -use nix::sys::signal::Signal; -use std::str::FromStr; + +use crate::signal::KillportSignal; /// Modes of operation for killport. #[derive(Debug, Clone, Copy, PartialEq, ValueEnum)] @@ -67,7 +67,7 @@ pub struct KillPortArgs { default_value = "sigkill", value_parser = parse_signal )] - pub signal: Signal, + pub signal: KillportSignal, /// A verbosity flag to control the level of logging output. #[command(flatten)] @@ -81,14 +81,6 @@ pub struct KillPortArgs { pub dry_run: bool, } -fn parse_signal(arg: &str) -> Result { - let str_arg = arg.parse::(); - match str_arg { - Ok(str_arg) => { - let signal_str = str_arg.to_uppercase(); - let signal = Signal::from_str(signal_str.as_str())?; - return Ok(signal); - } - Err(e) => Err(std::io::Error::new(std::io::ErrorKind::Other, e)), - } +fn parse_signal(arg: &str) -> Result { + arg.to_uppercase().parse() } diff --git a/src/docker.rs b/src/docker.rs index 1a4f3f1..b53011c 100644 --- a/src/docker.rs +++ b/src/docker.rs @@ -1,7 +1,7 @@ +use crate::signal::KillportSignal; use bollard::container::{KillContainerOptions, ListContainersOptions}; use bollard::Docker; use log::debug; -use nix::sys::signal::Signal; use std::collections::HashMap; use std::io::Error; use tokio::runtime::Runtime; @@ -17,7 +17,7 @@ impl DockerContainer { /// /// * `name` - A container name. /// * `signal` - A enum value representing the signal type. - pub fn kill_container(name: &String, signal: Signal) -> Result<(), Error> { + pub fn kill_container(name: &str, signal: KillportSignal) -> Result<(), Error> { let rt = Runtime::new()?; rt.block_on(async { let docker = Docker::connect_with_socket_defaults() @@ -63,11 +63,7 @@ impl DockerContainer { .as_ref()? .first() .map(|name| DockerContainer { - name: if name.starts_with('/') { - name[1..].to_string() - } else { - name.clone() - }, + name: name.strip_prefix('/').unwrap_or(name).to_string(), }) }) .collect()) diff --git a/src/killport.rs b/src/killport.rs index a17e7c8..97f0857 100644 --- a/src/killport.rs +++ b/src/killport.rs @@ -1,64 +1,34 @@ -use crate::cli::Mode; use crate::docker::DockerContainer; #[cfg(target_os = "linux")] use crate::linux::find_target_processes; #[cfg(target_os = "macos")] use crate::macos::find_target_processes; -use log::info; -use nix::sys::signal::{kill, Signal}; -use nix::unistd::Pid; -use std::io::Error; - -#[derive(Debug)] -pub struct NativeProcess { - /// System native process ID. - pub pid: Pid, - pub name: String, -} +#[cfg(target_os = "windows")] +use crate::windows::find_target_processes; +use crate::{cli::Mode, signal::KillportSignal}; +use std::{fmt::Display, io::Error}; /// Interface for killable targets such as native process and docker container. pub trait Killable { - fn kill(&self, signal: Signal) -> Result; - fn get_type(&self) -> String; + fn kill(&self, signal: KillportSignal) -> Result; + + fn get_type(&self) -> KillableType; + fn get_name(&self) -> String; } -impl Killable for NativeProcess { - /// Entry point to kill the linux native process. - /// - /// # Arguments - /// - /// * `signal` - A enum value representing the signal type. - fn kill(&self, signal: Signal) -> Result { - info!("Killing process '{}' with PID {}", self.name, self.pid); - - kill(self.pid, signal).map(|_| true).map_err(|e| { - Error::new( - std::io::ErrorKind::Other, - format!( - "Failed to kill process '{}' with PID {}: {}", - self.name, self.pid, e - ), - ) - }) - } - - /// Returns the type of the killable target. - /// - /// This method is used to identify the type of the target (either a native process or a Docker container) - /// that is being handled. This information can be useful for logging, error handling, or other needs - /// where type of the target is relevant. - /// - /// # Returns - /// - /// * `String` - A string that describes the type of the killable target. For a `NativeProcess` it will return "process", - /// and for a `DockerContainer` it will return "container". - fn get_type(&self) -> String { - "process".to_string() - } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum KillableType { + Process, + Container, +} - fn get_name(&self) -> String { - self.name.to_string() +impl Display for KillableType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + KillableType::Process => "process", + KillableType::Container => "container", + }) } } @@ -68,10 +38,8 @@ impl Killable for DockerContainer { /// # Arguments /// /// * `signal` - A enum value representing the signal type. - fn kill(&self, signal: Signal) -> Result { - if let Err(err) = Self::kill_container(&self.name, signal) { - return Err(err); - } + fn kill(&self, signal: KillportSignal) -> Result { + Self::kill_container(&self.name, signal)?; Ok(true) } @@ -84,10 +52,10 @@ impl Killable for DockerContainer { /// /// # Returns /// - /// * `String` - A string that describes the type of the killable target. For a `NativeProcess` it will return "process", + /// * `String` - A string that describes the type of the killable target. For a `UnixProcess` it will return "process", /// and for a `DockerContainer` it will return "container". - fn get_type(&self) -> String { - "container".to_string() + fn get_type(&self) -> KillableType { + KillableType::Container } fn get_name(&self) -> String { @@ -104,10 +72,10 @@ pub trait KillportOperations { fn kill_service_by_port( &self, port: u16, - signal: Signal, + signal: KillportSignal, mode: Mode, dry_run: bool, - ) -> Result, Error>; + ) -> Result, Error>; } pub struct Killport; @@ -133,9 +101,10 @@ impl KillportOperations for Killport { for process in target_processes { // Check if the process name contains 'docker' and skip if in docker mode - if docker_present && process.name.to_lowercase().contains("docker") { + if docker_present && process.get_name().to_lowercase().contains("docker") { continue; } + target_killables.push(Box::new(process)); } } @@ -166,10 +135,10 @@ impl KillportOperations for Killport { fn kill_service_by_port( &self, port: u16, - signal: Signal, + signal: KillportSignal, mode: Mode, dry_run: bool, - ) -> Result, Error> { + ) -> Result, Error> { let mut results = Vec::new(); let target_killables = self.find_target_killables(port, mode)?; // Use the existing function to find targets @@ -179,7 +148,7 @@ impl KillportOperations for Killport { results.push((killable.get_type(), killable.get_name())); } else { // In actual mode, attempt to kill the entity and collect its information if successful - if killable.kill(signal)? { + if killable.kill(signal.clone())? { results.push((killable.get_type(), killable.get_name())); } } diff --git a/src/lib.rs b/src/lib.rs index 97f84b9..49c452d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,11 @@ pub mod cli; pub mod docker; pub mod killport; +pub mod signal; + +#[cfg(unix)] +pub mod unix; + #[cfg(target_os = "linux")] pub mod linux; #[cfg(target_os = "macos")] diff --git a/src/linux.rs b/src/linux.rs index de0151a..d9d5d7d 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -1,4 +1,4 @@ -use crate::killport::NativeProcess; +use crate::unix::UnixProcess; use log::debug; use nix::unistd::Pid; @@ -75,8 +75,8 @@ fn find_target_inodes(port: u16) -> Vec { /// # Arguments /// /// * `inodes` - Target inodes -pub fn find_target_processes(port: u16) -> Result, Error> { - let mut target_pids: Vec = vec![]; +pub fn find_target_processes(port: u16) -> Result, Error> { + let mut target_pids: Vec = vec![]; let inodes = find_target_inodes(port); for inode in inodes { @@ -96,10 +96,7 @@ pub fn find_target_processes(port: u16) -> Result, Error> { .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))? .join(" "); debug!("Found process '{}' with PID {}", name, process.pid()); - target_pids.push(NativeProcess { - pid: Pid::from_raw(process.pid), - name: name, - }); + target_pids.push(UnixProcess::new(Pid::from_raw(process.pid), name)); } } } diff --git a/src/macos.rs b/src/macos.rs index bdcbd1a..b7114f9 100644 --- a/src/macos.rs +++ b/src/macos.rs @@ -1,4 +1,4 @@ -use crate::killport::NativeProcess; +use crate::unix::UnixProcess; use libproc::libproc::file_info::pidfdinfo; use libproc::libproc::file_info::{ListFDs, ProcFDType}; @@ -16,8 +16,8 @@ use std::io; /// # Arguments /// /// * `port` - Target port number -pub fn find_target_processes(port: u16) -> Result, io::Error> { - let mut target_pids: Vec = vec![]; +pub fn find_target_processes(port: u16) -> Result, io::Error> { + let mut target_pids: Vec = vec![]; if let Ok(procs) = pids_by_type(ProcFilter::All) { for p in procs { @@ -56,10 +56,10 @@ pub fn find_target_processes(port: u16) -> Result, io::Error> "Found process '{}' with PID {} listening on port {}", process_name, pid, port ); - target_pids.push(NativeProcess { - pid: Pid::from_raw(pid), - name: process_name, - }); + target_pids.push(UnixProcess::new( + Pid::from_raw(pid), + process_name, + )); } } _ => (), diff --git a/src/main.rs b/src/main.rs index a4dedc0..8895564 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,7 @@ fn main() { // Attempt to kill processes listening on specified ports for port in args.ports { - match killport.kill_service_by_port(port, args.signal, args.mode, args.dry_run) { + match killport.kill_service_by_port(port, args.signal.clone(), args.mode, args.dry_run) { Ok(killed_services) => { if killed_services.is_empty() { println!("No {} found using port {}", service_type_singular, port); diff --git a/src/signal.rs b/src/signal.rs new file mode 100644 index 0000000..7f16e9e --- /dev/null +++ b/src/signal.rs @@ -0,0 +1,35 @@ +//! Wrapper around signals for platforms that they are not supported on + +use std::{fmt::Display, str::FromStr}; + +#[cfg(unix)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct KillportSignal(pub nix::sys::signal::Signal); + +/// On a platform where we don't have the proper signals enum +#[cfg(not(unix))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct KillportSignal(pub String); + +impl Display for KillportSignal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl FromStr for KillportSignal { + type Err = std::io::Error; + + fn from_str(value: &str) -> Result { + #[cfg(unix)] + { + let signal = nix::sys::signal::Signal::from_str(value)?; + Ok(KillportSignal(signal)) + } + + #[cfg(not(unix))] + { + Ok(KillportSignal(value.to_string())) + } + } +} diff --git a/src/unix.rs b/src/unix.rs new file mode 100644 index 0000000..8b13cda --- /dev/null +++ b/src/unix.rs @@ -0,0 +1,59 @@ +use crate::killport::{Killable, KillableType}; +use crate::signal::KillportSignal; +use log::info; +use nix::sys::signal::kill; +use nix::unistd::Pid; +use std::io::Error; + +/// Process type shared amongst unix-like operating systems +#[derive(Debug)] +pub struct UnixProcess { + /// System native process ID. + pid: Pid, + name: String, +} + +impl UnixProcess { + pub fn new(pid: Pid, name: String) -> Self { + Self { pid, name } + } +} + +impl Killable for UnixProcess { + /// Entry point to kill the linux native process. + /// + /// # Arguments + /// + /// * `signal` - A enum value representing the signal type. + fn kill(&self, signal: KillportSignal) -> Result { + info!("Killing process '{}' with PID {}", self.name, self.pid); + + kill(self.pid, signal.0).map(|_| true).map_err(|e| { + Error::new( + std::io::ErrorKind::Other, + format!( + "Failed to kill process '{}' with PID {}: {}", + self.name, self.pid, e + ), + ) + }) + } + + /// Returns the type of the killable target. + /// + /// This method is used to identify the type of the target (either a native process or a Docker container) + /// that is being handled. This information can be useful for logging, error handling, or other needs + /// where type of the target is relevant. + /// + /// # Returns + /// + /// * `String` - A string that describes the type of the killable target. For a `UnixProcess` it will return "process", + /// and for a `DockerContainer` it will return "container". + fn get_type(&self) -> KillableType { + KillableType::Process + } + + fn get_name(&self) -> String { + self.name.to_string() + } +} diff --git a/src/windows.rs b/src/windows.rs index 6090e43..1fab695 100644 --- a/src/windows.rs +++ b/src/windows.rs @@ -1,8 +1,8 @@ -use crate::KillPortSignalOptions; +use crate::killport::{Killable, KillableType}; use log::info; use std::{ alloc::{alloc, dealloc, Layout}, - collections::HashSet, + collections::{HashMap, HashSet}, ffi::c_void, io::{Error, ErrorKind, Result}, ptr::addr_of, @@ -29,24 +29,36 @@ use windows_sys::Win32::{ }, }; -/// Attempts to kill processes listening on the specified `port`. -/// -/// # Arguments +/// Represents a windows native process +#[derive(Debug)] +pub struct WindowsProcess { + pid: u32, + name: String, + parent: Option>, +} + +impl WindowsProcess { + pub fn new(pid: u32, name: String) -> Self { + Self { + pid, + name, + parent: None, + } + } +} + +/// Finds the processes associated with the specified `port`. /// -/// * `port` - A u16 value representing the port number. +/// Returns a `Vec` of native processes. /// -/// # Returns +/// # Arguments /// -/// A `Result` containing a tuple. The first element is a boolean indicating if -/// at least one process was killed (true if yes, false otherwise). The second -/// element is a string indicating the type of the killed entity. An `Error` is -/// returned if the operation failed or the platform is unsupported. -pub fn kill_processes_by_port( - port: u16, - _: KillPortSignalOptions, -) -> Result<(bool, String), Error> { +/// * `port` - Target port number +pub fn find_target_processes(port: u16) -> Result> { + let lookup_table: ProcessLookupTable = ProcessLookupTable::create()?; let mut pids: HashSet = HashSet::new(); - unsafe { + + let processes = unsafe { // Find processes in the TCP IPv4 table use_extended_table::(port, &mut pids)?; @@ -59,87 +71,269 @@ pub fn kill_processes_by_port( // Find processes in the UDP IPv6 table use_extended_table::(port, &mut pids)?; - // Nothing was found - if pids.is_empty() { - return Ok((false, "None".to_string())); + let mut processes: Vec = Vec::with_capacity(pids.len()); + + for pid in pids { + let process_name = lookup_table + .process_names + .get(&pid) + .cloned() + .unwrap_or_else(|| "Unknown".to_string()); + + let mut process = WindowsProcess::new(pid, process_name); + + // Resolve the process parents + lookup_process_parents(&lookup_table, &mut process)?; + + processes.push(process); } - // Collect parents of the PIDs - collect_parents(&mut pids)?; + processes + }; - for pid in pids { - kill_process(pid)?; + Ok(processes) +} + +impl Killable for WindowsProcess { + fn kill(&self, _signal: crate::signal::KillportSignal) -> Result { + let mut killed = false; + let mut next = Some(self); + while let Some(current) = next { + unsafe { + kill_process(current)?; + } + + killed = true; + next = current.parent.as_ref().map(|value| value.as_ref()); } - // Something had to have been killed to reach here - Ok((true, "process".to_string())) + Ok(killed) + } + + fn get_type(&self) -> KillableType { + KillableType::Process + } + + fn get_name(&self) -> String { + self.name.to_string() } } -/// Collects all the parent processes for the PIDs in -/// the provided set +/// Checks if there is a running process with the provided pid /// /// # Arguments /// -/// * `pids` - The set to match PIDs from and insert PIDs into -unsafe fn collect_parents(pids: &mut HashSet) -> Result<()> { - // Request a snapshot handle - let handle: HANDLE = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); +/// * `pid` - The process ID to search for +fn is_process_running(pid: u32) -> Result { + let mut snapshot = WindowsProcessesSnapshot::create()?; + let is_running = snapshot.any(|entry| entry.th32ProcessID == pid); + Ok(is_running) +} - // Ensure we got a valid handle - if handle == INVALID_HANDLE_VALUE { - let error: WIN32_ERROR = GetLastError(); - return Err(Error::new( - ErrorKind::Other, - format!("Failed to get handle to processes: {:#x}", error), - )); +/// Lookup table for finding the names and parents for +/// a process using its pid +pub struct ProcessLookupTable { + /// Mapping from pid to name + process_names: HashMap, + /// Mapping from pid to parent pid + process_parents: HashMap, +} + +impl ProcessLookupTable { + pub fn create() -> Result { + let mut process_names: HashMap = HashMap::new(); + let mut process_parents: HashMap = HashMap::new(); + + WindowsProcessesSnapshot::create()?.for_each(|entry| { + process_names.insert(entry.th32ProcessID, get_process_entry_name(&entry)); + process_parents.insert(entry.th32ProcessID, entry.th32ParentProcessID); + }); + + Ok(Self { + process_names, + process_parents, + }) + } +} + +/// Finds any parent processes of the provided process, adding +/// the process to the list of parents +/// +/// WARNING - This worked in the previous versions because the implementation +/// was flawwed and didn't properly look up the tree of parents, trying to kill +/// all of the parents causes problems since you'll end up killing explorer.exe +/// or some other windows sys process. This has been disabled (Depth of 0) but +/// may be enabled in a future release +/// +/// +/// +/// # Arguments +/// +/// * `process` - The process to collect parents for +fn lookup_process_parents( + lookup_table: &ProcessLookupTable, + process: &mut WindowsProcess, +) -> Result<()> { + const MAX_PARENT_DEPTH: u8 = 0; + + let mut current_procces = process; + let mut depth = 0; + + while let Some(&parent_pid) = lookup_table.process_parents.get(¤t_procces.pid) { + if depth == MAX_PARENT_DEPTH { + break; + } + + let process_name = lookup_table + .process_names + .get(&parent_pid) + .cloned() + .unwrap_or_else(|| "Unknown".to_string()); + + // Add the new parent process + let parent = current_procces + .parent + .insert(Box::new(WindowsProcess::new(parent_pid, process_name))); + + current_procces = parent; + depth += 1 } - // Allocate the memory to use for the entries - let mut entry: PROCESSENTRY32 = std::mem::zeroed(); - entry.dwSize = std::mem::size_of::() as u32; + Ok(()) +} + +/// Parses the name from a process entry, falls back to "Unknown" +/// for invalid names +/// +/// # Arguments +/// +/// * `entry` - The process entry +fn get_process_entry_name(entry: &PROCESSENTRY32) -> String { + let name_chars = entry + .szExeFile + .iter() + .copied() + .take_while(|value| *value != 0) + .collect(); + + let name = String::from_utf8(name_chars); + name.unwrap_or_else(|_| "Unknown".to_string()) +} - // Process the first item - if Process32First(handle, &mut entry) != FALSE { - let mut count = 0; +/// Snapshot of the running windows processes that can be iterated to find +/// information about various processes such as parent processes and +/// process names +/// +/// This is a safe abstraction +pub struct WindowsProcessesSnapshot { + /// Handle to the snapshot + handle: HANDLE, + /// The memory for reading process entries + entry: PROCESSENTRY32, + /// State of reading + state: SnapshotState, +} - loop { - // Add matching processes to the output - if pids.contains(&entry.th32ProcessID) { - pids.insert(entry.th32ParentProcessID); - count += 1; - } +/// State for the snapshot iterator +pub enum SnapshotState { + /// Can read the first entry + First, + /// Can read the next entry + Next, + /// Reached the end, cannot iterate further always give [None] + End, +} - // Process the next entry - if Process32Next(handle, &mut entry) == FALSE { - break; - } +impl WindowsProcessesSnapshot { + /// Creates a new process snapshot to iterate + pub fn create() -> Result { + // Request a snapshot handle + let handle: HANDLE = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) }; + + // Ensure we got a valid handle + if handle == INVALID_HANDLE_VALUE { + let error: WIN32_ERROR = unsafe { GetLastError() }; + return Err(Error::new( + ErrorKind::Other, + format!("Failed to get handle to processes: {:#x}", error), + )); } - info!("Collected {} parent processes", count); + // Allocate the memory to use for the entries + let mut entry: PROCESSENTRY32 = unsafe { std::mem::zeroed() }; + entry.dwSize = std::mem::size_of::() as u32; + + Ok(Self { + handle, + entry, + state: SnapshotState::First, + }) } +} - // Close the handle now that its no longer needed - CloseHandle(handle); +impl Iterator for WindowsProcessesSnapshot { + type Item = PROCESSENTRY32; - Ok(()) + fn next(&mut self) -> Option { + match self.state { + SnapshotState::First => { + // Process the first entry + if unsafe { Process32First(self.handle, &mut self.entry) } == FALSE { + self.state = SnapshotState::End; + return None; + } + self.state = SnapshotState::Next; + + Some(self.entry) + } + SnapshotState::Next => { + // Process the next entry + if unsafe { Process32Next(self.handle, &mut self.entry) } == FALSE { + self.state = SnapshotState::End; + return None; + } + + Some(self.entry) + } + SnapshotState::End => None, + } + } +} + +impl Drop for WindowsProcessesSnapshot { + fn drop(&mut self) { + unsafe { + // Close the handle now that its no longer needed + CloseHandle(self.handle); + } + } } /// Kills a process with the provided process ID /// /// # Arguments /// -/// * `pid` - The process ID -unsafe fn kill_process(pid: u32) -> Result<()> { - info!("Killing process with PID {}", pid); +/// * `process` - The process +unsafe fn kill_process(process: &WindowsProcess) -> Result<()> { + info!("Killing process {}:{}", process.get_name(), process.pid); // Open the process handle with intent to terminate - let handle: HANDLE = OpenProcess(PROCESS_TERMINATE, FALSE, pid); + let handle: HANDLE = OpenProcess(PROCESS_TERMINATE, FALSE, process.pid); if handle == 0 { + // If the process just isn't running we can ignore the error + if !is_process_running(process.pid)? { + return Ok(()); + } + let error: WIN32_ERROR = GetLastError(); return Err(Error::new( ErrorKind::Other, - format!("Failed to obtain handle to process {}: {:#x}", pid, error), + format!( + "Failed to obtain handle to process {}:{}: {:#x}", + process.get_name(), + process.pid, + error + ), )); } @@ -153,7 +347,12 @@ unsafe fn kill_process(pid: u32) -> Result<()> { let error: WIN32_ERROR = GetLastError(); return Err(Error::new( ErrorKind::Other, - format!("Failed to terminate process {}: {:#x}", pid, error), + format!( + "Failed to terminate process {}:{}: {:#x}", + process.get_name(), + process.pid, + error + ), )); } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 6ee506d..bbe4ef0 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -4,6 +4,11 @@ use utils::start_listener_process; use assert_cmd::Command; use tempfile::tempdir; +#[cfg(unix)] +const MOCK_PROCESS_NAME: &str = "mock_process"; +#[cfg(windows)] +const MOCK_PROCESS_NAME: &str = "mock_process.exe"; + #[test] fn test_basic_kill_no_process() { let mut cmd = Command::cargo_bin("killport").unwrap(); @@ -21,10 +26,9 @@ fn test_basic_kill_process() { let mut child = start_listener_process(tempdir_path, 8080); let mut cmd = Command::cargo_bin("killport").unwrap(); - cmd.args(&["8080"]) - .assert() - .success() - .stdout("Successfully killed process 'mock_process' listening on port 8080\n"); + cmd.args(&["8080"]).assert().success().stdout(format!( + "Successfully killed process '{MOCK_PROCESS_NAME}' listening on port 8080\n" + )); // Clean up let _ = child.kill(); @@ -44,7 +48,7 @@ fn test_signal_handling() { .assert() .success() .stdout(format!( - "Successfully killed process 'mock_process' listening on port 8081\n" + "Successfully killed process '{MOCK_PROCESS_NAME}' listening on port 8081\n" )); // Clean up @@ -66,7 +70,7 @@ fn test_mode_option() { .assert() .success() .stdout(format!( - "Successfully killed process 'mock_process' listening on port 8082\n" + "Successfully killed process '{MOCK_PROCESS_NAME}' listening on port 8082\n" )); // Clean up let _ = child.kill(); @@ -103,7 +107,9 @@ fn test_dry_run_option() { cmd.args(&["8083", "--dry-run"]) .assert() .success() - .stdout("Would kill process 'mock_process' listening on port 8083\n"); + .stdout(format!( + "Would kill process '{MOCK_PROCESS_NAME}' listening on port 8083\n" + )); // Clean up let _ = child.kill(); diff --git a/tests/killport_tests.rs b/tests/killport_unix_tests.rs similarity index 60% rename from tests/killport_tests.rs rename to tests/killport_unix_tests.rs index d668e80..f2f0af7 100644 --- a/tests/killport_tests.rs +++ b/tests/killport_unix_tests.rs @@ -1,50 +1,55 @@ +#![cfg(unix)] + use killport::cli::Mode; use killport::docker::DockerContainer; -use killport::killport::KillportOperations; -use killport::killport::{Killable, NativeProcess}; +use killport::killport::{Killable, KillableType, KillportOperations}; +use killport::signal::KillportSignal; +use killport::unix::UnixProcess; use mockall::*; use nix::sys::signal::Signal; use nix::unistd::Pid; use std::io::Error; -use std::sync::{Arc, Mutex}; // Setup Mocks mock! { DockerContainer {} impl Killable for DockerContainer { - fn kill(&self, signal: Signal) -> Result; - fn get_type(&self) -> String; + fn kill(&self, signal: KillportSignal) -> Result; + fn get_type(&self) -> KillableType; fn get_name(&self) -> String; } } mock! { - NativeProcess {} + UnixProcess {} - impl Killable for NativeProcess { - fn kill(&self, signal: Signal) -> Result; - fn get_type(&self) -> String; + impl Killable for UnixProcess { + fn kill(&self, signal: KillportSignal) -> Result; + fn get_type(&self) -> KillableType; fn get_name(&self) -> String; } } mock! { KillportOperations { fn find_target_killables(&self, port: u16, mode: Mode) -> Result>, Error>; - fn kill_service_by_port(&self, port: u16, signal: Signal, mode: Mode, dry_run: bool) -> Result, Error>; + fn kill_service_by_port(&self, port: u16, signal: KillportSignal, mode: Mode, dry_run: bool) -> Result, Error>; } } #[test] fn native_process_kill_succeeds() { - let mut mock_process = MockNativeProcess::new(); + let mut mock_process = MockUnixProcess::new(); // Setup the expectation for the mock mock_process .expect_kill() - .with(mockall::predicate::eq(Signal::SIGKILL)) + .with(mockall::predicate::eq(KillportSignal(Signal::SIGKILL))) .times(1) // Ensure the kill method is called exactly once .returning(|_| Ok(true)); // Simulate successful kill - assert_eq!(mock_process.kill(Signal::SIGKILL).unwrap(), true); + assert_eq!( + mock_process.kill(KillportSignal(Signal::SIGKILL)).unwrap(), + true + ); } #[test] @@ -52,11 +57,16 @@ fn docker_container_kill_succeeds() { let mut mock_container = MockDockerContainer::new(); mock_container .expect_kill() - .with(mockall::predicate::eq(Signal::SIGKILL)) + .with(mockall::predicate::eq(KillportSignal(Signal::SIGKILL))) .times(1) .returning(|_| Ok(true)); - assert_eq!(mock_container.kill(Signal::SIGKILL).unwrap(), true); + assert_eq!( + mock_container + .kill(KillportSignal(Signal::SIGKILL)) + .unwrap(), + true + ); } #[test] @@ -67,10 +77,10 @@ fn find_killables_processes_only() { .expect_find_target_killables() .withf(|&port, &mode| port == 8080 && mode == Mode::Process) .returning(|_, _| { - let mut mock_process = MockNativeProcess::new(); + let mut mock_process = MockUnixProcess::new(); mock_process .expect_get_type() - .return_const("process".to_string()); + .return_const(KillableType::Process); mock_process .expect_get_name() .return_const("mock_process".to_string()); @@ -80,47 +90,46 @@ fn find_killables_processes_only() { let port = 8080; let mode = Mode::Process; let found_killables = mock_killport.find_target_killables(port, mode).unwrap(); - assert!(found_killables.iter().all(|k| k.get_type() == "process")); + assert!(found_killables + .iter() + .all(|k| k.get_type() == KillableType::Process)); } #[test] fn kill_service_by_port_dry_run() { let mut mock_killport = MockKillportOperations::new(); - let mut mock_process = MockNativeProcess::new(); + let mut mock_process = MockUnixProcess::new(); mock_process.expect_kill().never(); mock_process .expect_get_type() - .return_const("process".to_string()); + .return_const(KillableType::Process); mock_process .expect_get_name() .return_const("mock_process".to_string()); mock_killport .expect_kill_service_by_port() - .returning(|_, _, _, _| Ok(vec![("process".to_string(), "mock_process".to_string())])); + .returning(|_, _, _, _| Ok(vec![(KillableType::Process, "mock_process".to_string())])); let port = 8080; let mode = Mode::Process; let dry_run = true; - let signal = Signal::SIGKILL; + let signal = KillportSignal(Signal::SIGKILL); let results = mock_killport .kill_service_by_port(port, signal, mode, dry_run) .unwrap(); assert_eq!(results.len(), 1); - assert_eq!(results[0].0, "process"); + assert_eq!(results[0].0, KillableType::Process); assert_eq!(results[0].1, "mock_process"); } #[test] fn check_process_type_and_name() { - let process = NativeProcess { - pid: Pid::from_raw(1234), - name: "unique_process".to_string(), - }; + let process = UnixProcess::new(Pid::from_raw(1234), "unique_process".to_string()); - assert_eq!(process.get_type(), "process"); + assert_eq!(process.get_type(), KillableType::Process); assert_eq!(process.get_name(), "unique_process"); } @@ -130,12 +139,12 @@ fn check_docker_container_type_and_name() { mock_container .expect_get_type() .times(1) - .returning(|| "container".to_string()); + .returning(|| KillableType::Container); mock_container .expect_get_name() .times(1) .returning(|| "docker_container".to_string()); - assert_eq!(mock_container.get_type(), "container"); + assert_eq!(mock_container.get_type(), KillableType::Container); assert_eq!(mock_container.get_name(), "docker_container"); } diff --git a/tests/killport_windows_tests.rs b/tests/killport_windows_tests.rs new file mode 100644 index 0000000..dd36b56 --- /dev/null +++ b/tests/killport_windows_tests.rs @@ -0,0 +1,149 @@ +#![cfg(windows)] + +use killport::cli::Mode; +use killport::killport::{Killable, KillableType}; +use killport::signal::KillportSignal; +use killport::windows::WindowsProcess; +use mockall::*; + +use std::io::Error; + +// Setup Mocks +mock! { + DockerContainer {} + + impl Killable for DockerContainer { + fn kill(&self, signal: KillportSignal) -> Result; + fn get_type(&self) -> KillableType; + fn get_name(&self) -> String; + } +} + +mock! { + WindowsProcess {} + + impl Killable for WindowsProcess { + fn kill(&self, signal: KillportSignal) -> Result; + fn get_type(&self) -> KillableType; + fn get_name(&self) -> String; + } +} +mock! { + KillportOperations { + fn find_target_killables(&self, port: u16, mode: Mode) -> Result>, Error>; + fn kill_service_by_port(&self, port: u16, signal: KillportSignal, mode: Mode, dry_run: bool) -> Result, Error>; + } +} + +#[test] +fn native_process_kill_succeeds() { + let mut mock_process = MockWindowsProcess::new(); + // Setup the expectation for the mock + mock_process + .expect_kill() + .with(mockall::predicate::eq(KillportSignal( + "SIGKILL".to_string(), + ))) + .times(1) // Ensure the kill method is called exactly once + .returning(|_| Ok(true)); // Simulate successful kill + + assert!(mock_process + .kill(KillportSignal("SIGKILL".to_string())) + .unwrap()); +} + +#[test] +fn docker_container_kill_succeeds() { + let mut mock_container = MockDockerContainer::new(); + mock_container + .expect_kill() + .with(mockall::predicate::eq(KillportSignal( + "SIGKILL".to_string(), + ))) + .times(1) + .returning(|_| Ok(true)); + + assert!(mock_container + .kill(KillportSignal("SIGKILL".to_string())) + .unwrap()); +} + +#[test] +fn find_killables_processes_only() { + let mut mock_killport = MockKillportOperations::new(); + + mock_killport + .expect_find_target_killables() + .withf(|&port, &mode| port == 8080 && mode == Mode::Process) + .returning(|_, _| { + let mut mock_process = MockWindowsProcess::new(); + mock_process + .expect_get_type() + .return_const(KillableType::Process); + mock_process + .expect_get_name() + .return_const("mock_process".to_string()); + Ok(vec![Box::new(mock_process)]) + }); + + let port = 8080; + let mode = Mode::Process; + let found_killables = mock_killport.find_target_killables(port, mode).unwrap(); + assert!(found_killables + .iter() + .all(|k| k.get_type() == KillableType::Process)); +} + +#[test] +fn kill_service_by_port_dry_run() { + let mut mock_killport = MockKillportOperations::new(); + let mut mock_process = MockWindowsProcess::new(); + + mock_process.expect_kill().never(); + mock_process + .expect_get_type() + .return_const(KillableType::Process); + mock_process + .expect_get_name() + .return_const("mock_process".to_string()); + + mock_killport + .expect_kill_service_by_port() + .returning(|_, _, _, _| Ok(vec![(KillableType::Process, "mock_process".to_string())])); + + let port = 8080; + let mode = Mode::Process; + let dry_run = true; + let signal = KillportSignal("SIGKILL".to_string()); + + let results = mock_killport + .kill_service_by_port(port, signal, mode, dry_run) + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, KillableType::Process); + assert_eq!(results[0].1, "mock_process"); +} + +#[test] +fn check_process_type_and_name() { + let process = WindowsProcess::new(1234, "unique_process".to_string()); + + assert_eq!(process.get_type(), KillableType::Process); + assert_eq!(process.get_name(), "unique_process"); +} + +#[test] +fn check_docker_container_type_and_name() { + let mut mock_container = MockDockerContainer::new(); + mock_container + .expect_get_type() + .times(1) + .returning(|| KillableType::Container); + mock_container + .expect_get_name() + .times(1) + .returning(|| "docker_container".to_string()); + + assert_eq!(mock_container.get_type(), KillableType::Container); + assert_eq!(mock_container.get_name(), "docker_container"); +} diff --git a/tests/utils.rs b/tests/utils.rs index 7feb689..eb17a36 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -34,7 +34,7 @@ pub fn start_listener_process(tempdir_path: &Path, port: u16) -> Child { .expect("Failed to write mock process code"); let status = SystemCommand::new("rustc") - .args(&[ + .args([ mock_process_path.to_str().unwrap(), "--out-dir", tempdir_path.to_str().unwrap(),