Skip to content

Commit b39f70c

Browse files
committed
feat: building out the final upload step
1 parent a5343fe commit b39f70c

File tree

1 file changed

+181
-88
lines changed

1 file changed

+181
-88
lines changed

examples/create_downampled.py

Lines changed: 181 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,9 @@ def upload_file_to_gcs(local_file_path, gcs_file_path):
413413
bool: True if successful, False otherwise
414414
"""
415415
if not use_gcs_output:
416-
print("GCS output not configured, skipping upload.")
416+
print(
417+
f"GCS output not configured, skipping upload of {local_file_path} to {gcs_file_path}."
418+
)
417419
return True
418420

419421
try:
@@ -430,53 +432,6 @@ def upload_file_to_gcs(local_file_path, gcs_file_path):
430432
return False
431433

432434

433-
def check_and_upload_completed_chunks():
434-
"""
435-
Check for completed chunk files and upload them to GCS if configured.
436-
This helps manage local disk space by uploading and optionally removing completed chunks.
437-
438-
Returns:
439-
int: Number of chunks uploaded
440-
"""
441-
if not use_gcs_output:
442-
return 0
443-
444-
uploaded_count = 0
445-
446-
try:
447-
# Look for chunk files in the output directory
448-
for mip_level in range(num_mips):
449-
mip_dir = output_path / str(mip_level)
450-
if mip_dir.exists():
451-
# Find all chunk files in this mip level
452-
for chunk_file in mip_dir.glob("**/*"):
453-
if chunk_file.is_file():
454-
# Construct the GCS path for this chunk
455-
relative_path = chunk_file.relative_to(output_path)
456-
gcs_chunk_path = (
457-
gcs_output_path.rstrip("/")
458-
+ "/"
459-
+ str(relative_path).replace("\\", "/")
460-
)
461-
462-
# Check if chunk should be uploaded (you can add more logic here)
463-
if upload_file_to_gcs(chunk_file, gcs_chunk_path):
464-
uploaded_count += 1
465-
print(f"Uploaded chunk: {gcs_chunk_path}")
466-
467-
# Optionally remove local chunk to save space
468-
# Uncomment the next line if you want to delete local chunks after upload
469-
# chunk_file.unlink()
470-
471-
if uploaded_count > 0:
472-
print(f"Uploaded {uploaded_count} chunks to GCS output bucket")
473-
474-
except Exception as e:
475-
print(f"Error checking/uploading chunks: {e}")
476-
477-
return uploaded_count
478-
479-
480435
def load_zarr_store(file_path):
481436
"""
482437
Load zarr store from a file path.
@@ -515,18 +470,20 @@ def load_data_from_zarr_store(zarr_store):
515470
exit(1)
516471

517472
# %% Inspect the data
518-
shape = zarr_store.shape
473+
volume_shape = zarr_store.shape
519474
# Input is in Z, T, C, Y, X order
520475
# Want XYTCZ order
521-
single_file_xyz_shape = [shape[4], shape[3], shape[1]]
476+
single_file_xyz_shape = [volume_shape[4], volume_shape[3], volume_shape[1]]
522477
# Here, T and Z are kind of transferrable.
523478
# The reason is because the z dimension in neuroglancer is the time dimension
524479
# from the raw original data.
525480
# So both terms might be used to represent the same thing.
526481
# It's a bit unusual that t is being used as the z dimension,
527482
# but otherwise you can't do volume rendering in neuroglancer.
528483

529-
num_channels = min(shape[2], channel_limit) # Limit to NUM_CHANNELS for memory usage
484+
num_channels = min(
485+
volume_shape[2], channel_limit
486+
) # Limit to NUM_CHANNELS for memory usage
530487
data_type = "uint16"
531488

532489
# %% Compute optimal chunk size based on data shape and MIP levels
@@ -653,9 +610,9 @@ def compute_optimal_chunk_size(single_file_shape, num_mips, max_chunk_size=None)
653610
progress_dir.mkdir(exist_ok=True)
654611

655612
# %% Functions for moving data
656-
shape = volume_size
657-
chunk_shape = np.array(single_file_xyz_shape) # this is for reading data
658-
num_chunks_per_dim = np.ceil(shape / chunk_shape).astype(int)
613+
volume_shape = volume_size
614+
single_file_shape = np.array(single_file_xyz_shape) # this is for reading data
615+
num_chunks_per_dim = np.ceil(volume_shape / single_file_shape).astype(int)
659616

660617

661618
def process(args):
@@ -669,11 +626,15 @@ def process(args):
669626
print(f"Warning: Could not load file for row {x_i}, col {y_i}. Skipping...")
670627
return
671628

672-
start = [x_i * chunk_shape[0], y_i * chunk_shape[1], z_i * chunk_shape[2]]
629+
start = [
630+
x_i * single_file_shape[0],
631+
y_i * single_file_shape[1],
632+
z_i * single_file_shape[2],
633+
]
673634
end = [
674-
min((x_i + 1) * chunk_shape[0], shape[0]),
675-
min((y_i + 1) * chunk_shape[1], shape[1]),
676-
min((z_i + 1) * chunk_shape[2], shape[2]),
635+
min((x_i + 1) * single_file_shape[0], volume_shape[0]),
636+
min((y_i + 1) * single_file_shape[1], volume_shape[1]),
637+
min((z_i + 1) * single_file_shape[2], volume_shape[2]),
677638
]
678639
f_name = progress_dir / f"{start[0]}-{end[0]}_{start[1]}-{end[1]}.done"
679640
print(f"Processing chunk: {start} to {end}, file: {f_name}")
@@ -688,32 +649,66 @@ def process(args):
688649
downsampled = rawdata
689650
ds_start = start
690651
ds_end = end
652+
if not allow_non_aligned_write:
653+
# Align to chunk boundaries
654+
ds_start = [
655+
int(round(math.floor(s / c) * c))
656+
for s, c in zip(ds_start, chunk_size)
657+
]
658+
ds_end = [
659+
int(round(math.ceil(e / c) * c)) for e, c in zip(ds_end, chunk_size)
660+
]
661+
ds_end = [min(e, s) for e, s in zip(ds_end, volume_shape)]
691662
else:
692663
factor = 2**mip_level
693664
factor_tuple = (factor, factor, factor, 1)
694665
ds_start = [int(np.round(s / (2**mip_level))) for s in start]
695-
# Actually make ds_start to be a multiple of chunk size
696-
old_start = ds_start.copy()
697-
ds_start = [int(np.round(s / c) * c) for s, c in zip(ds_start, chunk_shape)]
698-
if ds_start != old_start:
699-
print(f"Adjusted ds_start from {old_start} to {ds_start}")
700-
bounds_from_end = [int(math.ceil(e / (2**mip_level))) for e in end]
666+
if not allow_non_aligned_write:
667+
# Align to chunk boundaries
668+
ds_start = [
669+
int(round(math.floor(s / c) * c))
670+
for s, c in zip(ds_start, chunk_size)
671+
]
701672
downsample_shape = [
702673
int(math.ceil(s / f)) for s, f in zip(rawdata.shape, factor_tuple)
703674
]
704675
ds_end_est = [s + d for s, d in zip(ds_start, downsample_shape)]
705-
# Actually make ds_end_est to be a multiple of chunk size
706-
old_end = ds_end_est.copy()
707-
ds_end_est = [
708-
int(np.round(e / c) * c) for e, c in zip(ds_end_est, chunk_shape)
709-
]
710-
if ds_end_est != old_end:
711-
print(f"Adjusted ds_end_est from {old_end} to {ds_end_est}")
712-
ds_end = [max(e1, e2) for e1, e2 in zip(ds_end_est, bounds_from_end)]
676+
if allow_non_aligned_write:
677+
bounds_from_end = [int(math.ceil(e / (2**mip_level))) for e in end]
678+
ds_end = [max(e1, e2) for e1, e2 in zip(ds_end_est, bounds_from_end)]
679+
else:
680+
# Align to chunk boundaries
681+
ds_end = [
682+
int(round(math.ceil(e / c) * c))
683+
for e, c in zip(ds_end_est, chunk_size)
684+
]
685+
ds_end = [min(e, s) for e, s in zip(ds_end, volume_shape)]
713686
print("DS fill", ds_start, ds_end)
714687
downsampled = downsample_with_averaging(rawdata, factor_tuple)
715688
print("Downsampled shape:", downsampled.shape)
716689

690+
if not allow_non_aligned_write:
691+
# TODO may need to ignore padding at the data edges
692+
# We may need to pad the downsampled data to fit the chunk boundaries
693+
pad_width = [
694+
(0, max(0, de - ds - s))
695+
for ds, de, s in zip(ds_start, ds_end, downsampled.shape)
696+
]
697+
pad_width.append((0, 0)) # No padding for channel dimension
698+
# we should never pad more than the mip level times a factor inverse
699+
max_allowed_pad = 2 ** (num_mips - mip_level)
700+
max_actual_pad = max(pw[1] for pw in pad_width)
701+
if max_actual_pad > max_allowed_pad:
702+
raise ValueError(
703+
f"Padding too large at mip {mip_level}: {pad_width}, max allowed {max_allowed_pad}"
704+
)
705+
if any(pw[1] > 0 for pw in pad_width):
706+
print("Padding downsampled data with:", pad_width)
707+
downsampled = np.pad(
708+
downsampled, pad_width, mode="constant", constant_values=0
709+
)
710+
print("Padded downsampled shape:", downsampled.shape)
711+
717712
vols[mip_level][
718713
ds_start[0] : ds_end[0], ds_start[1] : ds_end[1], ds_start[2] : ds_end[2]
719714
] = downsampled
@@ -725,6 +720,9 @@ def process(args):
725720
# (you can comment this out if you want to keep files cached)
726721
delete_cached_file(x_i, y_i)
727722

723+
# Return the bounds of the processed chunk
724+
return (start, end)
725+
728726

729727
# %% Try with a single chunk to see if it works
730728
x_i, y_i, z_i = 0, 0, 0
@@ -754,31 +752,126 @@ def process(args):
754752
# with ProcessPoolExecutor(max_workers=max_workers) as executor:
755753
# executor.map(process, coords)
756754

755+
# %% Function to check the output directory for completed chunks and upload them to GCS
756+
757+
processed_chunks_bounds = [(np.inf, np.inf, np.inf), (-np.inf, -np.inf, -np.inf)]
758+
759+
760+
# TODO this probably wants to bulk together uploads to reduce overhead
761+
def check_and_upload_completed_chunks():
762+
"""
763+
Check for completed chunk files and upload them to GCS if configured.
764+
This helps manage local disk space by uploading and optionally removing completed chunks.
765+
766+
Returns:
767+
int: Number of chunks uploaded
768+
"""
769+
if not use_gcs_output:
770+
return 0
771+
772+
uploaded_count = 0
773+
774+
for mip_level in range(num_mips):
775+
factor = 2**mip_level
776+
dir_name = f"{factor}_{factor}_{factor}"
777+
output_path_for_mip = output_path / dir_name
778+
# For each file in the output dir check if it is fully covered by the already processed bounds
779+
# First, we loop over all the files in the output directory
780+
for chunk_file in output_path_for_mip.glob("**/*"):
781+
# 1. Pull out the bounds of the chunk from the filename
782+
# filename format is x0-x1_y0-y1_z0-z1
783+
match = re.search(r"(\d+)-(\d+)_(\d+)-(\d+)_(\d+)-(\d+)", str(chunk_file))
784+
if not match:
785+
continue
786+
x0, x1, y0, y1, z0, z1 = map(int, match.groups())
787+
chunk_bounds = [(x0, y0, z0), (x1, y1, z1)]
788+
# Multiply by the factor to get back to original resolution
789+
chunk_bounds = [
790+
[c * factor for c in chunk_bounds[0]],
791+
[c * factor for c in chunk_bounds[1]],
792+
]
793+
# 2. Check if the chunk is fully covered by the processed bounds
794+
if all(
795+
pb0 <= cb0 and pb1 >= cb1
796+
for pb0, pb1, cb0, cb1 in zip(
797+
processed_chunks_bounds[0],
798+
processed_chunks_bounds[1],
799+
chunk_bounds[0],
800+
chunk_bounds[1],
801+
)
802+
):
803+
# 3. If it is, upload it to GCS
804+
relative_path = chunk_file.relative_to(output_path)
805+
gcs_chunk_path = (
806+
gcs_output_path.rstrip("/")
807+
+ "/"
808+
+ str(relative_path).replace("\\", "/")
809+
)
810+
if upload_file_to_gcs(chunk_file, gcs_chunk_path):
811+
uploaded_count += 1
812+
print(f"Uploaded chunk: {gcs_chunk_path}")
813+
# Remove local chunk to save space
814+
chunk_file.unlink()
815+
816+
return uploaded_count
817+
818+
819+
def upload_any_remaining_chunks():
820+
"""
821+
Upload any remaining chunks in the output directory to GCS.
822+
This is called at the end of processing to ensure all data is uploaded.
823+
824+
Returns:
825+
int: Number of chunks uploaded
826+
"""
827+
uploaded_count = 0
828+
829+
for mip_level in range(num_mips):
830+
factor = 2**mip_level
831+
dir_name = f"{factor}_{factor}_{factor}"
832+
output_path_for_mip = output_path / dir_name
833+
# For each file in the output dir
834+
for chunk_file in output_path_for_mip.glob("**/*"):
835+
relative_path = chunk_file.relative_to(output_path)
836+
gcs_chunk_path = (
837+
gcs_output_path.rstrip("/")
838+
+ "/"
839+
+ str(relative_path).replace("\\", "/")
840+
)
841+
if upload_file_to_gcs(chunk_file, gcs_chunk_path):
842+
uploaded_count += 1
843+
print(f"Uploaded chunk: {gcs_chunk_path}")
844+
# Remove local chunk to save space
845+
chunk_file.unlink()
846+
847+
return uploaded_count
848+
849+
757850
# %% Move the data across with a single worker
758-
chunk_count = 0
851+
total_uploaded_files = 0
759852
for coord in reversed_coords:
760-
process(coord)
761-
chunk_count += 1
853+
bounds = process(coord)
854+
if bounds is not None:
855+
start, end = bounds
856+
processed_chunks_bounds[0] = [
857+
min(ps, s) for ps, s in zip(processed_chunks_bounds[0], start)
858+
]
859+
processed_chunks_bounds[1] = [
860+
max(pe, e) for pe, e in zip(processed_chunks_bounds[1], end)
861+
]
862+
print(f"Updated processed bounds: {processed_chunks_bounds}")
762863

763864
# Periodically check and upload completed chunks to save disk space
764865
# This is done every 10 chunks to balance upload frequency vs overhead
765-
if use_gcs_output and chunk_count % 10 == 0:
766-
print(f"Processed {chunk_count} chunks, checking for uploads...")
767-
check_and_upload_completed_chunks()
768-
769-
# The original TODO was about uploading chunks as they're completed
770-
# The above implementation provides a basic version of this functionality
771-
# For more sophisticated chunk management, you could:
772-
# 1. Track which specific chunks are complete across all MIP levels
773-
# 2. Only upload chunks that are fully written at all relevant MIP levels
774-
# 3. Implement more granular deletion of local chunks after successful upload
775-
# 4. Add retry logic for failed uploads
866+
total_uploaded_files += check_and_upload_completed_chunks()
867+
print(f"Total uploaded chunks so far: {total_uploaded_files}")
868+
776869

777870
# Final upload of any remaining chunks
778871
if use_gcs_output:
779872
print("Processing complete, uploading any remaining chunks...")
780-
final_upload_count = check_and_upload_completed_chunks()
781-
print(f"Final upload completed: {final_upload_count} chunks uploaded")
873+
total_uploaded_files += upload_any_remaining_chunks()
874+
print(f"Final upload completed: {total_uploaded_files} chunks uploaded")
782875

783876
# %% Serve the dataset to be used in neuroglancer
784877
vols[0].viewer(port=1337)

0 commit comments

Comments
 (0)