Skip to content

Commit

Permalink
Coordinator: tests for upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
vzotova committed Aug 15, 2024
1 parent 5d78a8f commit 2ee2d92
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 13 deletions.
2 changes: 1 addition & 1 deletion contracts/contracts/coordination/Coordinator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable
ITACoChildApplication public immutable application;
uint96 private immutable minAuthorization; // TODO use child app for checking eligibility

Ritual[] private ritualsStub; // former rituals
Ritual[] internal ritualsStub; // former rituals, "internal" for testing only
uint32 public timeout;
uint16 public maxDkgSize;
bool private stub1; // former isInitiationPublic
Expand Down
35 changes: 35 additions & 0 deletions contracts/test/CoordinatorTestSet.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pragma solidity ^0.8.0;

import "../threshold/ITACoChildApplication.sol";
import "../contracts/coordination/Coordinator.sol";

/**
* @notice Contract for testing Coordinator contract
Expand Down Expand Up @@ -33,3 +34,37 @@ contract ChildApplicationForCoordinatorMock is ITACoChildApplication {
// solhint-disable-next-line no-empty-blocks
function penalize(address _stakingProvider) external {}
}

contract ExtendedCoordinator is Coordinator {
constructor(ITACoChildApplication _application) Coordinator(_application) {}

function initiateOldRitual(
IFeeModel feeModel,
address[] calldata providers,
address authority,
uint32 duration,
IEncryptionAuthorizer accessController
) external returns (uint32) {
uint16 length = uint16(providers.length);

uint32 id = uint32(ritualsStub.length);
Ritual storage ritual = ritualsStub.push();
ritual.initiator = msg.sender;
ritual.authority = authority;
ritual.dkgSize = length;
ritual.threshold = getThresholdForRitualSize(length);
ritual.initTimestamp = uint32(block.timestamp);
ritual.endTimestamp = ritual.initTimestamp + duration;
ritual.accessController = accessController;
ritual.feeModel = feeModel;

address previous = address(0);
for (uint256 i = 0; i < length; i++) {
Participant storage newParticipant = ritual.participant.push();
address current = providers[i];
newParticipant.provider = current;
previous = current;
}
return id;
}
}
110 changes: 98 additions & 12 deletions tests/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import ape
import pytest
from ape.utils import ZERO_ADDRESS
from eth_account import Account
from hexbytes import HexBytes
from web3 import Web3
Expand Down Expand Up @@ -86,9 +87,9 @@ def erc20(project, initiator):


@pytest.fixture()
def coordinator(project, deployer, application, initiator, oz_dependency):
def coordinator(project, deployer, application, oz_dependency):
admin = deployer
contract = project.Coordinator.deploy(
contract = project.ExtendedCoordinator.deploy(
application.address,
sender=deployer,
)
Expand All @@ -100,7 +101,7 @@ def coordinator(project, deployer, application, initiator, oz_dependency):
encoded_initializer_function,
sender=deployer,
)
proxy_contract = project.Coordinator.at(proxy.address)
proxy_contract = project.ExtendedCoordinator.at(proxy.address)
return proxy_contract


Expand Down Expand Up @@ -219,17 +220,20 @@ def test_initiate_ritual(

ritual_struct = coordinator.rituals(ritualID)
assert ritual_struct[0] == initiator
init, end = ritual_struct[1], ritual_struct[2]
init, end = ritual_struct[1], ritual_struct["endTimestamp"]
assert end - init == DURATION
total_transcripts, total_aggregations = ritual_struct[3], ritual_struct[4]
total_transcripts, total_aggregations = (
ritual_struct["totalTranscripts"],
ritual_struct["totalAggregations"],
)
assert total_transcripts == total_aggregations == 0
assert ritual_struct[5] == authority
assert ritual_struct[6] == len(nodes)
assert ritual_struct[7] == 1 + len(nodes) // 2 # threshold
assert not ritual_struct[8] # aggregationMismatch
assert ritual_struct[9] == global_allow_list.address # accessController
assert ritual_struct[10] == (b"\x00" * 32, b"\x00" * 16) # publicKey
assert not ritual_struct[11] # aggregatedTranscript
assert ritual_struct["authority"] == authority
assert ritual_struct["dkgSize"] == len(nodes)
assert ritual_struct["threshold"] == 1 + len(nodes) // 2 # threshold
assert not ritual_struct["aggregationMismatch"] # aggregationMismatch
assert ritual_struct["accessController"] == global_allow_list.address # accessController
assert ritual_struct["publicKey"] == (b"\x00" * 32, b"\x00" * 16) # publicKey
assert not ritual_struct["aggregatedTranscript"] # aggregatedTranscript

fee = fee_model.getRitualCost(len(nodes), DURATION)
assert erc20.balanceOf(fee_model) == fee
Expand Down Expand Up @@ -564,3 +568,85 @@ def test_post_aggregation_fails(
assert fee_model.totalPendingFees() == 0
assert fee_model.pendingFees(ritualID) == 0
fee_model.withdrawTokens(fee_model_balance_after_refund, sender=deployer)


def test_upgrade(
coordinator, nodes, initiator, erc20, fee_model, treasury, deployer, global_allow_list
):
coordinator.initiateOldRitual(
fee_model, nodes, initiator, DURATION, global_allow_list.address, sender=initiator
)
coordinator.initiateOldRitual(
ZERO_ADDRESS, [nodes[0]], treasury, DURATION // 2, deployer, sender=initiator
)
assert coordinator.numberOfRituals() == 0
coordinator.initializeNumberOfRituals(sender=deployer)
assert coordinator.numberOfRituals() == 2

initiate_ritual(
coordinator=coordinator,
fee_model=fee_model,
erc20=erc20,
authority=initiator,
nodes=nodes,
allow_logic=global_allow_list,
)
assert coordinator.numberOfRituals() == 3

assert coordinator.getRitualState(0) == RitualState.DKG_AWAITING_TRANSCRIPTS
assert coordinator.getRitualState(1) == RitualState.DKG_AWAITING_TRANSCRIPTS
assert coordinator.getRitualState(2) == RitualState.DKG_AWAITING_TRANSCRIPTS

ritual_struct = coordinator.rituals(0)
assert ritual_struct["initiator"] == initiator
init, end = ritual_struct["initTimestamp"], ritual_struct["endTimestamp"]
assert end - init == DURATION
total_transcripts, total_aggregations = (
ritual_struct["totalTranscripts"],
ritual_struct["totalAggregations"],
)
assert total_transcripts == total_aggregations == 0
assert ritual_struct["authority"] == initiator
assert ritual_struct["dkgSize"] == len(nodes)
assert ritual_struct["threshold"] == 1 + len(nodes) // 2
assert not ritual_struct["aggregationMismatch"]
assert ritual_struct["accessController"] == global_allow_list.address
assert ritual_struct["publicKey"] == (b"\x00" * 32, b"\x00" * 16)
assert not ritual_struct["aggregatedTranscript"]
assert ritual_struct["feeModel"] == fee_model.address

ritual_struct = coordinator.rituals(1)
assert ritual_struct["initiator"] == initiator
init, end = ritual_struct["initTimestamp"], ritual_struct["endTimestamp"]
assert end - init == DURATION // 2
total_transcripts, total_aggregations = (
ritual_struct["totalTranscripts"],
ritual_struct["totalAggregations"],
)
assert total_transcripts == total_aggregations == 0
assert ritual_struct["authority"] == treasury
assert ritual_struct["dkgSize"] == 1
assert ritual_struct["threshold"] == 1 # threshold
assert not ritual_struct["aggregationMismatch"] # aggregationMismatch
assert ritual_struct["accessController"] == deployer # accessController
assert ritual_struct["publicKey"] == (b"\x00" * 32, b"\x00" * 16) # publicKey
assert not ritual_struct["aggregatedTranscript"] # aggregatedTranscript
assert ritual_struct["feeModel"] == ZERO_ADDRESS # feeModel

ritual_struct = coordinator.rituals(2)
assert ritual_struct["initiator"] == initiator
init, end = ritual_struct["initTimestamp"], ritual_struct["endTimestamp"]
assert end - init == DURATION
total_transcripts, total_aggregations = (
ritual_struct["totalTranscripts"],
ritual_struct["totalAggregations"],
)
assert total_transcripts == total_aggregations == 0
assert ritual_struct["authority"] == initiator
assert ritual_struct["dkgSize"] == len(nodes)
assert ritual_struct["threshold"] == 1 + len(nodes) // 2 # threshold
assert not ritual_struct["aggregationMismatch"] # aggregationMismatch
assert ritual_struct["accessController"] == global_allow_list.address # accessController
assert ritual_struct["publicKey"] == (b"\x00" * 32, b"\x00" * 16) # publicKey
assert not ritual_struct["aggregatedTranscript"] # aggregatedTranscript
assert ritual_struct["feeModel"] == fee_model.address # feeModel

0 comments on commit 2ee2d92

Please sign in to comment.