diff --git a/src/windows.rs b/src/windows.rs index b244d65..2ea303d 100644 --- a/src/windows.rs +++ b/src/windows.rs @@ -1,24 +1,25 @@ use crate::KillPortSignalOptions; -use log::{debug, info}; +use log::info; use std::{ - alloc::Layout, + alloc::{alloc, dealloc, Layout}, collections::HashSet, ffi::c_void, - io::{Error, ErrorKind}, + io::{Error, ErrorKind, Result}, ptr::addr_of, + slice, }; use windows_sys::Win32::{ Foundation::{ - CloseHandle, GetLastError, ERROR_INSUFFICIENT_BUFFER, INVALID_HANDLE_VALUE, NO_ERROR, + CloseHandle, GetLastError, BOOL, ERROR_INSUFFICIENT_BUFFER, FALSE, HANDLE, + INVALID_HANDLE_VALUE, NO_ERROR, WIN32_ERROR, }, NetworkManagement::IpHelper::{ GetExtendedTcpTable, GetExtendedUdpTable, MIB_TCP6ROW_OWNER_MODULE, MIB_TCP6TABLE_OWNER_MODULE, MIB_TCPROW_OWNER_MODULE, MIB_TCPTABLE_OWNER_MODULE, MIB_UDP6ROW_OWNER_MODULE, MIB_UDP6TABLE_OWNER_MODULE, MIB_UDPROW_OWNER_MODULE, - MIB_UDPTABLE_OWNER_MODULE, TCP_TABLE_CLASS, TCP_TABLE_OWNER_MODULE_ALL, UDP_TABLE_CLASS, - UDP_TABLE_OWNER_MODULE, + MIB_UDPTABLE_OWNER_MODULE, TCP_TABLE_OWNER_MODULE_ALL, UDP_TABLE_OWNER_MODULE, }, - Networking::WinSock::{ADDRESS_FAMILY, AF_INET, AF_INET6}, + Networking::WinSock::{AF_INET, AF_INET6}, System::{ Diagnostics::ToolHelp::{ CreateToolhelp32Snapshot, Process32First, Process32Next, PROCESSENTRY32, @@ -37,8 +38,8 @@ use windows_sys::Win32::{ /// # Arguments /// /// * `port` - A u16 value representing the port number. -pub fn kill_processes_by_port(port: u16, _: KillPortSignalOptions) -> Result { - let mut pids = HashSet::new(); +pub fn kill_processes_by_port(port: u16, _: KillPortSignalOptions) -> Result { + let mut pids: HashSet = HashSet::new(); unsafe { // Find processes in the TCP IPv4 table use_extended_table::(port, &mut pids)?; @@ -61,7 +62,6 @@ pub fn kill_processes_by_port(port: u16, _: KillPortSignalOptions) -> Result Result) -> Result<(), Error> { +unsafe fn collect_parents(pids: &mut HashSet) -> Result<()> { // Request a snapshot handle - let handle = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); + let handle: HANDLE = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); // Ensure we got a valid handle if handle == INVALID_HANDLE_VALUE { - let error = GetLastError(); - return Err(std::io::Error::new( + let error: WIN32_ERROR = GetLastError(); + return Err(Error::new( ErrorKind::Other, format!("Failed to get handle to processes: {:#x}", error), )); @@ -94,7 +94,7 @@ unsafe fn collect_parents(pids: &mut HashSet) -> Result<(), Error> { entry.dwSize = std::mem::size_of::() as u32; // Process the first item - if Process32First(handle, &mut entry) != 0 { + if Process32First(handle, &mut entry) != FALSE { let mut count = 0; loop { @@ -105,7 +105,7 @@ unsafe fn collect_parents(pids: &mut HashSet) -> Result<(), Error> { } // Process the next entry - if Process32Next(handle, &mut entry) == 0 { + if Process32Next(handle, &mut entry) == FALSE { break; } } @@ -113,7 +113,7 @@ unsafe fn collect_parents(pids: &mut HashSet) -> Result<(), Error> { info!("Collected {} parent processes", count); } - // Close the handle we obtained + // Close the handle now that its no longer needed CloseHandle(handle); Ok(()) @@ -124,23 +124,28 @@ unsafe fn collect_parents(pids: &mut HashSet) -> Result<(), Error> { /// # Arguments /// /// * `pid` - The process ID -unsafe fn kill_process(pid: u32) -> Result<(), Error> { +unsafe fn kill_process(pid: u32) -> Result<()> { info!("Killing process with PID {}", pid); // Open the process handle with intent to terminate - let handle = OpenProcess(PROCESS_TERMINATE, 0, pid); + let handle: HANDLE = OpenProcess(PROCESS_TERMINATE, FALSE, pid); if handle == 0 { - let error = GetLastError(); - return Err(std::io::Error::new( + let error: WIN32_ERROR = GetLastError(); + return Err(Error::new( ErrorKind::Other, format!("Failed to obtain handle to process {}: {:#x}", pid, error), )); } - let result = TerminateProcess(handle, 0); - if result == 0 { - let error = GetLastError(); - return Err(std::io::Error::new( + // Terminate the process + let result: BOOL = TerminateProcess(handle, 0); + + // Close the handle now that its no longer needed + CloseHandle(handle); + + if result == FALSE { + let error: WIN32_ERROR = GetLastError(); + return Err(Error::new( ErrorKind::Other, format!("Failed to terminate process {}: {:#x}", pid, error), )); @@ -157,26 +162,28 @@ unsafe fn kill_process(pid: u32) -> Result<(), Error> { /// /// * `port` - The port to check for /// * `pids` - The output list of process IDs -unsafe fn use_extended_table(port: u16, pids: &mut HashSet) -> Result<(), Error> +unsafe fn use_extended_table(port: u16, pids: &mut HashSet) -> Result<()> where T: TableClass, { - let mut layout = Layout::new::(); - let mut buffer = std::alloc::alloc(layout); + // Allocation of initial memory + let mut layout: Layout = Layout::new::(); + let mut buffer: *mut u8 = alloc(layout); - // Size estimate for resizing the buffer - let mut size = 0; + // Current buffer size later changed by the fn call to be the estimated size + // for resizing the buffer + let mut size: u32 = layout.size() as u32; - // Result of asking for the TCP table - let mut result: u32; + // Result of asking for the table + let mut result: WIN32_ERROR; loop { // Ask windows for the extended table result = (T::TABLE_FN)( buffer.cast(), &mut size, - 1, - T::FAMILY as u32, + FALSE, + T::FAMILY, T::TABLE_CLASS, 0, ); @@ -186,40 +193,60 @@ where break; } + // Always deallocate the memory regardless of the error + // (Resizing needs to reallocate the memory anyway) + dealloc(buffer, layout); + // Handle buffer too small if result == ERROR_INSUFFICIENT_BUFFER { - // Deallocate the old memory layout - std::alloc::dealloc(buffer, layout); - // Create the new memory layout from the new size and previous alignment layout = Layout::from_size_align_unchecked(size as usize, layout.align()); // Allocate the new chunk of memory - buffer = std::alloc::alloc(layout); + buffer = alloc(layout); continue; } - // Deallocate the buffer memory - std::alloc::dealloc(buffer, layout); - // Handle unknown failures - return Err(std::io::Error::new( + return Err(Error::new( ErrorKind::Other, - "Failed to get size estimate for TCP table", + format!( + "Failed to get size estimate for extended table: {:#x}", + result + ), )); } let table: *const T = buffer.cast(); + // Obtain the processes from the table T::get_processes(table, port, pids); // Deallocate the buffer memory - std::alloc::dealloc(buffer, layout); + dealloc(buffer, layout); Ok(()) } /// Type of the GetExtended[UDP/TCP]Table Windows API function -type GetExtendedTable = unsafe extern "system" fn(*mut c_void, *mut u32, i32, u32, i32, u32) -> u32; +type GetExtendedTable = + unsafe extern "system" fn(*mut c_void, *mut u32, i32, AddressFamily, i32, u32) -> WIN32_ERROR; + +/// For some reason the actual INET types are u16 so this +/// is just a casted version to u32 +type AddressFamily = u32; + +/// IPv4 Address family +const INET: AddressFamily = AF_INET as u32; +/// IPv6 Address family +const INET6: AddressFamily = AF_INET6 as u32; + +/// Table class type (either TCP_TABLE_CLASS for TCP or UDP_TABLE_CLASS for UDP) +type TableClassType = i32; + +/// TCP class type for the owner to module mappings +const TCP_TYPE: TableClassType = TCP_TABLE_OWNER_MODULE_ALL; +/// UDP class type for the owner to module mappings +const UDP_TYPE: TableClassType = UDP_TABLE_OWNER_MODULE; /// Trait implemented by extended tables that can /// be enumerated for processes that match a @@ -229,10 +256,10 @@ trait TableClass { const TABLE_FN: GetExtendedTable; /// Address family type - const FAMILY: ADDRESS_FAMILY; + const FAMILY: AddressFamily; /// Windows table class type - const TABLE_CLASS: i32; + const TABLE_CLASS: TableClassType; /// Iterates the contents of the extended table inserting any /// process entires that match the provided `port` into the @@ -247,14 +274,14 @@ trait TableClass { impl TableClass for MIB_TCPTABLE_OWNER_MODULE { const TABLE_FN: GetExtendedTable = GetExtendedTcpTable; - const FAMILY: ADDRESS_FAMILY = AF_INET; - const TABLE_CLASS: TCP_TABLE_CLASS = TCP_TABLE_OWNER_MODULE_ALL; + const FAMILY: AddressFamily = INET; + const TABLE_CLASS: TableClassType = TCP_TYPE; unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { let row_ptr: *const MIB_TCPROW_OWNER_MODULE = addr_of!((*table).table).cast(); let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; - std::slice::from_raw_parts(row_ptr, length) + slice::from_raw_parts(row_ptr, length) .iter() .for_each(|element| { // Convert the port value @@ -268,14 +295,14 @@ impl TableClass for MIB_TCPTABLE_OWNER_MODULE { impl TableClass for MIB_TCP6TABLE_OWNER_MODULE { const TABLE_FN: GetExtendedTable = GetExtendedTcpTable; - const FAMILY: ADDRESS_FAMILY = AF_INET6; - const TABLE_CLASS: TCP_TABLE_CLASS = TCP_TABLE_OWNER_MODULE_ALL; + const FAMILY: AddressFamily = INET6; + const TABLE_CLASS: TableClassType = TCP_TYPE; unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { let row_ptr: *const MIB_TCP6ROW_OWNER_MODULE = addr_of!((*table).table).cast(); let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; - std::slice::from_raw_parts(row_ptr, length) + slice::from_raw_parts(row_ptr, length) .iter() .for_each(|element| { // Convert the port value @@ -289,14 +316,14 @@ impl TableClass for MIB_TCP6TABLE_OWNER_MODULE { impl TableClass for MIB_UDPTABLE_OWNER_MODULE { const TABLE_FN: GetExtendedTable = GetExtendedUdpTable; - const FAMILY: ADDRESS_FAMILY = AF_INET; - const TABLE_CLASS: UDP_TABLE_CLASS = UDP_TABLE_OWNER_MODULE; + const FAMILY: AddressFamily = INET; + const TABLE_CLASS: TableClassType = UDP_TYPE; unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { let row_ptr: *const MIB_UDPROW_OWNER_MODULE = addr_of!((*table).table).cast(); let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; - std::slice::from_raw_parts(row_ptr, length) + slice::from_raw_parts(row_ptr, length) .iter() .for_each(|element| { // Convert the port value @@ -310,14 +337,14 @@ impl TableClass for MIB_UDPTABLE_OWNER_MODULE { impl TableClass for MIB_UDP6TABLE_OWNER_MODULE { const TABLE_FN: GetExtendedTable = GetExtendedUdpTable; - const FAMILY: ADDRESS_FAMILY = AF_INET6; - const TABLE_CLASS: UDP_TABLE_CLASS = UDP_TABLE_OWNER_MODULE; + const FAMILY: AddressFamily = INET6; + const TABLE_CLASS: TableClassType = UDP_TYPE; unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { let row_ptr: *const MIB_UDP6ROW_OWNER_MODULE = addr_of!((*table).table).cast(); let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; - std::slice::from_raw_parts(row_ptr, length) + slice::from_raw_parts(row_ptr, length) .iter() .for_each(|element| { // Convert the port value