Skip to content

Commit

Permalink
Merge pull request #334 from cygnusv/tsize
Browse files Browse the repository at this point in the history
Make Coordinator aware of transcript size
  • Loading branch information
cygnusv committed Sep 18, 2024
2 parents d931844 + 6231be4 commit a21ac6f
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 80 deletions.
21 changes: 20 additions & 1 deletion contracts/contracts/coordination/Coordinator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable
return keccak256(abi.encode(nodes));
}

function expectedTranscriptSize(
uint16 dkgSize,
uint16 threshold
) public pure returns (uint256) {
return 40 + (dkgSize + 1) * BLS12381.G2_POINT_SIZE + threshold * BLS12381.G1_POINT_SIZE;
}

function postTranscript(uint32 ritualId, bytes calldata transcript) external {
uint256 initialGasLeft = gasleft();

Expand All @@ -394,6 +401,11 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable
"Not waiting for transcripts"
);

require(
transcript.length == expectedTranscriptSize(ritual.dkgSize, ritual.threshold),
"Invalid transcript size"
);

address provider = application.operatorToStakingProvider(msg.sender);
Participant storage participant = getParticipant(ritual, provider);

Expand Down Expand Up @@ -449,6 +461,11 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable
"Invalid length for decryption request static key"
);

require(
aggregatedTranscript.length == expectedTranscriptSize(ritual.dkgSize, ritual.threshold),
"Invalid transcript size"
);

// nodes commit to their aggregation result
bytes32 aggregatedTranscriptDigest = keccak256(aggregatedTranscript);
participant.aggregated = true;
Expand Down Expand Up @@ -618,7 +635,9 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable

function processReimbursement(uint256 initialGasLeft) internal {
if (address(reimbursementPool) != address(0)) {
uint256 gasUsed = initialGasLeft - gasleft();
// For calldataGasCost calculation, see https://github.com/nucypher/nucypher-contracts/issues/328
uint256 calldataGasCost = (msg.data.length - 128) * 16 + 128 * 4;
uint256 gasUsed = initialGasLeft - gasleft() + calldataGasCost;
try reimbursementPool.refund(gasUsed, msg.sender) {
return;
} catch {
Expand Down
24 changes: 24 additions & 0 deletions deployment/artifacts/lynx.json
Original file line number Diff line number Diff line change
Expand Up @@ -3964,6 +3964,30 @@
}
]
},
{
"type": "function",
"name": "expectedTranscriptSize",
"stateMutability": "pure",
"inputs": [
{
"name": "dkgSize",
"type": "uint16",
"internalType": "uint16"
},
{
"name": "threshold",
"type": "uint16",
"internalType": "uint16"
}
],
"outputs": [
{
"name": "",
"type": "uint256",
"internalType": "uint256"
}
]
},
{
"type": "function",
"name": "extendRitual",
Expand Down
Empty file added tests/__init__.py
Empty file.
49 changes: 41 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,47 @@
import os
import pytest
from ape import convert, project
from ape import project
from enum import IntEnum

# Common constants
G1_SIZE = 48
G2_SIZE = 48 * 2
ONE_DAY = 24 * 60 * 60

RitualState = IntEnum(
"RitualState",
[
"NON_INITIATED",
"DKG_AWAITING_TRANSCRIPTS",
"DKG_AWAITING_AGGREGATIONS",
"DKG_TIMEOUT",
"DKG_INVALID",
"ACTIVE",
"EXPIRED",
],
start=0,
)


# Utility functions
def transcript_size(shares, threshold):
return 40 + (1 + shares) * G2_SIZE + threshold * G1_SIZE


def generate_transcript(shares, threshold):
return os.urandom(transcript_size(shares, threshold))


def gen_public_key():
return (os.urandom(32), os.urandom(32), os.urandom(32))


def access_control_error_message(address, role=None):
role = role or b"\x00" * 32
return f"account={address}, neededRole={role}"


# Fixtures
@pytest.fixture(scope="session")
def oz_dependency():
return project.dependencies["openzeppelin"]["5.0.0"]
Expand All @@ -20,10 +60,3 @@ def account1(accounts):
@pytest.fixture
def account2(accounts):
return accounts[2]


@pytest.fixture
def nu_token(NuCypherToken, creator):
TOTAL_SUPPLY = convert("1_000_000_000 ether", int)
nu_token = creator.deploy(NuCypherToken, TOTAL_SUPPLY)
return nu_token
102 changes: 62 additions & 40 deletions tests/test_coordinator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from enum import IntEnum

import ape
import pytest
Expand All @@ -8,41 +7,13 @@
from hexbytes import HexBytes
from web3 import Web3

from tests.conftest import ONE_DAY, gen_public_key, generate_transcript, RitualState

TIMEOUT = 1000
MAX_DKG_SIZE = 31
FEE_RATE = 42
ERC20_SUPPLY = 10**24
DURATION = 48 * 60 * 60
ONE_DAY = 24 * 60 * 60

RitualState = IntEnum(
"RitualState",
[
"NON_INITIATED",
"DKG_AWAITING_TRANSCRIPTS",
"DKG_AWAITING_AGGREGATIONS",
"DKG_TIMEOUT",
"DKG_INVALID",
"ACTIVE",
"EXPIRED",
],
start=0,
)


# This formula returns an approximated size
# To have a representative size, create transcripts with `nucypher-core`
def transcript_size(shares, threshold):
return int(424 + 240 * (shares / 2) + 50 * (threshold))


def gen_public_key():
return (os.urandom(32), os.urandom(32), os.urandom(32))


def access_control_error_message(address, role=None):
role = role or b"\x00" * 32
return f"account={address}, neededRole={role}"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -275,7 +246,9 @@ def test_post_transcript(coordinator, nodes, initiator, erc20, fee_model, global
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
assert coordinator.getRitualState(0) == RitualState.DKG_AWAITING_TRANSCRIPTS
Expand Down Expand Up @@ -309,7 +282,10 @@ def test_post_transcript_but_not_part_of_ritual(
allow_logic=global_allow_list,
)

transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

with ape.reverts("Participant not part of ritual"):
coordinator.postTranscript(0, transcript, sender=initiator)

Expand All @@ -325,12 +301,40 @@ def test_post_transcript_but_already_posted_transcript(
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

coordinator.postTranscript(0, transcript, sender=nodes[0])
with ape.reverts("Node already posted transcript"):
coordinator.postTranscript(0, transcript, sender=nodes[0])


def test_post_transcript_but_wrong_size(
coordinator, nodes, initiator, erc20, fee_model, global_allow_list
):
initiate_ritual(
coordinator=coordinator,
fee_model=fee_model,
erc20=erc20,
authority=initiator,
nodes=nodes,
allow_logic=global_allow_list,
)

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
bad_transcript = generate_transcript(size, threshold + 1)

with ape.reverts("Invalid transcript size"):
coordinator.postTranscript(0, bad_transcript, sender=nodes[0])

bad_transcript = b""
with ape.reverts("Invalid transcript size"):
coordinator.postTranscript(0, bad_transcript, sender=nodes[0])


def test_post_transcript_but_not_waiting_for_transcripts(
coordinator, nodes, initiator, erc20, fee_model, global_allow_list
):
Expand All @@ -342,7 +346,11 @@ def test_post_transcript_but_not_waiting_for_transcripts(
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
coordinator.postTranscript(0, transcript, sender=node)

Expand All @@ -359,7 +367,10 @@ def test_get_participants(coordinator, nodes, initiator, erc20, fee_model, globa
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
_ = coordinator.postTranscript(0, transcript, sender=node)
Expand Down Expand Up @@ -413,7 +424,10 @@ def test_get_participant(nodes, coordinator, initiator, erc20, fee_model, global
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
_ = coordinator.postTranscript(0, transcript, sender=node)
Expand Down Expand Up @@ -462,8 +476,12 @@ def test_post_aggregation(
nodes=nodes,
allow_logic=global_allow_list,
)

ritualID = 0
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
coordinator.postTranscript(ritualID, transcript, sender=node)

Expand Down Expand Up @@ -520,8 +538,12 @@ def test_post_aggregation_fails(
nodes=nodes,
allow_logic=global_allow_list,
)

ritualID = 0
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
coordinator.postTranscript(ritualID, transcript, sender=node)

Expand All @@ -535,7 +557,7 @@ def test_post_aggregation_fails(
)

# Second node screws up everything
bad_aggregated = os.urandom(transcript_size(len(nodes), len(nodes)))
bad_aggregated = generate_transcript(size, threshold)
tx = coordinator.postAggregation(
ritualID, bad_aggregated, dkg_public_key, decryption_request_static_keys[1], sender=nodes[1]
)
Expand Down
36 changes: 5 additions & 31 deletions tests/test_global_allow_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,13 @@
from eth_account.messages import encode_defunct
from web3 import Web3

from tests.conftest import gen_public_key, generate_transcript

TIMEOUT = 1000
MAX_DKG_SIZE = 31
FEE_RATE = 42
ERC20_SUPPLY = 10**24
DURATION = 48 * 60 * 60
ONE_DAY = 24 * 60 * 60

RitualState = IntEnum(
"RitualState",
[
"NON_INITIATED",
"DKG_AWAITING_TRANSCRIPTS",
"DKG_AWAITING_AGGREGATIONS",
"DKG_TIMEOUT",
"DKG_INVALID",
"ACTIVE",
"EXPIRED",
],
start=0,
)


# This formula returns an approximated size
# To have a representative size, create transcripts with `nucypher-core`
def transcript_size(shares, threshold):
return int(424 + 240 * (shares / 2) + 50 * (threshold))


def gen_public_key():
return (os.urandom(32), os.urandom(32), os.urandom(32))


def access_control_error_message(address, role=None):
role = role or b"\x00" * 32
return f"account={address}, neededRole={role}"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -153,6 +125,7 @@ def test_authorize_using_global_allow_list(
signable_message = encode_defunct(digest)
signed_digest = w3.eth.account.sign_message(signable_message, private_key=deployer.private_key)
signature = signed_digest.signature
size = len(nodes)

# Not authorized
assert not global_allow_list.isAuthorized(0, bytes(signature), bytes(digest))
Expand All @@ -168,7 +141,8 @@ def test_authorize_using_global_allow_list(
coordinator.isEncryptionAuthorized(0, bytes(signature), bytes(digest))

# Finalize ritual
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)
for node in nodes:
coordinator.postTranscript(0, transcript, sender=node)

Expand Down

0 comments on commit a21ac6f

Please sign in to comment.