diff --git a/nevis/_interpolation.py b/nevis/_interpolation.py index 7311341..6d6cfcd 100644 --- a/nevis/_interpolation.py +++ b/nevis/_interpolation.py @@ -13,10 +13,17 @@ class linear_interpolant(object): """ - Returns a linear interpolation over the full GB data set. + Returns a linear interpolation and optionally its gradient over the full + GB data set. - The returned function takes two arguments ``x`` and ``y`` (both in metres) - and returns an interpolated height ``z`` (in meters). + When ``grad`` is set to ``False`` (by default), the returned function + takes two arguments ``x`` and ``y`` (both in metres) and returns an + interpolated height ``z`` (in meters). + + When ``grad`` is set to ``True``, the returned function takes two + arguments ``x`` and ``y`` (both in metres) and returns a tuple ``(z, g)``, + where ``z`` is an interpolated height (in meters) and ``g`` is a tuple + ``(dz/dx, dz/dy)``. The height for each grid point ``(i, j)`` is assumed to be in the center of the square from ``(i, j)`` to ``(i + 1, j + 1)``. @@ -26,13 +33,18 @@ class linear_interpolant(object): f = linear_interpolation() print(f(1000, 500)) + f_grad = linear_interpolation(grad=True) + z, (gx, gy) = f_grad(1000, 500) + print(f"Height: {z:.2f} m, gradient: ({gx:.2f} m/m, {gy:.2f} m/m)") + """ # Note: This is technically a class, but used as a function here so # following the underscore naming convention. - def __init__(self): + def __init__(self, grad=False): self._heights = nevis.gb() self._resolution = nevis.spacing() + self.grad = grad def __call__(self, x, y): ny, nx = self._heights.shape @@ -65,7 +77,13 @@ def __call__(self, x, y): f2 = np.where(h12 == h22, h12, (x2 - x) * h12 + (x - x1) * h22) # Final result - return np.where(f1 == f2, f1, (y2 - y) * f1 + (y - y1) * f2) + f = float(np.where(f1 == f2, f1, (y2 - y) * f1 + (y - y1) * f2)) + if self.grad: + # Gradient of the interpolant + g = (h21 - h11) / self._resolution, (h12 - h11) / self._resolution + return f, g + else: + return f def spline(verbose=False):