diff --git a/docker_utils/api_config.ini b/docker_utils/api_config.ini index 02213546..0b52d404 100644 --- a/docker_utils/api_config.ini +++ b/docker_utils/api_config.ini @@ -16,6 +16,7 @@ password = vhost = / flush_every = 100 flush_interval = 10 +tz = US/Eastern [Database] db_type = sqlite diff --git a/multiscanner/__init__.py b/multiscanner/__init__.py index 1ec3f674..89606184 100644 --- a/multiscanner/__init__.py +++ b/multiscanner/__init__.py @@ -4,12 +4,12 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .config import ( # noqa F401 - PY3, MS_WD, CONFIG, MODULESDIR + MSConfigParser, MS_WD, PY3, config_init, update_ms_config, update_ms_config_file ) from .ms import ( # noqa F401 - config_init, multiscan, parse_reports, _ModuleInterface, - _GlobalModuleInterface, _Thread, _run_module, _main + multiscan, parse_reports, _ModuleInterface, + _GlobalModuleInterface, _Thread, _run_module, _main, _get_main_modules ) from .version import __version__ # noqa F401 diff --git a/multiscanner/analytics/ssdeep_analytics.py b/multiscanner/analytics/ssdeep_analytics.py index fb211007..ec5338fc 100644 --- a/multiscanner/analytics/ssdeep_analytics.py +++ b/multiscanner/analytics/ssdeep_analytics.py @@ -22,7 +22,6 @@ ''' import argparse -import configparser import json import logging import sys @@ -37,20 +36,13 @@ ssdeep = False -from multiscanner import CONFIG as MS_CONFIG -from multiscanner.common import utils from multiscanner.storage import storage class SSDeepAnalytic: def __init__(self): - storage_conf = utils.get_config_path(MS_CONFIG, 'storage') - config_object = configparser.ConfigParser() - config_object.optionxform = str - config_object.read(storage_conf) - conf = utils.parse_config(config_object) - storage_handler = storage.StorageHandler(configfile=storage_conf) + storage_handler = storage.StorageHandler() es_handler = storage_handler.load_required_module('ElasticSearchStorage') if not es_handler: @@ -59,7 +51,7 @@ def __init__(self): # probably not ideal... self.es = es_handler.es - self.index = conf['ElasticSearchStorage']['index'] + self.index = es_handler.index self.doc_type = '_doc' def ssdeep_compare(self): diff --git a/multiscanner/common/dir_monitor.py b/multiscanner/common/dir_monitor.py index d555a659..5c57d662 100755 --- a/multiscanner/common/dir_monitor.py +++ b/multiscanner/common/dir_monitor.py @@ -19,9 +19,8 @@ from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer -from multiscanner import CONFIG as MS_CONFIG from multiscanner import multiscan, parse_reports -from multiscanner.common import utils +from multiscanner import config as msconf from multiscanner.storage import storage logger = logging.getLogger(__name__) @@ -81,8 +80,7 @@ def start_observer(directory, work_queue, recursive=False): def multiscanner_process(work_queue, config, batch_size, wait_seconds, delete, exit_signal): filelist = [] time_stamp = None - storage_conf = utils.get_config_path(config, 'storage') - storage_handler = storage.StorageHandler(configfile=storage_conf) + storage_handler = storage.StorageHandler() while not exit_signal.value: time.sleep(1) try: @@ -102,7 +100,7 @@ def multiscanner_process(work_queue, config, batch_size, wait_seconds, delete, e else: continue - resultlist = multiscan(filelist, configfile=config) + resultlist = multiscan(filelist, config=config) results = parse_reports(resultlist, python=True) if delete: for file_name in results: @@ -118,13 +116,16 @@ def multiscanner_process(work_queue, config, batch_size, wait_seconds, delete, e def _main(): args = _parse_args() + if args.config != msconf.CONFIG_FILEPATH: + msconf.update_ms_config_file(args.config) + work_queue = multiprocessing.Queue() exit_signal = multiprocessing.Value('b') exit_signal.value = False observer = start_observer(args.Directory, work_queue, args.recursive) ms_process = multiprocessing.Process( target=multiscanner_process, - args=(work_queue, args.config, args.batch, args.seconds, args.delete, exit_signal)) + args=(work_queue, msconf.MS_CONFIG, args.batch, args.seconds, args.delete, exit_signal)) ms_process.start() try: while True: @@ -141,7 +142,7 @@ def _main(): def _parse_args(): parser = argparse.ArgumentParser(description='Monitor a directory and submit new files to MultiScanner') parser.add_argument("-c", "--config", help="The config file to use", required=False, - default=MS_CONFIG) + default=msconf.CONFIG_FILEPATH) parser.add_argument("-s", "--seconds", help="The number of seconds to wait for additional files", required=False, default=120, type=int) parser.add_argument("-b", "--batch", help="The max number of files per batch", required=False, diff --git a/multiscanner/common/pdf_generator/__init__.py b/multiscanner/common/pdf_generator/__init__.py index f0bc130c..4604aded 100644 --- a/multiscanner/common/pdf_generator/__init__.py +++ b/multiscanner/common/pdf_generator/__init__.py @@ -1,13 +1,13 @@ from __future__ import (division, absolute_import, with_statement, print_function, unicode_literals) -import json import os from reportlab.lib import colors, units from reportlab.platypus import TableStyle from multiscanner.common.pdf_generator import generic_pdf +from multiscanner import config as msconf def create_pdf_document(DIR, report): @@ -15,13 +15,12 @@ def create_pdf_document(DIR, report): Method to create a PDF report based of a multiscanner JSON report. Args: - DIR: Represents the a directory containing the 'pdf_config.json' file. + DIR: Represents the a directory containing the 'pdf_config.ini' file. report: A JSON object. ''' - with open(os.path.join(os.path.split(DIR)[0], 'pdf_config.json')) as data_file: - pdf_components = json.load(data_file) - + pdf_config = os.path.join(DIR, 'pdf_config.ini') + pdf_components = msconf.read_config(pdf_config).get_section('pdf') gen_pdf = generic_pdf.GenericPDF(pdf_components) notice = [] diff --git a/multiscanner/common/utils.py b/multiscanner/common/utils.py index d804710d..59e75769 100644 --- a/multiscanner/common/utils.py +++ b/multiscanner/common/utils.py @@ -3,14 +3,11 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import (absolute_import, division, unicode_literals, with_statement) -import ast -import configparser import imp import logging import os -import sys -from multiscanner.config import PY3 +from six import PY3 logger = logging.getLogger(__name__) @@ -76,42 +73,6 @@ def convert_encoding(data, encoding='UTF-8', errors='replace'): return data -def parse_config(config_object): - """Take a config object and returns it as a dictionary""" - return_var = {} - for section in config_object.sections(): - section_dict = dict(config_object.items(section)) - for key in section_dict: - try: - section_dict[key] = ast.literal_eval(section_dict[key]) - except Exception as e: - logger.debug(e) - return_var[section] = section_dict - return return_var - - -def get_config_path(config_file, component): - """Gets the location of the config file for the given multiscanner component - from the multiscanner config file - - Components: - storage - api - web""" - conf = configparser.ConfigParser() - conf.read(config_file) - conf = parse_config(conf) - try: - return conf['main']['%s-config' % component] - except KeyError: - logger.error( - "Couldn't find '{}-config' value in 'main' section " - "of config file. Have you run 'python multiscanner.py init'?" - .format(component) - ) - sys.exit() - - def dirname(path): """OS independent version of os.path.dirname""" split = path.split('/') @@ -134,7 +95,7 @@ def basename(path): return split[-1] -def parseDir(directory, recursive=False, exclude=['__init__']): +def parse_dir(directory, recursive=False, exclude=['__init__']): """ Returns a list of files in a directory. @@ -148,7 +109,7 @@ def parseDir(directory, recursive=False, exclude=['__init__']): item = os.path.join(directory, item) if os.path.isdir(item): if recursive: - filelist.extend(parseDir(item, recursive)) + filelist.extend(parse_dir(item, recursive)) else: continue else: @@ -162,7 +123,7 @@ def parseDir(directory, recursive=False, exclude=['__init__']): return filelist -def parseFileList(FileList, recursive=False): +def parse_file_list(FileList, recursive=False): """ Takes a list of files and directories and returns a list of files. @@ -173,7 +134,7 @@ def parseFileList(FileList, recursive=False): filelist = [] for item in FileList: if os.path.isdir(item): - filelist.extend(parseDir(item, recursive)) + filelist.extend(parse_dir(item, recursive)) elif os.path.isfile(item): if not PY3: filelist.append(item.decode('utf8')) diff --git a/multiscanner/config.py b/multiscanner/config.py index 831b00bf..69ed354b 100644 --- a/multiscanner/config.py +++ b/multiscanner/config.py @@ -1,11 +1,15 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +import ast +import codecs +import configparser import logging import os import sys -from six import PY3 # noqa F401 +from six import PY3, iteritems # noqa F401 +from multiscanner.common.utils import parse_dir logger = logging.getLogger(__name__) @@ -17,7 +21,46 @@ MS_WD = os.path.dirname(os.path.abspath(__file__)) # The directory where the modules are kept -MODULESDIR = os.path.join(MS_WD, 'modules') +MODULES_DIR = os.path.join(MS_WD, 'modules') + +# The default config file +CONFIG_FILEPATH = None + +# Main MultiScanner config, as a ConfigParser object +MS_CONFIG = None + +# The dictionary of modules and whether they're enabled or not +MODULE_LIST = None + + +class MSConfigParser(configparser.ConfigParser): + def __init__(self, *args, **kwargs): + super(MSConfigParser, self).__init__(*args, **kwargs) + self.optionxform = str # Preserve case + + def __getitem__(self, key): + value = super(MSConfigParser, self).__getitem__(key) + return _convert_to_literal(value) + + def get(self, *args, **kwargs): + value = super(MSConfigParser, self).get(*args, **kwargs) + return _convert_to_literal(value) + + def get_section(self, section_name): + section = self.items(section_name) + return {k: _convert_to_literal(v) for k, v in section} + + +def _convert_to_literal(value): + """Attempts to convert value to a Python literal if possible.""" + try: + return ast.literal_eval(value) + except (SyntaxError, ValueError) as e: + # Ignore if config value isn't convertible to a Python literal + pass + except Exception as e: + logger.debug(e) + return value def get_configuration_paths(): @@ -30,7 +73,7 @@ def get_configuration_paths(): ] -def determine_configuration_path(filepath): +def determine_configuration_path(filepath=None): if filepath: return filepath @@ -40,6 +83,7 @@ def determine_configuration_path(filepath): for config_path in config_paths: if os.path.exists(config_path): config_file = config_path + break if not config_file: # If the local storage folder doesn't exist, we create it. @@ -51,5 +95,260 @@ def determine_configuration_path(filepath): return config_file -# The default config file -CONFIG = determine_configuration_path(None) +CONFIG_FILEPATH = determine_configuration_path() + + +def parse_config(config_object): + """Converts a ConfigParser object to a dictionary""" + return_var = {} + for section in config_object.sections(): + section_dict = dict(config_object.items(section)) + for key in section_dict: + try: + section_dict[key] = ast.literal_eval(section_dict[key]) + except (SyntaxError, ValueError) as e: + # Ignore if config value isn't convertible to a Python literal + pass + except Exception as e: + logger.debug(e) + return_var[section] = section_dict + return return_var + + +def dict_to_config(dictionary): + """Converts a dictionary to a ConfigParser object""" + config = MSConfigParser() + + for name, section in dictionary.items(): + config.add_section(name) + for key in section.keys(): + config.set(name, key, str(section[key])) + return config + + +def write_config(config_object, config_file, default_config): + """Write the default configuration to the given config file + + config_object - the ConfigParser object + config_file - the filename of the config file + default_config - dictionary of section names and values to set within this configuration + """ + for section_name, section in default_config.items(): + if section_name not in config_object.sections(): + config_object.add_section(section_name) + for key in section: + config_object.set(section_name, key, str(default_config[section_name][key])) + with codecs.open(config_file, 'w', 'utf-8') as conffile: + config_object.write(conffile) + + +def read_config(config_file, default_config=None): + """Parse a config file into a ConfigParser object + + Can optionally set a default configuration by providing 'section_name' and + 'default_config' arguments. + + config_file - the filename of the config file + default_config - dictionary of section names and values to set within this configuration + """ + config_object = MSConfigParser() + config_object.read(config_file) + if default_config is not None: + contains_sections = set(default_config.keys()).issubset(config_object.sections()) + if not contains_sections or not os.path.isfile(config_file): + # Write default config + write_config(config_object, config_file, default_config) + return config_object + + +MS_CONFIG = read_config(CONFIG_FILEPATH) + + +def get_config_path(component, config=None): + """Gets the location of the config file for the given MultiScanner component + from the MultiScanner config + + Components: + storage + api + web + + component - component to get the path for + config - dictionary or ConfigParser object containing MultiScanner config + """ + if config is None: + config = MS_CONFIG + + try: + return config['main']['%s-config' % component] + except KeyError: + logger.error( + "Couldn't find '{}-config' value in 'main' section " + "of config file. Have you run 'python multiscanner.py init'?" + .format(component) + ) + sys.exit() + + +def get_modules(): + """Returns a dictionary with module names as keys. Values contain a boolean + denoting whether or not they are enabled in the config, and the folder + containing the module. + """ + files = parse_dir(MODULES_DIR, recursive=True, exclude=["__init__"]) + + global MS_CONFIG + modules = {} + # for module in module_names: + for f in files: + folder = os.path.dirname(f) + filename = os.path.splitext(os.path.basename(f)) + + if filename[1] == '.py': + module = filename[0] + # Always run these modules + if module == 'filemeta' or module == 'ssdeeper': + modules[module] = [True, folder] + continue + try: + modules[module] = [MS_CONFIG[module]['ENABLED'], folder] + except KeyError as e: + logger.debug(e) + modules[module] = [False, folder] + return modules + + +MODULE_LIST = get_modules() + + +def update_ms_config(config): + """Update global config object. + + config - the ConfigParser object or dictionary to replace MS_CONFIG with + """ + global MS_CONFIG + if isinstance(config, MSConfigParser): + MS_CONFIG = config + else: + MS_CONFIG = dict_to_config(config) + + +def update_ms_config_file(config_file): + """Update config globals to a different file than the default. + + config_file - the file to be assigned to CONFIG_FILEPATH and read into MS_CONFIG + """ + global CONFIG_FILEPATH, MS_CONFIG + CONFIG_FILEPATH = config_file + MS_CONFIG = read_config(CONFIG_FILEPATH) + + +def update_paths_in_config(conf, filepath): + """Rewrite config values containing paths to point to a new multiscanner config directory. + """ + base_dir = os.path.split(filepath)[0] + if 'storage-config' in conf: + conf['storage-config'] = os.path.join(base_dir, 'storage.ini') + if 'api-config' in conf: + conf['api-config'] = os.path.join(base_dir, 'api_config.ini') + if 'web-config' in conf: + conf['web-config'] = os.path.join(base_dir, 'web_config.ini') + if 'ruledir' in conf: + conf['ruledir'] = os.path.join(base_dir, "etc", "yarasigs") + if 'key' in conf: + conf['key'] = os.path.join(base_dir, 'etc', 'id_rsa') + if 'hash_list' in conf: + conf['hash_list'] = os.path.join(base_dir, 'etc', 'nsrl', 'hash_list') + if 'offsets' in conf: + conf['offsets'] = os.path.join(base_dir, 'etc', 'nsrl', 'offsets') + + +def config_init(filepath, sections, overwrite=False): + """ + Creates a new config file at filepath + + filepath - The config file to create + sections - Dictionary mapping section names to the Python module containing its DEFAULTCONF + overwrite - Whether to overwrite the config file at filepath, if it already exists + """ + + config = MSConfigParser() + + if overwrite or not os.path.isfile(filepath): + return reset_config(sections, config, filepath) + else: + config.read(filepath) + write_missing_config(sections, config, filepath) + return config + + +def write_missing_config(sections, config_object, filepath): + """ + Write in default config for modules not in config file. Returns True if config was written, False if not. + + config_object - The ConfigParser object + filepath - The path to the config file + sections - Dictionary mapping section names to the Python module containing its DEFAULTCONF + """ + ConfNeedsWrite = False + keys = list(sections.keys()) + keys.sort() + for section_name in keys: + if section_name in config_object: + continue + try: + conf = sections[section_name].DEFAULTCONF + except Exception as e: + logger.warning(e) + continue + ConfNeedsWrite = True + update_paths_in_config(conf, filepath) + config_object.add_section(section_name) + for key in conf: + config_object.set(section_name, key, str(conf[key])) + + if ConfNeedsWrite: + with codecs.open(filepath, 'w', 'utf-8') as f: + config_object.write(f) + return True + return False + + +def reset_config(sections, config, filepath=None): + """ + Reset specific sections of a config file to their factory defaults. + + sections - Dictionary mapping section names to the Python module containing its DEFAULTCONF + config - ConfigParser object in which to store config + filepath - Path to the config file + + Returns: + The ConfigParser object that was written to the file. + """ + if not filepath: + filepath = CONFIG_FILEPATH + + # Read in the old config to preserve any sections not being reset + if os.path.isfile(filepath): + config.read(filepath) + + logger.info('Rewriting config at {}...'.format(filepath)) + + keys = list(sections.keys()) + keys.sort() + for section_name in keys: + try: + conf = sections[section_name].DEFAULTCONF + except Exception as e: + logger.warning(e) + continue + + update_paths_in_config(conf, filepath) + if not config.has_section(section_name): + config.add_section(section_name) + for key in conf: + config.set(section_name, key, str(conf[key])) + + with codecs.open(filepath, 'w', 'utf-8') as f: + config.write(f) + return config diff --git a/multiscanner/distributed/api.py b/multiscanner/distributed/api.py index 2e013424..43a5595f 100755 --- a/multiscanner/distributed/api.py +++ b/multiscanner/distributed/api.py @@ -43,8 +43,6 @@ TODO: * Add doc strings to functions ''' -import codecs -import configparser import hashlib import json import logging @@ -68,15 +66,15 @@ from sqlalchemy.exc import SQLAlchemyError -# TODO: Why do we need to parseDir(MODULEDIR) multiple times? -from multiscanner import MODULESDIR, MS_WD, multiscan, parse_reports, CONFIG as MS_CONFIG -from multiscanner.common import utils, pdf_generator, stix2_generator -from multiscanner.config import PY3 +import multiscanner as ms +from multiscanner.analytics.ssdeep_analytics import SSDeepAnalytic +from multiscanner.common import pdf_generator, stix2_generator +from multiscanner.config import PY3, get_config_path, read_config +from multiscanner.distributed.celery_worker import multiscanner_celery, ssdeep_compare_celery from multiscanner.storage import StorageHandler from multiscanner.storage import sql_driver as database from multiscanner.storage.storage import StorageNotLoadedError - TASK_NOT_FOUND = {'Message': 'No task with that ID found!'} INVALID_REQUEST = {'Message': 'Invalid request parameters'} TASK_STILL_PROCESSING = {'Message': 'Task still pending'} @@ -117,39 +115,22 @@ def default(self, obj): app = Flask(__name__) app.json_encoder = CustomJSONEncoder -api_config_object = configparser.ConfigParser() -api_config_object.optionxform = str -# TODO: Why does this multiscanner.common instead of just common? -api_config_file = utils.get_config_path(MS_CONFIG, 'api') -api_config_object.read(api_config_file) -if not api_config_object.has_section('api') or not os.path.isfile(api_config_file): - # Write default config - api_config_object.add_section('api') - for key in DEFAULTCONF: - api_config_object.set('api', key, str(DEFAULTCONF[key])) - conffile = codecs.open(api_config_file, 'w', 'utf-8') - api_config_object.write(conffile) - conffile.close() -api_config = utils.parse_config(api_config_object) - -# TODO: fix this mess -# Needs api_config in order to function properly -from multiscanner.distributed.celery_worker import multiscanner_celery, ssdeep_compare_celery -from multiscanner.analytics.ssdeep_analytics import SSDeepAnalytic +api_config_file = get_config_path('api') +api_config = read_config(api_config_file, {'api': DEFAULTCONF, 'Database': database.Database.DEFAULTCONF}) -db = database.Database(config=api_config.get('Database')) +db = database.Database(config=api_config.get_section('Database'), regenconfig=False) # To run under Apache, we need to set up the DB outside of __main__ # Sleep and retry until database connection is successful try: # wait this many seconds between tries - db_sleep_time = int(api_config_object.get('Database', 'retry_time')) -except (configparser.NoSectionError, configparser.NoOptionError) as e: + db_sleep_time = int(api_config['Database']['retry_time']) +except KeyError as e: logger.debug(e) db_sleep_time = database.Database.DEFAULTCONF['retry_time'] try: # max number of times to retry - db_num_retries = int(api_config_object.get('Database', 'retry_num')) -except (configparser.NoSectionError, configparser.NoOptionError) as e: + db_num_retries = int(api_config['Database']['retry_num']) +except KeyError as e: logger.debug(e) db_num_retries = database.Database.DEFAULTCONF['retry_num'] @@ -168,16 +149,9 @@ def default(self, obj): logger.error("Retrying...") time.sleep(db_sleep_time) -storage_conf = utils.get_config_path(MS_CONFIG, 'storage') -storage_handler = StorageHandler(configfile=storage_conf) +storage_handler = StorageHandler() handler = storage_handler.load_required_module('ElasticSearchStorage') -ms_config_object = configparser.ConfigParser() -ms_config_object.optionxform = str -ms_configfile = MS_CONFIG -ms_config_object.read(ms_configfile) -ms_config = utils.parse_config(ms_config_object) - try: DISTRIBUTED = api_config['api']['distributed'] except KeyError as e: @@ -224,21 +198,18 @@ def multiscanner_process(work_queue, exit_signal): else: continue - filelist = [item[0] for item in metadata_list] - # modulelist = [item[5] for item in metadata_list] - resultlist = multiscan( - filelist, configfile=MS_CONFIG - # module_list - ) - results = parse_reports(resultlist, python=True) - - scan_time = datetime.now().isoformat() + for item in metadata_list: + filelist = [item[0]] + module_list = item[5] + resultlist = ms.multiscan( + filelist, + config=ms.config.MS_CONFIG, + module_list=module_list + ) + results = ms.parse_reports(resultlist, python=True) - if delete_after_scan: - for file_name in results: - os.remove(file_name) + scan_time = datetime.now().isoformat() - for item in metadata_list: # Use the original filename as the index instead of the full path results[item[1]] = results[item[0]] del results[item[0]] @@ -254,12 +225,15 @@ def multiscanner_process(work_queue, exit_signal): task_status='Complete', timestamp=scan_time, ) - metadata_list = [] - storage_handler.store(results, wait=False) + storage_handler.store(results, wait=False) + + if delete_after_scan: + for file_name in results: + os.remove(file_name) - filelist = [] time_stamp = None + metadata_list = [] storage_handler.close() @@ -290,24 +264,10 @@ def modules(): Return a list of module names available for MultiScanner to use, and whether or not they are enabled in the config. ''' - files = utils.parseDir(MODULESDIR, True) - filenames = [os.path.splitext(os.path.basename(f)) for f in files] - module_names = [m[0] for m in filenames if m[1] == '.py'] - - ms_config = configparser.ConfigParser() - ms_config.optionxform = str - ms_config.read(MS_CONFIG) - modules = {} - for module in module_names: - try: - is_enabled = ms_config.get(module, 'ENABLED') - if is_enabled == "True": - modules[module] = True - else: - modules[module] = False - except (configparser.NoSectionError, configparser.NoOptionError) as e: - logger.debug(e) - return jsonify(modules) + modlist = {name: mod[0] for (name, mod) in ms.config.MODULE_LIST.items()} + del modlist['filemeta'] + del modlist['ssdeeper'] + return jsonify(modlist) @app.route('/api/v2/tasks', methods=['GET']) @@ -415,7 +375,7 @@ def save_hashed_filename(f, zipped=False): # TODO: should we check if the file is already there # and skip this step if it is? file_path = os.path.join(api_config['api']['upload_folder'], f_name) - full_path = os.path.join(MS_WD, file_path) + full_path = os.path.join(ms.MS_WD, file_path) if zipped: shutil.copy2(f.name, full_path) else: @@ -475,7 +435,8 @@ def import_task(file_): def queue_task(original_filename, f_name, full_path, metadata, rescan=False, - queue_name='medium_tasks', priority=5, routing_key='tasks.medium'): + module_list=None, queue_name='medium_tasks', priority=5, + routing_key='tasks.medium'): ''' Queue up a single new task, for a single non-archive file. ''' @@ -492,14 +453,15 @@ def queue_task(original_filename, f_name, full_path, metadata, rescan=False, if DISTRIBUTED: # Publish the task to Celery + tmp_config = ms.config.parse_config(ms.config.MS_CONFIG) multiscanner_celery.apply_async( args=(full_path, original_filename, task_id, f_name, metadata), - kwargs=dict(config=MS_CONFIG), + kwargs=dict(config=tmp_config, module_list=module_list), **{'queue': queue_name, 'priority': priority, 'routing_key': routing_key} ) else: # Put the task on the queue - work_queue.put((full_path, original_filename, task_id, f_name, metadata)) + work_queue.put((full_path, original_filename, task_id, f_name, metadata, module_list)) return task_id @@ -546,6 +508,7 @@ def create_task(): task_id_list = [] extract_dir = None rescan = False + modules = None priority = 5 routing_key = 'tasks.medium' queue_name = 'medium_tasks' @@ -558,13 +521,10 @@ def create_task(): elif request.form[key] == 'rescan': rescan = True elif key == 'modules': - module_names = request.form[key] - files = utils.parseDir(MODULESDIR, True) - modules = [] - for f in files: - split = os.path.splitext(os.path.basename(f)) - if split[0] in module_names and split[1] == '.py': - modules.append(f) + module_names = request.form[key].split(',') + modules = list(set(module_names).intersection(ms.config.MODULE_LIST.keys())) + modules.append('filemeta') + modules.append('ssdeeper') elif key == 'archive-analyze' and request.form[key] == 'true': extract_dir = api_config['api']['upload_folder'] if not os.path.isdir(extract_dir): @@ -608,7 +568,9 @@ def create_task(): unzipped_file = open(os.path.join(extract_dir, uzfile)) f_name, full_path = save_hashed_filename(unzipped_file, True) tid = queue_task(uzfile, f_name, full_path, metadata, - rescan=rescan, queue_name=queue_name, priority=priority, routing_key=routing_key) + rescan=rescan, module_list=modules, + queue_name=queue_name, priority=priority, + routing_key=routing_key) task_id_list.append(tid) except RuntimeError as e: msg = 'ERROR: Failed to extract ' + str(file_) + ' - ' + str(e) @@ -625,7 +587,9 @@ def create_task(): unrarred_file = open(os.path.join(extract_dir, urfile)) f_name, full_path = save_hashed_filename(unrarred_file, True) tid = queue_task(urfile, f_name, full_path, metadata, - rescan=rescan, queue_name=queue_name, priority=priority, routing_key=routing_key) + rescan=rescan, module_list=modules, + queue_name=queue_name, priority=priority, + routing_key=routing_key) task_id_list.append(tid) except RuntimeError as e: msg = "ERROR: Failed to extract " + str(file_) + ' - ' + str(e) @@ -638,7 +602,9 @@ def create_task(): # File was not an archive to extract f_name, full_path = save_hashed_filename(file_) tid = queue_task(original_filename, f_name, full_path, metadata, - rescan=rescan, queue_name=queue_name, priority=priority, routing_key=routing_key) + rescan=rescan, module_list=modules, + queue_name=queue_name, priority=priority, + routing_key=routing_key) task_id_list = [tid] except SQLAlchemyError: abort(HTTP_BAD_REQUEST, {'Message': 'Could not queue task(s) due backend error'}) @@ -887,8 +853,9 @@ def get_maec_report(task_id): # Get the MAEC report from Cuckoo try: + cuckoo_report = ms.config.MS_CONFIG.get('Cuckoo', 'API URL', fallback='') maec_report = requests.get( - '{}/v1/tasks/report/{}/maec'.format(ms_config.get('Cuckoo', {}).get('API URL', ''), cuckoo_task_id) + '{}/v1/tasks/report/{}/maec'.format(cuckoo_report, cuckoo_task_id) ) except Exception as e: logger.warning('No MAEC report found for that task! - {}'.format(e)) @@ -1132,7 +1099,8 @@ def generate_pdf_report(task_id): if report_dict == TASK_STILL_PROCESSING: return make_response(jsonify(TASK_STILL_PROCESSING), HTTP_STILL_PROCESSING) - pdf = pdf_generator.create_pdf_document(MS_CONFIG, report_dict) + config_dir = os.path.split(ms.config.CONFIG_FILEPATH)[0] + pdf = pdf_generator.create_pdf_document(config_dir, report_dict) response = make_response(pdf) response.headers['Content-Type'] = 'application/pdf' response.headers['Content-Disposition'] = 'attachment; filename=%s.pdf' % task_id diff --git a/multiscanner/distributed/celery_worker.py b/multiscanner/distributed/celery_worker.py index 5b257071..3ea6d58f 100644 --- a/multiscanner/distributed/celery_worker.py +++ b/multiscanner/distributed/celery_worker.py @@ -4,9 +4,6 @@ from the utils/ directory. ''' -import codecs -import configparser -import os from datetime import datetime from socket import gethostname @@ -16,9 +13,8 @@ from kombu import Exchange, Queue -from multiscanner import CONFIG as MS_CONFIG from multiscanner import multiscan, parse_reports -from multiscanner.common import utils +from multiscanner import config as msconf from multiscanner.storage import elasticsearch_storage, storage from multiscanner.storage import sql_driver as database from multiscanner.analytics.ssdeep_analytics import SSDeepAnalytic @@ -37,41 +33,27 @@ 'tz': 'US/Eastern', } -config_object = configparser.ConfigParser() -config_object.optionxform = str -configfile = utils.get_config_path(MS_CONFIG, 'api') -config_object.read(configfile) - -if not config_object.has_section('celery') or not os.path.isfile(configfile): - # Write default config - config_object.add_section('celery') - for key in DEFAULTCONF: - config_object.set('celery', key, str(DEFAULTCONF[key])) - conffile = codecs.open(configfile, 'w', 'utf-8') - config_object.write(conffile) - conffile.close() -config = utils.parse_config(config_object) -api_config = config.get('api') -worker_config = config.get('celery') -db_config = config.get('Database') - -storage_config_object = configparser.ConfigParser() -storage_config_object.optionxform = str -storage_configfile = utils.get_config_path(MS_CONFIG, 'storage') -storage_config_object.read(storage_configfile) -config = utils.parse_config(storage_config_object) -es_storage_config = config.get('ElasticSearchStorage') +configfile = msconf.get_config_path('api') +config = msconf.read_config(configfile, {'celery': DEFAULTCONF, 'Database': database.Database.DEFAULTCONF}) +db_config = dict(config.items('Database')) + +storage_configfile = msconf.get_config_path('storage') +storage_config = msconf.read_config(storage_configfile) +try: + es_storage_config = storage_config['ElasticSearchStorage'] +except KeyError: + es_storage_config = {} default_exchange = Exchange('celery', type='direct') app = Celery(broker='{0}://{1}:{2}@{3}/{4}'.format( - worker_config.get('protocol'), - worker_config.get('user'), - worker_config.get('password'), - worker_config.get('host'), - worker_config.get('vhost'), + config.get('celery', 'protocol'), + config.get('celery', 'user'), + config.get('celery', 'password'), + config.get('celery', 'host'), + config.get('celery', 'vhost'), )) -app.conf.timezone = worker_config.get('tz') +app.conf.timezone = config.get('celery', 'tz') app.conf.task_queues = [ Queue('low_tasks', default_exchange, routing_key='tasks.low', queue_arguments={'x-max-priority': 10}), Queue('medium_tasks', default_exchange, routing_key='tasks.medium', queue_arguments={'x-max-priority': 10}), @@ -100,8 +82,8 @@ def setup_periodic_tasks(sender, **kwargs): sender.add_periodic_task( crontab(hour=3, minute=0), metricbeat_rollover_celery.s(), - args=(es_storage_config.get('metricbeat_rollover_days')), - kwargs=dict(config=MS_CONFIG), + args=(es_storage_config.get('metricbeat_rollover_days'), 7), + kwargs=dict(config=msconf.MS_CONFIG), **{ 'queue': 'low_tasks', 'routing_key': 'tasks.low', @@ -158,7 +140,7 @@ def on_success(self, retval, task_id, args, kwargs): @app.task(base=MultiScannerTask) def multiscanner_celery(file_, original_filename, task_id, file_hash, metadata, - config=MS_CONFIG, module_list=None): + config=None, module_list=None): ''' Queue up multiscanner tasks @@ -170,12 +152,16 @@ def multiscanner_celery(file_, original_filename, task_id, file_hash, metadata, logger.info('\n\n{}{}Got file: {}.\nOriginal filename: {}.\n'.format('=' * 48, '\n', file_hash, original_filename)) # Get the storage config - storage_conf = utils.get_config_path(config, 'storage') + if config is None: + config = msconf.MS_CONFIG + elif isinstance(config, dict): + config = msconf.dict_to_config(config) + storage_conf = msconf.get_config_path('storage', config) storage_handler = storage.StorageHandler(configfile=storage_conf) resultlist = multiscan( [file_], - configfile=config, + config=config, module_list=module_list ) results = parse_reports(resultlist, python=True) @@ -184,26 +170,22 @@ def multiscanner_celery(file_, original_filename, task_id, file_hash, metadata, # Get the Scan Config that the task was run with and # add it to the task metadata - scan_config_object = configparser.ConfigParser() - scan_config_object.optionxform = str - scan_config_object.read(config) - full_conf = utils.parse_config(scan_config_object) sub_conf = {} - # Count number of modules enabled out of total possible + # Count number of modules enabled out of total possible (-1 for main) # and add it to the Scan Metadata total_enabled = 0 - total_modules = len(full_conf.keys()) + total_modules = len(config.keys()) - 1 # Get the count of modules enabled from the module_list # if it exists, else count via the config if module_list: total_enabled = len(module_list) else: - for key in full_conf: + for key in config: if key == 'main': continue sub_conf[key] = {} - sub_conf[key]['ENABLED'] = full_conf[key]['ENABLED'] + sub_conf[key]['ENABLED'] = config[key]['ENABLED'] if sub_conf[key]['ENABLED'] is True: total_enabled += 1 @@ -249,7 +231,7 @@ def ssdeep_compare_celery(): @app.task() -def metricbeat_rollover_celery(days, config=MS_CONFIG): +def metricbeat_rollover_celery(days): ''' Clean up old Elastic Beats indices ''' @@ -263,7 +245,7 @@ def metricbeat_rollover_celery(days, config=MS_CONFIG): return if not days: - days = es_storage_config.get('metricbeat_rollover_days') + days = es_storage_config.get('metricbeat_rollover_days', 7) if not days: raise NameError("name 'days' is not defined, check storage.ini for 'metricbeat_rollover_days' setting") diff --git a/multiscanner/distributed/distributed_worker.py b/multiscanner/distributed/distributed_worker.py index 6fe300fc..b8b66010 100755 --- a/multiscanner/distributed/distributed_worker.py +++ b/multiscanner/distributed/distributed_worker.py @@ -5,8 +5,6 @@ from __future__ import (absolute_import, division, unicode_literals, with_statement) import argparse -import codecs -import configparser import logging import multiprocessing import os @@ -18,7 +16,7 @@ standard_library.install_aliases() from multiscanner import multiscan, parse_reports -from multiscanner.common import utils +from multiscanner.config import get_config_path, read_config, update_ms_config_file from multiscanner.storage import storage @@ -33,7 +31,7 @@ def multiscanner_process(work_queue, config, batch_size, wait_seconds, delete, exit_signal): filelist = [] time_stamp = None - storage_conf = utils.get_config_path(config, 'storage') + storage_conf = get_config_path('storage', config) storage_handler = storage.StorageHandler(configfile=storage_conf) while not exit_signal.value: time.sleep(1) @@ -54,7 +52,7 @@ def multiscanner_process(work_queue, config, batch_size, wait_seconds, delete, e else: continue - resultlist = multiscan(filelist, configfile=config) + resultlist = multiscan(filelist, config=config) results = parse_reports(resultlist, python=True) if delete: for file_name in results: @@ -68,19 +66,13 @@ def multiscanner_process(work_queue, config, batch_size, wait_seconds, delete, e storage_handler.close() -def _read_conf(file_path): - conf = configparser.ConfigParser() - conf.optionxform = str - with codecs.open(file_path, 'r', encoding='utf-8') as fp: - conf.readfp(fp) - return utils.parse_config(conf) - - def _main(): args = _parse_args() # Pull config options - conf = _read_conf(args.config) + conf = read_config(args.config) multiscanner_config = conf['worker']['multiscanner_config'] + update_ms_config_file(multiscanner_config) + config = read_config(multiscanner_config) # Start worker task work_queue = multiprocessing.Queue() @@ -88,7 +80,7 @@ def _main(): exit_signal.value = False ms_process = multiprocessing.Process( target=multiscanner_process, - args=(work_queue, multiscanner_config, args.delete, exit_signal)) + args=(work_queue, config, args.delete, exit_signal)) ms_process.start() # Start message pickup task diff --git a/multiscanner/modules/antivirus/AVGScan.py b/multiscanner/modules/antivirus/AVGScan.py index 50efe45d..09b81406 100644 --- a/multiscanner/modules/antivirus/AVGScan.py +++ b/multiscanner/modules/antivirus/AVGScan.py @@ -7,7 +7,7 @@ import subprocess import re -from multiscanner.config import CONFIG +import multiscanner as ms from multiscanner.common.utils import list2cmdline, sshexec, SSH subprocess.list2cmdline = list2cmdline @@ -21,7 +21,7 @@ # Hostname, port, username HOST = ("MultiScanner", 22, "User") # SSH Key -KEY = os.path.join(os.path.split(CONFIG)[0], 'etc', 'id_rsa') +KEY = os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], 'etc', 'id_rsa') # Replacement path for SSH connections PATHREPLACE = "X:\\" DEFAULTCONF = { diff --git a/multiscanner/modules/antivirus/MSEScan.py b/multiscanner/modules/antivirus/MSEScan.py index 672aa95c..2e263a7e 100644 --- a/multiscanner/modules/antivirus/MSEScan.py +++ b/multiscanner/modules/antivirus/MSEScan.py @@ -6,7 +6,7 @@ import os import subprocess -from multiscanner.config import CONFIG +import multiscanner as ms from multiscanner.common.utils import list2cmdline, sshconnect, SSH subprocess.list2cmdline = list2cmdline @@ -18,7 +18,7 @@ NAME = "Microsoft Security Essentials" # These are overwritten by the config file # SSH Key -KEY = os.path.join(os.path.split(CONFIG)[0], 'etc', 'id_rsa') +KEY = os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], 'etc', 'id_rsa') # Replacement path for SSH connections PATHREPLACE = "X:\\" HOST = ("MultiScanner", 22, "User") diff --git a/multiscanner/modules/antivirus/McAfeeScan.py b/multiscanner/modules/antivirus/McAfeeScan.py index b4f0b6f4..3baa953d 100644 --- a/multiscanner/modules/antivirus/McAfeeScan.py +++ b/multiscanner/modules/antivirus/McAfeeScan.py @@ -7,7 +7,7 @@ import subprocess import re -from multiscanner.config import CONFIG +import multiscanner as ms from multiscanner.common.utils import list2cmdline, sshexec, SSH subprocess.list2cmdline = list2cmdline @@ -19,7 +19,7 @@ NAME = "McAfee" # These are overwritten by the config file # SSH Key -KEY = os.path.join(os.path.split(CONFIG)[0], 'etc', 'id_rsa') +KEY = os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], 'etc', 'id_rsa') # Replacement path for SSH connections PATHREPLACE = "X:\\" HOST = ("MultiScanner", 22, "User") diff --git a/multiscanner/modules/database/NSRL.py b/multiscanner/modules/database/NSRL.py index 9d8381f6..05c5f242 100755 --- a/multiscanner/modules/database/NSRL.py +++ b/multiscanner/modules/database/NSRL.py @@ -8,7 +8,7 @@ import os import struct -from multiscanner.config import CONFIG +import multiscanner as ms __author__ = "Drew Bonasera" __license__ = "MPL 2.0" @@ -19,8 +19,8 @@ REQUIRES = ["filemeta"] DEFAULTCONF = { - 'hash_list': os.path.join(os.path.split(CONFIG)[0], 'etc', 'nsrl', 'hash_list'), - 'offsets': os.path.join(os.path.split(CONFIG)[0], 'etc', 'nsrl', 'offsets'), + 'hash_list': os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], 'etc', 'nsrl', 'hash_list'), + 'offsets': os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], 'etc', 'nsrl', 'offsets'), 'ENABLED': True } diff --git a/multiscanner/modules/machinelearning/EndgameEmber.py b/multiscanner/modules/machinelearning/EndgameEmber.py index 554cd316..a5a837d1 100644 --- a/multiscanner/modules/machinelearning/EndgameEmber.py +++ b/multiscanner/modules/machinelearning/EndgameEmber.py @@ -19,7 +19,7 @@ import os from pathlib import Path -from multiscanner import CONFIG +import multiscanner as ms __authors__ = "Patrick Copeland" @@ -30,7 +30,7 @@ REQUIRES = ['libmagic'] DEFAULTCONF = { 'ENABLED': False, - 'path-to-model': os.path.join(os.path.split(CONFIG)[0], 'etc', 'ember', 'ember_model_2017.txt'), + 'path-to-model': os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], 'etc', 'ember', 'ember_model_2017.txt'), } LGBM_MODEL = None diff --git a/multiscanner/modules/metadata/ExifToolsScan.py b/multiscanner/modules/metadata/ExifToolsScan.py index 2dd402e8..47aa3b3d 100644 --- a/multiscanner/modules/metadata/ExifToolsScan.py +++ b/multiscanner/modules/metadata/ExifToolsScan.py @@ -8,7 +8,7 @@ import subprocess import re -from multiscanner.config import CONFIG +import multiscanner as ms from multiscanner.common.utils import list2cmdline, sshexec, SSH subprocess.list2cmdline = list2cmdline @@ -20,7 +20,7 @@ NAME = "ExifTool" # These are overwritten by the config file HOST = ("MultiScanner", 22, "User") -KEY = os.path.join(os.path.split(CONFIG)[0], "etc", "id_rsa") +KEY = os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], "etc", "id_rsa") PATHREPLACE = "X:\\" # Entries to be removed from the final results REMOVEENTRY = ["ExifTool Version Number", "File Name", "Directory", "File Modification Date/Time", diff --git a/multiscanner/modules/metadata/TrID.py b/multiscanner/modules/metadata/TrID.py index 30d7e4a3..344bc868 100644 --- a/multiscanner/modules/metadata/TrID.py +++ b/multiscanner/modules/metadata/TrID.py @@ -8,7 +8,7 @@ import subprocess import re -from multiscanner.config import CONFIG +import multiscanner as ms from multiscanner.common.utils import list2cmdline, sshexec, SSH logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ # Hostname, port, username HOST = ("MultiScanner", 22, "User") # SSH Key -KEY = os.path.join(os.path.split(CONFIG)[0], 'etc', 'id_rsa') +KEY = os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], 'etc', 'id_rsa') # Replacement path for SSH connections PATHREPLACE = "X:\\" DEFAULTCONF = { diff --git a/multiscanner/modules/signature/YaraScan.py b/multiscanner/modules/signature/YaraScan.py index e3a03ea0..31ed5f85 100644 --- a/multiscanner/modules/signature/YaraScan.py +++ b/multiscanner/modules/signature/YaraScan.py @@ -8,8 +8,8 @@ import os import time -from multiscanner.config import CONFIG -from multiscanner.common.utils import parseDir +import multiscanner as ms +from multiscanner.common.utils import parse_dir __authors__ = "Nick Beede, Drew Bonasera" @@ -18,7 +18,7 @@ TYPE = "Signature" NAME = "Yara" DEFAULTCONF = { - "ruledir": os.path.join(os.path.split(CONFIG)[0], "etc", "yarasigs"), + "ruledir": os.path.join(os.path.split(ms.config.CONFIG_FILEPATH)[0], "etc", "yarasigs"), "fileextensions": [".yar", ".yara", ".sig"], "ignore-tags": ["TLPRED"], "string-threshold": 30, @@ -51,7 +51,7 @@ def scan(filelist, conf=DEFAULTCONF): ruleset = {} try: - rules = parseDir(ruleDir, recursive=True) + rules = parse_dir(ruleDir, recursive=True) except (OSError, IOError) as e: logger.error('Cannot read files: {}'.format(e.filename)) return None diff --git a/multiscanner/ms.py b/multiscanner/ms.py index e2f2b0ce..03c3803e 100644 --- a/multiscanner/ms.py +++ b/multiscanner/ms.py @@ -7,7 +7,6 @@ import argparse import codecs -import configparser import datetime import json import logging @@ -16,6 +15,7 @@ import random import re import shutil +import six import string import sys import tempfile @@ -29,9 +29,11 @@ from multiscanner.version import __version__ as MS_VERSION from multiscanner.common.utils import (basename, convert_encoding, load_module, - parse_config, parseDir, parseFileList, - queue2list) -from multiscanner.config import PY3, CONFIG, MODULESDIR, determine_configuration_path + parse_file_list, queue2list) +from multiscanner import config as msconf +from multiscanner.config import (MSConfigParser, PY3, config_init, get_config_path, + update_ms_config, update_ms_config_file, + update_paths_in_config, write_missing_config) from multiscanner.storage import storage @@ -39,9 +41,9 @@ DEFAULTCONF = { "copyfilesto": False, "group-types": ["Antivirus"], - "storage-config": CONFIG.replace('config.ini', 'storage.ini'), - "api-config": CONFIG.replace('config.ini', 'api_config.ini'), - "web-config": CONFIG.replace('config.ini', 'web_config.ini'), + "storage-config": msconf.CONFIG_FILEPATH.replace('config.ini', 'storage.ini'), + "api-config": msconf.CONFIG_FILEPATH.replace('config.ini', 'api_config.ini'), + "web-config": msconf.CONFIG_FILEPATH.replace('config.ini', 'web_config.ini'), } logger = logging.getLogger(__name__) @@ -248,49 +250,6 @@ def _run_module(modname, mod, filelist, threadDict, global_module_interface, con logger.debug("{} failed check()".format(modname)) -def _update_DEFAULTCONF(defaultconf, filepath): - if 'storage-config' in defaultconf: - defaultconf['storage-config'] = filepath.replace('config.ini', 'storage.ini') - if 'api-config' in defaultconf: - defaultconf['api-config'] = filepath.replace('config.ini', 'api_config.ini') - if 'web-config' in defaultconf: - defaultconf['web-config'] = filepath.replace('config.ini', 'web_config.ini') - if 'ruledir' in defaultconf: - defaultconf['ruledir'] = os.path.join(os.path.split(filepath)[0], "etc", "yarasigs") - if 'key' in defaultconf: - defaultconf['key'] = os.path.join(os.path.split(filepath)[0], 'etc', 'id_rsa') - if 'hash_list' in defaultconf: - defaultconf['hash_list'] = os.path.join(os.path.split(filepath)[0], 'etc', 'nsrl', 'hash_list') - if 'offsets' in defaultconf: - defaultconf['offsets'] = os.path.join(os.path.split(filepath)[0], 'etc', 'nsrl', 'offsets') - - -def _get_main_config(config_object, filepath=CONFIG): - """ - Reads in config for main script. It will write defaults if not present. - Returns dictionary. - - Config - The config object - filepath - The path to the config file - """ - filepath = determine_configuration_path(filepath) - # Write main defaults if needed - ConfNeedsWrite = False - if 'main' not in config_object.sections(): - ConfNeedsWrite = True - _update_DEFAULTCONF(DEFAULTCONF, filepath) - config_object.add_section('main') - for key in DEFAULTCONF: - config_object.set('main', key, str(DEFAULTCONF[key])) - - if ConfNeedsWrite: - with codecs.open(filepath, 'w', 'utf-8') as f: - config_object.write(f) - - # Read in main config - return parse_config(config_object)['main'] - - def _copy_to_share(filelist, filedic, sharedir): """ Copies files from filelist to a share and populates the filedic. Returns a @@ -323,12 +282,12 @@ def _copy_to_share(filelist, filedic, sharedir): return filelist -def _start_module_threads(filelist, ModuleList, config, global_module_interface): +def _start_module_threads(filelist, module_list, config, global_module_interface): """ Starts each module on the file list in a separate thread. Returns a list of threads filelist - A lists of strings. The strings are files to be scanned - ModuleList - A list of all the modules to be run + module_list - A list of the names of all modules to be run config - The config dictionary global_module_interface - The global module interface to be injected in each module """ @@ -337,145 +296,43 @@ def _start_module_threads(filelist, ModuleList, config, global_module_interface) ThreadDict = {} global_module_interface.run_count += 1 # Starts a thread for each module. - for module in ModuleList: - if module.endswith(".py"): - modname = os.path.basename(module[:-3]) - - # If the module is disabled we don't mess with it further to prevent spamming errors on screen - if modname in config: - if not config[modname].get('ENABLED', True): - continue - - moddir = os.path.dirname(module) - mod = load_module(os.path.basename(module).split('.')[0], [moddir]) - if not mod: - logger.warning("{} not a valid module...".format(module)) + for modname in module_list: + # If the module is disabled we don't mess with it further to prevent spamming errors on screen + if modname in config: + if not config[modname].get('ENABLED', True): continue - conf = None - if modname in config: - if '_load_default' in config or '_load_default' in config[modname]: - try: - conf = mod.DEFAULTCONF - conf.update(config[modname]) - except Exception as e: - logger.warning(e) - conf = config[modname] - # Remove _load_default from config - if '_load_default' in conf: - del conf['_load_default'] - else: - conf = config[modname] - - # Try and read in the default conf if one was not passed - if not conf: - try: - conf = mod.DEFAULTCONF - except Exception as e: - logger.error(e) - thread = _Thread( - target=_run_module, - args=(modname, mod, filelist, ThreadDict, global_module_interface, conf)) - thread.name = modname - thread.setDaemon(True) - ThreadList.append(thread) - ThreadDict[modname] = thread - for thread in ThreadList: - thread.start() - return ThreadList - + # TODO: What if the module isn't specified in the config -def _write_missing_module_configs(ModuleList, config, filepath=CONFIG): - """ - Write in default config for modules not in config file. Returns True if config was written, False if not. - - ModuleList - The list of modules - config - The config object - """ - filepath = determine_configuration_path(filepath) - ConfNeedsWrite = False - ModuleList.sort() - for module in ModuleList: - if module.endswith(".py"): - modname = os.path.basename(module).split('.')[0] - moddir = os.path.dirname(module) - if modname not in config.sections(): - mod = load_module(os.path.basename(module).split('.')[0], [moddir]) - if mod: - try: - conf = mod.DEFAULTCONF - except Exception as e: - logger.warning(e) - continue - ConfNeedsWrite = True - _update_DEFAULTCONF(conf, filepath) - config.add_section(modname) - for key in conf: - config.set(modname, key, str(conf[key])) - - if 'main' not in config.sections(): - ConfNeedsWrite = True - _update_DEFAULTCONF(DEFAULTCONF, filepath) - config.add_section('main') - for key in DEFAULTCONF: - config.set('main', key, str(DEFAULTCONF[key])) - - if ConfNeedsWrite: - with codecs.open(filepath, 'w', 'utf-8') as f: - config.write(f) - return True - return False - - -def _rewrite_config(ModuleList, config, filepath=CONFIG): - """ - Write in default config for all modules. - - ModuleList - The list of modules - config - The config object - """ - filepath = determine_configuration_path(filepath) - logger.info('Rewriting config...') - ModuleList.sort() - for module in ModuleList: - if module.endswith('.py'): - modname = os.path.basename(module).split('.')[0] - moddir = os.path.dirname(module) - mod = load_module(os.path.basename(module).split('.')[0], [moddir]) - if mod: - try: - conf = mod.DEFAULTCONF - except Exception as e: - logger.warning(e) - continue - _update_DEFAULTCONF(conf, filepath) - config.add_section(modname) - for key in conf: - config.set(modname, key, str(conf[key])) - - _update_DEFAULTCONF(DEFAULTCONF, filepath) - config.add_section('main') - for key in DEFAULTCONF: - config.set('main', key, str(DEFAULTCONF[key])) - - with codecs.open(filepath, 'w', 'utf-8') as f: - config.write(f) - - -def config_init(filepath, module_list=parseDir(MODULESDIR, recursive=True, exclude=["__init__"])): - """ - Creates a new config file at filepath + try: + moddir = msconf.MODULE_LIST[modname][1] + except KeyError: + logger.warning(msconf.MODULE_LIST) + logger.warning("{} not a valid module...".format(modname)) + continue - filepath - The config file to create - """ - config = configparser.ConfigParser() - config.optionxform = str + mod = load_module(modname, [moddir]) + if not mod: + logger.warning("{} not a valid module...".format(modname)) + continue + try: + conf = mod.DEFAULTCONF + except Exception as e: + logger.error(e) + conf = {} + if modname in config: + conf.update(config[modname]) + + thread = _Thread( + target=_run_module, + args=(modname, mod, filelist, ThreadDict, global_module_interface, conf)) + thread.name = modname + thread.setDaemon(True) + ThreadList.append(thread) + ThreadDict[modname] = thread - if filepath: - _rewrite_config(module_list, config, filepath) - else: - filepath = determine_configuration_path(filepath) - _rewrite_config(module_list, config, filepath) - logger.info('Configuration file initialized at {}'.format(filepath)) + for thread in ThreadList: + thread.start() + return ThreadList def parse_reports(resultlist, groups=None, ugly=True, includeMetadata=False, python=False): @@ -529,74 +386,34 @@ def parse_reports(resultlist, groups=None, ugly=True, includeMetadata=False, pyt return json.dumps(finaldata, sort_keys=True, separators=(',', ':'), ensure_ascii=False) -def multiscan(Files, recursive=False, configregen=False, configfile=CONFIG, config=None, module_list=None): +def multiscan(Files, config=None, module_list=None): """ The meat and potatoes. Returns the list of module results Files - A list of files and dirs to be scanned - recursive - If true it will search the dirs in Files recursively - configregen - If True a new config file will be created overwriting the old - configfile - What config file to use. Can be None. - config - A dictionary containing the configuration options to be used. - module_list - A list of file paths to be used as modules. Each string should end in .py + config - ConfigParser object containing the configuration options to be used. + module_list - A list of the names of the modules to run on the files. """ # Init some vars - # If recursive is False we don't parse the file list and take it as is. - if recursive: - filelist = parseFileList(Files, recursive=recursive) - else: - filelist = Files + filelist = Files # A list of files in the module dir if module_list is None: - module_list = parseDir(MODULESDIR, recursive=True, exclude=["__init__"]) + module_list = [modname for modname in msconf.MODULE_LIST] # A dictionary used for the copyfileto parameter filedic = {} - # What will be the config file object - config_object = None - - # Read in config - if configfile: - config_object = configparser.ConfigParser() - config_object.optionxform = str - # Regen the config if needed or wanted - if configregen or not os.path.isfile(configfile): - _rewrite_config(module_list, config_object, filepath=configfile) - - config_object.read(configfile) - main_config = _get_main_config(config_object, filepath=configfile) - if config: - file_conf = parse_config(config_object) - for key in config: - if key not in file_conf: - file_conf[key] = config[key] - file_conf[key]['_load_default'] = True - else: - file_conf[key].update(config[key]) - config = file_conf - else: - config = parse_config(config_object) - else: - if config is None: - config = {} - else: - config['_load_default'] = True - if 'main' in config: - main_config = config['main'] - else: - main_config = DEFAULTCONF - # If none of the files existed - if not filelist: - raise ValueError("No valid files") + if config is None: + config = MSConfigParser() + elif isinstance(config, dict): + config = msconf.dict_to_config(config) # Copy files to a share if configured - if "copyfilesto" not in main_config: - main_config["copyfilesto"] = False - if main_config["copyfilesto"]: - if os.path.isdir(main_config["copyfilesto"]): - filelist = _copy_to_share(filelist, filedic, main_config["copyfilesto"]) + copyfilesto = config.get('main', 'copyfilesto', fallback=DEFAULTCONF['copyfilesto']) + if copyfilesto: + if os.path.isdir(copyfilesto): + filelist = _copy_to_share(filelist, filedic, copyfilesto) else: - raise IOError('The copyfilesto dir "' + main_config["copyfilesto"] + '" is not a valid dir') + raise IOError('The copyfilesto dir "{}" is not a valid dir'.format(copyfilesto)) # Create the global module interface global_module_interface = _GlobalModuleInterface() @@ -604,10 +421,6 @@ def multiscan(Files, recursive=False, configregen=False, configfile=CONFIG, conf # Start a thread for each module thread_list = _start_module_threads(filelist, module_list, config, global_module_interface) - # Write the default configure settings for missing ones - if config_object: - _write_missing_module_configs(module_list, config_object, filepath=configfile) - # Warn about spaces in file names for f in filelist: if ' ' in f: @@ -636,7 +449,7 @@ def multiscan(Files, recursive=False, configregen=False, configfile=CONFIG, conf time.sleep(1) # Delete copied files - if main_config["copyfilesto"]: + if copyfilesto: for item in filelist: try: os.remove(item) @@ -678,20 +491,20 @@ def multiscan(Files, recursive=False, configregen=False, configfile=CONFIG, conf from_filename = filedic[base] subscan_list[i] = (file_path, from_filename, module_name) - results.extend(_subscan(subscan_list, config, main_config, module_list, global_module_interface)) + results.extend(_subscan(subscan_list, config, copyfilesto, module_list, global_module_interface)) global_module_interface._cleanup() return results -def _subscan(subscan_list, config, main_config, module_list, global_module_interface): +def _subscan(subscan_list, config, copyfilesto, module_list, global_module_interface): """ Scans files created by modules subscan_list - The result of _get_subscan_list() from the global module interface config - The configuration dictionary - main_config - A dictionary of the configuration for main + copyfilesto - Directory to copy files to; if False files will not be copied module_list - The list of modules global_module_interface - The global module interface """ @@ -740,10 +553,8 @@ def _subscan(subscan_list, config, main_config, module_list, global_module_inter del subscan_list, subfiles_dict # Copy files to a share if configured - if "copyfilesto" not in main_config: - main_config["copyfilesto"] = False - if main_config["copyfilesto"]: - filelist = _copy_to_share(filelist, filedic, main_config["copyfilesto"]) + if copyfilesto: + filelist = _copy_to_share(filelist, filedic, copyfilesto) # Start a thread for each module thread_list = _start_module_threads(filelist, module_list, config, global_module_interface) @@ -770,7 +581,7 @@ def _subscan(subscan_list, config, main_config, module_list, global_module_inter time.sleep(1) # Delete copied files - if main_config["copyfilesto"]: + if copyfilesto: for item in filelist: os.remove(item) @@ -809,7 +620,7 @@ def _subscan(subscan_list, config, main_config, module_list, global_module_inter null, from_filename = file_mapping[from_filename] subscan_list[i] = (file_path, from_filename, module_name) - results.extend(_subscan(subscan_list, config, main_config, module_list, global_module_interface)) + results.extend(_subscan(subscan_list, config, copyfilesto, module_list, global_module_interface)) return results @@ -819,7 +630,7 @@ def _parse_args(): Parses arguments """ # argparse stuff - desc = "MultiScanner v{} - Analyse files against multiple engines" + desc = "MultiScanner v{} - Analyze files against multiple engines" parser = argparse.ArgumentParser(description=desc.format(MS_VERSION)) parser.add_argument("-c", "--config", required=False, default=None, help="The config file to use") @@ -848,12 +659,29 @@ def _parse_args(): parser.add_argument("--resume", action="store_true", help="Read in the report file and continue where we left off") parser.add_argument('Files', nargs='+', - help="Files and Directories to analyse") + help="Files and Directories to analyze") return parser.parse_args() +def _get_main_modules(): + module_list = {} + module_list['main'] = sys.modules[__name__] # current module + for modname, module in sorted(six.iteritems(msconf.MODULE_LIST)): + moddir = module[1] + mod = load_module(modname, [moddir]) + if mod: + module_list[modname] = mod + return module_list + + def _init(args): # Initialize configuration file + if args.config is None: + args.config = msconf.CONFIG_FILEPATH + + # Compile all the sections to go in the config + module_list = _get_main_modules() + if os.path.isfile(args.config): logger.warning('{} already exists, overwriting will destroy changes'.format(args.config)) try: @@ -862,53 +690,45 @@ def _init(args): logger.warning(e) answer = 'N' if answer == 'y': - config_init(args.config) + config = config_init(args.config, module_list, overwrite=True) + update_ms_config(config) # Set global main config + logger.info('Main configuration file initialized at {}'.format(args.config)) else: - logger.info('Checking for missing modules in configuration...') - ModuleList = parseDir(MODULESDIR, recursive=True, exclude=["__init__"]) - config = configparser.ConfigParser() - config.optionxform = str - config.read(args.config) - _write_missing_module_configs(ModuleList, config, filepath=args.config) + logger.info('Checking for missing modules in main configuration...') + config = config_init(args.config, module_list, overwrite=False) + update_ms_config(config) # Set global main config else: - config_init(args.config) + config = config_init(args.config, module_list) + update_ms_config(config) # Set global main config + logger.info('Main configuration file initialized at {}'.format(args.config)) # Init storage - config = configparser.ConfigParser() - config.optionxform = str - config.read(args.config) - config = _get_main_config(config) - if os.path.isfile(config["storage-config"]): - logger.warning('{} already exists, overwriting will destroy changes'.format(config["storage-config"])) + storage_config = get_config_path('storage') + storage_classes = storage._get_storage_classes() + storage_classes['main'] = sys.modules[storage.__name__] + if os.path.isfile(storage_config): + logger.warning('{} already exists, overwriting will destroy changes'.format(storage_config)) try: answer = input('Do you wish to overwrite the configuration file [y/N]:') except EOFError as e: logger.warning(e) answer = 'N' if answer == 'y': - storage.config_init(config["storage-config"], overwrite=True) - logger.info('Storage configuration file initialized at {}'.format(config["storage-config"])) + config_init(storage_config, storage_classes, overwrite=True) + logger.info('Storage configuration file initialized at {}'.format(storage_config)) else: logger.info('Checking for missing modules in storage configuration...') - storage.config_init(config["storage-config"], overwrite=False) + config_init(storage_config, storage_classes, overwrite=False) else: - storage.config_init(config["storage-config"]) - logger.info('Storage configuration file initialized at {}'.format(config["storage-config"])) + config_init(storage_config, storage_classes) + logger.info('Storage configuration file initialized at {}'.format(storage_config)) exit(0) def _main(): - global CONFIG - # Get args args = _parse_args() - # Set config or update locations - if args.config is None: - args.config = CONFIG - else: - CONFIG = args.config - _update_DEFAULTCONF(DEFAULTCONF, CONFIG) # Send all logs to stderr and set verbose if args.debug or args.verbose > 1: @@ -925,12 +745,23 @@ def _main(): logging.basicConfig(format="%(asctime)s [%(module)s] %(levelname)s: %(message)s", stream=sys.stderr, level=log_lvl) - # Checks if user is trying to initialize - if str(args.Files) == "['init']" and not os.path.isfile('init'): + # Check if user is trying to initialize + if args.Files == ['init'] and not os.path.isfile('init'): _init(args) + # Set config or update locations + if args.config is None: + args.config = msconf.CONFIG_FILEPATH + else: + update_ms_config_file(args.config) + update_paths_in_config(DEFAULTCONF, msconf.CONFIG_FILEPATH) + + module_list = _get_main_modules() if not os.path.isfile(args.config): - config_init(args.config) + config_init(args.config, module_list) + else: + # Write the default config settings for any missing modules + write_missing_config(module_list, msconf.MS_CONFIG, msconf.CONFIG_FILEPATH) # Make sure report is not a dir if args.json: @@ -938,7 +769,7 @@ def _main(): sys.exit('ERROR:', args.json, 'is a directory, a file is expected') # Parse the file list - parsedlist = parseFileList(args.Files, recursive=args.recursive) + parsedlist = parse_file_list(args.Files, recursive=args.recursive) # Unzip zip files if asked to if args.extractzips: @@ -990,17 +821,12 @@ def _main(): starttime = str(datetime.datetime.now()) # Run the multiscan - results = multiscan(filelist, configfile=args.config) + results = multiscan(filelist, config=msconf.MS_CONFIG) # We need to read in the config for the parseReports call - config = configparser.ConfigParser() - config.optionxform = str - config.read(args.config) - config = _get_main_config(config) + config = msconf.MS_CONFIG.get_section('main') # Make sure we have a group-types - if "group-types" not in config: - config["group-types"] = [] - elif not config["group-types"]: + if "group-types" not in config or not config["group-types"]: config["group-types"] = [] # Add in script metadata diff --git a/multiscanner/storage/__init__.py b/multiscanner/storage/__init__.py index 30410fc0..40ce5430 100644 --- a/multiscanner/storage/__init__.py +++ b/multiscanner/storage/__init__.py @@ -1,3 +1,3 @@ -from .storage import config_init, Storage, StorageHandler +from .storage import Storage, StorageHandler __all__ = ['config_init', 'Storage', 'StorageHandler', ] diff --git a/multiscanner/storage/sql_driver.py b/multiscanner/storage/sql_driver.py index 2852ded9..23546447 100644 --- a/multiscanner/storage/sql_driver.py +++ b/multiscanner/storage/sql_driver.py @@ -1,8 +1,6 @@ #!/usr/bin/env python from __future__ import print_function -import codecs -import configparser import json import logging import os @@ -18,10 +16,9 @@ from sqlalchemy.pool import QueuePool from sqlalchemy_utils import create_database, database_exists -from multiscanner import CONFIG +from multiscanner.config import MSConfigParser, get_config_path, reset_config - -CONFIG_FILE = os.path.join(os.path.split(CONFIG)[0], "api_config.ini") +CONFIG_FILEPATH = get_config_path('api') Base = declarative_base() logger = logging.getLogger(__name__) @@ -71,53 +68,39 @@ class Database(object): 'strategy': 'threadlocal' } - def __init__(self, config=None, configfile=CONFIG_FILE, regenconfig=False): + def __init__(self, config=None, configfile=None, regenconfig=False): self.db_connection_string = None self.db_engine = None # Configuration parsing - config_parser = configparser.ConfigParser() - config_parser.optionxform = str + config_parser = MSConfigParser() + if configfile is None: + configfile = CONFIG_FILEPATH + section_name = self.__class__.__name__ # (re)generate conf file if necessary if regenconfig or not os.path.isfile(configfile): - self._rewrite_config(config_parser, configfile, config) + sections = {section_name: self} + reset_config(sections, config_parser, configfile) + # now read in and parse the conf file config_parser.read(configfile) # If we didn't regen the config file in the above check, it's possible # that the file is missing our DB settings... - if not config_parser.has_section(self.__class__.__name__): - self._rewrite_config(config_parser, configfile, config) - config_parser.read(configfile) + if not config_parser.has_section(section_name): + sections = {section_name: self} + reset_config(sections, config_parser, configfile) # If configuration was specified, use what was stored in the config file # as a base and then override specific settings as contained in the user's # config. This allows the user to specify ONLY the config settings they want to # override - config_from_file = dict(config_parser.items(self.__class__.__name__)) + config_from_file = dict(config_parser.items(section_name)) if config: for key_ in config: config_from_file[key_] = config[key_] self.config = config_from_file - def _rewrite_config(self, config_parser, configfile, usr_override_config): - """ - Regenerates the Database-specific part of the API config file - """ - if os.path.isfile(configfile): - # Read in the old config - config_parser.read(configfile) - if not config_parser.has_section(self.__class__.__name__): - config_parser.add_section(self.__class__.__name__) - if not usr_override_config: - usr_override_config = self.DEFAULTCONF - # Update config - for key_ in usr_override_config: - config_parser.set(self.__class__.__name__, key_, str(usr_override_config[key_])) - - with codecs.open(configfile, 'w', 'utf-8') as conffile: - config_parser.write(conffile) - def init_db(self): """ Initializes the database connection based on the configuration parameters @@ -126,7 +109,7 @@ def init_db(self): db_name = self.config['db_name'] if db_type == 'sqlite': # we can ignore host, username, password, etc - sql_lite_db_path = os.path.join(os.path.split(CONFIG)[0], db_name) + sql_lite_db_path = os.path.join(os.path.split(CONFIG_FILEPATH)[0], db_name) self.db_connection_string = 'sqlite:///{}'.format(sql_lite_db_path) else: username = self.config['username'] diff --git a/multiscanner/storage/storage.py b/multiscanner/storage/storage.py index b90090ca..93670e71 100644 --- a/multiscanner/storage/storage.py +++ b/multiscanner/storage/storage.py @@ -4,8 +4,6 @@ from __future__ import (absolute_import, division, unicode_literals, with_statement) -import codecs -import configparser import inspect import logging import os @@ -17,8 +15,8 @@ standard_library.install_aliases() -from multiscanner.config import CONFIG as MS_CONFIG from multiscanner.common import utils +from multiscanner.config import MSConfigParser, get_config_path DEFAULTCONF = { @@ -89,56 +87,32 @@ def teardown(self): class StorageHandler(object): - def __init__(self, configfile=MS_CONFIG, config=None, configregen=False): + def __init__(self, configfile=None, config=None): self.storage_lock = threading.Lock() self.storage_counter = ThreadCounter() # Load all storage classes storage_classes = _get_storage_classes() + if configfile is None: + configfile = get_config_path('storage') # Read in config - if configfile: - configfile = utils.get_config_path(MS_CONFIG, 'storage') - config_object = configparser.ConfigParser() - config_object.optionxform = str - # Regen the config if needed or wanted - if configregen or not os.path.isfile(configfile): - _write_main_config(config_object) - _rewrite_config(storage_classes, config_object, configfile) - - config_object.read(configfile) - if config: - file_conf = utils.parse_config(config_object) - for key in config: - if key not in file_conf: - file_conf[key] = config[key] - file_conf[key]['_load_default'] = True - else: - file_conf[key].update(config[key]) - config = file_conf - else: - config = utils.parse_config(config_object) - else: - if config is None: - config = {} - for storage_name in storage_classes: - config[storage_name] = {} - config['_load_default'] = True - - self.sleep_time = config.get('main', {}).get('retry_time', DEFAULTCONF['retry_time']) - self.num_retries = config.get('main', {}).get('retry_num', DEFAULTCONF['retry_num']) + config_object = MSConfigParser() + config_object.read(configfile) + if config: + for key in config: + if key not in config_object: + config_object[key] = config[key] + else: + config_object[key].update(config[key]) + config = config_object + self.sleep_time = config.get('main', 'retry_time', fallback=DEFAULTCONF['retry_time']) + self.num_retries = config.get('main', 'retry_num', fallback=DEFAULTCONF['retry_num']) # Set the config inside of the storage classes for storage_name in storage_classes: if storage_name in config: - if '_load_default' in config or '_load_default' in config[storage_name]: - # Remove _load_default from config - if '_load_default' in config[storage_name]: - del config[storage_name]['_load_default'] - # Update the default storage config - storage_classes[storage_name].config = storage_classes[storage_name].DEFAULTCONF - storage_classes[storage_name].config.update(config[storage_name]) - else: - storage_classes[storage_name].config = config[storage_name] + storage_classes[storage_name].config = storage_classes[storage_name].DEFAULTCONF + storage_classes[storage_name].config.update(config[storage_name]) self.storage_classes = storage_classes self.loaded_storage = {} @@ -264,79 +238,13 @@ def is_done(self, wait=False): return self.storage_counter.is_done() -def config_init(filepath, overwrite=False, storage_classes=None): - if storage_classes is None: - storage_classes = _get_storage_classes() - config_object = configparser.ConfigParser() - config_object.optionxform = str - if overwrite or not os.path.isfile(filepath): - _write_main_config(config_object) - _rewrite_config(storage_classes, config_object, filepath) - else: - config_object.read(filepath) - _write_main_config(config_object) - _write_missing_config(config_object, filepath, storage_classes=storage_classes) - - -def _write_main_config(config_object): - if not config_object.has_section('main'): - # Write default config - config_object.add_section('main') - for key in DEFAULTCONF: - config_object.set('main', key, str(DEFAULTCONF[key])) - - -def _rewrite_config(storage_classes, config_object, filepath): - keys = list(storage_classes.keys()) - keys.sort() - for class_name in keys: - conf = storage_classes[class_name].DEFAULTCONF - config_object.add_section(class_name) - for key in conf: - config_object.set(class_name, key, str(conf[key])) - - with codecs.open(filepath, 'w', 'utf-8') as f: - config_object.write(f) - - -def _write_missing_config(config_object, filepath, storage_classes=None): - """ - Write in default config for modules not in config file. Returns True if config was written, False if not. - - config_object - The config object - filepath - The path to the config file - storage_classes - The dictionary object from _get_storage_classes. If None we call _get_storage_classes() - """ - if storage_classes is None: - storage_classes = _get_storage_classes() - ConfNeedsWrite = False - keys = list(storage_classes.keys()) - keys.sort() - for module in keys: - if module in config_object: - continue - try: - conf = module.DEFAULTCONF - except Exception as e: - logger.warning(e) - continue - ConfNeedsWrite = True - config_object.add_section(module) - for key in conf: - config_object.set(module, key, str(conf[key])) - - if ConfNeedsWrite: - with codecs.open(filepath, 'w', 'utf-8') as f: - config_object.write(f) - return True - return False - - def _get_storage_classes(dir_path=STORAGE_DIR): storage_classes = {} - dir_list = utils.parseDir(dir_path, recursive=True) + dir_list = utils.parse_dir(dir_path, recursive=True) dir_list.remove(os.path.join(dir_path, 'storage.py')) # dir_list.remove(os.path.join(dir_path, '__init__.py')) + # sql_driver is not configurable in storage.ini, and is used by the api + # and celery workers rather than by the StorageHandler dir_list.remove(os.path.join(dir_path, 'sql_driver.py')) for filename in dir_list: if filename.endswith('.py'): diff --git a/multiscanner/tests/module_tests/Metadefender/test_metadefender_module.py b/multiscanner/tests/module_tests/Metadefender/test_metadefender_module.py index 4c2d95fa..13d6fe1f 100644 --- a/multiscanner/tests/module_tests/Metadefender/test_metadefender_module.py +++ b/multiscanner/tests/module_tests/Metadefender/test_metadefender_module.py @@ -20,6 +20,8 @@ MSG_SERVER_UNAVAILABLE = 'Server unavailable, try again later' FILE_200_COMPLETE_REPORT = 'retrieval_responses/200_found_complete.json' FILE_200_INCOMPLETE_REPORT = 'retrieval_responses/200_found_incomplete.json' +MDF_GET = 'multiscanner.modules.antivirus.Metadefender.requests.get' +MDF_POST = 'multiscanner.modules.antivirus.Metadefender.requests.post' class MockResponse(object): @@ -146,7 +148,7 @@ def create_conf_short_timeout(self): # possible responses to sample submission requests # --------------------------------------------------------------------- - @mock.patch('Metadefender.requests.post', side_effect=mocked_requests_post_sample_submitted) + @mock.patch(MDF_POST, side_effect=mocked_requests_post_sample_submitted) def test_submit_sample_success(self, mock_get): ''' Tests Metadefender._submit_sample()'s handling of a successful response from @@ -158,7 +160,7 @@ def test_submit_sample_success(self, mock_get): self.assertEqual(submit_resp['error'], None) self.assertEqual(submit_resp['scan_id'], generate_scan_id(RANDOM_INPUT_FILES[0])) - @mock.patch('Metadefender.requests.post', side_effect=mocked_requests_post_sample_failed_w_msg) + @mock.patch(MDF_POST, side_effect=mocked_requests_post_sample_failed_w_msg) def test_submit_sample_fail_unavailable(self, mock_get): ''' Tests Metadefender._submit_sample()'s handling of a submission that fails due to @@ -170,7 +172,7 @@ def test_submit_sample_fail_unavailable(self, mock_get): self.assertEqual(submit_resp['error'], MSG_SERVER_UNAVAILABLE) self.assertEqual(submit_resp['scan_id'], None) - @mock.patch('Metadefender.requests.post', side_effect=mocked_requests_post_sample_failed_no_msg) + @mock.patch(MDF_POST, side_effect=mocked_requests_post_sample_failed_no_msg) def test_submit_sample_fail_unavailable_no_msg(self, mock_get): ''' Tests Metadefender._submit_sample()'s handling of a submission that fails due to @@ -186,7 +188,7 @@ def test_submit_sample_fail_unavailable_no_msg(self, mock_get): # This section tests the logic for parsing Metadefender's responses # to requests for analysis results # --------------------------------------------------------------------- - @mock.patch('Metadefender.requests.get', side_effect=mocked_requests_get_sample_200_success) + @mock.patch(MDF_GET, side_effect=mocked_requests_get_sample_200_success) def test_get_results_200_success(self, mock_get): ''' Tests Metadefender._parse_scan_result()'s handling of a complete @@ -218,7 +220,7 @@ def test_get_results_200_success(self, mock_get): else: self.fail('Unexpected Engine: %s' % engine_name) - @mock.patch('Metadefender.requests.get', side_effect=mocked_requests_get_sample_200_not_found) + @mock.patch(MDF_GET, side_effect=mocked_requests_get_sample_200_not_found) def test_get_results_200_not_found(self, mock_get): ''' Tests Metadefender._parse_scan_result()'s handling of a 200 response @@ -235,7 +237,7 @@ def test_get_results_200_not_found(self, mock_get): if len(engine_results) != 0: self.fail('Engine result list should be empty') - @mock.patch('Metadefender.requests.get', side_effect=mocked_requests_get_sample_200_in_progress) + @mock.patch(MDF_GET, side_effect=mocked_requests_get_sample_200_in_progress) def test_get_results_200_succes_in_progress(self, mock_get): ''' Tests Metadefender._parse_scan_result()'s handling of a 200 response @@ -256,8 +258,8 @@ def test_get_results_200_succes_in_progress(self, mock_get): # --------------------------------------------------------------------- # This section tests the entire scan() method # --------------------------------------------------------------------- - @mock.patch('Metadefender.requests.get', side_effect=mocked_requests_get_sample_200_success) - @mock.patch('Metadefender.requests.post', side_effect=mocked_requests_post_sample_submitted) + @mock.patch(MDF_GET, side_effect=mocked_requests_get_sample_200_success) + @mock.patch(MDF_POST, side_effect=mocked_requests_post_sample_submitted) def test_scan_complete_success(self, mock_post, mock_get): ''' Test for a perfect scan. No submission errors, no retrieval errors @@ -269,8 +271,8 @@ def test_scan_complete_success(self, mock_post, mock_get): for scan_res in resultlist: self.assertEqual(scan_res[1]['overall_status'], Metadefender.STATUS_SUCCESS) - @mock.patch('Metadefender.requests.get', side_effect=mocked_requests_get_sample_200_in_progress) - @mock.patch('Metadefender.requests.post', side_effect=mocked_requests_post_sample_submitted) + @mock.patch(MDF_GET, side_effect=mocked_requests_get_sample_200_in_progress) + @mock.patch(MDF_POST, side_effect=mocked_requests_post_sample_submitted) def test_scan_timeout_scan_in_progress(self, mock_post, mock_get): ''' Test for a scan where analysis time exceeds timeout period diff --git a/multiscanner/tests/modules/test_1.py b/multiscanner/tests/modules/test_1.py index 88569fcb..c4699dd5 100644 --- a/multiscanner/tests/modules/test_1.py +++ b/multiscanner/tests/modules/test_1.py @@ -3,13 +3,16 @@ """ TYPE = "Test" NAME = "test_1" +DEFAULTCONF = { + 'ENABLED': True +} -def check(): +def check(conf=DEFAULTCONF): return True -def scan(filelist): +def scan(filelist, conf=DEFAULTCONF): results = [] for fname in filelist: diff --git a/multiscanner/tests/modules/test_2.py b/multiscanner/tests/modules/test_2.py index 57cd7611..71e16663 100644 --- a/multiscanner/tests/modules/test_2.py +++ b/multiscanner/tests/modules/test_2.py @@ -4,7 +4,11 @@ TYPE = "Test" NAME = "test_2" REQUIRES = ["test_1"] -DEFAULTCONF = {'a': 1, 'b': 2} +DEFAULTCONF = { + 'ENABLED': True, + 'a': 1, + 'b': 2 +} def check(conf=DEFAULTCONF): diff --git a/multiscanner/tests/modules/test_subscan.py b/multiscanner/tests/modules/test_subscan.py index 8786999a..4dccff11 100644 --- a/multiscanner/tests/modules/test_subscan.py +++ b/multiscanner/tests/modules/test_subscan.py @@ -3,16 +3,19 @@ """ TYPE = "Test" NAME = "test_subscan" +DEFAULTCONF = { + 'ENABLED': True +} # Overwritten in multiscanner multiscanner = None -def check(): +def check(conf=DEFAULTCONF): return True -def scan(filelist): +def scan(filelist, conf=DEFAULTCONF): results = [] for f in filelist: diff --git a/multiscanner/tests/test_celery_worker.py b/multiscanner/tests/test_celery_worker.py index bab8174e..fa76db24 100644 --- a/multiscanner/tests/test_celery_worker.py +++ b/multiscanner/tests/test_celery_worker.py @@ -5,7 +5,6 @@ import mock import multiscanner -from multiscanner.common import utils from multiscanner.distributed import celery_worker from multiscanner.storage.sql_driver import Database @@ -20,12 +19,10 @@ # Get a subset of simple modules to run in testing # the celery worker -MODULE_LIST = utils.parseDir(multiscanner.MODULESDIR, recursive=True) -DESIRED_MODULES = [ - 'filemeta.py', - 'ssdeep.py' +MODULES_TO_TEST = [ + 'filemeta', + 'ssdeep' ] -MODULES_TO_TEST = [i for e in DESIRED_MODULES for i in MODULE_LIST if e in i] TEST_DB_PATH = os.path.join(CWD, 'testing.db') @@ -41,7 +38,7 @@ with open(TEST_FULL_PATH, 'r') as f: TEST_FILE_HASH = hashlib.sha256(f.read().encode('utf-8')).hexdigest() TEST_METADATA = {} -TEST_CONFIG = multiscanner.CONFIG +TEST_CONFIG = multiscanner.config.MS_CONFIG TEST_REPORT = { 'filemeta': { @@ -60,10 +57,6 @@ def post_file(app): data={'file': (BytesIO(b'my file contents'), 'hello world.txt'), }) -# def mock_delay(file_, original_filename, task_id, f_name, metadata, config): -# return TEST_REPORT - - class CeleryTestCase(unittest.TestCase): def setUp(self): self.sql_db = Database(config=DB_CONF) @@ -79,7 +72,6 @@ def tearDown(self): class TestCeleryCase(CeleryTestCase): def setUp(self): super(self.__class__, self).setUp() - # api.multiscanner_celery.delay = mock_delay def test_base(self): self.assertEqual(True, True) diff --git a/multiscanner/tests/test_common_lib.py b/multiscanner/tests/test_common_lib.py index 30fe1741..6143ccf8 100644 --- a/multiscanner/tests/test_common_lib.py +++ b/multiscanner/tests/test_common_lib.py @@ -58,8 +58,8 @@ def test_basename_win_path(): assert result == 'd' -def test_parseDir(): +def test_parse_dir(): path = os.path.abspath(os.path.join(MS_WD, 'tests', 'dir_test')) - result = utils.parseDir(path, recursive=False) + result = utils.parse_dir(path, recursive=False) expected = [os.path.join(path, '1.1.txt'), os.path.join(path, '1.2.txt')] assert sorted(result) == sorted(expected) diff --git a/multiscanner/tests/test_configs.py b/multiscanner/tests/test_configs.py index e8e5d874..59a43512 100644 --- a/multiscanner/tests/test_configs.py +++ b/multiscanner/tests/test_configs.py @@ -1,4 +1,5 @@ from __future__ import division, absolute_import, with_statement, print_function, unicode_literals +import mock import os import tempfile @@ -8,41 +9,82 @@ # Makes sure we use the multiscanner in ../ CWD = os.path.dirname(os.path.abspath(__file__)) -module_list = [os.path.join(CWD, 'modules', 'test_conf.py')] -filelist = utils.parseDir(os.path.join(CWD, 'files')) +mock_modlist = {'test_conf': [True, os.path.join(CWD, 'modules')]} +mock_modlist2 = {'test_conf': [True, os.path.join(CWD, 'modules')], + 'test_1': [True, os.path.join(CWD, 'modules')], + 'test_2': [True, os.path.join(CWD, 'modules')]} +filelist = utils.parse_dir(os.path.join(CWD, 'files')) +module_list = ['test_conf'] +def test_config_parse_round_trip(): + conf_dict = {'test': {'a': 'b', 'c': 'd'}} + conf_parser = multiscanner.config.dict_to_config(conf_dict) + assert conf_parser.get('test', 'a') == 'b' + conf_dict2 = multiscanner.config.parse_config(conf_parser) + assert conf_dict == conf_dict2 + + +@mock.patch('multiscanner.config.MODULE_LIST', mock_modlist) def test_no_config(): results, metadata = multiscanner.multiscan( - filelist, configfile=None, config=None, - recursive=None, module_list=module_list)[0] + filelist, config=None, + module_list=module_list)[0] assert metadata['conf'] == {'a': 'b', 'c': 'd'} +@mock.patch('multiscanner.config.MODULE_LIST', mock_modlist) def test_config_api_no_file(): - config = {'test_conf': {'a': 'z'}} + config = multiscanner.config.dict_to_config({'test_conf': {'a': 'z'}}) results, metadata = multiscanner.multiscan( - filelist, configfile=None, config=config, - recursive=None, module_list=module_list)[0] + filelist, config=config, + module_list=module_list)[0] assert metadata['conf'] == {'a': 'z', 'c': 'd'} +@mock.patch('multiscanner.config.MODULE_LIST', mock_modlist) def test_config_api_with_empty_file(): - config = {'test_conf': {'a': 'z'}} + config = multiscanner.config.dict_to_config({'test_conf': {'a': 'z'}}) config_file = tempfile.mkstemp()[1] + multiscanner.update_ms_config_file(config_file) results, metadata = multiscanner.multiscan( - filelist, configfile=config_file, config=config, - recursive=None, module_list=module_list)[0] + filelist, config=config, + module_list=module_list)[0] os.remove(config_file) assert metadata['conf'] == {'a': 'z', 'c': 'd'} +@mock.patch('multiscanner.config.MODULE_LIST', mock_modlist) def test_config_api_with_real_file(): - config = {'test_conf': {'a': 'z'}} + config = multiscanner.config.dict_to_config({'test_conf': {'a': 'z'}}) config_file = tempfile.mkstemp()[1] - multiscanner.config_init(config_file) + module_list = multiscanner._get_main_modules() + multiscanner.config_init(config_file, module_list) + multiscanner.update_ms_config_file(config_file) results, metadata = multiscanner.multiscan( - filelist, configfile=config_file, config=config, - recursive=None, module_list=module_list)[0] + filelist, config=config, + module_list=module_list)[0] os.remove(config_file) assert metadata['conf'] == {'a': 'z', 'c': 'd'} + + +@mock.patch('multiscanner.config.MODULE_LIST', mock_modlist2) +def test_config_reset_not_overwrite(): + config_file = tempfile.mkstemp()[1] + module_list = multiscanner._get_main_modules() + multiscanner.config_init(config_file, module_list) + multiscanner.update_ms_config_file(config_file) + + # Change a config val from default + config_object = multiscanner.MSConfigParser() + config_object.read(config_file) + config_object.set('test_2', 'ENABLED', 'False') + with open(config_file, 'w') as conf_file: + config_object.write(conf_file) + + # call config_init with overwrite=true, but since test_2 isn't in the module list it won't be overwritten + del module_list['test_2'] + multiscanner.config.config_init(config_file, module_list, True) + multiscanner.update_ms_config_file(config_file) + os.remove(config_file) + assert multiscanner.config.MS_CONFIG.get('test_2', 'ENABLED') is False diff --git a/multiscanner/tests/test_module_interface.py b/multiscanner/tests/test_module_interface.py index eedbc1e4..df5882af 100644 --- a/multiscanner/tests/test_module_interface.py +++ b/multiscanner/tests/test_module_interface.py @@ -1,19 +1,22 @@ from __future__ import division, absolute_import, print_function, unicode_literals +import mock import os import multiscanner CWD = os.path.dirname(os.path.abspath(__file__)) +mock_modlist = {'test_subscan': [True, os.path.join(CWD, 'modules')]} def add_int(x, y): return x + y +@mock.patch('multiscanner.config.MODULE_LIST', mock_modlist) def test_subscan(): m = multiscanner.multiscan( - ['fake.zip'], recursive=None, configfile=None, - module_list=[os.path.join(CWD, 'modules', 'test_subscan.py')]) + ['fake.zip'], + module_list=['test_subscan']) assert m == [([(u'fake.zip', 0)], {'Type': 'Test', 'Name': 'test_subscan'}), ([(u'fake.zip/0', u'fake.zip')], {u'Include': False, u'Type': u'subscan', u'Name': u'Parent'}), ([(u'fake.zip', [u'fake.zip/0'])], {u'Include': False, u'Type': u'subscan', u'Name': u'Children'}), ([(u'fake.zip/0', u'test_subscan')], {u'Include': False, u'Type': u'subscan', u'Name': u'Created by'}), ([(u'fake.zip/0', 1)], {'Type': 'Test', 'Name': 'test_subscan'}), ([(u'fake.zip/0/1', u'fake.zip/0')], {u'Include': False, u'Type': u'subscan', u'Name': u'Parent'}), ([(u'fake.zip/0', [u'fake.zip/0/1'])], {u'Include': False, u'Type': u'subscan', u'Name': u'Children'}), ([(u'fake.zip/0/1', u'test_subscan')], {u'Include': False, u'Type': u'subscan', u'Name': u'Created by'}), ([(u'fake.zip/0/1', 2)], {'Type': 'Test', 'Name': 'test_subscan'})] # noqa: E501 diff --git a/multiscanner/tests/test_modules.py b/multiscanner/tests/test_modules.py index 395b25c3..43a71e4c 100644 --- a/multiscanner/tests/test_modules.py +++ b/multiscanner/tests/test_modules.py @@ -26,20 +26,20 @@ def test_fail_loadModule(): class _runmod_tests(object): @classmethod def setup_class(cls): - cls.real_mod_dir = multiscanner.MODULESDIR - multiscanner.MODULESDIR = os.path.join(CWD, "modules") - cls.filelist = utils.parseDir(os.path.join(CWD, 'files')) + cls.real_mod_dir = multiscanner.config.MODULES_DIR + multiscanner.config.MODULES_DIR = os.path.join(CWD, "modules") + cls.filelist = utils.parse_dir(os.path.join(CWD, 'files')) cls.files = ['a', 'b', 'C:\\c', '/d/d'] cls.threadDict = {} @classmethod def teardown_class(cls): - multiscanner.MODULESDIR = cls.real_mod_dir + multiscanner.config.MODULES_DIR = cls.real_mod_dir class Test_runModule_test_1(_runmod_tests): def setup(self): - m = utils.load_module('test_1', [multiscanner.MODULESDIR]) + m = utils.load_module('test_1', [multiscanner.config.MODULES_DIR]) global_module_interface = multiscanner._GlobalModuleInterface() self.result = multiscanner._run_module('test_1', m, self.filelist, self.threadDict, global_module_interface) global_module_interface._cleanup() @@ -55,7 +55,7 @@ def test_runModule_results(self): class Test_runModule_test_2(_runmod_tests): def setup(self): - self.m = utils.load_module('test_2', [multiscanner.MODULESDIR]) + self.m = utils.load_module('test_2', [multiscanner.config.MODULES_DIR]) self.threadDict['test_2'] = mock.Mock() self.threadDict['test_1'] = mock.Mock() self.threadDict['test_1'].ret = ([('a', 'a'), ('C:\\c', 'c')], {}) @@ -100,7 +100,7 @@ def teardown(self): def test_all_started(self): ThreadList = multiscanner._start_module_threads( - self.filelist, utils.parseDir(os.path.join(CWD, "modules")), self.config, self.global_module_interface) + self.filelist, utils.parse_dir(os.path.join(CWD, "modules")), self.config, self.global_module_interface) time.sleep(.001) for t in ThreadList: assert t.started diff --git a/multiscanner/tests/test_multiscanner.py b/multiscanner/tests/test_multiscanner.py index 21b07333..8ac437c5 100644 --- a/multiscanner/tests/test_multiscanner.py +++ b/multiscanner/tests/test_multiscanner.py @@ -1,4 +1,5 @@ from __future__ import division, absolute_import, print_function, unicode_literals +import mock import os import sys @@ -8,29 +9,36 @@ # Makes sure we use the multiscanner in ../ CWD = os.path.dirname(os.path.abspath(__file__)) +TEST_CONFIG_FILEPATH = '.tmpfile.ini' +TEST_REPORT = 'tmp_report.json' +TEST_FILES = os.path.join(CWD, 'files') + class _runmulti_tests(object): @classmethod def setup_class(cls): - cls.real_mod_dir = multiscanner.MODULESDIR - multiscanner.MODULEDIR = os.path.join(CWD, "modules") - cls.filelist = utils.parseDir(os.path.join(CWD, 'files')) - multiscanner.CONFIG = '.tmpfile.ini' + cls.real_mod_dir = multiscanner.config.MODULES_DIR + cls.real_mod_list = multiscanner.config.MODULE_LIST + multiscanner.config.MODULES_DIR = os.path.join(CWD, "modules") + multiscanner.config.MODULE_LIST = multiscanner.config.get_modules() + cls.filelist = utils.parse_dir(TEST_FILES) @classmethod def teardown_class(cls): - multiscanner.MODULESDIR = cls.real_mod_dir + multiscanner.config.MODULES_DIR = cls.real_mod_dir + multiscanner.config.MODULE_LIST = cls.real_mod_list -class Test_multiscan(_runmulti_tests): +class TestMultiscan(_runmulti_tests): def setup(self): - self.result = multiscanner.multiscan( - self.filelist, recursive=False, configregen=False, configfile='.tmpfile.ini') + multiscanner.config_init(TEST_CONFIG_FILEPATH, multiscanner._get_main_modules()) + multiscanner.update_ms_config_file(TEST_CONFIG_FILEPATH) + self.result = multiscanner.multiscan(self.filelist) self.report = multiscanner.parse_reports(self.result, includeMetadata=False, python=True) self.report_m = multiscanner.parse_reports(self.result, includeMetadata=True, python=True) def teardown(self): - os.remove('.tmpfile.ini') + os.remove(TEST_CONFIG_FILEPATH) def test_multiscan_results(self): for f in self.filelist: @@ -38,19 +46,134 @@ def test_multiscan_results(self): assert f in self.report_m['Files'] -class Test_main(_runmulti_tests): +class TestMain(_runmulti_tests): def setup(self): - sys.argv = [''] + multiscanner.config_init(TEST_CONFIG_FILEPATH, multiscanner._get_main_modules()) + multiscanner.update_ms_config_file(TEST_CONFIG_FILEPATH) def teardown(self): try: - os.remove('.tmpfile.ini') - os.remove('tmp_report.json') + os.remove(TEST_CONFIG_FILEPATH) + os.remove(TEST_REPORT) except Exception as e: # TODO: log exception pass def test_basic_main(self): - sys.argv = ['-z', '-j', 'tmp_report.json'] - sys.argv.extend(self.filelist) - multiscanner._main() + with mock.patch.object(sys, 'argv', ['ms.py', '-j', TEST_REPORT] + self.filelist): + try: + multiscanner._main() + except SystemExit: + pass + + +@mock.patch.object(multiscanner.config, 'CONFIG_FILEPATH', TEST_CONFIG_FILEPATH) +class TestInitNoConfig(_runmulti_tests): + def test_basic_main(self): + with mock.patch.object(sys, 'argv', ['ms.py', 'init']), \ + mock.patch('multiscanner.ms.input', return_value='y'): + try: + multiscanner._main() + except SystemExit: + pass + + with mock.patch.object(sys, 'argv', ['ms.py', '-j', TEST_REPORT] + self.filelist): + try: + multiscanner._main() + except SystemExit: + pass + + def teardown(self): + try: + os.remove(TEST_CONFIG_FILEPATH) + os.remove(TEST_REPORT) + except Exception as e: + # TODO: log exception + pass + + +class TestMissingConfig(_runmulti_tests): + def setup(self): + with mock.patch.object(sys, 'argv', ['ms.py', '-c', TEST_CONFIG_FILEPATH, 'init']), \ + mock.patch('multiscanner.ms.input', return_value='y'): + try: + multiscanner._main() + except SystemExit: + pass + + def test_config_init(self): + config_object = multiscanner.MSConfigParser() + config_object.read(TEST_CONFIG_FILEPATH) + + assert config_object.has_section('main') + assert config_object.has_section('test_1') + assert not config_object.has_section('Cuckoo') + + def test_fill_in_missing_config_sections(self): + # Simulate a section missing from config file before multiscanner is imported/run + config_object = multiscanner.MSConfigParser() + config_object.read(TEST_CONFIG_FILEPATH) + config_object.remove_section('main') + config_object.remove_section('test_1') + with open(TEST_CONFIG_FILEPATH, 'w') as conf_file: + config_object.write(conf_file) + multiscanner.update_ms_config_file(TEST_CONFIG_FILEPATH) + + # Run MultiScanner + with mock.patch.object(sys, 'argv', ['ms.py', '-c', TEST_CONFIG_FILEPATH, TEST_FILES]): + multiscanner._main() + with open(TEST_CONFIG_FILEPATH, 'r') as conf_file: + conf = conf_file.read() + assert 'test_1' in conf + + def test_read_config_with_default(self): + multiscanner.config.read_config(TEST_CONFIG_FILEPATH, {'test': {'foo': 'bar'}}) + with open(TEST_CONFIG_FILEPATH, 'r') as conf_file: + conf = conf_file.read() + assert 'foo' in conf + + def test_overwriting_config_on_reset(self): + # Change a config val from default + config_object = multiscanner.MSConfigParser() + config_object.read(TEST_CONFIG_FILEPATH) + config_object.set('test_2', 'ENABLED', 'False') + with open(TEST_CONFIG_FILEPATH, 'w') as conf_file: + config_object.write(conf_file) + + # Trigger reset_config and it gets overwritten + self.setup() + with mock.patch.object(sys, 'argv', ['ms.py', '-c', TEST_CONFIG_FILEPATH, '-j', TEST_REPORT, self.filelist[0]]): + try: + multiscanner._main() + except SystemExit: + pass + + with open(TEST_REPORT, 'r') as report_file: + report = report_file.read() + assert 'test_2' in report + + # teardown + os.remove(TEST_REPORT) + + def test_config_init_no_overwrite(self): + # Remove a section from config file + config_object = multiscanner.MSConfigParser() + config_object.read(TEST_CONFIG_FILEPATH) + config_object.remove_section('test_1') + with open(TEST_CONFIG_FILEPATH, 'w') as conf_file: + config_object.write(conf_file) + + # this time we answer 'no' so config won't be overwritten, but missing modules' configs will be regenerated + with mock.patch.object(sys, 'argv', ['ms.py', '-c', TEST_CONFIG_FILEPATH, 'init']), \ + mock.patch('multiscanner.ms.input', return_value='n'): + try: + multiscanner._main() + except SystemExit: + pass + + with open(TEST_CONFIG_FILEPATH, 'r') as conf_file: + conf = conf_file.read() + assert 'test_1' in conf + + def teardown(self): + os.remove(TEST_CONFIG_FILEPATH) diff --git a/multiscanner/utils/cython_compile_libs.py b/multiscanner/utils/cython_compile_libs.py index 38a358d3..251a1659 100644 --- a/multiscanner/utils/cython_compile_libs.py +++ b/multiscanner/utils/cython_compile_libs.py @@ -16,7 +16,7 @@ def main(): - filelist = utils.parseFileList([LIBS], recursive=True) + filelist = utils.parse_file_list([LIBS], recursive=True) try: import pefile filepath = pefile.__file__[:-1] diff --git a/multiscanner/web/app.py b/multiscanner/web/app.py index 5ee34167..64b76d6b 100644 --- a/multiscanner/web/app.py +++ b/multiscanner/web/app.py @@ -1,13 +1,8 @@ -import codecs -from collections import namedtuple -import configparser from flask import Flask, render_template, request -import os import re -from multiscanner import CONFIG as MS_CONFIG from multiscanner import __version__ -from multiscanner.common import utils +from multiscanner.config import get_config_path, read_config DEFAULTCONF = { 'HOST': "localhost", @@ -32,21 +27,9 @@ app = Flask(__name__) # Finagle Flask to read config from .ini file instead of .py file -web_config_object = configparser.ConfigParser() -web_config_object.optionxform = str -web_config_file = utils.get_config_path(MS_CONFIG, 'web') -web_config_object.read(web_config_file) -if not web_config_object.has_section('web') or not os.path.isfile(web_config_file): - # Write default config - web_config_object.add_section('web') - for key in DEFAULTCONF: - web_config_object.set('web', key, str(DEFAULTCONF[key])) - conffile = codecs.open(web_config_file, 'w', 'utf-8') - web_config_object.write(conffile) - conffile.close() -web_config = utils.parse_config(web_config_object)['web'] -conf_tuple = namedtuple('WebConfig', web_config.keys())(*web_config.values()) -app.config.from_object(conf_tuple) +web_config_file = get_config_path('web') +web_config = read_config(web_config_file, {'web': DEFAULTCONF}).get_section('web') +app.config.update(**web_config) @app.context_processor diff --git a/multiscanner/web/templates/index.html b/multiscanner/web/templates/index.html index e966474d..bda257c1 100644 --- a/multiscanner/web/templates/index.html +++ b/multiscanner/web/templates/index.html @@ -134,7 +134,7 @@ obj['duplicate'] = duplicate_action; // Modules options var moduleList = $("#module-opts input:checked").map(function(){return $(this).attr("name");}); - obj['modules'] = moduleList; + obj['modules'] = moduleList.toArray(); // Archive options if ($('#archive-analyze').is(':checked')) { obj['archive-analyze'] = 'true'; @@ -233,15 +233,15 @@ }) // Add options for selecting which modules to run - //$.get("{{ api_loc }}/api/v2/modules", function(data) { - // var modules = '
Select which modules to use:
'; - // for (mod in data) { - // checked = (data[mod] ? "checked" : ""); - // modules += 'Select which modules to use:
'; + for (mod in data) { + checked = (data[mod] ? "checked" : ""); + modules += '