Skip to content

Commit

Permalink
Add signature shares logic
Browse files Browse the repository at this point in the history
  • Loading branch information
evgeny-stakewise committed Feb 2, 2024
1 parent 95ba429 commit 2803caf
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 53 deletions.
72 changes: 30 additions & 42 deletions src/validators/keystores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from web3 import Web3

from src.common.typings import Oracles
from src.config.settings import NETWORKS, REMOTE_SIGNER_TIMEOUT, settings
from src.config.networks import NETWORKS
from src.config.settings import REMOTE_SIGNER_TIMEOUT, settings
from src.validators.keystores.base import BaseKeystore
from src.validators.signing.common import encrypt_signature
from src.validators.signing.key_shares import reconstruct_shared_bls_signature
from src.validators.signing.key_shares import bls_signature_and_public_key_to_shares
from src.validators.typings import ExitSignatureShards

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,29 +94,27 @@ async def get_exit_signature_shards(
genesis_validators_root=settings.network_config.GENESIS_VALIDATORS_ROOT,
fork=fork,
)
pubkey_shares = self.pubkeys_to_shares.get(public_key)
if not pubkey_shares:
raise RuntimeError(f'Failed to get signature for {public_key}.')

validator_pubkey_shares = [BLSPubkey(Web3.to_bytes(hexstr=s)) for s in pubkey_shares]

signature_shards = []
for validator_pubkey_share, oracle_pubkey in zip(
validator_pubkey_shares, oracles.public_keys
):
shard = await self._fetch_signature_shard(
pubkey_share=validator_pubkey_share,
validator_index=validator_index,
fork=fork,
message=message,
)

# Encrypt it with the oracle's pubkey
signature_shards.append(encrypt_signature(oracle_pubkey, shard))
public_key_bytes = BLSPubkey(Web3.to_bytes(hexstr=public_key))
threshold = oracles.exit_signature_recover_threshold
total = len(oracles.public_keys)

exit_signature = await self._sign(public_key_bytes, validator_index, fork, message)

exit_signature_shares, public_key_shares = bls_signature_and_public_key_to_shares(
message, exit_signature, public_key_bytes, threshold, total
)

encrypted_exit_signature_shares: list[HexStr] = []

for exit_signature_share, oracle_pubkey in zip(exit_signature_shares, oracles.public_keys):
encrypted_exit_signature_shares.append(
encrypt_signature(oracle_pubkey, exit_signature_share)
)

return ExitSignatureShards(
public_keys=[Web3.to_hex(pubkey) for pubkey in validator_pubkey_shares],
exit_signatures=signature_shards,
public_keys=[Web3.to_hex(p) for p in public_key_shares],
exit_signatures=encrypted_exit_signature_shares,
)

async def get_exit_signature(
Expand All @@ -126,19 +125,10 @@ async def get_exit_signature(
genesis_validators_root=NETWORKS[network].GENESIS_VALIDATORS_ROOT,
fork=fork,
)
signature_shards = []
for pubkey_share in self.pubkeys_to_shares[public_key]:
signature_shards.append(
await self._fetch_signature_shard(
pubkey_share=BLSPubkey(Web3.to_bytes(hexstr=pubkey_share)),
validator_index=validator_index,
fork=fork,
message=message,
)
)
exit_signature = reconstruct_shared_bls_signature(
signatures=dict(enumerate(signature_shards))
)
public_key_bytes = BLSPubkey(Web3.to_bytes(hexstr=public_key))

exit_signature = await self._sign(public_key_bytes, validator_index, fork, message)

bls.Verify(BLSPubkey(Web3.to_bytes(hexstr=public_key)), message, exit_signature)
return exit_signature

Expand All @@ -157,9 +147,9 @@ def _load_data(cls, data: dict) -> 'RemoteSignerKeystore':

return RemoteSignerKeystore(pubkeys_to_shares=pubkeys_to_shares)

async def _fetch_signature_shard(
async def _sign(
self,
pubkey_share: BLSPubkey,
public_key: BLSPubkey,
validator_index: int,
fork: ConsensusFork,
message: bytes,
Expand All @@ -181,17 +171,15 @@ async def _fetch_signature_shard(
)

async with ClientSession(timeout=ClientTimeout(REMOTE_SIGNER_TIMEOUT)) as session:
signer_url = f'{settings.remote_signer_url}/api/v1/eth2/sign/0x{pubkey_share.hex()}'
signer_url = f'{settings.remote_signer_url}/api/v1/eth2/sign/0x{public_key.hex()}'

response = await session.post(signer_url, json=dataclasses.asdict(data))

if response.status == 404:
# Pubkey not present on remote signer side
raise RuntimeError(
f'Failed to get signature for {pubkey_share.hex()}.'
f' Is this keyshare present in the remote signer?'
f' If the oracle set changed, you may need to regenerate'
f' and reimport the new key shares!'
f'Failed to get signature for {public_key.hex()}.'
f' Is this public key present in the remote signer?'
)

response.raise_for_status()
Expand Down
80 changes: 78 additions & 2 deletions src/validators/signing/key_shares.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
from random import randint
import secrets
from typing import TypeAlias

from eth_typing import BLSPubkey, BLSSignature
from py_ecc.bls import G2ProofOfPossession
from py_ecc.bls.g2_primitives import (
G1_to_pubkey,
G2_to_signature,
pubkey_to_G1,
signature_to_G2,
)
from py_ecc.bls.hash_to_curve import hash_to_G2
from py_ecc.optimized_bls12_381.optimized_curve import (
G1 as P1, # don't mess group name (G1) and primitive element name (P1)
)
from py_ecc.optimized_bls12_381.optimized_curve import (
Z1,
Z2,
add,
curve_order,
multiply,
)
from py_ecc.typing import Optimized_Field, Optimized_Point3D
from py_ecc.utils import prime_field_inv

from src.validators.typings import BLSPrivkey

# element of G1 or G2
G12: TypeAlias = Optimized_Point3D[Optimized_Field]


def get_polynomial_points(coefficients: list[int], num_points: int) -> list[int]:
"""Calculates polynomial points."""
Expand All @@ -35,6 +45,23 @@ def get_polynomial_points(coefficients: list[int], num_points: int) -> list[int]
return points


def get_G12_polynomial_points(coefficients: list, num_points: int) -> list:
"""Calculates polynomial points in G1 or G2."""
points = []
for x in range(1, num_points + 1):
# start with x=1 and calculate the value of y
y = coefficients[0]
# calculate each term and add it to y, using modular math
for i in range(1, len(coefficients)):
exponentiation = (x**i) % curve_order
term = multiply(coefficients[i], exponentiation)
y = add(y, term)

# add the point to the list of points
points.append(y)
return points


def private_key_to_private_key_shares(
private_key: BLSPrivkey,
threshold: int,
Expand All @@ -43,13 +70,62 @@ def private_key_to_private_key_shares(
coefficients: list[int] = [int.from_bytes(private_key, 'big')]

for _ in range(threshold - 1):
coefficients.append(randint(0, curve_order - 1)) # nosec
coefficients.append(secrets.randbelow(curve_order))

points = get_polynomial_points(coefficients, total)

return [BLSPrivkey(p.to_bytes(32, 'big')) for p in points]


def bls_signature_to_shares(
bls_signature: BLSSignature,
coefficients_G2: list[G12],
total: int,
) -> list[BLSSignature]:
coefficients_G2 = [signature_to_G2(bls_signature)] + coefficients_G2

points = get_G12_polynomial_points(coefficients_G2, total)

return [BLSSignature(G2_to_signature(p)) for p in points]


def bls_public_key_to_shares(
public_key: BLSPubkey,
coefficients_G1: list,
total: int,
) -> list[BLSPubkey]:
coefficients_G1 = [pubkey_to_G1(public_key)] + coefficients_G1

points = get_G12_polynomial_points(coefficients_G1, total)

return [BLSPubkey(G1_to_pubkey(p)) for p in points]


def bls_signature_and_public_key_to_shares(
message: bytes, signature: BLSSignature, public_key: BLSPubkey, threshold: int, total: int
) -> tuple[list[BLSSignature], list[BLSPubkey]]:
"""
Given `message`, `signature` and `public_key` so that
`signature` for `message` can be verified with `public_key`.
The function splits `signature` and `public_key` to shares so that
each signature share can be verified with corresponding public key share.
The most straight forward way to do this is to use private key shares.
But this function does not require private key.
"""
message_g2 = hash_to_G2(message, G2ProofOfPossession.DST, G2ProofOfPossession.xmd_hash_function)

coefficients_int = [secrets.randbelow(curve_order) for _ in range(threshold - 1)]
coefficients_G1 = [multiply(P1, coef) for coef in coefficients_int]
coefficients_G2 = [multiply(message_g2, coef) for coef in coefficients_int]

bls_signature_shards = bls_signature_to_shares(signature, coefficients_G2, total)
public_key_shards = bls_public_key_to_shares(public_key, coefficients_G1, total)

return bls_signature_shards, public_key_shards


def reconstruct_shared_bls_signature(signatures: dict[int, BLSSignature]) -> BLSSignature:
"""
Reconstructs shared BLS private key signature.
Expand Down
23 changes: 14 additions & 9 deletions src/validators/signing/tests/oracle_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def verify_signature_shards(
validator_index: int,
fork: ConsensusFork,
exit_signature_shards: ExitSignatureShards,
exit_signature: BLSSignature | None = None,
):
# Decrypt the signature shards using the oracle private keys
exit_signatures_decrypted = []
exit_signature_shares_decrypted = []
for oracle_privkey, exit_signature_shard in zip(
self.oracle_privkeys, exit_signature_shards.exit_signatures
):
exit_signatures_decrypted.append(
exit_signature_shares_decrypted.append(
BLSSignature(
ecies.decrypt(oracle_privkey.secret, Web3.to_bytes(hexstr=exit_signature_shard))
)
Expand All @@ -52,23 +53,27 @@ def verify_signature_shards(
aggregate_key = get_aggregate_key(validator_pubkey_shares)
assert aggregate_key == validator_pubkey

# Verify the signature (shards)
# Verify the signature shards using public key shards
message = get_exit_message_signing_root(
validator_index=validator_index,
genesis_validators_root=settings.network_config.GENESIS_VALIDATORS_ROOT,
fork=fork,
)
for idx, (signature, validator_pubkey_share) in enumerate(
zip(exit_signatures_decrypted, exit_signature_shards.public_keys)
for idx, (signature_share, validator_pubkey_share) in enumerate(
zip(exit_signature_shares_decrypted, exit_signature_shards.public_keys)
):
pubkey = Web3.to_bytes(hexstr=validator_pubkey_share)
assert bls.Verify(pubkey, message, signature) is True
pubkey_share = Web3.to_bytes(hexstr=validator_pubkey_share)
assert bls.Verify(pubkey_share, message, signature_share) is True

# Verify the full reconstructed signature
signatures = dict(enumerate(exit_signatures_decrypted))
# Verify the full reconstructed signature using full public key
signatures = dict(enumerate(exit_signature_shares_decrypted))
random_indexes = random.sample(sorted(signatures), k=self.exit_signature_recover_threshold)
random_signature_subset = {k: v for k, v in signatures.items() if k in random_indexes}
reconstructed_full_signature = reconstruct_shared_bls_signature(random_signature_subset)
assert (
bls.Verify(aggregate_key, message, reconstructed_full_signature) is True
), 'Unable to reconstruct full signature'

# The case when we split signature created by third party service
if exit_signature is not None:
assert reconstructed_full_signature == exit_signature

0 comments on commit 2803caf

Please sign in to comment.