Skip to content

Commit

Permalink
Refact core
Browse files Browse the repository at this point in the history
type hints and make CPE matching more readable.
No logical change, everything should work the same as before.
  • Loading branch information
kissgyorgy committed Nov 15, 2024
1 parent ab89f3d commit b6c6bc6
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 95 deletions.
188 changes: 95 additions & 93 deletions cpematcher/core.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,21 @@
import fnmatch
from typing import Self

from .utils import split_cpe_string
from .version import Version

OR_OPERATOR = "OR"
AND_OPERATOR = "AND"
CPEv23 = "cpe:2.3:"


class CPE:
cpe23_start = "cpe:2.3:"
fields = [
"part",
"vendor",
"product",
"version",
"update",
"edition",
"language",
"sw_edition",
"target_sw",
"target_hw",
"other",
]


def __init__(
self,
cpe_str,
vulnerable=True,
version_start_including=None,
version_start_excluding=None,
version_end_including=None,
version_end_excluding=None,
cpe_str: str,
vulnerable: bool = True,
version_start_including: str | None = None,
version_start_excluding: str | None = None,
version_end_including: str | None = None,
version_end_excluding: str | None = None,
):
"""Create CPE object with information about affected software.
Expand All @@ -40,78 +24,100 @@ def __init__(
then we added the argument `vulnerable`.
There are some examples in CVE database.
"""
assert cpe_str.startswith(self.cpe23_start), "Only CPE 2.3 is supported"
cpe_str = cpe_str.replace(self.cpe23_start, "")

values = split_cpe_string(cpe_str)
if len(values) != 11:
raise ValueError("Incomplete number of fields")

for f in self.fields:
setattr(self, f, values.pop(0))
assert cpe_str.startswith(CPEv23), "Only CPE 2.3 is supported"

attr_values = split_cpe_string(cpe_str)
if len(attr_values) != 13:
raise ValueError("Incomplete number of CPE attributes")

(
*_,
self.part,
self.vendor,
self.product,
self.version_str,
self.update,
self.edition,
self.language,
self.sw_edition,
self.target_sw,
self.target_hw,
self.other,
) = attr_values

self.is_vulnerable = vulnerable

self.version = Version(self.version_str)
self.version_start_including = Version(version_start_including)
self.version_start_excluding = Version(version_start_excluding)
self.version_end_including = Version(version_end_including)
self.version_end_excluding = Version(version_end_excluding)

def matches(self, another_cpe): # noqa: C901
"""Verify if `another_cpe` matches, first through field comparison and
then using the border constraints.
@property
def no_version(self) -> tuple[str, str, str, str, str, str, str, str, str, str]:
return (
self.part,
self.vendor,
self.product,
self.update,
self.edition,
self.language,
self.sw_edition,
self.target_sw,
self.target_hw,
self.other,
)

def matches(self, other: Self) -> bool:
"""Verify if `other` matches, first through attribute comparison
then using version matching and border constraints.
"""
for f in self.fields:
value = getattr(self, f)
another_value = getattr(another_cpe, f)
"""
Depending on the order, fnmatch.fnmatch could return False
if wildcard is the first value.
As wildcard should always return True in any case,
we reorder the arguments based on that.
"""
if another_value == "*":
order = [value, another_value]
else:
order = [another_value, value]

if (
f == "version"
and "*" not in order
and "*" not in order[0]
and "*" not in order[1]
):
if Version(order[0]) != Version(order[1]):
return False
elif not fnmatch.fnmatch(*order):
return self._matches_fields(other) and self._matches_version(other)

@staticmethod
def _glob_equal(value1: str, value2: str) -> bool:
# Depending on the order, fnmatch.fnmatch could return False if wildcard
# is the first value. As wildcard should always return True in any case,
# we reorder the arguments based on that.
glob_values = [value1, value2] if value2 == "*" else [value2, value1]
return fnmatch.fnmatch(*glob_values)

def _matches_fields(self, other: Self) -> bool:
return all(
self._glob_equal(value, other_value)
for value, other_value in zip(self.no_version, other.no_version)
)

def _matches_version(self, other: Self) -> bool: # noqa: C901
if "*" in self.version_str or "*" in other.version_str:
if not self._glob_equal(self.version_str, other.version_str):
return False

version = Version(another_cpe.version)

# Do verifications on start version
if self.version_start_including and version < self.version_start_including:
elif self.version != other.version:
return False

if self.version_start_excluding and version <= self.version_start_excluding:
if (
self.version_start_including
and other.version < self.version_start_including
):
return False

if self.version_end_including and version > self.version_end_including:
if (
self.version_start_excluding
and other.version <= self.version_start_excluding
):
return False

if self.version_end_excluding and version >= self.version_end_excluding:
if self.version_end_including and other.version > self.version_end_including:
return False
if self.version_end_excluding and other.version >= self.version_end_excluding:
return False

# ruff: noqa: SIM103
return True


class CPEOperation:
"""Handle operations defined on CPE sets.
Support for:
- OR operations
Support only OR operations.
"""

VERSION_MAP = {
Expand All @@ -131,25 +137,21 @@ def _get_value(self, cpe_dict, key):
def __init__(self, operation_dict):
self.cpes = set()

operator = operation_dict["operator"]
if operation_dict["operator"] != "OR":
return None

if operator == OR_OPERATOR:
for cpe_dict in operation_dict["cpe"]:
c = CPE(
cpe_dict["cpe23Uri"],
cpe_dict.get("vulnerable"),
version_start_including=self._get_value(cpe_dict, "vsi"),
version_start_excluding=self._get_value(cpe_dict, "vse"),
version_end_including=self._get_value(cpe_dict, "vei"),
version_end_excluding=self._get_value(cpe_dict, "vee"),
)
for cpe_dict in operation_dict["cpe"]:
cpe = CPE(
cpe_dict["cpe23Uri"],
cpe_dict.get("vulnerable"),
version_start_including=self._get_value(cpe_dict, "vsi"),
version_start_excluding=self._get_value(cpe_dict, "vse"),
version_end_including=self._get_value(cpe_dict, "vei"),
version_end_excluding=self._get_value(cpe_dict, "vee"),
)

self.cpes.add(c)
self.cpes.add(cpe)

def matches(self, another_cpe):
def matches(self, other: CPE) -> CPE | None:
"""Return matching CPE object."""
for cpe in self.cpes:
if cpe.matches(another_cpe):
return cpe

return None
return next((cpe for cpe in self.cpes if cpe.matches(other)), None)
2 changes: 1 addition & 1 deletion cpematcher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


# heavily inspired by https://stackoverflow.com/a/21882672
def split_cpe_string(string):
def split_cpe_string(string: str) -> list[str]:
ret = []
current = []
itr = iter(string)
Expand Down
2 changes: 1 addition & 1 deletion cpematcher/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class Version:
def __init__(self, version):
def __init__(self, version: str):
self.version = version

def __bool__(self):
Expand Down

0 comments on commit b6c6bc6

Please sign in to comment.