Skip to content

Commit 12a556a

Browse files
bjlittleCarwyn Pelley
authored andcommitted
Fix interpolate z_target nd
1 parent 3e44325 commit 12a556a

File tree

2 files changed

+120
-27
lines changed

2 files changed

+120
-27
lines changed

stratify/_vinterp.pyx

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,15 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
5656
5757
Parameters
5858
----------
59-
z_target - the levels to interpolate to.
60-
z_src - the coordinate from which to find the levels.
61-
fz_src - the data to use for the actual interpolation
59+
z_target - the levels to interpolate the source data ``fz_src`` to.
60+
z_src - the levels that the source data ``fz_src`` is interpolated from.
61+
fz_src - the source data to be interpolated.
6262
increasing - true when increasing Z index generally implies increasing Z values
6363
interpolation - the inner interpolation functionality. See the definition of
6464
Interpolator.
6565
extrapolation - the inner extrapolation functionality. See the definition of
6666
Extrapolator.
67-
fz_target - the pre-allocated array to be used for the outputting the result
68-
of interpolation.
67+
fz_target - the pre-allocated array to be used for the interpolated result
6968
7069
Note: This algorithm is not symmetric. It does not make assumptions about monotonicity
7170
of z_src nor z_target. Instead, the algorithm marches forwards from the last
@@ -79,8 +78,8 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
7978
We then continue from this index, only looking for the crossing of the next z_target.
8079
8180
For this reason, the order that the levels are provided is important.
82-
If z_src = [2, 4, 6], f_src = [2, 4, 6] and z_target = [3, 5], fz_target will be
83-
[3, 5]. But if z_target = [5, 3] fz_target will be [5, <extrapolation value>].
81+
If z_src = [2, 4, 6], fz_src = [2, 4, 6] and z_target = [3, 5], fz_target will be
82+
[3, 5]. But if z_target = [5, 3], fz_target will be [5, <extrapolation value>].
8483
8584
"""
8685
cdef unsigned int i_src, i_target, n_src, n_target, i, m
@@ -462,10 +461,13 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
462461
463462
Parameters
464463
----------
465-
z_target: 1d array
464+
z_target: 1d or nd array
466465
Target coordinate.
467466
This coordinate defines the levels to interpolate the source data
468-
``fz_src`` to.
467+
``fz_src`` to. If ``z_target`` is an nd array, it must have the same
468+
dimensionality as the source coordinate ``z_src``, and the shape of
469+
``z_target`` must match the shape of ``z_src``, although the axis
470+
of interpolation may differ in dimension size.
469471
z_src: nd array
470472
Source coordinate.
471473
This coordinate defines the levels that the source data ``fz_src`` is
@@ -477,7 +479,7 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
477479
dimensions (i.e. those on its right hand side) must be exactly
478480
the same as the shape of ``z_src``.
479481
axis: int (default -1)
480-
The axis to perform the interpolation over.
482+
The ``fz_src`` axis to perform the interpolation over.
481483
rising: bool (default None)
482484
Whether the values of the source's interpolation coordinate values
483485
are generally rising or generally falling. For example, values of
@@ -526,7 +528,7 @@ cdef class _Interpolation(object):
526528
cdef public np.dtype _target_dtype
527529
cdef int rising
528530
cpdef public z_target, orig_shape, axis, _zp_reshaped, _fp_reshaped
529-
cpdef public _result_working_shape, result_shape, _first_value
531+
cpdef public _result_working_shape, result_shape
530532

531533
def __init__(self, z_target, z_src, fz_src, axis=-1,
532534
rising=None,
@@ -543,12 +545,6 @@ cdef class _Interpolation(object):
543545
self._target_dtype = fz_src.dtype
544546
fz_src = fz_src.astype(np.float64)
545547

546-
# Broadcast the z_target shape if it is 1d (which it is in most cases)
547-
if z_target.ndim == 1:
548-
z_target_size = z_target.shape[0]
549-
else:
550-
z_target_size = z_target.shape[axis]
551-
552548
# Compute the axis in absolute terms.
553549
fp_axis = (axis + fz_src.ndim) % fz_src.ndim
554550
zp_axis = fp_axis - (fz_src.ndim - z_src.ndim)
@@ -557,7 +553,32 @@ cdef class _Interpolation(object):
557553

558554
# Ensure that fz_src's shape is a superset of z_src's.
559555
if z_src.shape != fz_src.shape[-z_src.ndim:]:
560-
raise ValueError('Shapes not consistent.')
556+
emsg = 'Shape for z_src {} is not a subset of fz_src {}.'
557+
raise ValueError(emsg.format(z_src.shape, fz_src.shape))
558+
559+
if z_target.ndim == 1:
560+
z_target_size = z_target.shape[0]
561+
else:
562+
# Ensure z_target and z_src have same ndims.
563+
if z_target.ndim != z_src.ndim:
564+
emsg = ('z_target and z_src must have the same number '
565+
'of dimensions, got {} != {}.')
566+
raise ValueError(emsg.format(z_target.ndim, z_src.ndim))
567+
# Ensure z_target and z_src have same shape over their
568+
# non-interpolated axes i.e. we need to ignore the axis of
569+
# interpolation when comparing the shapes of z_target and z_src.
570+
# E.g a z_target.shape=(3, 4, 5) and z_src.shape=(3, 10, 5),
571+
# interpolating over zp_axis=1 is fine as (3, :, 5) == (3, :, 5).
572+
# However, a z_target.shape=(3, 4, 6) and z_src.shape=(3, 10, 5),
573+
# interpolating over zp_axis=1 must fail as (3, :, 6) != (3, :, 5)
574+
zts, zss = z_target.shape, z_src.shape
575+
ztsp, zssp = zip(*[(str(j), str(k)) if i!=zp_axis else (':', ':')
576+
for i, (j, k) in enumerate(zip(zts, zss))])
577+
if ztsp != zssp:
578+
sep, emsg = ', ', ('z_target and z_src have different shapes, '
579+
'got ({}) != ({}).')
580+
raise ValueError(emsg.format(sep.join(ztsp), sep.join(zssp)))
581+
z_target_size = zts[zp_axis]
561582

562583
# We are going to put the source coordinate into a 3d shape for convenience of
563584
# Cython interface. Writing generic, fast, n-dimensional Cython code
@@ -574,7 +595,7 @@ cdef class _Interpolation(object):
574595
self.z_target = z_target
575596
#: The shape of the input data (fz_src).
576597
self.orig_shape = fz_src.shape
577-
#: The axis over which to do the interpolation.
598+
#: The fz_src axis over which to do the interpolation.
578599
self.axis = axis
579600

580601
#: The source z coordinate data reshaped into 3d working shape form.
@@ -694,4 +715,3 @@ cdef class _Interpolation(object):
694715
fz_target_view[:, i, :, j])
695716

696717
return fz_target.reshape(self.result_shape).astype(self._target_dtype)
697-

stratify/tests/test_vinterp.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,34 @@ def test_axis_2(self):
335335
def test_inconsistent_shape(self):
336336
data = np.empty([5, 4, 23, 7, 3])
337337
zdata = np.empty([5, 4, 3, 7, 3])
338-
with self.assertRaises(ValueError):
338+
emsg = 'z_src .* is not a subset of fz_src'
339+
with self.assertRaisesRegexp(ValueError, emsg):
339340
vinterp._Interpolation([1, 3], data, zdata, axis=2)
340341

341342
def test_axis_out_of_bounds(self):
342343
data = np.empty([5, 4])
343344
zdata = np.empty([5, 4])
344-
with self.assertRaises(ValueError):
345-
vinterp._Interpolation([1, 3], data, zdata, axis=4)
345+
axis = 4
346+
emsg = 'Axis {} out of range'
347+
with self.assertRaisesRegexp(ValueError, emsg.format(axis)):
348+
vinterp._Interpolation([1, 3], data, zdata, axis=axis)
349+
350+
def test_nd_inconsistent_ndims(self):
351+
z_target = np.empty((2, 3, 4))
352+
z_src = np.empty((3, 4))
353+
fz_src = np.empty((2, 3, 4))
354+
emsg = 'z_target and z_src must have the same number of dimensions'
355+
with self.assertRaisesRegexp(ValueError, emsg):
356+
vinterp._Interpolation(z_target, z_src, fz_src)
357+
358+
def test_nd_inconsistent_shape(self):
359+
z_target = np.empty((3, 2, 6))
360+
z_src = np.empty((3, 4, 5))
361+
fz_src = np.empty((2, 3, 4, 5))
362+
emsg = ('z_target and z_src have different shapes, '
363+
'got \(3, :, 6\) != \(3, :, 5\)')
364+
with self.assertRaisesRegexp(ValueError, emsg):
365+
vinterp._Interpolation(z_target, z_src, fz_src, axis=2)
346366

347367
def test_result_dtype_f4(self):
348368
interp = vinterp._Interpolation([17.5], np.arange(4) * 10,
@@ -362,26 +382,79 @@ def test_result_dtype_f8(self):
362382

363383

364384
class Test__Interpolation_interpolate_z_target_nd(unittest.TestCase):
365-
def test_target_z_3d_axis_0(self):
385+
def test_target_z_3d_on_axis_0(self):
366386
z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3])
367387
interp = vinterp._Interpolation(z_target, z_source, f_source,
368388
axis=0, extrapolation=stratify.EXTRAPOLATE_NEAREST)
369389
result = interp.interpolate_z_target_nd()
370390
assert_array_equal(result, f_source)
371391

372-
def test_target_z_3d_axis_m1(self):
392+
def test_target_z_3d_on_axis_m1(self):
373393
z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3])
374394
interp = vinterp._Interpolation(z_target, z_source, f_source,
375395
axis=-1, extrapolation=stratify.EXTRAPOLATE_NEAREST)
376396
result = interp.interpolate_z_target_nd()
377397
assert_array_equal(result, f_source)
378398

399+
def test_target_z_2d_over_3d_on_axis_1(self):
400+
"""
401+
Test the case where z_target(2, 4) and z_src(3, 4) are 2d, but the
402+
source data fz_src(3, 3, 4) is 3d. z_target and z_src cover the last
403+
2 dimensions of fz_src. The axis of interpolation is axis=1 wrt fz_src.
404+
405+
"""
406+
# Generate the 3d source data fz_src(3, 3, 4)
407+
base = np.arange(3).reshape(1, 3, 1) * 2
408+
data = np.broadcast_to(base, (3, 3, 4))
409+
fz_src = data * np.arange(1, 4).reshape(3, 1, 1) * 10
410+
# Generate the 2d target coordinate z_target(2, 4)
411+
# The target coordinate is configured to request the interpolated
412+
# mid-points over axis=1 of fz_src.
413+
z_target = np.repeat(np.arange(1, 4, 2).reshape(2, 1), 4, axis=1) * 10
414+
# Generate the 2d source coordinate z_src(3, 4)
415+
z_src = np.repeat(np.arange(3).reshape(3, 1), 4, axis=1) * 20
416+
# Configure the vertical interpolator.
417+
interp = vinterp._Interpolation(z_target, z_src, fz_src, axis=1)
418+
# Perform the vertical interpolation.
419+
result = interp.interpolate_z_target_nd()
420+
# Generate the 3d expected interpolated result(3, 2, 4).
421+
expected = np.repeat(z_target[np.newaxis, ...], 3, axis=0)
422+
expected = expected * np.arange(1, 4).reshape(3, 1, 1)
423+
assert_array_equal(result, expected)
424+
425+
def test_target_z_2d_over_3d_on_axis_m1(self):
426+
"""
427+
Test the case where z_target(3, 3) and z_src(3, 4) are 2d, but the
428+
source data fz_src(3, 3, 4) is 3d. z_target and z_src cover the last
429+
2 dimensions of fz_src. The axis of interpolation is the default last
430+
dimension, axis=-1, wrt fx_src.
431+
432+
"""
433+
# Generate the 3d source data fz_src(3, 3, 4)
434+
base = np.arange(4) * 2
435+
data = np.broadcast_to(base, (3, 3, 4))
436+
fz_src = data * np.arange(1, 4).reshape(3, 1, 1) * 10
437+
# Generate the 2d target coordinate z_target(3, 3)
438+
# The target coordinate is configured to request the interpolated
439+
# mid-points over axis=-1 (aka axis=2) of fz_src.
440+
z_target = np.repeat(np.arange(1, 6, 2).reshape(1, 3), 3, axis=0) * 10
441+
# Generate the 2d source coordinate z_src(3, 4)
442+
z_src = np.repeat(np.arange(4).reshape(1, 4), 3, axis=0) * 20
443+
# Configure the vertical interpolator.
444+
interp = vinterp._Interpolation(z_target, z_src, fz_src,)
445+
# Perform the vertical interpolation.
446+
result = interp.interpolate_z_target_nd()
447+
# Generate the 3d expected interpolated result(3, 3, 3)
448+
expected = np.repeat(z_target[np.newaxis, ...], 3, axis=0)
449+
expected = expected * np.arange(1, 4).reshape(3, 1, 1)
450+
assert_array_equal(result, expected)
451+
379452

380453
class Test_interpolate(unittest.TestCase):
381454
def test_target_z_3d_axis_0(self):
382455
z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3])
383-
result= vinterp.interpolate(z_target, z_source, f_source,
384-
extrapolation='linear')
456+
result = vinterp.interpolate(z_target, z_source, f_source,
457+
extrapolation='linear')
385458
assert_array_equal(result, f_source)
386459

387460

0 commit comments

Comments
 (0)