Skip to content

Commit

Permalink
Merge pull request #8379 from OpenMined/fix-image-id-and-worker_pool-…
Browse files Browse the repository at this point in the history
…list

Fix image id and worker pool list
  • Loading branch information
shubham3121 authored Jan 9, 2024
2 parents 4b810d6 + f36d291 commit 3767bde
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@
"SeaweedSecureFilePathLocation": {
"2": {
"version": 2,
"hash": "3ca49db7536a33d5712485164e95406000df9af2aed78e9f9fa2bb2bbbb34fe6",
"hash": "5fd63fed2a4efba8c2b6c7a7b5e9b5939181781c331230896aa130b6fd558739",
"action": "add"
}
},
Expand Down
26 changes: 18 additions & 8 deletions packages/syft/src/syft/service/blob_storage/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ...types.blob_storage import BlobStorageEntry
from ...types.blob_storage import BlobStorageMetadata
from ...types.blob_storage import CreateBlobStorageEntry
from ...types.blob_storage import SeaweedSecureFilePathLocation
from ...types.uid import UID
from ..context import AuthedServiceContext
from ..response import SyftError
Expand Down Expand Up @@ -63,6 +64,7 @@ def mount_azure(
account_key: str,
container_name: str,
bucket_name: str,
use_direct_connections=True,
):
# stdlib

Expand Down Expand Up @@ -105,14 +107,22 @@ def mount_azure(
objects = res["Contents"]
file_sizes = [object["Size"] for object in objects]
file_paths = [object["Key"] for object in objects]
secure_file_paths = [
AzureSecureFilePathLocation(
path=file_path,
azure_profile_name=remote_name,
bucket_name=bucket_name,
)
for file_path in file_paths
]
if use_direct_connections:
secure_file_paths = [
AzureSecureFilePathLocation(
path=file_path,
azure_profile_name=remote_name,
bucket_name=bucket_name,
)
for file_path in file_paths
]
else:
secure_file_paths = [
SeaweedSecureFilePathLocation(
path=file_path,
)
for file_path in file_paths
]

for sfp, file_size in zip(secure_file_paths, file_sizes):
blob_storage_entry = BlobStorageEntry(
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/service/worker/worker_pool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def add_workers(

result = self.image_stash.get_by_uid(
credentials=context.credentials,
uid=worker_pool.syft_worker_image_id,
uid=worker_pool.image.id,
)

if result.is_err():
Expand All @@ -204,7 +204,7 @@ def add_workers(
worker_stash=self.worker_stash,
)

worker_pool.worker_list.append(worker_list)
worker_pool.worker_list += worker_list
worker_pool.max_count = existing_worker_cnt + number

update_result = self.stash.update(
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/types/blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class SeaweedSecureFilePathLocation(SecureFilePathLocation):
__canonical_name__ = "SeaweedSecureFilePathLocation"
__version__ = SYFT_OBJECT_VERSION_2

upload_id: str
upload_id: Optional[str] = None

def generate_url(self, connection, type_, bucket_name):
try:
Expand Down

0 comments on commit 3767bde

Please sign in to comment.