diff --git a/src/windows.rs b/src/windows.rs index 2ea303d..e0fc0f4 100644 --- a/src/windows.rs +++ b/src/windows.rs @@ -272,25 +272,35 @@ trait TableClass { unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet); } +/// Implementation for get_processes is identical for all of the +/// implementations only difference is the type of row pointer +/// other than that all the fields accessed are the same to in +/// order to prevent repeating this its a macro now +macro_rules! impl_get_processes { + ($ty:ty) => { + unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { + let row_ptr: *const $ty = addr_of!((*table).table).cast(); + let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; + + slice::from_raw_parts(row_ptr, length) + .iter() + .for_each(|element| { + // Convert the port value + let local_port: u16 = (element.dwLocalPort as u16).to_be(); + if local_port == port { + pids.insert(element.dwOwningPid); + } + }); + } + }; +} + impl TableClass for MIB_TCPTABLE_OWNER_MODULE { const TABLE_FN: GetExtendedTable = GetExtendedTcpTable; const FAMILY: AddressFamily = INET; const TABLE_CLASS: TableClassType = TCP_TYPE; - unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { - let row_ptr: *const MIB_TCPROW_OWNER_MODULE = addr_of!((*table).table).cast(); - let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; - - slice::from_raw_parts(row_ptr, length) - .iter() - .for_each(|element| { - // Convert the port value - let local_port: u16 = (element.dwLocalPort as u16).to_be(); - if local_port == port { - pids.insert(element.dwOwningPid); - } - }); - } + impl_get_processes!(MIB_TCPROW_OWNER_MODULE); } impl TableClass for MIB_TCP6TABLE_OWNER_MODULE { @@ -298,20 +308,7 @@ impl TableClass for MIB_TCP6TABLE_OWNER_MODULE { const FAMILY: AddressFamily = INET6; const TABLE_CLASS: TableClassType = TCP_TYPE; - unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { - let row_ptr: *const MIB_TCP6ROW_OWNER_MODULE = addr_of!((*table).table).cast(); - let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; - - slice::from_raw_parts(row_ptr, length) - .iter() - .for_each(|element| { - // Convert the port value - let local_port: u16 = (element.dwLocalPort as u16).to_be(); - if local_port == port { - pids.insert(element.dwOwningPid); - } - }); - } + impl_get_processes!(MIB_TCP6ROW_OWNER_MODULE); } impl TableClass for MIB_UDPTABLE_OWNER_MODULE { @@ -319,20 +316,7 @@ impl TableClass for MIB_UDPTABLE_OWNER_MODULE { const FAMILY: AddressFamily = INET; const TABLE_CLASS: TableClassType = UDP_TYPE; - unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { - let row_ptr: *const MIB_UDPROW_OWNER_MODULE = addr_of!((*table).table).cast(); - let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; - - slice::from_raw_parts(row_ptr, length) - .iter() - .for_each(|element| { - // Convert the port value - let local_port: u16 = (element.dwLocalPort as u16).to_be(); - if local_port == port { - pids.insert(element.dwOwningPid); - } - }); - } + impl_get_processes!(MIB_UDPROW_OWNER_MODULE); } impl TableClass for MIB_UDP6TABLE_OWNER_MODULE { @@ -340,18 +324,5 @@ impl TableClass for MIB_UDP6TABLE_OWNER_MODULE { const FAMILY: AddressFamily = INET6; const TABLE_CLASS: TableClassType = UDP_TYPE; - unsafe fn get_processes(table: *const Self, port: u16, pids: &mut HashSet) { - let row_ptr: *const MIB_UDP6ROW_OWNER_MODULE = addr_of!((*table).table).cast(); - let length: usize = addr_of!((*table).dwNumEntries).read_unaligned() as usize; - - slice::from_raw_parts(row_ptr, length) - .iter() - .for_each(|element| { - // Convert the port value - let local_port: u16 = (element.dwLocalPort as u16).to_be(); - if local_port == port { - pids.insert(element.dwOwningPid); - } - }); - } + impl_get_processes!(MIB_UDP6ROW_OWNER_MODULE); }