Skip to content

Commit

Permalink
Cleaned up, added extra typing, added more documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtread committed May 18, 2023
1 parent 7639be4 commit d1b3bcd
Showing 1 changed file with 86 additions and 59 deletions.
145 changes: 86 additions & 59 deletions src/windows.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
use crate::KillPortSignalOptions;
use log::{debug, info};
use log::info;
use std::{
alloc::Layout,
alloc::{alloc, dealloc, Layout},
collections::HashSet,
ffi::c_void,
io::{Error, ErrorKind},
io::{Error, ErrorKind, Result},
ptr::addr_of,
slice,
};
use windows_sys::Win32::{
Foundation::{
CloseHandle, GetLastError, ERROR_INSUFFICIENT_BUFFER, INVALID_HANDLE_VALUE, NO_ERROR,
CloseHandle, GetLastError, BOOL, ERROR_INSUFFICIENT_BUFFER, FALSE, HANDLE,
INVALID_HANDLE_VALUE, NO_ERROR, WIN32_ERROR,
},
NetworkManagement::IpHelper::{
GetExtendedTcpTable, GetExtendedUdpTable, MIB_TCP6ROW_OWNER_MODULE,
MIB_TCP6TABLE_OWNER_MODULE, MIB_TCPROW_OWNER_MODULE, MIB_TCPTABLE_OWNER_MODULE,
MIB_UDP6ROW_OWNER_MODULE, MIB_UDP6TABLE_OWNER_MODULE, MIB_UDPROW_OWNER_MODULE,
MIB_UDPTABLE_OWNER_MODULE, TCP_TABLE_CLASS, TCP_TABLE_OWNER_MODULE_ALL, UDP_TABLE_CLASS,
UDP_TABLE_OWNER_MODULE,
MIB_UDPTABLE_OWNER_MODULE, TCP_TABLE_OWNER_MODULE_ALL, UDP_TABLE_OWNER_MODULE,
},
Networking::WinSock::{ADDRESS_FAMILY, AF_INET, AF_INET6},
Networking::WinSock::{AF_INET, AF_INET6},
System::{
Diagnostics::ToolHelp::{
CreateToolhelp32Snapshot, Process32First, Process32Next, PROCESSENTRY32,
Expand All @@ -37,8 +38,8 @@ use windows_sys::Win32::{
/// # Arguments
///
/// * `port` - A u16 value representing the port number.
pub fn kill_processes_by_port(port: u16, _: KillPortSignalOptions) -> Result<bool, Error> {
let mut pids = HashSet::new();
pub fn kill_processes_by_port(port: u16, _: KillPortSignalOptions) -> Result<bool> {
let mut pids: HashSet<u32> = HashSet::new();
unsafe {
// Find processes in the TCP IPv4 table
use_extended_table::<MIB_TCPTABLE_OWNER_MODULE>(port, &mut pids)?;
Expand All @@ -61,7 +62,6 @@ pub fn kill_processes_by_port(port: u16, _: KillPortSignalOptions) -> Result<boo
collect_parents(&mut pids)?;

for pid in pids {
debug!("Found process with PID {}", pid);
kill_process(pid)?;
}

Expand All @@ -76,14 +76,14 @@ pub fn kill_processes_by_port(port: u16, _: KillPortSignalOptions) -> Result<boo
/// # Arguments
///
/// * `pids` - The set to match PIDs from and insert PIDs into
unsafe fn collect_parents(pids: &mut HashSet<u32>) -> Result<(), Error> {
unsafe fn collect_parents(pids: &mut HashSet<u32>) -> Result<()> {
// Request a snapshot handle
let handle = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
let handle: HANDLE = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);

// Ensure we got a valid handle
if handle == INVALID_HANDLE_VALUE {
let error = GetLastError();
return Err(std::io::Error::new(
let error: WIN32_ERROR = GetLastError();
return Err(Error::new(
ErrorKind::Other,
format!("Failed to get handle to processes: {:#x}", error),
));
Expand All @@ -94,7 +94,7 @@ unsafe fn collect_parents(pids: &mut HashSet<u32>) -> Result<(), Error> {
entry.dwSize = std::mem::size_of::<PROCESSENTRY32>() as u32;

// Process the first item
if Process32First(handle, &mut entry) != 0 {
if Process32First(handle, &mut entry) != FALSE {
let mut count = 0;

loop {
Expand All @@ -105,15 +105,15 @@ unsafe fn collect_parents(pids: &mut HashSet<u32>) -> Result<(), Error> {
}

// Process the next entry
if Process32Next(handle, &mut entry) == 0 {
if Process32Next(handle, &mut entry) == FALSE {
break;
}
}

info!("Collected {} parent processes", count);
}

// Close the handle we obtained
// Close the handle now that its no longer needed
CloseHandle(handle);

Ok(())
Expand All @@ -124,23 +124,28 @@ unsafe fn collect_parents(pids: &mut HashSet<u32>) -> Result<(), Error> {
/// # Arguments
///
/// * `pid` - The process ID
unsafe fn kill_process(pid: u32) -> Result<(), Error> {
unsafe fn kill_process(pid: u32) -> Result<()> {
info!("Killing process with PID {}", pid);

// Open the process handle with intent to terminate
let handle = OpenProcess(PROCESS_TERMINATE, 0, pid);
let handle: HANDLE = OpenProcess(PROCESS_TERMINATE, FALSE, pid);
if handle == 0 {
let error = GetLastError();
return Err(std::io::Error::new(
let error: WIN32_ERROR = GetLastError();
return Err(Error::new(
ErrorKind::Other,
format!("Failed to obtain handle to process {}: {:#x}", pid, error),
));
}

let result = TerminateProcess(handle, 0);
if result == 0 {
let error = GetLastError();
return Err(std::io::Error::new(
// Terminate the process
let result: BOOL = TerminateProcess(handle, 0);

// Close the handle now that its no longer needed
CloseHandle(handle);

if result == FALSE {
let error: WIN32_ERROR = GetLastError();
return Err(Error::new(
ErrorKind::Other,
format!("Failed to terminate process {}: {:#x}", pid, error),
));
Expand All @@ -157,26 +162,28 @@ unsafe fn kill_process(pid: u32) -> Result<(), Error> {
///
/// * `port` - The port to check for
/// * `pids` - The output list of process IDs
unsafe fn use_extended_table<T>(port: u16, pids: &mut HashSet<u32>) -> Result<(), Error>
unsafe fn use_extended_table<T>(port: u16, pids: &mut HashSet<u32>) -> Result<()>
where
T: TableClass,
{
let mut layout = Layout::new::<T>();
let mut buffer = std::alloc::alloc(layout);
// Allocation of initial memory
let mut layout: Layout = Layout::new::<T>();
let mut buffer: *mut u8 = alloc(layout);

// Size estimate for resizing the buffer
let mut size = 0;
// Current buffer size later changed by the fn call to be the estimated size
// for resizing the buffer
let mut size: u32 = layout.size() as u32;

// Result of asking for the TCP table
let mut result: u32;
// Result of asking for the table
let mut result: WIN32_ERROR;

loop {
// Ask windows for the extended table
result = (T::TABLE_FN)(
buffer.cast(),
&mut size,
1,
T::FAMILY as u32,
FALSE,
T::FAMILY,
T::TABLE_CLASS,
0,
);
Expand All @@ -186,40 +193,60 @@ where
break;
}

// Always deallocate the memory regardless of the error
// (Resizing needs to reallocate the memory anyway)
dealloc(buffer, layout);

// Handle buffer too small
if result == ERROR_INSUFFICIENT_BUFFER {
// Deallocate the old memory layout
std::alloc::dealloc(buffer, layout);

// Create the new memory layout from the new size and previous alignment
layout = Layout::from_size_align_unchecked(size as usize, layout.align());
// Allocate the new chunk of memory
buffer = std::alloc::alloc(layout);
buffer = alloc(layout);
continue;
}

// Deallocate the buffer memory
std::alloc::dealloc(buffer, layout);

// Handle unknown failures
return Err(std::io::Error::new(
return Err(Error::new(
ErrorKind::Other,
"Failed to get size estimate for TCP table",
format!(
"Failed to get size estimate for extended table: {:#x}",
result
),
));
}

let table: *const T = buffer.cast();

// Obtain the processes from the table
T::get_processes(table, port, pids);

// Deallocate the buffer memory
std::alloc::dealloc(buffer, layout);
dealloc(buffer, layout);

Ok(())
}

/// Type of the GetExtended[UDP/TCP]Table Windows API function
type GetExtendedTable = unsafe extern "system" fn(*mut c_void, *mut u32, i32, u32, i32, u32) -> u32;
type GetExtendedTable =
unsafe extern "system" fn(*mut c_void, *mut u32, i32, AddressFamily, i32, u32) -> WIN32_ERROR;

/// For some reason the actual INET types are u16 so this
/// is just a casted version to u32
type AddressFamily = u32;

/// IPv4 Address family
const INET: AddressFamily = AF_INET as u32;
/// IPv6 Address family
const INET6: AddressFamily = AF_INET6 as u32;

/// Table class type (either TCP_TABLE_CLASS for TCP or UDP_TABLE_CLASS for UDP)
type TableClassType = i32;

/// TCP class type for the owner to module mappings
const TCP_TYPE: TableClassType = TCP_TABLE_OWNER_MODULE_ALL;
/// UDP class type for the owner to module mappings
const UDP_TYPE: TableClassType = UDP_TABLE_OWNER_MODULE;

/// Trait implemented by extended tables that can
/// be enumerated for processes that match a
Expand All @@ -229,10 +256,10 @@ trait TableClass {
const TABLE_FN: GetExtendedTable;

/// Address family type
const FAMILY: ADDRESS_FAMILY;
const FAMILY: AddressFamily;

/// Windows table class type
const TABLE_CLASS: i32;
const TABLE_CLASS: TableClassType;

/// Iterates the contents of the extended table inserting any
/// process entires that match the provided `port` into the
Expand All @@ -247,14 +274,14 @@ trait TableClass {

impl TableClass for MIB_TCPTABLE_OWNER_MODULE {
const TABLE_FN: GetExtendedTable = GetExtendedTcpTable;
const FAMILY: ADDRESS_FAMILY = AF_INET;
const TABLE_CLASS: TCP_TABLE_CLASS = TCP_TABLE_OWNER_MODULE_ALL;
const FAMILY: AddressFamily = INET;
const TABLE_CLASS: TableClassType = TCP_TYPE;

unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet<u32>) {
let row_ptr: *const MIB_TCPROW_OWNER_MODULE = addr_of!((*table).table).cast();
let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize;

std::slice::from_raw_parts(row_ptr, length)
slice::from_raw_parts(row_ptr, length)
.iter()
.for_each(|element| {
// Convert the port value
Expand All @@ -268,14 +295,14 @@ impl TableClass for MIB_TCPTABLE_OWNER_MODULE {

impl TableClass for MIB_TCP6TABLE_OWNER_MODULE {
const TABLE_FN: GetExtendedTable = GetExtendedTcpTable;
const FAMILY: ADDRESS_FAMILY = AF_INET6;
const TABLE_CLASS: TCP_TABLE_CLASS = TCP_TABLE_OWNER_MODULE_ALL;
const FAMILY: AddressFamily = INET6;
const TABLE_CLASS: TableClassType = TCP_TYPE;

unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet<u32>) {
let row_ptr: *const MIB_TCP6ROW_OWNER_MODULE = addr_of!((*table).table).cast();
let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize;

std::slice::from_raw_parts(row_ptr, length)
slice::from_raw_parts(row_ptr, length)
.iter()
.for_each(|element| {
// Convert the port value
Expand All @@ -289,14 +316,14 @@ impl TableClass for MIB_TCP6TABLE_OWNER_MODULE {

impl TableClass for MIB_UDPTABLE_OWNER_MODULE {
const TABLE_FN: GetExtendedTable = GetExtendedUdpTable;
const FAMILY: ADDRESS_FAMILY = AF_INET;
const TABLE_CLASS: UDP_TABLE_CLASS = UDP_TABLE_OWNER_MODULE;
const FAMILY: AddressFamily = INET;
const TABLE_CLASS: TableClassType = UDP_TYPE;

unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet<u32>) {
let row_ptr: *const MIB_UDPROW_OWNER_MODULE = addr_of!((*table).table).cast();
let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize;

std::slice::from_raw_parts(row_ptr, length)
slice::from_raw_parts(row_ptr, length)
.iter()
.for_each(|element| {
// Convert the port value
Expand All @@ -310,14 +337,14 @@ impl TableClass for MIB_UDPTABLE_OWNER_MODULE {

impl TableClass for MIB_UDP6TABLE_OWNER_MODULE {
const TABLE_FN: GetExtendedTable = GetExtendedUdpTable;
const FAMILY: ADDRESS_FAMILY = AF_INET6;
const TABLE_CLASS: UDP_TABLE_CLASS = UDP_TABLE_OWNER_MODULE;
const FAMILY: AddressFamily = INET6;
const TABLE_CLASS: TableClassType = UDP_TYPE;

unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet<u32>) {
let row_ptr: *const MIB_UDP6ROW_OWNER_MODULE = addr_of!((*table).table).cast();
let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize;

std::slice::from_raw_parts(row_ptr, length)
slice::from_raw_parts(row_ptr, length)
.iter()
.for_each(|element| {
// Convert the port value
Expand Down

0 comments on commit d1b3bcd

Please sign in to comment.