|
| 1 | +import math |
| 2 | + |
| 3 | +from cloudvolume.lib import touch |
| 4 | +import numpy as np |
| 5 | +from neuroglancer.downsample import downsample_with_averaging |
| 6 | + |
| 7 | +from io import load_data_from_zarr_store, delete_cached_zarr_file, load_file |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | +def compute_optimal_chunk_size(single_file_shape, num_mips, max_chunk_size=None): |
| 12 | + """ |
| 13 | + Compute optimal chunk size based on single file shape and number of MIP levels. |
| 14 | +
|
| 15 | + Args: |
| 16 | + single_file_shape: [x, y, z] shape of a single file |
| 17 | + num_mips: Number of MIP levels |
| 18 | + max_chunk_size: Optional maximum chunk size (default: 512) |
| 19 | +
|
| 20 | + Returns: |
| 21 | + List[int]: [chunk_x, chunk_y, chunk_z] optimal chunk sizes |
| 22 | + """ |
| 23 | + if max_chunk_size is None: |
| 24 | + max_chunk_size = 512 |
| 25 | + |
| 26 | + single_file_shape = np.array(single_file_shape) |
| 27 | + optimal_chunks = np.ceil(single_file_shape / (2 ** (num_mips - 1))) |
| 28 | + |
| 29 | + return [int(c) for c in optimal_chunks] |
| 30 | + |
| 31 | + |
| 32 | +def compute_volume_and_chunk_size( |
| 33 | + single_file_xyz_shape, |
| 34 | + computed_num_rows, |
| 35 | + computed_num_cols, |
| 36 | + num_mips, |
| 37 | + manual_chunk_size, |
| 38 | +): |
| 39 | + # Compute optimal chunk size based on single file shape and MIP levels |
| 40 | + print( |
| 41 | + f"Computing optimal chunk size for shape {single_file_xyz_shape} with {num_mips} MIP levels..." |
| 42 | + ) |
| 43 | + computed_chunk_size = compute_optimal_chunk_size(single_file_xyz_shape, num_mips) |
| 44 | + if manual_chunk_size is not None: |
| 45 | + if len(manual_chunk_size) != 3: |
| 46 | + print( |
| 47 | + "Error: MANUAL_CHUNK_SIZE must be a list of three integers (e.g., 64,64,16)" |
| 48 | + ) |
| 49 | + exit(1) |
| 50 | + print(f"Using manual chunk size from configuration: {manual_chunk_size}") |
| 51 | + chunk_size = manual_chunk_size |
| 52 | + else: |
| 53 | + computed_chunk_size = compute_optimal_chunk_size( |
| 54 | + single_file_xyz_shape, num_mips |
| 55 | + ) |
| 56 | + print(f"Computed optimal chunk size: {computed_chunk_size}") |
| 57 | + chunk_size = computed_chunk_size |
| 58 | + |
| 59 | + volume_size = [ |
| 60 | + single_file_xyz_shape[0] * computed_num_rows, |
| 61 | + single_file_xyz_shape[1] * computed_num_cols, |
| 62 | + single_file_xyz_shape[2], |
| 63 | + ] # XYZ (T) |
| 64 | + print("Volume size:", volume_size) |
| 65 | + |
| 66 | + # Validate chunk size works with the data |
| 67 | + for i, (dim_name, dim_size, chunk_dim) in enumerate( |
| 68 | + zip(["X", "Y", "Z"], volume_size, chunk_size) |
| 69 | + ): |
| 70 | + num_chunks_this_dim = math.ceil(dim_size / chunk_dim) |
| 71 | + print( |
| 72 | + f" {dim_name} dimension: {dim_size} → {num_chunks_this_dim} chunks of size {chunk_dim}" |
| 73 | + ) |
| 74 | + |
| 75 | + # Check how this works across MIP levels |
| 76 | + for mip in range(num_mips): # Show first few MIP levels |
| 77 | + effective_size = dim_size // (2**mip) |
| 78 | + if effective_size > 0: |
| 79 | + mip_chunks = math.ceil(effective_size / chunk_dim) |
| 80 | + utilization = ( |
| 81 | + (effective_size / (mip_chunks * chunk_dim)) * 100 |
| 82 | + if mip_chunks > 0 |
| 83 | + else 0 |
| 84 | + ) |
| 85 | + print( |
| 86 | + f" MIP {mip}: {effective_size} → {mip_chunks} chunks ({utilization:.1f}% utilization)" |
| 87 | + ) |
| 88 | + |
| 89 | + return volume_size, chunk_size |
| 90 | + |
| 91 | + |
| 92 | +def process( |
| 93 | + args, |
| 94 | + single_file_shape, |
| 95 | + volume_shape, |
| 96 | + vols, |
| 97 | + chunk_size, |
| 98 | + num_mips, |
| 99 | + mip_cutoff, |
| 100 | + allow_non_aligned_write, |
| 101 | + overwrite_output, |
| 102 | + progress_dir, |
| 103 | +): |
| 104 | + x_i, y_i, z_i = args |
| 105 | + |
| 106 | + start = [ |
| 107 | + x_i * single_file_shape[0], |
| 108 | + y_i * single_file_shape[1], |
| 109 | + z_i * single_file_shape[2], |
| 110 | + ] |
| 111 | + end = [ |
| 112 | + min((x_i + 1) * single_file_shape[0], volume_shape[0]), |
| 113 | + min((y_i + 1) * single_file_shape[1], volume_shape[1]), |
| 114 | + min((z_i + 1) * single_file_shape[2], volume_shape[2]), |
| 115 | + ] |
| 116 | + f_name = progress_dir / f"{start[0]}-{end[0]}_{start[1]}-{end[1]}.done" |
| 117 | + print(f"Processing chunk: {start} to {end}, file: {f_name}") |
| 118 | + if f_name.exists() and not overwrite_output: |
| 119 | + return (start, end) |
| 120 | + |
| 121 | + # Use the new load_file function that handles download/caching |
| 122 | + print(f"Loading file for coordinates ({x_i}, {y_i}, {z_i})") |
| 123 | + loaded_zarr_store = load_file(x_i, y_i) |
| 124 | + |
| 125 | + if loaded_zarr_store is None: |
| 126 | + print(f"Warning: Could not load file for row {x_i}, col {y_i}. Skipping...") |
| 127 | + return |
| 128 | + |
| 129 | + rawdata = load_data_from_zarr_store(loaded_zarr_store) |
| 130 | + |
| 131 | + # Process all mip levels |
| 132 | + for mip_level in reversed(range(mip_cutoff, num_mips)): |
| 133 | + if mip_level == 0: |
| 134 | + downsampled = rawdata |
| 135 | + ds_start = start |
| 136 | + ds_end = end |
| 137 | + if not allow_non_aligned_write: |
| 138 | + # Align to chunk boundaries |
| 139 | + ds_start = [ |
| 140 | + int(round(math.floor(s / c) * c)) |
| 141 | + for s, c in zip(ds_start, chunk_size) |
| 142 | + ] |
| 143 | + ds_end = [ |
| 144 | + int(round(math.ceil(e / c) * c)) for e, c in zip(ds_end, chunk_size) |
| 145 | + ] |
| 146 | + ds_end = [min(e, s) for e, s in zip(ds_end, volume_shape)] |
| 147 | + else: |
| 148 | + factor = 2**mip_level |
| 149 | + factor_tuple = (factor, factor, factor, 1) |
| 150 | + ds_start = [int(np.round(s / (2**mip_level))) for s in start] |
| 151 | + if not allow_non_aligned_write: |
| 152 | + # Align to chunk boundaries |
| 153 | + ds_start = [ |
| 154 | + int(round(math.floor(s / c) * c)) |
| 155 | + for s, c in zip(ds_start, chunk_size) |
| 156 | + ] |
| 157 | + downsample_shape = [ |
| 158 | + int(math.ceil(s / f)) for s, f in zip(rawdata.shape, factor_tuple) |
| 159 | + ] |
| 160 | + ds_end_est = [s + d for s, d in zip(ds_start, downsample_shape)] |
| 161 | + if allow_non_aligned_write: |
| 162 | + bounds_from_end = [int(math.ceil(e / (2**mip_level))) for e in end] |
| 163 | + ds_end = [max(e1, e2) for e1, e2 in zip(ds_end_est, bounds_from_end)] |
| 164 | + else: |
| 165 | + # Align to chunk boundaries |
| 166 | + ds_end = [ |
| 167 | + int(round(math.ceil(e / c) * c)) |
| 168 | + for e, c in zip(ds_end_est, chunk_size) |
| 169 | + ] |
| 170 | + ds_end = [min(e, s) for e, s in zip(ds_end, volume_shape)] |
| 171 | + print("DS fill", ds_start, ds_end) |
| 172 | + downsampled = downsample_with_averaging(rawdata, factor_tuple) |
| 173 | + print("Downsampled shape:", downsampled.shape) |
| 174 | + |
| 175 | + if not allow_non_aligned_write: |
| 176 | + # TODO may need to ignore padding at the data edges |
| 177 | + # We may need to pad the downsampled data to fit the chunk boundaries |
| 178 | + pad_width = [ |
| 179 | + (0, max(0, de - ds - s)) |
| 180 | + for ds, de, s in zip(ds_start, ds_end, downsampled.shape) |
| 181 | + ] |
| 182 | + pad_width.append((0, 0)) # No padding for channel dimension |
| 183 | + # we should never pad more than the mip level times a factor inverse |
| 184 | + max_allowed_pad = 2 ** (num_mips - mip_level) |
| 185 | + max_actual_pad = max(pw[1] for pw in pad_width) |
| 186 | + if max_actual_pad > max_allowed_pad: |
| 187 | + raise ValueError( |
| 188 | + f"Padding too large at mip {mip_level}: {pad_width}, max allowed {max_allowed_pad}" |
| 189 | + ) |
| 190 | + if any(pw[1] > 0 for pw in pad_width): |
| 191 | + print("Padding downsampled data with:", pad_width) |
| 192 | + downsampled = np.pad( |
| 193 | + downsampled, pad_width, mode="constant", constant_values=0 |
| 194 | + ) |
| 195 | + print("Padded downsampled shape:", downsampled.shape) |
| 196 | + |
| 197 | + vols[mip_level][ |
| 198 | + ds_start[0] : ds_end[0], ds_start[1] : ds_end[1], ds_start[2] : ds_end[2] |
| 199 | + ] = downsampled |
| 200 | + |
| 201 | + # Mark chunk as complete |
| 202 | + touch(f_name) |
| 203 | + |
| 204 | + # Clean up cached file to save disk space |
| 205 | + delete_cached_zarr_file(x_i, y_i) |
| 206 | + |
| 207 | + # Return the bounds of the processed chunk |
| 208 | + return (start, end) |
| 209 | + |
| 210 | + |
| 211 | +def is_chunk_fully_covered(chunk_bounds, processed_chunks_bounds): |
| 212 | + """ |
| 213 | + Check if a chunk is fully covered by processed bounds. |
| 214 | +
|
| 215 | + Args: |
| 216 | + chunk_bounds: [start_coord, end_coord] where each coord is [x, y, z] |
| 217 | + processed_chunks_bounds: List of tuples (start, end) where start and end are [x, y, z] |
| 218 | +
|
| 219 | + Returns: |
| 220 | + bool: True if all 8 corners of the chunk are covered by processed bounds |
| 221 | + """ |
| 222 | + if not processed_chunks_bounds: |
| 223 | + return False |
| 224 | + |
| 225 | + start_coord, end_coord = chunk_bounds |
| 226 | + x0, y0, z0 = start_coord |
| 227 | + x1, y1, z1 = end_coord |
| 228 | + |
| 229 | + # Generate all 8 corners of the chunk |
| 230 | + corners = [ |
| 231 | + [x0, y0, z0], # min corner |
| 232 | + [x1, y0, z0], |
| 233 | + [x0, y1, z0], |
| 234 | + [x0, y0, z1], |
| 235 | + [x1, y1, z0], |
| 236 | + [x1, y0, z1], |
| 237 | + [x0, y1, z1], |
| 238 | + [x1, y1, z1], # max corner |
| 239 | + ] |
| 240 | + |
| 241 | + # Check if each corner is covered by at least one processed bound |
| 242 | + for corner in corners: |
| 243 | + corner_covered = False |
| 244 | + for start, end in processed_chunks_bounds: |
| 245 | + # Check if corner is inside this processed bound |
| 246 | + if ( |
| 247 | + start[0] <= corner[0] < end[0] |
| 248 | + and start[1] <= corner[1] < end[1] |
| 249 | + and start[2] <= corner[2] < end[2] |
| 250 | + ): |
| 251 | + corner_covered = True |
| 252 | + break |
| 253 | + |
| 254 | + # If any corner is not covered, the chunk is not fully covered |
| 255 | + if not corner_covered: |
| 256 | + return False |
| 257 | + |
| 258 | + # All corners are covered |
| 259 | + return True |
0 commit comments