From 49bcafbd25a3b76a9a5cd7cd3f93c1457eb2abe9 Mon Sep 17 00:00:00 2001 From: Ryan McGinty Date: Wed, 26 Jun 2024 12:36:04 -0700 Subject: [PATCH] add support for s3 references to batch data sync requests --- .../handlers/data_sync/operations.py | 27 ++++++++-- .../handlers/data_sync/test_operations.py | 49 +++++++++++++++++-- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/src/aibs_informatics_aws_lambda/handlers/data_sync/operations.py b/src/aibs_informatics_aws_lambda/handlers/data_sync/operations.py index fda6d3f..a452308 100644 --- a/src/aibs_informatics_aws_lambda/handlers/data_sync/operations.py +++ b/src/aibs_informatics_aws_lambda/handlers/data_sync/operations.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Union, cast from aibs_informatics_aws_utils.data_sync import ( DataSyncOperations, @@ -122,7 +122,15 @@ class BatchDataSyncHandler(LambdaHandler[BatchDataSyncRequest, BatchDataSyncResp def handle(self, request: BatchDataSyncRequest) -> BatchDataSyncResponse: self.logger.info(f"Received {len(request.requests)} requests to transfer") responses = [] - for _ in request.requests: + if isinstance(request.requests, S3URI): + self.logger.info(f"Request is stored at {request.requests}... fetching content.") + _ = download_to_json(request.requests) + assert isinstance(_, list) + batch_requests = [DataSyncRequest.from_dict(__) for __ in _] + else: + batch_requests = request.requests + + for _ in batch_requests: sync_operations = DataSyncOperations(_) self.logger.info(f"Syncing content from {_.source_path} to {_.destination_path}") sync_operations.sync( @@ -170,7 +178,20 @@ def handle(self, request: PrepareBatchDataSyncRequest) -> PrepareBatchDataSyncRe ) ) batch_data_sync_requests.append(BatchDataSyncRequest(requests=data_sync_requests)) - return PrepareBatchDataSyncResponse(requests=batch_data_sync_requests) + + if request.intermediate_s3_path: + self.logger.info(f"Uploading batch requests to {request.intermediate_s3_path}") + new_batch_data_sync_requests = [] + + for i, batch_data_sync_request in enumerate(batch_data_sync_requests): + upload_json( + [cast(DataSyncRequest, _).to_dict() for _ in batch_data_sync_request.requests], + s3_path=(s3_path := request.intermediate_s3_path / f"request_{i}.json"), + ) + new_batch_data_sync_requests.append(BatchDataSyncRequest(requests=s3_path)) + return PrepareBatchDataSyncResponse(requests=new_batch_data_sync_requests) + else: + return PrepareBatchDataSyncResponse(requests=batch_data_sync_requests) @classmethod def build_source_path( diff --git a/test/aibs_informatics_aws_lambda/handlers/data_sync/test_operations.py b/test/aibs_informatics_aws_lambda/handlers/data_sync/test_operations.py index 9958a7d..a288e90 100644 --- a/test/aibs_informatics_aws_lambda/handlers/data_sync/test_operations.py +++ b/test/aibs_informatics_aws_lambda/handlers/data_sync/test_operations.py @@ -104,9 +104,9 @@ def test__handles__puts_content_with_local_path_specified(self): self.assertHandles(self.handler, request.to_dict(), response.to_dict()) - self.assertIsInstance(response.path, Path) - - self.assertEqual(response.path.read_text(), f"{content}") + assert isinstance(response.path, Path) + assert response.path.exists() + assert response.path.read_text() == f"{content}" self.mock_upload_content.assert_not_called() @@ -138,6 +138,10 @@ class PrepareBatchDataSyncHandlerTests(LambdaHandlerTestCase): def setUp(self) -> None: super().setUp() + self.mock_upload_content = self.create_patch( + "aibs_informatics_aws_lambda.handlers.data_sync.operations.upload_json" + ) + @property def handler(self) -> LambdaHandlerType: return PrepareBatchDataSyncHandler.get_handler() @@ -198,6 +202,45 @@ def test__handle__prepare_local_to_s3__simple(self): ) self.assertHandles(self.handler, request.to_dict(), expected.to_dict()) + def test__handle__prepare_local_to_s3__simple__upload_to_s3(self): + fs = self.setUpLocalFS( + ("a", 1), + ("b", 1), + ("c", 1), + ) + source_path = fs + destination_path = S3Path.build(bucket_name="bucket", key="key/") + request = PrepareBatchDataSyncRequest( + source_path=source_path, + destination_path=destination_path, + batch_size_bytes_limit=10, + max_concurrency=10, + retain_source_data=True, + intermediate_s3_path=S3Path.build(bucket_name="bucket", key="intermediate/"), + ) + + expected_json = [ + DataSyncRequest( + source_path=source_path, + destination_path=destination_path, + max_concurrency=10, + retain_source_data=True, + ).to_dict() + ] + expected = PrepareBatchDataSyncResponse( + requests=[ + BatchDataSyncRequest( + requests=S3Path.build(bucket_name="bucket", key="intermediate/request_0.json"), + ) + ] + ) + self.assertHandles(self.handler, request.to_dict(), expected.to_dict()) + + self.mock_upload_content.assert_called_once_with( + expected_json, + s3_path=S3Path.build(bucket_name="bucket", key="intermediate/request_0.json"), + ) + def test__handle__prepare_local_to_s3__complex(self): fs = self.setUpLocalFS( ("a", 3),