From b6c6bc6b9db3a2b37f42224fa638c5a064e7a0ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20Kiss?= Date: Fri, 15 Nov 2024 18:09:24 +0100 Subject: [PATCH] Refact core type hints and make CPE matching more readable. No logical change, everything should work the same as before. --- cpematcher/core.py | 188 +++++++++++++++++++++--------------------- cpematcher/utils.py | 2 +- cpematcher/version.py | 2 +- 3 files changed, 97 insertions(+), 95 deletions(-) diff --git a/cpematcher/core.py b/cpematcher/core.py index 9f84e0a..96e49d4 100644 --- a/cpematcher/core.py +++ b/cpematcher/core.py @@ -1,37 +1,21 @@ import fnmatch +from typing import Self from .utils import split_cpe_string from .version import Version -OR_OPERATOR = "OR" -AND_OPERATOR = "AND" +CPEv23 = "cpe:2.3:" class CPE: - cpe23_start = "cpe:2.3:" - fields = [ - "part", - "vendor", - "product", - "version", - "update", - "edition", - "language", - "sw_edition", - "target_sw", - "target_hw", - "other", - ] - - def __init__( self, - cpe_str, - vulnerable=True, - version_start_including=None, - version_start_excluding=None, - version_end_including=None, - version_end_excluding=None, + cpe_str: str, + vulnerable: bool = True, + version_start_including: str | None = None, + version_start_excluding: str | None = None, + version_end_including: str | None = None, + version_end_excluding: str | None = None, ): """Create CPE object with information about affected software. @@ -40,78 +24,100 @@ def __init__( then we added the argument `vulnerable`. There are some examples in CVE database. - """ - assert cpe_str.startswith(self.cpe23_start), "Only CPE 2.3 is supported" - cpe_str = cpe_str.replace(self.cpe23_start, "") - - values = split_cpe_string(cpe_str) - if len(values) != 11: - raise ValueError("Incomplete number of fields") - - for f in self.fields: - setattr(self, f, values.pop(0)) + assert cpe_str.startswith(CPEv23), "Only CPE 2.3 is supported" + + attr_values = split_cpe_string(cpe_str) + if len(attr_values) != 13: + raise ValueError("Incomplete number of CPE attributes") + + ( + *_, + self.part, + self.vendor, + self.product, + self.version_str, + self.update, + self.edition, + self.language, + self.sw_edition, + self.target_sw, + self.target_hw, + self.other, + ) = attr_values self.is_vulnerable = vulnerable + + self.version = Version(self.version_str) self.version_start_including = Version(version_start_including) self.version_start_excluding = Version(version_start_excluding) self.version_end_including = Version(version_end_including) self.version_end_excluding = Version(version_end_excluding) - def matches(self, another_cpe): # noqa: C901 - """Verify if `another_cpe` matches, first through field comparison and - then using the border constraints. - + @property + def no_version(self) -> tuple[str, str, str, str, str, str, str, str, str, str]: + return ( + self.part, + self.vendor, + self.product, + self.update, + self.edition, + self.language, + self.sw_edition, + self.target_sw, + self.target_hw, + self.other, + ) + + def matches(self, other: Self) -> bool: + """Verify if `other` matches, first through attribute comparison + then using version matching and border constraints. """ - for f in self.fields: - value = getattr(self, f) - another_value = getattr(another_cpe, f) - """ - Depending on the order, fnmatch.fnmatch could return False - if wildcard is the first value. - As wildcard should always return True in any case, - we reorder the arguments based on that. - """ - if another_value == "*": - order = [value, another_value] - else: - order = [another_value, value] - - if ( - f == "version" - and "*" not in order - and "*" not in order[0] - and "*" not in order[1] - ): - if Version(order[0]) != Version(order[1]): - return False - elif not fnmatch.fnmatch(*order): + return self._matches_fields(other) and self._matches_version(other) + + @staticmethod + def _glob_equal(value1: str, value2: str) -> bool: + # Depending on the order, fnmatch.fnmatch could return False if wildcard + # is the first value. As wildcard should always return True in any case, + # we reorder the arguments based on that. + glob_values = [value1, value2] if value2 == "*" else [value2, value1] + return fnmatch.fnmatch(*glob_values) + + def _matches_fields(self, other: Self) -> bool: + return all( + self._glob_equal(value, other_value) + for value, other_value in zip(self.no_version, other.no_version) + ) + + def _matches_version(self, other: Self) -> bool: # noqa: C901 + if "*" in self.version_str or "*" in other.version_str: + if not self._glob_equal(self.version_str, other.version_str): return False - - version = Version(another_cpe.version) - - # Do verifications on start version - if self.version_start_including and version < self.version_start_including: + elif self.version != other.version: return False - if self.version_start_excluding and version <= self.version_start_excluding: + if ( + self.version_start_including + and other.version < self.version_start_including + ): return False - - if self.version_end_including and version > self.version_end_including: + if ( + self.version_start_excluding + and other.version <= self.version_start_excluding + ): return False - - if self.version_end_excluding and version >= self.version_end_excluding: + if self.version_end_including and other.version > self.version_end_including: + return False + if self.version_end_excluding and other.version >= self.version_end_excluding: return False + # ruff: noqa: SIM103 return True class CPEOperation: """Handle operations defined on CPE sets. - - Support for: - - OR operations - + Support only OR operations. """ VERSION_MAP = { @@ -131,25 +137,21 @@ def _get_value(self, cpe_dict, key): def __init__(self, operation_dict): self.cpes = set() - operator = operation_dict["operator"] + if operation_dict["operator"] != "OR": + return None - if operator == OR_OPERATOR: - for cpe_dict in operation_dict["cpe"]: - c = CPE( - cpe_dict["cpe23Uri"], - cpe_dict.get("vulnerable"), - version_start_including=self._get_value(cpe_dict, "vsi"), - version_start_excluding=self._get_value(cpe_dict, "vse"), - version_end_including=self._get_value(cpe_dict, "vei"), - version_end_excluding=self._get_value(cpe_dict, "vee"), - ) + for cpe_dict in operation_dict["cpe"]: + cpe = CPE( + cpe_dict["cpe23Uri"], + cpe_dict.get("vulnerable"), + version_start_including=self._get_value(cpe_dict, "vsi"), + version_start_excluding=self._get_value(cpe_dict, "vse"), + version_end_including=self._get_value(cpe_dict, "vei"), + version_end_excluding=self._get_value(cpe_dict, "vee"), + ) - self.cpes.add(c) + self.cpes.add(cpe) - def matches(self, another_cpe): + def matches(self, other: CPE) -> CPE | None: """Return matching CPE object.""" - for cpe in self.cpes: - if cpe.matches(another_cpe): - return cpe - - return None + return next((cpe for cpe in self.cpes if cpe.matches(other)), None) diff --git a/cpematcher/utils.py b/cpematcher/utils.py index 6939ddd..aff2a59 100644 --- a/cpematcher/utils.py +++ b/cpematcher/utils.py @@ -2,7 +2,7 @@ # heavily inspired by https://stackoverflow.com/a/21882672 -def split_cpe_string(string): +def split_cpe_string(string: str) -> list[str]: ret = [] current = [] itr = iter(string) diff --git a/cpematcher/version.py b/cpematcher/version.py index e189490..b90e03d 100644 --- a/cpematcher/version.py +++ b/cpematcher/version.py @@ -2,7 +2,7 @@ class Version: - def __init__(self, version): + def __init__(self, version: str): self.version = version def __bool__(self):