Skip to content

Commit

Permalink
test: update tests for dkgs with relaxed constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
piotr-roslaniec committed Feb 1, 2024
1 parent 514221e commit 175dda7
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 297 deletions.
12 changes: 5 additions & 7 deletions ferveo-python/examples/server_api_precomputed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,16 @@ def gen_eth_addr(i: int) -> str:

tau = 1
shares_num = 4
validators_num = shares_num + 2
# In precomputed variant, security threshold must be equal to shares_num
security_threshold = shares_num

validator_keypairs = [Keypair.random() for _ in range(0, shares_num)]
validator_keypairs = [Keypair.random() for _ in range(0, validators_num)]
validators = [
Validator(gen_eth_addr(i), keypair.public_key(), i)
for i, keypair in enumerate(validator_keypairs)
]

# Validators must be sorted by their public key
validators.sort(key=lambda v: v.address)

# Each validator holds their own DKG instance and generates a transcript every
# validator, including themselves
messages = []
Expand All @@ -52,11 +50,11 @@ def gen_eth_addr(i: int) -> str:

# Server can aggregate the transcripts
server_aggregate = dkg.aggregate_transcripts(messages)
assert server_aggregate.verify(shares_num, messages)
assert server_aggregate.verify(validators_num, messages)

# And the client can also aggregate and verify the transcripts
client_aggregate = AggregatedTranscript(messages)
assert client_aggregate.verify(shares_num, messages)
assert client_aggregate.verify(validators_num, messages)

# In the meantime, the client creates a ciphertext and decryption request
msg = "abc".encode()
Expand All @@ -76,7 +74,7 @@ def gen_eth_addr(i: int) -> str:

# We can also obtain the aggregated transcript from the side-channel (deserialize)
aggregate = AggregatedTranscript(messages)
assert aggregate.verify(shares_num, messages)
assert aggregate.verify(validators_num, messages)

# The ciphertext is obtained from the client

Expand Down
9 changes: 5 additions & 4 deletions ferveo-python/examples/server_api_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def gen_eth_addr(i: int) -> str:
tau = 1
security_threshold = 3
shares_num = 4
validator_keypairs = [Keypair.random() for _ in range(0, shares_num)]
validators_num = shares_num + 2
validator_keypairs = [Keypair.random() for _ in range(0, validators_num)]
validators = [
Validator(gen_eth_addr(i), keypair.public_key(), i)
for i, keypair in enumerate(validator_keypairs)
Expand Down Expand Up @@ -52,11 +53,11 @@ def gen_eth_addr(i: int) -> str:

# Server can aggregate the transcripts
server_aggregate = dkg.aggregate_transcripts(messages)
assert server_aggregate.verify(shares_num, messages)
assert server_aggregate.verify(validators_num, messages)

# And the client can also aggregate and verify the transcripts
client_aggregate = AggregatedTranscript(messages)
assert client_aggregate.verify(shares_num, messages)
assert client_aggregate.verify(validators_num, messages)

# In the meantime, the client creates a ciphertext and decryption request
msg = "abc".encode()
Expand All @@ -79,7 +80,7 @@ def gen_eth_addr(i: int) -> str:

# We can also obtain the aggregated transcript from the side-channel (deserialize)
aggregate = AggregatedTranscript(messages)
assert aggregate.verify(shares_num, messages)
assert aggregate.verify(validators_num, messages)

# The ciphertext is obtained from the client

Expand Down
7 changes: 6 additions & 1 deletion ferveo-python/ferveo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@
InvalidDkgPublicKey,
InsufficientValidators,
InvalidTranscriptAggregate,
ValidatorsNotSorted,
ValidatorPublicKeyMismatch,
SerializationError,
InvalidVariant,
InvalidDkgParameters,
InvalidDkgParametersForPrecomputedVariant,
InvalidShareIndex,
DuplicatedShareIndex,
NoTranscriptsToAggregate,
InvalidAggregateVerificationParameters,
)
30 changes: 25 additions & 5 deletions ferveo-python/ferveo/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ class FerveoPublicKey:

@final
class Validator:
def __init__(self, address: str, public_key: FerveoPublicKey): ...
def __init__(self, address: str, public_key: FerveoPublicKey, share_index: int): ...

address: str

public_key: FerveoPublicKey

share_index: int

@final
class Transcript:
@staticmethod
Expand Down Expand Up @@ -104,7 +106,7 @@ class DecryptionSharePrecomputed:
@final
class AggregatedTranscript:
def __init__(self, messages: Sequence[ValidatorMessage]): ...
def verify(self, shares_num: int, messages: Sequence[ValidatorMessage]) -> bool: ...
def verify(self, validators_num: int, messages: Sequence[ValidatorMessage]) -> bool: ...
def create_decryption_share_simple(
self,
dkg: Dkg,
Expand Down Expand Up @@ -189,11 +191,29 @@ class InsufficientValidators(Exception):
class InvalidTranscriptAggregate(Exception):
pass

class ValidatorsNotSorted(Exception):
pass

class ValidatorPublicKeyMismatch(Exception):
pass

class SerializationError(Exception):
pass

class InvalidVariant(Exception):
pass

class InvalidDkgParameters(Exception):
pass

class InvalidDkgParametersForPrecomputedVariant(Exception):
pass

class InvalidShareIndex(Exception):
pass

class DuplicatedShareIndex(Exception):
pass

class NoTranscriptsToAggregate(Exception):
pass

class InvalidAggregateVerificationParameters(Exception):
pass
94 changes: 71 additions & 23 deletions ferveo-python/test/test_ferveo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
combine_decryption_shares_simple,
combine_decryption_shares_precomputed,
decrypt_with_shared_secret,
AggregatedTranscript,
Keypair,
Validator,
ValidatorMessage,
Expand Down Expand Up @@ -37,18 +38,29 @@ def combine_shares_for_variant(v: FerveoVariant, decryption_shares):
raise ValueError("Unknown variant")


def scenario_for_variant(variant: FerveoVariant, shares_num, threshold, shares_to_use):
def scenario_for_variant(
variant: FerveoVariant, shares_num, validators_num, threshold, shares_to_use
):
if variant not in [FerveoVariant.Simple, FerveoVariant.Precomputed]:
raise ValueError("Unknown variant: " + variant)

if validators_num < shares_num:
raise ValueError("validators_num must be >= shares_num")

if variant == FerveoVariant.Precomputed and shares_to_use != validators_num:
raise ValueError(
"In precomputed variant, shares_to_use must be equal to validators_num"
)

tau = 1
validator_keypairs = [Keypair.random() for _ in range(0, shares_num)]
validator_keypairs = [Keypair.random() for _ in range(0, validators_num)]
validators = [
Validator(gen_eth_addr(i), keypair.public_key(), i)
for i, keypair in enumerate(validator_keypairs)
]
validators.sort(key=lambda v: v.address)

# Each validator holds their own DKG instance and generates a transcript every
# validator, including themselves
messages = []
for sender in validators:
dkg = Dkg(
Expand All @@ -60,25 +72,31 @@ def scenario_for_variant(variant: FerveoVariant, shares_num, threshold, shares_t
)
messages.append(ValidatorMessage(sender, dkg.generate_transcript()))

# Both client and server should be able to verify the aggregated transcript
dkg = Dkg(
tau=tau,
shares_num=shares_num,
security_threshold=threshold,
validators=validators,
me=validators[0],
)
pvss_aggregated = dkg.aggregate_transcripts(messages)
assert pvss_aggregated.verify(shares_num, messages)
server_aggregate = dkg.aggregate_transcripts(messages)
assert server_aggregate.verify(validators_num, messages)

dkg_pk_bytes = bytes(dkg.public_key)
dkg_pk = DkgPublicKey.from_bytes(dkg_pk_bytes)
client_aggregate = AggregatedTranscript(messages)
assert client_aggregate.verify(validators_num, messages)

# Client creates a ciphertext and requests decryption shares from validators
msg = "abc".encode()
aad = "my-aad".encode()
ciphertext = encrypt(msg, aad, dkg_pk)
ciphertext = encrypt(msg, aad, dkg.public_key)

# Having aggregated the transcripts, the validators can now create decryption shares
decryption_shares = []
for validator, validator_keypair in zip(validators, validator_keypairs):
assert validator.public_key == validator_keypair.public_key()
print("validator: ", validator.share_index)

dkg = Dkg(
tau=tau,
shares_num=shares_num,
Expand All @@ -87,23 +105,25 @@ def scenario_for_variant(variant: FerveoVariant, shares_num, threshold, shares_t
me=validator,
)
pvss_aggregated = dkg.aggregate_transcripts(messages)
assert pvss_aggregated.verify(shares_num, messages)
assert pvss_aggregated.verify(validators_num, messages)

decryption_share = decryption_share_for_variant(variant, pvss_aggregated)(
dkg, ciphertext.header, aad, validator_keypair
)
decryption_shares.append(decryption_share)

decryption_shares = decryption_shares[:shares_to_use]
# We are limiting the number of decryption shares to use for testing purposes
# decryption_shares = decryption_shares[:shares_to_use]

# Client combines the decryption shares and decrypts the ciphertext
shared_secret = combine_shares_for_variant(variant, decryption_shares)

if variant == FerveoVariant.Simple and len(decryption_shares) < threshold:
with pytest.raises(ThresholdEncryptionError):
decrypt_with_shared_secret(ciphertext, aad, shared_secret)
return

if variant == FerveoVariant.Precomputed and len(decryption_shares) < shares_num:
if variant == FerveoVariant.Precomputed and len(decryption_shares) < threshold:
with pytest.raises(ThresholdEncryptionError):
decrypt_with_shared_secret(ciphertext, aad, shared_secret)
return
Expand All @@ -113,27 +133,55 @@ def scenario_for_variant(variant: FerveoVariant, shares_num, threshold, shares_t


def test_simple_tdec_has_enough_messages():
scenario_for_variant(
FerveoVariant.Simple, shares_num=4, threshold=3, shares_to_use=3
)
shares_num = 4
threshold = shares_num - 1
for validators_num in [shares_num, shares_num + 2]:
scenario_for_variant(
FerveoVariant.Simple,
shares_num=shares_num,
validators_num=validators_num,
threshold=threshold,
shares_to_use=threshold,
)


def test_simple_tdec_doesnt_have_enough_messages():
scenario_for_variant(
FerveoVariant.Simple, shares_num=4, threshold=3, shares_to_use=2
)
shares_num = 4
threshold = shares_num - 1
for validators_num in [shares_num, shares_num + 2]:
scenario_for_variant(
FerveoVariant.Simple,
shares_num=shares_num,
validators_num=validators_num,
threshold=threshold,
shares_to_use=validators_num - 1,
)


def test_precomputed_tdec_has_enough_messages():
scenario_for_variant(
FerveoVariant.Precomputed, shares_num=4, threshold=4, shares_to_use=4
)
shares_num = 4
threshold = shares_num # in precomputed variant, we need all shares
for validators_num in [shares_num, shares_num + 2]:
scenario_for_variant(
FerveoVariant.Precomputed,
shares_num=shares_num,
validators_num=validators_num,
threshold=threshold,
shares_to_use=validators_num,
)


def test_precomputed_tdec_doesnt_have_enough_messages():
scenario_for_variant(
FerveoVariant.Precomputed, shares_num=4, threshold=4, shares_to_use=3
)
shares_num = 4
threshold = shares_num # in precomputed variant, we need all shares
for validators_num in [shares_num, shares_num + 2]:
scenario_for_variant(
FerveoVariant.Simple,
shares_num=shares_num,
validators_num=validators_num,
threshold=threshold,
shares_to_use=threshold - 1,
)


PARAMS = [
Expand Down
Loading

0 comments on commit 175dda7

Please sign in to comment.