Skip to content

Commit 0504d40

Browse files
Merge pull request #2363 from Parcels-code/support_1d_fields
Support for 1D fields without lon and/or lat
2 parents 9dca8d5 + 3c6cc8e commit 0504d40

File tree

4 files changed

+61
-20
lines changed

4 files changed

+61
-20
lines changed

src/parcels/_core/field.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
206206
_ei = None
207207
else:
208208
_ei = particles.ei[:, self.igrid]
209+
z = np.atleast_1d(z)
210+
y = np.atleast_1d(y)
211+
x = np.atleast_1d(x)
209212

210213
particle_positions, grid_positions = _get_positions(self, time, z, y, x, particles, _ei)
211214

@@ -289,6 +292,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
289292
_ei = None
290293
else:
291294
_ei = particles.ei[:, self.igrid]
295+
z = np.atleast_1d(z)
296+
y = np.atleast_1d(y)
297+
x = np.atleast_1d(x)
292298

293299
particle_positions, grid_positions = _get_positions(self.U, time, z, y, x, particles, _ei)
294300

src/parcels/_core/xgrid.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -289,22 +289,13 @@ def search(self, z, y, x, ei=None):
289289
else:
290290
zi, zeta = np.zeros(z.shape, dtype=int), np.zeros(z.shape, dtype=float)
291291

292-
if ds.lon.ndim == 1:
293-
yi, eta = _search_1d_array(ds.lat.values, y)
294-
xi, xsi = _search_1d_array(ds.lon.values, x)
295-
return {
296-
"Z": {"index": zi, "bcoord": zeta},
297-
"Y": {"index": yi, "bcoord": eta},
298-
"X": {"index": xi, "bcoord": xsi},
299-
}
292+
if "X" in self.axes and "Y" in self.axes and ds.lon.ndim == 2:
293+
yi, xi = None, None
294+
if ei is not None:
295+
axis_indices = self.unravel_index(ei)
296+
xi = axis_indices.get("X")
297+
yi = axis_indices.get("Y")
300298

301-
yi, xi = None, None
302-
if ei is not None:
303-
axis_indices = self.unravel_index(ei)
304-
xi = axis_indices.get("X")
305-
yi = axis_indices.get("Y")
306-
307-
if ds.lon.ndim == 2:
308299
yi, eta, xi, xsi = _search_indices_curvilinear_2d(self, y, x, yi, xi)
309300

310301
return {
@@ -313,7 +304,24 @@ def search(self, z, y, x, ei=None):
313304
"X": {"index": xi, "bcoord": xsi},
314305
}
315306

316-
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
307+
if "X" in self.axes and ds.lon.ndim > 2:
308+
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
309+
310+
if "Y" in self.axes:
311+
yi, eta = _search_1d_array(ds.lat.values, y)
312+
else:
313+
yi, eta = np.zeros(y.shape, dtype=int), np.zeros(y.shape, dtype=float)
314+
315+
if "X" in self.axes:
316+
xi, xsi = _search_1d_array(ds.lon.values, x)
317+
else:
318+
xi, xsi = np.zeros(x.shape, dtype=int), np.zeros(x.shape, dtype=float)
319+
320+
return {
321+
"Z": {"index": zi, "bcoord": zeta},
322+
"Y": {"index": yi, "bcoord": eta},
323+
"X": {"index": xi, "bcoord": xsi},
324+
}
317325

318326
@cached_property
319327
def _fpoint_info(self):

src/parcels/interpolators.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ def _get_corner_data_Agrid(
8484
xi = np.tile(np.array([xi, xi_1]).flatten(), lenT * lenZ * 2)
8585

8686
# Create DataArrays for indexing
87-
selection_dict = {
88-
axis_dim["X"]: xr.DataArray(xi, dims=("points")),
89-
axis_dim["Y"]: xr.DataArray(yi, dims=("points")),
90-
}
87+
selection_dict = {}
88+
if "X" in axis_dim:
89+
selection_dict[axis_dim["X"]] = xr.DataArray(xi, dims=("points"))
90+
if "Y" in axis_dim:
91+
selection_dict[axis_dim["Y"]] = xr.DataArray(yi, dims=("points"))
9192
if "Z" in axis_dim:
9293
selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points"))
9394
if "time" in data.dims:

tests/test_xgrid.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import xarray as xr
77
from numpy.testing import assert_allclose
88

9+
from parcels import Field
910
from parcels._core.index_search import (
1011
LEFT_OUT_OF_BOUNDS,
1112
RIGHT_OUT_OF_BOUNDS,
@@ -16,6 +17,7 @@
1617
_transpose_xfield_data_to_tzyx,
1718
)
1819
from parcels._datasets.structured.generic import X, Y, Z, datasets
20+
from parcels.interpolators import XLinear
1921
from tests import utils
2022

2123
GridTestCase = namedtuple("GridTestCase", ["ds", "attr", "expected"])
@@ -134,6 +136,30 @@ def test_invalid_depth():
134136
XGrid.from_dataset(ds)
135137

136138

139+
def test_vertical1D_field():
140+
nz = 11
141+
ds = xr.Dataset(
142+
{"z1d": (["depth"], np.linspace(0, 10, nz))},
143+
coords={"depth": (["depth"], np.linspace(0, 1, nz), {"axis": "Z"})},
144+
)
145+
grid = XGrid.from_dataset(ds)
146+
field = Field("z1d", ds["z1d"], grid, XLinear)
147+
148+
assert field.eval(np.timedelta64(0, "s"), 0.45, 0, 0) == 4.5
149+
150+
151+
def test_time1D_field():
152+
timerange = xr.date_range("2000-01-01", "2000-01-20")
153+
ds = xr.Dataset(
154+
{"t1d": (["time"], np.arange(0, len(timerange)))},
155+
coords={"time": (["time"], timerange, {"axis": "T"})},
156+
)
157+
grid = XGrid.from_dataset(ds)
158+
field = Field("t1d", ds["t1d"], grid, XLinear)
159+
160+
assert field.eval(np.datetime64("2000-01-10T12:00:00"), -20, 5, 6) == 9.5
161+
162+
137163
@pytest.mark.parametrize(
138164
"ds",
139165
[

0 commit comments

Comments
 (0)