diff --git a/ferveo-python/ferveo/__init__.pyi b/ferveo-python/ferveo/__init__.pyi index 4d540655..503bc77e 100644 --- a/ferveo-python/ferveo/__init__.pyi +++ b/ferveo-python/ferveo/__init__.pyi @@ -179,11 +179,8 @@ class SharedSecret: @final class FerveoVariant: - @staticmethod - def simple() -> str: ... - - @staticmethod - def precomputed() -> str: ... + simple: str + precomputed: str def encrypt(message: bytes, aad: bytes, dkg_public_key: DkgPublicKey) -> Ciphertext: diff --git a/ferveo-python/test/test_ferveo.py b/ferveo-python/test/test_ferveo.py index b045fec3..38fa1f06 100644 --- a/ferveo-python/test/test_ferveo.py +++ b/ferveo-python/test/test_ferveo.py @@ -11,7 +11,8 @@ Dkg, AggregatedTranscript, DkgPublicKey, - ThresholdEncryptionError + ThresholdEncryptionError, + FerveoVariant ) @@ -19,26 +20,26 @@ def gen_eth_addr(i: int) -> str: return f"0x{i:040x}" -def decryption_share_for_variant(variant, agg_transcript): - if variant == "simple": +def decryption_share_for_variant(v: FerveoVariant, agg_transcript): + if v == FerveoVariant.simple: return agg_transcript.create_decryption_share_simple - elif variant == "precomputed": + elif v == FerveoVariant.precomputed: return agg_transcript.create_decryption_share_precomputed else: raise ValueError("Unknown variant") -def combine_shares_for_variant(variant, decryption_shares): - if variant == "simple": +def combine_shares_for_variant(v: FerveoVariant, decryption_shares): + if v == FerveoVariant.simple: return combine_decryption_shares_simple(decryption_shares) - elif variant == "precomputed": + elif v == FerveoVariant.precomputed: return combine_decryption_shares_precomputed(decryption_shares) else: raise ValueError("Unknown variant") -def scenario_for_variant(variant, shares_num, threshold, shares_to_use): - if variant not in ["simple", "precomputed"]: +def scenario_for_variant(variant: FerveoVariant, shares_num, threshold, shares_to_use): + if variant not in [FerveoVariant.simple, FerveoVariant.precomputed]: raise ValueError("Unknown variant: " + variant) tau = 1 @@ -98,12 +99,12 @@ def scenario_for_variant(variant, shares_num, threshold, shares_to_use): shared_secret = combine_shares_for_variant(variant, decryption_shares) - if variant == "simple" and len(decryption_shares) < threshold: + if variant == FerveoVariant.simple and len(decryption_shares) < threshold: with pytest.raises(ThresholdEncryptionError): decrypt_with_shared_secret(ciphertext, aad, shared_secret) return - if variant == "precomputed" and len(decryption_shares) < shares_num: + if variant == FerveoVariant.precomputed and len(decryption_shares) < shares_num: with pytest.raises(ThresholdEncryptionError): decrypt_with_shared_secret(ciphertext, aad, shared_secret) return @@ -113,30 +114,30 @@ def scenario_for_variant(variant, shares_num, threshold, shares_to_use): def test_simple_tdec_has_enough_messages(): - scenario_for_variant("simple", shares_num=4, threshold=3, shares_to_use=3) + scenario_for_variant(FerveoVariant.simple, shares_num=4, threshold=3, shares_to_use=3) def test_simple_tdec_doesnt_have_enough_messages(): - scenario_for_variant("simple", shares_num=4, threshold=3, shares_to_use=2) + scenario_for_variant(FerveoVariant.simple, shares_num=4, threshold=3, shares_to_use=2) def test_precomputed_tdec_has_enough_messages(): - scenario_for_variant("precomputed", shares_num=4, threshold=4, shares_to_use=4) + scenario_for_variant(FerveoVariant.precomputed, shares_num=4, threshold=4, shares_to_use=4) def test_precomputed_tdec_doesnt_have_enough_messages(): - scenario_for_variant("precomputed", shares_num=4, threshold=4, shares_to_use=3) + scenario_for_variant(FerveoVariant.precomputed, shares_num=4, threshold=4, shares_to_use=3) PARAMS = [ - (1, 'simple'), - (4, 'simple'), - (8, 'simple'), - (32, 'simple'), - (1, 'precomputed'), - (4, 'precomputed'), - (8, 'precomputed'), - (32, 'precomputed'), + (1, FerveoVariant.simple), + (4, FerveoVariant.simple), + (8, FerveoVariant.simple), + (32, FerveoVariant.simple), + (1, FerveoVariant.precomputed), + (4, FerveoVariant.precomputed), + (8, FerveoVariant.precomputed), + (32, FerveoVariant.precomputed), ] diff --git a/ferveo-python/test/test_serialization.py b/ferveo-python/test/test_serialization.py index ee48cd8a..f608f5b9 100644 --- a/ferveo-python/test/test_serialization.py +++ b/ferveo-python/test/test_serialization.py @@ -81,5 +81,5 @@ def test_public_key_serialization(): def test_ferveo_variant_serialization(): - assert FerveoVariant.precomputed() == "FerveoVariant::Precomputed" - assert FerveoVariant.simple() == "FerveoVariant::Simple" + assert FerveoVariant.precomputed == "FerveoVariant::Precomputed" + assert FerveoVariant.simple == "FerveoVariant::Simple" diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index 99538e60..eab2a3db 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -273,12 +273,12 @@ struct FerveoVariant {} #[pymethods] impl FerveoVariant { - #[staticmethod] + #[classattr] fn precomputed() -> &'static str { api::FerveoVariant::Precomputed.as_str() } - #[staticmethod] + #[classattr] fn simple() -> &'static str { api::FerveoVariant::Simple.as_str() }