diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index 2372e36c..37308dbf 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -385,6 +385,13 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable return keccak256(abi.encode(nodes)); } + function expectedTranscriptSize( + uint16 dkgSize, + uint16 threshold + ) public pure returns (uint256) { + return 40 + (dkgSize + 1) * BLS12381.G2_POINT_SIZE + threshold * BLS12381.G1_POINT_SIZE; + } + function postTranscript(uint32 ritualId, bytes calldata transcript) external { uint256 initialGasLeft = gasleft(); @@ -394,6 +401,11 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable "Not waiting for transcripts" ); + require( + transcript.length == expectedTranscriptSize(ritual.dkgSize, ritual.threshold), + "Invalid transcript size" + ); + address provider = application.operatorToStakingProvider(msg.sender); Participant storage participant = getParticipant(ritual, provider); @@ -449,6 +461,11 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable "Invalid length for decryption request static key" ); + require( + aggregatedTranscript.length == expectedTranscriptSize(ritual.dkgSize, ritual.threshold), + "Invalid transcript size" + ); + // nodes commit to their aggregation result bytes32 aggregatedTranscriptDigest = keccak256(aggregatedTranscript); participant.aggregated = true; @@ -618,7 +635,9 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable function processReimbursement(uint256 initialGasLeft) internal { if (address(reimbursementPool) != address(0)) { - uint256 gasUsed = initialGasLeft - gasleft(); + // For calldataGasCost calculation, see https://github.com/nucypher/nucypher-contracts/issues/328 + uint256 calldataGasCost = (msg.data.length - 128) * 16 + 128 * 4; + uint256 gasUsed = initialGasLeft - gasleft() + calldataGasCost; try reimbursementPool.refund(gasUsed, msg.sender) { return; } catch { diff --git a/deployment/artifacts/lynx.json b/deployment/artifacts/lynx.json index 70c197c1..0b5cc1d3 100644 --- a/deployment/artifacts/lynx.json +++ b/deployment/artifacts/lynx.json @@ -3964,6 +3964,30 @@ } ] }, + { + "type": "function", + "name": "expectedTranscriptSize", + "stateMutability": "pure", + "inputs": [ + { + "name": "dkgSize", + "type": "uint16", + "internalType": "uint16" + }, + { + "name": "threshold", + "type": "uint16", + "internalType": "uint16" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ] + }, { "type": "function", "name": "extendRitual", diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 6b8bb8d4..734eae90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,47 @@ +import os import pytest -from ape import convert, project +from ape import project +from enum import IntEnum +# Common constants +G1_SIZE = 48 +G2_SIZE = 48 * 2 +ONE_DAY = 24 * 60 * 60 +RitualState = IntEnum( + "RitualState", + [ + "NON_INITIATED", + "DKG_AWAITING_TRANSCRIPTS", + "DKG_AWAITING_AGGREGATIONS", + "DKG_TIMEOUT", + "DKG_INVALID", + "ACTIVE", + "EXPIRED", + ], + start=0, +) + + +# Utility functions +def transcript_size(shares, threshold): + return 40 + (1 + shares) * G2_SIZE + threshold * G1_SIZE + + +def generate_transcript(shares, threshold): + return os.urandom(transcript_size(shares, threshold)) + + +def gen_public_key(): + return (os.urandom(32), os.urandom(32), os.urandom(32)) + + +def access_control_error_message(address, role=None): + role = role or b"\x00" * 32 + return f"account={address}, neededRole={role}" + + +# Fixtures @pytest.fixture(scope="session") def oz_dependency(): return project.dependencies["openzeppelin"]["5.0.0"] @@ -20,10 +60,3 @@ def account1(accounts): @pytest.fixture def account2(accounts): return accounts[2] - - -@pytest.fixture -def nu_token(NuCypherToken, creator): - TOTAL_SUPPLY = convert("1_000_000_000 ether", int) - nu_token = creator.deploy(NuCypherToken, TOTAL_SUPPLY) - return nu_token diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 58d38bf3..24f3711c 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -1,5 +1,4 @@ import os -from enum import IntEnum import ape import pytest @@ -8,41 +7,13 @@ from hexbytes import HexBytes from web3 import Web3 +from tests.conftest import ONE_DAY, gen_public_key, generate_transcript, RitualState + TIMEOUT = 1000 MAX_DKG_SIZE = 31 FEE_RATE = 42 ERC20_SUPPLY = 10**24 DURATION = 48 * 60 * 60 -ONE_DAY = 24 * 60 * 60 - -RitualState = IntEnum( - "RitualState", - [ - "NON_INITIATED", - "DKG_AWAITING_TRANSCRIPTS", - "DKG_AWAITING_AGGREGATIONS", - "DKG_TIMEOUT", - "DKG_INVALID", - "ACTIVE", - "EXPIRED", - ], - start=0, -) - - -# This formula returns an approximated size -# To have a representative size, create transcripts with `nucypher-core` -def transcript_size(shares, threshold): - return int(424 + 240 * (shares / 2) + 50 * (threshold)) - - -def gen_public_key(): - return (os.urandom(32), os.urandom(32), os.urandom(32)) - - -def access_control_error_message(address, role=None): - role = role or b"\x00" * 32 - return f"account={address}, neededRole={role}" @pytest.fixture(scope="module") @@ -275,7 +246,9 @@ def test_post_transcript(coordinator, nodes, initiator, erc20, fee_model, global nodes=nodes, allow_logic=global_allow_list, ) - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) for node in nodes: assert coordinator.getRitualState(0) == RitualState.DKG_AWAITING_TRANSCRIPTS @@ -309,7 +282,10 @@ def test_post_transcript_but_not_part_of_ritual( allow_logic=global_allow_list, ) - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) + with ape.reverts("Participant not part of ritual"): coordinator.postTranscript(0, transcript, sender=initiator) @@ -325,12 +301,40 @@ def test_post_transcript_but_already_posted_transcript( nodes=nodes, allow_logic=global_allow_list, ) - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) + coordinator.postTranscript(0, transcript, sender=nodes[0]) with ape.reverts("Node already posted transcript"): coordinator.postTranscript(0, transcript, sender=nodes[0]) +def test_post_transcript_but_wrong_size( + coordinator, nodes, initiator, erc20, fee_model, global_allow_list +): + initiate_ritual( + coordinator=coordinator, + fee_model=fee_model, + erc20=erc20, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list, + ) + + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + bad_transcript = generate_transcript(size, threshold + 1) + + with ape.reverts("Invalid transcript size"): + coordinator.postTranscript(0, bad_transcript, sender=nodes[0]) + + bad_transcript = b"" + with ape.reverts("Invalid transcript size"): + coordinator.postTranscript(0, bad_transcript, sender=nodes[0]) + + def test_post_transcript_but_not_waiting_for_transcripts( coordinator, nodes, initiator, erc20, fee_model, global_allow_list ): @@ -342,7 +346,11 @@ def test_post_transcript_but_not_waiting_for_transcripts( nodes=nodes, allow_logic=global_allow_list, ) - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) + for node in nodes: coordinator.postTranscript(0, transcript, sender=node) @@ -359,7 +367,10 @@ def test_get_participants(coordinator, nodes, initiator, erc20, fee_model, globa nodes=nodes, allow_logic=global_allow_list, ) - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) for node in nodes: _ = coordinator.postTranscript(0, transcript, sender=node) @@ -413,7 +424,10 @@ def test_get_participant(nodes, coordinator, initiator, erc20, fee_model, global nodes=nodes, allow_logic=global_allow_list, ) - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) for node in nodes: _ = coordinator.postTranscript(0, transcript, sender=node) @@ -462,8 +476,12 @@ def test_post_aggregation( nodes=nodes, allow_logic=global_allow_list, ) + ritualID = 0 - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) + for node in nodes: coordinator.postTranscript(ritualID, transcript, sender=node) @@ -520,8 +538,12 @@ def test_post_aggregation_fails( nodes=nodes, allow_logic=global_allow_list, ) + ritualID = 0 - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + size = len(nodes) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) + for node in nodes: coordinator.postTranscript(ritualID, transcript, sender=node) @@ -535,7 +557,7 @@ def test_post_aggregation_fails( ) # Second node screws up everything - bad_aggregated = os.urandom(transcript_size(len(nodes), len(nodes))) + bad_aggregated = generate_transcript(size, threshold) tx = coordinator.postAggregation( ritualID, bad_aggregated, dkg_public_key, decryption_request_static_keys[1], sender=nodes[1] ) diff --git a/tests/test_global_allow_list.py b/tests/test_global_allow_list.py index 1294fd3a..663b6675 100644 --- a/tests/test_global_allow_list.py +++ b/tests/test_global_allow_list.py @@ -6,41 +6,13 @@ from eth_account.messages import encode_defunct from web3 import Web3 +from tests.conftest import gen_public_key, generate_transcript + TIMEOUT = 1000 MAX_DKG_SIZE = 31 FEE_RATE = 42 ERC20_SUPPLY = 10**24 DURATION = 48 * 60 * 60 -ONE_DAY = 24 * 60 * 60 - -RitualState = IntEnum( - "RitualState", - [ - "NON_INITIATED", - "DKG_AWAITING_TRANSCRIPTS", - "DKG_AWAITING_AGGREGATIONS", - "DKG_TIMEOUT", - "DKG_INVALID", - "ACTIVE", - "EXPIRED", - ], - start=0, -) - - -# This formula returns an approximated size -# To have a representative size, create transcripts with `nucypher-core` -def transcript_size(shares, threshold): - return int(424 + 240 * (shares / 2) + 50 * (threshold)) - - -def gen_public_key(): - return (os.urandom(32), os.urandom(32), os.urandom(32)) - - -def access_control_error_message(address, role=None): - role = role or b"\x00" * 32 - return f"account={address}, neededRole={role}" @pytest.fixture(scope="module") @@ -153,6 +125,7 @@ def test_authorize_using_global_allow_list( signable_message = encode_defunct(digest) signed_digest = w3.eth.account.sign_message(signable_message, private_key=deployer.private_key) signature = signed_digest.signature + size = len(nodes) # Not authorized assert not global_allow_list.isAuthorized(0, bytes(signature), bytes(digest)) @@ -168,7 +141,8 @@ def test_authorize_using_global_allow_list( coordinator.isEncryptionAuthorized(0, bytes(signature), bytes(digest)) # Finalize ritual - transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + threshold = coordinator.getThresholdForRitualSize(size) + transcript = generate_transcript(size, threshold) for node in nodes: coordinator.postTranscript(0, transcript, sender=node)