diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 62788762acf..4a346ebde9d 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -1,6 +1,7 @@ # stdlib from collections import defaultdict import logging +from typing import Any # syft absolute import syft @@ -16,6 +17,7 @@ from ...types.syft_object import SyftObject from ...types.syft_object_registry import SyftObjectRegistry from ...types.twin_object import TwinObject +from ...types.uid import UID from ..action.action_object import Action from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission @@ -26,7 +28,10 @@ from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method +from ..sync.sync_service import get_store +from ..sync.sync_service import get_store_by_type from ..user.user_roles import ADMIN_ROLE_LEVEL +from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL from ..worker.utils import DEFAULT_WORKER_POOL_NAME from .object_migration_state import MigrationData from .object_migration_state import StoreMetadata @@ -493,3 +498,29 @@ def reset_and_restore( ) return SyftSuccess(message="Database reset successfully.") + + @service_method( + path="migration._get_object", + name="_get_object", + roles=DATA_SCIENTIST_ROLE_LEVEL, + ) + def _get_object( + self, context: AuthedServiceContext, uid: UID, object_type: type + ) -> Any: + return ( + get_store_by_type(context, object_type) + .get_by_uid(credentials=context.credentials, uid=uid) + .unwrap() + ) + + @service_method( + path="migration._update_object", + name="_update_object", + roles=ADMIN_ROLE_LEVEL, + ) + def _update_object(self, context: AuthedServiceContext, object: Any) -> Any: + return ( + get_store(context, object) + .update(credentials=context.credentials, obj=object) + .unwrap() + ) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 7c5495a756e..ab07c7380f1 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -585,6 +585,7 @@ def get_status(self, context: AuthedServiceContext | None = None) -> RequestStat # which tries to send an email to the admin and ends up here pass # lets keep going + self.refresh() if len(self.history) == 0: return RequestStatus.PENDING diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 19bed044eb4..b6cc955ac4f 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -39,10 +39,14 @@ def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash: - if isinstance(item, ActionObject): + return get_store_by_type(context=context, obj_type=type(item)) + + +def get_store_by_type(context: AuthedServiceContext, obj_type: type) -> ObjectStash: + if issubclass(obj_type, ActionObject): service = context.server.services.action # type: ignore return service.stash # type: ignore - service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore + service = context.server.get_service(TYPE_TO_SERVICE[obj_type]) # type: ignore return service.stash diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index f6a4d3233cb..7b30ffaa562 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -429,6 +429,17 @@ def make_id(cls, values: Any) -> Any: __table_coll_widths__: ClassVar[list[str] | None] = None __table_sort_attr__: ClassVar[str | None] = None + def refresh(self) -> None: + try: + api = self._get_api() + new_object = api.services.migration._get_object( + uid=self.id, object_type=type(self) + ) + if type(new_object) == type(self): + self.__dict__.update(new_object.__dict__) + except Exception as _: + return + def __syft_get_funcs__(self) -> list[tuple[str, Signature]]: funcs = print_type_cache[type(self)] if len(funcs) > 0: diff --git a/packages/syft/tests/syft/service/sync/get_set_object_test.py b/packages/syft/tests/syft/service/sync/get_set_object_test.py new file mode 100644 index 00000000000..e6681dc621f --- /dev/null +++ b/packages/syft/tests/syft/service/sync/get_set_object_test.py @@ -0,0 +1,57 @@ +# third party + +# syft absolute +import syft as sy +from syft.client.datasite_client import DatasiteClient +from syft.service.action.action_object import ActionObject +from syft.service.dataset.dataset import Dataset + + +def get_ds_client(client: DatasiteClient) -> DatasiteClient: + client.register( + name="a", + email="a@a.com", + password="asdf", + password_verify="asdf", + ) + return client.login(email="a@a.com", password="asdf") + + +def test_get_set_object(high_worker): + high_client: DatasiteClient = high_worker.root_client + _ = get_ds_client(high_client) + root_datasite_client = high_worker.root_client + dataset = sy.Dataset( + name="local_test", + asset_list=[ + sy.Asset( + name="local_test", + data=[1, 2, 3], + mock=[1, 1, 1], + ) + ], + ) + root_datasite_client.upload_dataset(dataset) + dataset = root_datasite_client.datasets[0] + + other_dataset = high_client.api.services.migration._get_object( + uid=dataset.id, object_type=Dataset + ) + other_dataset.server_uid = dataset.server_uid + assert dataset == other_dataset + other_dataset.name = "new_name" + updated_dataset = high_client.api.services.migration._update_object( + object=other_dataset + ) + assert updated_dataset.name == "new_name" + + asset = root_datasite_client.datasets[0].assets[0] + source_ao = high_client.api.services.action.get(uid=asset.action_id) + ao = high_client.api.services.migration._get_object( + uid=asset.action_id, object_type=ActionObject + ) + ao._set_obj_location_( + high_worker.id, + root_datasite_client.credentials, + ) + assert source_ao == ao