diff --git a/src/validators/execution.py b/src/validators/execution.py index 31721373..2c33e806 100644 --- a/src/validators/execution.py +++ b/src/validators/execution.py @@ -3,13 +3,8 @@ from typing import Set from eth_typing import BlockNumber, HexStr -from multiproof import StandardMerkleTree -from sw_utils import ( - EventProcessor, - compute_deposit_data, - get_eth1_withdrawal_credentials, - is_valid_deposit_data_signature, -) +from multiproof.standart import MultiProof +from sw_utils import EventProcessor, is_valid_deposit_data_signature from sw_utils.typings import Bytes32 from web3 import Web3 from web3.types import EventData, Wei @@ -23,7 +18,7 @@ from src.common.ipfs import fetch_harvest_params from src.common.metrics import metrics from src.config.networks import ETH_NETWORKS -from src.config.settings import DEPOSIT_AMOUNT, DEPOSIT_AMOUNT_GWEI, settings +from src.config.settings import DEPOSIT_AMOUNT, settings from src.validators.database import NetworkValidatorCrud from src.validators.typings import ( DepositData, @@ -222,9 +217,9 @@ async def update_unused_validator_keys_metric( async def register_single_validator( - tree: StandardMerkleTree, - validator: Validator, approval: OraclesApproval, + multi_proof: MultiProof, + tx_validators: list[bytes], update_state_call: HexStr | None, validators_registry_root: Bytes32, ) -> None: @@ -232,19 +227,15 @@ async def register_single_validator( if settings.network not in ETH_NETWORKS: raise NotImplementedError('networks other than Ethereum not supported') - credentials = get_eth1_withdrawal_credentials(settings.vault) - tx_validator = _encode_tx_validator(credentials, validator) - proof = tree.get_proof([tx_validator, validator.deposit_data_index]) - logger.info('Submitting registration transaction') register_call_args = [ ( validators_registry_root, - tx_validator, + tx_validators[0], approval.signatures, approval.ipfs_hash, ), - proof, + multi_proof.proof, ] if update_state_call is not None: register_call = vault_contract.encodeABI( @@ -259,10 +250,9 @@ async def register_single_validator( await execution_client.eth.wait_for_transaction_receipt(tx, timeout=300) -# pylint: disable-next=too-many-locals async def register_multiple_validator( - tree: StandardMerkleTree, - validators: list[Validator], + multi_proof: MultiProof, + tx_validators: list[bytes], approval: OraclesApproval, update_state_call: HexStr | None, validators_registry_root: Bytes32, @@ -271,18 +261,8 @@ async def register_multiple_validator( if settings.network not in ETH_NETWORKS: raise NotImplementedError('networks other than Ethereum not supported') - credentials = get_eth1_withdrawal_credentials(settings.vault) - tx_validators: list[bytes] = [] - leaves: list[tuple[bytes, int]] = [] - for validator in validators: - tx_validator = _encode_tx_validator(credentials, validator) - tx_validators.append(tx_validator) - leaves.append((tx_validator, validator.deposit_data_index)) - - multi_proof = tree.get_multi_proof(leaves) sorted_tx_validators: list[bytes] = [v[0] for v in multi_proof.leaves] indexes = [sorted_tx_validators.index(v) for v in tx_validators] - logger.info('Submitting registration transaction') register_call_args = [ ( @@ -306,15 +286,3 @@ async def register_multiple_validator( logger.info('Waiting for transaction %s confirmation', Web3.to_hex(tx)) await execution_client.eth.wait_for_transaction_receipt(tx, timeout=300) - - -def _encode_tx_validator(withdrawal_credentials: bytes, validator: Validator) -> bytes: - public_key = Web3.to_bytes(hexstr=validator.public_key) - signature = Web3.to_bytes(hexstr=validator.signature) - deposit_root = compute_deposit_data( - public_key=public_key, - withdrawal_credentials=withdrawal_credentials, - amount_gwei=DEPOSIT_AMOUNT_GWEI, - signature=signature, - ).hash_tree_root - return public_key + signature + deposit_root diff --git a/src/validators/signing.py b/src/validators/signing.py index e204d5aa..24d41a0e 100644 --- a/src/validators/signing.py +++ b/src/validators/signing.py @@ -2,14 +2,17 @@ import milagro_bls_binding as bls from Cryptodome.Random.random import randint from eth_typing import HexStr +from multiproof import StandardMerkleTree +from multiproof.standart import MultiProof from py_ecc.optimized_bls12_381.optimized_curve import curve_order -from sw_utils.signing import get_exit_message_signing_root +from sw_utils import get_eth1_withdrawal_credentials +from sw_utils.signing import compute_deposit_data, get_exit_message_signing_root from sw_utils.typings import ConsensusFork from web3 import Web3 from src.common.typings import Oracles -from src.config.settings import settings -from src.validators.typings import BLSPrivkey, ExitSignatureShards +from src.config.settings import DEPOSIT_AMOUNT_GWEI, settings +from src.validators.typings import BLSPrivkey, ExitSignatureShards, Validator def get_polynomial_points(coefficients: list[int], num_points: int) -> list[bytes]: @@ -60,3 +63,31 @@ def get_exit_signature_shards( public_keys=[Web3.to_hex(bls.SkToPk(priv_key)) for priv_key in private_keys], exit_signatures=exit_signature_shards, ) + + +def get_validators_proof( + tree: StandardMerkleTree, + validators: list[Validator], +) -> tuple[list[bytes], MultiProof]: + credentials = get_eth1_withdrawal_credentials(settings.vault) + tx_validators: list[bytes] = [] + leaves: list[tuple[bytes, int]] = [] + for validator in validators: + tx_validator = encode_tx_validator(credentials, validator) + tx_validators.append(tx_validator) + leaves.append((tx_validator, validator.deposit_data_index)) + + multi_proof = tree.get_multi_proof(leaves) + return tx_validators, multi_proof + + +def encode_tx_validator(withdrawal_credentials: bytes, validator: Validator) -> bytes: + public_key = Web3.to_bytes(hexstr=validator.public_key) + signature = Web3.to_bytes(hexstr=validator.signature) + deposit_root = compute_deposit_data( + public_key=public_key, + withdrawal_credentials=withdrawal_credentials, + amount_gwei=DEPOSIT_AMOUNT_GWEI, + signature=signature, + ).hash_tree_root + return public_key + signature + deposit_root diff --git a/src/validators/tasks.py b/src/validators/tasks.py index ec94cf1f..d16bb7d3 100644 --- a/src/validators/tasks.py +++ b/src/validators/tasks.py @@ -1,5 +1,6 @@ import logging +from multiproof.standart import MultiProof from sw_utils.typings import Bytes32 from web3 import Web3 from web3.types import BlockNumber, Wei @@ -20,7 +21,7 @@ register_multiple_validator, register_single_validator, ) -from src.validators.signing import get_exit_signature_shards +from src.validators.signing import get_exit_signature_shards, get_validators_proof from src.validators.typings import ( ApprovalRequest, DepositData, @@ -34,6 +35,7 @@ logger = logging.getLogger(__name__) +# pylint: disable-next=too-many-locals async def register_validators(keystores: Keystores, deposit_data: DepositData) -> None: """Registers vault validators.""" vault_balance, update_state_call = await get_withdrawable_assets() @@ -71,8 +73,11 @@ async def register_validators(keystores: Keystores, deposit_data: DepositData) - ) return + tx_validators, multi_proof = get_validators_proof( + tree=deposit_data.tree, + validators=validators, + ) registry_root = None - while True: latest_registry_root = await validators_registry_contract.get_registry_root() @@ -85,6 +90,7 @@ async def register_validators(keystores: Keystores, deposit_data: DepositData) - oracles=oracles, keystores=keystores, validators=validators, + multi_proof=multi_proof, ) try: @@ -96,9 +102,9 @@ async def register_validators(keystores: Keystores, deposit_data: DepositData) - if len(validators) == 1: validator = validators[0] await register_single_validator( - tree=deposit_data.tree, - validator=validator, approval=oracles_approval, + multi_proof=multi_proof, + tx_validators=tx_validators, update_state_call=update_state_call, validators_registry_root=registry_root, ) @@ -106,9 +112,9 @@ async def register_validators(keystores: Keystores, deposit_data: DepositData) - if len(validators) > 1: await register_multiple_validator( - tree=deposit_data.tree, - validators=validators, approval=oracles_approval, + multi_proof=multi_proof, + tx_validators=tx_validators, update_state_call=update_state_call, validators_registry_root=registry_root, ) @@ -117,7 +123,11 @@ async def register_validators(keystores: Keystores, deposit_data: DepositData) - async def create_approval_request( - oracles: Oracles, keystores: Keystores, validators: list[Validator], registry_root: Bytes32 + oracles: Oracles, + keystores: Keystores, + validators: list[Validator], + registry_root: Bytes32, + multi_proof: MultiProof, ) -> ApprovalRequest: """Generate validator registration request data""" @@ -139,6 +149,8 @@ async def create_approval_request( deposit_signatures=[], public_key_shards=[], exit_signature_shards=[], + proof=multi_proof.proof, + proof_flags=multi_proof.proof_flags, ) for validator in validators: shards = get_exit_signature_shards( diff --git a/src/validators/typings.py b/src/validators/typings.py index 283a9877..08254b8a 100644 --- a/src/validators/typings.py +++ b/src/validators/typings.py @@ -50,6 +50,8 @@ class ApprovalRequest: deposit_signatures: list[HexStr] public_key_shards: list[list[HexStr]] exit_signature_shards: list[list[HexStr]] + proof: list[HexStr] + proof_flags: list[bool] @dataclass diff --git a/src/validators/utils.py b/src/validators/utils.py index 1ea3c1fb..6d16b23c 100644 --- a/src/validators/utils.py +++ b/src/validators/utils.py @@ -27,10 +27,8 @@ RegistryRootChangedError, ValidatorIndexChangedError, ) -from src.validators.execution import ( - _encode_tx_validator, - get_latest_network_validator_public_keys, -) +from src.validators.execution import get_latest_network_validator_public_keys +from src.validators.signing import encode_tx_validator from src.validators.typings import ( ApprovalRequest, BLSPrivkey, @@ -202,7 +200,7 @@ def load_deposit_data(vault: HexAddress, deposit_data_file: Path) -> DepositData public_key=add_0x_prefix(data['pubkey']), signature=add_0x_prefix(data['signature']), ) - leaves.append((_encode_tx_validator(credentials, validator), i)) + leaves.append((encode_tx_validator(credentials, validator), i)) validators.append(validator) tree = StandardMerkleTree.of(leaves, ['bytes', 'uint256'])