Skip to content

Commit

Permalink
Merge pull request #9055 from OpenMined/fix-serde-for-types
Browse files Browse the repository at this point in the history
make serialization of types less dynamic
  • Loading branch information
koenvanderveen authored Jul 19, 2024
2 parents 4faf529 + 562ce8c commit d167b49
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 133 deletions.
7 changes: 3 additions & 4 deletions packages/syft/src/syft/capnp/recursive_serde.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
struct RecursiveSerde {
fieldsName @0 :List(Text);
fieldsData @1 :List(List(Data));
fullyQualifiedName @2 :Text;
nonrecursiveBlob @3 :List(Data);
canonicalName @4 :Text;
version @5 :Int32;
nonrecursiveBlob @2 :List(Data);
canonicalName @3 :Text;
version @4 :Int32;
}
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ..protocol.data_protocol import get_data_protocol
from ..protocol.data_protocol import migrate_args_and_kwargs
from ..serde.deserialize import _deserialize
from ..serde.recursive import index_syft_by_module_name
from ..serde.serializable import serializable
from ..serde.serialize import _serialize
from ..serde.signature import Signature
Expand Down Expand Up @@ -63,6 +62,7 @@
from ..util.markdown import as_markdown_python_code
from ..util.notebook_ui.components.tabulator_template import build_tabulator_table
from ..util.telemetry import instrument
from ..util.util import index_syft_by_module_name
from ..util.util import prompt_warning_message
from .connection import ServerConnection

Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/custom_worker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def merged_custom_cmds(self, sep: str = ";") -> str:
return sep.join(self.custom_cmds)


@serializable(canonical_name="WorkerConfig", version=1)
class WorkerConfig(SyftBaseModel):
pass

Expand Down
21 changes: 21 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
{
"dev": {
"object_versions": {
"SyftObjectVersioned": {
"1": {
"version": 1,
"hash": "7c842dcdbb57e2528ffa690ea18c19fff3c8a591811d40cad2b19be3100e2ff4",
"action": "add"
}
},
"BaseDateTime": {
"1": {
"version": 1,
"hash": "614db484b1950be729902b1861bd3a7b33899176507c61cef11dc0d44611cfd3",
"action": "add"
}
},
"SyftObject": {
"1": {
"version": 1,
"hash": "bb70d874355988908d3a92a3941d6613a6995a4850be3b6a0147f4d387724406",
"action": "add"
}
},
"PartialSyftObject": {
"1": {
"version": 1,
Expand Down Expand Up @@ -988,6 +1002,13 @@
"action": "add"
}
},
"ProjectEvent": {
"1": {
"version": 1,
"hash": "dc0486c52daebd5e98c2b3b03ffd9a9a14bc3d86d8dc0c23e41ebf6c31fe2ffb",
"action": "add"
}
},
"ProjectThreadMessage": {
"1": {
"version": 1,
Expand Down
8 changes: 8 additions & 0 deletions packages/syft/src/syft/serde/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.number,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.number)[0],
canonical_name="numpy_number",
version=SYFT_OBJECT_VERSION_1,
)

# TODO: There is an incorrect mapping in looping,which makes it not work.
# numpy_scalar_types = [
# np.bool_,
Expand Down
55 changes: 14 additions & 41 deletions packages/syft/src/syft/serde/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from enum import Enum
from enum import EnumMeta
import os
import sys
import tempfile
import types
from typing import Any
Expand All @@ -17,7 +16,6 @@

# relative
from ..types.syft_object_registry import SyftObjectRegistry
from ..util.util import index_syft_by_module_name
from .capnp import get_capnp_schema
from .util import compatible_with_large_file_writes_capnp

Expand Down Expand Up @@ -285,7 +283,7 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild
# todo: rewrite and make sure every object has a canonical name and version
canonical_name, version = SyftObjectRegistry.get_canonical_name_version(self)

if not SyftObjectRegistry.has_serde_class("", canonical_name, version):
if not SyftObjectRegistry.has_serde_class(canonical_name, version):
# third party
raise Exception(
f"obj2proto: {canonical_name} version {version} not in SyftObjectRegistry"
Expand Down Expand Up @@ -382,34 +380,12 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any:
# relative
from .deserialize import _deserialize

# clean this mess, Tudor
module_parts = proto.fullyQualifiedName.split(".")
klass = module_parts.pop()
class_type: type | Any = type(None)

if klass != "NoneType":
try:
class_type = index_syft_by_module_name(proto.fullyQualifiedName) # type: ignore[assignment,unused-ignore]
except Exception: # nosec
try:
class_type = getattr(sys.modules[".".join(module_parts)], klass)
except Exception: # nosec
if "syft.user" in proto.fullyQualifiedName:
# relative
from ..server.server import CODE_RELOADER

for load_user_code in CODE_RELOADER.values():
load_user_code()
try:
class_type = getattr(sys.modules[".".join(module_parts)], klass)
except Exception: # nosec
pass

canonical_name = proto.canonicalName
version = getattr(proto, "version", -1)
fqn = getattr(proto, "fullyQualifiedName", "")
fqn = map_fqns_for_backward_compatibility(fqn)
if not SyftObjectRegistry.has_serde_class(fqn, canonical_name, version):

if not SyftObjectRegistry.has_serde_class(canonical_name, version):
# third party
raise Exception(
f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry"
Expand All @@ -431,13 +407,9 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any:
cls,
_,
version,
) = SyftObjectRegistry.get_serde_properties_bw_compatible(
fqn, canonical_name, version
)
) = SyftObjectRegistry.get_serde_properties(canonical_name, version)

if class_type == type(None) or fqn != "":
# yes this looks stupid but it works and the opposite breaks
class_type = cls
class_type = cls

if nonrecursive:
if deserialize is None:
Expand Down Expand Up @@ -468,14 +440,15 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any:
# if we skip the __new__ flow of BaseModel we get the error
# AttributeError: object has no attribute '__fields_set__'

if "syft.user" in proto.fullyQualifiedName:
# weird issues with pydantic and ForwardRef on user classes being inited
# with custom state args / kwargs
obj = class_type()
for attr_name, attr_value in kwargs.items():
setattr(obj, attr_name, attr_value)
else:
obj = class_type(**kwargs)
# if "syft.user" in proto.fullyQualifiedName:
# # weird issues with pydantic and ForwardRef on user classes being inited
# # with custom state args / kwargs
# obj = class_type()
# for attr_name, attr_value in kwargs.items():
# setattr(obj, attr_name, attr_value)
# else:
# obj = class_type(**kwargs)
obj = class_type(**kwargs)

else:
obj = class_type.__new__(class_type) # type: ignore
Expand Down
86 changes: 73 additions & 13 deletions packages/syft/src/syft/serde/recursive_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from enum import Enum
from enum import EnumMeta
import functools
import inspect
import pathlib
from pathlib import PurePath
import sys
import tempfile
from types import MappingProxyType
from types import UnionType
import typing
from typing import Any
from typing import GenericAlias
from typing import Optional
Expand All @@ -27,6 +29,7 @@
import weakref

# relative
from ..types.syft_object_registry import SyftObjectRegistry
from .capnp import get_capnp_schema
from .recursive import chunk_bytes
from .recursive import combine_bytes
Expand Down Expand Up @@ -169,22 +172,31 @@ def deserialize_enum(enum_type: type, enum_buf: bytes) -> Enum:
return enum_type(enum_value)


def serialize_type(serialized_type: type) -> bytes:
def serialize_type(_type_to_serialize: type) -> bytes:
# relative
from ..util.util import full_name_with_qualname
type_to_serialize = typing.get_origin(_type_to_serialize) or _type_to_serialize
canonical_name, version = SyftObjectRegistry.get_identifier_for_type(
type_to_serialize
)
return f"{canonical_name}:{version}".encode()

fqn = full_name_with_qualname(klass=serialized_type)
module_parts = fqn.split(".")
return ".".join(module_parts).encode()
# from ..util.util import full_name_with_qualname

# fqn = full_name_with_qualname(klass=serialized_type)
# module_parts = fqn.split(".")
# return ".".join(module_parts).encode()


def deserialize_type(type_blob: bytes) -> type:
deserialized_type = type_blob.decode()
module_parts = deserialized_type.split(".")
klass = module_parts.pop()
klass = "None" if klass == "NoneType" else klass
exception_type = getattr(sys.modules[".".join(module_parts)], klass)
return exception_type
canonical_name, version = deserialized_type.split(":", 1)
return SyftObjectRegistry.get_serde_class(canonical_name, int(version))

# module_parts = deserialized_type.split(".")
# klass = module_parts.pop()
# klass = "None" if klass == "NoneType" else klass
# exception_type = getattr(sys.modules[".".join(module_parts)], klass)
# return exception_type


TPath = TypeVar("TPath", bound=PurePath)
Expand Down Expand Up @@ -434,6 +446,8 @@ def recursive_serde_register_type(
canonical_name: str | None = None,
version: int | None = None,
) -> None:
# former case is for instance for _GerericAlias itself or UnionGenericAlias
# Latter case is true for for instance List[str], which is currently not used
if (isinstance(t, type) and issubclass(t, _GenericAlias)) or issubclass(
type(t), _GenericAlias
):
Expand Down Expand Up @@ -471,6 +485,31 @@ def deserialize_union_type(type_blob: bytes) -> type:
return functools.reduce(lambda x, y: x | y, args)


def serialize_union(serialized_type: UnionType) -> bytes:
return b""


def deserialize_union(type_blob: bytes) -> type: # type: ignore
return Union # type: ignore


def serialize_typevar(serialized_type: TypeVar) -> bytes:
return f"{serialized_type.__name__}".encode()


def deserialize_typevar(type_blob: bytes) -> type:
name = type_blob.decode()
return TypeVar(name=name) # type: ignore


def serialize_any(serialized_type: TypeVar) -> bytes:
return b""


def deserialize_any(type_blob: bytes) -> type: # type: ignore
return Any # type: ignore


recursive_serde_register(
UnionType,
serialize=serialize_union_type,
Expand All @@ -481,8 +520,27 @@ def deserialize_union_type(type_blob: bytes) -> type:

recursive_serde_register_type(_SpecialForm, canonical_name="_SpecialForm", version=1)
recursive_serde_register_type(_GenericAlias, canonical_name="_GenericAlias", version=1)
recursive_serde_register_type(Union, canonical_name="Union", version=1)
recursive_serde_register_type(TypeVar, canonical_name="TypeVar", version=1)
recursive_serde_register(
Union,
canonical_name="Union",
serialize=serialize_union,
deserialize=deserialize_union,
version=1,
)
recursive_serde_register(
TypeVar,
canonical_name="TypeVar",
serialize=serialize_typevar,
deserialize=deserialize_typevar,
version=1,
)
recursive_serde_register(
Any,
canonical_name="Any",
serialize=serialize_any,
deserialize=deserialize_any,
version=1,
)

recursive_serde_register_type(
_UnionGenericAlias,
Expand All @@ -503,7 +561,9 @@ def deserialize_union_type(type_blob: bytes) -> type:
)
recursive_serde_register_type(GenericAlias, canonical_name="GenericAlias", version=1)

recursive_serde_register_type(Any, canonical_name="Any", version=1)
# recursive_serde_register_type(Any, canonical_name="Any", version=1)
recursive_serde_register_type(EnumMeta, canonical_name="EnumMeta", version=1)

recursive_serde_register_type(ABCMeta, canonical_name="ABCMeta", version=1)

recursive_serde_register_type(inspect._empty, canonical_name="inspect_empty", version=1)
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def repr_uid(_id: LineageID) -> str:
)


@serializable(canonical_name="ActionObjectPointer", version=1)
class ActionObjectPointer:
pass

Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/network/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .server_peer import ServerPeer


@serializable(canonical_name="ServerRoute", version=1)
class ServerRoute:
def client_with_context(
self, context: ServerServiceContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import cast

# relative
from ...serde.serializable import serializable
from ...store.linked_obj import LinkedObject
from ..context import AuthedServiceContext

Expand All @@ -21,6 +22,7 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s
return ""


@serializable(canonical_name="OnboardEmailTemplate", version=1)
class OnBoardEmailTemplate(EmailTemplate):
@staticmethod
def email_title(notification: "Notification", context: AuthedServiceContext) -> str:
Expand Down Expand Up @@ -107,6 +109,7 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s
return f"""<html>{head} {body}</html>"""


@serializable(canonical_name="RequestEmailTemplate", version=1)
class RequestEmailTemplate(EmailTemplate):
@staticmethod
def email_title(notification: "Notification", context: AuthedServiceContext) -> str:
Expand Down Expand Up @@ -254,6 +257,7 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s
return f"""<html>{head} {body}</html>"""


@serializable(canonical_name="RequestUpdateEmailTemplate", version=1)
class RequestUpdateEmailTemplate(EmailTemplate):
@staticmethod
def email_title(notification: "Notification", context: AuthedServiceContext) -> str:
Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/notifier/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def send(
TBaseNotifier = TypeVar("TBaseNotifier", bound=BaseNotifier)


@serializable(canonical_name="EmailNotifier", version=1)
class EmailNotifier(BaseNotifier):
smtp_client: SMTPClient
sender = ""
Expand Down
Loading

0 comments on commit d167b49

Please sign in to comment.