Skip to content

Commit

Permalink
Merge pull request #10 from AllenInstitute/feature/update-batch-data-…
Browse files Browse the repository at this point in the history
…sync

add support for s3 references to batch data sync requests
  • Loading branch information
rpmcginty authored Jun 27, 2024
2 parents 481ef75 + 49bcafb commit cb32318
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
27 changes: 24 additions & 3 deletions src/aibs_informatics_aws_lambda/handlers/data_sync/operations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, cast

from aibs_informatics_aws_utils.data_sync import (
DataSyncOperations,
Expand Down Expand Up @@ -122,7 +122,15 @@ class BatchDataSyncHandler(LambdaHandler[BatchDataSyncRequest, BatchDataSyncResp
def handle(self, request: BatchDataSyncRequest) -> BatchDataSyncResponse:
self.logger.info(f"Received {len(request.requests)} requests to transfer")
responses = []
for _ in request.requests:
if isinstance(request.requests, S3URI):
self.logger.info(f"Request is stored at {request.requests}... fetching content.")
_ = download_to_json(request.requests)
assert isinstance(_, list)
batch_requests = [DataSyncRequest.from_dict(__) for __ in _]
else:
batch_requests = request.requests

for _ in batch_requests:
sync_operations = DataSyncOperations(_)
self.logger.info(f"Syncing content from {_.source_path} to {_.destination_path}")
sync_operations.sync(
Expand Down Expand Up @@ -173,7 +181,20 @@ def handle(self, request: PrepareBatchDataSyncRequest) -> PrepareBatchDataSyncRe
)
)
batch_data_sync_requests.append(BatchDataSyncRequest(requests=data_sync_requests))
return PrepareBatchDataSyncResponse(requests=batch_data_sync_requests)

if request.intermediate_s3_path:
self.logger.info(f"Uploading batch requests to {request.intermediate_s3_path}")
new_batch_data_sync_requests = []

for i, batch_data_sync_request in enumerate(batch_data_sync_requests):
upload_json(
[cast(DataSyncRequest, _).to_dict() for _ in batch_data_sync_request.requests],
s3_path=(s3_path := request.intermediate_s3_path / f"request_{i}.json"),
)
new_batch_data_sync_requests.append(BatchDataSyncRequest(requests=s3_path))
return PrepareBatchDataSyncResponse(requests=new_batch_data_sync_requests)
else:
return PrepareBatchDataSyncResponse(requests=batch_data_sync_requests)

@classmethod
def build_source_path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def test__handles__puts_content_with_local_path_specified(self):

self.assertHandles(self.handler, request.to_dict(), response.to_dict())

self.assertIsInstance(response.path, Path)

self.assertEqual(response.path.read_text(), f"{content}")
assert isinstance(response.path, Path)
assert response.path.exists()
assert response.path.read_text() == f"{content}"

self.mock_upload_content.assert_not_called()

Expand Down Expand Up @@ -138,6 +138,10 @@ class PrepareBatchDataSyncHandlerTests(LambdaHandlerTestCase):
def setUp(self) -> None:
super().setUp()

self.mock_upload_content = self.create_patch(
"aibs_informatics_aws_lambda.handlers.data_sync.operations.upload_json"
)

@property
def handler(self) -> LambdaHandlerType:
return PrepareBatchDataSyncHandler.get_handler()
Expand Down Expand Up @@ -198,6 +202,45 @@ def test__handle__prepare_local_to_s3__simple(self):
)
self.assertHandles(self.handler, request.to_dict(), expected.to_dict())

def test__handle__prepare_local_to_s3__simple__upload_to_s3(self):
fs = self.setUpLocalFS(
("a", 1),
("b", 1),
("c", 1),
)
source_path = fs
destination_path = S3Path.build(bucket_name="bucket", key="key/")
request = PrepareBatchDataSyncRequest(
source_path=source_path,
destination_path=destination_path,
batch_size_bytes_limit=10,
max_concurrency=10,
retain_source_data=True,
intermediate_s3_path=S3Path.build(bucket_name="bucket", key="intermediate/"),
)

expected_json = [
DataSyncRequest(
source_path=source_path,
destination_path=destination_path,
max_concurrency=10,
retain_source_data=True,
).to_dict()
]
expected = PrepareBatchDataSyncResponse(
requests=[
BatchDataSyncRequest(
requests=S3Path.build(bucket_name="bucket", key="intermediate/request_0.json"),
)
]
)
self.assertHandles(self.handler, request.to_dict(), expected.to_dict())

self.mock_upload_content.assert_called_once_with(
expected_json,
s3_path=S3Path.build(bucket_name="bucket", key="intermediate/request_0.json"),
)

def test__handle__prepare_local_to_s3__complex(self):
fs = self.setUpLocalFS(
("a", 3),
Expand Down

0 comments on commit cb32318

Please sign in to comment.