From b842ff47ef64b827419dc29d6743105684520310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nikola=20Forr=C3=B3?= Date: Wed, 22 May 2024 15:41:30 +0200 Subject: [PATCH] Make (N)EVR(A) objects comparable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nikola Forró --- specfile/changelog.py | 11 ++-- specfile/utils.py | 117 ++++++++++++++++++++++++++++++++++++++- tests/unit/test_utils.py | 36 ++++++++++++ 3 files changed, 156 insertions(+), 8 deletions(-) diff --git a/specfile/changelog.py b/specfile/changelog.py index 6628785..506a6f4 100644 --- a/specfile/changelog.py +++ b/specfile/changelog.py @@ -11,8 +11,6 @@ import subprocess from typing import List, Optional, Union, overload -import rpm - from specfile.exceptions import SpecfileException from specfile.formatter import formatted from specfile.macros import Macros @@ -282,10 +280,9 @@ def filter( def parse_evr(s): try: - evr = EVR.from_string(s) + return EVR.from_string(s) except SpecfileException: - return "0", "0", "" - return str(evr.epoch), evr.version or "0", evr.release + return EVR(version="0") if since is None: start_index = 0 @@ -294,7 +291,7 @@ def parse_evr(s): ( i for i, e in enumerate(self.data) - if rpm.labelCompare(parse_evr(e.evr), parse_evr(since)) >= 0 + if parse_evr(e.evr) >= parse_evr(since) ), len(self.data) + 1, ) @@ -305,7 +302,7 @@ def parse_evr(s): ( i + 1 for i, e in reversed(list(enumerate(self.data))) - if rpm.labelCompare(parse_evr(e.evr), parse_evr(until)) <= 0 + if parse_evr(e.evr) <= parse_evr(until) ), 0, ) diff --git a/specfile/utils.py b/specfile/utils.py index e6ae0db..545a75b 100644 --- a/specfile/utils.py +++ b/specfile/utils.py @@ -6,6 +6,8 @@ import sys from typing import TYPE_CHECKING, Tuple +import rpm + from specfile.constants import ARCH_NAMES from specfile.exceptions import SpecfileException, UnterminatedMacroException from specfile.formatter import formatted @@ -28,10 +30,41 @@ def _key(self) -> tuple: def __hash__(self) -> int: return hash(self._key()) + def _rpm_evr_tuple(self) -> Tuple[str, str, str]: + return str(self.epoch), self.version or "0", self.release + + def _cmp(self, other: "EVR") -> int: + return rpm.labelCompare(self._rpm_evr_tuple(), other._rpm_evr_tuple()) + + def __lt__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return self._cmp(other) < 0 + + def __le__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return self._cmp(other) <= 0 + def __eq__(self, other: object) -> bool: if type(other) is not self.__class__: return NotImplemented - return self._key() == other._key() + return self._cmp(other) == 0 + + def __ne__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return self._cmp(other) != 0 + + def __ge__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return self._cmp(other) >= 0 + + def __gt__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return self._cmp(other) > 0 @formatted def __repr__(self) -> str: @@ -65,6 +98,44 @@ def __init__( def _key(self) -> tuple: return self.name, self.epoch, self.version, self.release + def __lt__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + if self.name != other.name: + return NotImplemented + return self._cmp(other) < 0 + + def __le__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + if self.name != other.name: + return NotImplemented + return self._cmp(other) <= 0 + + def __eq__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return self.name == other.name and self._cmp(other) == 0 + + def __ne__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return self.name != other.name or self._cmp(other) != 0 + + def __ge__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + if self.name != other.name: + return NotImplemented + return self._cmp(other) >= 0 + + def __gt__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + if self.name != other.name: + return NotImplemented + return self._cmp(other) > 0 + @formatted def __repr__(self) -> str: return ( @@ -101,6 +172,50 @@ def __init__( def _key(self) -> tuple: return self.name, self.epoch, self.version, self.release, self.arch + def __lt__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + if self.name != other.name or self.arch != other.arch: + return NotImplemented + return self._cmp(other) < 0 + + def __le__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + if self.name != other.name or self.arch != other.arch: + return NotImplemented + return self._cmp(other) <= 0 + + def __eq__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return ( + self.name == other.name + and self.arch == other.arch + and self._cmp(other) == 0 + ) + + def __ne__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + return ( + self.name != other.name or self.arch != other.arch or self._cmp(other) != 0 + ) + + def __ge__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + if self.name != other.name or self.arch != other.arch: + return NotImplemented + return self._cmp(other) >= 0 + + def __gt__(self, other: object) -> bool: + if type(other) is not self.__class__: + return NotImplemented + if self.name != other.name or self.arch != other.arch: + return NotImplemented + return self._cmp(other) > 0 + @formatted def __repr__(self) -> str: return ( diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index a9af766..ed79278 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -32,6 +32,42 @@ def test_get_filename_from_location(location, filename): assert get_filename_from_location(location) == filename +def test_EVR_compare(): + assert EVR(version="0") == EVR(version="0") + assert EVR(version="0", release="1") != EVR(version="0", release="2") + assert EVR(version="12.0", release="1") <= EVR(version="12.0", release="1") + assert EVR(version="12.0", release="1") <= EVR(version="12.0", release="2") + assert EVR(epoch=2, version="56.8", release="5") > EVR( + epoch=1, version="99.2", release="2" + ) + + +def test_NEVR_compare(): + assert NEVR(name="test", version="1", release="1") == NEVR( + name="test", version="1", release="1" + ) + assert NEVR(name="test", version="3", release="1") != NEVR( + name="test2", version="3", release="1" + ) + with pytest.raises(TypeError): + NEVR(name="test", version="3", release="1") > NEVR( + name="test2", version="1", release="2" + ) + + +def test_NEVRA_compare(): + assert NEVRA(name="test", version="1", release="1", arch="x86_64") == NEVRA( + name="test", version="1", release="1", arch="x86_64" + ) + assert NEVRA(name="test", version="2", release="1", arch="x86_64") != NEVRA( + name="test", version="2", release="1", arch="aarch64" + ) + with pytest.raises(TypeError): + NEVRA(name="test", version="1", release="1", arch="aarch64") < NEVRA( + name="test", version="2", release="1", arch="x86_64" + ) + + @pytest.mark.parametrize( "evr, result", [