Skip to content

Commit f96587d

Browse files
committed
feat: scaffolding out a downsample module
1 parent 7383196 commit f96587d

File tree

8 files changed

+1136
-0
lines changed

8 files changed

+1136
-0
lines changed

create_downsampled/__init__.py

Whitespace-only changes.

create_downsampled/chunking.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import math
2+
3+
from cloudvolume.lib import touch
4+
import numpy as np
5+
from neuroglancer.downsample import downsample_with_averaging
6+
7+
from io import load_data_from_zarr_store, delete_cached_zarr_file, load_file
8+
9+
10+
11+
def compute_optimal_chunk_size(single_file_shape, num_mips, max_chunk_size=None):
12+
"""
13+
Compute optimal chunk size based on single file shape and number of MIP levels.
14+
15+
Args:
16+
single_file_shape: [x, y, z] shape of a single file
17+
num_mips: Number of MIP levels
18+
max_chunk_size: Optional maximum chunk size (default: 512)
19+
20+
Returns:
21+
List[int]: [chunk_x, chunk_y, chunk_z] optimal chunk sizes
22+
"""
23+
if max_chunk_size is None:
24+
max_chunk_size = 512
25+
26+
single_file_shape = np.array(single_file_shape)
27+
optimal_chunks = np.ceil(single_file_shape / (2 ** (num_mips - 1)))
28+
29+
return [int(c) for c in optimal_chunks]
30+
31+
32+
def compute_volume_and_chunk_size(
33+
single_file_xyz_shape,
34+
computed_num_rows,
35+
computed_num_cols,
36+
num_mips,
37+
manual_chunk_size,
38+
):
39+
# Compute optimal chunk size based on single file shape and MIP levels
40+
print(
41+
f"Computing optimal chunk size for shape {single_file_xyz_shape} with {num_mips} MIP levels..."
42+
)
43+
computed_chunk_size = compute_optimal_chunk_size(single_file_xyz_shape, num_mips)
44+
if manual_chunk_size is not None:
45+
if len(manual_chunk_size) != 3:
46+
print(
47+
"Error: MANUAL_CHUNK_SIZE must be a list of three integers (e.g., 64,64,16)"
48+
)
49+
exit(1)
50+
print(f"Using manual chunk size from configuration: {manual_chunk_size}")
51+
chunk_size = manual_chunk_size
52+
else:
53+
computed_chunk_size = compute_optimal_chunk_size(
54+
single_file_xyz_shape, num_mips
55+
)
56+
print(f"Computed optimal chunk size: {computed_chunk_size}")
57+
chunk_size = computed_chunk_size
58+
59+
volume_size = [
60+
single_file_xyz_shape[0] * computed_num_rows,
61+
single_file_xyz_shape[1] * computed_num_cols,
62+
single_file_xyz_shape[2],
63+
] # XYZ (T)
64+
print("Volume size:", volume_size)
65+
66+
# Validate chunk size works with the data
67+
for i, (dim_name, dim_size, chunk_dim) in enumerate(
68+
zip(["X", "Y", "Z"], volume_size, chunk_size)
69+
):
70+
num_chunks_this_dim = math.ceil(dim_size / chunk_dim)
71+
print(
72+
f" {dim_name} dimension: {dim_size}{num_chunks_this_dim} chunks of size {chunk_dim}"
73+
)
74+
75+
# Check how this works across MIP levels
76+
for mip in range(num_mips): # Show first few MIP levels
77+
effective_size = dim_size // (2**mip)
78+
if effective_size > 0:
79+
mip_chunks = math.ceil(effective_size / chunk_dim)
80+
utilization = (
81+
(effective_size / (mip_chunks * chunk_dim)) * 100
82+
if mip_chunks > 0
83+
else 0
84+
)
85+
print(
86+
f" MIP {mip}: {effective_size}{mip_chunks} chunks ({utilization:.1f}% utilization)"
87+
)
88+
89+
return volume_size, chunk_size
90+
91+
92+
def process(
93+
args,
94+
single_file_shape,
95+
volume_shape,
96+
vols,
97+
chunk_size,
98+
num_mips,
99+
mip_cutoff,
100+
allow_non_aligned_write,
101+
overwrite_output,
102+
progress_dir,
103+
):
104+
x_i, y_i, z_i = args
105+
106+
start = [
107+
x_i * single_file_shape[0],
108+
y_i * single_file_shape[1],
109+
z_i * single_file_shape[2],
110+
]
111+
end = [
112+
min((x_i + 1) * single_file_shape[0], volume_shape[0]),
113+
min((y_i + 1) * single_file_shape[1], volume_shape[1]),
114+
min((z_i + 1) * single_file_shape[2], volume_shape[2]),
115+
]
116+
f_name = progress_dir / f"{start[0]}-{end[0]}_{start[1]}-{end[1]}.done"
117+
print(f"Processing chunk: {start} to {end}, file: {f_name}")
118+
if f_name.exists() and not overwrite_output:
119+
return (start, end)
120+
121+
# Use the new load_file function that handles download/caching
122+
print(f"Loading file for coordinates ({x_i}, {y_i}, {z_i})")
123+
loaded_zarr_store = load_file(x_i, y_i)
124+
125+
if loaded_zarr_store is None:
126+
print(f"Warning: Could not load file for row {x_i}, col {y_i}. Skipping...")
127+
return
128+
129+
rawdata = load_data_from_zarr_store(loaded_zarr_store)
130+
131+
# Process all mip levels
132+
for mip_level in reversed(range(mip_cutoff, num_mips)):
133+
if mip_level == 0:
134+
downsampled = rawdata
135+
ds_start = start
136+
ds_end = end
137+
if not allow_non_aligned_write:
138+
# Align to chunk boundaries
139+
ds_start = [
140+
int(round(math.floor(s / c) * c))
141+
for s, c in zip(ds_start, chunk_size)
142+
]
143+
ds_end = [
144+
int(round(math.ceil(e / c) * c)) for e, c in zip(ds_end, chunk_size)
145+
]
146+
ds_end = [min(e, s) for e, s in zip(ds_end, volume_shape)]
147+
else:
148+
factor = 2**mip_level
149+
factor_tuple = (factor, factor, factor, 1)
150+
ds_start = [int(np.round(s / (2**mip_level))) for s in start]
151+
if not allow_non_aligned_write:
152+
# Align to chunk boundaries
153+
ds_start = [
154+
int(round(math.floor(s / c) * c))
155+
for s, c in zip(ds_start, chunk_size)
156+
]
157+
downsample_shape = [
158+
int(math.ceil(s / f)) for s, f in zip(rawdata.shape, factor_tuple)
159+
]
160+
ds_end_est = [s + d for s, d in zip(ds_start, downsample_shape)]
161+
if allow_non_aligned_write:
162+
bounds_from_end = [int(math.ceil(e / (2**mip_level))) for e in end]
163+
ds_end = [max(e1, e2) for e1, e2 in zip(ds_end_est, bounds_from_end)]
164+
else:
165+
# Align to chunk boundaries
166+
ds_end = [
167+
int(round(math.ceil(e / c) * c))
168+
for e, c in zip(ds_end_est, chunk_size)
169+
]
170+
ds_end = [min(e, s) for e, s in zip(ds_end, volume_shape)]
171+
print("DS fill", ds_start, ds_end)
172+
downsampled = downsample_with_averaging(rawdata, factor_tuple)
173+
print("Downsampled shape:", downsampled.shape)
174+
175+
if not allow_non_aligned_write:
176+
# TODO may need to ignore padding at the data edges
177+
# We may need to pad the downsampled data to fit the chunk boundaries
178+
pad_width = [
179+
(0, max(0, de - ds - s))
180+
for ds, de, s in zip(ds_start, ds_end, downsampled.shape)
181+
]
182+
pad_width.append((0, 0)) # No padding for channel dimension
183+
# we should never pad more than the mip level times a factor inverse
184+
max_allowed_pad = 2 ** (num_mips - mip_level)
185+
max_actual_pad = max(pw[1] for pw in pad_width)
186+
if max_actual_pad > max_allowed_pad:
187+
raise ValueError(
188+
f"Padding too large at mip {mip_level}: {pad_width}, max allowed {max_allowed_pad}"
189+
)
190+
if any(pw[1] > 0 for pw in pad_width):
191+
print("Padding downsampled data with:", pad_width)
192+
downsampled = np.pad(
193+
downsampled, pad_width, mode="constant", constant_values=0
194+
)
195+
print("Padded downsampled shape:", downsampled.shape)
196+
197+
vols[mip_level][
198+
ds_start[0] : ds_end[0], ds_start[1] : ds_end[1], ds_start[2] : ds_end[2]
199+
] = downsampled
200+
201+
# Mark chunk as complete
202+
touch(f_name)
203+
204+
# Clean up cached file to save disk space
205+
delete_cached_zarr_file(x_i, y_i)
206+
207+
# Return the bounds of the processed chunk
208+
return (start, end)
209+
210+
211+
def is_chunk_fully_covered(chunk_bounds, processed_chunks_bounds):
212+
"""
213+
Check if a chunk is fully covered by processed bounds.
214+
215+
Args:
216+
chunk_bounds: [start_coord, end_coord] where each coord is [x, y, z]
217+
processed_chunks_bounds: List of tuples (start, end) where start and end are [x, y, z]
218+
219+
Returns:
220+
bool: True if all 8 corners of the chunk are covered by processed bounds
221+
"""
222+
if not processed_chunks_bounds:
223+
return False
224+
225+
start_coord, end_coord = chunk_bounds
226+
x0, y0, z0 = start_coord
227+
x1, y1, z1 = end_coord
228+
229+
# Generate all 8 corners of the chunk
230+
corners = [
231+
[x0, y0, z0], # min corner
232+
[x1, y0, z0],
233+
[x0, y1, z0],
234+
[x0, y0, z1],
235+
[x1, y1, z0],
236+
[x1, y0, z1],
237+
[x0, y1, z1],
238+
[x1, y1, z1], # max corner
239+
]
240+
241+
# Check if each corner is covered by at least one processed bound
242+
for corner in corners:
243+
corner_covered = False
244+
for start, end in processed_chunks_bounds:
245+
# Check if corner is inside this processed bound
246+
if (
247+
start[0] <= corner[0] < end[0]
248+
and start[1] <= corner[1] < end[1]
249+
and start[2] <= corner[2] < end[2]
250+
):
251+
corner_covered = True
252+
break
253+
254+
# If any corner is not covered, the chunk is not fully covered
255+
if not corner_covered:
256+
return False
257+
258+
# All corners are covered
259+
return True

0 commit comments

Comments
 (0)