Skip to content

Commit 97d9a3c

Browse files
committed
feat: use gcloud cp -r instead of own recursion
1 parent 0324302 commit 97d9a3c

File tree

1 file changed

+67
-31
lines changed

1 file changed

+67
-31
lines changed

examples/create_downampled.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import itertools
44
import math
5+
import subprocess
56
from pathlib import Path
67
import numpy as np
78
from cloudvolume import CloudVolume
@@ -72,6 +73,8 @@ def load_env_config():
7273
# Local paths (used when USE_GCS_BUCKET is False)
7374
"INPUT_PATH": Path(os.getenv("INPUT_PATH", "/temp/in")),
7475
"OUTPUT_PATH": Path(os.getenv("OUTPUT_PATH", "/temp/out")),
76+
"DELETE_INPUT": parse_bool(os.getenv("DELETE_INPUT", "false")),
77+
"DELETE_OUTPUT": parse_bool(os.getenv("DELETE_OUTPUT", "false")),
7578
# Processing settings
7679
"OVERWRITE": parse_bool(os.getenv("OVERWRITE", "false")),
7780
"OVERWRITE_GCS": parse_bool(os.getenv("OVERWRITE_GCS", "false")),
@@ -107,6 +110,8 @@ def load_env_config():
107110
gcs_output_path = config["GCS_OUTPUT_PREFIX"]
108111
input_path = config["INPUT_PATH"]
109112
output_path = config["OUTPUT_PATH"]
113+
delete_input = config["DELETE_INPUT"]
114+
delete_output = config["DELETE_OUTPUT"]
110115
overwrite_output = config["OVERWRITE"]
111116
overwrite_gcs = config["OVERWRITE_GCS"]
112117
num_mips = config["NUM_MIPS"]
@@ -148,13 +153,8 @@ def list_gcs_files(bucket_name, prefix="", file_extension=""):
148153
"""
149154
if gcs_local_list and gcs_local_list.exists():
150155
print(f"Loading file list from local file: {gcs_local_list}")
151-
full_prefix_str = "gs://" + bucket_name + "/"
152156
with open(gcs_local_list, "r") as f:
153-
files = [
154-
line.strip().rstrip("/")[len(full_prefix_str) :]
155-
for line in f
156-
if line.strip()
157-
]
157+
files = [line.strip() for line in f if line.strip()]
158158
print(f"Found {len(files)} files in local list")
159159
return files
160160
client = storage.Client(project=gcs_project)
@@ -281,9 +281,9 @@ def get_local_cache_path(row, col):
281281

282282
# Create local filename based on remote file
283283
if use_gcs_bucket:
284-
cache_dir = input_path / "cache"
284+
cache_dir = input_path
285285
cache_dir.mkdir(exist_ok=True, parents=True)
286-
local_name = str(remote_file).split("/")[-1]
286+
local_name = str(remote_file).rstrip("/").split("/")[-1]
287287
output = Path(cache_dir / local_name)
288288
else:
289289
output = Path(remote_file)
@@ -316,7 +316,45 @@ def get_remote_file_path(row, col):
316316
return None
317317

318318

319-
def download_file(row, col):
319+
def gcloud_download_dir(gs_prefix: str, local_dir: Path) -> None:
320+
"""
321+
Recursively download a GCS prefix to a local directory using gcloud.
322+
Example gs_prefix: 'gs://my-bucket/some/prefix/'
323+
"""
324+
local_dir.mkdir(parents=True, exist_ok=True)
325+
326+
# Use a list (no shell=True) to avoid injection issues
327+
cmd = [
328+
"gcloud",
329+
"storage",
330+
"cp",
331+
"--recursive",
332+
"--project",
333+
gcs_project,
334+
gs_prefix,
335+
str(local_dir),
336+
]
337+
338+
print("Running command:", " ".join(cmd))
339+
try:
340+
res = subprocess.run(
341+
cmd,
342+
check=True, # raises CalledProcessError on nonzero exit
343+
capture_output=True, # capture logs; integrate with your logger
344+
text=True,
345+
)
346+
print(res.stdout)
347+
if res.stderr:
348+
print(res.stderr)
349+
except subprocess.CalledProcessError as e:
350+
# Surface meaningful diagnostics
351+
print("gcloud cp failed:", e.returncode)
352+
print(e.stdout)
353+
print(e.stderr)
354+
raise
355+
356+
357+
def download_zarr_file(row, col):
320358
"""
321359
Download the file for a specific row and column to local cache.
322360
@@ -341,23 +379,13 @@ def download_file(row, col):
341379
print(f"File already cached: {local_path}")
342380
return local_path
343381

344-
local_path.parent.mkdir(exist_ok=True, parents=True)
382+
local_path.mkdir(exist_ok=True, parents=True)
345383

346384
if use_gcs_bucket:
347-
# Download from GCS
348-
try:
349-
client = storage.Client(project=gcs_project)
350-
bucket = client.bucket(gcs_bucket_name)
351-
blob = bucket.blob(remote_file)
352-
353-
print(f"Downloading {remote_file} to {local_path}")
354-
blob.download_to_filename(str(local_path))
355-
print(f"Downloaded successfully: {local_path}")
356-
return local_path
357-
except Exception as e:
358-
print(f"Error downloading {remote_file}: {e}")
359-
return None
360-
return remote_file # For local files, just return the path
385+
gcloud_download_dir(remote_file, local_path.parent)
386+
return local_path
387+
else:
388+
return remote_file # For local files, just return the path
361389

362390

363391
def load_file(row, col):
@@ -372,7 +400,7 @@ def load_file(row, col):
372400
Returns:
373401
zarr store object, or None if not found/error
374402
"""
375-
local_path = download_file(row, col)
403+
local_path = download_zarr_file(row, col)
376404
if local_path is None:
377405
return None
378406

@@ -385,7 +413,7 @@ def load_file(row, col):
385413
return None
386414

387415

388-
def delete_cached_file(row, col):
416+
def delete_cached_zarr_file(row, col):
389417
"""
390418
Delete the locally cached file for a specific row and column to save disk space.
391419
@@ -396,15 +424,23 @@ def delete_cached_file(row, col):
396424
Returns:
397425
bool: True if file was deleted or didn't exist, False if error
398426
"""
399-
if not use_gcs_bucket:
427+
if not use_gcs_bucket or not delete_input:
400428
return True
401429
local_path = get_local_cache_path(row, col)
402430
if local_path is None:
403431
return True
404432

405433
try:
434+
# Check that local_path is not something dangerous like root or home directory
435+
if local_path in [Path("/"), Path.home()]:
436+
print(f"Refusing to delete dangerous path: {local_path}")
437+
return False
438+
# It should also end with .zarr
439+
if not local_path.suffix == ".zarr":
440+
print(f"Refusing to delete non-zarr path: {local_path}")
441+
return False
406442
if local_path.exists():
407-
local_path.unlink()
443+
local_path.rmdir()
408444
print(f"Deleted cached file: {local_path}")
409445
return True
410446
except Exception as e:
@@ -743,7 +779,7 @@ def process(args):
743779

744780
# Clean up cached file to save disk space
745781
# (you can comment this out if you want to keep files cached)
746-
delete_cached_file(x_i, y_i)
782+
delete_cached_zarr_file(x_i, y_i)
747783

748784
# Return the bounds of the processed chunk
749785
return (start, end)
@@ -892,7 +928,7 @@ def check_and_upload_completed_chunks():
892928
):
893929
uploaded_count += 1
894930
# Remove local chunk to save space
895-
if use_gcs_output:
931+
if use_gcs_output and delete_output:
896932
chunk_file.unlink()
897933
uploaded_files.append((chunk_file, gcs_chunk_path))
898934

@@ -926,7 +962,7 @@ def upload_any_remaining_chunks():
926962
if upload_file_to_gcs(chunk_file, gcs_chunk_path, overwrite=overwrite_gcs):
927963
uploaded_count += 1
928964
# Remove local chunk to save space
929-
if use_gcs_output:
965+
if use_gcs_output and delete_output:
930966
chunk_file.unlink()
931967
uploaded_files.append((chunk_file, gcs_chunk_path))
932968

0 commit comments

Comments
 (0)