Skip to content

Commit c13b218

Browse files
committed
Generalize reproject_and_coadd to N-dimensions and fix test failures
1 parent bc503a4 commit c13b218

File tree

5 files changed

+182
-220
lines changed

5 files changed

+182
-220
lines changed

Diff for: reproject/array_utils.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
__all__ = ["map_coordinates"]
3+
__all__ = ["map_coordinates", "sample_array_edges"]
44

55

66
def map_coordinates(image, coords, **kwargs):
@@ -35,3 +35,22 @@ def map_coordinates(image, coords, **kwargs):
3535
values[reset] = kwargs.get("cval", 0.0)
3636

3737
return values
38+
39+
40+
def sample_array_edges(shape, *, n_samples):
41+
# Given an N-dimensional array shape, sample each edge of the array using
42+
# the requested number of samples (which will include vertices). To do this
43+
# we iterate through the dimensions and for each one we sample the points
44+
# in that dimension and iterate over the combination of other vertices.
45+
# Returns an array with dimensions (N, n_samples)
46+
all_positions = []
47+
ndim = len(shape)
48+
shape = np.array(shape)
49+
for idim in range(ndim):
50+
for vertex in range(2**ndim):
51+
positions = -0.5 + shape * ((vertex & (2 ** np.arange(ndim))) > 0).astype(int)
52+
positions = np.broadcast_to(positions, (n_samples, ndim)).copy()
53+
positions[:, idim] = np.linspace(-0.5, shape[idim] - 0.5, n_samples)
54+
all_positions.append(positions)
55+
positions = np.unique(np.vstack(all_positions), axis=0).T
56+
return positions

Diff for: reproject/mosaicking/coadd.py

+77-103
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from astropy.wcs import WCS
55
from astropy.wcs.wcsapi import SlicedLowLevelWCS
66

7+
from ..array_utils import sample_array_edges
78
from ..utils import parse_input_data, parse_input_weights, parse_output_projection
89
from .background import determine_offset_matrix, solve_corrections_sgd
910
from .subset_array import ReprojectedArraySubset
@@ -30,15 +31,13 @@ def reproject_and_coadd(
3031
output_footprint=None,
3132
block_sizes=None,
3233
progress_bar=None,
33-
blank_pixel_value=np.nan,
34+
blank_pixel_value=0,
3435
**kwargs,
3536
):
3637
"""
37-
Given a set of input images, reproject and co-add these to a single
38+
Given a set of input data, reproject and co-add these to a single
3839
final image.
3940
40-
This currently only works with 2-d images with celestial WCS.
41-
4241
Parameters
4342
----------
4443
input_data : iterable
@@ -149,24 +148,31 @@ def reproject_and_coadd(
149148

150149
wcs_out, shape_out = parse_output_projection(output_projection, shape_out=shape_out)
151150

152-
if output_array is not None and output_array.shape != shape_out:
151+
if output_array is None:
152+
output_array = np.zeros(shape_out)
153+
elif output_array.shape != shape_out:
153154
raise ValueError(
154155
"If you specify an output array, it must have a shape matching "
155156
f"the output shape {shape_out}"
156157
)
157-
if output_footprint is not None and output_footprint.shape != shape_out:
158+
159+
if output_footprint is None:
160+
output_footprint = np.zeros(shape_out)
161+
elif output_footprint.shape != shape_out:
158162
raise ValueError(
159163
"If you specify an output footprint array, it must have a shape matching "
160164
f"the output shape {shape_out}"
161165
)
162166

163-
if output_array is None:
164-
output_array = np.zeros(shape_out)
165-
if output_footprint is None:
166-
output_footprint = np.zeros(shape_out)
167+
# Define 'on-the-fly' mode: in the case where we don't need to match
168+
# the backgrounds and we are combining with 'mean' or 'sum', we don't
169+
# have to keep track of the intermediate arrays and can just modify
170+
# the output array on-the-fly
171+
on_the_fly = not match_background and combine_function in ("mean", "sum")
167172

168173
# Start off by reprojecting individual images to the final projection
169-
if match_background:
174+
175+
if not on_the_fly:
170176
arrays = []
171177

172178
for idata in progress_bar(range(len(input_data))):
@@ -192,71 +198,42 @@ def reproject_and_coadd(
192198
# significant distortion (when the edges of the input image become
193199
# convex in the output projection), and transforming every edge pixel,
194200
# which provides a lot of redundant information.
195-
if array_in.ndim == 2:
196-
ny, nx = array_in.shape
197-
n_per_edge = 11
198-
xs = np.linspace(-0.5, nx - 0.5, n_per_edge)
199-
ys = np.linspace(-0.5, ny - 0.5, n_per_edge)
200-
xs = np.concatenate((xs, np.full(n_per_edge, xs[-1]), xs, np.full(n_per_edge, xs[0])))
201-
ys = np.concatenate((np.full(n_per_edge, ys[0]), ys, np.full(n_per_edge, ys[-1]), ys))
202-
xc_out, yc_out = wcs_out.world_to_pixel(wcs_in.pixel_to_world(xs, ys))
203-
shape_out_cel = shape_out
204-
elif array_in.ndim == 3:
205-
# for cubes, we only handle single corners now
206-
nz, ny, nx = array_in.shape
207-
xc = np.array([-0.5, nx - 0.5, nx - 0.5, -0.5])
208-
yc = np.array([-0.5, -0.5, ny - 0.5, ny - 0.5])
209-
zc = np.array([-0.5, nz - 0.5])
210-
# TODO: figure out what to do here if the low_level_wcs doesn't support subsetting
211-
xc_out, yc_out = wcs_out.low_level_wcs.celestial.world_to_pixel(
212-
wcs_in.celestial.pixel_to_world(xc, yc)
213-
)
214-
zc_out = wcs_out.low_level_wcs.spectral.world_to_pixel(
215-
wcs_in.spectral.pixel_to_world(zc)
216-
)
217-
shape_out_cel = shape_out[1:]
218-
else:
219-
raise ValueError(f"Wrong number of dimensions: {array_in.ndim}")
201+
202+
edges = sample_array_edges(array_in.shape, n_samples=11)[::-1]
203+
edges_out = wcs_out.world_to_pixel(wcs_in.pixel_to_world(*edges))[::-1]
220204

221205
# Determine the cutout parameters
222206

223207
# In some cases, images might not have valid coordinates in the corners,
224208
# such as all-sky images or full solar disk views. In this case we skip
225209
# this step and just use the full output WCS for reprojection.
226210

227-
if np.any(np.isnan(xc_out)) or np.any(np.isnan(yc_out)):
228-
imin = 0
229-
imax = shape_out_cel[1]
230-
jmin = 0
231-
jmax = shape_out_cel[0]
232-
else:
233-
imin = max(0, int(np.floor(xc_out.min() + 0.5)))
234-
imax = min(shape_out_cel[1], int(np.ceil(xc_out.max() + 0.5)))
235-
jmin = max(0, int(np.floor(yc_out.min() + 0.5)))
236-
jmax = min(shape_out_cel[0], int(np.ceil(yc_out.max() + 0.5)))
211+
ndim_out = len(shape_out)
237212

238-
if imax < imin or jmax < jmin:
213+
skip_data = False
214+
if np.any(np.isnan(edges_out)):
215+
bounds = list(zip([0] * ndim_out, shape_out))
216+
else:
217+
bounds = []
218+
for idim in range(ndim_out):
219+
imin = max(0, int(np.floor(edges_out[idim].min() + 0.5)))
220+
imax = min(shape_out[idim], int(np.ceil(edges_out[idim].max() + 0.5)))
221+
bounds.append((imin, imax))
222+
if imax < imin:
223+
skip_data = True
224+
break
225+
226+
if skip_data:
239227
continue
240228

241-
if array_in.ndim == 2:
242-
if isinstance(wcs_out, WCS):
243-
wcs_out_indiv = wcs_out[jmin:jmax, imin:imax]
244-
else:
245-
wcs_out_indiv = SlicedLowLevelWCS(
246-
wcs_out.low_level_wcs, (slice(jmin, jmax), slice(imin, imax))
247-
)
248-
shape_out_indiv = (jmax - jmin, imax - imin)
249-
kmin, kmax = None, None # for reprojectedarraysubset below
250-
elif array_in.ndim == 3:
251-
kmin = max(0, int(np.floor(zc_out.min() + 0.5)))
252-
kmax = min(shape_out[0], int(np.ceil(zc_out.max() + 0.5)))
253-
if isinstance(wcs_out, WCS):
254-
wcs_out_indiv = wcs_out[kmin:kmax, jmin:jmax, imin:imax]
255-
else:
256-
wcs_out_indiv = SlicedLowLevelWCS(
257-
wcs_out.low_level_wcs, (slice(kmin, kmax), slice(jmin, jmax), slice(imin, imax))
258-
)
259-
shape_out_indiv = (kmax - kmin, jmax - jmin, imax - imin)
229+
slice_out = tuple([slice(imin, imax) for (imin, imax) in bounds])
230+
231+
if isinstance(wcs_out, WCS):
232+
wcs_out_indiv = wcs_out[slice_out]
233+
else:
234+
wcs_out_indiv = SlicedLowLevelWCS(wcs_out.low_level_wcs, slice_out)
235+
236+
shape_out_indiv = [imax - imin for (imin, imax) in bounds]
260237

261238
if block_sizes is not None:
262239
if len(block_sizes) == len(input_data) and len(block_sizes[idata]) == len(shape_out):
@@ -296,22 +273,20 @@ def reproject_and_coadd(
296273
weights[reset] = 0.0
297274
footprint *= weights
298275

299-
array = ReprojectedArraySubset(array, footprint, imin, imax, jmin, jmax, kmin, kmax)
276+
array = ReprojectedArraySubset(array, footprint, bounds)
300277

301278
# TODO: make sure we gracefully handle the case where the
302279
# output image is empty (due e.g. to no overlap).
303280

304-
if match_background:
305-
arrays.append(array)
281+
if on_the_fly:
282+
# By default, values outside of the footprint are set to NaN
283+
# but we set these to 0 here to avoid getting NaNs in the
284+
# means/sums.
285+
array.array[array.footprint == 0] = 0
286+
output_array[array.view_in_original_array] += array.array * array.footprint
287+
output_footprint[array.view_in_original_array] += array.footprint
306288
else:
307-
if combine_function in ("mean", "sum"):
308-
# By default, values outside of the footprint are set to NaN
309-
# but we set these to 0 here to avoid getting NaNs in the
310-
# means/sums.
311-
array.array[array.footprint == 0] = 0
312-
313-
output_array[array.view_in_original_array] += array.array * array.footprint
314-
output_footprint[array.view_in_original_array] += array.footprint
289+
arrays.append(array)
315290

316291
# If requested, try and match the backgrounds.
317292
if match_background and len(arrays) > 1:
@@ -322,11 +297,6 @@ def reproject_and_coadd(
322297
for array, correction in zip(arrays, corrections, strict=True):
323298
array.array -= correction
324299

325-
if combine_function == "min":
326-
output_array[...] = np.inf
327-
elif combine_function == "max":
328-
output_array[...] = -np.inf
329-
330300
if combine_function in ("mean", "sum"):
331301
if match_background:
332302
# if we're not matching the background, this part has already been done
@@ -336,37 +306,41 @@ def reproject_and_coadd(
336306
# means/sums.
337307
array.array[array.footprint == 0] = 0
338308

339-
output_array[array.view_in_original_array] += array.array * array.footprint
340-
output_footprint[array.view_in_original_array] += array.footprint
309+
output_array[array.view_in_original_array] += array.array * array.footprint
310+
output_footprint[array.view_in_original_array] += array.footprint
341311

342312
if combine_function == "mean":
343313
with np.errstate(invalid="ignore"):
344314
output_array /= output_footprint
345315
output_array[output_footprint == 0] = blank_pixel_value
346316

347317
elif combine_function in ("first", "last", "min", "max"):
348-
if match_background:
349-
for array in arrays:
350-
if combine_function == "first":
351-
mask = output_footprint[array.view_in_original_array] == 0
352-
elif combine_function == "last":
353-
mask = array.footprint > 0
354-
elif combine_function == "min":
355-
mask = (array.footprint > 0) & (
356-
array.array < output_array[array.view_in_original_array]
357-
)
358-
elif combine_function == "max":
359-
mask = (array.footprint > 0) & (
360-
array.array > output_array[array.view_in_original_array]
361-
)
362-
363-
output_footprint[array.view_in_original_array] = np.where(
364-
mask, array.footprint, output_footprint[array.view_in_original_array]
318+
if combine_function == "min":
319+
output_array[...] = np.inf
320+
elif combine_function == "max":
321+
output_array[...] = -np.inf
322+
323+
for array in arrays:
324+
if combine_function == "first":
325+
mask = output_footprint[array.view_in_original_array] == 0
326+
elif combine_function == "last":
327+
mask = array.footprint > 0
328+
elif combine_function == "min":
329+
mask = (array.footprint > 0) & (
330+
array.array < output_array[array.view_in_original_array]
365331
)
366-
output_array[array.view_in_original_array] = np.where(
367-
mask, array.array, output_array[array.view_in_original_array]
332+
elif combine_function == "max":
333+
mask = (array.footprint > 0) & (
334+
array.array > output_array[array.view_in_original_array]
368335
)
369336

337+
output_footprint[array.view_in_original_array] = np.where(
338+
mask, array.footprint, output_footprint[array.view_in_original_array]
339+
)
340+
output_array[array.view_in_original_array] = np.where(
341+
mask, array.array, output_array[array.view_in_original_array]
342+
)
343+
370344
output_array[output_footprint == 0] = blank_pixel_value
371345

372346
return output_array, output_footprint

0 commit comments

Comments
 (0)