From 2a8946e599da2a73333330aeb73b69833515027a Mon Sep 17 00:00:00 2001 From: Hirokuni Kitahara Date: Tue, 28 May 2024 17:28:09 +0900 Subject: [PATCH] add --scan-per-target option (#236) Signed-off-by: hirokuni-kitahara --- ansible_risk_insight/cli/__init__.py | 102 ++++++++++++++++++++------ ansible_risk_insight/finder.py | 87 ++++++++++++++++++++++ ansible_risk_insight/risk_detector.py | 32 +++++++- 3 files changed, 198 insertions(+), 23 deletions(-) diff --git a/ansible_risk_insight/cli/__init__.py b/ansible_risk_insight/cli/__init__.py index 3de40bff..a2963804 100644 --- a/ansible_risk_insight/cli/__init__.py +++ b/ansible_risk_insight/cli/__init__.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import json import argparse from ..scanner import ARIScanner, config @@ -24,6 +26,7 @@ get_role_metadata, split_name_and_version, ) +from ..finder import list_scan_target class ARICLI: @@ -62,6 +65,16 @@ def __init__(self): parser.add_argument( "--save-only-rule-result", action="store_true", help="if true, save only rule results and remove node details to reduce result file size" ) + parser.add_argument( + "--scan-per-target", + action="store_true", + help="if true, do scanning per playbook, role or taskfile (this reduces memory usage while scanning)", + ) + parser.add_argument( + "--task-num-threshold", + default="100", + help="A threshold number to give up scanning a file where the number of tasks exceeds this (default to 100)", + ) parser.add_argument("-o", "--out-dir", help="output directory for the rule evaluation result") parser.add_argument( "-r", "--rules-dir", help=f"specify custom rule directories. use `-R` instead to ignore default rules in {config.rules_dir}" @@ -158,25 +171,72 @@ def run(self): pretty=pretty, output_format=output_format, ) - if not silent and not pretty: - print("Start preparing dependencies") - root_install = not args.skip_install - if not silent and not pretty: + + if args.scan_per_target: + c.silent = True + task_num_threshold = int(args.task_num_threshold) + print("Listing scan targets (This might take several minutes for a large proejct)") + targets = list_scan_target(root_dir=target_name, task_num_threshold=task_num_threshold) print("Start scanning") - c.evaluate( - type=args.target_type, - name=target_name, - version=target_version, - install_dependencies=root_install, - dependency_dir=args.dependency_dir, - collection_name=collection_name, - role_name=role_name, - source_repository=args.source, - playbook_only=args.playbook_only, - taskfile_only=args.taskfile_only, - include_test_contents=args.include_tests, - load_all_taskfiles=load_all_taskfiles, - save_only_rule_result=save_only_rule_result, - objects=args.objects, - out_dir=args.out_dir, - ) + total = len(targets) + file_list = {"playbook": [], "role": [], "taskfile": []} + for i, target_info in enumerate(targets): + fpath = target_info["filepath"] + fpath_from_root = target_info["path_from_root"] + scan_type = target_info["scan_type"] + count_in_type = len(file_list[scan_type]) + print(f"\r[{i+1}/{total}] {scan_type} {fpath_from_root} ", end="") + out_dir = os.path.join(args.out_dir, f"{scan_type}s", str(count_in_type)) + c.evaluate( + type=scan_type, + name=fpath, + target_path=fpath, + version=target_version, + install_dependencies=False, + dependency_dir=args.dependency_dir, + collection_name=collection_name, + role_name=role_name, + source_repository=args.source, + playbook_only=True, + taskfile_only=True, + include_test_contents=args.include_tests, + load_all_taskfiles=load_all_taskfiles, + save_only_rule_result=save_only_rule_result, + objects=args.objects, + out_dir=out_dir, + ) + file_list[scan_type].append(fpath_from_root) + print("") + for scan_type, list_per_type in file_list.items(): + index_data = {} + if not list_per_type: + continue + for i, fpath in enumerate(list_per_type): + index_data[i] = fpath + list_file_path = os.path.join(args.out_dir, f"{scan_type}s", "index.json") + with open(list_file_path, "w") as file: + json.dump(index_data, file) + + else: + if not silent and not pretty: + print("Start preparing dependencies") + root_install = not args.skip_install + if not silent and not pretty: + print("Start scanning") + c.evaluate( + type=args.target_type, + name=target_name, + version=target_version, + install_dependencies=root_install, + dependency_dir=args.dependency_dir, + collection_name=collection_name, + role_name=role_name, + source_repository=args.source, + playbook_only=args.playbook_only, + taskfile_only=args.taskfile_only, + include_test_contents=args.include_tests, + load_all_taskfiles=load_all_taskfiles, + save_only_rule_result=save_only_rule_result, + objects=args.objects, + out_dir=args.out_dir, + ) diff --git a/ansible_risk_insight/finder.py b/ansible_risk_insight/finder.py index 215b5407..4bdd1653 100644 --- a/ansible_risk_insight/finder.py +++ b/ansible_risk_insight/finder.py @@ -644,3 +644,90 @@ def label_yml_file(yml_path: str = "", yml_body: str = "", task_num_thresh: int else: label = "others" return label, name_count, None + + +def get_yml_label(file_path, root_path, task_num_threshold: int = -1): + relative_path = file_path.replace(root_path, "") + if relative_path[-1] == "/": + relative_path = relative_path[:-1] + + label, _, error = label_yml_file(file_path, task_num_thresh=task_num_threshold) + role_name, role_path = get_role_info_from_path(file_path) + role_info = None + if role_name and role_path: + role_info = {"name": role_name, "path": role_path} + + project_name, project_path = get_project_info_for_file(file_path, root_path) + project_info = None + if project_name and project_path: + project_info = {"name": project_name, "path": project_path} + + # print(f"[{label}] {relative_path} {role_info}") + if error: + logger.debug(f"failed to get yml label:\n {error}") + label = "error" + return label, role_info, project_info + + +def get_yml_list(root_dir: str, task_num_threshold: int = -1): + found_ymls = find_all_ymls(root_dir) + all_files = [] + for yml_path in found_ymls: + label, role_info, project_info = get_yml_label(yml_path, root_dir, task_num_threshold) + if not role_info: + role_info = {} + if not project_info: + project_info = {} + if role_info: + if role_info["path"] and not role_info["path"].startswith(root_dir): + role_info["path"] = os.path.join(root_dir, role_info["path"]) + role_info["is_external_dependency"] = True if "." in role_info["name"] else False + in_role = True if role_info else False + in_project = True if project_info else False + all_files.append( + { + "filepath": yml_path, + "path_from_root": yml_path.replace(root_dir, "").lstrip("/"), + "label": label, + "role_info": role_info, + "project_info": project_info, + "in_role": in_role, + "in_project": in_project, + } + ) + return all_files + + +def list_scan_target(root_dir: str, task_num_threshold: int = -1): + yml_list = get_yml_list(root_dir=root_dir, task_num_threshold=task_num_threshold) + known_roles = set() + all_targets = [] + for yml_info in yml_list: + if yml_info["label"] not in ["playbook", "taskfile"]: + continue + role_path = "" + if yml_info["in_role"]: + role_path = yml_info["role_info"].get("path", None) + if role_path and role_path in known_roles: + continue + scan_type = "" + filepath = "" + path_from_root = "" + if role_path: + scan_type = "role" + filepath = role_path + path_from_root = role_path.replace(root_dir, "").lstrip("/") + known_roles.add(role_path) + else: + scan_type = yml_info["label"] + filepath = yml_info["filepath"] + path_from_root = yml_info["path_from_root"] + target_info = { + "filepath": filepath, + "path_from_root": path_from_root, + "scan_type": scan_type, + } + all_targets.append(target_info) + all_targets = sorted(all_targets, key=lambda x: x["filepath"]) + all_targets = sorted(all_targets, key=lambda x: x["scan_type"]) + return all_targets diff --git a/ansible_risk_insight/risk_detector.py b/ansible_risk_insight/risk_detector.py index a5b8161f..26fae444 100644 --- a/ansible_risk_insight/risk_detector.py +++ b/ansible_risk_insight/risk_detector.py @@ -22,7 +22,18 @@ import time import ansible_risk_insight.logger as logger -from .models import AnsibleRunContext, ARIResult, TargetResult, NodeResult, RuleResult, Rule, SpecMutation, FatalRuleResultError +from .models import ( + AnsibleRunContext, + ARIResult, + TargetResult, + NodeResult, + RuleResult, + Rule, + SpecMutation, + FatalRuleResultError, + RunTarget, + TaskCall, +) from .keyutil import detect_type, key_delimiter from .analyzer import load_taskcalls_in_trees from .utils import load_classes_in_dir @@ -204,7 +215,7 @@ def detect(contexts: List[AnsibleRunContext], rules_dir: str = "", rules: list = n_result.rules.append(r_result) # remove node details if save_only_rule_result: - n_result.node = None + n_result.node = omit_node_details(n_result.node) t_result.nodes.append(n_result) ari_result.targets.append(t_result) @@ -214,6 +225,23 @@ def detect(contexts: List[AnsibleRunContext], rules_dir: str = "", rules: list = return data_report, loaded_rules +def omit_node_details(node: RunTarget): + spec = None + if getattr(node, "spec"): + spec = { + "type": getattr(node.spec, "type"), + "name": getattr(node.spec, "name"), + "defined_in": getattr(node.spec, "defined_in"), + } + if isinstance(node, TaskCall): + spec["line_num_in_file"] = (getattr(node.spec, "line_num_in_file"),) + summary = { + "type": node.type, + "spec": spec, + } + return summary + + def main(): parser = argparse.ArgumentParser( prog="risk_detector.py",