Skip to content

Commit 44135af

Browse files
authored
Merge pull request #18 from cpelley/ENH_CONSERVATIVE_APPROACH
ENH: Added conservative interpolation approach
2 parents 1e03c25 + bab0fdd commit 44135af

File tree

5 files changed

+497
-3
lines changed

5 files changed

+497
-3
lines changed

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
extensions = [Extension('{}._vinterp'.format(NAME),
1515
[os.path.join(NAME, '_vinterp.pyx')],
16+
include_dirs=[np.get_include()]),
17+
Extension('{}._conservative'.format(NAME),
18+
[os.path.join(NAME, '_conservative.pyx')],
1619
include_dirs=[np.get_include()])]
1720

1821

stratify/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import absolute_import
22

3-
from ._vinterp import (interpolate, interp_schemes, extrap_schemes,
4-
INTERPOLATE_LINEAR, INTERPOLATE_NEAREST,
3+
from ._vinterp import (interpolate, interp_schemes, # noqa: F401
4+
extrap_schemes, INTERPOLATE_LINEAR, INTERPOLATE_NEAREST,
55
EXTRAPOLATE_NAN, EXTRAPOLATE_NEAREST,
6-
EXTRAPOLATE_LINEAR, PyFuncExtrapolator,
6+
EXTRAPOLATE_LINEAR, PyFuncExtrapolator,
77
PyFuncInterpolator)
8+
from ._bounded_vinterp import interpolate_conservative # noqa: F401
9+
810

911
__version__ = '0.1a3.dev0'

stratify/_bounded_vinterp.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from __future__ import absolute_import
2+
3+
import numpy as np
4+
5+
from ._conservative import conservative_interpolation
6+
7+
8+
def interpolate_conservative(z_target, z_src, fz_src, axis=-1):
9+
"""
10+
1d conservative interpolation across multiple dimensions.
11+
12+
This function provides the ability to perform 1d interpolation on datasets
13+
with more than one dimension. For instance, this function can be used to
14+
interpolate a set of vertical levels, even if the interpolation coordinate
15+
depends upon other dimensions.
16+
17+
A good use case might be when wanting to interpolate at a specific height
18+
for height data which also depends on x and y - e.g. extract 1000hPa level
19+
from a 3d dataset and associated pressure field. In the case of this
20+
example, pressure would be the `z` coordinate, and the dataset
21+
(e.g. geopotential height / temperature etc.) would be `f(z)`.
22+
23+
Parameters
24+
----------
25+
z_target: :class:`np.ndarray`
26+
Target coordinate.
27+
This coordinate defines the levels to interpolate the source data
28+
``fz_src`` to. ``z_target`` must have the same dimensionality as the
29+
source coordinate ``z_src``, and the shape of ``z_target`` must match
30+
the shape of ``z_src``, although the axis of interpolation may differ
31+
in dimension size.
32+
z_src: :class:`np.ndarray`
33+
Source coordinate.
34+
This coordinate defines the levels that the source data ``fz_src`` is
35+
interpolated from.
36+
fz_src: :class:`np.ndarray`
37+
The source data; the phenomenon data values to be interpolated from
38+
``z_src`` to ``z_target``.
39+
The data array must be at least ``z_src.ndim``, and its trailing
40+
dimensions (i.e. those on its right hand side) must be exactly
41+
the same as the shape of ``z_src``.
42+
axis: int (default -1)
43+
The ``fz_src`` axis to perform the interpolation over.
44+
45+
Returns
46+
-------
47+
: :class:`np.ndarray`
48+
fz_src interpolated from z_src to z_target.
49+
50+
Note
51+
----
52+
- Support for 1D z_target and corresponding ND z_src will be provided in
53+
future as driven by user requirement.
54+
- Those cells, where 'nan' values in the source data contribute, a 'nan'
55+
value is returned.
56+
57+
"""
58+
if z_src.ndim != z_target.ndim:
59+
msg = ('Expecting source and target levels dimensionality to be '
60+
'identical. {} != {}.')
61+
raise ValueError(msg.format(z_src.ndim, z_target.ndim))
62+
63+
# Relative axis
64+
axis = axis % fz_src.ndim
65+
axis_relative = axis - (fz_src.ndim - (z_target.ndim-1))
66+
67+
src_shape = list(z_src.shape)
68+
src_shape.pop(axis_relative)
69+
tgt_shape = list(z_target.shape)
70+
tgt_shape.pop(axis_relative)
71+
72+
if src_shape != tgt_shape:
73+
src_shape = list(z_src.shape)
74+
src_shape[axis_relative] = '-'
75+
tgt_shape = list(z_target.shape)
76+
src_shape[axis_relative] = '-'
77+
msg = ('Expecting the shape of the source and target levels except '
78+
'the axis of interpolation to be identical. {} != {}')
79+
raise ValueError(msg.format(tuple(src_shape), tuple(tgt_shape)))
80+
81+
dat_shape = list(fz_src.shape)
82+
dat_shape = dat_shape[-(z_src.ndim-1):]
83+
src_shape = list(z_src.shape[:-1])
84+
if dat_shape != src_shape:
85+
dat_shape = list(fz_src.shape)
86+
dat_shape[:-(z_src.ndim-1)] = '-'
87+
msg = ('The provided data is not of compatible shape with the '
88+
'provided source bounds. {} != {}')
89+
raise ValueError(msg.format(tuple(dat_shape), tuple(src_shape)))
90+
91+
if z_src.shape[-1] != 2:
92+
msg = 'Unexpected source and target bounds shape. shape[-1] != 2'
93+
raise ValueError(msg)
94+
95+
# Define our source in a consistent way.
96+
# [broadcasting_dims, axis_interpolation, z_varying]
97+
98+
# src_data
99+
bdims = list(range(fz_src.ndim - (z_src.ndim-1)))
100+
data_vdims = [ind for ind in range(fz_src.ndim) if ind not in
101+
(bdims + [axis])]
102+
data_transpose = bdims + [axis] + data_vdims
103+
fz_src_reshaped = np.transpose(fz_src, data_transpose)
104+
fz_src_orig = list(fz_src_reshaped.shape)
105+
shape = (
106+
int(np.product(fz_src_reshaped.shape[:len(bdims)])),
107+
fz_src_reshaped.shape[len(bdims)],
108+
int(np.product(fz_src_reshaped.shape[len(bdims)+1:])))
109+
fz_src_reshaped = fz_src_reshaped.reshape(shape)
110+
111+
# Define our src and target bounds in a consistent way.
112+
# [axis_interpolation, z_varying, 2]
113+
vdims = list(set(range(z_src.ndim)) - set([axis_relative]))
114+
z_src_reshaped = np.transpose(z_src, [axis_relative] + vdims)
115+
z_target_reshaped = np.transpose(z_target, [axis_relative] + vdims)
116+
117+
shape = int(np.product(z_src_reshaped.shape[1:-1]))
118+
z_src_reshaped = z_src_reshaped.reshape([z_src_reshaped.shape[0], shape,
119+
z_src_reshaped.shape[-1]])
120+
shape = int(np.product(z_target_reshaped.shape[1:-1]))
121+
z_target_reshaped = z_target_reshaped.reshape(
122+
[z_target_reshaped.shape[0], shape, z_target_reshaped.shape[-1]])
123+
124+
result = conservative_interpolation(
125+
z_src_reshaped, z_target_reshaped, fz_src_reshaped)
126+
127+
# Turn the result into a shape consistent with the source.
128+
# First reshape, then reverse transpose.
129+
shape = fz_src_orig
130+
shape[len(bdims)] = z_target.shape[axis_relative]
131+
result = result.reshape(shape)
132+
invert_transpose = [data_transpose.index(ind) for ind in
133+
list(range(result.ndim))]
134+
result = result.transpose(invert_transpose)
135+
return result

stratify/_conservative.pyx

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import numpy as np
2+
cimport numpy as np
3+
4+
5+
cdef calculate_weights(np.ndarray[np.float64_t, ndim=2] src_point,
6+
np.ndarray[np.float64_t, ndim=2] tgt_point):
7+
"""
8+
Calculate weights for a given point.
9+
10+
The following visually illustrates the calculation::
11+
12+
src_min src_max
13+
|----------| : Source
14+
tgt_min tgt_max
15+
|------------| : Target
16+
|----------| : Delta (src_max - src_min)
17+
|----| : Overlap (between src & tgt)
18+
weight = overlap / delta
19+
20+
Parameters
21+
----------
22+
src_point (2d double array) - Source point (at a specific location).
23+
tgt_point (2d double array) - Target point (at a specific location).
24+
25+
Returns
26+
-------
27+
2d double array - Weights corresponding to shape [src_point.shape[0],
28+
tgt_point.shape[0]].
29+
30+
"""
31+
cdef Py_ssize_t src_ind, tgt_ind
32+
cdef np.float64_t delta, weight
33+
cdef np.ndarray[np.float64_t, ndim=2] weights
34+
cdef np.ndarray[np.float64_t, ndim=1] src_cell, tgt_cell
35+
36+
weights = np.zeros([src_point.shape[0], tgt_point.shape[0]],
37+
dtype=np.float64)
38+
for src_ind, src_cell in enumerate(src_point):
39+
delta = src_cell.max() - src_cell.min()
40+
for tgt_ind, tgt_cell in enumerate(tgt_point):
41+
weight = (min(src_cell.max(), tgt_cell.max()) -
42+
max(src_cell.min(), tgt_cell.min())) / float(delta)
43+
if weight > 0:
44+
weights[src_ind, tgt_ind] = weight
45+
return weights
46+
47+
48+
cdef apply_weights(np.ndarray[np.float64_t, ndim=3] src_point,
49+
np.ndarray[np.float64_t, ndim=3] tgt_point,
50+
np.ndarray[np.float64_t, ndim=3] src_data):
51+
"""
52+
Perform conservative interpolation.
53+
54+
Conservation interpolation of a dataset between a provided source
55+
coordinate and a target coordinate. Where no source cells contribute to a
56+
target cell, a np.nan value is returned.
57+
58+
Parameters
59+
----------
60+
src_points (3d double array) - Source coordinate, taking the form
61+
[axis_interpolation, z_varying, 2].
62+
tgt_points (3d double array) - Target coordinate, taking the form
63+
[axis_interpolation, z_varying, 2].
64+
src_data (3d double array) - The source data, the phenomenon data to be
65+
interpolated from ``src_points`` to ``tgt_points``. Taking the form
66+
[broadcasting_dims, axis_interpolation, z_varying].
67+
68+
Returns
69+
-------
70+
3d double array - Interpolated result over target levels (``tgt_points``).
71+
Taking the form [broadcasting_dims, axis_interpolation, z_varying].
72+
73+
"""
74+
cdef Py_ssize_t ind
75+
cdef np.ndarray[np.float64_t, ndim=3] results, weighted_contrib
76+
cdef np.ndarray[np.float64_t, ndim=2] weights
77+
results = np.zeros(
78+
[src_data.shape[0], tgt_point.shape[0], src_data.shape[2]],
79+
dtype='float64')
80+
# Calculate and apply weights
81+
for ind in range(src_data.shape[2]):
82+
weights = calculate_weights(src_point[:, ind], tgt_point[:, ind])
83+
if not (weights.sum(axis=1) == 1).all():
84+
msg = ('Weights calculation yields a less than conservative '
85+
'result. Aborting.')
86+
raise ValueError(msg)
87+
weighted_contrib = weights * src_data[..., ind][..., None]
88+
results[..., ind] = (
89+
np.nansum(weighted_contrib, axis=1))
90+
# Return nan values for those target cells, where there is any
91+
# contribution of nan data from the source data.
92+
results[..., ind][
93+
((weights > 0) * np.isnan(weighted_contrib)).any(axis=1)] = np.nan
94+
95+
# Return np.nan for those target cells where no source contributes.
96+
results[:, weights.sum(axis=0) == 0, ind] = np.nan
97+
return results
98+
99+
100+
def conservative_interpolation(src_points, tgt_points, src_data):
101+
"""
102+
Perform conservative interpolation.
103+
104+
Conservation interpolation of a dataset between a provided source
105+
coordinate and a target coordinate. All inputs are recast to 64-bit float
106+
arrays before calculation.
107+
108+
Parameters
109+
----------
110+
src_points (3d array) - Source coordinate, taking the form
111+
[axis_interpolation, z_varying, 2].
112+
tgt_points (3d array) - Target coordinate, taking the form
113+
[axis_interpolation, z_varying, 2].
114+
src_data (3d array) - The source data, the phenonenon data to be
115+
interpolated from ``src_points`` to ``tgt_points``. Taking the form
116+
[broadcasting_dims, axis_interpolation, z_varying].
117+
118+
Returns
119+
-------
120+
3d double array - Interpolated result over target levels (``tgt_points``).
121+
Taking the form [broadcasting_dims, axis_interpolation, z_varying].
122+
123+
"""
124+
return apply_weights(src_points.astype('float64'),
125+
tgt_points.astype('float64'),
126+
src_data.astype('float64'))

0 commit comments

Comments
 (0)