Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove use of JSONDict for encryption requests field schema; marshmallow Dict field is sufficient. #30

Merged
merged 5 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 0 additions & 37 deletions porter/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,40 +102,3 @@ def _deserialize(self, value, attr, data, **kwargs):
f"Unexpected object type, {type(result)}; expected {self.expected_type}")

return result


class JSONDict(BaseField, fields.Dict):
"""Serializes/Deserializes Dictionaries to/from JSON strings."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _serialize(self, value, attr, obj, **kwargs):
try:
value = super()._serialize(value, attr, obj, **kwargs)
except Exception as e:
raise InvalidInputData(
f"Could not convert input for {self.name} to JSON: {e}"
)
try:
value_json = json.dumps(value)
return value_json
except Exception as e:
raise InvalidInputData(
f"Could not convert input for {self.name} to JSON: {e}"
)

def _deserialize(self, value, attr, data, **kwargs):
try:
result = json.loads(value)
except Exception as e:
raise InvalidInputData(
f"Could not convert input for {self.name} to dictionary: {e}"
)

try:
return super()._deserialize(result, attr, data, **kwargs)
except Exception as e:
raise InvalidInputData(
f"Could not convert input for {self.name} to dictionary: {e}"
)
24 changes: 10 additions & 14 deletions porter/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
from marshmallow import INCLUDE, Schema
from marshmallow import fields as marshmallow_fields
from marshmallow import validates_schema
from marshmallow.fields import URL, Dict, String

from porter.cli.types import EIP55_CHECKSUM_ADDRESS
from porter.fields.base import (
JSON,
Base64BytesRepresentation,
Integer,
JSONDict,
PositiveInteger,
StringList,
)
from porter.fields.base import JSON, Integer, PositiveInteger, StringList
from porter.fields.cbd import (
EncryptedThresholdDecryptionRequestField,
EncryptedThresholdDecryptionResponseField,
Expand Down Expand Up @@ -54,7 +46,7 @@ def option_bob_encrypting_key():
class UrsulaInfoSchema(BaseSchema):
"""Schema for the result of sampling of Ursulas."""
checksum_address = UrsulaChecksumAddress()
uri = URL()
uri = marshmallow_fields.URL()
encrypting_key = UmbralKey()

# maintain field declaration ordering
Expand Down Expand Up @@ -132,7 +124,9 @@ class PRERetrievalOutcomeSchema(BaseSchema):
"""Schema for the result of /retrieve_cfrags endpoint."""

cfrags = marshmallow_fields.Dict(keys=UrsulaChecksumAddress(), values=CapsuleFrag())
errors = marshmallow_fields.Dict(keys=UrsulaChecksumAddress(), values=String())
errors = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(), values=marshmallow_fields.String()
)

# maintain field declaration ordering
class Meta:
Expand Down Expand Up @@ -211,10 +205,12 @@ class PRERetrieveCFrags(BaseSchema):
class CBDDecryptionOutcomeSchema(BaseSchema):
"""Schema for the result of /retrieve_cfrags endpoint."""

encrypted_decryption_responses = Dict(
encrypted_decryption_responses = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(), values=EncryptedThresholdDecryptionResponseField()
)
errors = Dict(keys=UrsulaChecksumAddress(), values=String())
errors = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(), values=marshmallow_fields.String()
)

# maintain field declaration ordering
class Meta:
Expand All @@ -233,7 +229,7 @@ class CBDDecrypt(BaseSchema):
required=True
)
)
encrypted_decryption_requests = JSONDict(
encrypted_decryption_requests = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(),
values=EncryptedThresholdDecryptionRequestField(),
required=True,
Expand Down
19 changes: 6 additions & 13 deletions tests/cbd/test_cbd_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,37 +61,30 @@ def test_cbd_decrypt(

with pytest.raises(InvalidInputData):
request_data = {
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests)
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

# invalid param names
with pytest.raises(InvalidInputData):
request_data = {
"dkg_threshold": threshold,
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

with pytest.raises(InvalidInputData):
request_data = {
"threshold": threshold,
"encrypted_dec_requests": json.dumps(encrypted_decryption_requests),
"encrypted_dec_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

# invalid param types
with pytest.raises(InvalidInputData):
request_data = {
"threshold": "threshold? we don't need no stinking threshold",
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
}
cbd_decrypt_schema.load(request_data)

with pytest.raises(InvalidInputData):
request_data = {
"threshold": threshold,
"encrypted_decryption_requests": encrypted_decryption_requests, # not json string
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

Expand All @@ -100,14 +93,14 @@ def test_cbd_decrypt(
request_data = {
"threshold": len(encrypted_decryption_requests)
+ 1, # threshold larger than number of requests
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

# simple schema successful load
request_data = {
"threshold": threshold,
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}
cbd_decrypt_schema.load(request_data)

Expand Down
5 changes: 2 additions & 3 deletions tests/cbd/test_porter_cbd_web_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ThresholdDecryptionRequest,
)
from nucypher_core.ferveo import (
Ciphertext,
DecryptionShareSimple,
combine_decryption_shares_simple,
decrypt_with_shared_secret,
Expand Down Expand Up @@ -67,7 +66,7 @@ def test_cbd_decrypt(

request_data = {
"threshold": threshold,
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}

#
Expand Down Expand Up @@ -157,7 +156,7 @@ def test_cbd_decrypt(

request_data = {
"threshold": threshold,
"encrypted_decryption_requests": json.dumps(encrypted_decryption_requests),
"encrypted_decryption_requests": encrypted_decryption_requests,
}
response = porter_web_controller.post(
"/cbd_decrypt", data=json.dumps(request_data)
Expand Down
29 changes: 9 additions & 20 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from eth_utils import to_canonical_address
from marshmallow import fields as marshmallow_fields
from nucypher.crypto.ferveo.dkg import FerveoVariant
from nucypher_core import (
Address,
Expand All @@ -23,7 +24,6 @@
from porter.fields.base import (
JSON,
Base64BytesRepresentation,
JSONDict,
PositiveInteger,
String,
StringList,
Expand Down Expand Up @@ -260,12 +260,12 @@ def test_json_field():
field._deserialize(value=json.dumps(d), attr=None, data=None)


def test_cbd_json_dict_field(get_random_checksum_address):
def test_cbd_dict_field(get_random_checksum_address):
# test data
original_data = {}
expected_serialized_result = {}
num_decryption_requests = 5
for i in range(0, 5):
for i in range(0, num_decryption_requests):
ursula_checksum_address = get_random_checksum_address()
encrypted_decryption_request = os.urandom(32)
original_data[ursula_checksum_address] = encrypted_decryption_request
Expand All @@ -274,36 +274,25 @@ def test_cbd_json_dict_field(get_random_checksum_address):
).decode()

# mimic usage for CBD
field = JSONDict(keys=UrsulaChecksumAddress(), values=Base64BytesRepresentation())
field = marshmallow_fields.Dict(
keys=UrsulaChecksumAddress(), values=Base64BytesRepresentation()
)
serialized = field._serialize(value=original_data, attr=None, obj=None)
assert serialized == json.dumps(expected_serialized_result)
assert serialized == expected_serialized_result

deserialized = field._deserialize(value=serialized, attr=None, data=None)
assert deserialized == original_data

with pytest.raises(InvalidInputData):
# attempt to deserialize invalid key; must be checksum address
json_to_deserialize = json.dumps({"a": b64encode(os.urandom(32)).decode()})
json_to_deserialize = {"a": b64encode(os.urandom(32)).decode()}
field._deserialize(value=json_to_deserialize, attr=None, data=None)

with pytest.raises(InvalidInputData):
# attempt to deserialize invalid value; must be base64 string
json_to_deserialize = json.dumps({get_random_checksum_address(): 1})
json_to_deserialize = {get_random_checksum_address(): "+_--1"}
derekpierre marked this conversation as resolved.
Show resolved Hide resolved
field._deserialize(value=json_to_deserialize, attr=None, data=None)

with pytest.raises(InvalidInputData):
# attempt to deserialize non-dict object
json_to_deserialize = json.dumps({get_random_checksum_address(): 1})
field._deserialize("the hills are alive...", attr=None, data=None)

with pytest.raises(InvalidInputData):
# non-dict object
field._serialize(value=[1, 2, 3], attr=None, obj=None)

with pytest.raises(InvalidInputData):
# attempt to serialize invalid key; must be checksum address
field._serialize(value={"a": os.urandom(32)}, attr=None, obj=None)


def test_encrypted_threshold_decryption_request(dkg_setup, dkg_encrypted_data):
ritual_id, _, _, _, _ = dkg_setup
Expand Down