Skip to content

Commit

Permalink
support group based module_defaults and fix unnecessary loop var cond…
Browse files Browse the repository at this point in the history
…ition (#144)

* support group module_defaults

Signed-off-by: hirokuni-kitahara <[email protected]>

* fix annotation for unnecessary loop var

Signed-off-by: hirokuni-kitahara <[email protected]>

---------

Signed-off-by: hirokuni-kitahara <[email protected]>
  • Loading branch information
hirokuni-kitahara authored Apr 28, 2023
1 parent 36a4e4b commit c20fd97
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 16 deletions.
2 changes: 1 addition & 1 deletion ansible_risk_insight/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.1.4"
__version__ = "0.1.6"
70 changes: 63 additions & 7 deletions ansible_risk_insight/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,7 @@ class AnsibleRunContext(object):
sequence: RunTargetList = field(default_factory=RunTargetList)
root_key: str = ""
parent: Object = None
ram_client: any = None

# used by rule check
current: RunTarget = None
Expand Down Expand Up @@ -1571,7 +1572,7 @@ def __getitem__(self, i):
return self.sequence[i]

@staticmethod
def from_tree(tree: ObjectList, parent: Object = None, last_item: bool = False):
def from_tree(tree: ObjectList, parent: Object = None, last_item: bool = False, ram_client=None):
if not tree:
return AnsibleRunContext(parent=parent, last_item=last_item)
if len(tree.items) == 0:
Expand All @@ -1584,15 +1585,15 @@ def from_tree(tree: ObjectList, parent: Object = None, last_item: bool = False):
continue
sequence_items.append(item)
tl = RunTargetList(items=sequence_items)
return AnsibleRunContext(sequence=tl, root_key=root_key, parent=parent, last_item=last_item)
return AnsibleRunContext(sequence=tl, root_key=root_key, parent=parent, last_item=last_item, ram_client=ram_client)

@staticmethod
def from_targets(targets: List[RunTarget], root_key: str = "", parent: Object = None, last_item: bool = False):
def from_targets(targets: List[RunTarget], root_key: str = "", parent: Object = None, last_item: bool = False, ram_client=None):
if not root_key:
if len(targets) > 0:
root_key = targets[0].spec.key
tl = RunTargetList(items=targets)
return AnsibleRunContext(sequence=tl, root_key=root_key, parent=parent, last_item=last_item)
return AnsibleRunContext(sequence=tl, root_key=root_key, parent=parent, last_item=last_item, ram_client=ram_client)

def find(self, target: RunTarget):
for t in self.sequence:
Expand All @@ -1606,11 +1607,15 @@ def before(self, target: RunTarget):
if rt.key == target.key:
break
targets.append(rt)
return AnsibleRunContext.from_targets(targets, root_key=self.root_key, parent=self.parent, last_item=self.last_item)
return AnsibleRunContext.from_targets(
targets, root_key=self.root_key, parent=self.parent, last_item=self.last_item, ram_client=self.ram_client
)

def search(self, cond: AnnotationCondition):
targets = [t for t in self.sequence if t.type == RunTargetType.Task and t.has_annotation_by_condition(cond)]
return AnsibleRunContext.from_targets(targets, root_key=self.root_key, parent=self.parent, last_item=self.last_item)
return AnsibleRunContext.from_targets(
targets, root_key=self.root_key, parent=self.parent, last_item=self.last_item, ram_client=self.ram_client
)

def is_end(self, target: RunTarget):
if len(self) == 0:
Expand All @@ -1631,7 +1636,9 @@ def is_begin(self, target: RunTarget):
return target.key == self.sequence[0].key

def copy(self):
return AnsibleRunContext.from_targets(targets=self.sequence.items, root_key=self.root_key, parent=self.parent, last_item=self.last_item)
return AnsibleRunContext.from_targets(
targets=self.sequence.items, root_key=self.root_key, parent=self.parent, last_item=self.last_item, ram_client=self.ram_client
)

@property
def info(self):
Expand Down Expand Up @@ -2122,6 +2129,55 @@ def __eq__(self, tfm):
return self.key == tfm.key and self.name == tfm.name and self.type == tfm.type and self.version == tfm.version and self.hash == tfm.hash


@dataclass
class ActionGroupMetadata(object):
group_name: str = ""
group_modules: list = field(default_factory=list)
type: str = ""
name: str = ""
version: str = ""
hash: str = ""

@staticmethod
def from_action_group(group_name: str, group_modules: list, metadata: dict):
if not group_name:
return None

if not group_modules:
return None

agm = ActionGroupMetadata()
agm.group_name = group_name
agm.group_modules = group_modules
agm.type = metadata.get("type", "")
agm.name = metadata.get("name", "")
agm.version = metadata.get("version", "")
agm.hash = metadata.get("hash", "")
return agm

@staticmethod
def from_dict(d: dict):
agm = ActionGroupMetadata()
agm.group_name = d.get("group_name", "")
agm.group_modules = d.get("group_modules", "")
agm.type = d.get("type", "")
agm.name = d.get("name", "")
agm.version = d.get("version", "")
agm.hash = d.get("hash", "")
return agm

def __eq__(self, agm):
if not isinstance(agm, ActionGroupMetadata):
return False
return (
self.group_name == agm.group_name
and self.name == agm.name
and self.type == agm.type
and self.version == agm.version
and self.hash == agm.hash
)


# following ansible-lint severity levels
class Severity:
VERY_HIGH = "very_high"
Expand Down
86 changes: 83 additions & 3 deletions ansible_risk_insight/risk_assessment_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RoleMetadata,
TaskFile,
TaskFileMetadata,
ActionGroupMetadata,
)
from .findings import Findings
from .utils import (
Expand All @@ -49,6 +50,7 @@
module_index_name = "module_index.json"
role_index_name = "role_index.json"
taskfile_index_name = "taskfile_index.json"
action_group_index_name = "action_group_index.json"


@dataclass
Expand All @@ -71,6 +73,9 @@ class RAMClient(object):
role_index: dict = field(default_factory=dict)
taskfile_index: dict = field(default_factory=dict)

# used for grouped module_defaults such as `group/aws`
action_group_index: dict = field(default_factory=dict)

max_cache_size: int = 200

def __post_init__(self):
Expand All @@ -89,6 +94,11 @@ def __post_init__(self):
with open(taskfile_index_path, "r") as file:
self.taskfile_index = json.load(file)

action_group_index_path = os.path.join(self.root_dir, "indices", action_group_index_name)
if os.path.exists(action_group_index_path):
with open(action_group_index_path, "r") as file:
self.action_group_index = json.load(file)

def clear_old_cache(self):
size = self.max_cache_size
self._remove_old_item(self.findings_cache, size)
Expand Down Expand Up @@ -125,6 +135,7 @@ def register_indices_to_ram(self, findings: Findings, include_test_contents: boo
self.register_module_index_to_ram(findings=findings, include_test_contents=include_test_contents)
self.register_role_index_to_ram(findings=findings, include_test_contents=include_test_contents)
self.register_taskfile_index_to_ram(findings=findings, include_test_contents=include_test_contents)
self.register_action_group_index_to_ram(findings=findings)

def register_module_index_to_ram(self, findings: Findings, include_test_contents: bool = False):
new_data_found = False
Expand Down Expand Up @@ -182,14 +193,12 @@ def register_module_index_to_ram(self, findings: Findings, include_test_contents
self.save_module_index(modules)
return

def register_role_index_to_ram(self, findings: Findings, include_test_contents: bool = False):
def register_role_index_to_ram(self, findings: Findings):
new_data_found = False
roles = self.load_role_index()
for role in findings.root_definitions.get("definitions", {}).get("roles", []):
if not isinstance(role, Role):
continue
if include_test_contents and is_test_object(role.defined_in):
continue
r_meta = RoleMetadata.from_role(role, findings.metadata)
current = roles.get(r_meta.fqcn, [])
exists = False
Expand Down Expand Up @@ -242,6 +251,59 @@ def register_taskfile_index_to_ram(self, findings: Findings, include_test_conten
self.save_taskfile_index(taskfiles)
return

def register_action_group_index_to_ram(self, findings: Findings, include_test_contents: bool = False):
new_data_found = False
action_groups = self.load_action_group_index()

for collection in findings.root_definitions.get("definitions", {}).get("collections", []):
if not isinstance(collection, Collection):
continue
if collection.meta_runtime and isinstance(collection.meta_runtime, dict):
for group_name, group_modules in collection.meta_runtime.get("action_groups", {}).items():
short_group_name = f"group/{group_name}"
fq_group_name = f"group/{collection.name}.{group_name}"

agm1 = ActionGroupMetadata.from_action_group(short_group_name, group_modules, findings.metadata)
current1 = action_groups.get(short_group_name, [])
exists = False
for ag_dict in current1:
ag = None
if isinstance(ag_dict, dict):
ag = ActionGroupMetadata.from_dict(ag_dict)
elif isinstance(ag_dict, ActionGroupMetadata):
ag = ag_dict
if not ag:
continue
if ag == agm1:
exists = True
break
if not exists:
current1.append(agm1)
new_data_found = True
action_groups.update({short_group_name: current1})

agm2 = ActionGroupMetadata.from_action_group(fq_group_name, group_modules, findings.metadata)
current2 = action_groups.get(fq_group_name, [])
exists = False
for ag_dict in current2:
ag = None
if isinstance(ag_dict, dict):
ag = ActionGroupMetadata.from_dict(ag_dict)
elif isinstance(ag_dict, ActionGroupMetadata):
ag = ag_dict
if not ag:
continue
if ag == agm2:
exists = True
break
if not exists:
current2.append(agm2)
new_data_found = True
action_groups.update({fq_group_name: current2})
if new_data_found:
self.save_action_group_index(action_groups)
return

def make_findings_dir_path(self, type, name, version, hash):
type_root = type + "s"
dir_name = name
Expand Down Expand Up @@ -731,6 +793,18 @@ def search_task(self, name, exact_match=False, max_match=-1, is_key=False, conte
self.task_search_cache[args_str] = matched_tasks
return matched_tasks

def search_action_group(self, name, max_match=-1):
if max_match == 0:
return []

found_groups = []
if name in self.action_group_index and self.action_group_index[name]:
found_groups = self.action_group_index[name]

if max_match > 0 and len(found_groups) > max_match:
found_groups = found_groups[:max_match]
return found_groups

def get_object_by_key(self, obj_key: str):
obj_info = get_obj_info_by_key(obj_key)
obj_type = obj_info.get("type", "")
Expand Down Expand Up @@ -892,6 +966,12 @@ def save_taskfile_index(self, taskfiles):
def load_taskfile_index(self):
return self.load_index(taskfile_index_name)

def save_action_group_index(self, action_groups):
return self.save_index(action_groups, action_group_index_name)

def load_action_group_index(self):
return self.load_index(action_group_index_name)

def save_error(self, error: str, out_dir: str):
if out_dir == "":
raise ValueError("output dir must be a non-empty value")
Expand Down
27 changes: 27 additions & 0 deletions ansible_risk_insight/rules/P002_module_argument_key_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Severity,
RuleTag as Tag,
ExecutableType,
ActionGroupMetadata,
)


Expand Down Expand Up @@ -59,6 +60,31 @@ def process(self, ctx: AnsibleRunContext):
default_args = task.module_defaults[module_short]
elif module_fqcn and module_fqcn in task.module_defaults:
default_args = task.module_defaults[module_fqcn]
elif ctx.ram_client:
for group_name in task.module_defaults:
tmp_args = task.module_defaults[group_name]
found = False
if not group_name.startswith("group/"):
continue
groups = ctx.ram_client.search_action_group(group_name)
if not groups:
continue
for group_dict in groups:
if not group_dict:
continue
if not isinstance(group_dict, dict):
continue
group = ActionGroupMetadata.from_dict(group_dict)
if module_short and module_short in group.group_modules:
found = True
default_args = tmp_args
break
elif module_fqcn and module_fqcn in group.group_modules:
found = True
default_args = tmp_args
break
if found:
break

used_keys = []
if isinstance(mo, dict):
Expand Down Expand Up @@ -118,6 +144,7 @@ def process(self, ctx: AnsibleRunContext):
task.set_annotation("module.required_arg_keys", required_keys, rule_id=self.rule_id)
task.set_annotation("module.missing_required_arg_keys", missing_required_keys, rule_id=self.rule_id)
task.set_annotation("module.available_args", available_args, rule_id=self.rule_id)
task.set_annotation("module.default_args", default_args, rule_id=self.rule_id)
task.set_annotation("module.used_alias_and_real_keys", used_alias_and_real_keys, rule_id=self.rule_id)

# TODO: find duplicate keys
Expand Down
19 changes: 18 additions & 1 deletion ansible_risk_insight/rules/P004_variable_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@
)


def is_loop_var(value, task):
# `item` or alternative loop variable (if any) should not be replaced to avoid breaking loop
skip_variables = []
if task.spec.loop and isinstance(task.spec.loop, dict):
skip_variables.extend(list(task.spec.loop.keys()))

_v = value.replace(" ", "")

for var in skip_variables:
for _prefix in ["}}", "|", "."]:
pattern = "{{" + var + _prefix
if pattern in _v:
return True
return False


@dataclass
class VariableValidationRule(Rule):
rule_id: str = "P004"
Expand Down Expand Up @@ -58,7 +74,8 @@ def process(self, ctx: AnsibleRunContext):
if v_name not in unknown_name_vars and v_name not in task_arg_keys:
unknown_name_vars.append(v_name)
if v_name not in unnecessary_loop:
if v_name.startswith("item."):
v_str = "{{ " + v_name + " }}"
if not is_loop_var(v_str, task):
unnecessary_loop.append({"name": v_name, "suggested": v_name.replace("item.", "")})

task.set_annotation("variable.undefined_vars", undefined_variables, rule_id=self.rule_id)
Expand Down
8 changes: 4 additions & 4 deletions ansible_risk_insight/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,13 +586,13 @@ def construct_trees(self, ram_client=None):
logger.info(" tree file saved")
return

def resolve_variables(self):
def resolve_variables(self, ram_client=None):
taskcalls_in_trees = resolve(self.trees, self.additional)
self.taskcalls_in_trees = taskcalls_in_trees

for i, tree in enumerate(self.trees):
last_item = i + 1 == len(self.trees)
ctx = AnsibleRunContext.from_tree(tree=tree, parent=self.target_object, last_item=last_item)
ctx = AnsibleRunContext.from_tree(tree=tree, parent=self.target_object, last_item=last_item, ram_client=ram_client)
self.contexts.append(ctx)

if self.do_save:
Expand Down Expand Up @@ -1004,16 +1004,16 @@ def evaluate(
logger.debug("construct_trees() done")

self.record_begin(time_records, "variable_resolution")
scandata.resolve_variables()
scandata.resolve_variables(_ram_client)
self.record_end(time_records, "variable_resolution")
if not self.silent:
logger.debug("resolve_variables() done")

self.record_begin(time_records, "module_annotators")
scandata.annotate()
self.record_end(time_records, "module_annotators")
if not self.silent:
logger.debug("annotate() done")
self.record_end(time_records, "module_annotators")

self.record_begin(time_records, "apply_rules")
scandata.apply_rules()
Expand Down

0 comments on commit c20fd97

Please sign in to comment.