Skip to content

Commit

Permalink
Merge pull request #8174 from OpenMined/obj-version-data-migrate
Browse files Browse the repository at this point in the history
Force Migrate existing data to new version
  • Loading branch information
shubham3121 authored Oct 30, 2023
2 parents 850e742 + ced7ebe commit 234db36
Show file tree
Hide file tree
Showing 14 changed files with 216 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-tests-syft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ jobs:
ORCHESTRA_DEPLOYMENT_TYPE: "${{ matrix.deployment-type }}"
TEST_NOTEBOOK_PATHS: "${{ matrix.notebook-paths }}"
with:
timeout_seconds: 1800
timeout_seconds: 2400
max_attempts: 3
command: tox -e syft.test.notebook

Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
from ..serde.serializable import serializable
from ..serde.serialize import _serialize
from ..service.context import NodeServiceContext
from ..service.metadata.node_metadata import NodeMetadata
from ..service.metadata.node_metadata import NodeMetadataJSON
from ..service.metadata.node_metadata import NodeMetadataV2
from ..service.response import SyftError
from ..service.response import SyftSuccess
from ..service.user.user import UserCreate
Expand Down Expand Up @@ -597,7 +597,7 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]:
result = self.api.services.network.exchange_credentials_with(
self_node_route=self_node_route,
remote_node_route=remote_node_route,
remote_node_verify_key=client.metadata.to(NodeMetadata).verify_key,
remote_node_verify_key=client.metadata.to(NodeMetadataV2).verify_key,
)

return result
Expand Down
99 changes: 77 additions & 22 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from ..service.dataset.dataset_service import DatasetService
from ..service.enclave.enclave_service import EnclaveService
from ..service.metadata.metadata_service import MetadataService
from ..service.metadata.node_metadata import NodeMetadata
from ..service.metadata.node_metadata import NodeMetadataV2
from ..service.network.network_service import NetworkService
from ..service.notification.notification_service import NotificationService
from ..service.object_search.migration_state_service import MigrateStateService
Expand Down Expand Up @@ -339,6 +339,9 @@ def __init__(

self.init_blob_storage(config=blob_storage_config)

# Migrate data before any operation on db
self.find_and_migrate_data()

NodeRegistry.set_node_for(self.id, self)

def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None:
Expand Down Expand Up @@ -459,49 +462,101 @@ def root_client(self):
root_client.api.refresh_api_callback()
return root_client

def _find_pending_migrations(self):
klasses_to_be_migrated = []

def _find_klasses_pending_for_migration(
self, object_types: List[SyftObject]
) -> List[SyftObject]:
context = AuthedServiceContext(
node=self,
credentials=self.verify_key,
role=ServiceRole.ADMIN,
)
migration_state_service = self.get_service(MigrateStateService)

canonical_name_version_map = []

# Track all object types from document store
for partition in self.document_store.partitions.values():
object_type = partition.settings.object_type
canonical_name = object_type.__canonical_name__
object_version = object_type.__version__
canonical_name_version_map.append((canonical_name, object_version))
klasses_to_be_migrated = []

# Track all object types from action store
action_object_types = [Action, ActionObject]
action_object_types.extend(ActionObject.__subclasses__())
for object_type in action_object_types:
for object_type in object_types:
canonical_name = object_type.__canonical_name__
object_version = object_type.__version__
canonical_name_version_map.append((canonical_name, object_version))

for canonical_name, current_version in canonical_name_version_map:
migration_state = migration_state_service.get_state(context, canonical_name)
if (
migration_state is not None
and migration_state.current_version != migration_state.latest_version
):
klasses_to_be_migrated.append(canonical_name)
klasses_to_be_migrated.append(object_type)
else:
migration_state_service.register_migration_state(
context,
current_version=current_version,
current_version=object_version,
canonical_name=canonical_name,
)

return klasses_to_be_migrated

def find_and_migrate_data(self):
# Track all object type that need migration for document store
context = AuthedServiceContext(
node=self,
credentials=self.verify_key,
role=ServiceRole.ADMIN,
)
document_store_object_types = [
partition.settings.object_type
for partition in self.document_store.partitions.values()
]

object_pending_migration = self._find_klasses_pending_for_migration(
object_types=document_store_object_types
)

if object_pending_migration:
print(
"Object in Document Store that needs migration: ",
object_pending_migration,
)

# Migrate data for objects in document store
for object_type in object_pending_migration:
canonical_name = object_type.__canonical_name__
object_partition = self.document_store.partitions.get(canonical_name)
if object_partition is None:
continue

print(f"Migrating data for: {canonical_name} table.")
migration_status = object_partition.migrate_data(
to_klass=object_type, context=context
)
if migration_status.is_err():
raise Exception(
f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}"
)

# Track all object types from action store
action_object_types = [Action, ActionObject]
action_object_types.extend(ActionObject.__subclasses__())
action_object_pending_migration = self._find_klasses_pending_for_migration(
action_object_types
)

if action_object_pending_migration:
print(
"Object in Action Store that needs migration: ",
action_object_pending_migration,
)

# Migrate data for objects in action store
for object_type in action_object_pending_migration:
canonical_name = object_type.__canonical_name__

migration_status = self.action_store.migrate_data(
to_klass=object_type, credentials=self.verify_key
)
if migration_status.is_err():
raise Exception(
f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}"
)
print("Data Migrated to latest version !!!")

@property
def guest_client(self):
return self.get_guest_client()
Expand Down Expand Up @@ -686,7 +741,7 @@ def _get_service_method_from_path(self, path: str) -> Callable:
return getattr(service_obj, method_name)

@property
def metadata(self) -> NodeMetadata:
def metadata(self) -> NodeMetadataV2:
name = ""
deployed_on = ""
organization = ""
Expand All @@ -709,7 +764,7 @@ def metadata(self) -> NodeMetadata:
admin_email = settings_data.admin_email
show_warnings = settings_data.show_warnings

return NodeMetadata(
return NodeMetadataV2(
name=name,
id=self.id,
verify_key=self.verify_key,
Expand Down
24 changes: 24 additions & 0 deletions packages/syft/src/syft/service/action/action_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,30 @@ def add_permissions(self, permissions: List[ActionObjectPermission]) -> None:
for permission in permissions:
self.add_permission(permission)

def migrate_data(self, to_klass: SyftObject, credentials: SyftVerifyKey):
has_root_permission = credentials == self.root_verify_key

if has_root_permission:
for key, value in self.data:
try:
if value.__canonical_name__ != to_klass.__canonical_name__:
continue
migrated_value = value.migrate_to(to_klass)
except Exception:
return Err(f"Failed to migrate data to {to_klass} for qk: {key}")
result = self.set(
uid=key,
credentials=credentials,
syft_object=migrated_value,
)

if result.is_err():
return result.err()

return Ok(True)

return Err("You don't have permissions to migrate data.")


@serializable()
class DictActionStore(KeyValueActionStore):
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/metadata/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
from ...types.syft_migration import migrate
from ...types.transforms import rename
from .node_metadata import NodeMetadata
from .node_metadata import NodeMetadataV1
from .node_metadata import NodeMetadataV2


@migrate(NodeMetadataV1, NodeMetadata)
@migrate(NodeMetadata, NodeMetadataV2)
def upgrade_metadata_v1_to_v2():
return [
rename("highest_object_version", "highest_version"),
rename("lowest_object_version", "lowest_version"),
]


@migrate(NodeMetadata, NodeMetadataV1)
@migrate(NodeMetadataV2, NodeMetadata)
def downgrade_metadata_v2_to_v1():
return [
rename("highest_version", "highest_object_version"),
Expand Down
8 changes: 4 additions & 4 deletions packages/syft/src/syft/service/metadata/node_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class NodeMetadataUpdate(SyftObject):


@serializable()
class NodeMetadataV1(SyftObject):
class NodeMetadata(SyftObject):
__canonical_name__ = "NodeMetadata"
__version__ = SYFT_OBJECT_VERSION_1

Expand Down Expand Up @@ -92,7 +92,7 @@ def check_version(self, client_version: str) -> bool:


@serializable()
class NodeMetadata(SyftObject):
class NodeMetadataV2(SyftObject):
__canonical_name__ = "NodeMetadata"
__version__ = SYFT_OBJECT_VERSION_2

Expand Down Expand Up @@ -155,7 +155,7 @@ def check_version(self, client_version: str) -> bool:
)


@transform(NodeMetadata, NodeMetadataJSON)
@transform(NodeMetadataV2, NodeMetadataJSON)
def metadata_to_json() -> List[Callable]:
return [
drop(["__canonical_name__"]),
Expand All @@ -166,7 +166,7 @@ def metadata_to_json() -> List[Callable]:
]


@transform(NodeMetadataJSON, NodeMetadata)
@transform(NodeMetadataJSON, NodeMetadataV2)
def json_to_metadata() -> List[Callable]:
return [
drop(["metadata_version", "supported_protocols"]),
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/service/network/network_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...util.telemetry import instrument
from ..context import AuthedServiceContext
from ..data_subject.data_subject import NamePartitionKey
from ..metadata.node_metadata import NodeMetadata
from ..metadata.node_metadata import NodeMetadataV2
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
Expand Down Expand Up @@ -398,7 +398,7 @@ def node_route_to_http_connection(
return HTTPConnection(url=url, proxy_target_uid=obj.proxy_target_uid)


@transform(NodeMetadata, NodePeer)
@transform(NodeMetadataV2, NodePeer)
def metadata_to_peer() -> List[Callable]:
return [
keep(["id", "name", "verify_key", "node_type", "admin_email"]),
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/service/network/node_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ...types.syft_object import SyftObject
from ...types.uid import UID
from ..context import NodeServiceContext
from ..metadata.node_metadata import NodeMetadata
from ..metadata.node_metadata import NodeMetadataV2
from .routes import NodeRoute
from .routes import NodeRouteType
from .routes import connection_to_route
Expand Down Expand Up @@ -53,7 +53,7 @@ def from_client(client: SyftClient) -> Self:
if not client.metadata:
raise Exception("Client has have metadata first")

peer = client.metadata.to(NodeMetadata).to(NodePeer)
peer = client.metadata.to(NodeMetadataV2).to(NodePeer)
route = connection_to_route(client.connection)
peer.node_routes.append(route)
return peer
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ...node.credentials import SyftVerifyKey
from ...serde.serializable import serializable
from ...serde.serialize import _serialize
from ...service.metadata.node_metadata import NodeMetadata
from ...service.metadata.node_metadata import NodeMetadataV2
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
from ...types.identity import Identity
Expand Down Expand Up @@ -64,7 +64,7 @@ class EventAlreadyAddedException(SyftException):
pass


@transform(NodeMetadata, NodeIdentity)
@transform(NodeMetadataV2, NodeIdentity)
def metadata_to_node_identity() -> List[Callable]:
return [rename("id", "node_id"), rename("name", "node_name")]

Expand Down Expand Up @@ -1232,7 +1232,7 @@ def to_node_identity(val: Union[SyftClient, NodeIdentity]):
if isinstance(val, NodeIdentity):
return val
elif isinstance(val, SyftClient):
metadata = val.metadata.to(NodeMetadata)
metadata = val.metadata.to(NodeMetadataV2)
return metadata.to(NodeIdentity)
else:
raise SyftException(
Expand Down
19 changes: 19 additions & 0 deletions packages/syft/src/syft/store/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..node.credentials import SyftVerifyKey
from ..serde.serializable import serializable
from ..service.action.action_permissions import ActionObjectPermission
from ..service.context import AuthedServiceContext
from ..service.response import SyftSuccess
from ..types.base import SyftBaseModel
from ..types.syft_object import SYFT_OBJECT_VERSION_1
Expand Down Expand Up @@ -456,6 +457,16 @@ def all(
) -> Result[List[BaseStash.object_type], str]:
return self._thread_safe_cbk(self._all, credentials, order_by, has_permission)

def migrate_data(
self,
to_klass: SyftObject,
context: AuthedServiceContext,
has_permission: Optional[bool] = False,
) -> Result[bool, str]:
return self._thread_safe_cbk(
self._migrate_data, to_klass, context, has_permission
)

# Potentially thread-unsafe methods.
# CAUTION:
# * Don't use self.lock here.
Expand Down Expand Up @@ -497,6 +508,14 @@ def remove_permission(self, permission: ActionObjectPermission) -> None:
def has_permission(self, permission: ActionObjectPermission) -> bool:
raise NotImplementedError

def _migrate_data(
self,
to_klass: SyftObject,
context: AuthedServiceContext,
has_permission: bool,
) -> Result[bool, str]:
raise NotImplementedError


@instrument
@serializable()
Expand Down
Loading

0 comments on commit 234db36

Please sign in to comment.