diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 4c438975245..33bf94c8d4f 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -386,10 +386,16 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: version = getattr(proto, "version", -1) if not SyftObjectRegistry.has_serde_class(canonical_name, version): + # relative + from ..server.server import CODE_RELOADER + + for load_user_code in CODE_RELOADER.values(): + load_user_code() # third party - raise Exception( - f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry" - ) + if not SyftObjectRegistry.has_serde_class(canonical_name, version): + raise Exception( + f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry" + ) # TODO: 🐉 sort this out, basically sometimes the syft.user classes are not in the # module name space in sub-processes or threads even though they are loaded on start diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index ba1ae048f95..4bf96a58c5a 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -12,7 +12,6 @@ from inspect import Signature from io import StringIO import sys -import types from typing import Any from typing import ClassVar @@ -1133,19 +1132,34 @@ def submit_policy_code_to_user_code() -> list[Callable]: ] -def add_class_to_user_module(klass: type, unique_name: str) -> type: - klass.__module__ = "syft.user" - klass.__name__ = unique_name - # syft absolute - import syft as sy +def register_policy_class(klass: type, unique_name: str) -> None: + nonrecursive = False + _serialize = None + _deserialize = None + attributes = list(klass.model_fields.keys()) + exclude_attrs: list = [] + serde_overrides: dict = {} + hash_exclude_attrs: list = [] + cls = klass + attribute_types: list = [] + version = 1 + + serde_attributes = ( + nonrecursive, + _serialize, + _deserialize, + attributes, + exclude_attrs, + serde_overrides, + hash_exclude_attrs, + cls, + attribute_types, + version, + ) - if not hasattr(sy, "user"): - user_module = types.ModuleType("user") - sys.modules["syft"].user = user_module - user_module = sy.user - setattr(user_module, unique_name, klass) - sys.modules["syft"].user = user_module - return klass + SyftObjectRegistry.register_cls( + canonical_name=unique_name, version=version, serde_attributes=serde_attributes + ) def execute_policy_code(user_policy: UserPolicy) -> Any: @@ -1169,7 +1183,7 @@ def execute_policy_code(user_policy: UserPolicy) -> Any: exec(user_policy.byte_code) # nosec policy_class = eval(user_policy.unique_name) # nosec - policy_class = add_class_to_user_module(policy_class, user_policy.unique_name) + register_policy_class(policy_class, user_policy.unique_name) sys.stdout = stdout_ sys.stderr = stderr_