Skip to content

Commit

Permalink
add --scan-per-target option (#236)
Browse files Browse the repository at this point in the history
Signed-off-by: hirokuni-kitahara <[email protected]>
  • Loading branch information
hirokuni-kitahara authored May 28, 2024
1 parent a2dd8d1 commit 2a8946e
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 23 deletions.
102 changes: 81 additions & 21 deletions ansible_risk_insight/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import argparse

from ..scanner import ARIScanner, config
Expand All @@ -24,6 +26,7 @@
get_role_metadata,
split_name_and_version,
)
from ..finder import list_scan_target


class ARICLI:
Expand Down Expand Up @@ -62,6 +65,16 @@ def __init__(self):
parser.add_argument(
"--save-only-rule-result", action="store_true", help="if true, save only rule results and remove node details to reduce result file size"
)
parser.add_argument(
"--scan-per-target",
action="store_true",
help="if true, do scanning per playbook, role or taskfile (this reduces memory usage while scanning)",
)
parser.add_argument(
"--task-num-threshold",
default="100",
help="A threshold number to give up scanning a file where the number of tasks exceeds this (default to 100)",
)
parser.add_argument("-o", "--out-dir", help="output directory for the rule evaluation result")
parser.add_argument(
"-r", "--rules-dir", help=f"specify custom rule directories. use `-R` instead to ignore default rules in {config.rules_dir}"
Expand Down Expand Up @@ -158,25 +171,72 @@ def run(self):
pretty=pretty,
output_format=output_format,
)
if not silent and not pretty:
print("Start preparing dependencies")
root_install = not args.skip_install
if not silent and not pretty:

if args.scan_per_target:
c.silent = True
task_num_threshold = int(args.task_num_threshold)
print("Listing scan targets (This might take several minutes for a large proejct)")
targets = list_scan_target(root_dir=target_name, task_num_threshold=task_num_threshold)
print("Start scanning")
c.evaluate(
type=args.target_type,
name=target_name,
version=target_version,
install_dependencies=root_install,
dependency_dir=args.dependency_dir,
collection_name=collection_name,
role_name=role_name,
source_repository=args.source,
playbook_only=args.playbook_only,
taskfile_only=args.taskfile_only,
include_test_contents=args.include_tests,
load_all_taskfiles=load_all_taskfiles,
save_only_rule_result=save_only_rule_result,
objects=args.objects,
out_dir=args.out_dir,
)
total = len(targets)
file_list = {"playbook": [], "role": [], "taskfile": []}
for i, target_info in enumerate(targets):
fpath = target_info["filepath"]
fpath_from_root = target_info["path_from_root"]
scan_type = target_info["scan_type"]
count_in_type = len(file_list[scan_type])
print(f"\r[{i+1}/{total}] {scan_type} {fpath_from_root} ", end="")
out_dir = os.path.join(args.out_dir, f"{scan_type}s", str(count_in_type))
c.evaluate(
type=scan_type,
name=fpath,
target_path=fpath,
version=target_version,
install_dependencies=False,
dependency_dir=args.dependency_dir,
collection_name=collection_name,
role_name=role_name,
source_repository=args.source,
playbook_only=True,
taskfile_only=True,
include_test_contents=args.include_tests,
load_all_taskfiles=load_all_taskfiles,
save_only_rule_result=save_only_rule_result,
objects=args.objects,
out_dir=out_dir,
)
file_list[scan_type].append(fpath_from_root)
print("")
for scan_type, list_per_type in file_list.items():
index_data = {}
if not list_per_type:
continue
for i, fpath in enumerate(list_per_type):
index_data[i] = fpath
list_file_path = os.path.join(args.out_dir, f"{scan_type}s", "index.json")
with open(list_file_path, "w") as file:
json.dump(index_data, file)

else:
if not silent and not pretty:
print("Start preparing dependencies")
root_install = not args.skip_install
if not silent and not pretty:
print("Start scanning")
c.evaluate(
type=args.target_type,
name=target_name,
version=target_version,
install_dependencies=root_install,
dependency_dir=args.dependency_dir,
collection_name=collection_name,
role_name=role_name,
source_repository=args.source,
playbook_only=args.playbook_only,
taskfile_only=args.taskfile_only,
include_test_contents=args.include_tests,
load_all_taskfiles=load_all_taskfiles,
save_only_rule_result=save_only_rule_result,
objects=args.objects,
out_dir=args.out_dir,
)
87 changes: 87 additions & 0 deletions ansible_risk_insight/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,90 @@ def label_yml_file(yml_path: str = "", yml_body: str = "", task_num_thresh: int
else:
label = "others"
return label, name_count, None


def get_yml_label(file_path, root_path, task_num_threshold: int = -1):
relative_path = file_path.replace(root_path, "")
if relative_path[-1] == "/":
relative_path = relative_path[:-1]

label, _, error = label_yml_file(file_path, task_num_thresh=task_num_threshold)
role_name, role_path = get_role_info_from_path(file_path)
role_info = None
if role_name and role_path:
role_info = {"name": role_name, "path": role_path}

project_name, project_path = get_project_info_for_file(file_path, root_path)
project_info = None
if project_name and project_path:
project_info = {"name": project_name, "path": project_path}

# print(f"[{label}] {relative_path} {role_info}")
if error:
logger.debug(f"failed to get yml label:\n {error}")
label = "error"
return label, role_info, project_info


def get_yml_list(root_dir: str, task_num_threshold: int = -1):
found_ymls = find_all_ymls(root_dir)
all_files = []
for yml_path in found_ymls:
label, role_info, project_info = get_yml_label(yml_path, root_dir, task_num_threshold)
if not role_info:
role_info = {}
if not project_info:
project_info = {}
if role_info:
if role_info["path"] and not role_info["path"].startswith(root_dir):
role_info["path"] = os.path.join(root_dir, role_info["path"])
role_info["is_external_dependency"] = True if "." in role_info["name"] else False
in_role = True if role_info else False
in_project = True if project_info else False
all_files.append(
{
"filepath": yml_path,
"path_from_root": yml_path.replace(root_dir, "").lstrip("/"),
"label": label,
"role_info": role_info,
"project_info": project_info,
"in_role": in_role,
"in_project": in_project,
}
)
return all_files


def list_scan_target(root_dir: str, task_num_threshold: int = -1):
yml_list = get_yml_list(root_dir=root_dir, task_num_threshold=task_num_threshold)
known_roles = set()
all_targets = []
for yml_info in yml_list:
if yml_info["label"] not in ["playbook", "taskfile"]:
continue
role_path = ""
if yml_info["in_role"]:
role_path = yml_info["role_info"].get("path", None)
if role_path and role_path in known_roles:
continue
scan_type = ""
filepath = ""
path_from_root = ""
if role_path:
scan_type = "role"
filepath = role_path
path_from_root = role_path.replace(root_dir, "").lstrip("/")
known_roles.add(role_path)
else:
scan_type = yml_info["label"]
filepath = yml_info["filepath"]
path_from_root = yml_info["path_from_root"]
target_info = {
"filepath": filepath,
"path_from_root": path_from_root,
"scan_type": scan_type,
}
all_targets.append(target_info)
all_targets = sorted(all_targets, key=lambda x: x["filepath"])
all_targets = sorted(all_targets, key=lambda x: x["scan_type"])
return all_targets
32 changes: 30 additions & 2 deletions ansible_risk_insight/risk_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,18 @@
import time

import ansible_risk_insight.logger as logger
from .models import AnsibleRunContext, ARIResult, TargetResult, NodeResult, RuleResult, Rule, SpecMutation, FatalRuleResultError
from .models import (
AnsibleRunContext,
ARIResult,
TargetResult,
NodeResult,
RuleResult,
Rule,
SpecMutation,
FatalRuleResultError,
RunTarget,
TaskCall,
)
from .keyutil import detect_type, key_delimiter
from .analyzer import load_taskcalls_in_trees
from .utils import load_classes_in_dir
Expand Down Expand Up @@ -204,7 +215,7 @@ def detect(contexts: List[AnsibleRunContext], rules_dir: str = "", rules: list =
n_result.rules.append(r_result)
# remove node details
if save_only_rule_result:
n_result.node = None
n_result.node = omit_node_details(n_result.node)
t_result.nodes.append(n_result)
ari_result.targets.append(t_result)

Expand All @@ -214,6 +225,23 @@ def detect(contexts: List[AnsibleRunContext], rules_dir: str = "", rules: list =
return data_report, loaded_rules


def omit_node_details(node: RunTarget):
spec = None
if getattr(node, "spec"):
spec = {
"type": getattr(node.spec, "type"),
"name": getattr(node.spec, "name"),
"defined_in": getattr(node.spec, "defined_in"),
}
if isinstance(node, TaskCall):
spec["line_num_in_file"] = (getattr(node.spec, "line_num_in_file"),)
summary = {
"type": node.type,
"spec": spec,
}
return summary


def main():
parser = argparse.ArgumentParser(
prog="risk_detector.py",
Expand Down

0 comments on commit 2a8946e

Please sign in to comment.