Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions parcels/application_kernels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"UXPiecewiseConstantFace",
"UXPiecewiseLinearNode",
"XLinear",
"XNearest",
"ZeroInterpolator",
]

Expand Down Expand Up @@ -111,6 +112,69 @@ def XLinear(
return value.compute() if isinstance(value, dask.Array) else value


def XNearest(
field: Field,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
):
"""
Nearest-Neighbour spatial interpolation on a regular grid.
Note that this still uses linear interpolation in time.
"""
xi, xsi = position["X"]
yi, eta = position["Y"]
zi, zeta = position["Z"]

axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
data = field.data

lenT = 2 if np.any(tau > 0) else 1

# Spatial coordinates: left if barycentric < 0.5, otherwise right
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
zi_full = np.where(zeta < 0.5, zi, zi_1)

yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
yi_full = np.where(eta < 0.5, yi, yi_1)

xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
xi_full = np.where(xsi < 0.5, xi, xi_1)

# Time coordinates: 1 point at ti, then 1 point at ti+1
if lenT == 1:
ti_full = ti
else:
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
ti_full = np.concatenate([ti, ti_1])
xi_full = np.repeat(xi_full, 2)
yi_full = np.repeat(yi_full, 2)
zi_full = np.repeat(zi_full, 2)

# Create DataArrays for indexing
selection_dict = {
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
}
if "Z" in axis_dim:
selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points"))
if "time" in data.dims:
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))

corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi))

if lenT == 2:
value = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau
else:
value = corner_data[0, :]

return value.compute() if isinstance(value, dask.Array) else value


def UXPiecewiseConstantFace(
field: Field,
ti: int,
Expand Down
46 changes: 0 additions & 46 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import numpy as np
import pytest
import xarray as xr

import parcels._interpolation as interpolation
from tests.utils import create_fieldset_zeros_3d
Expand Down Expand Up @@ -31,50 +29,6 @@ def some_function():
assert f() == g() == "test"


def create_interpolation_data():
"""Reference data used for testing interpolation.

Most interpolation will be focussed around index
(depth, lat, lon) = (zi, yi, xi) = (1, 1, 1) with ti=0.
"""
z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous
[
[0.0, 1.0, 2.0, 3.0],
[2.0, 3.0, 4.0, 5.0],
[4.0, 5.0, 6.0, 7.0],
[6.0, 7.0, 8.0, 9.0],
]
)
spatial_data = [z0, z0 + 3, z0 + 6, z0 + 9] # each z is +3 from the previous
return xr.DataArray([spatial_data, spatial_data, spatial_data], dims=("time", "depth", "lat", "lon"))


@pytest.fixture
def data_2d():
"""2D slice of the reference data at depth=0."""
return create_interpolation_data().isel(depth=0).values


@pytest.mark.v4remove
@pytest.mark.xfail(reason="GH1946")
@pytest.mark.parametrize(
"func, eta, xsi, expected",
[
pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"),
pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"),
pytest.param(interpolation._nearest_2d, 0.51, 0.49, 5.0, id="nearest_2d-3"),
pytest.param(interpolation._nearest_2d, 0.51, 0.51, 6.0, id="nearest_2d-4"),
pytest.param(interpolation._tracer_2d, None, None, 6.0, id="tracer_2d"),
],
)
def test_raw_2d_interpolation(data_2d, func, eta, xsi, expected):
"""Test the 2D interpolation functions on the raw arrays."""
tau, ti = 0, 0
yi, xi = 1, 1
ctx = interpolation.InterpolationContext2D(data_2d, tau, eta, xsi, ti, yi, xi)
assert func(ctx) == expected


@pytest.mark.v4remove
@pytest.mark.xfail(reason="GH1946")
@pytest.mark.usefixtures("tmp_interpolator_registry")
Expand Down
67 changes: 65 additions & 2 deletions tests/v4/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from parcels._datasets.structured.generated import simple_UV_dataset
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
from parcels._index_search import _search_time_index
from parcels.application_kernels.advection import AdvectionRK4_3D
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, XNearest, ZeroInterpolator
from parcels.field import Field, VectorField
from parcels.fieldset import FieldSet
from parcels.particle import Particle, Variable
Expand All @@ -16,8 +17,70 @@
from tests.utils import TEST_DATA


@pytest.fixture
def field():
"""Reference data used for testing interpolation."""
z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous
[
[0.0, 1.0, 2.0, 3.0],
[2.0, 3.0, 4.0, 5.0],
[4.0, 5.0, 6.0, 7.0],
[6.0, 7.0, 8.0, 9.0],
]
)
spatial_data = np.array([z0, z0 + 3, z0 + 6, z0 + 9]) # each z is +3 from the previous
temporal_data = np.array([spatial_data, spatial_data + 10, spatial_data + 20]) # each t is +10 from the previous

ds = xr.Dataset(
{"U": (["time", "depth", "lat", "lon"], temporal_data)},
coords={
"time": (["time"], [np.timedelta64(t, "s") for t in [0, 2, 4]], {"axis": "T"}),
"depth": (["depth"], [0, 1, 2, 3], {"axis": "Z"}),
"lat": (["lat"], [0, 1, 2, 3], {"axis": "Y", "c_grid_axis_shift": -0.5}),
"lon": (["lon"], [0, 1, 2, 3], {"axis": "X", "c_grid_axis_shift": -0.5}),
"x": (["x"], [0.5, 1.5, 2.5, 3.5], {"axis": "X"}),
"y": (["y"], [0.5, 1.5, 2.5, 3.5], {"axis": "Y"}),
},
)
return Field("U", ds["U"], XGrid.from_dataset(ds))


@pytest.mark.parametrize(
"func, t, z, y, x, expected",
[
pytest.param(ZeroInterpolator, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 0, id="Zero"),
pytest.param(
XLinear,
[np.timedelta64(0, "s"), np.timedelta64(1, "s")],
[0, 0],
[0.49, 0.49],
[0.51, 0.51],
[1.49, 6.49],
id="Linear",
),
pytest.param(XLinear, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="Linear-2"),
pytest.param(
XNearest,
[np.timedelta64(0, "s"), np.timedelta64(3, "s")],
[0.2, 0.2],
[0.2, 0.2],
[0.51, 0.51],
[1.0, 16.0],
id="Nearest",
),
],
)
def test_raw_2d_interpolation(field, func, t, z, y, x, expected):
"""Test the interpolation functions on the Field."""
tau, ti = _search_time_index(field, t)
position = field.grid.search(z, y, x)

value = func(field, ti, position, tau, 0, 0, y, x)
np.testing.assert_equal(value, expected)


@pytest.mark.parametrize("mesh", ["spherical", "flat"])
def test_interpolation_mesh(mesh, npart=10):
def test_interpolation_mesh_type(mesh, npart=10):
ds = simple_UV_dataset(mesh=mesh)
ds["U"].data[:] = 1.0
grid = XGrid.from_dataset(ds, mesh=mesh)
Expand Down
Loading