diff --git a/packages/syft/src/syft/capnp/recursive_serde.capnp b/packages/syft/src/syft/capnp/recursive_serde.capnp index c29ba57aae6..5b6fadb5c65 100644 --- a/packages/syft/src/syft/capnp/recursive_serde.capnp +++ b/packages/syft/src/syft/capnp/recursive_serde.capnp @@ -3,8 +3,7 @@ struct RecursiveSerde { fieldsName @0 :List(Text); fieldsData @1 :List(List(Data)); - fullyQualifiedName @2 :Text; - nonrecursiveBlob @3 :List(Data); - canonicalName @4 :Text; - version @5 :Int32; + nonrecursiveBlob @2 :List(Data); + canonicalName @3 :Text; + version @4 :Int32; } diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index d60416fdf3b..0ff9299bdfe 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -32,7 +32,6 @@ from ..protocol.data_protocol import get_data_protocol from ..protocol.data_protocol import migrate_args_and_kwargs from ..serde.deserialize import _deserialize -from ..serde.recursive import index_syft_by_module_name from ..serde.serializable import serializable from ..serde.serialize import _serialize from ..serde.signature import Signature @@ -63,6 +62,7 @@ from ..util.markdown import as_markdown_python_code from ..util.notebook_ui.components.tabulator_template import build_tabulator_table from ..util.telemetry import instrument +from ..util.util import index_syft_by_module_name from ..util.util import prompt_warning_message from .connection import ServerConnection diff --git a/packages/syft/src/syft/custom_worker/config.py b/packages/syft/src/syft/custom_worker/config.py index 254d64b702e..f01266221bd 100644 --- a/packages/syft/src/syft/custom_worker/config.py +++ b/packages/syft/src/syft/custom_worker/config.py @@ -79,6 +79,7 @@ def merged_custom_cmds(self, sep: str = ";") -> str: return sep.join(self.custom_cmds) +@serializable(canonical_name="WorkerConfig", version=1) class WorkerConfig(SyftBaseModel): pass diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 1b5f6f9afcb..fa34ec9c538 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1,6 +1,13 @@ { "dev": { "object_versions": { + "SyftObjectVersioned": { + "1": { + "version": 1, + "hash": "7c842dcdbb57e2528ffa690ea18c19fff3c8a591811d40cad2b19be3100e2ff4", + "action": "add" + } + }, "BaseDateTime": { "1": { "version": 1, @@ -8,6 +15,13 @@ "action": "add" } }, + "SyftObject": { + "1": { + "version": 1, + "hash": "bb70d874355988908d3a92a3941d6613a6995a4850be3b6a0147f4d387724406", + "action": "add" + } + }, "PartialSyftObject": { "1": { "version": 1, @@ -988,6 +1002,13 @@ "action": "add" } }, + "ProjectEvent": { + "1": { + "version": 1, + "hash": "dc0486c52daebd5e98c2b3b03ffd9a9a14bc3d86d8dc0c23e41ebf6c31fe2ffb", + "action": "add" + } + }, "ProjectThreadMessage": { "1": { "version": 1, diff --git a/packages/syft/src/syft/serde/array.py b/packages/syft/src/syft/serde/array.py index fa1ed27e74b..3f19e575b97 100644 --- a/packages/syft/src/syft/serde/array.py +++ b/packages/syft/src/syft/serde/array.py @@ -162,6 +162,14 @@ version=SYFT_OBJECT_VERSION_1, ) +recursive_serde_register( + np.number, + serialize=lambda x: x.tobytes(), + deserialize=lambda buffer: frombuffer(buffer, dtype=np.number)[0], + canonical_name="numpy_number", + version=SYFT_OBJECT_VERSION_1, +) + # TODO: There is an incorrect mapping in looping,which makes it not work. # numpy_scalar_types = [ # np.bool_, diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 19ede3d0040..4c438975245 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -3,7 +3,6 @@ from enum import Enum from enum import EnumMeta import os -import sys import tempfile import types from typing import Any @@ -17,7 +16,6 @@ # relative from ..types.syft_object_registry import SyftObjectRegistry -from ..util.util import index_syft_by_module_name from .capnp import get_capnp_schema from .util import compatible_with_large_file_writes_capnp @@ -285,7 +283,7 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild # todo: rewrite and make sure every object has a canonical name and version canonical_name, version = SyftObjectRegistry.get_canonical_name_version(self) - if not SyftObjectRegistry.has_serde_class("", canonical_name, version): + if not SyftObjectRegistry.has_serde_class(canonical_name, version): # third party raise Exception( f"obj2proto: {canonical_name} version {version} not in SyftObjectRegistry" @@ -382,34 +380,12 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: # relative from .deserialize import _deserialize - # clean this mess, Tudor - module_parts = proto.fullyQualifiedName.split(".") - klass = module_parts.pop() class_type: type | Any = type(None) - if klass != "NoneType": - try: - class_type = index_syft_by_module_name(proto.fullyQualifiedName) # type: ignore[assignment,unused-ignore] - except Exception: # nosec - try: - class_type = getattr(sys.modules[".".join(module_parts)], klass) - except Exception: # nosec - if "syft.user" in proto.fullyQualifiedName: - # relative - from ..server.server import CODE_RELOADER - - for load_user_code in CODE_RELOADER.values(): - load_user_code() - try: - class_type = getattr(sys.modules[".".join(module_parts)], klass) - except Exception: # nosec - pass - canonical_name = proto.canonicalName version = getattr(proto, "version", -1) - fqn = getattr(proto, "fullyQualifiedName", "") - fqn = map_fqns_for_backward_compatibility(fqn) - if not SyftObjectRegistry.has_serde_class(fqn, canonical_name, version): + + if not SyftObjectRegistry.has_serde_class(canonical_name, version): # third party raise Exception( f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry" @@ -431,13 +407,9 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: cls, _, version, - ) = SyftObjectRegistry.get_serde_properties_bw_compatible( - fqn, canonical_name, version - ) + ) = SyftObjectRegistry.get_serde_properties(canonical_name, version) - if class_type == type(None) or fqn != "": - # yes this looks stupid but it works and the opposite breaks - class_type = cls + class_type = cls if nonrecursive: if deserialize is None: @@ -468,14 +440,15 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: # if we skip the __new__ flow of BaseModel we get the error # AttributeError: object has no attribute '__fields_set__' - if "syft.user" in proto.fullyQualifiedName: - # weird issues with pydantic and ForwardRef on user classes being inited - # with custom state args / kwargs - obj = class_type() - for attr_name, attr_value in kwargs.items(): - setattr(obj, attr_name, attr_value) - else: - obj = class_type(**kwargs) + # if "syft.user" in proto.fullyQualifiedName: + # # weird issues with pydantic and ForwardRef on user classes being inited + # # with custom state args / kwargs + # obj = class_type() + # for attr_name, attr_value in kwargs.items(): + # setattr(obj, attr_name, attr_value) + # else: + # obj = class_type(**kwargs) + obj = class_type(**kwargs) else: obj = class_type.__new__(class_type) # type: ignore diff --git a/packages/syft/src/syft/serde/recursive_primitives.py b/packages/syft/src/syft/serde/recursive_primitives.py index 9385219b8e0..38d8281434d 100644 --- a/packages/syft/src/syft/serde/recursive_primitives.py +++ b/packages/syft/src/syft/serde/recursive_primitives.py @@ -8,12 +8,14 @@ from enum import Enum from enum import EnumMeta import functools +import inspect import pathlib from pathlib import PurePath import sys import tempfile from types import MappingProxyType from types import UnionType +import typing from typing import Any from typing import GenericAlias from typing import Optional @@ -27,6 +29,7 @@ import weakref # relative +from ..types.syft_object_registry import SyftObjectRegistry from .capnp import get_capnp_schema from .recursive import chunk_bytes from .recursive import combine_bytes @@ -169,22 +172,31 @@ def deserialize_enum(enum_type: type, enum_buf: bytes) -> Enum: return enum_type(enum_value) -def serialize_type(serialized_type: type) -> bytes: +def serialize_type(_type_to_serialize: type) -> bytes: # relative - from ..util.util import full_name_with_qualname + type_to_serialize = typing.get_origin(_type_to_serialize) or _type_to_serialize + canonical_name, version = SyftObjectRegistry.get_identifier_for_type( + type_to_serialize + ) + return f"{canonical_name}:{version}".encode() - fqn = full_name_with_qualname(klass=serialized_type) - module_parts = fqn.split(".") - return ".".join(module_parts).encode() + # from ..util.util import full_name_with_qualname + + # fqn = full_name_with_qualname(klass=serialized_type) + # module_parts = fqn.split(".") + # return ".".join(module_parts).encode() def deserialize_type(type_blob: bytes) -> type: deserialized_type = type_blob.decode() - module_parts = deserialized_type.split(".") - klass = module_parts.pop() - klass = "None" if klass == "NoneType" else klass - exception_type = getattr(sys.modules[".".join(module_parts)], klass) - return exception_type + canonical_name, version = deserialized_type.split(":", 1) + return SyftObjectRegistry.get_serde_class(canonical_name, int(version)) + + # module_parts = deserialized_type.split(".") + # klass = module_parts.pop() + # klass = "None" if klass == "NoneType" else klass + # exception_type = getattr(sys.modules[".".join(module_parts)], klass) + # return exception_type TPath = TypeVar("TPath", bound=PurePath) @@ -434,6 +446,8 @@ def recursive_serde_register_type( canonical_name: str | None = None, version: int | None = None, ) -> None: + # former case is for instance for _GerericAlias itself or UnionGenericAlias + # Latter case is true for for instance List[str], which is currently not used if (isinstance(t, type) and issubclass(t, _GenericAlias)) or issubclass( type(t), _GenericAlias ): @@ -471,6 +485,31 @@ def deserialize_union_type(type_blob: bytes) -> type: return functools.reduce(lambda x, y: x | y, args) +def serialize_union(serialized_type: UnionType) -> bytes: + return b"" + + +def deserialize_union(type_blob: bytes) -> type: # type: ignore + return Union # type: ignore + + +def serialize_typevar(serialized_type: TypeVar) -> bytes: + return f"{serialized_type.__name__}".encode() + + +def deserialize_typevar(type_blob: bytes) -> type: + name = type_blob.decode() + return TypeVar(name=name) # type: ignore + + +def serialize_any(serialized_type: TypeVar) -> bytes: + return b"" + + +def deserialize_any(type_blob: bytes) -> type: # type: ignore + return Any # type: ignore + + recursive_serde_register( UnionType, serialize=serialize_union_type, @@ -481,8 +520,27 @@ def deserialize_union_type(type_blob: bytes) -> type: recursive_serde_register_type(_SpecialForm, canonical_name="_SpecialForm", version=1) recursive_serde_register_type(_GenericAlias, canonical_name="_GenericAlias", version=1) -recursive_serde_register_type(Union, canonical_name="Union", version=1) -recursive_serde_register_type(TypeVar, canonical_name="TypeVar", version=1) +recursive_serde_register( + Union, + canonical_name="Union", + serialize=serialize_union, + deserialize=deserialize_union, + version=1, +) +recursive_serde_register( + TypeVar, + canonical_name="TypeVar", + serialize=serialize_typevar, + deserialize=deserialize_typevar, + version=1, +) +recursive_serde_register( + Any, + canonical_name="Any", + serialize=serialize_any, + deserialize=deserialize_any, + version=1, +) recursive_serde_register_type( _UnionGenericAlias, @@ -503,7 +561,9 @@ def deserialize_union_type(type_blob: bytes) -> type: ) recursive_serde_register_type(GenericAlias, canonical_name="GenericAlias", version=1) -recursive_serde_register_type(Any, canonical_name="Any", version=1) +# recursive_serde_register_type(Any, canonical_name="Any", version=1) recursive_serde_register_type(EnumMeta, canonical_name="EnumMeta", version=1) recursive_serde_register_type(ABCMeta, canonical_name="ABCMeta", version=1) + +recursive_serde_register_type(inspect._empty, canonical_name="inspect_empty", version=1) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index ae2e0d6e66c..bbad29396b9 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -230,6 +230,7 @@ def repr_uid(_id: LineageID) -> str: ) +@serializable(canonical_name="ActionObjectPointer", version=1) class ActionObjectPointer: pass diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index 05829eafb57..04e758b1bdc 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -30,6 +30,7 @@ from .server_peer import ServerPeer +@serializable(canonical_name="ServerRoute", version=1) class ServerRoute: def client_with_context( self, context: ServerServiceContext diff --git a/packages/syft/src/syft/service/notification/email_templates.py b/packages/syft/src/syft/service/notification/email_templates.py index 1a6965365dc..f8baceee38a 100644 --- a/packages/syft/src/syft/service/notification/email_templates.py +++ b/packages/syft/src/syft/service/notification/email_templates.py @@ -3,6 +3,7 @@ from typing import cast # relative +from ...serde.serializable import serializable from ...store.linked_obj import LinkedObject from ..context import AuthedServiceContext @@ -21,6 +22,7 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s return "" +@serializable(canonical_name="OnboardEmailTemplate", version=1) class OnBoardEmailTemplate(EmailTemplate): @staticmethod def email_title(notification: "Notification", context: AuthedServiceContext) -> str: @@ -107,6 +109,7 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s return f"""{head} {body}""" +@serializable(canonical_name="RequestEmailTemplate", version=1) class RequestEmailTemplate(EmailTemplate): @staticmethod def email_title(notification: "Notification", context: AuthedServiceContext) -> str: @@ -254,6 +257,7 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s return f"""{head} {body}""" +@serializable(canonical_name="RequestUpdateEmailTemplate", version=1) class RequestUpdateEmailTemplate(EmailTemplate): @staticmethod def email_title(notification: "Notification", context: AuthedServiceContext) -> str: diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py index 01a601d0da9..26dafe34e44 100644 --- a/packages/syft/src/syft/service/notifier/notifier.py +++ b/packages/syft/src/syft/service/notifier/notifier.py @@ -31,6 +31,7 @@ def send( TBaseNotifier = TypeVar("TBaseNotifier", bound=BaseNotifier) +@serializable(canonical_name="EmailNotifier", version=1) class EmailNotifier(BaseNotifier): smtp_client: SMTPClient sender = "" diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index b5a66a185cd..415f46fefb4 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -65,6 +65,7 @@ def metadata_to_server_identity() -> list[Callable]: return [rename("id", "server_id"), rename("name", "server_name")] +@serializable() class ProjectEvent(SyftObject): __canonical_name__ = "ProjectEvent" __version__ = SYFT_OBJECT_VERSION_1 diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index 4612ec5a416..fb8e854ccfe 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -181,10 +181,12 @@ def _coll_repr_(self) -> dict[str, str]: return {"file_name": self.file_name} +@serializable(canonical_name="BlobFileType", version=1) class BlobFileType(type): pass +@serializable(canonical_name="BlobFileObjectPointer", version=1) class BlobFileObjectPointer(ActionObjectPointer): pass diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 661f4d1db5a..f987479cafc 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -308,7 +308,8 @@ def get_migration_for_version( ] -class SyftObjectVersioned(SyftBaseObject, SyftObjectRegistry, SyftMigrationRegistry): +@serializable() +class SyftObjectVersioned(SyftBaseObject, SyftMigrationRegistry): __canonical_name__ = "SyftObjectVersioned" __version__ = SYFT_OBJECT_VERSION_1 @@ -345,6 +346,7 @@ def __lt__(self, other: Self) -> bool: return self.utc_timestamp < other.utc_timestamp +@serializable() class SyftObject(SyftObjectVersioned): __canonical_name__ = "SyftObject" __version__ = SYFT_OBJECT_VERSION_1 @@ -503,7 +505,6 @@ def __getitem__(self, key: str | int) -> Any: # transform from one supported type to another def to(self, projection: type[T], context: Context | None = None) -> T: # relative - from .syft_object_registry import SyftObjectRegistry # 🟡 TODO 19: Could we do an mro style inheritence conversion? Risky? transform = SyftObjectRegistry.get_transform(type(self), projection) @@ -756,7 +757,6 @@ class StorableObjectType: def to(self, projection: type, context: Context | None = None) -> Any: # 🟡 TODO 19: Could we do an mro style inheritence conversion? Risky? # relative - from .syft_object_registry import SyftObjectRegistry transform = SyftObjectRegistry.get_transform(type(self), projection) return transform(self, context) diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 7226e6a246f..d5cc342635e 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -38,6 +38,13 @@ def get_versions(cls, canonical_name: str) -> list[int]: ) return list(available_versions.keys()) + @classmethod + def get_identifier_for_type(cls, obj: Any) -> tuple[str, int]: + """ + This is to create the string in nonrecursiveBlob + """ + return cls.__type_to_canonical_name__[obj] + @classmethod def get_canonical_name_version(cls, obj: Any) -> tuple[str, int]: """ @@ -53,7 +60,7 @@ def get_canonical_name_version(cls, obj: Any) -> tuple[str, int]: get_canonical_name_version([1,2,3]) -> "list" get_canonical_name_version(list) -> "type" get_canonical_name_version(MyEnum.A) -> "MyEnum" - get_canonical_name_version(MyEnum) -> "EnumMeta" + get_canonical_name_version(MyEnum) -> "type" Args: obj: The object or type for which to get the canonical name. @@ -62,8 +69,7 @@ def get_canonical_name_version(cls, obj: Any) -> tuple[str, int]: The canonical name and version of the object or type. """ - # NOTE the metaclass of the object is not needed during serde - # so we can safely ignore it + # for types we return "type" if isinstance(obj, type): return cls.__type_to_canonical_name__[type] @@ -74,7 +80,15 @@ def get_canonical_name_version(cls, obj: Any) -> tuple[str, int]: @classmethod def get_serde_properties(cls, canonical_name: str, version: int) -> tuple: - return cls.__object_serialization_registry__[canonical_name][version] + try: + return cls.__object_serialization_registry__[canonical_name][version] + except Exception: + # This is a hack for python 3.10 in which Any is not a type + # if the server uses py>3.10 and the client 3.10 this goes wrong + if canonical_name == "Any_typing._SpecialForm": + return cls.__object_serialization_registry__["Any"][version] + else: + raise @classmethod def get_serde_class(cls, canonical_name: str, version: int) -> type["SyftObject"]: @@ -82,74 +96,12 @@ def get_serde_class(cls, canonical_name: str, version: int) -> type["SyftObject" return serde_properties[7] @classmethod - def get_serde_properties_bw_compatible( - cls, fqn: str, canonical_name: str, version: int - ) -> tuple: + def has_serde_class(cls, canonical_name: str | None, version: int) -> bool: # relative - from ..serde.recursive import TYPE_BANK - - if canonical_name != "" and canonical_name is not None: - return cls.get_serde_properties(canonical_name, version) - else: - # this is for backward compatibility with 0.8.6 - try: - # relative - from ..protocol.data_protocol import get_data_protocol - - serde_props = TYPE_BANK[fqn] - klass = serde_props[7] - is_syftobject = hasattr(klass, "__canonical_name__") - if is_syftobject: - canonical_name = klass.__canonical_name__ - dp = get_data_protocol() - try: - version_mutations = dp.protocol_history[ - SYFT_086_PROTOCOL_VERSION - ]["object_versions"][canonical_name] - except Exception: - print(f"could not find {canonical_name} in protocol history") - raise - - version_086 = max( - [ - int(k) - for k, v in version_mutations.items() - if v["action"] == "add" - ] - ) - try: - res = cls.get_serde_properties(canonical_name, version_086) - - except Exception: - print( - f"could not find {canonical_name} {version_086} in ObjectRegistry" - ) - raise - return res - else: - # TODO, add refactoring for non syftobject versions - canonical_name = fqn - version = 1 - return cls.get_serde_properties(canonical_name, version) - except Exception as e: - print(e) - raise - - @classmethod - def has_serde_class( - cls, fqn: str, canonical_name: str | None, version: int - ) -> bool: - # relative - from ..serde.recursive import TYPE_BANK - - if canonical_name != "" and canonical_name is not None: - return ( - canonical_name in cls.__object_serialization_registry__ - and version in cls.__object_serialization_registry__[canonical_name] - ) - else: - # this is for backward compatibility with 0.8.6 - return fqn in TYPE_BANK + return ( + canonical_name in cls.__object_serialization_registry__ + and version in cls.__object_serialization_registry__[canonical_name] + ) @classmethod def add_transform(