Skip to content

Commit

Permalink
Merge branch 'dev' into safer_execution
Browse files Browse the repository at this point in the history
  • Loading branch information
teo-milea authored Jan 9, 2024
2 parents 879ee22 + 3767bde commit a2582ff
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 20 deletions.
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@
"SeaweedSecureFilePathLocation": {
"2": {
"version": 2,
"hash": "3ca49db7536a33d5712485164e95406000df9af2aed78e9f9fa2bb2bbbb34fe6",
"hash": "5fd63fed2a4efba8c2b6c7a7b5e9b5939181781c331230896aa130b6fd558739",
"action": "add"
}
},
Expand Down
26 changes: 18 additions & 8 deletions packages/syft/src/syft/service/blob_storage/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ...types.blob_storage import BlobStorageEntry
from ...types.blob_storage import BlobStorageMetadata
from ...types.blob_storage import CreateBlobStorageEntry
from ...types.blob_storage import SeaweedSecureFilePathLocation
from ...types.uid import UID
from ..context import AuthedServiceContext
from ..response import SyftError
Expand Down Expand Up @@ -63,6 +64,7 @@ def mount_azure(
account_key: str,
container_name: str,
bucket_name: str,
use_direct_connections=True,
):
# stdlib

Expand Down Expand Up @@ -105,14 +107,22 @@ def mount_azure(
objects = res["Contents"]
file_sizes = [object["Size"] for object in objects]
file_paths = [object["Key"] for object in objects]
secure_file_paths = [
AzureSecureFilePathLocation(
path=file_path,
azure_profile_name=remote_name,
bucket_name=bucket_name,
)
for file_path in file_paths
]
if use_direct_connections:
secure_file_paths = [
AzureSecureFilePathLocation(
path=file_path,
azure_profile_name=remote_name,
bucket_name=bucket_name,
)
for file_path in file_paths
]
else:
secure_file_paths = [
SeaweedSecureFilePathLocation(
path=file_path,
)
for file_path in file_paths
]

for sfp, file_size in zip(secure_file_paths, file_sizes):
blob_storage_entry = BlobStorageEntry(
Expand Down
26 changes: 18 additions & 8 deletions packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,13 @@ def check(self) -> Union[SyftSuccess, SyftError]:
return SyftError(
message=f"set_obj type {type(self.data)} must match set_mock type {type(self.mock)}"
)
if not _is_action_data_empty(self.mock):
data_shape = get_shape_or_len(self.data)
mock_shape = get_shape_or_len(self.mock)
if data_shape != mock_shape:
return SyftError(
message=f"set_obj shape {data_shape} must match set_mock shape {mock_shape}"
)
# if not _is_action_data_empty(self.mock):
# data_shape = get_shape_or_len(self.data)
# mock_shape = get_shape_or_len(self.mock)
# if data_shape != mock_shape:
# return SyftError(
# message=f"set_obj shape {data_shape} must match set_mock shape {mock_shape}"
# )
total_size_mb = get_mb_size(self.data) + get_mb_size(self.mock)
if total_size_mb > DATA_SIZE_WARNING_LIMIT:
print(
Expand All @@ -438,7 +438,10 @@ def get_shape_or_len(obj: Any) -> Optional[Union[Tuple[int, ...], int]]:
return shape
len_attr = getattr(obj, "__len__", None)
if len_attr is not None:
return len_attr()
len_value = len_attr()
if isinstance(len_value, int):
return (len_value,)
return len_value
return None


Expand Down Expand Up @@ -774,6 +777,12 @@ def add_msg_creation_time(context: TransformContext) -> TransformContext:
return context


def add_default_node_uid(context: TransformContext) -> TransformContext:
if context.output["node_uid"] is None:
context.output["node_uid"] = context.node.id
return context


@transform(CreateAsset, Asset)
def createasset_to_asset() -> List[Callable]:
return [
Expand All @@ -782,6 +791,7 @@ def createasset_to_asset() -> List[Callable]:
infer_shape,
create_and_store_twin,
set_data_subjects,
add_default_node_uid,
]


Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/service/worker/worker_pool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def add_workers(

result = self.image_stash.get_by_uid(
credentials=context.credentials,
uid=worker_pool.syft_worker_image_id,
uid=worker_pool.image.id,
)

if result.is_err():
Expand All @@ -204,7 +204,7 @@ def add_workers(
worker_stash=self.worker_stash,
)

worker_pool.worker_list.append(worker_list)
worker_pool.worker_list += worker_list
worker_pool.max_count = existing_worker_cnt + number

update_result = self.stash.update(
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/types/blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class SeaweedSecureFilePathLocation(SecureFilePathLocation):
__canonical_name__ = "SeaweedSecureFilePathLocation"
__version__ = SYFT_OBJECT_VERSION_2

upload_id: str
upload_id: Optional[str] = None

def generate_url(self, connection, type_, bucket_name):
try:
Expand Down

0 comments on commit a2582ff

Please sign in to comment.