diff --git a/parcels/application_kernels/interpolation.py b/parcels/application_kernels/interpolation.py index 1df817a2a..20008c457 100644 --- a/parcels/application_kernels/interpolation.py +++ b/parcels/application_kernels/interpolation.py @@ -17,6 +17,7 @@ "UXPiecewiseConstantFace", "UXPiecewiseLinearNode", "XLinear", + "XNearest", "ZeroInterpolator", ] @@ -111,6 +112,69 @@ def XLinear( return value.compute() if isinstance(value, dask.Array) else value +def XNearest( + field: Field, + ti: int, + position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]], + tau: np.float32 | np.float64, + t: np.float32 | np.float64, + z: np.float32 | np.float64, + y: np.float32 | np.float64, + x: np.float32 | np.float64, +): + """ + Nearest-Neighbour spatial interpolation on a regular grid. + Note that this still uses linear interpolation in time. + """ + xi, xsi = position["X"] + yi, eta = position["Y"] + zi, zeta = position["Z"] + + axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) + data = field.data + + lenT = 2 if np.any(tau > 0) else 1 + + # Spatial coordinates: left if barycentric < 0.5, otherwise right + zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1) + zi_full = np.where(zeta < 0.5, zi, zi_1) + + yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1) + yi_full = np.where(eta < 0.5, yi, yi_1) + + xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1) + xi_full = np.where(xsi < 0.5, xi, xi_1) + + # Time coordinates: 1 point at ti, then 1 point at ti+1 + if lenT == 1: + ti_full = ti + else: + ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1) + ti_full = np.concatenate([ti, ti_1]) + xi_full = np.repeat(xi_full, 2) + yi_full = np.repeat(yi_full, 2) + zi_full = np.repeat(zi_full, 2) + + # Create DataArrays for indexing + selection_dict = { + axis_dim["X"]: xr.DataArray(xi_full, dims=("points")), + axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")), + } + if "Z" in axis_dim: + selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points")) + if "time" in data.dims: + selection_dict["time"] = xr.DataArray(ti_full, dims=("points")) + + corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi)) + + if lenT == 2: + value = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau + else: + value = corner_data[0, :] + + return value.compute() if isinstance(value, dask.Array) else value + + def UXPiecewiseConstantFace( field: Field, ti: int, diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 6dd9ff691..67eab6420 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -1,6 +1,4 @@ -import numpy as np import pytest -import xarray as xr import parcels._interpolation as interpolation from tests.utils import create_fieldset_zeros_3d @@ -31,50 +29,6 @@ def some_function(): assert f() == g() == "test" -def create_interpolation_data(): - """Reference data used for testing interpolation. - - Most interpolation will be focussed around index - (depth, lat, lon) = (zi, yi, xi) = (1, 1, 1) with ti=0. - """ - z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous - [ - [0.0, 1.0, 2.0, 3.0], - [2.0, 3.0, 4.0, 5.0], - [4.0, 5.0, 6.0, 7.0], - [6.0, 7.0, 8.0, 9.0], - ] - ) - spatial_data = [z0, z0 + 3, z0 + 6, z0 + 9] # each z is +3 from the previous - return xr.DataArray([spatial_data, spatial_data, spatial_data], dims=("time", "depth", "lat", "lon")) - - -@pytest.fixture -def data_2d(): - """2D slice of the reference data at depth=0.""" - return create_interpolation_data().isel(depth=0).values - - -@pytest.mark.v4remove -@pytest.mark.xfail(reason="GH1946") -@pytest.mark.parametrize( - "func, eta, xsi, expected", - [ - pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"), - pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"), - pytest.param(interpolation._nearest_2d, 0.51, 0.49, 5.0, id="nearest_2d-3"), - pytest.param(interpolation._nearest_2d, 0.51, 0.51, 6.0, id="nearest_2d-4"), - pytest.param(interpolation._tracer_2d, None, None, 6.0, id="tracer_2d"), - ], -) -def test_raw_2d_interpolation(data_2d, func, eta, xsi, expected): - """Test the 2D interpolation functions on the raw arrays.""" - tau, ti = 0, 0 - yi, xi = 1, 1 - ctx = interpolation.InterpolationContext2D(data_2d, tau, eta, xsi, ti, yi, xi) - assert func(ctx) == expected - - @pytest.mark.v4remove @pytest.mark.xfail(reason="GH1946") @pytest.mark.usefixtures("tmp_interpolator_registry") diff --git a/tests/v4/test_interpolation.py b/tests/v4/test_interpolation.py index a8c52b7cf..d3ca4dea3 100644 --- a/tests/v4/test_interpolation.py +++ b/tests/v4/test_interpolation.py @@ -4,8 +4,9 @@ from parcels._datasets.structured.generated import simple_UV_dataset from parcels._datasets.unstructured.generic import datasets as datasets_unstructured +from parcels._index_search import _search_time_index from parcels.application_kernels.advection import AdvectionRK4_3D -from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear +from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, XNearest, ZeroInterpolator from parcels.field import Field, VectorField from parcels.fieldset import FieldSet from parcels.particle import Particle, Variable @@ -16,8 +17,70 @@ from tests.utils import TEST_DATA +@pytest.fixture +def field(): + """Reference data used for testing interpolation.""" + z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous + [ + [0.0, 1.0, 2.0, 3.0], + [2.0, 3.0, 4.0, 5.0], + [4.0, 5.0, 6.0, 7.0], + [6.0, 7.0, 8.0, 9.0], + ] + ) + spatial_data = np.array([z0, z0 + 3, z0 + 6, z0 + 9]) # each z is +3 from the previous + temporal_data = np.array([spatial_data, spatial_data + 10, spatial_data + 20]) # each t is +10 from the previous + + ds = xr.Dataset( + {"U": (["time", "depth", "lat", "lon"], temporal_data)}, + coords={ + "time": (["time"], [np.timedelta64(t, "s") for t in [0, 2, 4]], {"axis": "T"}), + "depth": (["depth"], [0, 1, 2, 3], {"axis": "Z"}), + "lat": (["lat"], [0, 1, 2, 3], {"axis": "Y", "c_grid_axis_shift": -0.5}), + "lon": (["lon"], [0, 1, 2, 3], {"axis": "X", "c_grid_axis_shift": -0.5}), + "x": (["x"], [0.5, 1.5, 2.5, 3.5], {"axis": "X"}), + "y": (["y"], [0.5, 1.5, 2.5, 3.5], {"axis": "Y"}), + }, + ) + return Field("U", ds["U"], XGrid.from_dataset(ds)) + + +@pytest.mark.parametrize( + "func, t, z, y, x, expected", + [ + pytest.param(ZeroInterpolator, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 0, id="Zero"), + pytest.param( + XLinear, + [np.timedelta64(0, "s"), np.timedelta64(1, "s")], + [0, 0], + [0.49, 0.49], + [0.51, 0.51], + [1.49, 6.49], + id="Linear", + ), + pytest.param(XLinear, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="Linear-2"), + pytest.param( + XNearest, + [np.timedelta64(0, "s"), np.timedelta64(3, "s")], + [0.2, 0.2], + [0.2, 0.2], + [0.51, 0.51], + [1.0, 16.0], + id="Nearest", + ), + ], +) +def test_raw_2d_interpolation(field, func, t, z, y, x, expected): + """Test the interpolation functions on the Field.""" + tau, ti = _search_time_index(field, t) + position = field.grid.search(z, y, x) + + value = func(field, ti, position, tau, 0, 0, y, x) + np.testing.assert_equal(value, expected) + + @pytest.mark.parametrize("mesh", ["spherical", "flat"]) -def test_interpolation_mesh(mesh, npart=10): +def test_interpolation_mesh_type(mesh, npart=10): ds = simple_UV_dataset(mesh=mesh) ds["U"].data[:] = 1.0 grid = XGrid.from_dataset(ds, mesh=mesh)