diff --git a/ferveo-python/ferveo/__init__.pyi b/ferveo-python/ferveo/__init__.pyi index 4b5aaa4c..1dfab2f0 100644 --- a/ferveo-python/ferveo/__init__.pyi +++ b/ferveo-python/ferveo/__init__.pyi @@ -195,8 +195,8 @@ class SharedSecret: @final class FerveoVariant: - simple: FerveoVariant - precomputed: FerveoVariant + Simple: FerveoVariant + Precomputed: FerveoVariant def __eq__(self, other: object) -> bool: ... diff --git a/ferveo-python/test/test_ferveo.py b/ferveo-python/test/test_ferveo.py index 9406a5ae..7d93c637 100644 --- a/ferveo-python/test/test_ferveo.py +++ b/ferveo-python/test/test_ferveo.py @@ -21,25 +21,25 @@ def gen_eth_addr(i: int) -> str: def decryption_share_for_variant(v: FerveoVariant, agg_transcript): - if v == FerveoVariant.simple: + if v == FerveoVariant.Simple: return agg_transcript.create_decryption_share_simple - elif v == FerveoVariant.precomputed: + elif v == FerveoVariant.Precomputed: return agg_transcript.create_decryption_share_precomputed else: raise ValueError("Unknown variant") def combine_shares_for_variant(v: FerveoVariant, decryption_shares): - if v == FerveoVariant.simple: + if v == FerveoVariant.Simple: return combine_decryption_shares_simple(decryption_shares) - elif v == FerveoVariant.precomputed: + elif v == FerveoVariant.Precomputed: return combine_decryption_shares_precomputed(decryption_shares) else: raise ValueError("Unknown variant") def scenario_for_variant(variant: FerveoVariant, shares_num, threshold, shares_to_use): - if variant not in [FerveoVariant.simple, FerveoVariant.precomputed]: + if variant not in [FerveoVariant.Simple, FerveoVariant.Precomputed]: raise ValueError("Unknown variant: " + variant) tau = 1 @@ -99,12 +99,12 @@ def scenario_for_variant(variant: FerveoVariant, shares_num, threshold, shares_t shared_secret = combine_shares_for_variant(variant, decryption_shares) - if variant == FerveoVariant.simple and len(decryption_shares) < threshold: + if variant == FerveoVariant.Simple and len(decryption_shares) < threshold: with pytest.raises(ThresholdEncryptionError): decrypt_with_shared_secret(ciphertext, aad, shared_secret) return - if variant == FerveoVariant.precomputed and len(decryption_shares) < shares_num: + if variant == FerveoVariant.Precomputed and len(decryption_shares) < shares_num: with pytest.raises(ThresholdEncryptionError): decrypt_with_shared_secret(ciphertext, aad, shared_secret) return @@ -114,32 +114,32 @@ def scenario_for_variant(variant: FerveoVariant, shares_num, threshold, shares_t def test_simple_tdec_has_enough_messages(): - scenario_for_variant(FerveoVariant.simple, shares_num=4, threshold=3, shares_to_use=3) + scenario_for_variant(FerveoVariant.Simple, shares_num=4, threshold=3, shares_to_use=3) def test_simple_tdec_doesnt_have_enough_messages(): - scenario_for_variant(FerveoVariant.simple, shares_num=4, threshold=3, shares_to_use=2) + scenario_for_variant(FerveoVariant.Simple, shares_num=4, threshold=3, shares_to_use=2) def test_precomputed_tdec_has_enough_messages(): - scenario_for_variant(FerveoVariant.precomputed, shares_num=4, threshold=4, shares_to_use=4) + scenario_for_variant(FerveoVariant.Precomputed, shares_num=4, threshold=4, shares_to_use=4) def test_precomputed_tdec_doesnt_have_enough_messages(): - scenario_for_variant(FerveoVariant.precomputed, shares_num=4, threshold=4, shares_to_use=3) + scenario_for_variant(FerveoVariant.Precomputed, shares_num=4, threshold=4, shares_to_use=3) PARAMS = [ - (1, FerveoVariant.simple), - (3, FerveoVariant.simple), - (4, FerveoVariant.simple), - (7, FerveoVariant.simple), - (8, FerveoVariant.simple), - (1, FerveoVariant.precomputed), - (3, FerveoVariant.precomputed), - (4, FerveoVariant.precomputed), - (7, FerveoVariant.precomputed), - (8, FerveoVariant.precomputed), + (1, FerveoVariant.Simple), + (3, FerveoVariant.Simple), + (4, FerveoVariant.Simple), + (7, FerveoVariant.Simple), + (8, FerveoVariant.Simple), + (1, FerveoVariant.Precomputed), + (3, FerveoVariant.Precomputed), + (4, FerveoVariant.Precomputed), + (7, FerveoVariant.Precomputed), + (8, FerveoVariant.Precomputed), ] TEST_CASES_WITH_THRESHOLD_RANGE = [] diff --git a/ferveo-python/test/test_serialization.py b/ferveo-python/test/test_serialization.py index 78e07868..8533d437 100644 --- a/ferveo-python/test/test_serialization.py +++ b/ferveo-python/test/test_serialization.py @@ -81,8 +81,8 @@ def test_public_key_serialization(): def test_ferveo_variant_serialization(): - assert str(FerveoVariant.precomputed) == "FerveoVariant::Precomputed" - assert str(FerveoVariant.simple) == "FerveoVariant::Simple" - assert FerveoVariant.precomputed == FerveoVariant.precomputed - assert FerveoVariant.simple == FerveoVariant.simple - assert FerveoVariant.precomputed != FerveoVariant.simple + assert str(FerveoVariant.Precomputed) == "FerveoVariant::Precomputed" + assert str(FerveoVariant.Simple) == "FerveoVariant::Simple" + assert FerveoVariant.Precomputed == FerveoVariant.Precomputed + assert FerveoVariant.Simple == FerveoVariant.Simple + assert FerveoVariant.Precomputed != FerveoVariant.Simple diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index 8ee06dbd..7d1cb93b 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -279,11 +279,13 @@ pub struct FerveoVariant(pub(crate) api::FerveoVariant); #[pymethods] impl FerveoVariant { #[classattr] + #[pyo3(name = "Precomputed")] fn precomputed() -> FerveoVariant { api::FerveoVariant::Precomputed.into() } #[classattr] + #[pyo3(name = "Simple")] fn simple() -> FerveoVariant { api::FerveoVariant::Simple.into() }