diff --git a/.github/workflows/pr-tests-syft.yml b/.github/workflows/pr-tests-syft.yml index 60be77fd438..177b836069a 100644 --- a/.github/workflows/pr-tests-syft.yml +++ b/.github/workflows/pr-tests-syft.yml @@ -188,7 +188,7 @@ jobs: ORCHESTRA_DEPLOYMENT_TYPE: "${{ matrix.deployment-type }}" TEST_NOTEBOOK_PATHS: "${{ matrix.notebook-paths }}" with: - timeout_seconds: 1800 + timeout_seconds: 2400 max_attempts: 3 command: tox -e syft.test.notebook diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 3492dd4423a..73bbcb558f7 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -41,8 +41,8 @@ from ..serde.serializable import serializable from ..serde.serialize import _serialize from ..service.context import NodeServiceContext -from ..service.metadata.node_metadata import NodeMetadata from ..service.metadata.node_metadata import NodeMetadataJSON +from ..service.metadata.node_metadata import NodeMetadataV2 from ..service.response import SyftError from ..service.response import SyftSuccess from ..service.user.user import UserCreate @@ -597,7 +597,7 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]: result = self.api.services.network.exchange_credentials_with( self_node_route=self_node_route, remote_node_route=remote_node_route, - remote_node_verify_key=client.metadata.to(NodeMetadata).verify_key, + remote_node_verify_key=client.metadata.to(NodeMetadataV2).verify_key, ) return result diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index ca7c4ee444e..85bf0f5cc5f 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -65,7 +65,7 @@ from ..service.dataset.dataset_service import DatasetService from ..service.enclave.enclave_service import EnclaveService from ..service.metadata.metadata_service import MetadataService -from ..service.metadata.node_metadata import NodeMetadata +from ..service.metadata.node_metadata import NodeMetadataV2 from ..service.network.network_service import NetworkService from ..service.notification.notification_service import NotificationService from ..service.object_search.migration_state_service import MigrateStateService @@ -339,6 +339,9 @@ def __init__( self.init_blob_storage(config=blob_storage_config) + # Migrate data before any operation on db + self.find_and_migrate_data() + NodeRegistry.set_node_for(self.id, self) def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None: @@ -459,9 +462,9 @@ def root_client(self): root_client.api.refresh_api_callback() return root_client - def _find_pending_migrations(self): - klasses_to_be_migrated = [] - + def _find_klasses_pending_for_migration( + self, object_types: List[SyftObject] + ) -> List[SyftObject]: context = AuthedServiceContext( node=self, credentials=self.verify_key, @@ -469,39 +472,91 @@ def _find_pending_migrations(self): ) migration_state_service = self.get_service(MigrateStateService) - canonical_name_version_map = [] - - # Track all object types from document store - for partition in self.document_store.partitions.values(): - object_type = partition.settings.object_type - canonical_name = object_type.__canonical_name__ - object_version = object_type.__version__ - canonical_name_version_map.append((canonical_name, object_version)) + klasses_to_be_migrated = [] - # Track all object types from action store - action_object_types = [Action, ActionObject] - action_object_types.extend(ActionObject.__subclasses__()) - for object_type in action_object_types: + for object_type in object_types: canonical_name = object_type.__canonical_name__ object_version = object_type.__version__ - canonical_name_version_map.append((canonical_name, object_version)) - for canonical_name, current_version in canonical_name_version_map: migration_state = migration_state_service.get_state(context, canonical_name) if ( migration_state is not None and migration_state.current_version != migration_state.latest_version ): - klasses_to_be_migrated.append(canonical_name) + klasses_to_be_migrated.append(object_type) else: migration_state_service.register_migration_state( context, - current_version=current_version, + current_version=object_version, canonical_name=canonical_name, ) return klasses_to_be_migrated + def find_and_migrate_data(self): + # Track all object type that need migration for document store + context = AuthedServiceContext( + node=self, + credentials=self.verify_key, + role=ServiceRole.ADMIN, + ) + document_store_object_types = [ + partition.settings.object_type + for partition in self.document_store.partitions.values() + ] + + object_pending_migration = self._find_klasses_pending_for_migration( + object_types=document_store_object_types + ) + + if object_pending_migration: + print( + "Object in Document Store that needs migration: ", + object_pending_migration, + ) + + # Migrate data for objects in document store + for object_type in object_pending_migration: + canonical_name = object_type.__canonical_name__ + object_partition = self.document_store.partitions.get(canonical_name) + if object_partition is None: + continue + + print(f"Migrating data for: {canonical_name} table.") + migration_status = object_partition.migrate_data( + to_klass=object_type, context=context + ) + if migration_status.is_err(): + raise Exception( + f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}" + ) + + # Track all object types from action store + action_object_types = [Action, ActionObject] + action_object_types.extend(ActionObject.__subclasses__()) + action_object_pending_migration = self._find_klasses_pending_for_migration( + action_object_types + ) + + if action_object_pending_migration: + print( + "Object in Action Store that needs migration: ", + action_object_pending_migration, + ) + + # Migrate data for objects in action store + for object_type in action_object_pending_migration: + canonical_name = object_type.__canonical_name__ + + migration_status = self.action_store.migrate_data( + to_klass=object_type, credentials=self.verify_key + ) + if migration_status.is_err(): + raise Exception( + f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}" + ) + print("Data Migrated to latest version !!!") + @property def guest_client(self): return self.get_guest_client() @@ -686,7 +741,7 @@ def _get_service_method_from_path(self, path: str) -> Callable: return getattr(service_obj, method_name) @property - def metadata(self) -> NodeMetadata: + def metadata(self) -> NodeMetadataV2: name = "" deployed_on = "" organization = "" @@ -709,7 +764,7 @@ def metadata(self) -> NodeMetadata: admin_email = settings_data.admin_email show_warnings = settings_data.show_warnings - return NodeMetadata( + return NodeMetadataV2( name=name, id=self.id, verify_key=self.verify_key, diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 779508fd945..25b510cb8ff 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -247,6 +247,30 @@ def add_permissions(self, permissions: List[ActionObjectPermission]) -> None: for permission in permissions: self.add_permission(permission) + def migrate_data(self, to_klass: SyftObject, credentials: SyftVerifyKey): + has_root_permission = credentials == self.root_verify_key + + if has_root_permission: + for key, value in self.data: + try: + if value.__canonical_name__ != to_klass.__canonical_name__: + continue + migrated_value = value.migrate_to(to_klass) + except Exception: + return Err(f"Failed to migrate data to {to_klass} for qk: {key}") + result = self.set( + uid=key, + credentials=credentials, + syft_object=migrated_value, + ) + + if result.is_err(): + return result.err() + + return Ok(True) + + return Err("You don't have permissions to migrate data.") + @serializable() class DictActionStore(KeyValueActionStore): diff --git a/packages/syft/src/syft/service/metadata/migrations.py b/packages/syft/src/syft/service/metadata/migrations.py index dd6200b97a2..58d09021eb2 100644 --- a/packages/syft/src/syft/service/metadata/migrations.py +++ b/packages/syft/src/syft/service/metadata/migrations.py @@ -2,10 +2,10 @@ from ...types.syft_migration import migrate from ...types.transforms import rename from .node_metadata import NodeMetadata -from .node_metadata import NodeMetadataV1 +from .node_metadata import NodeMetadataV2 -@migrate(NodeMetadataV1, NodeMetadata) +@migrate(NodeMetadata, NodeMetadataV2) def upgrade_metadata_v1_to_v2(): return [ rename("highest_object_version", "highest_version"), @@ -13,7 +13,7 @@ def upgrade_metadata_v1_to_v2(): ] -@migrate(NodeMetadata, NodeMetadataV1) +@migrate(NodeMetadataV2, NodeMetadata) def downgrade_metadata_v2_to_v1(): return [ rename("highest_version", "highest_object_version"), diff --git a/packages/syft/src/syft/service/metadata/node_metadata.py b/packages/syft/src/syft/service/metadata/node_metadata.py index a3b3922b36e..05f61ba59e7 100644 --- a/packages/syft/src/syft/service/metadata/node_metadata.py +++ b/packages/syft/src/syft/service/metadata/node_metadata.py @@ -63,7 +63,7 @@ class NodeMetadataUpdate(SyftObject): @serializable() -class NodeMetadataV1(SyftObject): +class NodeMetadata(SyftObject): __canonical_name__ = "NodeMetadata" __version__ = SYFT_OBJECT_VERSION_1 @@ -92,7 +92,7 @@ def check_version(self, client_version: str) -> bool: @serializable() -class NodeMetadata(SyftObject): +class NodeMetadataV2(SyftObject): __canonical_name__ = "NodeMetadata" __version__ = SYFT_OBJECT_VERSION_2 @@ -155,7 +155,7 @@ def check_version(self, client_version: str) -> bool: ) -@transform(NodeMetadata, NodeMetadataJSON) +@transform(NodeMetadataV2, NodeMetadataJSON) def metadata_to_json() -> List[Callable]: return [ drop(["__canonical_name__"]), @@ -166,7 +166,7 @@ def metadata_to_json() -> List[Callable]: ] -@transform(NodeMetadataJSON, NodeMetadata) +@transform(NodeMetadataJSON, NodeMetadataV2) def json_to_metadata() -> List[Callable]: return [ drop(["metadata_version", "supported_protocols"]), diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 4297572e5f0..18e9cd1c952 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -29,7 +29,7 @@ from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..data_subject.data_subject import NamePartitionKey -from ..metadata.node_metadata import NodeMetadata +from ..metadata.node_metadata import NodeMetadataV2 from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService @@ -398,7 +398,7 @@ def node_route_to_http_connection( return HTTPConnection(url=url, proxy_target_uid=obj.proxy_target_uid) -@transform(NodeMetadata, NodePeer) +@transform(NodeMetadataV2, NodePeer) def metadata_to_peer() -> List[Callable]: return [ keep(["id", "name", "verify_key", "node_type", "admin_email"]), diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index 37c762fe09c..b03037a75dc 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -16,7 +16,7 @@ from ...types.syft_object import SyftObject from ...types.uid import UID from ..context import NodeServiceContext -from ..metadata.node_metadata import NodeMetadata +from ..metadata.node_metadata import NodeMetadataV2 from .routes import NodeRoute from .routes import NodeRouteType from .routes import connection_to_route @@ -53,7 +53,7 @@ def from_client(client: SyftClient) -> Self: if not client.metadata: raise Exception("Client has have metadata first") - peer = client.metadata.to(NodeMetadata).to(NodePeer) + peer = client.metadata.to(NodeMetadataV2).to(NodePeer) route = connection_to_route(client.connection) peer.node_routes.append(route) return peer diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 9fbf7218974..5ad6cbdf2ae 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -31,7 +31,7 @@ from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...serde.serialize import _serialize -from ...service.metadata.node_metadata import NodeMetadata +from ...service.metadata.node_metadata import NodeMetadataV2 from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.identity import Identity @@ -64,7 +64,7 @@ class EventAlreadyAddedException(SyftException): pass -@transform(NodeMetadata, NodeIdentity) +@transform(NodeMetadataV2, NodeIdentity) def metadata_to_node_identity() -> List[Callable]: return [rename("id", "node_id"), rename("name", "node_name")] @@ -1232,7 +1232,7 @@ def to_node_identity(val: Union[SyftClient, NodeIdentity]): if isinstance(val, NodeIdentity): return val elif isinstance(val, SyftClient): - metadata = val.metadata.to(NodeMetadata) + metadata = val.metadata.to(NodeMetadataV2) return metadata.to(NodeIdentity) else: raise SyftException( diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 98ca622c5bb..474975a5511 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -26,6 +26,7 @@ from ..node.credentials import SyftVerifyKey from ..serde.serializable import serializable from ..service.action.action_permissions import ActionObjectPermission +from ..service.context import AuthedServiceContext from ..service.response import SyftSuccess from ..types.base import SyftBaseModel from ..types.syft_object import SYFT_OBJECT_VERSION_1 @@ -456,6 +457,16 @@ def all( ) -> Result[List[BaseStash.object_type], str]: return self._thread_safe_cbk(self._all, credentials, order_by, has_permission) + def migrate_data( + self, + to_klass: SyftObject, + context: AuthedServiceContext, + has_permission: Optional[bool] = False, + ) -> Result[bool, str]: + return self._thread_safe_cbk( + self._migrate_data, to_klass, context, has_permission + ) + # Potentially thread-unsafe methods. # CAUTION: # * Don't use self.lock here. @@ -497,6 +508,14 @@ def remove_permission(self, permission: ActionObjectPermission) -> None: def has_permission(self, permission: ActionObjectPermission) -> bool: raise NotImplementedError + def _migrate_data( + self, + to_klass: SyftObject, + context: AuthedServiceContext, + has_permission: bool, + ) -> Result[bool, str]: + raise NotImplementedError + @instrument @serializable() diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index bc1c6d2ea37..0b86685ea66 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -24,6 +24,7 @@ from ..service.action.action_permissions import ActionObjectREAD from ..service.action.action_permissions import ActionObjectWRITE from ..service.action.action_permissions import ActionPermission +from ..service.context import AuthedServiceContext from ..service.response import SyftSuccess from ..types.syft_object import SyftObject from ..types.uid import UID @@ -370,6 +371,7 @@ def _update( qk: QueryKey, obj: SyftObject, has_permission=False, + overwrite=False, ) -> Result[SyftObject, str]: try: if qk.value not in self.data: @@ -396,11 +398,15 @@ def _update( ) # update the object with new data - for key, value in obj.to_dict(exclude_empty=True).items(): - if key == "id": - # protected field - continue - setattr(_original_obj, key, value) + if overwrite: + # Overwrite existing object and their values + _original_obj = obj + else: + for key, value in obj.to_dict(exclude_empty=True).items(): + if key == "id": + # protected field + continue + setattr(_original_obj, key, value) # update data and keys self._set_data_and_keys( @@ -608,3 +614,30 @@ def _set_data_and_keys( self.searchable_keys[pk_key] = ck_col self.data[store_query_key.value] = obj + + def _migrate_data( + self, to_klass: SyftObject, context: AuthedServiceContext, has_permission: bool + ) -> Result[bool, str]: + credentials = context.credentials + has_permission = (credentials == self.root_verify_key) or has_permission + if has_permission: + for key, value in self.data.items(): + try: + migrated_value = value.migrate_to(to_klass.__version__, context) + except Exception: + return Err(f"Failed to migrate data to {to_klass} for qk: {key}") + qk = self.settings.store_key.with_obj(key) + result = self._update( + credentials, + qk=qk, + obj=migrated_value, + has_permission=has_permission, + overwrite=True, + ) + + if result.is_err(): + return result.err() + + return Ok(True) + + return Err("You don't have permissions to migrate data.") diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index b74e877607a..c502db794d2 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -26,6 +26,7 @@ from ..service.action.action_permissions import ActionObjectREAD from ..service.action.action_permissions import ActionObjectWRITE from ..service.action.action_permissions import ActionPermission +from ..service.context import AuthedServiceContext from ..service.response import SyftSuccess from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import StorableObjectType @@ -562,6 +563,42 @@ def __len__(self): collection: MongoCollection = collection_status.ok() return collection.count_documents(filter={}) + def _migrate_data( + self, to_klass: SyftObject, context: AuthedServiceContext, has_permission: bool + ) -> Result[bool, str]: + credentials = context.credentials + has_permission = (credentials == self.root_verify_key) or has_permission + collection_status = self.collection + if collection_status.is_err(): + return collection_status + collection: MongoCollection = collection_status.ok() + + if has_permission: + storage_objs = collection.find({}) + for storage_obj in storage_objs: + obj = self.storage_type(storage_obj) + transform_context = TransformContext(output={}, obj=obj) + value = obj.to(self.settings.object_type, transform_context) + key = obj.get("_id") + try: + migrated_value = value.migrate_to(to_klass.__version__, context) + except Exception: + return Err(f"Failed to migrate data to {to_klass} for qk: {key}") + qk = self.settings.store_key.with_obj(key) + result = self._update( + credentials, + qk=qk, + obj=migrated_value, + has_permission=has_permission, + ) + + if result.is_err(): + return result.err() + + return Ok(True) + + return Err("You don't have permissions to migrate data.") + @serializable() class MongoDocumentStore(DocumentStore): diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index ee474273814..ff61fc58b74 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -319,7 +319,7 @@ def get_migration_for_version( raise Exception( f"No migration found for class type: {type_from} to " - "version: {version_to} in the migration registry." + f"version: {version_to} in the migration registry." ) diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py index 596e5cec17b..ad318b82b04 100644 --- a/packages/syft/tests/syft/settings/settings_service_test.py +++ b/packages/syft/tests/syft/settings/settings_service_test.py @@ -14,7 +14,7 @@ from syft.node.credentials import SyftSigningKey from syft.node.credentials import SyftVerifyKey from syft.service.context import AuthedServiceContext -from syft.service.metadata.node_metadata import NodeMetadata +from syft.service.metadata.node_metadata import NodeMetadataV2 from syft.service.response import SyftError from syft.service.response import SyftSuccess from syft.service.settings.settings import NodeSettings @@ -227,7 +227,7 @@ def test_settings_allow_guest_registration( # Create a new worker verify_key = SyftSigningKey.generate().verify_key - mock_node_metadata = NodeMetadata( + mock_node_metadata = NodeMetadataV2( name=faker.name(), verify_key=verify_key, highest_version=1, @@ -310,7 +310,7 @@ def get_mock_client(faker, root_client, role): ) verify_key = SyftSigningKey.generate().verify_key - mock_node_metadata = NodeMetadata( + mock_node_metadata = NodeMetadataV2( name=faker.name(), verify_key=verify_key, highest_version=1,