Skip to content

Commit

Permalink
No need to use JSONDict for encryption request field schema, a simple…
Browse files Browse the repository at this point in the history
… marshmallow Dict field is sufficient.
  • Loading branch information
derekpierre committed Jun 23, 2023
1 parent 30a2cd8 commit 717da68
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 87 deletions.
37 changes: 0 additions & 37 deletions porter/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,40 +102,3 @@ def _deserialize(self, value, attr, data, **kwargs):
f"Unexpected object type, {type(result)}; expected {self.expected_type}")

return result


class JSONDict(BaseField, fields.Dict):
"""Serializes/Deserializes Dictionaries to/from JSON strings."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _serialize(self, value, attr, obj, **kwargs):
try:
value = super()._serialize(value, attr, obj, **kwargs)
except Exception as e:
raise InvalidInputData(
f"Could not convert input for {self.name} to JSON: {e}"
)
try:
value_json = json.dumps(value)
return value_json
except Exception as e:
raise InvalidInputData(
f"Could not convert input for {self.name} to JSON: {e}"
)

def _deserialize(self, value, attr, data, **kwargs):
try:
result = json.loads(value)
except Exception as e:
raise InvalidInputData(
f"Could not convert input for {self.name} to dictionary: {e}"
)

try:
return super()._deserialize(result, attr, data, **kwargs)
except Exception as e:
raise InvalidInputData(
f"Could not convert input for {self.name} to dictionary: {e}"
)
24 changes: 10 additions & 14 deletions porter/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
from marshmallow import INCLUDE, Schema
from marshmallow import fields as marshmallow_fields
from marshmallow import validates_schema
from marshmallow.fields import URL, Dict, String

from porter.cli.types import EIP55_CHECKSUM_ADDRESS
from porter.fields.base import (
JSON,
Base64BytesRepresentation,
Integer,
JSONDict,
PositiveInteger,
StringList,
)
from porter.fields.base import JSON, Integer, PositiveInteger, StringList
from porter.fields.cbd import (
EncryptedThresholdDecryptionRequestField,
EncryptedThresholdDecryptionResponseField,
Expand Down Expand Up @@ -54,7 +46,7 @@ def option_bob_encrypting_key():
class UrsulaInfoSchema(BaseSchema):
"""Schema for the result of sampling of Ursulas."""
checksum_address = UrsulaChecksumAddress()
uri = URL()
uri = marshmallow_fields.URL()
encrypting_key = UmbralKey()

# maintain field declaration ordering
Expand Down Expand Up @@ -132,7 +124,9 @@ class PRERetrievalOutcomeSchema(BaseSchema):
"""Schema for the result of /retrieve_cfrags endpoint."""

cfrags = marshmallow_fields.Dict(keys=UrsulaChecksumAddress(), values=CapsuleFrag())
errors = marshmallow_fields.Dict(keys=UrsulaChecksumAddress(), values=String())
errors = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(), values=marshmallow_fields.String()
)

# maintain field declaration ordering
class Meta:
Expand Down Expand Up @@ -211,10 +205,12 @@ class PRERetrieveCFrags(BaseSchema):
class CBDDecryptionOutcomeSchema(BaseSchema):
"""Schema for the result of /retrieve_cfrags endpoint."""

encrypted_decryption_responses = Dict(
encrypted_decryption_responses = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(), values=EncryptedThresholdDecryptionResponseField()
)
errors = Dict(keys=UrsulaChecksumAddress(), values=String())
errors = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(), values=marshmallow_fields.String()
)

# maintain field declaration ordering
class Meta:
Expand All @@ -233,7 +229,7 @@ class CBDDecrypt(BaseSchema):
required=True
)
)
encrypted_decryption_requests = JSONDict(
encrypted_decryption_requests = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(),
values=EncryptedThresholdDecryptionRequestField(),
required=True,
Expand Down
19 changes: 6 additions & 13 deletions tests/cbd/test_cbd_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,37 +61,30 @@ def test_cbd_decrypt(

with pytest.raises(InvalidInputData):
request_data = {
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests)
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

# invalid param names
with pytest.raises(InvalidInputData):
request_data = {
"dkg_threshold": threshold,
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

with pytest.raises(InvalidInputData):
request_data = {
"threshold": threshold,
"encrypted_dec_requests": json.dumps(encrypted_decryption_requests),
"encrypted_dec_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

# invalid param types
with pytest.raises(InvalidInputData):
request_data = {
"threshold": "threshold? we don't need no stinking threshold",
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
}
cbd_decrypt_schema.load(request_data)

with pytest.raises(InvalidInputData):
request_data = {
"threshold": threshold,
"encrypted_decryption_requests": encrypted_decryption_requests, # not json string
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

Expand All @@ -100,14 +93,14 @@ def test_cbd_decrypt(
request_data = {
"threshold": len(encrypted_decryption_requests)
+ 1, # threshold larger than number of requests
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

# simple schema successful load
request_data = {
"threshold": threshold,
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

Expand Down
5 changes: 2 additions & 3 deletions tests/cbd/test_porter_cbd_web_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ThresholdDecryptionRequest,
)
from nucypher_core.ferveo import (
Ciphertext,
DecryptionShareSimple,
combine_decryption_shares_simple,
decrypt_with_shared_secret,
Expand Down Expand Up @@ -67,7 +66,7 @@ def test_cbd_decrypt(

request_data = {
"threshold": threshold,
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}

#
Expand Down Expand Up @@ -157,7 +156,7 @@ def test_cbd_decrypt(

request_data = {
"threshold": threshold,
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}
response = porter_web_controller.post(
"/cbd_decrypt", data=json.dumps(request_data)
Expand Down
29 changes: 9 additions & 20 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from eth_utils import to_canonical_address
from marshmallow import fields as marshmallow_fields
from nucypher.crypto.ferveo.dkg import FerveoVariant
from nucypher_core import (
Address,
Expand All @@ -23,7 +24,6 @@
from porter.fields.base import (
JSON,
Base64BytesRepresentation,
JSONDict,
PositiveInteger,
String,
StringList,
Expand Down Expand Up @@ -260,12 +260,12 @@ def test_json_field():
field._deserialize(value=json.dumps(d), attr=None, data=None)


def test_cbd_json_dict_field(get_random_checksum_address):
def test_cbd_dict_field(get_random_checksum_address):
# test data
original_data = {}
expected_serialized_result = {}
num_decryption_requests = 5
for i in range(0, 5):
for i in range(0, num_decryption_requests):
ursula_checksum_address = get_random_checksum_address()
encrypted_decryption_request = os.urandom(32)
original_data[ursula_checksum_address] = encrypted_decryption_request
Expand All @@ -274,36 +274,25 @@ def test_cbd_json_dict_field(get_random_checksum_address):
).decode()

# mimic usage for CBD
field = JSONDict(keys=UrsulaChecksumAddress(), values=Base64BytesRepresentation())
field = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(), values=Base64BytesRepresentation()
)
serialized = field._serialize(value=original_data, attr=None, obj=None)
assert serialized == json.dumps(expected_serialized_result)
assert serialized == expected_serialized_result

deserialized = field._deserialize(value=serialized, attr=None, data=None)
assert deserialized == original_data

with pytest.raises(InvalidInputData):
# attempt to deserialize invalid key; must be checksum address
json_to_deserialize = json.dumps({"a": b64encode(os.urandom(32)).decode()})
json_to_deserialize = {"a": b64encode(os.urandom(32)).decode()}
field._deserialize(value=json_to_deserialize, attr=None, data=None)

with pytest.raises(InvalidInputData):
# attempt to deserialize invalid value; must be base64 string
json_to_deserialize = json.dumps({get_random_checksum_address(): 1})
json_to_deserialize = {get_random_checksum_address(): "+_--1"}
field._deserialize(value=json_to_deserialize, attr=None, data=None)

with pytest.raises(InvalidInputData):
# attempt to deserialize non-dict object
json_to_deserialize = json.dumps({get_random_checksum_address(): 1})
field._deserialize("the hills are alive...", attr=None, data=None)

with pytest.raises(InvalidInputData):
# non-dict object
field._serialize(value=[1, 2, 3], attr=None, obj=None)

with pytest.raises(InvalidInputData):
# attempt to serialize invalid key; must be checksum address
field._serialize(value={"a": os.urandom(32)}, attr=None, obj=None)


def test_encrypted_threshold_decryption_request(dkg_setup, dkg_encrypted_data):
ritual_id, _, _, _, _ = dkg_setup
Expand Down

0 comments on commit 717da68

Please sign in to comment.