From 48d57948b5c351e934e49ad61dc31c210a9f2d3d Mon Sep 17 00:00:00 2001
From: Piotr Roslaniec
Date: Fri, 30 Jun 2023 14:54:42 +0200
Subject: [PATCH] use mapping to store participant pk
---
.../contracts/coordination/Coordinator.sol | 19 ++++++-
tests/application/conftest.py | 2 +-
tests/test_coordinator.py | 56 ++++++++++---------
3 files changed, 47 insertions(+), 30 deletions(-)
diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol
index 5e0e4f70..6bd2bcae 100644
--- a/contracts/contracts/coordination/Coordinator.sol
+++ b/contracts/contracts/coordination/Coordinator.sol
@@ -65,6 +65,8 @@ contract Coordinator is AccessControlDefaultAdminRules {
bytes32 public constant INITIATOR_ROLE = keccak256("INITIATOR_ROLE");
+ mapping(address => bytes) public providerPublicKey;
+
IAccessControlApplication public immutable application;
Ritual[] public rituals;
@@ -124,6 +126,12 @@ contract Coordinator is AccessControlDefaultAdminRules {
_setRoleAdmin(INITIATOR_ROLE, bytes32(0));
}
+ function setProviderPublicKey(bytes calldata publicKey) external {
+ // TODO: Verify public key length
+ require(publicKey.length == 48, "Invalid public key length");
+ providerPublicKey[msg.sender] = publicKey;
+ }
+
function setTimeout(uint32 newTimeout) external onlyRole(DEFAULT_ADMIN_ROLE) {
emit TimeoutChanged(timeout, newTimeout);
timeout = newTimeout;
@@ -136,7 +144,7 @@ contract Coordinator is AccessControlDefaultAdminRules {
function setReimbursementPool(IReimbursementPool pool) external onlyRole(DEFAULT_ADMIN_ROLE) {
require(
- address(pool) == address(0) ||
+ address(pool) == address(0) ||
pool.isAuthorized(address(this)),
"Invalid ReimbursementPool"
);
@@ -179,6 +187,11 @@ contract Coordinator is AccessControlDefaultAdminRules {
for(uint256 i=0; i < length; i++){
Participant storage newParticipant = ritual.participant.push();
address current = providers[i];
+ // Make sure that current provider has already set their public key
+ require(
+ providerPublicKey[current].length > 0,
+ "Provider has not set their public key"
+ );
require(previous < current, "Providers must be sorted");
// TODO: Improve check for eligible nodes (staking, etc) - nucypher#3109
// TODO: Change check to isAuthorized(), without amount
@@ -191,7 +204,7 @@ contract Coordinator is AccessControlDefaultAdminRules {
}
processRitualPayment(id, providers, duration);
-
+
// TODO: Include cohort fingerprint in StartRitual event?
emit StartRitual(id, ritual.authority, providers);
return id;
@@ -369,7 +382,7 @@ contract Coordinator is AccessControlDefaultAdminRules {
currency.transferFrom(address(this), ritual.initiator, refundableFee);
}
}
-
+
function processReimbursement(uint256 initialGasLeft) internal {
if(address(reimbursementPool) != address(0)){ // TODO: Consider defining a method
uint256 gasUsed = initialGasLeft - gasleft();
diff --git a/tests/application/conftest.py b/tests/application/conftest.py
index 282841b2..d1f456c5 100644
--- a/tests/application/conftest.py
+++ b/tests/application/conftest.py
@@ -35,7 +35,7 @@
REWARD_DURATION = 60 * 60 * 24 * 7 # one week in seconds
DEAUTHORIZATION_DURATION = 60 * 60 * 24 * 60 # 60 days in seconds
-DEPENDENCY = project.dependencies["openzeppelin"]["4.8.1"]
+DEPENDENCY = project.dependencies["openzeppelin"]["4.9.1"]
@pytest.fixture()
diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py
index 78e7c82e..c69b9a2c 100644
--- a/tests/test_coordinator.py
+++ b/tests/test_coordinator.py
@@ -60,10 +60,7 @@ def erc20(project, initiator):
@pytest.fixture()
def flat_rate_fee_model(project, deployer, stake_info, erc20):
contract = project.FlatRateFeeModel.deploy(
- erc20.address,
- FEE_RATE,
- stake_info.address,
- sender=deployer
+ erc20.address, FEE_RATE, stake_info.address, sender=deployer
)
return contract
@@ -77,7 +74,7 @@ def coordinator(project, deployer, stake_info, flat_rate_fee_model, initiator):
MAX_DKG_SIZE,
admin,
flat_rate_fee_model.address,
- sender=deployer
+ sender=deployer,
)
contract.grantRole(contract.INITIATOR_ROLE(), initiator, sender=admin)
return contract
@@ -100,6 +97,11 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator):
with ape.reverts("Invalid ritual duration"):
coordinator.initiateRitual(nodes, initiator, 0, sender=initiator)
+ with ape.reverts("Provider has not set their public key"):
+ coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
+
+ for node in nodes:
+ coordinator.setProviderPublicKey(os.urandom(48), sender=node)
with ape.reverts("Providers must be sorted"):
coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator)
@@ -108,11 +110,17 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator):
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
-def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
+def initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes):
+ for node in nodes:
+ coordinator.setProviderPublicKey(os.urandom(48), sender=node)
cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=initiator)
- authority = initiator
- tx = coordinator.initiateRitual(nodes, authority, DURATION, sender=initiator)
+ tx = coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
+ return initiator, tx
+
+
+def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
+ authority, tx = initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
events = list(coordinator.StartRitual.from_receipt(tx))
assert len(events) == 1
@@ -125,9 +133,7 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_mod
def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
- cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
- erc20.approve(coordinator.address, cost, sender=initiator)
- coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
+ initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
for node in nodes:
assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS
@@ -150,27 +156,27 @@ def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_mod
assert coordinator.getRitualState(0) == RitualState.AWAITING_AGGREGATIONS
-def test_post_transcript_but_not_part_of_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
- cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
- erc20.approve(coordinator.address, cost, sender=initiator)
- coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
+def test_post_transcript_but_not_part_of_ritual(
+ coordinator, nodes, initiator, erc20, flat_rate_fee_model
+):
+ initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
with ape.reverts("Participant not part of ritual"):
coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=initiator)
-def test_post_transcript_but_already_posted_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
- cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
- erc20.approve(coordinator.address, cost, sender=initiator)
- coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
+def test_post_transcript_but_already_posted_transcript(
+ coordinator, nodes, initiator, erc20, flat_rate_fee_model
+):
+ initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=nodes[0])
with ape.reverts("Node already posted transcript"):
coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=nodes[0])
-def test_post_transcript_but_not_waiting_for_transcripts(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
- cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
- erc20.approve(coordinator.address, cost, sender=initiator)
- coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
+def test_post_transcript_but_not_waiting_for_transcripts(
+ coordinator, nodes, initiator, erc20, flat_rate_fee_model
+):
+ initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
for node in nodes:
transcript = os.urandom(TRANSCRIPT_SIZE)
coordinator.postTranscript(0, transcript, sender=node)
@@ -180,9 +186,7 @@ def test_post_transcript_but_not_waiting_for_transcripts(coordinator, nodes, ini
def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
- cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
- erc20.approve(coordinator.address, cost, sender=initiator)
- coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
+ initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
transcript = os.urandom(TRANSCRIPT_SIZE)
for node in nodes:
coordinator.postTranscript(0, transcript, sender=node)