Skip to content

Commit 68f6e80

Browse files
committed
fix: correct downsampling
1 parent df5b2e2 commit 68f6e80

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

examples/create_downampled.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import itertools
44
import math
55
from pathlib import Path
6-
import time
76
import numpy as np
87
from cloudvolume import CloudVolume
98
from cloudvolume.lib import touch, Vec
@@ -21,8 +20,9 @@
2120
# Other settings
2221
OUTPUT_PATH.mkdir(exist_ok=True, parents=True)
2322
OVERWRITE = False
24-
NUM_MIPS = 3
25-
MIP_CUTOFF = 0 # To save time you can start at the lowest resolution and work up
23+
NUM_MIPS = 5
24+
MIP_CUTOFF = 3 # To save time you can start at the lowest resolution and work up
25+
NUM_CHANNELS = 2 # For less memory usage (can't be 1 right now though)
2626

2727
# %% Load the data
2828
OUTPUT_PATH.mkdir(exist_ok=True, parents=True)
@@ -45,13 +45,13 @@ def load_zarr_and_permute(file_path):
4545

4646

4747
def load_chunk_from_zarr_store(
48-
zarr_store, x_start, x_end, y_start, y_end, z_start, z_end, channel=0
48+
zarr_store, x_start, x_end, y_start, y_end, z_start, z_end
4949
):
5050
# Input is in Z, T, C, Y, X order
5151
data = zarr_store[
5252
:, # T
5353
z_start:z_end, # Z
54-
channel, # C
54+
:NUM_CHANNELS, # C
5555
y_start:y_end, # Y
5656
x_start:x_end, # X
5757
]
@@ -60,17 +60,18 @@ def load_chunk_from_zarr_store(
6060
data = np.squeeze(data) # Remove any singleton dimensions
6161
print("Loaded data shape:", data.shape)
6262
# Then we permute to XYTCZ
63-
data = np.transpose(data, (-1, -2, 0)) # Permute to XYTCZ
63+
data = np.transpose(data, (-1, -2, 0, 1)) # Permute to XYTCZ
6464
return data
6565

6666

6767
zarr_store = load_zarr_data(all_files[0])
6868

6969
# It may take too long to just load one file, might need to process in chunks
7070
# %% Check how long to load a single file
71-
start_time = time.time()
72-
data = load_chunk_from_zarr_store(zarr_store, 0, 256, 0, 200, 0, 128, channel=0)
73-
print("Time to load a single file:", time.time() - start_time)
71+
# import time
72+
# start_time = time.time()
73+
# data = load_chunk_from_zarr_store(zarr_store, 0, 256, 0, 200, 0, 128)
74+
# print("Time to load a single file:", time.time() - start_time)
7475

7576
# %% Inspect the data
7677
shape = zarr_store.shape
@@ -82,9 +83,9 @@ def load_chunk_from_zarr_store(
8283
size_y = 1
8384
size_z = 1
8485

85-
num_channels = shape[2]
86+
num_channels = min(shape[2], NUM_CHANNELS) # Limit to NUM_CHANNELS for memory usage
8687
data_type = "uint16"
87-
chunk_size = [256, 256, 128]
88+
chunk_size = [64, 64, 32]
8889

8990
# You can provide a subset here also
9091
num_rows = 1
@@ -128,16 +129,17 @@ def load_chunk_from_zarr_store(
128129
progress_dir.mkdir(exist_ok=True)
129130

130131
# %% Functions for moving data
131-
# TODO setup file loop
132-
# TODO setup channel handling
133-
134132
shape = single_file_dims_shape
135133
chunk_shape = np.array([1500, 936, 687]) # this is for reading data
136134
num_chunks_per_dim = np.ceil(shape / chunk_shape).astype(int)
137135

138136

139137
def process(args):
140138
x_i, y_i, z_i = args
139+
flat_index = x_i * num_cols + y_i
140+
print(f"Processing chunk {flat_index} at coordinates ({x_i}, {y_i}, {z_i})")
141+
# Load the data for this chunk
142+
loaded_zarr_store = load_zarr_data(all_files[flat_index])
141143
start = [x_i * chunk_shape[0], y_i * chunk_shape[1], z_i * chunk_shape[2]]
142144
end = [
143145
min((x_i + 1) * chunk_shape[0], shape[0]),
@@ -149,31 +151,31 @@ def process(args):
149151
if f_name.exists() and not OVERWRITE:
150152
return
151153
rawdata = load_chunk_from_zarr_store(
152-
zarr_store, start[0], end[0], start[1], end[1], start[2], end[2], channel=0
154+
loaded_zarr_store, start[0], end[0], start[1], end[1], start[2], end[2]
153155
)
154156
for mip_level in reversed(range(MIP_CUTOFF, NUM_MIPS)):
155157
if mip_level == 0:
156158
downsampled = rawdata
157159
ds_start = start
158160
ds_end = end
159161
else:
162+
ds_start = [int(math.ceil(s / (2**mip_level))) for s in start]
163+
ds_end = [int(math.ceil(e / (2**mip_level))) for e in end]
164+
print("DS fill", ds_start, ds_end)
160165
downsampled = downsample_with_averaging(
161-
rawdata, [2 * mip_level, 2 * mip_level, 2 * mip_level]
166+
rawdata, [2**mip_level, 2**mip_level, 2**mip_level, 1]
162167
)
163-
ds_start = [int(math.ceil(s / (2 * mip_level))) for s in start]
164-
ds_end = [int(math.ceil(e / (2 * mip_level))) for e in end]
165-
print(ds_start, ds_end)
166168
print("Downsampled shape:", downsampled.shape)
167169

168170
vols[mip_level][
169171
ds_start[0] : ds_end[0], ds_start[1] : ds_end[1], ds_start[2] : ds_end[2]
170-
] = downsampled
172+
] = downsampled.astype(np.uint16)
171173
touch(f_name)
172174

173175

174176
# %% Try with a single chunk to see if it works
175-
x_i, y_i, z_i = 0, 0, 0
176-
process((x_i, y_i, z_i))
177+
# x_i, y_i, z_i = 0, 0, 0
178+
# process((x_i, y_i, z_i))
177179

178180

179181
# %% Loop over all the chunks

0 commit comments

Comments
 (0)