Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node version filtering for sampling #74

Merged
merged 12 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ Parameters
| | | | are greater than this max default value are |
| | | | capped at the default value |
+----------------------------------+------------------+------------------------------------------------+
| ``min_version`` | *(Optional)* | | Minimum acceptable version of Ursula. |
| | VersionString | | |
+----------------------------------+------------------+------------------------------------------------+


Returns
Expand Down
10 changes: 10 additions & 0 deletions porter/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import click
from marshmallow import fields
from packaging.version import parse

from porter.fields.exceptions import InvalidInputData

Expand Down Expand Up @@ -108,3 +109,12 @@ def _deserialize(self, value, attr, data, **kwargs):
f"Unexpected object type, {type(result)}; expected {self.expected_type}")

return result


class VersionString(String):

def _validate(self, value):
try:
parse(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

except Exception:
raise InvalidInputData(f"{self.name} must be a correct version.")
4 changes: 4 additions & 0 deletions porter/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def get_ursulas(
include_ursulas: Optional[List[ChecksumAddress]] = None,
timeout: Optional[int] = None,
duration: Optional[int] = None,
min_version: Optional[str] = None,
) -> Dict:
ursulas_info = self.implementer.get_ursulas(
quantity=quantity,
exclude_ursulas=exclude_ursulas,
include_ursulas=include_ursulas,
timeout=timeout,
duration=duration,
min_version=min_version,
)

response_data = {"ursulas": ursulas_info} # list of UrsulaInfo objects
Expand Down Expand Up @@ -104,13 +106,15 @@ def bucket_sampling(
exclude_ursulas: Optional[List[ChecksumAddress]] = None,
timeout: Optional[int] = None,
duration: Optional[int] = None,
min_version: Optional[str] = None,
) -> Dict:
ursulas, block_number = self.implementer.bucket_sampling(
quantity=quantity,
random_seed=random_seed,
exclude_ursulas=exclude_ursulas,
timeout=timeout,
duration=duration,
min_version=min_version,
)

response_data = {"ursulas": ursulas, "block_number": block_number}
Expand Down
59 changes: 51 additions & 8 deletions porter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TreasureMap,
)
from nucypher_core.umbral import PublicKey
from packaging.version import Version, parse
from prometheus_flask_exporter import PrometheusMetrics

import porter
Expand Down Expand Up @@ -100,6 +101,12 @@ class DecryptOutcome(NamedTuple):
]
errors: Dict[ChecksumAddress, str]

class UrsulaVersionTooOld(Exception):
def __init__(self, ursula_address: str, version: str, min_version: str):
super().__init__(
f"Ursula ({ursula_address}) version is too old ({version} < {min_version})"
)

def __init__(
self,
eth_endpoint: str,
Expand Down Expand Up @@ -155,18 +162,30 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str):
):
BlockchainInterfaceFactory.initialize_interface(endpoint=polygon_endpoint)

@staticmethod
def _is_version_greater_or_equal(min_version: Version, version: str) -> bool:
return parse(version) >= min_version

def _get_ursula_version(self, ursula: Ursula) -> str:
response = self.network_middleware.client.get(
node_or_sprout=ursula, path="status", params={"json": "true"}
)
return response.json()["version"]

def get_ursulas(
self,
quantity: int,
exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None,
include_ursulas: Optional[Sequence[ChecksumAddress]] = None,
timeout: Optional[int] = None,
duration: Optional[int] = None,
min_version: Optional[str] = None,
) -> List[UrsulaInfo]:
timeout = self._configure_timeout(
"sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT
)
duration = duration or 0
parse_min_version = parse(min_version) if min_version else None

reservoir = self._make_reservoir(exclude_ursulas, include_ursulas, duration)
available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir)
Expand All @@ -184,11 +203,18 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo:
ursula_address = to_checksum_address(ursula_address)
ursula = self.known_nodes[ursula_address]
try:
# ensure node is up and reachable
self.network_middleware.ping(ursula)
return Porter.UrsulaInfo(checksum_address=ursula_address,
uri=f"{ursula.rest_interface.formal_uri}",
encrypting_key=ursula.public_keys(DecryptingPower))
# ensure node is up and reachable and possibly check version
version = self._get_ursula_version(ursula)
if parse_min_version and not self._is_version_greater_or_equal(
parse_min_version, version
):
raise self.UrsulaVersionTooOld(ursula_address, version, min_version)

return Porter.UrsulaInfo(
checksum_address=ursula_address,
uri=f"{ursula.rest_interface.formal_uri}",
encrypting_key=ursula.public_keys(DecryptingPower),
)
except Exception as e:
self.log.debug(f"Ursula ({ursula_address}) is unreachable: {str(e)}")
raise
Expand Down Expand Up @@ -299,11 +325,13 @@ def bucket_sampling(
exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None,
timeout: Optional[int] = None,
duration: Optional[int] = None,
min_version: Optional[str] = None,
) -> Tuple[List[ChecksumAddress], int]:
timeout = self._configure_timeout(
"bucket_sampling", timeout, self.MAX_BUCKET_SAMPLING_TIMEOUT
)
duration = duration or 0
parse_min_version = parse(min_version) if min_version else None

if self.domain not in self._ALLOWED_DOMAINS_FOR_BUCKET_SAMPLING:
raise ValueError("Bucket sampling is only for TACo Mainnet")
Expand Down Expand Up @@ -364,7 +392,10 @@ def __init__(self, _reservoir, need_successes: int):
self.reservoir = _reservoir
self.need_successes = need_successes
self.predefined_buckets = self.read_buckets()
self.bucketed_nodes = defaultdict(list)
self.bucketed_nodes = defaultdict(
list
) # <bucket> -> <list of checksum addresses>
self.selected_nodes = dict() # <checksum address> -> <bucket>

def read_buckets(self) -> Dict:
try:
Expand All @@ -391,6 +422,11 @@ def find_bucket(self, node):
return bucket_name
return None

def mark_as_not_successful(self, unsuccessful_node: ChecksumAddress):
bucket = self.selected_nodes.get(unsuccessful_node)
if bucket:
self.bucketed_nodes[bucket].remove(unsuccessful_node)

def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]:
batch = []
batch_size = self.need_successes - _successes
Expand All @@ -403,6 +439,7 @@ def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]:
if len(self.bucketed_nodes[bucket]) >= self.BUCKET_CAP:
continue
self.bucketed_nodes[bucket].append(selected)
self.selected_nodes[selected] = bucket
batch.append(selected)
if not batch:
return None
Expand All @@ -417,12 +454,18 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress:
ursula_address = to_checksum_address(ursula_address)
ursula = self.known_nodes[ursula_address]
try:
# ensure node is up and reachable
self.network_middleware.ping(ursula)
# ensure node is up and reachable and possibly check version
version = self._get_ursula_version(ursula)
if parse_min_version and not self._is_version_greater_or_equal(
parse_min_version, version
):
raise self.UrsulaVersionTooOld(ursula_address, version, min_version)

return ursula_address
except Exception as e:
message = f"Ursula ({ursula_address}) is unreachable: {str(e)}"
self.log.debug(message)
value_factory.mark_as_not_successful(ursula_address)
raise

self.block_until_number_of_known_nodes_is(
Expand Down
25 changes: 25 additions & 0 deletions porter/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
NonNegativeInteger,
PositiveInteger,
StringList,
VersionString,
)
from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData
from porter.fields.retrieve import CapsuleFrag, RetrievalKit
Expand Down Expand Up @@ -125,6 +126,18 @@ class GetUrsulas(BaseSchema):
),
)

min_version = VersionString(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update the README to add this optional parameter.

Filed #75 which can be addressed in a separate PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added small note in README for get_ursulas

required=False,
load_only=True,
click=click.option(
"--min-version",
"-mv",
help="Minimum acceptable version of Ursula",
type=click.STRING,
required=False,
),
)

# output
ursulas = marshmallow_fields.List(marshmallow_fields.Nested(UrsulaInfoSchema), dump_only=True)

Expand Down Expand Up @@ -369,6 +382,18 @@ class BucketSampling(BaseSchema):
),
)

min_version = VersionString(
required=False,
load_only=True,
click=click.option(
"--min-version",
"-mv",
help="Minimum acceptable version of Ursula",
type=click.STRING,
required=False,
),
)

# output
ursulas = marshmallow_fields.List(UrsulaChecksumAddress, dump_only=True)
block_number = marshmallow_fields.Int(dump_only=True)
46 changes: 46 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
from tests.constants import (
MOCK_ETH_PROVIDER_URI,
TEMPORARY_DOMAIN,
TEST_ETH_PROVIDER_URI,
TESTERCHAIN_CHAIN_ID,
)
from tests.mock.interfaces import MockBlockchain
from tests.utils.middleware import MockRestMiddleware, _TestMiddlewareClient
from tests.utils.registry import MockRegistrySource, mock_registry_sources

# Crash on server error by default
Expand Down Expand Up @@ -245,6 +247,50 @@ def mock_signer(get_random_checksum_address):
return signer


class _MockMiddlewareClient(_TestMiddlewareClient):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ursulas_versions = {}

def get(self, *args, **kwargs):
if kwargs.get("path") == "status" and kwargs.get("params")["json"]:
node_address = kwargs.get("node_or_sprout").checksum_address
version = self.ursulas_versions.get(node_address, "1.1.1")
return _MockMiddlewareClient.MockResponse({"version": version}, 200)

real_get = super(_TestMiddlewareClient, self).__getattr__("get")
return real_get(*args, **kwargs)


class _MockRestMiddleware(MockRestMiddleware):
"""
Modified middleware to emulate returning status with version.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = _MockMiddlewareClient(eth_endpoint=TEST_ETH_PROVIDER_URI)

def set_ursulas_versions(self, ursulas_versions: dict):
self.client.ursulas_versions = dict(ursulas_versions)

def clean_ursulas_versions(self):
self.client.ursulas_versions = {}


@pytest.fixture(scope="module")
def mock_rest_middleware():
return _MockRestMiddleware(eth_endpoint=TEST_ETH_PROVIDER_URI)


@pytest.fixture(scope="module")
@pytest.mark.usefixtures('testerchain', 'agency')
def porter(ursulas, mock_rest_middleware, test_registry):
Expand Down
Loading
Loading