From 48d57948b5c351e934e49ad61dc31c210a9f2d3d Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Fri, 30 Jun 2023 14:54:42 +0200 Subject: [PATCH] use mapping to store participant pk --- .../contracts/coordination/Coordinator.sol | 19 ++++++- tests/application/conftest.py | 2 +- tests/test_coordinator.py | 56 ++++++++++--------- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index 5e0e4f70..6bd2bcae 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -65,6 +65,8 @@ contract Coordinator is AccessControlDefaultAdminRules { bytes32 public constant INITIATOR_ROLE = keccak256("INITIATOR_ROLE"); + mapping(address => bytes) public providerPublicKey; + IAccessControlApplication public immutable application; Ritual[] public rituals; @@ -124,6 +126,12 @@ contract Coordinator is AccessControlDefaultAdminRules { _setRoleAdmin(INITIATOR_ROLE, bytes32(0)); } + function setProviderPublicKey(bytes calldata publicKey) external { + // TODO: Verify public key length + require(publicKey.length == 48, "Invalid public key length"); + providerPublicKey[msg.sender] = publicKey; + } + function setTimeout(uint32 newTimeout) external onlyRole(DEFAULT_ADMIN_ROLE) { emit TimeoutChanged(timeout, newTimeout); timeout = newTimeout; @@ -136,7 +144,7 @@ contract Coordinator is AccessControlDefaultAdminRules { function setReimbursementPool(IReimbursementPool pool) external onlyRole(DEFAULT_ADMIN_ROLE) { require( - address(pool) == address(0) || + address(pool) == address(0) || pool.isAuthorized(address(this)), "Invalid ReimbursementPool" ); @@ -179,6 +187,11 @@ contract Coordinator is AccessControlDefaultAdminRules { for(uint256 i=0; i < length; i++){ Participant storage newParticipant = ritual.participant.push(); address current = providers[i]; + // Make sure that current provider has already set their public key + require( + providerPublicKey[current].length > 0, + "Provider has not set their public key" + ); require(previous < current, "Providers must be sorted"); // TODO: Improve check for eligible nodes (staking, etc) - nucypher#3109 // TODO: Change check to isAuthorized(), without amount @@ -191,7 +204,7 @@ contract Coordinator is AccessControlDefaultAdminRules { } processRitualPayment(id, providers, duration); - + // TODO: Include cohort fingerprint in StartRitual event? emit StartRitual(id, ritual.authority, providers); return id; @@ -369,7 +382,7 @@ contract Coordinator is AccessControlDefaultAdminRules { currency.transferFrom(address(this), ritual.initiator, refundableFee); } } - + function processReimbursement(uint256 initialGasLeft) internal { if(address(reimbursementPool) != address(0)){ // TODO: Consider defining a method uint256 gasUsed = initialGasLeft - gasleft(); diff --git a/tests/application/conftest.py b/tests/application/conftest.py index 282841b2..d1f456c5 100644 --- a/tests/application/conftest.py +++ b/tests/application/conftest.py @@ -35,7 +35,7 @@ REWARD_DURATION = 60 * 60 * 24 * 7 # one week in seconds DEAUTHORIZATION_DURATION = 60 * 60 * 24 * 60 # 60 days in seconds -DEPENDENCY = project.dependencies["openzeppelin"]["4.8.1"] +DEPENDENCY = project.dependencies["openzeppelin"]["4.9.1"] @pytest.fixture() diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 78e7c82e..c69b9a2c 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -60,10 +60,7 @@ def erc20(project, initiator): @pytest.fixture() def flat_rate_fee_model(project, deployer, stake_info, erc20): contract = project.FlatRateFeeModel.deploy( - erc20.address, - FEE_RATE, - stake_info.address, - sender=deployer + erc20.address, FEE_RATE, stake_info.address, sender=deployer ) return contract @@ -77,7 +74,7 @@ def coordinator(project, deployer, stake_info, flat_rate_fee_model, initiator): MAX_DKG_SIZE, admin, flat_rate_fee_model.address, - sender=deployer + sender=deployer, ) contract.grantRole(contract.INITIATOR_ROLE(), initiator, sender=admin) return contract @@ -100,6 +97,11 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): with ape.reverts("Invalid ritual duration"): coordinator.initiateRitual(nodes, initiator, 0, sender=initiator) + with ape.reverts("Provider has not set their public key"): + coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + + for node in nodes: + coordinator.setProviderPublicKey(os.urandom(48), sender=node) with ape.reverts("Providers must be sorted"): coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator) @@ -108,11 +110,17 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) -def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model): +def initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes): + for node in nodes: + coordinator.setProviderPublicKey(os.urandom(48), sender=node) cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) erc20.approve(coordinator.address, cost, sender=initiator) - authority = initiator - tx = coordinator.initiateRitual(nodes, authority, DURATION, sender=initiator) + tx = coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + return initiator, tx + + +def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model): + authority, tx = initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) events = list(coordinator.StartRitual.from_receipt(tx)) assert len(events) == 1 @@ -125,9 +133,7 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_mod def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) for node in nodes: assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS @@ -150,27 +156,27 @@ def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_mod assert coordinator.getRitualState(0) == RitualState.AWAITING_AGGREGATIONS -def test_post_transcript_but_not_part_of_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) +def test_post_transcript_but_not_part_of_ritual( + coordinator, nodes, initiator, erc20, flat_rate_fee_model +): + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) with ape.reverts("Participant not part of ritual"): coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=initiator) -def test_post_transcript_but_already_posted_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) +def test_post_transcript_but_already_posted_transcript( + coordinator, nodes, initiator, erc20, flat_rate_fee_model +): + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=nodes[0]) with ape.reverts("Node already posted transcript"): coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=nodes[0]) -def test_post_transcript_but_not_waiting_for_transcripts(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) +def test_post_transcript_but_not_waiting_for_transcripts( + coordinator, nodes, initiator, erc20, flat_rate_fee_model +): + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) for node in nodes: transcript = os.urandom(TRANSCRIPT_SIZE) coordinator.postTranscript(0, transcript, sender=node) @@ -180,9 +186,7 @@ def test_post_transcript_but_not_waiting_for_transcripts(coordinator, nodes, ini def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) transcript = os.urandom(TRANSCRIPT_SIZE) for node in nodes: coordinator.postTranscript(0, transcript, sender=node)