diff --git a/napalm_base/base.py b/napalm_base/base.py index bef7cfa1..7e321845 100644 --- a/napalm_base/base.py +++ b/napalm_base/base.py @@ -16,20 +16,76 @@ from __future__ import print_function from __future__ import unicode_literals -# std libs +# Import stdlib import sys +import logging +import functools -# local modules +# Import NAPALM Base import napalm_base.exceptions import napalm_base.helpers - import napalm_base.constants as c - from napalm_base import validate +log = logging.getLogger(__file__) + + +def _raise_napalm_error(ERROR_MAP, meth): + ''' + Wrap a method and raise the method indicated + in the ERROR_MAP hashmap. + If there's no binding found, this will raise the initial exception. + ''' + @functools.wraps(meth) + def fun(*args, **kwargs): + try: + return meth(*args, **kwargs) + except Exception as error: + err_name = error.__class__.__name__ + err_mod = error.__class__.__module__ + err_full_name = '{}.{}'.format(err_mod, err_name) + log.error('Raised {}'.format(err_full_name), exc_info=True) + if err_full_name in ERROR_MAP: + err_class = ERROR_MAP[err_full_name] + log.info('Raising {} instead'.format(err_class.__name__)) + err_obj = err_class(error) + err_obj.original_exc = error + raise err_obj, None, sys.exc_info()[2] + elif err_name not in dir(napalm_base.exceptions) and \ + err_name not in __builtins__.keys(): + log.debug('Didnt catch that, raising UncaughtException.') + err_msg = ( + 'NAPALM didn\'t catch this exception. Please, fill a bugfix on ' + 'https://github.com/napalm-automation/napalm/issues\n' + 'Don\'t forget to include this traceback.' + ) + err_obj = napalm_base.exceptions.UncaughtException(err_msg) + err_obj.original_exc = error + raise err_obj, None, sys.exc_info()[2] + # Raise everything else, using the original class. + raise + return fun + + +class _NAPALMErrorCatcherMeta(type): + ''' + Metaclass to wrap the driver methods using the + _raise_napalm_error function. + ''' + def __new__(cls, name, bases, dct): + log.info('Setting up metaclass') + ERROR_MAP = dct.get('ERROR_MAP', {}) + for meth in dct: + if not meth.startswith('_') and hasattr(dct[meth], '__call__'): + log.debug('Wrapping {}'.format(meth)) + dct[meth] = _raise_napalm_error(ERROR_MAP, dct[meth]) + return type.__new__(cls, name, bases, dct) + class NetworkDriver(object): + __metaclass__ = _NAPALMErrorCatcherMeta + def __init__(self, hostname, username, password, timeout=60, optional_args=None): """ This is the base class you have to inherit from when writing your own Network Driver to diff --git a/napalm_base/exceptions.py b/napalm_base/exceptions.py index 5a142288..0eefa53b 100644 --- a/napalm_base/exceptions.py +++ b/napalm_base/exceptions.py @@ -115,3 +115,7 @@ class TemplateRenderException(Exception): class ValidationException(Exception): pass + + +class UncaughtException(Exception): + pass