22
33import itertools
44import math
5+ import subprocess
56from pathlib import Path
67import numpy as np
78from cloudvolume import CloudVolume
@@ -72,6 +73,8 @@ def load_env_config():
7273 # Local paths (used when USE_GCS_BUCKET is False)
7374 "INPUT_PATH" : Path (os .getenv ("INPUT_PATH" , "/temp/in" )),
7475 "OUTPUT_PATH" : Path (os .getenv ("OUTPUT_PATH" , "/temp/out" )),
76+ "DELETE_INPUT" : parse_bool (os .getenv ("DELETE_INPUT" , "false" )),
77+ "DELETE_OUTPUT" : parse_bool (os .getenv ("DELETE_OUTPUT" , "false" )),
7578 # Processing settings
7679 "OVERWRITE" : parse_bool (os .getenv ("OVERWRITE" , "false" )),
7780 "OVERWRITE_GCS" : parse_bool (os .getenv ("OVERWRITE_GCS" , "false" )),
@@ -107,6 +110,8 @@ def load_env_config():
107110gcs_output_path = config ["GCS_OUTPUT_PREFIX" ]
108111input_path = config ["INPUT_PATH" ]
109112output_path = config ["OUTPUT_PATH" ]
113+ delete_input = config ["DELETE_INPUT" ]
114+ delete_output = config ["DELETE_OUTPUT" ]
110115overwrite_output = config ["OVERWRITE" ]
111116overwrite_gcs = config ["OVERWRITE_GCS" ]
112117num_mips = config ["NUM_MIPS" ]
@@ -148,13 +153,8 @@ def list_gcs_files(bucket_name, prefix="", file_extension=""):
148153 """
149154 if gcs_local_list and gcs_local_list .exists ():
150155 print (f"Loading file list from local file: { gcs_local_list } " )
151- full_prefix_str = "gs://" + bucket_name + "/"
152156 with open (gcs_local_list , "r" ) as f :
153- files = [
154- line .strip ().rstrip ("/" )[len (full_prefix_str ) :]
155- for line in f
156- if line .strip ()
157- ]
157+ files = [line .strip () for line in f if line .strip ()]
158158 print (f"Found { len (files )} files in local list" )
159159 return files
160160 client = storage .Client (project = gcs_project )
@@ -281,9 +281,9 @@ def get_local_cache_path(row, col):
281281
282282 # Create local filename based on remote file
283283 if use_gcs_bucket :
284- cache_dir = input_path / "cache"
284+ cache_dir = input_path
285285 cache_dir .mkdir (exist_ok = True , parents = True )
286- local_name = str (remote_file ).split ("/" )[- 1 ]
286+ local_name = str (remote_file ).rstrip ( "/" ). split ("/" )[- 1 ]
287287 output = Path (cache_dir / local_name )
288288 else :
289289 output = Path (remote_file )
@@ -316,7 +316,45 @@ def get_remote_file_path(row, col):
316316 return None
317317
318318
319- def download_file (row , col ):
319+ def gcloud_download_dir (gs_prefix : str , local_dir : Path ) -> None :
320+ """
321+ Recursively download a GCS prefix to a local directory using gcloud.
322+ Example gs_prefix: 'gs://my-bucket/some/prefix/'
323+ """
324+ local_dir .mkdir (parents = True , exist_ok = True )
325+
326+ # Use a list (no shell=True) to avoid injection issues
327+ cmd = [
328+ "gcloud" ,
329+ "storage" ,
330+ "cp" ,
331+ "--recursive" ,
332+ "--project" ,
333+ gcs_project ,
334+ gs_prefix ,
335+ str (local_dir ),
336+ ]
337+
338+ print ("Running command:" , " " .join (cmd ))
339+ try :
340+ res = subprocess .run (
341+ cmd ,
342+ check = True , # raises CalledProcessError on nonzero exit
343+ capture_output = True , # capture logs; integrate with your logger
344+ text = True ,
345+ )
346+ print (res .stdout )
347+ if res .stderr :
348+ print (res .stderr )
349+ except subprocess .CalledProcessError as e :
350+ # Surface meaningful diagnostics
351+ print ("gcloud cp failed:" , e .returncode )
352+ print (e .stdout )
353+ print (e .stderr )
354+ raise
355+
356+
357+ def download_zarr_file (row , col ):
320358 """
321359 Download the file for a specific row and column to local cache.
322360
@@ -341,23 +379,13 @@ def download_file(row, col):
341379 print (f"File already cached: { local_path } " )
342380 return local_path
343381
344- local_path .parent . mkdir (exist_ok = True , parents = True )
382+ local_path .mkdir (exist_ok = True , parents = True )
345383
346384 if use_gcs_bucket :
347- # Download from GCS
348- try :
349- client = storage .Client (project = gcs_project )
350- bucket = client .bucket (gcs_bucket_name )
351- blob = bucket .blob (remote_file )
352-
353- print (f"Downloading { remote_file } to { local_path } " )
354- blob .download_to_filename (str (local_path ))
355- print (f"Downloaded successfully: { local_path } " )
356- return local_path
357- except Exception as e :
358- print (f"Error downloading { remote_file } : { e } " )
359- return None
360- return remote_file # For local files, just return the path
385+ gcloud_download_dir (remote_file , local_path .parent )
386+ return local_path
387+ else :
388+ return remote_file # For local files, just return the path
361389
362390
363391def load_file (row , col ):
@@ -372,7 +400,7 @@ def load_file(row, col):
372400 Returns:
373401 zarr store object, or None if not found/error
374402 """
375- local_path = download_file (row , col )
403+ local_path = download_zarr_file (row , col )
376404 if local_path is None :
377405 return None
378406
@@ -385,7 +413,7 @@ def load_file(row, col):
385413 return None
386414
387415
388- def delete_cached_file (row , col ):
416+ def delete_cached_zarr_file (row , col ):
389417 """
390418 Delete the locally cached file for a specific row and column to save disk space.
391419
@@ -396,15 +424,23 @@ def delete_cached_file(row, col):
396424 Returns:
397425 bool: True if file was deleted or didn't exist, False if error
398426 """
399- if not use_gcs_bucket :
427+ if not use_gcs_bucket or not delete_input :
400428 return True
401429 local_path = get_local_cache_path (row , col )
402430 if local_path is None :
403431 return True
404432
405433 try :
434+ # Check that local_path is not something dangerous like root or home directory
435+ if local_path in [Path ("/" ), Path .home ()]:
436+ print (f"Refusing to delete dangerous path: { local_path } " )
437+ return False
438+ # It should also end with .zarr
439+ if not local_path .suffix == ".zarr" :
440+ print (f"Refusing to delete non-zarr path: { local_path } " )
441+ return False
406442 if local_path .exists ():
407- local_path .unlink ()
443+ local_path .rmdir ()
408444 print (f"Deleted cached file: { local_path } " )
409445 return True
410446 except Exception as e :
@@ -743,7 +779,7 @@ def process(args):
743779
744780 # Clean up cached file to save disk space
745781 # (you can comment this out if you want to keep files cached)
746- delete_cached_file (x_i , y_i )
782+ delete_cached_zarr_file (x_i , y_i )
747783
748784 # Return the bounds of the processed chunk
749785 return (start , end )
@@ -892,7 +928,7 @@ def check_and_upload_completed_chunks():
892928 ):
893929 uploaded_count += 1
894930 # Remove local chunk to save space
895- if use_gcs_output :
931+ if use_gcs_output and delete_output :
896932 chunk_file .unlink ()
897933 uploaded_files .append ((chunk_file , gcs_chunk_path ))
898934
@@ -926,7 +962,7 @@ def upload_any_remaining_chunks():
926962 if upload_file_to_gcs (chunk_file , gcs_chunk_path , overwrite = overwrite_gcs ):
927963 uploaded_count += 1
928964 # Remove local chunk to save space
929- if use_gcs_output :
965+ if use_gcs_output and delete_output :
930966 chunk_file .unlink ()
931967 uploaded_files .append ((chunk_file , gcs_chunk_path ))
932968
0 commit comments