diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index fa31a1419b6..f0100bd0182 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -985,7 +985,7 @@ "SeaweedSecureFilePathLocation": { "2": { "version": 2, - "hash": "3ca49db7536a33d5712485164e95406000df9af2aed78e9f9fa2bb2bbbb34fe6", + "hash": "5fd63fed2a4efba8c2b6c7a7b5e9b5939181781c331230896aa130b6fd558739", "action": "add" } }, diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index f485c2e67bf..d1ef73aab9f 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -20,6 +20,7 @@ from ...types.blob_storage import BlobStorageEntry from ...types.blob_storage import BlobStorageMetadata from ...types.blob_storage import CreateBlobStorageEntry +from ...types.blob_storage import SeaweedSecureFilePathLocation from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftError @@ -63,6 +64,7 @@ def mount_azure( account_key: str, container_name: str, bucket_name: str, + use_direct_connections=True, ): # stdlib @@ -105,14 +107,22 @@ def mount_azure( objects = res["Contents"] file_sizes = [object["Size"] for object in objects] file_paths = [object["Key"] for object in objects] - secure_file_paths = [ - AzureSecureFilePathLocation( - path=file_path, - azure_profile_name=remote_name, - bucket_name=bucket_name, - ) - for file_path in file_paths - ] + if use_direct_connections: + secure_file_paths = [ + AzureSecureFilePathLocation( + path=file_path, + azure_profile_name=remote_name, + bucket_name=bucket_name, + ) + for file_path in file_paths + ] + else: + secure_file_paths = [ + SeaweedSecureFilePathLocation( + path=file_path, + ) + for file_path in file_paths + ] for sfp, file_size in zip(secure_file_paths, file_sizes): blob_storage_entry = BlobStorageEntry( diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 6d656c206dc..87fed2a65b2 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -413,13 +413,13 @@ def check(self) -> Union[SyftSuccess, SyftError]: return SyftError( message=f"set_obj type {type(self.data)} must match set_mock type {type(self.mock)}" ) - if not _is_action_data_empty(self.mock): - data_shape = get_shape_or_len(self.data) - mock_shape = get_shape_or_len(self.mock) - if data_shape != mock_shape: - return SyftError( - message=f"set_obj shape {data_shape} must match set_mock shape {mock_shape}" - ) + # if not _is_action_data_empty(self.mock): + # data_shape = get_shape_or_len(self.data) + # mock_shape = get_shape_or_len(self.mock) + # if data_shape != mock_shape: + # return SyftError( + # message=f"set_obj shape {data_shape} must match set_mock shape {mock_shape}" + # ) total_size_mb = get_mb_size(self.data) + get_mb_size(self.mock) if total_size_mb > DATA_SIZE_WARNING_LIMIT: print( @@ -438,7 +438,10 @@ def get_shape_or_len(obj: Any) -> Optional[Union[Tuple[int, ...], int]]: return shape len_attr = getattr(obj, "__len__", None) if len_attr is not None: - return len_attr() + len_value = len_attr() + if isinstance(len_value, int): + return (len_value,) + return len_value return None @@ -774,6 +777,12 @@ def add_msg_creation_time(context: TransformContext) -> TransformContext: return context +def add_default_node_uid(context: TransformContext) -> TransformContext: + if context.output["node_uid"] is None: + context.output["node_uid"] = context.node.id + return context + + @transform(CreateAsset, Asset) def createasset_to_asset() -> List[Callable]: return [ @@ -782,6 +791,7 @@ def createasset_to_asset() -> List[Callable]: infer_shape, create_and_store_twin, set_data_subjects, + add_default_node_uid, ] diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 53577f5d4b1..6bd8143833d 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -185,7 +185,7 @@ def add_workers( result = self.image_stash.get_by_uid( credentials=context.credentials, - uid=worker_pool.syft_worker_image_id, + uid=worker_pool.image.id, ) if result.is_err(): @@ -204,7 +204,7 @@ def add_workers( worker_stash=self.worker_stash, ) - worker_pool.worker_list.append(worker_list) + worker_pool.worker_list += worker_list worker_pool.max_count = existing_worker_cnt + number update_result = self.stash.update( diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index 7f0438acf2e..899cb04fcbd 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -201,7 +201,7 @@ class SeaweedSecureFilePathLocation(SecureFilePathLocation): __canonical_name__ = "SeaweedSecureFilePathLocation" __version__ = SYFT_OBJECT_VERSION_2 - upload_id: str + upload_id: Optional[str] = None def generate_url(self, connection, type_, bucket_name): try: