Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/parcels/_core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
_ei = None
else:
_ei = particles.ei[:, self.igrid]
z = np.atleast_1d(z)
y = np.atleast_1d(y)
x = np.atleast_1d(x)

particle_positions, grid_positions = _get_positions(self, time, z, y, x, particles, _ei)

Expand Down Expand Up @@ -300,6 +303,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
_ei = None
else:
_ei = particles.ei[:, self.igrid]
z = np.atleast_1d(z)
y = np.atleast_1d(y)
x = np.atleast_1d(x)

particle_positions, grid_positions = _get_positions(self.U, time, z, y, x, particles, _ei)

Expand Down
40 changes: 24 additions & 16 deletions src/parcels/_core/xgrid.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that in future search() can be refactored making it clearer how the grids are handled to make sure that we're covering edge cases.

Something for my (later) todo - no need to make into an issue

Original file line number Diff line number Diff line change
Expand Up @@ -289,22 +289,13 @@ def search(self, z, y, x, ei=None):
else:
zi, zeta = np.zeros(z.shape, dtype=int), np.zeros(z.shape, dtype=float)

if ds.lon.ndim == 1:
yi, eta = _search_1d_array(ds.lat.values, y)
xi, xsi = _search_1d_array(ds.lon.values, x)
return {
"Z": {"index": zi, "bcoord": zeta},
"Y": {"index": yi, "bcoord": eta},
"X": {"index": xi, "bcoord": xsi},
}
if "X" in self.axes and "Y" in self.axes and ds.lon.ndim == 2:
yi, xi = None, None
if ei is not None:
axis_indices = self.unravel_index(ei)
xi = axis_indices.get("X")
yi = axis_indices.get("Y")

yi, xi = None, None
if ei is not None:
axis_indices = self.unravel_index(ei)
xi = axis_indices.get("X")
yi = axis_indices.get("Y")

if ds.lon.ndim == 2:
yi, eta, xi, xsi = _search_indices_curvilinear_2d(self, y, x, yi, xi)

return {
Expand All @@ -313,7 +304,24 @@ def search(self, z, y, x, ei=None):
"X": {"index": xi, "bcoord": xsi},
}

raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
if "X" in self.axes and ds.lon.ndim > 2:
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")

if "Y" in self.axes:
yi, eta = _search_1d_array(ds.lat.values, y)
else:
yi, eta = np.zeros(y.shape, dtype=int), np.zeros(y.shape, dtype=float)

if "X" in self.axes:
xi, xsi = _search_1d_array(ds.lon.values, x)
else:
xi, xsi = np.zeros(x.shape, dtype=int), np.zeros(x.shape, dtype=float)

return {
"Z": {"index": zi, "bcoord": zeta},
"Y": {"index": yi, "bcoord": eta},
"X": {"index": xi, "bcoord": xsi},
}

@cached_property
def _fpoint_info(self):
Expand Down
9 changes: 5 additions & 4 deletions src/parcels/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def _get_corner_data_Agrid(
xi = np.tile(np.array([xi, xi_1]).flatten(), lenT * lenZ * 2)

# Create DataArrays for indexing
selection_dict = {
axis_dim["X"]: xr.DataArray(xi, dims=("points")),
axis_dim["Y"]: xr.DataArray(yi, dims=("points")),
}
selection_dict = {}
if "X" in axis_dim:
selection_dict[axis_dim["X"]] = xr.DataArray(xi, dims=("points"))
if "Y" in axis_dim:
selection_dict[axis_dim["Y"]] = xr.DataArray(yi, dims=("points"))
if "Z" in axis_dim:
selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points"))
if "time" in data.dims:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import xarray as xr
from numpy.testing import assert_allclose

from parcels import Field
from parcels._core.index_search import (
LEFT_OUT_OF_BOUNDS,
RIGHT_OUT_OF_BOUNDS,
Expand Down Expand Up @@ -134,6 +135,30 @@ def test_invalid_depth():
XGrid.from_dataset(ds)


def test_vertical1D_field():
nz = 11
ds = xr.Dataset(
{"z1d": (["depth"], np.linspace(0, 10, nz))},
coords={"depth": (["depth"], np.linspace(0, 1, nz), {"axis": "Z"})},
)
grid = XGrid.from_dataset(ds)
field = Field("z1d", ds["z1d"], grid)

assert field.eval(np.timedelta64(0, "s"), 0.45, 0, 0) == 4.5


def test_time1D_field():
timerange = xr.date_range("2000-01-01", "2000-01-20")
ds = xr.Dataset(
{"t1d": (["time"], np.arange(0, len(timerange)))},
coords={"time": (["time"], timerange, {"axis": "T"})},
)
grid = XGrid.from_dataset(ds)
field = Field("t1d", ds["t1d"], grid)

assert field.eval(np.datetime64("2000-01-10T12:00:00"), -20, 5, 6) == 9.5


@pytest.mark.parametrize(
"ds",
[
Expand Down
Loading