diff --git a/src/pyrovelocity/io/gcs.py b/src/pyrovelocity/io/gcs.py index 9a21cd8d7..9ab431049 100644 --- a/src/pyrovelocity/io/gcs.py +++ b/src/pyrovelocity/io/gcs.py @@ -9,8 +9,9 @@ from pyrovelocity.logging import configure_logging __all__ = [ - "download_bucket", "download_blob_from_uri", + "download_bucket", + "download_bucket_from_uri", "upload_file_concurrently", "upload_directory", ] @@ -21,16 +22,100 @@ @beartype def download_blob_from_uri( blob_uri: str, + concurrent: bool = False, ): client = Client() parsed_blob_uri = urlparse(blob_uri) + if not parsed_blob_uri.scheme == "gs": + raise ValueError( + f"URI scheme must be 'gs', not {parsed_blob_uri.scheme}." + ) blob_path = Path(parsed_blob_uri.path) blob_filename = blob_path.name blob = Blob.from_string(blob_uri, client) - blob.download_to_filename(f"./{blob_filename}") + if concurrent: + download_blob_concurrently(blob, blob_filename) + else: + blob.download_to_filename(f"./{blob_filename}") print(f"Downloaded {blob_uri} to {blob_filename}.") +@beartype +def download_blob_concurrently( + blob: Blob, + filename: str | Path, + chunk_size: int = 32 * 1024 * 1024, + workers: int = 8, +): + """Download a single file in chunks, concurrently in a process pool.""" + + transfer_manager.download_chunks_concurrently( + blob, filename, chunk_size=chunk_size, max_workers=workers + ) + + +@beartype +def download_bucket_from_uri( + bucket_uri: str, + destination_directory: str | Path = "", + workers: int = 8, + max_results: int = 1000, +) -> Result[None, Exception]: + """ + Download all of the blobs in a bucket, concurrently in a process pool. + + The filename of each blob once downloaded is derived from the blob name and + the `destination_directory` parameter. + + Directories will be created automatically as needed, for instance to + accommodate blob names that include slashes. + + Adapted from: + https://github.com/googleapis/python-storage/blob/v2.14.0/samples/snippets/storage_transfer_manager_download_bucket.py + """ + + try: + parsed_bucket_uri = urlparse(bucket_uri) + if not parsed_bucket_uri.scheme == "gs": + raise ValueError( + f"URI scheme must be 'gs', not {parsed_bucket_uri.scheme}." + ) + storage_client = Client() + bucket = storage_client.bucket(parsed_bucket_uri.netloc) + + blob_names = [ + blob.name + for blob in bucket.list_blobs( + max_results=max_results, + prefix=parsed_bucket_uri.path[1:], + ) + ] + + results = transfer_manager.download_many_to_path( + bucket, + blob_names, + destination_directory=destination_directory, + max_workers=workers, + ) + + for name, result in zip(blob_names, results): + if isinstance(result, Exception): + logger.error( + f"Failed to download {name} due to exception: {result}" + ) + else: + logger.info( + f"Downloaded {name} to {destination_directory + "/" + name}" + ) + return Success(None) + + except Exception as e: + logger.error( + f"Failed to download files from {bucket_uri} to {destination_directory}." + ) + return Failure(e) + + @beartype def download_bucket( bucket_name: str, @@ -85,26 +170,6 @@ def download_bucket( return Failure(e) -def download_blob_concurrently( - bucket_name: str, - blob_name: str, - filename: str | Path, - chunk_size: int = 32 * 1024 * 1024, - workers: int = 8, -): - """Download a single file in chunks, concurrently in a process pool.""" - - storage_client = Client() - bucket = storage_client.bucket(bucket_name) - blob = bucket.blob(blob_name) - - transfer_manager.download_chunks_concurrently( - blob, filename, chunk_size=chunk_size, max_workers=workers - ) - - print(f"Downloaded {blob_name} to {filename}.") - - @beartype def upload_file_concurrently( bucket_name: str,