diff --git a/bin/systemd/switchmap_poller b/bin/systemd/switchmap_poller index 929b499fa..66260ec52 100755 --- a/bin/systemd/switchmap_poller +++ b/bin/systemd/switchmap_poller @@ -74,7 +74,7 @@ class PollingAgent(Agent): """ # Initialize key variables delay = self._server_config.polling_interval() - multiprocessing = self._server_config.multiprocessing() + max_concurrent = self._server_config.agent_subprocesses() # Post data to the remote server while True: @@ -89,7 +89,7 @@ class PollingAgent(Agent): open(self.lockfile, "a").close() # Poll after sleeping - poll.devices(multiprocessing=multiprocessing) + poll.run_devices(max_concurrent_devices=max_concurrent) # Delete lockfile os.remove(self.lockfile) diff --git a/requirements.txt b/requirements.txt index 244676e3e..b8788380a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ gunicorn==20.0.4 # Posting requests +aiohttp # Polling easysnmp==0.2.5 @@ -31,6 +32,7 @@ graphql_server==3.0.0b5 # Testing mock pytest +pytest-asyncio # Other more-itertools diff --git a/snmp_test.py b/snmp_test.py deleted file mode 100644 index 2d8e2d10a..000000000 --- a/snmp_test.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -"""Test async_snmp_info.everything() using proper credential validation.""" - -import asyncio -import sys -import traceback -import time - -from switchmap.poller.snmp.async_snmp_info import Query -from switchmap.poller.snmp import async_snmp_manager -from switchmap.poller.configuration import ConfigPoller -from switchmap.poller import POLLING_OPTIONS, POLL - - -async def test_everything(): - """Test everything() method with proper SNMP credential validation.""" - print("Testing async_snmp_info.everything()") - - hostname = "162.249.37.218" - - try: - # SNMP configuration - print(f"Getting SNMP configuration...") - config = ConfigPoller() - - print(f"Validating SNMP credentials for {hostname}...") - validate = async_snmp_manager.Validate( - POLLING_OPTIONS( - hostname=hostname, authorizations=config.snmp_auth() - ) - ) - - # Get valid authorization - authorization = await validate.credentials() - if not authorization: - print(f"Failed to get valid SNMP credentials for {hostname}") - return None - - snmp_object = async_snmp_manager.Interact( - POLL(hostname=hostname, authorization=authorization) - ) - - print(f"Testing device connectivity...") - is_contactable = await snmp_object.contactable() - if not is_contactable: - print(f"device {hostname} is not contactable via SNMP") - return None - - print(f"device {hostname} is contactable!") - - sysobjectid = await snmp_object.sysobjectid() - enterprise_no = await snmp_object.enterprise_number() - print(f"Device info:") - print(f"SysObjectID: {sysobjectid}") - print(f"Enterprise: {enterprise_no}") - - query_obj = Query(snmp_object) - - print(f"Calling everything() method...") - print(f"wait a little...") - - start_time = time.time() - everything_data = await query_obj.everything() - end_time = time.time() - - print(f"Completed in {end_time - start_time:.2f} seconds") - - # Display results - if everything_data: - print(f"\nSUCCESS! everything() returned data :)))))))") - print(f"Data str:") - for key, value in everything_data.items(): - if isinstance(value, dict) and value: - print(f" {key}: {len(value)} items") - # Show sample of nested data - sample_key = list(value.keys())[0] - sample_value = value[sample_key] - if isinstance(sample_value, dict): - print( - f"Sample {sample_key}: {len(sample_value)} sub-items" - ) - else: - print(f"Sample {sample_key}: {type(sample_value)}") - elif isinstance(value, dict): - print(f" {key}: empty dict") - else: - print(f" {key}: {type(value)} = {value}") - else: - print(f"everything() returned nonee result") - - return everything_data - - except Exception as e: - print(f"ERROR: :(((((({e}") - print(f"Full traceback:") - traceback.print_exc() - return None - - -async def main(): - """Main test function.""" - print("Async SNMP Info Test (Proper Credentials)") - print("=" * 50) - - result = await test_everything() - - if result is not None: - print(f"\nTest completed successfully!") - return True - else: - print(f"\nTest failed!") - return False - - -if __name__ == "__main__": - print(" Running Async SNMP Test with Proper Credentials...") - success = asyncio.run(main()) - - if success: - print("Test completed!") - sys.exit(0) - else: - print("Test failed!") - sys.exit(1) diff --git a/switchmap/core/files.py b/switchmap/core/files.py index b59a24877..dd450ee51 100644 --- a/switchmap/core/files.py +++ b/switchmap/core/files.py @@ -253,7 +253,7 @@ def read_yaml_file(filepath, as_string=False, die=True): "".format(filepath) ) if bool(die) is True: - log.log2die_safe(1001, log_message) + log.log2die_safe(1008, log_message) else: log.log2debug(1002, log_message) return {} diff --git a/switchmap/core/mac_utils.py b/switchmap/core/mac_utils.py new file mode 100644 index 000000000..b09a5d6a5 --- /dev/null +++ b/switchmap/core/mac_utils.py @@ -0,0 +1,40 @@ +"""MAC address utility functions.""" + +import binascii + + +def decode_mac_address(encoded_mac): + """Decode double-encoded MAC addresses from async poller. + + This function handles MAC addresses that may be double hex-encoded + and returns them in a standard format. + + Args: + encoded_mac: MAC address that may be double hex-encoded + + Returns: + str: Properly formatted MAC address or original if already valid + + """ + # Fast-path non-strings + if not isinstance(encoded_mac, str): + return encoded_mac + + s = encoded_mac.strip() + + # Handle plain '0x' prefix + if s.lower().startswith("0x"): + return s[2:] + + # Attempt to unhexlify only when likely hex and long enough + hexchars = "0123456789abcdefABCDEF" + if len(s) > 12 and len(s) % 2 == 0 and all(c in hexchars for c in s): + try: + decoded = binascii.unhexlify(s).decode("ascii") + except (binascii.Error, UnicodeDecodeError, ValueError): + return encoded_mac + if decoded.lower().startswith("0x"): + return decoded[2:] + return decoded + + return s diff --git a/switchmap/poller/poll.py b/switchmap/poller/poll.py index 7cd32517a..c4937cab9 100644 --- a/switchmap/poller/poll.py +++ b/switchmap/poller/poll.py @@ -1,37 +1,32 @@ -"""Switchmap-NG poll modulre. +"""Async Switchmap-NG poll module.""" -Updates the database with device SNMP data. - -""" - -# Standard libraries -from multiprocessing import Pool +import asyncio from collections import namedtuple from pprint import pprint import os +import time +import aiohttp # Import app libraries -from switchmap import API_POLLER_POST_URI +from switchmap import API_POLLER_POST_URI, API_PREFIX from switchmap.poller.snmp import poller from switchmap.poller.update import device as udevice from switchmap.poller.configuration import ConfigPoller -from switchmap.core import log -from switchmap.core import rest -from switchmap.core import files +from switchmap.core import log, rest, files from switchmap import AGENT_POLLER _META = namedtuple("_META", "zone hostname config") -def devices(multiprocessing=False): - """Poll all devices for data using subprocesses and create YAML files. +async def devices(max_concurrent_devices=None): + """Poll all devices asynchronously. Args: - multiprocessing: Run multiprocessing when True + max_concurrent_devices: Maximum number of devices to poll concurrently. + If None, uses config.agent_subprocesses() Returns: None - """ # Initialize key variables arguments = [] @@ -39,95 +34,188 @@ def devices(multiprocessing=False): # Get configuration config = ConfigPoller() - # Get the number of threads to use in the pool - pool_size = config.agent_subprocesses() + # Use config value if not provided + if max_concurrent_devices is None: + max_concurrent_devices = config.agent_subprocesses() + elif ( + not isinstance(max_concurrent_devices, int) + or max_concurrent_devices < 1 + ): + log.log2warning( + 1401, + f"Invalid concurrency={max_concurrent_devices}; defaulting to 1", + ) + max_concurrent_devices = 1 # Create a list of polling objects - zones = sorted(config.zones()) + zones = sorted(config.zones(), key=lambda z: z.name) - # Create a list of arguments for zone in zones: + if not zone.hostnames: + continue arguments.extend( _META(zone=zone.name, hostname=_, config=config) for _ in zone.hostnames ) - # Process the data - if bool(multiprocessing) is False: - for argument in arguments: - device(argument) - - else: - # Create a multiprocessing pool of sub process resources - with Pool(processes=pool_size) as pool: - # Create sub processes from the pool - pool.map(device, arguments) - + if not arguments: + log_message = "No devices found in configuration" + log.log2info(1400, log_message) + return -def device(poll, post=True): - """Poll single device for data and create YAML files. + log_message = ( + f"Starting async polling of {len(arguments)} devices " + f"with max concurrency: {max_concurrent_devices}" + ) + log.log2info(1401, log_message) + + # Semaphore to limit concurrent devices + device_semaphore = asyncio.Semaphore(max_concurrent_devices) + + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + tasks = [ + device(argument, device_semaphore, session, post=True) + for argument in arguments + ] + # Execute all devices concurrently + start_time = time.time() + results = await asyncio.gather(*tasks, return_exceptions=True) + end_time = time.time() + + # Process results and log summary + success_count = sum(1 for r in results if r is True) + error_count = sum(1 for r in results if isinstance(r, Exception)) + failed_count = len(results) - success_count - error_count + + log_message = ( + f"Polling completed in {end_time - start_time:.2f}s: " + f"{success_count} succeeded, {failed_count} failed, " + f"{error_count} errors" + ) + log.log2info(1402, log_message) + # Log specific errors + for i, result in enumerate(results): + if isinstance(result, Exception): + hostname = arguments[i].hostname + log_message = f"Device {hostname} polling error: {result}" + log.log2warning(1403, log_message) + + +async def device(poll_meta, device_semaphore, session, post=True): + """Poll each device asynchronously. Args: - poll: _META object - post: Post the data if True, else just print it. + poll_meta: _META object containing zone, hostname, config + device_semaphore: Semaphore to limit concurrent devices + session: aiohttp ClientSession for HTTP requests + post: Post the data if True, else just print it Returns: - None - + bool: True if successful, False otherwise """ - # Initialize key variables - hostname = poll.hostname - zone = poll.zone - config = poll.config - - # Do nothing if the skip file exists - skip_file = files.skip_file(AGENT_POLLER, config) - if os.path.isfile(skip_file) is True: - log_message = """\ -Skip file {} found. Aborting poll for {} in zone "{}". A daemon \ -shutdown request was probably requested""".format( - skip_file, hostname, zone - ) - log.log2debug(1041, log_message) - return + async with device_semaphore: + # Initialize key variables + hostname = poll_meta.hostname + zone = poll_meta.zone + config = poll_meta.config + + # Do nothing if the skip file exists + skip_file = files.skip_file(AGENT_POLLER, config) + if os.path.isfile(skip_file): + log_message = ( + f"Skip file {skip_file} found. Aborting poll for " + f"{hostname} in zone '{zone}'" + ) + log.log2debug(1404, log_message) + return False + + # Poll data for obviously valid hostname + if ( + not hostname + or not isinstance(hostname, str) + or hostname.lower() == "none" + ): + log_message = f"Invalid hostname: {hostname}" + log.log2debug(1405, log_message) + return False + + try: + poll = poller.Poll(hostname) + + # Initialize SNMP connection + if not await poll.initialize_snmp(): + log_message = f"Failed to initialize SNMP for {hostname}" + log.log2debug(1406, log_message) + return False + + # Query device data asynchronously + snmp_data = await poll.query() + + # Process if we get valid data + if bool(snmp_data) and isinstance(snmp_data, dict): + # Process device data + _device = udevice.Device(snmp_data) + data = _device.process() + data["misc"]["zone"] = zone + + if post: + try: + # Construct full URL for posting + url = ( + f"{config.server_url_root()}{API_PREFIX}" + f"{API_POLLER_POST_URI}" + ) + log_message = f"Posting data for {hostname} to {url}" + log.log2debug(1416, log_message) + + async with session.post(url, json=data) as res: + if 200 <= res.status < 300: + log_message = ( + f"Successfully polled and posted data " + f"for {hostname}" + ) + log.log2debug(1407, log_message) + else: + log_message = ( + f"Failed to post data for {hostname}, " + f"status={res.status}" + ) + log.log2warning(1414, log_message) + return False + except aiohttp.ClientError as e: + log_message = ( + f"HTTP error posting data for {hostname}: {e}" + ) + log.log2warning(1415, log_message) + return False - # Poll data for obviously valid hostnames (eg. "None" used in installation) - if bool(hostname) is True: - if isinstance(hostname, str) is True: - if hostname.lower() != "none": - poll = poller.Poll(hostname) - snmp_data = poll.query() - - # Process if we get valid data - if bool(snmp_data) and isinstance(snmp_data, dict): - # Process device data - _device = udevice.Device(snmp_data) - data = _device.process() - data["misc"]["zone"] = zone - - if bool(post) is True: - # Update the database tables with polled data - rest.post(API_POLLER_POST_URI, data, config) - else: - pprint(data) else: - log_message = """\ -Device {} returns no data. Check your connectivity and/or SNMP configuration\ -""".format( - hostname - ) - log.log2debug(1025, log_message) + pprint(data) + return True + else: + log_message = ( + f"Device {hostname} returns no data. Check " + f"connectivity/SNMP configuration" + ) + log.log2debug(1408, log_message) + return False -def cli_device(hostname): - """Poll single device for data and create YAML files. + except (asyncio.TimeoutError, KeyError, ValueError) as e: + log_message = f"Recoverable error polling device {hostname}: {e}" + log.log2warning(1409, log_message) + return False + + +async def cli_device(hostname): + """Poll single device for data - CLI interface. Args: hostname: Host to poll Returns: None - """ # Initialize key variables arguments = [] @@ -136,19 +224,75 @@ def cli_device(hostname): config = ConfigPoller() # Create a list of polling objects - zones = sorted(config.zones()) + zones = sorted(config.zones(), key=lambda z: z.name) # Create a list of arguments for zone in zones: + if not zone.hostnames: + continue for next_hostname in zone.hostnames: if next_hostname == hostname: arguments.append( _META(zone=zone.name, hostname=hostname, config=config) ) - if bool(arguments) is True: - for argument in arguments: - device(argument, post=False) + if arguments: + log_message = ( + f"Found {hostname} in {len(arguments)} zone(s), starting async poll" + ) + log.log2info(1410, log_message) + + # Poll each zone occurrence + semaphore = asyncio.Semaphore(1) + async with aiohttp.ClientSession() as session: + tasks = [ + device(argument, semaphore, session, post=False) + for argument in arguments + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Check results + success_count = sum(1 for r in results if r is True) + if success_count > 0: + log_message = ( + f"Successfully polled {hostname} from " + f"{success_count}/{len(results)} zone(s)" + ) + log.log2info(1411, log_message) + else: + log_message = f"Failed to poll {hostname} from any configured zone" + log.log2warning(1412, log_message) + else: - log_message = "No hostname {} found in configuration".format(hostname) - log.log2see(1036, log_message) + log_message = f"No hostname {hostname} found in configuration" + log.log2see(1413, log_message) + + +def run_devices(max_concurrent_devices=None): + """Run device polling - main entry point. + + Args: + max_concurrent_devices (int, optional): Maximum number of devices to + poll concurrently. If None, uses config.agent_subprocesses(). + + Returns: + None + """ + # Use config if not specified + if max_concurrent_devices is None: + config = ConfigPoller() + max_concurrent_devices = config.agent_subprocesses() + + asyncio.run(devices(max_concurrent_devices)) + + +def run_cli_device(hostname): + """Run CLI device polling - main entry point. + + Args: + hostname (str): The hostname of the device to poll. + + Returns: + None + """ + asyncio.run(cli_device(hostname)) diff --git a/switchmap/poller/snmp/async_poller.py b/switchmap/poller/snmp/async_poller.py deleted file mode 100644 index debb8b32d..000000000 --- a/switchmap/poller/snmp/async_poller.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Asynchronous SNMP Poller module for switchmap-ng.""" - -# Switchmap imports -from switchmap.poller.configuration import ConfigPoller -from switchmap.poller import POLLING_OPTIONS, SNMP, POLL -from . import async_snmp_manager -from . import async_snmp_info -from switchmap.core import log - - -class Poll: - """Asynchronous SNMP poller for switchmap-ng that gathers network data. - - This class manages SNMP credential validation and data querying for - network devices using asynchronous operations for improved - performance and scalability. - - Args: - hostname (str): The hostname or IP address of the device to poll - - Methods: - initialize_snmp(): Validates SNMP credentials and - initializes SNMP interaction - query(): Queries the device for topology data asynchronously - """ - - def __init__(self, hostname): - """Initialize the class. - - Args: - hostname: Hostname to poll - - Returns: - None - """ - # Initialize key variables - self._server_config = ConfigPoller() - self._hostname = hostname - self.snmp_object = None - - async def initialize_snmp(self): - """Initialize SNMP connection asynchronously. - - Returns; - bool: True if successful, False otherwise - """ - # Get snmp config information from Switchmap-NG - validate = async_snmp_manager.Validate( - POLLING_OPTIONS( - hostname=self._hostname, - authorizations=self._server_config.snmp_auth(), - ) - ) - - # Get credentials asynchronously - authorization = await validate.credentials() - - # Create an SNMP object for querying - if _do_poll(authorization) is True: - self.snmp_object = async_snmp_manager.Interact( - POLL(hostname=self._hostname, authorization=authorization) - ) - return True - else: - log_message = ( - "Uncontactable or disabled host {}, or no valid SNMP " - "credentials found in it.".format(self._hostname) - ) - log.log2info(1081, log_message) - return False - - async def query(self): - """Query all remote hosts for data. - - Args: - None - - Returns: - dict: Polled data or None if failed - """ - # Initialize key variables - _data = None - - # Only query if the device is contactable - if bool(self.snmp_object) is False: - log.log2die(1001, f"No valid SNMP object for {self._hostname} ") - return _data - - # Get data - log_message = """\ -Querying topology data from host: {}.""".format( - self._hostname - ) - - log.log2info(1078, log_message) - - status = async_snmp_info.Query(snmp_object=self.snmp_object) - - _data = await status.everything() - - return _data - - -def _do_poll(authorization): - """Determine whether doing a poll is valid. - - Args: - authorization: SNMP object - - Returns: - poll: True if a poll should be done - """ - # Initialize key variables - poll = False - - if bool(authorization) is True: - if isinstance(authorization, SNMP) is True: - poll = bool(authorization.enabled) - - return poll diff --git a/switchmap/poller/snmp/async_snmp_info.py b/switchmap/poller/snmp/async_snmp_info.py deleted file mode 100644 index 122530a46..000000000 --- a/switchmap/poller/snmp/async_snmp_info.py +++ /dev/null @@ -1,467 +0,0 @@ -"""Async module to aggregate query results.""" - -import time -from collections import defaultdict -from switchmap.core import log -import asyncio - -from . import iana_enterprise -from . import get_queries - - -class Query: - """Async class interacts with devices - use existing MIB classes. - - Args: - None - - Returns: - None - """ - - def __init__(self, snmp_object): - """Instantiate the class. - - Args: - snmp_object: SNMP interact class object from async_snmp_manager.py - - Returns: - None - """ - # Define query object - self.snmp_object = snmp_object - - async def everything(self): - """Get all information from device. - - Args: - None - - Returns: - data: Aggregated data - """ - # Initialize key variables - data = {} - - # Run all sections concurrently - results = await asyncio.gather( - self.misc(), - self.system(), - self.layer1(), - self.layer2(), - self.layer3(), - return_exceptions=True, - ) - - keys = ["misc", "system", "layer1", "layer2", "layer3"] - for key, result in zip(keys, results): - if isinstance(result, Exception): - log.warning(f"{key} failed: {result}") - elif result: - data[key] = result - - # Return - return data - - async def misc(self): - """Provide miscellaneous information about the device and the poll.""" - # Initialize data - data = defaultdict(lambda: defaultdict(dict)) - data["timestamp"] = int(time.time()) - data["host"] = self.snmp_object.hostname() - - # Get vendor information - sysobjectid = await self.snmp_object.sysobjectid() - vendor = iana_enterprise.Query(sysobjectid=sysobjectid) - data["IANAEnterpriseNumber"] = vendor.enterprise() - - return data - - async def system(self): - """Get all system information from device. - - Args: - None - - Returns: - data: Aggregated system data - """ - # Initialize key variables - data = defaultdict(lambda: defaultdict(dict)) - processed = False - - # Get system information from various MIB classes - system_queries = get_queries("system") - - # Create all query instances - query_items = [ - (query_class(self.snmp_object), query_class.__name__) - for query_class in system_queries - ] - - # Check if supported - support_results = await asyncio.gather( - *[item.supported() for item, _ in query_items] - ) - - supported_items = [ - (item, name) - for (item, name), supported in zip(query_items, support_results) - if supported - ] - - if supported_items: - results = await asyncio.gather( - *[ - _add_system(item, defaultdict(lambda: defaultdict(dict))) - for item, _ in supported_items - ] - ) - - # Merge results - for result in results: - for key, value in result.items(): - data[key].update(value) - processed = True - - if processed is True: - return data - else: - return None - - async def layer1(self): - """Get all layer 1 information from device. - - Args: - None - - Returns: - data: Aggregated layer1 data - """ - # Initialize key values - data = defaultdict(lambda: defaultdict(dict)) - processed = False - - layer1_queries = get_queries("layer1") - - query_items = [ - (query_class(self.snmp_object), query_class.__name__) - for query_class in layer1_queries - ] - - # Concurrent support check - support_results = await asyncio.gather( - *[item.supported() for item, _ in query_items] - ) - - supported_items = [ - (item, name) - for (item, name), supported in zip(query_items, support_results) - if supported - ] - - if supported_items: - results = await asyncio.gather( - *[ - _add_layer1(item, defaultdict(lambda: defaultdict(dict))) - for item, _ in supported_items - ], - return_exceptions=True, - ) - - for i, result in enumerate(results): - if isinstance(result, Exception): - item_name = supported_items[i][1] - log.log2warning( - 1005, f"Layer1 error in {item_name}: {result}" - ) - continue - - for key, value in result.items(): - data[key].update(value) - - processed = True - - # Return - if processed is True: - return data - else: - return None - - async def layer2(self): - """Get all layer 2 information from device. - - Args: - None - - Returns: - data: Aggregated layer2 data - """ - # Initialize key variables - data = defaultdict(lambda: defaultdict(dict)) - processed = False - - # Get layer2 information from MIB classes - layer2_queries = get_queries("layer2") - - query_items = [ - (query_class(self.snmp_object), query_class.__name__) - for query_class in layer2_queries - ] - - support_results = await asyncio.gather( - *[item.supported() for item, _ in query_items] - ) - - # Filter supported MIBs - supported_items = [ - (item, name) - for (item, name), supported in zip(query_items, support_results) - if supported - ] - - if supported_items: - # Concurrent processing - results = await asyncio.gather( - *[ - _add_layer2(item, defaultdict(lambda: defaultdict(dict))) - for item, _ in supported_items - ], - return_exceptions=True, - ) - - for i, result in enumerate(results): - if isinstance(result, Exception): - item_name = supported_items[i][1] - log.log2warning( - 1007, f"Layer2 error in {item_name}: {result}" - ) - continue - - # Merge this MIB's complete results - for key, value in result.items(): - data[key].update(value) - - processed = True - - # Return - - if processed is True: - return data - else: - return None - - async def layer3(self): - """Get all layer3 information from device. - - Args: - None - - Returns: - data: Aggregated layer3 data - """ - # Initialize key variables - data = defaultdict(lambda: defaultdict(dict)) - processed = False - - # Get layer3 information from MIB classes - layer3_queries = get_queries("layer3") - - query_items = [ - (query_class(self.snmp_object), query_class.__name__) - for query_class in layer3_queries - ] - - support_results = await asyncio.gather( - *[item.supported() for item, _ in query_items] - ) - - # Filter supported MIBs - supported_items = [ - (item, name) - for (item, name), supported in zip(query_items, support_results) - if supported - ] - - if supported_items: - # Concurrent processing - results = await asyncio.gather( - *[ - _add_layer3(item, defaultdict(lambda: defaultdict(dict))) - for item, _ in supported_items - ], - return_exceptions=True, - ) - - for i, result in enumerate(results): - if isinstance(result, Exception): - item_name = supported_items[i][1] - log.log2warning( - 1006, f"Layer3 error in {item_name}: {result}" - ) - continue - - # Merge this MIB's complete results - for key, value in result.items(): - data[key].update(value) - - processed = True - - if processed is True: - return data - return None - - -async def _add_data(source, target): - """Add data from source to target dict. Both dicts must have two keys. - - Args: - source: Source dict - target: Target dict - - Returns: - target: Aggregated data - """ - # Process data - for primary in source.keys(): - for secondary, value in source[primary].items(): - target[primary][secondary] = value - - # Return - return target - - -async def _add_system(query, data): - """Add data from successful system MIB query to original data provided. - - Args: - query: MIB query object - data: Three keyed dict of data - - Returns: - data: Aggregated data - """ - try: - result = None - - if asyncio.iscoroutinefunction(query.system): - result = await query.system() - else: - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, query.system) - - # Merge only if we have data - if not result: - return data - for primary, secondary_map in result.items(): - if isinstance(secondary_map, dict): - for secondary, maybe_tertiary in secondary_map.items(): - if isinstance(maybe_tertiary, dict): - for tertiary, value in maybe_tertiary.items(): - data[primary][secondary][tertiary] = value - else: - data[primary][secondary] = maybe_tertiary - else: - # Handle case where secondary level is not a dict - data[primary] = secondary_map - - return data - except Exception as e: - log.log2warning(1320, f"Error in _add_system: {e}") - return data - - -async def _add_layer1(query, data): - """Add data from successful layer1 MIB query to original data provided. - - Args: - query: MIB query object - data: dict of data - - Returns: - data: Aggregated data - """ - try: - mib_name = query.__class__.__name__ - - result = None - if asyncio.iscoroutinefunction(query.layer1): - result = await query.layer1() - else: - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, query.layer1) - - if result: - data = await _add_data(result, data) - else: - log.log2debug(1302, f" No layer1 data returned for {mib_name}") - - return data - - except Exception as e: - log.log2warning(1316, f" Error in _add_layer1 for {mib_name}: {e}") - return data - - -async def _add_layer2(query, data): - """Add data from successful layer2 MIB query to original data provided. - - Args: - query: MIB query object - data: dict of data - - Returns: - data: Aggregated data - """ - try: - mib_name = query.__class__.__name__ - result = None - if asyncio.iscoroutinefunction(query.layer2): - result = await query.layer2() - else: - - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, query.layer2) - - if result: - data = await _add_data(result, data) - else: - log.log2debug(1306, f" No layer2 data returned for {mib_name}") - - return data - - except Exception as e: - log.log2warning(1308, f" Error in _add_layer2 for {mib_name}: {e}") - return data - - -async def _add_layer3(query, data): - """Add data from successful layer3 MIB query to original data provided. - - Args: - query: MIB query object - data: dict of data - - Returns: - data: Aggregated data - """ - try: - mib_name = query.__class__.__name__ - - result = None - if asyncio.iscoroutinefunction(query.layer3): - result = await query.layer3() - else: - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, query.layer3) - - if result: - data = await _add_data(result, data) - else: - log.log2debug(1309, f" No layer3 data returned for {mib_name}") - - return data - - except Exception as e: - log.log2warning(1310, f" Error in _add_layer3 for {mib_name}: {e}") - return data diff --git a/switchmap/poller/snmp/async_snmp_manager.py b/switchmap/poller/snmp/async_snmp_manager.py deleted file mode 100644 index 7d2945c65..000000000 --- a/switchmap/poller/snmp/async_snmp_manager.py +++ /dev/null @@ -1,1128 +0,0 @@ -"""Async SNMP manager class.""" - -import os -import asyncio - - -# import project libraries -from switchmap.core import log -from switchmap.core import files -from switchmap.poller import POLL -from switchmap.poller.configuration import ConfigPoller - -from . import iana_enterprise - -from pysnmp.hlapi.asyncio import ( - SnmpEngine, - CommunityData, - UdpTransportTarget, - ContextData, - ObjectType, - ObjectIdentity, - getCmd, - nextCmd, - bulkCmd, - UsmUserData, - # Authentication protocols - usmHMACMD5AuthProtocol, - usmHMACSHAAuthProtocol, - usmHMAC128SHA224AuthProtocol, - usmHMAC192SHA256AuthProtocol, - usmHMAC256SHA384AuthProtocol, - usmHMAC384SHA512AuthProtocol, - # Privacy protocols - usmDESPrivProtocol, - usmAesCfb128Protocol, - usmAesCfb192Protocol, - usmAesCfb256Protocol, -) - -from pysnmp.error import PySnmpError -from pysnmp.proto.rfc1905 import EndOfMibView, NoSuchInstance, NoSuchObject - - -class Validate: - """Class to validate SNMP data asynchronously.""" - - def __init__(self, options): - """Initialize the Validate class. - - Args: - options: POLLING_OPTIONS object containing SNMP configuration. - - Returns: - None - """ - self._options = options - - async def credentials(self): - """Determine valid SNMP credentials for a host. - - Args: - None - - Returns: - authentication: SNMP authorization object containing valid - credentials, or None if no valid credentials found - """ - # Initialize key variables - cache_exists = False - filename = files.snmp_file(self._options.hostname, ConfigPoller()) - if os.path.exists(filename): - cache_exists = True - group = None - - if cache_exists is False: - authentication = await self.validation() - - # Save credentials if successful - if bool(authentication): - _update_cache(filename, authentication.group) - else: - # Read credentials from cache - if os.path.isfile(filename): - with open(filename) as f_handle: - group = f_handle.readline().strip() or None - - # Get Credentials - authentication = await self.validation(group) - - # Try the rest if the credentials fail - if bool(authentication) is False: - authentication = await self.validation() - - # update cache if found - if bool(authentication): - _update_cache(filename, authentication.group) - - return authentication - - async def validation(self, group=None): - """Determine valid SNMP authorization for a host. - - Args: - group: String containing SNMP group name to try, or None to try all - groups - - Returns: - result: SNMP authorization object if valid credentials found, - None otherwise - """ - # Initialize key variables - result = None - - # Probe device with all SNMP options - for authorization in self._options.authorizations: - # Only process enabled SNMP values - if bool(authorization.enabled) is False: - continue - - # Setup contact with the remote device - device = Interact( - POLL( - hostname=self._options.hostname, authorization=authorization - ) - ) - # Try successive groups check if device is contactable - if group is None: - if await device.contactable() is True: - result = authorization - break - else: - if authorization.group == group: - if await device.contactable() is True: - result = authorization - - return result - - -class Interact: - """Class Gets SNMP data.""" - - def __init__(self, _poll): - """Initialize the Interact class. - - Args: - _poll: POLL object containing SNMP configuration and target info - - Returns: - None - """ - # Initialize key variables - self._poll = _poll - self._engine = SnmpEngine() - - # Rate Limiting - self._semaphore = asyncio.Semaphore(10) - - # Fail if there is no authentication - if bool(self._poll.authorization) is False: - log_message = ( - "SNMP parameters provided are either blank or missing." - "Non existent host?" - ) - log.log2die(1045, log_message) - - async def enterprise_number(self): - """Get SNMP enterprise number for the device. - - Args: - None - - Returns: - int: SNMP enterprise number identifying the device vendor - """ - # Get the sysObjectID.0 value of the device - sysid = await self.sysobjectid() - - # Get the vendor ID - enterprise_obj = iana_enterprise.Query(sysobjectid=sysid) - enterprise = enterprise_obj.enterprise() - - return enterprise - - def hostname(self): - """Get SNMP hostname for the interaction. - - Args: - None - - Returns: - str: Hostname of the target device - """ - return self._poll.hostname - - async def contactable(self): - """Check if device is reachable via SNMP. - - Args: - None - - Returns: - bool: True if device responds to SNMP queries, False otherwise - """ - # key variables - contactable = False - result = None - - # Try to reach device - try: - # Test if we can poll the SNMP sysObjectID - # if true, then the device is contactable - result = await self.sysobjectid(check_reachability=True) - if bool(result) is True: - contactable = True - - except Exception: - # Not Contactable - contactable = False - - return contactable - - async def sysobjectid(self, check_reachability=False): - """Get the sysObjectID of the device. - - Args: - check_reachability: Boolean indicating whether to test connectivity. - Some session errors are ignored to return null result. - - Returns: - str: sysObjectID value as string, or None if not available - """ - # Initialize key variables - oid = ".1.3.6.1.2.1.1.2.0" - object_id = None - - # Get sysObjectID - results = await self.get(oid, check_reachability) - # Pysnmp already returns out value as value unlike easysnmp - if bool(results) is True: - # Both formats: with and without leading dot - object_id = results.get(oid) - if object_id is None: - oid_without_dot = oid.lstrip(".") - object_id = results.get(oid_without_dot) - - # Convert bytes to string if needed - if isinstance(object_id, bytes): - object_id = object_id.decode("utf-8") - - return object_id - - async def oid_exists(self, oid_to_get, context_name=""): - """Determine if an OID exists on the device. - - Args: - oid_to_get: String containing OID to check - context_name: String containing SNMPv3 context name. - Default is empty string. - - Returns: - bool: True if OID exists, False otherwise - """ - try: - # Initialize key - validity = False - - # Validate OID - if ( - await self._oid_exists_get( - oid_to_get, context_name=context_name - ) - is True - ): - validity = True - if validity is False: - if ( - await self._oid_exists_walk( - oid_to_get, context_name=context_name - ) - is True - ): - validity = True - - return validity - except Exception as e: - log.log2warning( - 1305, f"OID existence check failed for {oid_to_get}: {e}" - ) - return False - - async def _oid_exists_get(self, oid_to_get, context_name=""): - """Determine existence of OID on device. - - Args: - oid_to_get: OID to get - context_name: Set the contextName used for SNMPv3 messages. - The default contextName is the empty string "". Overrides the - defContext token in the snmp.conf file. - - Returns: - validity: True if exists - """ - try: - validity = False - - (_, exists, result) = await self.query( - oid_to_get, - get=True, - check_reachability=True, - check_existence=True, - context_name=context_name, - ) - - if exists and bool(result): - # Make sure the OID key exists in result - if isinstance(result, dict) and oid_to_get in result: - if result[oid_to_get] is not None: - validity = True - elif isinstance(result, dict) and result: - # If result has data but not exact OID, still consider valid - validity = True - - return validity - except Exception as e: - log.log2warning( - 1305, f"OID existence check failed for {oid_to_get}: {e}" - ) - return False - - async def _oid_exists_walk(self, oid_to_get, context_name=""): - """Check OID existence on device using WALK. - - Args: - oid_to_get: OID to get - context_name: Set the contextName used for SNMPv3 messages. - The default contextName is the empty string "". Overrides the - defContext token in the snmp.conf file. - - Returns: - validity: True if exist - """ - try: - (_, exists, results) = await self.query( - oid_to_get, - get=False, - check_existence=True, - context_name=context_name, - check_reachability=True, - ) - # Check if we get valid results - if exists and isinstance(results, dict) and results: - return True - return False - except Exception as e: - log.log2warning( - 1306, f"Walk existence check failed for {oid_to_get}: {e}" - ) - return False - - async def get( - self, - oid_to_get, - check_reachability=False, - check_existence=False, - normalized=False, - context_name="", - ): - """Do an SNMPget. - - Args: - oid_to_get: OID to get - check_reachability: Set if testing for connectivity. Some session - errors are ignored so that a null result is returned - check_existence: Set if checking for the existence of the OID - normalized: If True, then return results as a dict keyed by - only the last node of an OID, otherwise return results - keyed by the entire OID string. Normalization is useful - when trying to create multidimensional dicts where the - primary key is a universal value such as IF-MIB::ifIndex - or BRIDGE-MIB::dot1dBasePort - context_name: Set the contextName used for SNMPv3 messages. - The default contextName is the empty string "". Overrides the - defContext token in the snmp.conf file. - - Returns: - result: Dictionary of {OID: value} pairs - """ - (_, _, result) = await self.query( - oid_to_get, - get=True, - check_reachability=check_reachability, - check_existence=check_existence, - normalized=normalized, - context_name=context_name, - ) - return result - - async def walk( - self, - oid_to_get, - normalized=False, - check_reachability=False, - check_existence=False, - context_name="", - safe=False, - ): - """Do an async SNMPwalk. - - Args: - oid_to_get: OID to walk - normalized: If True, then return results as a dict keyed by - only the last node of an OID, otherwise return results - keyed by the entire OID string. Normalization is useful - when trying to create multidimensional dicts where the - primary key is a universal value such as IF-MIB::ifIndex - or BRIDGE-MIB::dot1dBasePort - check_reachability: - Set if testing for connectivity. Some session - errors are ignored so that a null result is returned - check_existence: - Set if checking for the existence of the OID - context_name: Set the contextName used for SNMPv3 messages. - The default contextName is the empty string "". Overrides the - defContext token in the snmp.conf file. - safe: Safe query if true. If there is an exception, then return \ - blank values. - - Returns: - result: Dictionary of tuples (OID, value) - """ - (_, _, result) = await self.query( - oid_to_get, - get=False, - check_reachability=check_reachability, - check_existence=check_existence, - normalized=normalized, - context_name=context_name, - safe=safe, - ) - - return result - - async def swalk(self, oid_to_get, normalized=False, context_name=""): - """Perform a safe async SNMPwalk that handles errors gracefully. - - Args: - oid_to_get: OID to get - normalized: If True, then return results as a dict keyed by - only the last node of an OID, otherwise return results - keyed by the entire OID string. Normalization is useful - when trying to create multidimensional dicts where the - primary key is a universal value such as IF-MIB::ifIndex - or BRIDGE-MIB::dot1dBasePort - context_name: Set the contextName used for SNMPv3 messages. - The default contextName is the empty string "". Overrides the - defContext token in the snmp.conf file. - - Returns: - dict: Results of SNMP walk as OID-value pairs - """ - # Process data - return await self.walk( - oid_to_get, - normalized=normalized, - check_reachability=True, - check_existence=True, - context_name=context_name, - safe=True, - ) - - async def query( - self, - oid_to_get, - get=False, - check_reachability=False, - check_existence=False, - normalized=False, - context_name="", - safe=False, - ): - """Do an SNMP query. - - Args: - oid_to_get: OID to walk - get: Flag determining whether to do a GET or WALK - check_reachability: Set if testing for connectivity. Some session - errors are ignored so that a null result is returned - check_existence: Set if checking for the existence of the OID - normalized: If True, then return results as a dict keyed by - only the last node of an OID, otherwise return results - keyed by the entire OID string. Normalization is useful - when trying to create multidimensional dicts where the - primary key is a universal value such as IF-MIB::ifIndex - or BRIDGE-MIB::dot1dBasePort - context_name: Set the contextName used for SNMPv3 messages. - The default contextName is the empty string "". Overrides the - defContext token in the snmp.conf file. - safe: Safe query if true. If there is an exception, then return\ - blank values. - - Returns: - return_value: Tuple of (_contactable, exists, values) - """ - # Initialize variables - _contactable = True - exists = True - results = [] - # Initialize formatted_result to avoid undefined variable error - formatted_result = {} - - # Check if OID is valid - if _oid_valid_format(oid_to_get) is False: - log_message = "OID {} has an invalid format".format(oid_to_get) - log.log2die(1057, log_message) - - # Get session parameters - async with self._semaphore: - try: - # Create SNMP session - session = Session( - self._poll, self._engine, context_name=context_name - ) - - # Use shorter timeouts for walk operations - auth_data, transport_target = await session._session( - walk_operation=(not get) - ) - context_data = ContextData(contextName=context_name) - - # Perform the SNMP operation - if get is True: - results = await session._do_async_get( - oid_to_get, auth_data, transport_target, context_data - ) - else: - results = await session._do_async_walk( - oid_to_get, auth_data, transport_target, context_data - ) - - formatted_result = _format_results( - results, oid_to_get, normalized=normalized - ) - - except PySnmpError as exception_error: - # Handle PySNMP errors similar to sync version - if check_reachability is True: - _contactable = False - exists = False - elif check_existence is True: - exists = False - elif safe is True: - _contactable = None - exists = None - log_message = ( - f"Async SNMP error for {self._poll.hostname}: " - f"{exception_error}" - ) - log.log2info(1209, log_message) - else: - log_message = ( - f"Async SNMP error for {self._poll.hostname}: " - f"{exception_error}" - ) - log.log2die(1003, log_message) - # Ensure formatted_result is set for exception cases - formatted_result = {} - - except Exception as exception_error: - # Handle unexpected errors - if safe is True: - _contactable = None - exists = None - log_message = ( - f"Unexpected async SNMP error for " - f"{self._poll.hostname}: {exception_error}" - ) - log.log2info(1210, log_message) - else: - log_message = ( - f"Unexpected async SNMP error for " - f"{self._poll.hostname}: {exception_error}" - ) - log.log2die(1003, log_message) - # Ensure formatted_result is set for exception cases - formatted_result = {} - - # Return - values = (_contactable, exists, formatted_result) - return values - - -class Session: - """Class to create a SNMP session with a device.""" - - def __init__(self, _poll, engine, context_name=""): - """Initialize the _Session class. - - Args: - _poll: POLL object containing SNMP configuration - engine: SNMP engine object - context_name: String containing SNMPv3 context name. - Default is empty string. - - Returns: - session: SNMP session - """ - # Assign variables - self.context_name = context_name - self._poll = _poll - self._engine = engine - - # Fail if there is no authentication - if bool(self._poll.authorization) is False: - log_message = ( - "SNMP parameters provided are blank. None existent host? " - ) - log.log2die(1046, log_message) - - async def _session(self, walk_operation=False): - """Create SNMP session parameters based on configuration. - - Returns: - Tuple of (auth_data, transport_target) - """ - # Initialize key variables - auth = self._poll.authorization - - # Use shorter timeouts for walk operations to prevent hanging - if walk_operation: - timeout = 3 - retries = 1 - else: - # Normal timeout for GET operations - timeout = 10 - retries = 3 - - # Create transport target - transport_target = UdpTransportTarget( - (self._poll.hostname, auth.port), timeout=timeout, retries=retries - ) - - # Create authentication data based on SNMP version - if auth.version == 3: - # SNMPv3 with USM - # If authprotocol/privprotocol is None/False/Empty, leave as None - auth_protocol = None - priv_protocol = None - - # Set auth protocol only if authprotocol is specified - if auth.authprotocol: - auth_proto = auth.authprotocol.lower() - if auth_proto == "md5": - auth_protocol = usmHMACMD5AuthProtocol - elif auth_proto == "sha1" or auth_proto == "sha": - auth_protocol = usmHMACSHAAuthProtocol - elif auth_proto == "sha224": - auth_protocol = usmHMAC128SHA224AuthProtocol - elif auth_proto == "sha256": - auth_protocol = usmHMAC192SHA256AuthProtocol - elif auth_proto == "sha384": - auth_protocol = usmHMAC256SHA384AuthProtocol - elif auth_proto == "sha512": - auth_protocol = usmHMAC384SHA512AuthProtocol - else: - # Default to SHA-256 for better security - auth_protocol = usmHMAC192SHA256AuthProtocol - - # Set privacy protocol only if privprotocol is specified - # Also if we have authentication (privacy requires authentication) - if auth.privprotocol and auth_protocol is not None: - priv_proto = auth.privprotocol.lower() - if priv_proto == "des": - priv_protocol = usmDESPrivProtocol - elif priv_proto == "aes128" or priv_proto == "aes": - priv_protocol = usmAesCfb128Protocol - elif priv_proto == "aes192": - priv_protocol = usmAesCfb192Protocol - elif priv_proto == "aes256": - priv_protocol = usmAesCfb256Protocol - else: - # Default to AES-256 for best security - priv_protocol = usmAesCfb256Protocol - - auth_data = UsmUserData( - userName=auth.secname, - authKey=auth.authpassword, - privKey=auth.privpassword, - authProtocol=auth_protocol, - privProtocol=priv_protocol, - ) - else: - # SNMPv1/v2c with community - mp_model = 0 if auth.version == 1 else 1 - auth_data = CommunityData(auth.community, mpModel=mp_model) - - return auth_data, transport_target - - async def _do_async_get( - self, oid, auth_data, transport_target, context_data - ): - """Pure async SNMP GET using pysnmp.""" - error_indication, error_status, error_index, var_binds = await getCmd( - self._engine, - auth_data, - transport_target, - context_data, - ObjectType(ObjectIdentity(oid)), - ) - - if error_indication: - raise PySnmpError(f"SNMP GET error: {error_indication}") - elif error_status: - raise PySnmpError(f"SNMP GET error status: {error_status}") - - # Return in object format expected by _format_results - results = [] - for var_bind in var_binds: - oid_str = str(var_bind[0]) - value = var_bind[1] - results.append((oid_str, value)) - - return results - - async def _do_async_walk( - self, oid_prefix, auth_data, transport_target, context_data - ): - """Pure async SNMP WALK using pysnmp async capabilities.""" - # Initialize key variables - results = [] - - # Use correct walk method based on SNMP version - if hasattr(auth_data, "mpModel") and auth_data.mpModel == 0: - # SNMPv1 - use nextCMD - results = await self._async_walk_v1( - oid_prefix, auth_data, transport_target, context_data - ) - else: - # SNMPv2c/v3 - use bulkCmd - - try: - results = await asyncio.wait_for( - self._async_walk_v2( - oid_prefix, auth_data, transport_target, context_data - ), - timeout=60.0, - ) - except asyncio.TimeoutError: - log.log2info( - 1011, f"bulk walk timeout after 60s for prefix {oid_prefix}" - ) - # Fallback to SNMPv1 walk which would be more reliable - results = await self._async_walk_v1( - oid_prefix, auth_data, transport_target, context_data - ) - - return results - - async def _async_walk_v1( - self, oid_prefix, auth_data, transport_target, context_data - ): - """Pure async walk for SNMPv1 using nextCmd.""" - # Initialize key variables - results = [] - - try: - async for ( - error_indication, - error_status, - error_index, - var_binds, - ) in nextCmd( - self._engine, - auth_data, - transport_target, - context_data, - ObjectType(ObjectIdentity(oid_prefix)), - lexicographicMode=False, - ): - # Handle errors first - if error_indication: - log.log2warning( - 1216, - f"SNMP v1 walk network error for {oid_prefix}: " - f"{error_indication}.", - ) - break - - elif error_status: - log.log2info( - 1217, - f"SNMP v1 walk protocol error for {oid_prefix}: " - f"{error_status} at index {error_index}", - ) - - # Handle specific SNMP errors - error_msg = error_status.prettyPrint() - if error_msg == "noSuchName": - # This OID doesn't exist, try next - continue - else: - # Other errors are usually fatal - break - - # Process successful response - for oid, value in var_binds: - oid_str = str(oid) - prefix_normalized = str(oid_prefix).lstrip(".") - oid_normalized = oid_str.lstrip(".") - if not oid_normalized.startswith(prefix_normalized): - log.log2debug( - 1220, - f"Reached end of OID tree for prefix {oid_prefix}", - ) - return results - results.append((oid_str, value)) - - # Return results after the loop completes - return results - - except Exception as e: - log.log2warning( - 1222, f"Unexpected error in SNMP v1 walk for {oid_prefix}: {e}." - ) - return results - - async def _async_walk_v2( - self, oid_prefix, auth_data, transport_target, context_data - ): - """Async walk for SNMPv2c/v3 using bulkCmd.""" - # Initialize key variables - results = [] - current_oids = [ObjectType(ObjectIdentity(oid_prefix))] - - try: - # !checking for 50, 100 would be too long to prevent from hanging - max_iterations = 50 - iterations = 0 - consecutive_empty_responses = 0 - # Stop after 3 consecutive empty responses - max_empty_responses = 3 - - while current_oids and iterations < max_iterations: - iterations += 1 - # non-repeaters = 0 , max_repetitions = 25 - error_indication, error_status, error_index, var_bind_table = ( - await bulkCmd( - self._engine, - auth_data, - transport_target, - context_data, - 0, - 25, - *current_oids, - ) - ) - - if error_indication: - log.log2info( - 1211, f"BULK error indication: {error_indication}" - ) - break - elif error_status: - log.log2info( - 1212, - f"BULK error status: {error_status.prettyPrint()} " - f"at {error_index}", - ) - break - - # Check if we got any response - if not var_bind_table: - consecutive_empty_responses += 1 - if consecutive_empty_responses >= max_empty_responses: - break - continue - else: - consecutive_empty_responses = 0 - - # Process the response - found_valid_data = False - prefix_normalized = str(oid_prefix).lstrip(".") - - for var_bind in var_bind_table: - if not var_bind or len(var_bind) == 0: - continue - - # Get the ObjectType from the list - for obj_type in var_bind: - oid, value = obj_type[0], obj_type[1] - - # Check for end of MIB - if isinstance(value, EndOfMibView): - continue - oid_str = str(oid) - - oid_normalized = oid_str.lstrip(".") - if not oid_normalized.startswith(prefix_normalized): - continue - results.append((oid_str, value)) - found_valid_data = True - - # Advance the walk using only the last row's OIDs - next_oids = [] - if var_bind_table: - last_row = var_bind_table[-1] - for obj_type in last_row: - oid, value = obj_type[0], obj_type[1] - if isinstance(value, EndOfMibView): - continue - oid_str = str(oid) - if oid_str.lstrip(".").startswith(prefix_normalized): - next_oids.append(ObjectType(ObjectIdentity(oid))) - - if not found_valid_data: - log.log2info( - 1213, - f"BULK walk: No more valid data for prefix " - f"{oid_prefix}", - ) - break - - current_oids = next_oids - - # In case, we get too many results - if len(results) > 10000: - log.log2warning( - 1214, - f"Stopping after collecting {len(results)} results " - f"(safety limit)", - ) - break - - except Exception as e: - log.log2warning(1215, f"BULK walk error: {e}") - return await self._async_walk_v1( - oid_prefix, auth_data, transport_target, context_data - ) - - return results - - -def _oid_valid_format(oid): - """Validate OID string format matching sync version. - - Args: - oid: String containing OID to validate - - Returns: - bool: True if OID format is valid, False otherwise - """ - # oid cannot be numeric - if isinstance(oid, str) is False: - return False - - # Make sure that oid is not blank - stripped_oid = oid.strip() - if not stripped_oid: - return False - - # Must start with a '.' - if oid[0] != ".": - return False - - # Must not end with a '.' - if oid[-1] == ".": - return False - - # Test each octet to be numeric - octets = oid.split(".") - - # Remove the first element of the list - octets.pop(0) - for value in octets: - try: - int(value) - except (ValueError, TypeError): - return False - - # Otherwise valid - return True - - -def _convert(value): - """Convert SNMP value from pysnmp object to Python type. - - Args: - value: pysnmp value object - - Returns: - converted: Value converted to appropriate Python type (bytes or int), - or None for null/empty values - """ - # Handle pysnmp exception values - if isinstance(value, NoSuchObject): - return None - if isinstance(value, NoSuchInstance): - return None - if isinstance(value, EndOfMibView): - return None - - if hasattr(value, "prettyPrint"): - value_str = value.prettyPrint() - - # Determine type based on pysnmp object type - value_type = type(value).__name__ - - # Handle string-like types - Convert to types for MIB compatibility - if any( - t in value_type - for t in [ - "OctetString", - "DisplayString", - "Opaque", - "Bits", - "IpAddress", - "ObjectIdentifier", - ] - ): - # For objectID, convert to string first then to bytes - if "ObjectIdentifier" in value_type: - return bytes(str(value_str), "utf-8") - else: - return bytes(value_str, "utf-8") - - # Handle integer types - elif any( - t in value_type - for t in ["Integer", "Counter", "Gauge", "TimeTicks", "Unsigned"] - ): - try: - return int(value_str) - except ValueError: - # Direct int conversion of the obj if prettyPrint fails - if hasattr(value, "__int__"): - try: - return int(value) - except (ValueError, TypeError): - pass - - # Accessing .value attr directly - if hasattr(value, "value"): - try: - return int(value.value) - except (ValueError, TypeError): - pass - - log_message = ( - f"Failed to convert pysnmp integer value: " - f"{value_type}, prettyPrint'{value_str}" - ) - log.log2warning(1059, log_message) - return None - - # Handle direct access to value (for objects without prettyPrint) - if hasattr(value, "value"): - try: - return int(value.value) - except (ValueError, TypeError): - return bytes(str(value.value), "utf-8") - - # Default Fallback - convert to string then to bytes - try: - return bytes(str(value), "utf-8") - except Exception: - return None - - -def _format_results(results, mock_filter, normalized=False): - """Normalize and format SNMP results. - - Args: - results: List of (OID, value) tuples from pysnmp - mock_filter: The original OID to get. Facilitates unittesting by - filtering Mock values. - normalized: If True, then return results as a dict keyed by - only the last node of an OID, otherwise return results - keyed by the entire OID string. Normalization is useful - when trying to create multidimensional dicts where the - primary key is a universal value such as IF-MIB::ifIndex - or BRIDGE-MIB::dot1dBasePort - - Returns: - dict: Formatted results as OID-value pairs - """ - # Initialize key variables - formatted = {} - - for oid_str, value in results: - - # Normalize both OIDs for comparison to handle leading dot mismatch - if mock_filter: - # Remove leading dots for comparison - filter_normalized = mock_filter.lstrip(".") - oid_normalized = oid_str.lstrip(".") - - if not oid_normalized.startswith(filter_normalized): - continue - - # convert value using proper type conversion - converted_value = _convert(value=value) - - if normalized is True: - # use only the last node of the OID - key = oid_str.split(".")[-1] - else: - key = oid_str - - formatted[key] = converted_value - - return formatted - - -def _update_cache(filename, group): - """Update SNMP credentials cache file. - - Args: - filename: String containing path to cache file - group: String containing SNMP group name to cache - - Returns: - None - """ - try: - with open(filename, "w+") as f_handle: - f_handle.write(group) - except Exception as e: - log_message = f"Failed to update cache file {filename}: {e}" - log.log2warning(1049, log_message) diff --git a/switchmap/poller/snmp/mib/generic/mib_if.py b/switchmap/poller/snmp/mib/generic/mib_if.py index f979c36d3..9d4053594 100644 --- a/switchmap/poller/snmp/mib/generic/mib_if.py +++ b/switchmap/poller/snmp/mib/generic/mib_if.py @@ -109,7 +109,7 @@ async def limited_query(method, name): try: return name, await method() except Exception as e: - log.log2warning(1301, f"Error in {name}: {e}") + log.log2warning(1092, f"Error in {name}: {e}") return name, {} queries = [ diff --git a/switchmap/poller/snmp/poller.py b/switchmap/poller/snmp/poller.py index 3be2f2b30..f03dea232 100644 --- a/switchmap/poller/snmp/poller.py +++ b/switchmap/poller/snmp/poller.py @@ -1,10 +1,10 @@ -"""SNMP Poller module.""" +"""Asynchronous SNMP Poller module for switchmap-ng.""" # Switchmap imports from switchmap.poller.configuration import ConfigPoller from switchmap.poller import POLLING_OPTIONS, SNMP, POLL -from . import snmp_info from . import snmp_manager +from . import snmp_info from switchmap.core import log @@ -12,15 +12,15 @@ class Poll: """Asynchronous SNMP poller for switchmap-ng that gathers network data. This class manages SNMP credential validation and data querying for - network devices using asynchronous operations for improved performance - and scalability. + network devices using asynchronous operations for improved + performance and scalability. Args: hostname (str): The hostname or IP address of the device to poll Methods: - initialize_snmp(): Validates SNMP credentials and initializes SNMP - interaction + initialize_snmp(): Validates SNMP credentials and + initializes SNMP interaction query(): Queries the device for topology data asynchronously """ @@ -32,64 +32,72 @@ def __init__(self, hostname): Returns: None - """ # Initialize key variables self._server_config = ConfigPoller() self._hostname = hostname - self._snmp_object = None + self.snmp_object = None + + async def initialize_snmp(self): + """Initialize SNMP connection asynchronously. - # Get snmp configuration information from Switchmap-NG + Returns: + bool: True if successful, False otherwise + """ + # Get snmp config information from Switchmap-NG validate = snmp_manager.Validate( POLLING_OPTIONS( - hostname=hostname, + hostname=self._hostname, authorizations=self._server_config.snmp_auth(), ) ) - authorization = validate.credentials() + + # Get credentials asynchronously + authorization = await validate.credentials() # Create an SNMP object for querying if _do_poll(authorization) is True: - self._snmp_object = snmp_manager.Interact( - POLL( - hostname=hostname, - authorization=authorization, - ) + self.snmp_object = snmp_manager.Interact( + POLL(hostname=self._hostname, authorization=authorization) ) + return True else: log_message = ( "Uncontactable or disabled host {}, or no valid SNMP " - "credentials found for it.".format(self._hostname) + "credentials found in it.".format(self._hostname) ) log.log2info(1081, log_message) + return False - def query(self): + async def query(self): """Query all remote hosts for data. Args: None Returns: - None - + dict: Polled data or None if failed """ # Initialize key variables _data = None - # Only query if wise - if bool(self._snmp_object) is False: + # Only query if the device is contactable + if bool(self.snmp_object) is False: + log.log2warning(1001, f"No valid SNMP object for {self._hostname} ") return _data # Get data log_message = """\ -Querying topology data from host {}.""".format( +Querying topology data from host: {}.""".format( self._hostname ) + log.log2info(1078, log_message) - # Return the data polled from the device - status = snmp_info.Query(self._snmp_object) - _data = status.everything() + status = snmp_info.Query(snmp_object=self.snmp_object) + + _data = await status.everything() + return _data @@ -101,7 +109,6 @@ def _do_poll(authorization): Returns: poll: True if a poll should be done - """ # Initialize key variables poll = False @@ -110,5 +117,4 @@ def _do_poll(authorization): if isinstance(authorization, SNMP) is True: poll = bool(authorization.enabled) - # Return return poll diff --git a/switchmap/poller/snmp/snmp_info.py b/switchmap/poller/snmp/snmp_info.py index cee2f140d..d664b446a 100644 --- a/switchmap/poller/snmp/snmp_info.py +++ b/switchmap/poller/snmp/snmp_info.py @@ -1,37 +1,37 @@ -"""Module to aggregate query results.""" +"""Async module to aggregate query results.""" import time from collections import defaultdict +from switchmap.core import log +import asyncio from . import iana_enterprise from . import get_queries class Query: - """Class interacts with IfMIB devices. + """Async class interacts with devices - use existing MIB classes. Args: None Returns: None - """ def __init__(self, snmp_object): """Instantiate the class. Args: - snmp_object: SNMP Interact class object from snmp_manager.py + snmp_object: SNMP interact class object from async_snmp_manager.py Returns: None - """ # Define query object self.snmp_object = snmp_object - def everything(self): + async def everything(self): """Get all information from device. Args: @@ -39,95 +39,151 @@ def everything(self): Returns: data: Aggregated data - """ # Initialize key variables data = {} - # Append data - data["misc"] = self.misc() - data["layer1"] = self.layer1() - data["layer2"] = self.layer2() - data["layer3"] = self.layer3() - data["system"] = self.system() + # Run all sections concurrently + results = await asyncio.gather( + self.misc(), + self.system(), + self.layer1(), + self.layer2(), + self.layer3(), + return_exceptions=True, + ) + + keys = ["misc", "system", "layer1", "layer2", "layer3"] + for key, result in zip(keys, results): + if isinstance(result, Exception): + log.log2warning(1417, f"{key} failed: {result}") + elif result: + data[key] = result # Return return data - def misc(self): - """Provide miscellaneous information about device and the poll. - - Args: - None - - Returns: - data: Aggregated data - - """ + async def misc(self): + """Provide miscellaneous information about the device and the poll.""" # Initialize data data = defaultdict(lambda: defaultdict(dict)) data["timestamp"] = int(time.time()) data["host"] = self.snmp_object.hostname() # Get vendor information - sysobjectid = self.snmp_object.sysobjectid() - vendor = iana_enterprise.Query(sysobjectid=sysobjectid) - data["IANAEnterpriseNumber"] = vendor.enterprise() + sysobjectid = await self.snmp_object.sysobjectid() + if sysobjectid: + vendor = iana_enterprise.Query(sysobjectid=sysobjectid) + data["IANAEnterpriseNumber"] = vendor.enterprise() + else: + data["IANAEnterpriseNumber"] = None - # Return return data - def system(self): + async def system(self): """Get all system information from device. Args: None Returns: - data: Aggregated data - + data: Aggregated system data """ - # Initialize data + # Initialize key variables data = defaultdict(lambda: defaultdict(dict)) processed = False - # Get system information from SNMPv2-MIB, ENTITY-MIB, IF-MIB - # Instantiate a query object for each system query - for item in [ - Query(self.snmp_object) for Query in get_queries("system") - ]: - if item.supported(): - processed = True - data = _add_system(item, data) + # Get system information from various MIB classes + system_queries = get_queries("system") + + # Create all query instances + query_items = [ + (query_class(self.snmp_object), query_class.__name__) + for query_class in system_queries + ] + + # Check if supported + support_results = await asyncio.gather( + *[item.supported() for item, _ in query_items] + ) + + supported_items = [ + (item, name) + for (item, name), supported in zip(query_items, support_results) + if supported + ] + + if supported_items: + results = await asyncio.gather( + *[ + _add_system(item, defaultdict(lambda: defaultdict(dict))) + for item, _ in supported_items + ] + ) + + # Merge results + for result in results: + for key, value in result.items(): + data[key].update(value) + processed = True - # Return if processed is True: return data else: return None - def layer1(self): - """Get all layer1 information from device. + async def layer1(self): + """Get all layer 1 information from device. Args: None Returns: - data: Aggregated data - + data: Aggregated layer1 data """ # Initialize key values data = defaultdict(lambda: defaultdict(dict)) processed = False - # Get information layer1 queries - - for item in [ - Query(self.snmp_object) for Query in get_queries("layer1") - ]: - if item.supported(): - processed = True - data = _add_layer1(item, data) + layer1_queries = get_queries("layer1") + + query_items = [ + (query_class(self.snmp_object), query_class.__name__) + for query_class in layer1_queries + ] + + # Concurrent support check + support_results = await asyncio.gather( + *[item.supported() for item, _ in query_items] + ) + + supported_items = [ + (item, name) + for (item, name), supported in zip(query_items, support_results) + if supported + ] + + if supported_items: + results = await asyncio.gather( + *[ + _add_layer1(item, defaultdict(lambda: defaultdict(dict))) + for item, _ in supported_items + ], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, Exception): + item_name = supported_items[i][1] + log.log2warning( + 1005, f"Layer1 error in {item_name}: {result}" + ) + continue + + for key, value in result.items(): + data[key].update(value) + + processed = True # Return if processed is True: @@ -135,157 +191,280 @@ def layer1(self): else: return None - def layer2(self): - """Get all layer2 information from device. + async def layer2(self): + """Get all layer 2 information from device. Args: None Returns: - data: Aggregated data - + data: Aggregated layer2 data """ # Initialize key variables data = defaultdict(lambda: defaultdict(dict)) processed = False - for item in [ - Query(self.snmp_object) for Query in get_queries("layer2") - ]: - if item.supported(): - processed = True - data = _add_layer2(item, data) + # Get layer2 information from MIB classes + layer2_queries = get_queries("layer2") + + query_items = [ + (query_class(self.snmp_object), query_class.__name__) + for query_class in layer2_queries + ] + + support_results = await asyncio.gather( + *[item.supported() for item, _ in query_items] + ) + + # Filter supported MIBs + supported_items = [ + (item, name) + for (item, name), supported in zip(query_items, support_results) + if supported + ] + + if supported_items: + # Concurrent processing + results = await asyncio.gather( + *[ + _add_layer2(item, defaultdict(lambda: defaultdict(dict))) + for item, _ in supported_items + ], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, Exception): + item_name = supported_items[i][1] + log.log2warning( + 1007, f"Layer2 error in {item_name}: {result}" + ) + continue + + # Merge this MIB's complete results + for key, value in result.items(): + data[key].update(value) + + processed = True # Return + if processed is True: return data else: return None - def layer3(self): + async def layer3(self): """Get all layer3 information from device. Args: None Returns: - data: Aggregated data - + data: Aggregated layer3 data """ # Initialize key variables data = defaultdict(lambda: defaultdict(dict)) processed = False - for item in [ - Query(self.snmp_object) for Query in get_queries("layer3") - ]: - if item.supported(): - processed = True - data = _add_layer3(item, data) + # Get layer3 information from MIB classes + layer3_queries = get_queries("layer3") + + query_items = [ + (query_class(self.snmp_object), query_class.__name__) + for query_class in layer3_queries + ] + + support_results = await asyncio.gather( + *[item.supported() for item, _ in query_items] + ) + + # Filter supported MIBs + supported_items = [ + (item, name) + for (item, name), supported in zip(query_items, support_results) + if supported + ] + + if supported_items: + # Concurrent processing + results = await asyncio.gather( + *[ + _add_layer3(item, defaultdict(lambda: defaultdict(dict))) + for item, _ in supported_items + ], + return_exceptions=True, + ) + + for i, result in enumerate(results): + if isinstance(result, Exception): + item_name = supported_items[i][1] + log.log2warning( + 1006, f"Layer3 error in {item_name}: {result}" + ) + continue + + # Merge this MIB's complete results + for key, value in result.items(): + data[key].update(value) + + processed = True - # Return if processed is True: return data - else: - return None + return None -def _add_data(source, target): +async def _add_data(source, target): """Add data from source to target dict. Both dicts must have two keys. Args: source: Source dict - target: Target dict + target: Target dict Returns: target: Aggregated data - """ # Process data for primary in source.keys(): for secondary, value in source[primary].items(): target[primary][secondary] = value - # Return + # Return return target -def _add_layer1(query, original_data): - """Add data from successful layer1 MIB query to original data provided. +async def _add_system(query, data): + """Add data from successful system MIB query to original data provided. Args: query: MIB query object - original_data: Two keyed dict of data + data: Three keyed dict of data Returns: - new_data: Aggregated data - + data: Aggregated data """ - # Process query - result = query.layer1() - new_data = _add_data(result, original_data) + try: + result = None - # Return - return new_data + if asyncio.iscoroutinefunction(query.system): + result = await query.system() + else: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, query.system) + # Merge only if we have data + if not result: + return data + for primary, secondary_map in result.items(): + if isinstance(secondary_map, dict): + for secondary, maybe_tertiary in secondary_map.items(): + if isinstance(maybe_tertiary, dict): + for tertiary, value in maybe_tertiary.items(): + data[primary][secondary][tertiary] = value + else: + data[primary][secondary] = maybe_tertiary + else: + # Handle case where secondary level is not a dict + data[primary] = secondary_map -def _add_layer2(query, original_data): - """Add data from successful layer2 MIB query to original data provided. + return data + except Exception as e: + log.log2warning(1320, f"Error in _add_system: {e}") + return data + + +async def _add_layer1(query, data): + """Add data from successful layer1 MIB query to original data provided. Args: query: MIB query object - original_data: Two keyed dict of data + data: dict of data Returns: - new_data: Aggregated data - + data: Aggregated data """ - # Process query - result = query.layer2() - new_data = _add_data(result, original_data) + try: + mib_name = query.__class__.__name__ - # Return - return new_data + result = None + if asyncio.iscoroutinefunction(query.layer1): + result = await query.layer1() + else: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, query.layer1) + if result: + data = await _add_data(result, data) + else: + log.log2debug(1302, f" No layer1 data returned for {mib_name}") -def _add_layer3(query, original_data): - """Add data from successful layer3 MIB query to original data provided. + return data + + except Exception as e: + log.log2warning(1316, f" Error in _add_layer1 for {mib_name}: {e}") + return data + + +async def _add_layer2(query, data): + """Add data from successful layer2 MIB query to original data provided. Args: query: MIB query object - original_data: Two keyed dict of data + data: dict of data Returns: - new_data: Aggregated data - + data: Aggregated data """ - # Process query - result = query.layer3() - new_data = _add_data(result, original_data) + try: + mib_name = query.__class__.__name__ + result = None + if asyncio.iscoroutinefunction(query.layer2): + result = await query.layer2() + else: - # Return - return new_data + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, query.layer2) + if result: + data = await _add_data(result, data) + else: + log.log2debug(1306, f" No layer2 data returned for {mib_name}") -def _add_system(query, data): - """Add data from successful system MIB query to original data provided. + return data + + except Exception as e: + log.log2warning(1308, f" Error in _add_layer2 for {mib_name}: {e}") + return data + + +async def _add_layer3(query, data): + """Add data from successful layer3 MIB query to original data provided. Args: query: MIB query object - data: Three keyed dict of data + data: dict of data Returns: data: Aggregated data - """ - # Process query - result = query.system() + try: + mib_name = query.__class__.__name__ - # Add tag - for primary in result.keys(): - for secondary in result[primary].keys(): - for tertiary, value in result[primary][secondary].items(): - data[primary][secondary][tertiary] = value + result = None + if asyncio.iscoroutinefunction(query.layer3): + result = await query.layer3() + else: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, query.layer3) - # Return - return data + if result: + data = await _add_data(result, data) + else: + log.log2debug(1309, f" No layer3 data returned for {mib_name}") + + return data + + except Exception as e: + log.log2warning(1310, f" Error in _add_layer3 for {mib_name}: {e}") + return data diff --git a/switchmap/poller/snmp/snmp_manager.py b/switchmap/poller/snmp/snmp_manager.py index 5ce5bf4d6..9ab1f73e5 100644 --- a/switchmap/poller/snmp/snmp_manager.py +++ b/switchmap/poller/snmp/snmp_manager.py @@ -1,35 +1,61 @@ -"""SNMP manager class.""" +"""Async SNMP manager class.""" import os -import sys +import asyncio -import easysnmp -from easysnmp import exceptions -# Import project libraries -from switchmap.poller.configuration import ConfigPoller -from switchmap.poller import POLL +# import project libraries from switchmap.core import log from switchmap.core import files +from switchmap.poller import POLL +from switchmap.poller.configuration import ConfigPoller + from . import iana_enterprise +from pysnmp.hlapi.asyncio import ( + SnmpEngine, + CommunityData, + UdpTransportTarget, + ContextData, + ObjectType, + ObjectIdentity, + getCmd, + nextCmd, + bulkCmd, + UsmUserData, + # Authentication protocols + usmHMACMD5AuthProtocol, + usmHMACSHAAuthProtocol, + usmHMAC128SHA224AuthProtocol, + usmHMAC192SHA256AuthProtocol, + usmHMAC256SHA384AuthProtocol, + usmHMAC384SHA512AuthProtocol, + # Privacy protocols + usmDESPrivProtocol, + usmAesCfb128Protocol, + usmAesCfb192Protocol, + usmAesCfb256Protocol, +) + +from pysnmp.error import PySnmpError +from pysnmp.proto.rfc1905 import EndOfMibView, NoSuchInstance, NoSuchObject + class Validate: - """Class Verify SNMP data.""" + """Class to validate SNMP data asynchronously.""" def __init__(self, options): """Initialize the Validate class. Args: - options: POLLING_OPTIONS object containing SNMP configuration + options: POLLING_OPTIONS object containing SNMP configuration. Returns: - None + None """ - # Initialize key variables self._options = options - def credentials(self): + async def credentials(self): """Determine valid SNMP credentials for a host. Args: @@ -41,47 +67,42 @@ def credentials(self): """ # Initialize key variables cache_exists = False - - # Create cache directory / file if not yet created filename = files.snmp_file(self._options.hostname, ConfigPoller()) - if os.path.exists(filename) is True: + if os.path.exists(filename): cache_exists = True + group = None - # Create file if necessary if cache_exists is False: - # Get credentials - authentication = self.validation() + authentication = await self.validation() # Save credentials if successful if bool(authentication): _update_cache(filename, authentication.group) - else: # Read credentials from cache if os.path.isfile(filename): with open(filename) as f_handle: - group = f_handle.readline() + group = f_handle.readline().strip() or None - # Get credentials - authentication = self.validation(group) + # Get Credentials + authentication = await self.validation(group) - # Try the rest if these credentials fail + # Try the rest if the credentials fail if bool(authentication) is False: - authentication = self.validation() + authentication = await self.validation() - # Update cache if found + # update cache if found if bool(authentication): _update_cache(filename, authentication.group) - # Return return authentication - def validation(self, group=None): + async def validation(self, group=None): """Determine valid SNMP authorization for a host. Args: group: String containing SNMP group name to try, or None to try all - groups + groups Returns: result: SNMP authorization object if valid credentials found, @@ -99,24 +120,19 @@ def validation(self, group=None): # Setup contact with the remote device device = Interact( POLL( - hostname=self._options.hostname, - authorization=authorization, + hostname=self._options.hostname, authorization=authorization ) ) - - # Try successive groups + # Try successive groups check if device is contactable if group is None: - # Verify connectivity - if device.contactable() is True: + if await device.contactable() is True: result = authorization break else: if authorization.group == group: - # Verify connectivity - if device.contactable() is True: + if await device.contactable() is True: result = authorization - # Return return result @@ -134,15 +150,20 @@ def __init__(self, _poll): """ # Initialize key variables self._poll = _poll + self._engine = SnmpEngine() + + # Rate Limiting + self._semaphore = asyncio.Semaphore(10) # Fail if there is no authentication if bool(self._poll.authorization) is False: log_message = ( - "SNMP parameters provided are blank. " "Non existent host?" + "SNMP parameters provided are either blank or missing." + "Non existent host?" ) log.log2die(1045, log_message) - def enterprise_number(self): + async def enterprise_number(self): """Get SNMP enterprise number for the device. Args: @@ -152,13 +173,12 @@ def enterprise_number(self): int: SNMP enterprise number identifying the device vendor """ # Get the sysObjectID.0 value of the device - sysid = self.sysobjectid() + sysid = await self.sysobjectid() # Get the vendor ID enterprise_obj = iana_enterprise.Query(sysobjectid=sysid) enterprise = enterprise_obj.enterprise() - # Return return enterprise def hostname(self): @@ -170,48 +190,36 @@ def hostname(self): Returns: str: Hostname of the target device """ - # Initialize key variables - hostname = self._poll.hostname - - # Return - return hostname + return self._poll.hostname - def contactable(self): + async def contactable(self): """Check if device is reachable via SNMP. Args: None Returns: - bool: True if device responds to SNMP queries, False otherwise + bool: True if device responds to SNMP queries, False otherwise """ - # Define key variables + # key variables contactable = False result = None # Try to reach device try: - # If we can poll the SNMP sysObjectID, - # then the device is contactable - result = self.sysobjectid(check_reachability=True) + # Test if we can poll the SNMP sysObjectID + # if true, then the device is contactable + result = await self.sysobjectid(check_reachability=True) if bool(result) is True: contactable = True except Exception: - # Not contactable + # Not Contactable contactable = False - except: - # Log a message - log_message = "Unexpected SNMP error for device {}" "".format( - self._poll.hostname - ) - log.log2die(1008, log_message) - - # Return return contactable - def sysobjectid(self, check_reachability=False): + async def sysobjectid(self, check_reachability=False): """Get the sysObjectID of the device. Args: @@ -226,14 +234,22 @@ def sysobjectid(self, check_reachability=False): object_id = None # Get sysObjectID - results = self.get(oid, check_reachability=check_reachability) + results = await self.get(oid, check_reachability) + # Pysnmp already returns out value as value unlike easysnmp if bool(results) is True: - object_id = results[oid].decode("utf-8") + # Both formats: with and without leading dot + object_id = results.get(oid) + if object_id is None: + oid_without_dot = oid.lstrip(".") + object_id = results.get(oid_without_dot) + + # Convert bytes to string if needed + if isinstance(object_id, bytes): + object_id = object_id.decode("utf-8") - # Return return object_id - def oid_exists(self, oid_to_get, context_name=""): + async def oid_exists(self, oid_to_get, context_name=""): """Determine if an OID exists on the device. Args: @@ -244,24 +260,35 @@ def oid_exists(self, oid_to_get, context_name=""): Returns: bool: True if OID exists, False otherwise """ - # Initialize key variables - validity = False - - # Validate OID - if self._oid_exists_get(oid_to_get, context_name=context_name) is True: - validity = True + try: + # Initialize key + validity = False - if validity is False: + # Validate OID if ( - self._oid_exists_walk(oid_to_get, context_name=context_name) + await self._oid_exists_get( + oid_to_get, context_name=context_name + ) is True ): validity = True + if validity is False: + if ( + await self._oid_exists_walk( + oid_to_get, context_name=context_name + ) + is True + ): + validity = True + + return validity + except Exception as e: + log.log2warning( + 1305, f"OID existence check failed for {oid_to_get}: {e}" + ) + return False - # Return - return validity - - def _oid_exists_get(self, oid_to_get, context_name=""): + async def _oid_exists_get(self, oid_to_get, context_name=""): """Determine existence of OID on device. Args: @@ -272,70 +299,84 @@ def _oid_exists_get(self, oid_to_get, context_name=""): Returns: validity: True if exists - """ - # Initialize key variables - validity = False - - # Process - (_, validity, result) = self.query( - oid_to_get, - get=True, - check_reachability=True, - context_name=context_name, - check_existence=True, - ) - - # If we get no result, then override validity - if bool(result) is False: + try: validity = False - elif isinstance(result, dict) is True: - if result[oid_to_get] is None: - validity = False - # Return - return validity + (_, exists, result) = await self.query( + oid_to_get, + get=True, + check_reachability=True, + check_existence=True, + context_name=context_name, + ) - def _oid_exists_walk(self, oid_to_get, context_name=""): - """Determine existence of OID on device. + if exists and bool(result): + # Make sure the OID key exists in result + exact_key = oid_to_get + alt_key = oid_to_get.lstrip(".") + if isinstance(result, dict) and oid_to_get in result: + if ( + result.get(exact_key) is not None + or result.get(alt_key) is not None + ): + validity = True + elif isinstance(result, dict) and result: + # If result has data but not exact OID, still consider valid + validity = True + + return validity + except Exception as e: + log.log2warning( + 1305, f"OID existence check failed for {oid_to_get}: {e}" + ) + return False - Args: + async def _oid_exists_walk(self, oid_to_get, context_name=""): + """Check OID existence on device using WALK. + + Args: oid_to_get: OID to get context_name: Set the contextName used for SNMPv3 messages. The default contextName is the empty string "". Overrides the defContext token in the snmp.conf file. Returns: - validity: True if exists - + validity: True if exist """ - # Initialize key variables - validity = False - - # Process - (_, validity, results) = self.query( - oid_to_get, - get=False, - check_reachability=True, - context_name=context_name, - check_existence=True, - ) - - # If we get no result, then override validity - if isinstance(results, dict) is True: - for _, value in results.items(): - if value is None: - validity = False - break - - # Return - return validity + try: + (_, exists, results) = await self.query( + oid_to_get, + get=False, + check_existence=True, + context_name=context_name, + check_reachability=True, + ) + # Check if we get valid results + if exists and isinstance(results, dict) and results: + return True + return False + except Exception as e: + log.log2warning( + 1306, f"Walk existence check failed for {oid_to_get}: {e}" + ) + return False - def swalk(self, oid_to_get, normalized=False, context_name=""): - """Perform a safe SNMPwalk that handles errors gracefully. + async def get( + self, + oid_to_get, + check_reachability=False, + check_existence=False, + normalized=False, + context_name="", + ): + """Do an SNMPget. Args: oid_to_get: OID to get + check_reachability: Set if testing for connectivity. Some session + errors are ignored so that a null result is returned + check_existence: Set if checking for the existence of the OID normalized: If True, then return results as a dict keyed by only the last node of an OID, otherwise return results keyed by the entire OID string. Normalization is useful @@ -347,22 +388,19 @@ def swalk(self, oid_to_get, normalized=False, context_name=""): defContext token in the snmp.conf file. Returns: - dict: Results of SNMP walk as OID-value pairs + result: Dictionary of {OID: value} pairs """ - # Process data - results = self.walk( + (_, _, result) = await self.query( oid_to_get, + get=True, + check_reachability=check_reachability, + check_existence=check_existence, normalized=normalized, - check_reachability=True, - check_existence=True, context_name=context_name, - safe=True, ) + return result - # Return - return results - - def walk( + async def walk( self, oid_to_get, normalized=False, @@ -371,7 +409,7 @@ def walk( context_name="", safe=False, ): - """Do an SNMPwalk. + """Do an async SNMPwalk. Args: oid_to_get: OID to walk @@ -394,9 +432,8 @@ def walk( Returns: result: Dictionary of tuples (OID, value) - """ - (_, _, result) = self.query( + (_, _, result) = await self.query( oid_to_get, get=False, check_reachability=check_reachability, @@ -405,23 +442,14 @@ def walk( context_name=context_name, safe=safe, ) + return result - def get( - self, - oid_to_get, - check_reachability=False, - check_existence=False, - normalized=False, - context_name="", - ): - """Do an SNMPget. + async def swalk(self, oid_to_get, normalized=False, context_name=""): + """Perform a safe async SNMPwalk that handles errors gracefully. Args: oid_to_get: OID to get - check_reachability: Set if testing for connectivity. Some session - errors are ignored so that a null result is returned - check_existence: Set if checking for the existence of the OID normalized: If True, then return results as a dict keyed by only the last node of an OID, otherwise return results keyed by the entire OID string. Normalization is useful @@ -433,20 +461,19 @@ def get( defContext token in the snmp.conf file. Returns: - result: Dictionary of tuples (OID, value) - + dict: Results of SNMP walk as OID-value pairs """ - (_, _, result) = self.query( + # Process data + return await self.walk( oid_to_get, - get=True, - check_reachability=check_reachability, - check_existence=check_existence, normalized=normalized, + check_reachability=True, + check_existence=True, context_name=context_name, + safe=True, ) - return result - def query( + async def query( self, oid_to_get, get=False, @@ -477,466 +504,468 @@ def query( blank values. Returns: - return_value: List of tuples (_contactable, exists, values) - + return_value: Tuple of (_contactable, exists, values) """ # Initialize variables _contactable = True exists = True results = [] + # Initialize formatted_result to avoid undefined variable error + formatted_result = {} # Check if OID is valid if _oid_valid_format(oid_to_get) is False: log_message = "OID {} has an invalid format".format(oid_to_get) log.log2die(1057, log_message) - # Create SNMP session - session = _Session(self._poll, context_name=context_name).session + # Get session parameters + async with self._semaphore: + try: + # Create SNMP session + session = Session( + self._poll, self._engine, context_name=context_name + ) - # Fill the results object by getting OID data - try: - # Get the data - if get is True: - results = [session.get(oid_to_get)] + # Use shorter timeouts for walk operations + auth_data, transport_target = await session._session( + walk_operation=(not get) + ) + context_data = ContextData(contextName=context_name) - else: - if self._poll.authorization.version != 1: - # Bulkwalk for SNMPv2 and SNMPv3 - results = session.bulkwalk( - oid_to_get, non_repeaters=0, max_repetitions=25 + # Perform the SNMP operation + if get is True: + results = await session._do_async_get( + oid_to_get, auth_data, transport_target, context_data ) else: - # Bulkwalk not supported in SNMPv1 - results = session.walk(oid_to_get) - - # Crash on error, return blank results if doing certain types of - # connectivity checks - except ( - exceptions.EasySNMPConnectionError, - exceptions.EasySNMPTimeoutError, - exceptions.EasySNMPUnknownObjectIDError, - exceptions.EasySNMPNoSuchNameError, - exceptions.EasySNMPNoSuchObjectError, - exceptions.EasySNMPNoSuchInstanceError, - exceptions.EasySNMPUndeterminedTypeError, - ) as exception_error: - # Update the error message - log_message = _exception_message( - self._poll.hostname, - oid_to_get, - context_name, - sys.exc_info(), - ) - - # Process easysnmp errors - (_contactable, exists) = _process_error( - log_message, - exception_error, - check_reachability, - check_existence, - ) - - except SystemError as exception_error: - log_message = _exception_message( - self._poll.hostname, - oid_to_get, - context_name, - sys.exc_info(), - ) - - # Process easysnmp errors - (_contactable, exists) = _process_error( - log_message, - exception_error, - check_reachability, - check_existence, - system_error=True, - ) + results = await session._do_async_walk( + oid_to_get, auth_data, transport_target, context_data + ) - except: - # Update the error message - log_message = _exception_message( - self._poll.hostname, - oid_to_get, - context_name, - sys.exc_info(), - ) - if bool(safe): - _contactable = None - exists = None - log.log2info(1209, log_message) - else: - log.log2die(1003, log_message) + formatted_result = _format_results( + results, oid_to_get, normalized=normalized + ) - # Format results - values = _format_results(results, oid_to_get, normalized=normalized) + except PySnmpError as exception_error: + # Handle PySNMP errors similar to sync version + if check_reachability is True: + _contactable = False + exists = False + elif check_existence is True: + exists = False + elif safe is True: + _contactable = None + exists = None + log_message = ( + f"Async SNMP error for {self._poll.hostname}: " + f"{exception_error}" + ) + log.log2info(1209, log_message) + else: + log_message = ( + f"Async SNMP error for {self._poll.hostname}: " + f"{exception_error}" + ) + log.log2die(1003, log_message) + # Ensure formatted_result is set for exception cases + formatted_result = {} + + except Exception as exception_error: + # Handle unexpected errors + if safe is True: + _contactable = None + exists = None + log_message = ( + f"Unexpected async SNMP error for " + f"{self._poll.hostname}: {exception_error}" + ) + log.log2info(1041, log_message) + else: + log_message = ( + f"Unexpected async SNMP error for " + f"{self._poll.hostname}: {exception_error}" + ) + log.log2die(1023, log_message) + # Ensure formatted_result is set for exception cases + formatted_result = {} # Return - return_value = (_contactable, exists, values) - return return_value + values = (_contactable, exists, formatted_result) + return values -class _Session: - """Class to create an SNMP session with a device.""" +class Session: + """Class to create a SNMP session with a device.""" - def __init__(self, _poll, context_name=""): + def __init__(self, _poll, engine, context_name=""): """Initialize the _Session class. Args: _poll: POLL object containing SNMP configuration + engine: SNMP engine object context_name: String containing SNMPv3 context name. Default is empty string. Returns: session: SNMP session - """ - # Initialize key variables - self._context_name = context_name - # Assign variables + self.context_name = context_name self._poll = _poll + self._engine = engine # Fail if there is no authentication if bool(self._poll.authorization) is False: log_message = ( - "SNMP parameters provided are blank. " "Non existent host?" + "SNMP parameters provided are blank. None existent host? " ) log.log2die(1046, log_message) - # Create SNMP session - self.session = self._session() - - def _session(self): - """Create an SNMP session for queries. - - Args: - None - - Returns: - session: SNMP session - - """ - # Create session - if self._poll.authorization.version != 3: - session = easysnmp.Session( - community=self._poll.authorization.community, - hostname=self._poll.hostname, - version=self._poll.authorization.version, - remote_port=self._poll.authorization.port, - use_numeric=True, - context=self._context_name, - ) - else: - session = easysnmp.Session( - hostname=self._poll.hostname, - version=self._poll.authorization.version, - remote_port=self._poll.authorization.port, - use_numeric=True, - context=self._context_name, - security_level=self._security_level(), - security_username=self._poll.authorization.secname, - privacy_protocol=self._priv_protocol(), - privacy_password=self._poll.authorization.privpassword, - auth_protocol=self._auth_protocol(), - auth_password=self._poll.authorization.authpassword, - ) - - # Return - return session - - def _security_level(self): - """Determine SNMPv3 security level string. - - Args: - None - - Returns: - result: Security level - """ - # Determine the security level - if bool(self._poll.authorization.authprotocol) is True: - if bool(self._poll.authorization.privprotocol) is True: - result = "authPriv" - else: - result = "authNoPriv" - else: - result = "noAuthNoPriv" - - # Return - return result - - def _auth_protocol(self): - """Get SNMPv3 authentication protocol. - - Args: - None + async def _session(self, walk_operation=False): + """Create SNMP session parameters based on configuration. Returns: - str: Authentication protocol string ('MD5', 'SHA', or 'DEFAULT') + Tuple of (auth_data, transport_target) """ # Initialize key variables - protocol = self._poll.authorization.authprotocol + auth = self._poll.authorization - # Setup AuthProtocol (Default SHA) - if bool(protocol) is False: - result = "DEFAULT" + # Use shorter timeouts for walk operations to prevent hanging + if walk_operation: + timeout = 3 + retries = 1 else: - if protocol.lower() == "md5": - result = "MD5" - else: - result = "SHA" - - # Return - return result - - def _priv_protocol(self): - """Get SNMPv3 privacy protocol. - - Args: - None + # Normal timeout for GET operations + timeout = 10 + retries = 3 - Returns: - str: Privacy protocol string ('DES', 'AES', or 'DEFAULT') - """ - # Initialize key variables - protocol = self._poll.authorization.privprotocol + # Create transport target + transport_target = UdpTransportTarget( + (self._poll.hostname, auth.port), timeout=timeout, retries=retries + ) - # Setup privProtocol (Default AES256) - if bool(protocol) is False: - result = "DEFAULT" + # Create authentication data based on SNMP version + if auth.version == 3: + # SNMPv3 with USM + # If authprotocol/privprotocol is None/False/Empty, leave as None + auth_protocol = None + priv_protocol = None + + # Set auth protocol only if authprotocol is specified + if auth.authprotocol: + auth_proto = auth.authprotocol.lower() + if auth_proto == "md5": + auth_protocol = usmHMACMD5AuthProtocol + elif auth_proto == "sha1" or auth_proto == "sha": + auth_protocol = usmHMACSHAAuthProtocol + elif auth_proto == "sha224": + auth_protocol = usmHMAC128SHA224AuthProtocol + elif auth_proto == "sha256": + auth_protocol = usmHMAC192SHA256AuthProtocol + elif auth_proto == "sha384": + auth_protocol = usmHMAC256SHA384AuthProtocol + elif auth_proto == "sha512": + auth_protocol = usmHMAC384SHA512AuthProtocol + else: + log.log2warning( + 1218, + f"Unknown auth protocol '{auth.authprotocol}'," + f"leaving unset", + ) + auth_protocol = None + + # Set privacy protocol only if privprotocol is specified + # Also if we have authentication (privacy requires authentication) + if auth.privprotocol and auth_protocol is not None: + priv_proto = auth.privprotocol.lower() + if priv_proto == "des": + priv_protocol = usmDESPrivProtocol + elif priv_proto == "aes128" or priv_proto == "aes": + priv_protocol = usmAesCfb128Protocol + elif priv_proto == "aes192": + priv_protocol = usmAesCfb192Protocol + elif priv_proto == "aes256": + priv_protocol = usmAesCfb256Protocol + else: + log.log2warning( + 1218, + f"Unknown auth protocol '{auth.privprotocol}'," + f"leaving unset", + ) + priv_protocol = None + + auth_data = UsmUserData( + userName=auth.secname, + authKey=auth.authpassword, + privKey=auth.privpassword, + authProtocol=auth_protocol, + privProtocol=priv_protocol, + ) else: - if protocol.lower() == "des": - result = "DES" - else: - result = "AES" + # SNMPv1/v2c with community + mp_model = 0 if auth.version == 1 else 1 + auth_data = CommunityData(auth.community, mpModel=mp_model) - # Return - return result + return auth_data, transport_target + async def _do_async_get( + self, oid, auth_data, transport_target, context_data + ): + """Pure async SNMP GET using pysnmp.""" + error_indication, error_status, error_index, var_binds = await getCmd( + self._engine, + auth_data, + transport_target, + context_data, + ObjectType(ObjectIdentity(oid)), + ) -def _exception_message(hostname, oid, context, exc_info): - """Create standardized exception message for SNMP errors. - - Args: - hostname: Hostname - oid: OID being polled - context: SNMP context - exc_info: Exception information + if error_indication: + raise PySnmpError(f"SNMP GET error: {error_indication}") + elif error_status: + raise PySnmpError(f"SNMP GET error status: {error_status}") - Returns: - str: Formatted error message - """ - # Create failure log message - try_log_message = ( - "Error occurred during SNMP query on host " - 'OID {} from {} for context "{}"' - "".format(oid, hostname, context) - ) - - # Add exception information - result = """\ -{}: [{}, {}, {}]""".format( - try_log_message, - exc_info[0], - exc_info[1], - exc_info[2], - ) - - # Return - return result - - -def _process_error( - log_message, - exception_error, - check_reachability, - check_existence, - system_error=False, -): - """Process the SNMP error. + # Return in object format expected by _format_results + results = [] + for var_bind in var_binds: + oid_str = str(var_bind[0]) + value = var_bind[1] + results.append((oid_str, value)) - Args: - log_message: Log message - exception_error: Exception error object - check_reachability: Attempt to contact the device if True - check_existence: Check existence of the device if True - system_error: True if a System error + return results - Returns: - alive: True if contactable + async def _do_async_walk( + self, oid_prefix, auth_data, transport_target, context_data + ): + """Pure async SNMP WALK using pysnmp async capabilities.""" + # Initialize key variables + results = [] - """ - # Initialize key varialbes - _contactable = True - exists = True - if system_error is False: - error_name = "EasySNMPError" - else: - error_name = "SystemError" - - # Check existence of OID - if check_existence is True: - if system_error is False: - if ( - isinstance( - exception_error, - easysnmp.exceptions.EasySNMPUnknownObjectIDError, + # Use correct walk method based on SNMP version + if hasattr(auth_data, "mpModel") and auth_data.mpModel == 0: + # SNMPv1 - use nextCMD + results = await self._async_walk_v1( + oid_prefix, auth_data, transport_target, context_data + ) + else: + # SNMPv2c/v3 - use bulkCmd + + try: + results = await asyncio.wait_for( + self._async_walk_v2( + oid_prefix, auth_data, transport_target, context_data + ), + timeout=60.0, ) - is True - ): - exists = False - return (_contactable, exists) - elif ( - isinstance( - exception_error, - easysnmp.exceptions.EasySNMPNoSuchNameError, + except asyncio.TimeoutError: + log.log2info( + 1011, f"bulk walk timeout after 60s for prefix {oid_prefix}" ) - is True - ): - exists = False - return (_contactable, exists) - elif ( - isinstance( - exception_error, - easysnmp.exceptions.EasySNMPNoSuchObjectError, + # Fallback to SNMPv1 walk which would be more reliable + results = await self._async_walk_v1( + oid_prefix, auth_data, transport_target, context_data ) - is True - ): - exists = False - return (_contactable, exists) - elif ( - isinstance( - exception_error, - easysnmp.exceptions.EasySNMPNoSuchInstanceError, - ) - is True - ): - exists = False - return (_contactable, exists) - else: - exists = False - return (_contactable, exists) - # Checking if the device is reachable - if check_reachability is True: - _contactable = False - exists = False - return (_contactable, exists) - - # Die an agonizing death! - log_message = "{}: {}".format(error_name, log_message) - log.log2die(1023, log_message) + return results + async def _async_walk_v1( + self, oid_prefix, auth_data, transport_target, context_data + ): + """Pure async walk for SNMPv1 using nextCmd.""" + # Initialize key variables + results = [] -def _format_results(results, mock_filter, normalized=False): - """Normalize and format SNMP walk results. + try: + async for ( + error_indication, + error_status, + error_index, + var_binds, + ) in nextCmd( + self._engine, + auth_data, + transport_target, + context_data, + ObjectType(ObjectIdentity(oid_prefix)), + lexicographicMode=False, + ): + # Handle errors first + if error_indication: + log.log2warning( + 1216, + f"SNMP v1 walk network error for {oid_prefix}: " + f"{error_indication}.", + ) + break - Args: - results: List of lists of results - mock_filter: The original OID to get. Facilitates unittesting by - filtering Mock values. - normalized: If True, then return results as a dict keyed by - only the last node of an OID, otherwise return results - keyed by the entire OID string. Normalization is useful - when trying to create multidimensional dicts where the - primary key is a universal value such as IF-MIB::ifIndex - or BRIDGE-MIB::dot1dBasePort + elif error_status: + log.log2info( + 1217, + f"SNMP v1 walk protocol error for {oid_prefix}: " + f"{error_status} at index {error_index}", + ) - Returns: - dict: Formatted results as OID-value pairs - """ - # Initialize key variables - return_results = {} + # Handle specific SNMP errors + error_msg = error_status.prettyPrint() + if error_msg == "noSuchName": + # This OID doesn't exist, try next + continue + else: + # Other errors are usually fatal + break + + # Process successful response + for oid, value in var_binds: + oid_str = str(oid) + prefix_normalized = str(oid_prefix).lstrip(".") + oid_normalized = oid_str.lstrip(".") + if not oid_normalized.startswith(prefix_normalized): + log.log2debug( + 1220, + f"Reached end of OID tree for prefix {oid_prefix}", + ) + return results + results.append((oid_str, value)) + + # Return results after the loop completes + return results + + except Exception as e: + log.log2warning( + 1222, f"Unexpected error in SNMP v1 walk for {oid_prefix}: {e}." + ) + return results - for result in results: - # Recreate the OID - oid = "{}.{}".format(result.oid, result.oid_index) + async def _async_walk_v2( + self, oid_prefix, auth_data, transport_target, context_data + ): + """Async walk for SNMPv2c/v3 using bulkCmd.""" + # Initialize key variables + results = [] + current_oids = [ObjectType(ObjectIdentity(oid_prefix))] - # Ignore unwanted OIDs - if mock_filter not in oid: - continue + try: + # !checking for 50, 100 would be too long to prevent from hanging + max_iterations = 50 + iterations = 0 + consecutive_empty_responses = 0 + # Stop after 3 consecutive empty responses + max_empty_responses = 3 + + while current_oids and iterations < max_iterations: + iterations += 1 + # non-repeaters = 0 , max_repetitions = 25 + error_indication, error_status, error_index, var_bind_table = ( + await bulkCmd( + self._engine, + auth_data, + transport_target, + context_data, + 0, + 25, + *current_oids, + ) + ) - # Process the rest - if normalized is True: - return_results[result.oid_index] = _convert(result) - else: - return_results[oid] = _convert(result) + if error_indication: + log.log2info( + 1211, f"BULK error indication: {error_indication}" + ) + break + elif error_status: + log.log2info( + 1212, + f"BULK error status: {error_status.prettyPrint()} " + f"at {error_index}", + ) + break - # Return - return return_results + # Check if we got any response + if not var_bind_table: + consecutive_empty_responses += 1 + if consecutive_empty_responses >= max_empty_responses: + break + continue + else: + consecutive_empty_responses = 0 + + # Process the response + found_valid_data = False + prefix_normalized = str(oid_prefix).lstrip(".") + + for var_bind in var_bind_table: + if not var_bind or len(var_bind) == 0: + continue + + # Get the ObjectType from the list + for obj_type in var_bind: + oid, value = obj_type[0], obj_type[1] + + # Check for end of MIB + if isinstance(value, EndOfMibView): + continue + oid_str = str(oid) + + oid_normalized = oid_str.lstrip(".") + if not oid_normalized.startswith(prefix_normalized): + continue + results.append((oid_str, value)) + found_valid_data = True + + # Advance the walk using only the last row's OIDs + next_oids = [] + if var_bind_table: + last_row = var_bind_table[-1] + for obj_type in last_row: + oid, value = obj_type[0], obj_type[1] + if isinstance(value, EndOfMibView): + continue + oid_str = str(oid) + if oid_str.lstrip(".").startswith(prefix_normalized): + next_oids.append(ObjectType(ObjectIdentity(oid))) + + if not found_valid_data: + log.log2info( + 1213, + f"BULK walk: No more valid data for prefix " + f"{oid_prefix}", + ) + break + current_oids = next_oids -def _convert(result): - """Convert SNMP value from pysnmp object to Python type. + # In case, we get too many results + if len(results) > 10000: + log.log2warning( + 1214, + f"Stopping after collecting {len(results)} results " + f"(safety limit)", + ) + break - Args: - result: Named tuple containing SNMP result + except Exception as e: + log.log2warning(1215, f"BULK walk error: {e}") + return await self._async_walk_v1( + oid_prefix, auth_data, transport_target, context_data + ) - Returns: - converted: Value converted to appropriate Python type (bytes or int), - or None for null/empty values - """ - # Initialieze key values - converted = None - value = result.value - snmp_type = result.snmp_type - - # Convert string type values to bytes - if snmp_type.upper() == "OCTETSTR": - converted = bytes(value, "utf-8") - elif snmp_type.upper() == "OPAQUE": - converted = bytes(value, "utf-8") - elif snmp_type.upper() == "BITS": - converted = bytes(value, "utf-8") - elif snmp_type.upper() == "IPADDR": - converted = bytes(value, "utf-8") - elif snmp_type.upper() == "NETADDR": - converted = bytes(value, "utf-8") - elif snmp_type.upper() == "OBJECTID": - # DO NOT CHANGE !!! - converted = bytes(str(value), "utf-8") - elif snmp_type.upper() == "NOSUCHOBJECT": - # Nothing if OID not found - converted = None - elif snmp_type.upper() == "NOSUCHINSTANCE": - # Nothing if OID not found - converted = None - elif snmp_type.upper() == "ENDOFMIBVIEW": - # Nothing - converted = None - elif snmp_type.upper() == "NULL": - # Nothing - converted = None - else: - # Convert everything else into integer values - # rfc1902.Integer - # rfc1902.Integer32 - # rfc1902.Counter32 - # rfc1902.Gauge32 - # rfc1902.Unsigned32 - # rfc1902.TimeTicks - # rfc1902.Counter64 - converted = int(value) - - # Return - return converted + return results def _oid_valid_format(oid): - """Validate OID string format. + """Validate OID string format matching sync version. Args: - oid: String containing OID to validate + oid: String containing OID to validate Returns: - bool: True if OID format is valid, False otherwise + bool: True if OID format is valid, False otherwise """ # oid cannot be numeric if isinstance(oid, str) is False: return False - # Make sure the oid is not blank + # Make sure that oid is not blank stripped_oid = oid.strip() if not stripped_oid: return False @@ -957,13 +986,143 @@ def _oid_valid_format(oid): for value in octets: try: int(value) - except: + except (ValueError, TypeError): return False # Otherwise valid return True +def _convert(value): + """Convert SNMP value from pysnmp object to Python type. + + Args: + value: pysnmp value object + + Returns: + converted: Value converted to appropriate Python type (bytes or int), + or None for null/empty values + """ + # Handle pysnmp exception values + if isinstance(value, NoSuchObject): + return None + if isinstance(value, NoSuchInstance): + return None + if isinstance(value, EndOfMibView): + return None + + if hasattr(value, "prettyPrint"): + value_str = value.prettyPrint() + + # Determine type based on pysnmp object type + value_type = type(value).__name__ + + # Handle string-like types - Convert to types for MIB compatibility + if any( + t in value_type + for t in [ + "OctetString", + "DisplayString", + "Opaque", + "Bits", + "IpAddress", + "ObjectIdentifier", + ] + ): + # For objectID, convert to string first then to bytes + if "ObjectIdentifier" in value_type: + return bytes(str(value_str), "utf-8") + else: + return bytes(value_str, "utf-8") + + # Handle integer types + elif any( + t in value_type + for t in ["Integer", "Counter", "Gauge", "TimeTicks", "Unsigned"] + ): + try: + return int(value_str) + except ValueError: + # Direct int conversion of the obj if prettyPrint fails + if hasattr(value, "__int__"): + try: + return int(value) + except (ValueError, TypeError): + pass + + # Accessing .value attr directly + if hasattr(value, "value"): + try: + return int(value.value) + except (ValueError, TypeError): + pass + + log_message = ( + f"Failed to convert pysnmp integer value: " + f"{value_type}, prettyPrint'{value_str}" + ) + log.log2warning(1036, log_message) + return None + + # Handle direct access to value (for objects without prettyPrint) + if hasattr(value, "value"): + try: + return int(value.value) + except (ValueError, TypeError): + return bytes(str(value.value), "utf-8") + + # Default Fallback - convert to string then to bytes + try: + return bytes(str(value), "utf-8") + except Exception: + return None + + +def _format_results(results, mock_filter, normalized=False): + """Normalize and format SNMP results. + + Args: + results: List of (OID, value) tuples from pysnmp + mock_filter: The original OID to get. Facilitates unittesting by + filtering Mock values. + normalized: If True, then return results as a dict keyed by + only the last node of an OID, otherwise return results + keyed by the entire OID string. Normalization is useful + when trying to create multidimensional dicts where the + primary key is a universal value such as IF-MIB::ifIndex + or BRIDGE-MIB::dot1dBasePort + + Returns: + dict: Formatted results as OID-value pairs + """ + # Initialize key variables + formatted = {} + + for oid_str, value in results: + + # Normalize both OIDs for comparison to handle leading dot mismatch + if mock_filter: + # Remove leading dots for comparison + filter_normalized = mock_filter.lstrip(".") + oid_normalized = oid_str.lstrip(".") + + if not oid_normalized.startswith(filter_normalized): + continue + + # convert value using proper type conversion + converted_value = _convert(value=value) + + if normalized is True: + # use only the last node of the OID + key = oid_str.split(".")[-1] + else: + key = oid_str + + formatted[key] = converted_value + + return formatted + + def _update_cache(filename, group): """Update SNMP credentials cache file. @@ -974,6 +1133,9 @@ def _update_cache(filename, group): Returns: None """ - # Do update - with open(filename, "w+") as env: - env.write(group) + try: + with open(filename, "w+") as f_handle: + f_handle.write(group) + except Exception as e: + log_message = f"Failed to update cache file {filename}: {e}" + log.log2warning(1025, log_message) diff --git a/switchmap/server/db/ingest/update/device.py b/switchmap/server/db/ingest/update/device.py index 63f54352b..8a49ff81f 100644 --- a/switchmap/server/db/ingest/update/device.py +++ b/switchmap/server/db/ingest/update/device.py @@ -8,6 +8,7 @@ # Application imports from switchmap.core import log from switchmap.core import general +from switchmap.core.mac_utils import decode_mac_address from switchmap.server.db.ingest.query import device as _misc_device from switchmap.server.db.misc import interface as _historical from switchmap.server.db.table import device as _device @@ -505,13 +506,14 @@ def macport(self, test=False): {self._device.hostname} based on SNMP MIB-BRIDGE entries""" log.log2debug(1065, log_message) - # Iterate over the MACs found for next_mac in sorted(_macs): - # Initialize loop variables + # Initialize variables valid_mac = None # Create lowercase version of MAC address - mactest = general.mac(next_mac) + # Handle double-encoded MAC addresses from async poller + decoded_mac = decode_mac_address(next_mac) + mactest = general.mac(decoded_mac) if bool(mactest.valid) is False: continue else: diff --git a/switchmap/server/db/ingest/update/zone.py b/switchmap/server/db/ingest/update/zone.py index 623d7ad47..88e6b8b9e 100644 --- a/switchmap/server/db/ingest/update/zone.py +++ b/switchmap/server/db/ingest/update/zone.py @@ -7,6 +7,7 @@ # Application imports from switchmap.core import log from switchmap.core import general +from switchmap.core.mac_utils import decode_mac_address from switchmap.server.db.table import oui as _oui from switchmap.server import ZoneObjects from switchmap.server import PairMacIp @@ -400,7 +401,9 @@ def _process_pairmacips(idx_zone, table): continue # Create lowercase version of mac address. Skip if invalid - mactest = general.mac(next_mac) + # Handle double-encoded MAC addresses from async poller + decoded_mac = decode_mac_address(next_mac) + mactest = general.mac(decoded_mac) if bool(mactest.valid) is False: continue else: @@ -442,7 +445,9 @@ def _arp_table(idx_zone, data): continue # Create lowercase version of mac address. Skip if invalid. - mactest = general.mac(next_mac) + # Handle double-encoded MAC addresses from async poller + decoded_mac = decode_mac_address(next_mac) + mactest = general.mac(decoded_mac) if bool(mactest.valid) is False: continue else: diff --git a/switchmap/server/db/misc/oui.py b/switchmap/server/db/misc/oui.py index 44f32a24c..3fe8b6c2f 100644 --- a/switchmap/server/db/misc/oui.py +++ b/switchmap/server/db/misc/oui.py @@ -11,11 +11,12 @@ from sqlalchemy.exc import IntegrityError -def update_db_oui(filepath): +def update_db_oui(filepath, new=False): """Update the database with Oui data. Args: filepath: File to process + new: If True, skip existing entry checks for new installations Returns: None @@ -46,11 +47,20 @@ def update_db_oui(filepath): organization = row["organization"].strip() rows.append(IOui(oui=oui, organization=organization, enabled=1)) - for row in rows: - existing_entry = _oui.exists(row.oui) - if not existing_entry: - _oui.insert_row([row]) - elif existing_entry.organization != row.organization: - _oui.update_row(existing_entry.idx_oui, row) + if new: + # Bulk insert on fresh install, skip checks + try: + _oui.insert_row(rows) + except IntegrityError: + SCOPED_SESSION.rollback() + new = False + if not new: + # For updates, check existing entries + for row in rows: + existing_entry = _oui.exists(row.oui) + if not existing_entry: + _oui.insert_row([row]) + elif existing_entry.organization != row.organization: + _oui.update_row(existing_entry.idx_oui, row) SCOPED_SESSION.commit() diff --git a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscoc2900.py b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscoc2900.py index 266201a60..545034e9d 100755 --- a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscoc2900.py +++ b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscoc2900.py @@ -4,7 +4,7 @@ import os import sys import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -153,7 +153,7 @@ def test_init_query(self): pass -class TestMibCiscoc2900(unittest.TestCase): +class TestMibCiscoc2900(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -167,11 +167,9 @@ class TestMibCiscoc2900(unittest.TestCase): # Set the stage for SNMPwalk for integer results snmpobj_integer = Mock(spec=Query) - mock_spec_integer = { - "swalk.return_value": nwalk_results_integer, - "walk.return_value": nwalk_results_integer, - } - snmpobj_integer.configure_mock(**mock_spec_integer) + # Configure async methods + snmpobj_integer.swalk = AsyncMock(return_value=nwalk_results_integer) + snmpobj_integer.walk = AsyncMock(return_value=nwalk_results_integer) # Initializing key variables expected_dict = { @@ -214,7 +212,7 @@ def test___init__(self): """Testing function __init__.""" pass - def test_layer1(self): + async def test_layer1(self): """Testing function layer1.""" # Initializing key variables expected_dict = { @@ -230,12 +228,11 @@ def test_layer1(self): # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.nwalk_results_integer} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.nwalk_results_integer) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.layer1() + results = await testobj.layer1() # Basic testing of results for primary in results.keys(): @@ -245,7 +242,7 @@ def test_layer1(self): expected_dict[primary][secondary], ) - def test_c2900portlinkbeatstatus(self): + async def test_c2900portlinkbeatstatus(self): """Testing function c2900portlinkbeatstatus.""" # Initialize key variables oid_key = "c2900PortLinkbeatStatus" @@ -253,7 +250,7 @@ def test_c2900portlinkbeatstatus(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.c2900portlinkbeatstatus() + results = await testobj.c2900portlinkbeatstatus() # Basic testing of results for key, value in results.items(): @@ -261,18 +258,18 @@ def test_c2900portlinkbeatstatus(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.c2900portlinkbeatstatus(oidonly=True) + results = await testobj.c2900portlinkbeatstatus(oidonly=True) self.assertEqual(results, oid) - def test_c2900portduplexstatus(self): + async def test_c2900portduplexstatus(self): """Testing function c2900portduplexstatus.""" # Initialize key variables - oid_key = "c2900PortLinkbeatStatus" + oid_key = "c2900PortDuplexStatus" oid = ".1.3.6.1.4.1.9.9.87.1.4.1.1.32" # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.c2900portduplexstatus() + results = await testobj.c2900portduplexstatus() # Basic testing of results for key, value in results.items(): @@ -280,7 +277,7 @@ def test_c2900portduplexstatus(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.c2900portduplexstatus(oidonly=True) + results = await testobj.c2900portduplexstatus(oidonly=True) self.assertEqual(results, oid) diff --git a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscocdp.py b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscocdp.py index d5deec65a..359efee45 100755 --- a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscocdp.py +++ b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscocdp.py @@ -4,7 +4,7 @@ import os import sys import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -106,7 +106,7 @@ def swalk(self): pass -class TestCiscoCDPFunctions(unittest.TestCase): +class TestCiscoCDPFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -142,7 +142,7 @@ def test_init_query(self): pass -class TestCiscoCDP(unittest.TestCase): +class TestCiscoCDP(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -174,26 +174,24 @@ def tearDownClass(cls): # Cleanup the CONFIG.cleanup() - def test_supported(self): + async def test_supported(self): """Testing method / function supported.""" # Set the stage for oid_exists returning True snmpobj = Mock(spec=Query) - mock_spec = {"oid_exists.return_value": True} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=True) # Test supported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), True) + self.assertEqual(await testobj.supported(), True) # Set the stage for oid_exists returning False - mock_spec = {"oid_exists.return_value": False} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=False) # Test unsupported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), False) + self.assertEqual(await testobj.supported(), False) - def test_layer1(self): + async def test_layer1(self): """Testing method / function layer1.""" # Initializing key variables expected_dict = { @@ -211,12 +209,11 @@ def test_layer1(self): # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.walk_results_string} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.walk_results_string) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.layer1() + results = await testobj.layer1() # Basic testing of results for primary in results.keys(): @@ -226,61 +223,58 @@ def test_layer1(self): expected_dict[primary][secondary], ) - def test_cdpcachedeviceid(self): + async def test_cdpcachedeviceid(self): """Testing method / function cdpcachedeviceid.""" # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.walk_results_string} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.walk_results_string) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.cdpcachedeviceid() + results = await testobj.cdpcachedeviceid() # Basic testing of results for key in results.keys(): self.assertEqual(isinstance(key, int), True) # Test that we are getting the correct OID - results = testobj.cdpcachedeviceid(oidonly=True) + results = await testobj.cdpcachedeviceid(oidonly=True) self.assertEqual(results, ".1.3.6.1.4.1.9.9.23.1.2.1.1.6") - def test_cdpcacheplatform(self): + async def test_cdpcacheplatform(self): """Testing method / function cdpcacheplatform.""" # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.walk_results_string} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.walk_results_string) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.cdpcacheplatform() + results = await testobj.cdpcacheplatform() # Basic testing of results for key in results.keys(): self.assertEqual(isinstance(key, int), True) # Test that we are getting the correct OID - results = testobj.cdpcacheplatform(oidonly=True) + results = await testobj.cdpcacheplatform(oidonly=True) self.assertEqual(results, ".1.3.6.1.4.1.9.9.23.1.2.1.1.8") - def test_cdpcachedeviceport(self): + async def test_cdpcachedeviceport(self): """Testing method / function cdpcachedeviceport.""" # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.walk_results_string} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.walk_results_string) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.cdpcachedeviceport() + results = await testobj.cdpcachedeviceport() # Basic testing of results for key in results.keys(): self.assertEqual(isinstance(key, int), True) # Test that we are getting the correct OID - results = testobj.cdpcachedeviceport(oidonly=True) + results = await testobj.cdpcachedeviceport(oidonly=True) self.assertEqual(results, ".1.3.6.1.4.1.9.9.23.1.2.1.1.7") def test__ifindex(self): diff --git a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscoietfip.py b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscoietfip.py index 8839e7f84..9f33d7059 100755 --- a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscoietfip.py +++ b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscoietfip.py @@ -104,7 +104,7 @@ def swalk(self): pass -class TestCiscoIetfIpQueryFunctions(unittest.TestCase): +class TestCiscoIetfIpQueryFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -140,7 +140,7 @@ def test_init_query(self): pass -class TestCiscoIetfIpQuery(unittest.TestCase): +class TestCiscoIetfIpQuery(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### diff --git a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscostack.py b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscostack.py index 234736da2..0a1f8341c 100755 --- a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscostack.py +++ b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscostack.py @@ -4,7 +4,7 @@ import os import sys import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -117,7 +117,7 @@ def walk(self): pass -class TestMibCiscoStackFunctions(unittest.TestCase): +class TestMibCiscoStackFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -153,7 +153,7 @@ def test_init_query(self): pass -class TestMibCiscoStack(unittest.TestCase): +class TestMibCiscoStack(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -167,11 +167,8 @@ class TestMibCiscoStack(unittest.TestCase): # Set the stage for SNMPwalk for integer results snmpobj_integer = Mock(spec=Query) - mock_spec_integer = { - "swalk.return_value": nwalk_results_integer, - "walk.return_value": nwalk_results_integer, - } - snmpobj_integer.configure_mock(**mock_spec_integer) + snmpobj_integer.swalk = AsyncMock(return_value=nwalk_results_integer) + snmpobj_integer.walk = AsyncMock(return_value=nwalk_results_integer) # Initializing key variables expected_dict = {100: {"portDuplex": 100}, 200: {"portDuplex": 100}} @@ -209,7 +206,7 @@ def test_layer1(self): """Testing function layer1.""" pass - def test_portduplex(self): + async def test_portduplex(self): """Testing function portduplex.""" # Initialize key variables oid_key = "portDuplex" @@ -217,7 +214,7 @@ def test_portduplex(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.portduplex() + results = await testobj.portduplex() # Basic testing of results for key, value in results.items(): @@ -225,10 +222,10 @@ def test_portduplex(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.portduplex(oidonly=True) + results = await testobj.portduplex(oidonly=True) self.assertEqual(results, oid) - def test__portifindex(self): + async def test__portifindex(self): """Testing function _portifindex.""" # Initialize key variables oid_key = "portDuplex" @@ -236,7 +233,7 @@ def test__portifindex(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj._portifindex() + results = await testobj._portifindex() # Basic testing of results for key, value in results.items(): @@ -244,7 +241,7 @@ def test__portifindex(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj._portifindex(oidonly=True) + results = await testobj._portifindex(oidonly=True) self.assertEqual(results, oid) diff --git a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovlaniftablerelationship.py b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovlaniftablerelationship.py index 164772296..d76e1cf1e 100755 --- a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovlaniftablerelationship.py +++ b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovlaniftablerelationship.py @@ -4,7 +4,7 @@ import os import sys import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -55,6 +55,7 @@ sys.exit(2) # Create the necessary configuration to load the module +from switchmap.poller import snmp from tests.testlib_ import setup CONFIG = setup.config() @@ -119,7 +120,7 @@ def walk(self): pass -class TestMibCiscoVlanIfTableFunctions(unittest.TestCase): +class TestMibCiscoVlanIfTableFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -155,7 +156,7 @@ def test_init_query(self): pass -class TestMibCiscoVlanIfTable(unittest.TestCase): +class TestMibCiscoVlanIfTable(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -174,11 +175,8 @@ class TestMibCiscoVlanIfTable(unittest.TestCase): # Set the stage for SNMPwalk for integer results snmpobj_integer = Mock(spec=Query) - mock_spec_integer = { - "swalk.return_value": nwalk_results_integer, - "walk.return_value": nwalk_results_integer, - } - snmpobj_integer.configure_mock(**mock_spec_integer) + snmpobj_integer.swalk = AsyncMock(return_value=nwalk_results_integer) + snmpobj_integer.walk = AsyncMock(return_value=nwalk_results_integer) # Initializing key variables expected_dict = { @@ -217,26 +215,24 @@ def test___init__(self): """Testing function __init__.""" pass - def test_supported(self): + async def test_supported(self): """Testing method / function supported.""" # Set the stage for oid_exists returning True snmpobj = Mock(spec=Query) - mock_spec = {"oid_exists.return_value": True} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=True) # Test supported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), True) + self.assertEqual(await testobj.supported(), True) # Set the stage for oid_exists returning False - mock_spec = {"oid_exists.return_value": False} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=False) # Test unsupported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), False) + self.assertEqual(await testobj.supported(), False) - def test_layer1(self): + async def test_layer1(self): """Testing function layer1.""" # Initializing key variables expected_dict = { @@ -248,12 +244,11 @@ def test_layer1(self): # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.nwalk_results_integer} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.nwalk_results_integer) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.layer1() + results = await testobj.layer1() # Basic testing of results for primary in results.keys(): @@ -263,14 +258,14 @@ def test_layer1(self): expected_dict[primary][secondary], ) - def test_cviroutedvlanifindex(self): + async def test_cviroutedvlanifindex(self): """Testing function cviroutedvlanifindex.""" oid_key = "CiscoVlanIftableRelationship" oid = ".1.3.6.1.4.1.9.9.128.1.1.1.1.3" # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.cviroutedvlanifindex() + results = await testobj.cviroutedvlanifindex() # Basic testing of results for key, value in results.items(): @@ -279,7 +274,7 @@ def test_cviroutedvlanifindex(self): self.assertEqual(value, expected_value) # Test that we are getting the correct OID - results = testobj.cviroutedvlanifindex(oidonly=True) + results = await testobj.cviroutedvlanifindex(oidonly=True) self.assertEqual(results, oid) diff --git a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovlanmembership.py b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovlanmembership.py index 0c697f5f6..b7bd27dc2 100755 --- a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovlanmembership.py +++ b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovlanmembership.py @@ -4,7 +4,7 @@ import os import sys import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -119,7 +119,7 @@ def walk(self): pass -class TestMibCiscoVlanMemberFunctions(unittest.TestCase): +class TestMibCiscoVlanMemberFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -155,7 +155,7 @@ def test_init_query(self): pass -class TestMibCiscoVlanMember(unittest.TestCase): +class TestMibCiscoVlanMember(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -169,11 +169,8 @@ class TestMibCiscoVlanMember(unittest.TestCase): # Set the stage for SNMPwalk for integer results snmpobj_integer = Mock(spec=Query) - mock_spec_integer = { - "swalk.return_value": nwalk_results_integer, - "walk.return_value": nwalk_results_integer, - } - snmpobj_integer.configure_mock(**mock_spec_integer) + snmpobj_integer.swalk = AsyncMock(return_value=nwalk_results_integer) + snmpobj_integer.walk = AsyncMock(return_value=nwalk_results_integer) # Initializing key variables expected_dict = { @@ -210,26 +207,24 @@ def test___init__(self): """Testing function __init__.""" pass - def test_supported(self): + async def test_supported(self): """Testing method / function supported.""" # Set the stage for oid_exists returning True snmpobj = Mock(spec=Query) - mock_spec = {"oid_exists.return_value": True} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=True) # Test supported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), True) + self.assertEqual(await testobj.supported(), True) # Set the stage for oid_exists returning False - mock_spec = {"oid_exists.return_value": False} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=False) # Test unsupported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), False) + self.assertEqual(await testobj.supported(), False) - def test_layer1(self): + async def test_layer1(self): """Testing function layer1.""" # Initializing key variables expected_dict = { @@ -239,12 +234,11 @@ def test_layer1(self): # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.nwalk_results_integer} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.nwalk_results_integer) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.layer1() + results = await testobj.layer1() # Basic testing of results for primary in results.keys(): @@ -254,14 +248,14 @@ def test_layer1(self): expected_dict[primary][secondary], ) - def test_vmvlan(self): + async def test_vmvlan(self): """Testing function vmvlan.""" oid_key = "vmVlan" oid = ".1.3.6.1.4.1.9.9.68.1.2.2.1.2" # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.vmvlan() + results = await testobj.vmvlan() # Basic testing of results for key, value in results.items(): @@ -269,10 +263,10 @@ def test_vmvlan(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vmvlan(oidonly=True) + results = await testobj.vmvlan(oidonly=True) self.assertEqual(results, oid) - def test_vmportstatus(self): + async def test_vmportstatus(self): """Testing function vmportstatus.""" # Initialize key variables oid_key = "vmPortStatus" @@ -280,7 +274,7 @@ def test_vmportstatus(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.vmportstatus() + results = await testobj.vmportstatus() # Basic testing of results for key, value in results.items(): @@ -288,7 +282,7 @@ def test_vmportstatus(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vmportstatus(oidonly=True) + results = await testobj.vmportstatus(oidonly=True) self.assertEqual(results, oid) diff --git a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovtp.py b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovtp.py index 64d9007e0..f974950bb 100755 --- a/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovtp.py +++ b/tests/switchmap_/poller/snmp/mib/cisco/test_mib_ciscovtp.py @@ -5,7 +5,7 @@ import sys import binascii import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -118,7 +118,7 @@ def walk(self): pass -class TestMibCiscoVTPFunctions(unittest.TestCase): +class TestMibCiscoVTPFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -154,7 +154,7 @@ def test_init_query(self): pass -class TestMibCiscoVTP(unittest.TestCase): +class TestMibCiscoVTP(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -168,33 +168,24 @@ class TestMibCiscoVTP(unittest.TestCase): # Set the stage for SNMPwalk for integer results snmpobj_integer = Mock(spec=Query) - mock_spec_integer = { - "swalk.return_value": nwalk_results_integer, - "walk.return_value": nwalk_results_integer, - } - snmpobj_integer.configure_mock(**mock_spec_integer) + snmpobj_integer.swalk = AsyncMock(return_value=nwalk_results_integer) + snmpobj_integer.walk = AsyncMock(return_value=nwalk_results_integer) # Normalized walk returning integers for the ifIndex nwalk_results_ifindex = {100: 100, 200: 200} # Set the stage for SNMPwalk for integer results for the ifIndex snmpobj_ifindex = Mock(spec=Query) - mock_spec_ifindex = { - "swalk.return_value": nwalk_results_ifindex, - "walk.return_value": nwalk_results_ifindex, - } - snmpobj_ifindex.configure_mock(**mock_spec_ifindex) + snmpobj_ifindex.swalk = AsyncMock(return_value=nwalk_results_ifindex) + snmpobj_ifindex.walk = AsyncMock(return_value=nwalk_results_ifindex) # Normalized walk returning strings nwalk_results_bytes = {100: b"1234", 200: b"5678"} # Set the stage for SNMPwalk for string results snmpobj_bytes = Mock(spec=Query) - mock_spec_bytes = { - "swalk.return_value": nwalk_results_bytes, - "walk.return_value": nwalk_results_bytes, - } - snmpobj_bytes.configure_mock(**mock_spec_bytes) + snmpobj_bytes.swalk = AsyncMock(return_value=nwalk_results_bytes) + snmpobj_bytes.walk = AsyncMock(return_value=nwalk_results_bytes) # Normalized walk returning binary data nwalk_results_binary = { @@ -204,11 +195,8 @@ class TestMibCiscoVTP(unittest.TestCase): # Set the stage for SNMPwalk for binary results snmpobj_binary = Mock(spec=Query) - mock_spec_binary = { - "swalk.return_value": nwalk_results_binary, - "walk.return_value": nwalk_results_binary, - } - snmpobj_binary.configure_mock(**mock_spec_binary) + snmpobj_binary.swalk = AsyncMock(return_value=nwalk_results_binary) + snmpobj_binary.walk = AsyncMock(return_value=nwalk_results_binary) # Initializing key variables expected_dict = { @@ -218,7 +206,7 @@ class TestMibCiscoVTP(unittest.TestCase): "vlanTrunkPortNativeVlan": 1234, "vlanTrunkPortEncapsulationType": 1234, "vlanTrunkPortVlansEnabled": 1234, - "vtpVlanType": 1234, + "vtpVlanType": "1234", "vtpVlanName": "1234", "vtpVlanState": 1234, }, @@ -228,7 +216,7 @@ class TestMibCiscoVTP(unittest.TestCase): "vlanTrunkPortNativeVlan": 5678, "vlanTrunkPortEncapsulationType": 5678, "vlanTrunkPortVlansEnabled": 5678, - "vtpVlanType": 5678, + "vtpVlanType": "5678", "vtpVlanName": "5678", "vtpVlanState": 5678, }, @@ -275,7 +263,7 @@ def test_layer1(self): # the same type of results (eg. int, string, hex) pass - def test_vlantrunkportencapsulationtype(self): + async def test_vlantrunkportencapsulationtype(self): """Testing function vlantrunkportencapsulationtype.""" # Initialize key variables oid_key = "vlanTrunkPortEncapsulationType" @@ -283,7 +271,7 @@ def test_vlantrunkportencapsulationtype(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.vlantrunkportencapsulationtype() + results = await testobj.vlantrunkportencapsulationtype() # Basic testing of results for key, value in results.items(): @@ -291,10 +279,10 @@ def test_vlantrunkportencapsulationtype(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vlantrunkportencapsulationtype(oidonly=True) + results = await testobj.vlantrunkportencapsulationtype(oidonly=True) self.assertEqual(results, oid) - def test_vlantrunkportnativevlan(self): + async def test_vlantrunkportnativevlan(self): """Testing function vlantrunkportnativevlan.""" # Initialize key variables oid_key = "vlanTrunkPortNativeVlan" @@ -302,7 +290,7 @@ def test_vlantrunkportnativevlan(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.vlantrunkportnativevlan() + results = await testobj.vlantrunkportnativevlan() # Basic testing of results for key, value in results.items(): @@ -310,10 +298,10 @@ def test_vlantrunkportnativevlan(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vlantrunkportnativevlan(oidonly=True) + results = await testobj.vlantrunkportnativevlan(oidonly=True) self.assertEqual(results, oid) - def test_vlantrunkportdynamicstatus(self): + async def test_vlantrunkportdynamicstatus(self): """Testing function vlantrunkportdynamicstatus.""" # Initialize key variables oid_key = "vlanTrunkPortDynamicStatus" @@ -321,7 +309,7 @@ def test_vlantrunkportdynamicstatus(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.vlantrunkportdynamicstatus() + results = await testobj.vlantrunkportdynamicstatus() # Basic testing of results for key, value in results.items(): @@ -329,10 +317,10 @@ def test_vlantrunkportdynamicstatus(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vlantrunkportdynamicstatus(oidonly=True) + results = await testobj.vlantrunkportdynamicstatus(oidonly=True) self.assertEqual(results, oid) - def test_vlantrunkportdynamicstate(self): + async def test_vlantrunkportdynamicstate(self): """Testing function vlantrunkportdynamicstate.""" # Initialize key variables oid_key = "vlanTrunkPortDynamicState" @@ -340,7 +328,7 @@ def test_vlantrunkportdynamicstate(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.vlantrunkportdynamicstate() + results = await testobj.vlantrunkportdynamicstate() # Basic testing of results for key, value in results.items(): @@ -348,10 +336,10 @@ def test_vlantrunkportdynamicstate(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vlantrunkportdynamicstate(oidonly=True) + results = await testobj.vlantrunkportdynamicstate(oidonly=True) self.assertEqual(results, oid) - def test_vtpvlanname(self): + async def test_vtpvlanname(self): """Testing function vtpvlanname.""" # Initialize key variables oid_key = "vtpVlanName" @@ -359,7 +347,7 @@ def test_vtpvlanname(self): # Get results testobj = testimport.init_query(self.snmpobj_bytes) - results = testobj.vtpvlanname() + results = await testobj.vtpvlanname() # Basic testing of results for key, value in results.items(): @@ -367,10 +355,10 @@ def test_vtpvlanname(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vtpvlanname(oidonly=True) + results = await testobj.vtpvlanname(oidonly=True) self.assertEqual(results, oid) - def test_vtpvlantype(self): + async def test_vtpvlantype(self): """Testing function vtpvlantype.""" # Initialize key variables oid_key = "vtpVlanType" @@ -378,7 +366,7 @@ def test_vtpvlantype(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.vtpvlantype() + results = await testobj.vtpvlantype() # Basic testing of results for key, value in results.items(): @@ -386,10 +374,10 @@ def test_vtpvlantype(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vtpvlantype(oidonly=True) + results = await testobj.vtpvlantype(oidonly=True) self.assertEqual(results, oid) - def test_vtpvlanstate(self): + async def test_vtpvlanstate(self): """Testing function vtpvlanstate.""" # Initialize key variables oid_key = "vtpVlanState" @@ -397,7 +385,7 @@ def test_vtpvlanstate(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.vtpvlanstate() + results = await testobj.vtpvlanstate() # Basic testing of results for key, value in results.items(): @@ -405,7 +393,7 @@ def test_vtpvlanstate(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.vtpvlanstate(oidonly=True) + results = await testobj.vtpvlanstate(oidonly=True) self.assertEqual(results, oid) def test_vlantrunkportvlansenabled(self): diff --git a/tests/switchmap_/poller/snmp/mib/generic/test_mib_bridge.py b/tests/switchmap_/poller/snmp/mib/generic/test_mib_bridge.py index 935bbbaac..814b077f1 100755 --- a/tests/switchmap_/poller/snmp/mib/generic/test_mib_bridge.py +++ b/tests/switchmap_/poller/snmp/mib/generic/test_mib_bridge.py @@ -115,7 +115,7 @@ def walk(self): pass -class TestMibBridgeFunctions(unittest.TestCase): +class TestMibBridgeFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -151,7 +151,7 @@ def test_init_query(self): pass -class TestMibBridge(unittest.TestCase): +class TestMibBridge(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### diff --git a/tests/switchmap_/poller/snmp/mib/generic/test_mib_entity.py b/tests/switchmap_/poller/snmp/mib/generic/test_mib_entity.py index 8444addf4..ccf0a90f5 100755 --- a/tests/switchmap_/poller/snmp/mib/generic/test_mib_entity.py +++ b/tests/switchmap_/poller/snmp/mib/generic/test_mib_entity.py @@ -115,7 +115,7 @@ def walk(self): pass -class TestMibEntityFunctions(unittest.TestCase): +class TestMibEntityFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -151,7 +151,7 @@ def test_init_query(self): pass -class TestMibEntity(unittest.TestCase): +class TestMibEntity(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### diff --git a/tests/switchmap_/poller/snmp/mib/generic/test_mib_essswitch.py b/tests/switchmap_/poller/snmp/mib/generic/test_mib_essswitch.py index f75f3df71..aa5381c00 100755 --- a/tests/switchmap_/poller/snmp/mib/generic/test_mib_essswitch.py +++ b/tests/switchmap_/poller/snmp/mib/generic/test_mib_essswitch.py @@ -4,7 +4,7 @@ import os import sys import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -106,7 +106,7 @@ def swalk(self): pass -class TestMibESSSwitchFunctions(unittest.TestCase): +class TestMibESSSwitchFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -142,7 +142,7 @@ def test_init_query(self): pass -class TestMibESSSwitch(unittest.TestCase): +class TestMibESSSwitch(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -171,26 +171,24 @@ def tearDownClass(cls): # Cleanup the CONFIG.cleanup() - def test_supported(self): + async def test_supported(self): """Testing method / function supported.""" # Set the stage for oid_exists returning True snmpobj = Mock(spec=Query) - mock_spec = {"oid_exists.return_value": True} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=True) # Test supported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), True) + self.assertEqual(await testobj.supported(), True) # Set the stage for oid_exists returning False - mock_spec = {"oid_exists.return_value": False} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=False) # Test unsupported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), False) + self.assertEqual(await testobj.supported(), False) - def test_layer1(self): + async def test_layer1(self): """Testing method / function layer1.""" # Initializing key variables expected_dict = { @@ -200,12 +198,11 @@ def test_layer1(self): # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.nwalk_results_integer} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.nwalk_results_integer) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.layer1() + results = await testobj.layer1() # Basic testing of results for primary in results.keys(): @@ -215,23 +212,22 @@ def test_layer1(self): expected_dict[primary][secondary], ) - def test_swportduplexstatus(self): + async def test_swportduplexstatus(self): """Testing method / function swportduplexstatus.""" # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.nwalk_results_integer} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.nwalk_results_integer) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.swportduplexstatus() + results = await testobj.swportduplexstatus() # Basic testing of results for key in results.keys(): self.assertEqual(isinstance(key, int), True) # Test that we are getting the correct OID - results = testobj.swportduplexstatus(oidonly=True) + results = await testobj.swportduplexstatus(oidonly=True) self.assertEqual(results, ".1.3.6.1.4.1.437.1.1.3.3.1.1.30") diff --git a/tests/switchmap_/poller/snmp/mib/generic/test_mib_etherlike.py b/tests/switchmap_/poller/snmp/mib/generic/test_mib_etherlike.py index 5c141bc38..1631f092d 100755 --- a/tests/switchmap_/poller/snmp/mib/generic/test_mib_etherlike.py +++ b/tests/switchmap_/poller/snmp/mib/generic/test_mib_etherlike.py @@ -2,7 +2,7 @@ """Test the mib_etherlike module.""" import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock import os import sys @@ -153,7 +153,7 @@ def test_init_query(self): pass -class TestMibEtherlike(unittest.TestCase): +class TestMibEtherlike(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -182,26 +182,24 @@ def tearDownClass(cls): # Cleanup the CONFIG.cleanup() - def test_supported(self): + async def test_supported(self): """Testing method / function supported.""" # Set the stage for oid_exists returning True snmpobj = Mock(spec=Query) - mock_spec = {"oid_exists.return_value": True} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=True) # Test supported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), True) + self.assertEqual(await testobj.supported(), True) # Set the stage for oid_exists returning False - mock_spec = {"oid_exists.return_value": False} - snmpobj.configure_mock(**mock_spec) + snmpobj.oid_exists = AsyncMock(return_value=False) # Test unsupported testobj = testimport.init_query(snmpobj) - self.assertEqual(testobj.supported(), False) + self.assertEqual(await testobj.supported(), False) - def test_layer1(self): + async def test_layer1(self): """Testing method / function layer1.""" # Initializing key variables expected_dict = { @@ -211,12 +209,11 @@ def test_layer1(self): # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.nwalk_results_integer} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.nwalk_results_integer) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.layer1() + results = await testobj.layer1() # Basic testing of results for primary in results.keys(): @@ -226,23 +223,22 @@ def test_layer1(self): expected_dict[primary][secondary], ) - def test_dot3statsduplexstatus(self): + async def test_dot3statsduplexstatus(self): """Testing method / function dot3statsduplexstatus.""" # Set the stage for SNMPwalk snmpobj = Mock(spec=Query) - mock_spec = {"swalk.return_value": self.nwalk_results_integer} - snmpobj.configure_mock(**mock_spec) + snmpobj.swalk = AsyncMock(return_value=self.nwalk_results_integer) # Get results testobj = testimport.init_query(snmpobj) - results = testobj.dot3statsduplexstatus() + results = await testobj.dot3statsduplexstatus() # Basic testing of results for key in results.keys(): self.assertEqual(isinstance(key, int), True) # Test that we are getting the correct OID - results = testobj.dot3statsduplexstatus(oidonly=True) + results = await testobj.dot3statsduplexstatus(oidonly=True) self.assertEqual(results, ".1.3.6.1.2.1.10.7.2.1.19") diff --git a/tests/switchmap_/poller/snmp/mib/generic/test_mib_if.py b/tests/switchmap_/poller/snmp/mib/generic/test_mib_if.py index 4c1ab15df..cded2fa74 100755 --- a/tests/switchmap_/poller/snmp/mib/generic/test_mib_if.py +++ b/tests/switchmap_/poller/snmp/mib/generic/test_mib_if.py @@ -5,7 +5,7 @@ import sys import binascii import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -154,7 +154,7 @@ def test_init_query(self): pass -class TestMibIf(unittest.TestCase): +class TestMibIf(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -168,33 +168,24 @@ class TestMibIf(unittest.TestCase): # Set the stage for SNMPwalk for integer results snmpobj_integer = Mock(spec=Query) - mock_spec_integer = { - "swalk.return_value": nwalk_results_integer, - "walk.return_value": nwalk_results_integer, - } - snmpobj_integer.configure_mock(**mock_spec_integer) + snmpobj_integer.swalk = AsyncMock(return_value=nwalk_results_integer) + snmpobj_integer.walk = AsyncMock(return_value=nwalk_results_integer) # Normalized walk returning integers for the ifIndex nwalk_results_ifindex = {100: 100, 200: 200} # Set the stage for SNMPwalk for integer results for the ifIndex snmpobj_ifindex = Mock(spec=Query) - mock_spec_ifindex = { - "swalk.return_value": nwalk_results_ifindex, - "walk.return_value": nwalk_results_ifindex, - } - snmpobj_ifindex.configure_mock(**mock_spec_ifindex) + snmpobj_ifindex.swalk = AsyncMock(return_value=nwalk_results_ifindex) + snmpobj_ifindex.walk = AsyncMock(return_value=nwalk_results_ifindex) # Normalized walk returning strings nwalk_results_bytes = {100: b"1234", 200: b"5678"} # Set the stage for SNMPwalk for string results snmpobj_bytes = Mock(spec=Query) - mock_spec_bytes = { - "swalk.return_value": nwalk_results_bytes, - "walk.return_value": nwalk_results_bytes, - } - snmpobj_bytes.configure_mock(**mock_spec_bytes) + snmpobj_bytes.swalk = AsyncMock(return_value=nwalk_results_bytes) + snmpobj_bytes.walk = AsyncMock(return_value=nwalk_results_bytes) # Normalized walk returning binary data nwalk_results_binary = { @@ -204,11 +195,8 @@ class TestMibIf(unittest.TestCase): # Set the stage for SNMPwalk for binary results snmpobj_binary = Mock(spec=Query) - mock_spec_binary = { - "swalk.return_value": nwalk_results_binary, - "walk.return_value": nwalk_results_binary, - } - snmpobj_binary.configure_mock(**mock_spec_binary) + snmpobj_binary.swalk = AsyncMock(return_value=nwalk_results_binary) + snmpobj_binary.walk = AsyncMock(return_value=nwalk_results_binary) # Initializing key variables expected_dict = { @@ -289,7 +277,7 @@ def test_layer1(self): # the same type of results (eg. int, string, hex) pass - def test_iflastchange(self): + async def test_iflastchange(self): """Testing function iflastchange.""" # Initialize key variables oid_key = "ifLastChange" @@ -297,7 +285,7 @@ def test_iflastchange(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.iflastchange() + results = await testobj.iflastchange() # Basic testing of results for key, value in results.items(): @@ -305,10 +293,10 @@ def test_iflastchange(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.iflastchange(oidonly=True) + results = await testobj.iflastchange(oidonly=True) self.assertEqual(results, oid) - def test_ifinoctets(self): + async def test_ifinoctets(self): """Testing function ifinoctets.""" # Initialize key variables oid_key = "ifInOctets" @@ -316,7 +304,7 @@ def test_ifinoctets(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifinoctets() + results = await testobj.ifinoctets() # Basic testing of results for key, value in results.items(): @@ -324,10 +312,10 @@ def test_ifinoctets(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifinoctets(oidonly=True) + results = await testobj.ifinoctets(oidonly=True) self.assertEqual(results, oid) - def test_ifoutoctets(self): + async def test_ifoutoctets(self): """Testing function ifoutoctets.""" # Initialize key variables oid_key = "ifOutOctets" @@ -335,7 +323,7 @@ def test_ifoutoctets(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifoutoctets() + results = await testobj.ifoutoctets() # Basic testing of results for key, value in results.items(): @@ -343,10 +331,10 @@ def test_ifoutoctets(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifoutoctets(oidonly=True) + results = await testobj.ifoutoctets(oidonly=True) self.assertEqual(results, oid) - def test_ifdescr(self): + async def test_ifdescr(self): """Testing function ifdescr.""" # Initialize key variables oid_key = "ifDescr" @@ -354,7 +342,7 @@ def test_ifdescr(self): # Get results testobj = testimport.init_query(self.snmpobj_bytes) - results = testobj.ifdescr() + results = await testobj.ifdescr() # Basic testing of results for key, value in results.items(): @@ -362,10 +350,10 @@ def test_ifdescr(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifdescr(oidonly=True) + results = await testobj.ifdescr(oidonly=True) self.assertEqual(results, oid) - def test_iftype(self): + async def test_iftype(self): """Testing function iftype.""" # Initialize key variables oid_key = "ifType" @@ -373,7 +361,7 @@ def test_iftype(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.iftype() + results = await testobj.iftype() # Basic testing of results for key, value in results.items(): @@ -381,10 +369,10 @@ def test_iftype(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.iftype(oidonly=True) + results = await testobj.iftype(oidonly=True) self.assertEqual(results, oid) - def test_ifspeed(self): + async def test_ifspeed(self): """Testing function ifspeed.""" # Initialize key variables oid_key = "ifSpeed" @@ -392,7 +380,7 @@ def test_ifspeed(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifspeed() + results = await testobj.ifspeed() # Basic testing of results for key, value in results.items(): @@ -400,10 +388,10 @@ def test_ifspeed(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifspeed(oidonly=True) + results = await testobj.ifspeed(oidonly=True) self.assertEqual(results, oid) - def test_ifadminstatus(self): + async def test_ifadminstatus(self): """Testing function ifadminstatus.""" # Initialize key variables oid_key = "ifAdminStatus" @@ -411,7 +399,7 @@ def test_ifadminstatus(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifadminstatus() + results = await testobj.ifadminstatus() # Basic testing of results for key, value in results.items(): @@ -419,10 +407,10 @@ def test_ifadminstatus(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifadminstatus(oidonly=True) + results = await testobj.ifadminstatus(oidonly=True) self.assertEqual(results, oid) - def test_ifoperstatus(self): + async def test_ifoperstatus(self): """Testing function ifoperstatus.""" # Initialize key variables oid_key = "ifOperStatus" @@ -430,7 +418,7 @@ def test_ifoperstatus(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifoperstatus() + results = await testobj.ifoperstatus() # Basic testing of results for key, value in results.items(): @@ -438,10 +426,10 @@ def test_ifoperstatus(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifoperstatus(oidonly=True) + results = await testobj.ifoperstatus(oidonly=True) self.assertEqual(results, oid) - def test_ifalias(self): + async def test_ifalias(self): """Testing function ifalias.""" # Initialize key variables oid_key = "ifAlias" @@ -449,7 +437,7 @@ def test_ifalias(self): # Get results testobj = testimport.init_query(self.snmpobj_bytes) - results = testobj.ifalias() + results = await testobj.ifalias() # Basic testing of results for key, value in results.items(): @@ -457,10 +445,10 @@ def test_ifalias(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifalias(oidonly=True) + results = await testobj.ifalias(oidonly=True) self.assertEqual(results, oid) - def test_ifname(self): + async def test_ifname(self): """Testing function ifname.""" # Initialize key variables oid_key = "ifName" @@ -468,7 +456,7 @@ def test_ifname(self): # Get results testobj = testimport.init_query(self.snmpobj_bytes) - results = testobj.ifname() + results = await testobj.ifname() # Basic testing of results for key, value in results.items(): @@ -476,10 +464,10 @@ def test_ifname(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifname(oidonly=True) + results = await testobj.ifname(oidonly=True) self.assertEqual(results, oid) - def test_ifindex(self): + async def test_ifindex(self): """Testing function ifindex.""" # Initialize key variables oid_key = "ifIndex" @@ -487,7 +475,7 @@ def test_ifindex(self): # Get results testobj = testimport.init_query(self.snmpobj_ifindex) - results = testobj.ifindex() + results = await testobj.ifindex() # Basic testing of results for key, value in results.items(): @@ -499,10 +487,10 @@ def test_ifindex(self): self.assertEqual(key, value) # Test that we are getting the correct OID - results = testobj.ifindex(oidonly=True) + results = await testobj.ifindex(oidonly=True) self.assertEqual(results, oid) - def test_ifphysaddress(self): + async def test_ifphysaddress(self): """Testing function ifphysaddress.""" # Initialize key variables oid_key = "ifPhysAddress" @@ -510,7 +498,7 @@ def test_ifphysaddress(self): # Get results testobj = testimport.init_query(self.snmpobj_binary) - results = testobj.ifphysaddress() + results = await testobj.ifphysaddress() # Basic testing of results for key, value in results.items(): @@ -518,10 +506,10 @@ def test_ifphysaddress(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifphysaddress(oidonly=True) + results = await testobj.ifphysaddress(oidonly=True) self.assertEqual(results, oid) - def test_ifinmulticastpkts(self): + async def test_ifinmulticastpkts(self): """Testing function ifinmulticastpkts.""" # Initialize key variables oid_key = "ifInMulticastPkts" @@ -529,7 +517,7 @@ def test_ifinmulticastpkts(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifinmulticastpkts() + results = await testobj.ifinmulticastpkts() # Basic testing of results for key, value in results.items(): @@ -537,10 +525,10 @@ def test_ifinmulticastpkts(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifinmulticastpkts(oidonly=True) + results = await testobj.ifinmulticastpkts(oidonly=True) self.assertEqual(results, oid) - def test_ifoutmulticastpkts(self): + async def test_ifoutmulticastpkts(self): """Testing function ifoutmulticastpkts.""" # Initialize key variables oid_key = "ifOutMulticastPkts" @@ -548,7 +536,7 @@ def test_ifoutmulticastpkts(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifoutmulticastpkts() + results = await testobj.ifoutmulticastpkts() # Basic testing of results for key, value in results.items(): @@ -556,10 +544,10 @@ def test_ifoutmulticastpkts(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifoutmulticastpkts(oidonly=True) + results = await testobj.ifoutmulticastpkts(oidonly=True) self.assertEqual(results, oid) - def test_ifinbroadcastpkts(self): + async def test_ifinbroadcastpkts(self): """Testing function ifinbroadcastpkts.""" # Initialize key variables oid_key = "ifInBroadcastPkts" @@ -567,7 +555,7 @@ def test_ifinbroadcastpkts(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifinbroadcastpkts() + results = await testobj.ifinbroadcastpkts() # Basic testing of results for key, value in results.items(): @@ -575,10 +563,10 @@ def test_ifinbroadcastpkts(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifinbroadcastpkts(oidonly=True) + results = await testobj.ifinbroadcastpkts(oidonly=True) self.assertEqual(results, oid) - def test_ifoutbroadcastpkts(self): + async def test_ifoutbroadcastpkts(self): """Testing function ifoutbroadcastpkts.""" # Initialize key variables oid_key = "ifOutBroadcastPkts" @@ -586,7 +574,7 @@ def test_ifoutbroadcastpkts(self): # Get results testobj = testimport.init_query(self.snmpobj_integer) - results = testobj.ifoutbroadcastpkts() + results = await testobj.ifoutbroadcastpkts() # Basic testing of results for key, value in results.items(): @@ -594,7 +582,7 @@ def test_ifoutbroadcastpkts(self): self.assertEqual(value, self.expected_dict[key][oid_key]) # Test that we are getting the correct OID - results = testobj.ifoutbroadcastpkts(oidonly=True) + results = await testobj.ifoutbroadcastpkts(oidonly=True) self.assertEqual(results, oid) def test_ifstackstatus(self): diff --git a/tests/switchmap_/poller/snmp/mib/generic/test_mib_ip.py b/tests/switchmap_/poller/snmp/mib/generic/test_mib_ip.py index 69767046e..e781d29ab 100755 --- a/tests/switchmap_/poller/snmp/mib/generic/test_mib_ip.py +++ b/tests/switchmap_/poller/snmp/mib/generic/test_mib_ip.py @@ -4,7 +4,7 @@ import os import sys import unittest -from mock import Mock +from unittest.mock import Mock, AsyncMock # Try to create a working PYTHONPATH EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -153,7 +153,7 @@ def test_init_query(self): pass -class TestMibIp(unittest.TestCase): +class TestMibIp(unittest.IsolatedAsyncioTestCase): """Checks all functions and methods.""" ######################################################################### @@ -181,18 +181,12 @@ class TestMibIp(unittest.TestCase): # Set the stage for SNMPwalk for binary results snmpobj_ipv4_binary = Mock(spec=Query) - mock_spec_ipv4_binary = { - "swalk.return_value": walk_results_ipv4_binary, - "walk.return_value": walk_results_ipv4_binary, - } - snmpobj_ipv4_binary.configure_mock(**mock_spec_ipv4_binary) + snmpobj_ipv4_binary.swalk = AsyncMock(return_value=walk_results_ipv4_binary) + snmpobj_ipv4_binary.walk = AsyncMock(return_value=walk_results_ipv4_binary) snmpobj_ipv6_binary = Mock(spec=Query) - mock_spec_ipv6_binary = { - "swalk.return_value": walk_results_ipv6_binary, - "walk.return_value": walk_results_ipv6_binary, - } - snmpobj_ipv6_binary.configure_mock(**mock_spec_ipv6_binary) + snmpobj_ipv6_binary.swalk = AsyncMock(return_value=walk_results_ipv6_binary) + snmpobj_ipv6_binary.walk = AsyncMock(return_value=walk_results_ipv6_binary) # Initialize expected results ipv4_expected_dict = { @@ -242,14 +236,14 @@ def test_layer3(self): # Initializing key variables pass - def test_ipnettomediatable(self): + async def test_ipnettomediatable(self): """Testing method / function ipnettomediatable.""" # Initialize key variables oid = ".1.3.6.1.2.1.4.22.1.2" # Get results testobj = testimport.init_query(self.snmpobj_ipv4_binary) - results = testobj.ipnettomediatable() + results = await testobj.ipnettomediatable() # Basic testing of results for key, value in results.items(): @@ -257,17 +251,17 @@ def test_ipnettomediatable(self): self.assertEqual(value, self.ipv4_expected_dict[key]) # Test that we are getting the correct OID - results = testobj.ipnettomediatable(oidonly=True) + results = await testobj.ipnettomediatable(oidonly=True) self.assertEqual(results, oid) - def test_ipnettophysicalphysaddress(self): + async def test_ipnettophysicalphysaddress(self): """Testing method / function ipnettophysicalphysaddress.""" # Initialize key variables oid = ".1.3.6.1.2.1.4.35.1.4" # Get results testobj = testimport.init_query(self.snmpobj_ipv6_binary) - results = testobj.ipnettophysicalphysaddress() + results = await testobj.ipnettophysicalphysaddress() # Basic testing of results for key, value in results.items(): @@ -275,7 +269,7 @@ def test_ipnettophysicalphysaddress(self): self.assertEqual(value, self.ipv6_expected_dict[key]) # Test that we are getting the correct OID - results = testobj.ipnettophysicalphysaddress(oidonly=True) + results = await testobj.ipnettophysicalphysaddress(oidonly=True) self.assertEqual(results, oid) diff --git a/tests/switchmap_/poller/snmp/mib/generic/test_mib_lldp.py b/tests/switchmap_/poller/snmp/mib/generic/test_mib_lldp.py index f84e2af4e..d5c98ef93 100755 --- a/tests/switchmap_/poller/snmp/mib/generic/test_mib_lldp.py +++ b/tests/switchmap_/poller/snmp/mib/generic/test_mib_lldp.py @@ -4,7 +4,6 @@ import unittest import os import sys -from mock import Mock # Try to create a working PYTHONPATH diff --git a/tests/switchmap_/poller/snmp/mib/juniper/test_mib_junipervlan.py b/tests/switchmap_/poller/snmp/mib/juniper/test_mib_junipervlan.py index 12c78ce02..4e7545ab2 100755 --- a/tests/switchmap_/poller/snmp/mib/juniper/test_mib_junipervlan.py +++ b/tests/switchmap_/poller/snmp/mib/juniper/test_mib_junipervlan.py @@ -116,7 +116,7 @@ def walk(self): pass -class TestJuniperVlanFunctions(unittest.TestCase): +class TestJuniperVlanFunctions(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### @@ -152,7 +152,7 @@ def test_init_query(self): pass -class TestJuniperVlan(unittest.TestCase): +class TestJuniperVlan(unittest.IsolatedAsyncioTestCase): """Checks all methods.""" ######################################################################### diff --git a/tests/switchmap_/poller/test_async_poll.py b/tests/switchmap_/poller/test_async_poll.py new file mode 100644 index 000000000..8d715a258 --- /dev/null +++ b/tests/switchmap_/poller/test_async_poll.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +"""Test the async poller poll module.""" + +import os +import sys +import pytest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock +from switchmap.poller.poll import devices, device, cli_device, _META + +# Create a working PYTHONPATH +EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) +ROOT_DIR = os.path.abspath( + os.path.join( + os.path.abspath( + os.path.join(EXEC_DIR, os.pardir) # Move up to 'poller' + ), + os.pardir, # Move up to 'switchmap_' + ) +) +_EXPECTED = "{0}switchmap-ng{0}tests{0}switchmap_{0}poller".format(os.sep) +if EXEC_DIR.endswith(_EXPECTED) is True: + # We need to prepend the path in case the repo has been installed + # elsewhere on the system using PIP. This could corrupt expected results + sys.path.insert(0, ROOT_DIR) +else: + print( + f'This script is not installed in the "{_EXPECTED}" directory. ' + "Please fix." + ) + sys.exit(2) + + +@pytest.fixture +def mock_config_setup(): + """Set up mock configuration for tests. + + Args: + None + + Returns: + MagicMock: Mock configuration instance with zones and subprocesses. + """ + mock_config_instance = MagicMock() + mock_zone = MagicMock() + mock_zone.name = "zone1" + mock_zone.hostnames = ["device1", "device2"] + mock_config_instance.zones.return_value = [mock_zone] + mock_config_instance.agent_subprocesses.return_value = 2 + return mock_config_instance + + +@pytest.fixture +def mock_poll_meta(): + """Create a mock poll meta object. + + Args: + None + + Returns: + _META: Mock poll meta object with zone, hostname, and config. + """ + return _META(zone="zone1", hostname="device1", config=MagicMock()) + + +class TestAsyncPoll: + """Test cases for the async poll module functionality.""" + + @pytest.mark.asyncio + async def test_devices_basic_functionality(self, mock_config_setup): + """Test basic device polling functionality.""" + with patch("switchmap.poller.poll.ConfigPoller") as mock_config: + mock_config.return_value = mock_config_setup + with patch( + "switchmap.poller.poll.device", new_callable=AsyncMock + ) as mock_device: + mock_device.return_value = True + + await devices() + + # Verify device was called for each hostname + assert mock_device.call_count == 2 + + # Check that the calls were made with correct hostnames + call_args_list = mock_device.call_args_list + hostnames_called = [ + call[0][0].hostname for call in call_args_list + ] + assert "device1" in hostnames_called + assert "device2" in hostnames_called + + @pytest.mark.asyncio + async def test_devices_invalid_concurrency(self, mock_config_setup): + """Test devices() with invalid concurrency values.""" + with patch("switchmap.poller.poll.ConfigPoller") as mock_config: + mock_config.return_value = mock_config_setup + with patch( + "switchmap.poller.poll.device", new_callable=AsyncMock + ) as mock_device: + mock_device.return_value = True + + # Test negative concurrency + await devices(max_concurrent_devices=-1) + # Should still call device functions with default concurrency + assert mock_device.call_count == 2 + + @pytest.mark.asyncio + async def test_devices_empty_zones(self): + """Test devices() with zones containing no hostnames.""" + mock_config_instance = MagicMock() + mock_zone = MagicMock() + mock_zone.name = "empty_zone" + mock_zone.hostnames = [] # Empty hostnames + mock_config_instance.zones.return_value = [mock_zone] + mock_config_instance.agent_subprocesses.return_value = 2 + + with patch("switchmap.poller.poll.ConfigPoller") as mock_config: + mock_config.return_value = mock_config_instance + with patch( + "switchmap.poller.poll.device", new_callable=AsyncMock + ) as mock_device: + await devices() + # Should not call device when no hostnames + assert mock_device.call_count == 0 + + @pytest.mark.asyncio + async def test_devices_with_custom_concurrency(self, mock_config_setup): + """Test device polling with custom concurrency limit.""" + with patch("switchmap.poller.poll.ConfigPoller") as mock_config: + mock_config.return_value = mock_config_setup + with patch( + "switchmap.poller.poll.device", new_callable=AsyncMock + ) as mock_device: + mock_device.return_value = True + + await devices(max_concurrent_devices=1) + + # Should still call device for both hostnames + assert mock_device.call_count == 2 + + @pytest.mark.asyncio + async def test_cli_device_found(self, mock_config_setup): + """Test CLI device polling when hostname is found.""" + with patch("switchmap.poller.poll.ConfigPoller") as mock_config: + mock_config.return_value = mock_config_setup + with patch( + "switchmap.poller.poll.device", new_callable=AsyncMock + ) as mock_device: + mock_device.return_value = True + + await cli_device("device1") + + # Should call device once for the found hostname + assert mock_device.call_count == 1 + assert mock_device.call_args[0][0].hostname == "device1" + + @pytest.mark.asyncio + async def test_cli_device_not_found(self, mock_config_setup): + """Test CLI device polling when hostname is not found.""" + with patch("switchmap.poller.poll.ConfigPoller") as mock_config: + mock_config.return_value = mock_config_setup + with patch( + "switchmap.poller.poll.device", new_callable=AsyncMock + ) as mock_device: + + await cli_device("nonexistent_device") + + # Should not call device for non-existent hostname + assert mock_device.call_count == 0 + + @pytest.mark.asyncio + async def test_device_with_skip_file(self, mock_poll_meta): + """Test device processing when skip file exists.""" + with patch("switchmap.poller.poll.files.skip_file") as mock_skip_file: + mock_skip_file.return_value = "/path/to/skip/file" + with patch("switchmap.poller.poll.os.path.isfile") as mock_isfile: + mock_isfile.return_value = True + with patch("switchmap.poller.poll.log.log2debug") as mock_log: + # Create mock semaphore and session + mock_semaphore = asyncio.Semaphore(1) + mock_session = MagicMock() + + result = await device( + mock_poll_meta, mock_semaphore, mock_session + ) + + # Should return False when skip file exists + assert result is False + mock_log.assert_called() + + @pytest.mark.asyncio + async def test_device_invalid_hostname(self): + """Test device processing with invalid hostname.""" + mock_semaphore = asyncio.Semaphore(1) + mock_session = MagicMock() + + # Test with None hostname + poll_meta = _META(zone="zone1", hostname=None, config=MagicMock()) + result = await device(poll_meta, mock_semaphore, mock_session) + assert result is False + + # Test with "none" hostname + poll_meta = _META(zone="zone1", hostname="none", config=MagicMock()) + result = await device(poll_meta, mock_semaphore, mock_session) + assert result is False + + @pytest.mark.asyncio + async def test_device_snmp_failure(self, mock_poll_meta): + """Test device processing when SNMP initialization fails.""" + mock_skip_file_path = "/path/to/skip/file" + with patch("switchmap.poller.poll.files.skip_file") as mock_skip_file: + mock_skip_file.return_value = mock_skip_file_path + with patch("switchmap.poller.poll.os.path.isfile") as mock_isfile: + mock_isfile.return_value = False + with patch( + "switchmap.poller.poll.poller.Poll" + ) as mock_poll_cls: + mock_poll_instance = AsyncMock() + mock_poll_instance.initialize_snmp.return_value = False + mock_poll_cls.return_value = mock_poll_instance + + mock_semaphore = asyncio.Semaphore(1) + mock_session = MagicMock() + + result = await device( + mock_poll_meta, mock_semaphore, mock_session + ) + + # Should return False when SNMP initialization fails + assert result is False + + @pytest.mark.asyncio + async def test_device_successful_poll_no_post(self, mock_poll_meta): + """Test successful device polling without posting data.""" + mock_skip_file_path = "/path/to/skip/file" + with patch("switchmap.poller.poll.files.skip_file") as mock_skip_file: + mock_skip_file.return_value = mock_skip_file_path + with patch("switchmap.poller.poll.os.path.isfile") as mock_isfile: + mock_isfile.return_value = False + with patch( + "switchmap.poller.poll.poller.Poll" + ) as mock_poll_cls: + mock_poll_instance = AsyncMock() + mock_poll_instance.initialize_snmp.return_value = True + mock_poll_instance.query.return_value = {"test": "data"} + mock_poll_cls.return_value = mock_poll_instance + + mock_semaphore = asyncio.Semaphore(1) + mock_session = MagicMock() + + with patch( + "switchmap.poller.poll.udevice.Device" + ) as mock_device_class: + mock_device_instance = MagicMock() + mock_device_instance.process.return_value = { + "misc": {}, + "test": "data", + } + mock_device_class.return_value = mock_device_instance + + result = await device( + mock_poll_meta, + mock_semaphore, + mock_session, + post=False, + ) + + # Should return True for successful poll + assert result is True + # Should create Device instance and call process + mock_device_class.assert_called_once() + mock_device_instance.process.assert_called_once() diff --git a/tests/switchmap_/poller/test_poll.py b/tests/switchmap_/poller/test_poll.py deleted file mode 100644 index 83f109ac3..000000000 --- a/tests/switchmap_/poller/test_poll.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python3 -"""Test the poller poll module.""" - -import os -import sys -import unittest -from unittest.mock import patch, MagicMock, call -from multiprocessing import ProcessError, TimeoutError -from switchmap.poller.poll import devices, device, cli_device, _META - -# Try to create a working PYTHONPATH -EXEC_DIR = os.path.dirname(os.path.realpath(__file__)) -ROOT_DIR = os.path.abspath( - os.path.join( - os.path.abspath( - os.path.join(EXEC_DIR, os.pardir) # Move up to 'poller' - ), - os.pardir, # Move up to 'switchmap_' - ) -) -_EXPECTED = "{0}tests{0}switchmap_{0}poller".format(os.sep) - -if EXEC_DIR.endswith(_EXPECTED): - # Prepend the root directory to the Python path - sys.path.insert(0, ROOT_DIR) -else: - print( - f'This script is not installed in the "{_EXPECTED}" directory. ' - "Please fix." - ) - sys.exit(2) - - -class TestPollModule(unittest.TestCase): - """Test cases for the poll module functionality.""" - - def setUp(self): - """Set up the test environment.""" - self.mock_config_instance = MagicMock() - self.mock_zone = MagicMock() - self.mock_zone.name = "zone1" - self.mock_zone.hostnames = ["device1", "device2"] - self.mock_config_instance.zones.return_value = [self.mock_zone] - self.mock_config_instance.agent_subprocesses.return_value = 2 - - def test_devices_without_multiprocessing(self): - """Test device processing without multiprocessing.""" - with patch("switchmap.poller.poll.ConfigPoller") as mock_config: - mock_config.return_value = self.mock_config_instance - with patch("switchmap.poller.poll.device") as mock_device: - devices(multiprocessing=False) - expected_calls = [ - call( - _META( - zone="zone1", - hostname="device1", - config=self.mock_config_instance, - ) - ), - call( - _META( - zone="zone1", - hostname="device2", - config=self.mock_config_instance, - ) - ), - ] - mock_device.assert_has_calls(expected_calls) - - def test_devices_with_multiprocessing(self): - """Test device processing with multiprocessing.""" - with patch("switchmap.poller.poll.ConfigPoller") as mock_config: - mock_config.return_value = self.mock_config_instance - with patch("switchmap.poller.poll.Pool") as mock_pool: - mock_pool_instance = MagicMock() - mock_pool.return_value.__enter__.return_value = ( - mock_pool_instance - ) - - devices(multiprocessing=True) - - mock_pool.assert_called_once_with(processes=2) - expected_args = [ - _META( - zone="zone1", - hostname="device1", - config=self.mock_config_instance, - ), - _META( - zone="zone1", - hostname="device2", - config=self.mock_config_instance, - ), - ] - mock_pool_instance.map.assert_called_once_with( - device, expected_args - ) - - def test_devices_with_multiprocessing_pool_error(self): - """Test error handling when pool creation fails.""" - with patch("switchmap.poller.poll.ConfigPoller") as mock_config: - mock_config.return_value = self.mock_config_instance - with patch("switchmap.poller.poll.Pool") as mock_pool: - mock_pool.side_effect = OSError("Failed to create pool") - with self.assertRaises(OSError): - devices(multiprocessing=True) - - def test_devices_with_multiprocessing_worker_error(self): - """Test error handling when a worker process fails.""" - with patch("switchmap.poller.poll.ConfigPoller") as mock_config: - mock_config.return_value = self.mock_config_instance - with patch("switchmap.poller.poll.Pool") as mock_pool: - mock_pool_instance = MagicMock() - mock_pool.return_value.__enter__.return_value = ( - mock_pool_instance - ) - mock_pool_instance.map.side_effect = RuntimeError( - "Worker process failed" - ) - - with self.assertRaises((RuntimeError, ProcessError)): - devices(multiprocessing=True) - - def test_devices_with_multiprocessing_timeout(self): - """Test handling of worker process timeout in multiprocessing mode.""" - with patch("switchmap.poller.poll.ConfigPoller") as mock_config: - mock_config.return_value = self.mock_config_instance - with patch("switchmap.poller.poll.Pool") as mock_pool: - mock_pool_instance = MagicMock() - mock_pool.return_value.__enter__.return_value = ( - mock_pool_instance - ) - mock_pool_instance.map.side_effect = TimeoutError( - "Worker timeout" - ) - - with self.assertRaises(TimeoutError): - devices(multiprocessing=True) - - @patch("switchmap.poller.poll.files.skip_file") - @patch("switchmap.poller.poll.os.path.isfile") - @patch("switchmap.poller.poll.poller.Poll") - @patch("switchmap.poller.poll.rest.post") - def test_device_with_skip_file( - self, mock_rest_post, mock_poll, mock_isfile, mock_skip_file - ): - """Test device processing when skip file exists.""" - mock_skip_file.return_value = "/path/to/skip/file" - mock_isfile.return_value = True - - with patch("switchmap.poller.poll.log.log2debug") as mock_log: - poll_meta = _META(zone="zone1", hostname="device1", config=None) - device(poll_meta) - mock_log.assert_called_once_with(1041, unittest.mock.ANY) - mock_poll.assert_not_called() - mock_rest_post.assert_not_called() - - @patch("switchmap.poller.poll.files.skip_file") - @patch("switchmap.poller.poll.os.path.isfile") - @patch("switchmap.poller.poll.poller.Poll") - @patch("switchmap.poller.poll.rest.post") - def test_device_with_invalid_snmp_data( - self, mock_rest_post, mock_poll, mock_isfile, mock_skip_file - ): - """Test device processing with invalid SNMP data.""" - mock_skip_file.return_value = "/path/to/skip/file" - mock_isfile.return_value = False - mock_poll_instance = MagicMock() - mock_poll_instance.query.return_value = None - mock_poll.return_value = mock_poll_instance - - with patch("switchmap.poller.poll.log.log2debug") as mock_log: - poll_meta = _META(zone="zone1", hostname="device1", config=None) - device(poll_meta) - mock_log.assert_called_once_with(1025, unittest.mock.ANY) - mock_rest_post.assert_not_called() - - def test_cli_device_not_found(self): - """Test CLI device handling when device not found.""" - with patch("switchmap.poller.poll.ConfigPoller") as mock_config: - mock_config.return_value = self.mock_config_instance - with patch("switchmap.poller.poll.log.log2see") as mock_log: - cli_device("unknown-device") - mock_log.assert_called_once_with( - 1036, "No hostname unknown-device found in configuration" - ) - - def test_cli_device_found(self): - """Test CLI device handling when device is found.""" - with patch("switchmap.poller.poll.ConfigPoller") as mock_config: - mock_config.return_value = self.mock_config_instance - with patch("switchmap.poller.poll.device") as mock_device: - cli_device("device1") - expected_meta = _META( - zone="zone1", - hostname="device1", - config=self.mock_config_instance, - ) - mock_device.assert_called_once_with(expected_meta, post=False) - - -if __name__ == "__main__": - unittest.main()