Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/prod-vectors.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
run: uv run python -m consensus_testing.keys --download --scheme prod

- name: Fill production test fixtures
run: uv run fill --fork=Devnet --scheme prod --clean -n 2
run: uv run fill --fork=Devnet --scheme prod --clean -n auto

- name: Upload production test fixtures
uses: actions/upload-artifact@v4
Expand Down
100 changes: 64 additions & 36 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import tempfile
import urllib.request
from concurrent.futures import ProcessPoolExecutor
from functools import cache, partial
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Iterator

Expand Down Expand Up @@ -131,43 +131,69 @@ def _get_keys_dir(scheme_name: str) -> Path:
return Path(__file__).parent / "test_keys" / f"{scheme_name}_scheme"


@cache
def load_keys(scheme_name: str) -> dict[ValidatorIndex, KeyPair]:
"""
Load pre-generated keys from disk (cached after first call).
class LazyKeyDict:
"""Load pre-generated keys from disk (cached after first call)."""

Args:
scheme_name: Name of the signature scheme.
def __init__(self, scheme_name: str) -> None:
"""Initialize with scheme name for locating key files."""
self._scheme_name = scheme_name
self._keys_dir = _get_keys_dir(scheme_name)
self._cache: dict[ValidatorIndex, KeyPair] = {}
self._available_indices: set[int] | None = None

Returns:
Mapping from validator index to key pair.
def _ensure_dir_exists(self) -> None:
if not self._keys_dir.exists():
raise FileNotFoundError(
f"Keys directory not found: {self._keys_dir} - "
f"Run: python -m consensus_testing.keys --scheme {self._scheme_name}"
)

Raises:
FileNotFoundError: If keys directory is missing.
"""
keys_dir = _get_keys_dir(scheme_name)

if not keys_dir.exists():
raise FileNotFoundError(
f"Keys directory not found: {keys_dir} - "
f"Run: python -m consensus_testing.keys --scheme {scheme_name}"
)

# Load all keypair files from the directory
result = {}
for key_file in sorted(keys_dir.glob("*.json")):
# Extract validator index from filename (e.g., "0.json" -> 0)
validator_idx = ValidatorIndex(int(key_file.stem))
def _get_available_indices(self) -> set[int]:
"""Scan directory for available key indices (cached)."""
if self._available_indices is None:
self._ensure_dir_exists()
self._available_indices = {int(f.stem) for f in self._keys_dir.glob("*.json")}
if not self._available_indices:
raise FileNotFoundError(
f"No key files found in: {self._keys_dir} - "
f"Run: python -m consensus_testing.keys --scheme {self._scheme_name}"
)
return self._available_indices

def _load_key(self, idx: int) -> KeyPair:
"""Load a single key from disk."""
key_file = self._keys_dir / f"{idx}.json"
if not key_file.exists():
raise KeyError(f"Key file not found: {key_file}")
data = json.loads(key_file.read_text())
result[validator_idx] = KeyPair.from_dict(data)
return KeyPair.from_dict(data)

def __getitem__(self, idx: ValidatorIndex) -> KeyPair:
"""Get key pair by validator index, loading from disk if needed."""
if idx not in self._cache:
self._cache[idx] = self._load_key(int(idx))
return self._cache[idx]

def __contains__(self, idx: ValidatorIndex) -> bool:
"""Check if a key exists for the given validator index."""
return int(idx) in self._get_available_indices()

def __len__(self) -> int:
"""Return the number of available keys."""
return len(self._get_available_indices())

def __iter__(self) -> Iterator[ValidatorIndex]:
"""Iterate over available validator indices in sorted order."""
return iter(ValidatorIndex(i) for i in sorted(self._get_available_indices()))

def items(self) -> Iterator[tuple[ValidatorIndex, KeyPair]]:
"""Iterate over all keys (loads all into memory)."""
for idx in self:
yield idx, self[idx]

if not result:
raise FileNotFoundError(
f"No key files found in: {keys_dir} - "
f"Run: python -m consensus_testing.keys --scheme {scheme_name}"
)

return result
_LAZY_KEY_CACHE: dict[str, LazyKeyDict] = {}
"""Cache for lazy key dictionaries by scheme name."""


class XmssKeyManager:
Expand Down Expand Up @@ -204,9 +230,11 @@ def __init__(
self.scheme_name = scheme_name

@property
def keys(self) -> dict[ValidatorIndex, KeyPair]:
def keys(self) -> LazyKeyDict:
"""Lazy access to immutable base keys."""
return load_keys(self.scheme_name)
if self.scheme_name not in _LAZY_KEY_CACHE:
_LAZY_KEY_CACHE[self.scheme_name] = LazyKeyDict(self.scheme_name)
return _LAZY_KEY_CACHE[self.scheme_name]

def __getitem__(self, idx: ValidatorIndex) -> KeyPair:
"""Get key pair, returning advanced state if available."""
Expand Down Expand Up @@ -370,7 +398,7 @@ def _generate_keys(lean_env: str, count: int, max_slot: int) -> None:
print(f"Saved {len(key_pairs)} key pairs to {keys_dir}/")

# Clear cache so new keys are loaded
load_keys.cache_clear()
_LAZY_KEY_CACHE.clear()


def _download_keys(scheme: str) -> None:
Expand Down Expand Up @@ -422,7 +450,7 @@ def _download_keys(scheme: str) -> None:
os.unlink(tmp_path)

# Clear cache so new keys are loaded
load_keys.cache_clear()
_LAZY_KEY_CACHE.clear()
print("Download complete!")


Expand Down
Loading