Skip to content

Commit

Permalink
Make (N)EVR(A) objects comparable (#379)
Browse files Browse the repository at this point in the history
Make (N)EVR(A) objects comparable

Related to packit/packit-service#2378. In the end it will probably not be needed for the sidetags related stuff, but it can be useful nevertheless.

Reviewed-by: Laura Barcziová
  • Loading branch information
softwarefactory-project-zuul[bot] authored May 23, 2024
2 parents 9f26a94 + b842ff4 commit 9d533dc
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 8 deletions.
11 changes: 4 additions & 7 deletions specfile/changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import subprocess
from typing import List, Optional, Union, overload

import rpm

from specfile.exceptions import SpecfileException
from specfile.formatter import formatted
from specfile.macros import Macros
Expand Down Expand Up @@ -282,10 +280,9 @@ def filter(

def parse_evr(s):
try:
evr = EVR.from_string(s)
return EVR.from_string(s)
except SpecfileException:
return "0", "0", ""
return str(evr.epoch), evr.version or "0", evr.release
return EVR(version="0")

if since is None:
start_index = 0
Expand All @@ -294,7 +291,7 @@ def parse_evr(s):
(
i
for i, e in enumerate(self.data)
if rpm.labelCompare(parse_evr(e.evr), parse_evr(since)) >= 0
if parse_evr(e.evr) >= parse_evr(since)
),
len(self.data) + 1,
)
Expand All @@ -305,7 +302,7 @@ def parse_evr(s):
(
i + 1
for i, e in reversed(list(enumerate(self.data)))
if rpm.labelCompare(parse_evr(e.evr), parse_evr(until)) <= 0
if parse_evr(e.evr) <= parse_evr(until)
),
0,
)
Expand Down
117 changes: 116 additions & 1 deletion specfile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import sys
from typing import TYPE_CHECKING, Tuple

import rpm

from specfile.constants import ARCH_NAMES
from specfile.exceptions import SpecfileException, UnterminatedMacroException
from specfile.formatter import formatted
Expand All @@ -28,10 +30,41 @@ def _key(self) -> tuple:
def __hash__(self) -> int:
return hash(self._key())

def _rpm_evr_tuple(self) -> Tuple[str, str, str]:
return str(self.epoch), self.version or "0", self.release

def _cmp(self, other: "EVR") -> int:
return rpm.labelCompare(self._rpm_evr_tuple(), other._rpm_evr_tuple())

def __lt__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return self._cmp(other) < 0

def __le__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return self._cmp(other) <= 0

def __eq__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return self._key() == other._key()
return self._cmp(other) == 0

def __ne__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return self._cmp(other) != 0

def __ge__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return self._cmp(other) >= 0

def __gt__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return self._cmp(other) > 0

@formatted
def __repr__(self) -> str:
Expand Down Expand Up @@ -65,6 +98,44 @@ def __init__(
def _key(self) -> tuple:
return self.name, self.epoch, self.version, self.release

def __lt__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
if self.name != other.name:
return NotImplemented
return self._cmp(other) < 0

def __le__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
if self.name != other.name:
return NotImplemented
return self._cmp(other) <= 0

def __eq__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return self.name == other.name and self._cmp(other) == 0

def __ne__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return self.name != other.name or self._cmp(other) != 0

def __ge__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
if self.name != other.name:
return NotImplemented
return self._cmp(other) >= 0

def __gt__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
if self.name != other.name:
return NotImplemented
return self._cmp(other) > 0

@formatted
def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -101,6 +172,50 @@ def __init__(
def _key(self) -> tuple:
return self.name, self.epoch, self.version, self.release, self.arch

def __lt__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
if self.name != other.name or self.arch != other.arch:
return NotImplemented
return self._cmp(other) < 0

def __le__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
if self.name != other.name or self.arch != other.arch:
return NotImplemented
return self._cmp(other) <= 0

def __eq__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return (
self.name == other.name
and self.arch == other.arch
and self._cmp(other) == 0
)

def __ne__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
return (
self.name != other.name or self.arch != other.arch or self._cmp(other) != 0
)

def __ge__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
if self.name != other.name or self.arch != other.arch:
return NotImplemented
return self._cmp(other) >= 0

def __gt__(self, other: object) -> bool:
if type(other) is not self.__class__:
return NotImplemented
if self.name != other.name or self.arch != other.arch:
return NotImplemented
return self._cmp(other) > 0

@formatted
def __repr__(self) -> str:
return (
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,42 @@ def test_get_filename_from_location(location, filename):
assert get_filename_from_location(location) == filename


def test_EVR_compare():
assert EVR(version="0") == EVR(version="0")
assert EVR(version="0", release="1") != EVR(version="0", release="2")
assert EVR(version="12.0", release="1") <= EVR(version="12.0", release="1")
assert EVR(version="12.0", release="1") <= EVR(version="12.0", release="2")
assert EVR(epoch=2, version="56.8", release="5") > EVR(
epoch=1, version="99.2", release="2"
)


def test_NEVR_compare():
assert NEVR(name="test", version="1", release="1") == NEVR(
name="test", version="1", release="1"
)
assert NEVR(name="test", version="3", release="1") != NEVR(
name="test2", version="3", release="1"
)
with pytest.raises(TypeError):
NEVR(name="test", version="3", release="1") > NEVR(
name="test2", version="1", release="2"
)


def test_NEVRA_compare():
assert NEVRA(name="test", version="1", release="1", arch="x86_64") == NEVRA(
name="test", version="1", release="1", arch="x86_64"
)
assert NEVRA(name="test", version="2", release="1", arch="x86_64") != NEVRA(
name="test", version="2", release="1", arch="aarch64"
)
with pytest.raises(TypeError):
NEVRA(name="test", version="1", release="1", arch="aarch64") < NEVRA(
name="test", version="2", release="1", arch="x86_64"
)


@pytest.mark.parametrize(
"evr, result",
[
Expand Down

0 comments on commit 9d533dc

Please sign in to comment.