Skip to content

Commit 79d694e

Browse files
authored
Merge pull request #68 from CardiacModelling/27-gradient
Add gradient to linear interpolant
2 parents a661763 + 26ca5ff commit 79d694e

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

nevis/_interpolation.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,17 @@
1313

1414
class linear_interpolant(object):
1515
"""
16-
Returns a linear interpolation over the full GB data set.
16+
Returns a linear interpolation and optionally its gradient over the full
17+
GB data set.
1718
18-
The returned function takes two arguments ``x`` and ``y`` (both in metres)
19-
and returns an interpolated height ``z`` (in meters).
19+
When ``grad`` is set to ``False`` (by default), the returned function
20+
takes two arguments ``x`` and ``y`` (both in metres) and returns an
21+
interpolated height ``z`` (in meters).
22+
23+
When ``grad`` is set to ``True``, the returned function takes two
24+
arguments ``x`` and ``y`` (both in metres) and returns a tuple ``(z, g)``,
25+
where ``z`` is an interpolated height (in meters) and ``g`` is a tuple
26+
``(dz/dx, dz/dy)``.
2027
2128
The height for each grid point ``(i, j)`` is assumed to be in the center of
2229
the square from ``(i, j)`` to ``(i + 1, j + 1)``.
@@ -26,13 +33,18 @@ class linear_interpolant(object):
2633
f = linear_interpolation()
2734
print(f(1000, 500))
2835
36+
f_grad = linear_interpolation(grad=True)
37+
z, (gx, gy) = f_grad(1000, 500)
38+
print(f"Height: {z:.2f} m, gradient: ({gx:.2f} m/m, {gy:.2f} m/m)")
39+
2940
"""
3041
# Note: This is technically a class, but used as a function here so
3142
# following the underscore naming convention.
3243

33-
def __init__(self):
44+
def __init__(self, grad=False):
3445
self._heights = nevis.gb()
3546
self._resolution = nevis.spacing()
47+
self.grad = grad
3648

3749
def __call__(self, x, y):
3850
ny, nx = self._heights.shape
@@ -65,7 +77,13 @@ def __call__(self, x, y):
6577
f2 = np.where(h12 == h22, h12, (x2 - x) * h12 + (x - x1) * h22)
6678

6779
# Final result
68-
return np.where(f1 == f2, f1, (y2 - y) * f1 + (y - y1) * f2)
80+
f = float(np.where(f1 == f2, f1, (y2 - y) * f1 + (y - y1) * f2))
81+
if self.grad:
82+
# Gradient of the interpolant
83+
g = (h21 - h11) / self._resolution, (h12 - h11) / self._resolution
84+
return f, g
85+
else:
86+
return f
6987

7088

7189
def spline(verbose=False):

0 commit comments

Comments
 (0)