Skip to content

Commit de62134

Browse files
Merge branch 'v4-dev' into interpolation_tutorial_v4
2 parents 302c59a + 6fc7724 commit de62134

File tree

6 files changed

+44
-11
lines changed

6 files changed

+44
-11
lines changed

src/parcels/_core/field.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from parcels._core.utils.string import _assert_str_and_python_varname
2222
from parcels._core.utils.time import TimeInterval
2323
from parcels._core.uxgrid import UxGrid
24-
from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx
24+
from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx, assert_all_field_dims_have_axis
2525
from parcels._python import assert_same_function_signature
2626
from parcels._reprs import default_repr
2727
from parcels._typing import VectorType
@@ -103,6 +103,7 @@ def __init__(
103103
_assert_compatible_combination(data, grid)
104104

105105
if isinstance(grid, XGrid):
106+
assert_all_field_dims_have_axis(data, grid.xgcm_grid)
106107
data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid)
107108

108109
self.name = name

src/parcels/_core/index_search.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,23 @@ def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray
102102
)
103103

104104
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
105+
# Map grid and particle longitudes to [-180,180)
106+
px = ((px + 180.0) % 360.0) - 180.0
107+
x = ((x + 180.0) % 360.0) - 180.0
108+
109+
# Create a mask for antimeridian cells
110+
lon_span = px.max(axis=0) - px.min(axis=0)
111+
antimeridian_cell = lon_span > 180.0
112+
113+
if np.any(antimeridian_cell):
114+
# For any antimeridian cell ...
115+
# If particle longitude is closer to 180.0, then negative cell longitudes need to be bumped by +360
116+
mask = (px < 0.0) & antimeridian_cell[np.newaxis, :] & (x[np.newaxis, :] > 0.0)
117+
px[mask] += 360.0
118+
# If particle longitude is closer to -180.0, then positive cell longitudes need to be bumped by -360
119+
mask = (px > 0.0) & antimeridian_cell[np.newaxis, :] & (x[np.newaxis, :] < 0.0)
120+
px[mask] -= 360.0
121+
105122
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
106123

107124
a, b = np.dot(invA, px), np.dot(invA, py)
@@ -119,7 +136,6 @@ def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray
119136
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
120137
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
121138
)
122-
123139
is_in_cell = np.where((xsi >= 0) & (xsi <= 1) & (eta >= 0) & (eta <= 1), 1, 0)
124140

125141
return is_in_cell, np.column_stack((xsi, eta))

src/parcels/_core/xgrid.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ def _drop_field_data(ds: xr.Dataset) -> xr.Dataset:
4848
return ds.drop_vars(ds.data_vars)
4949

5050

51+
def assert_all_field_dims_have_axis(da: xr.DataArray, xgcm_grid: xgcm.Grid) -> None:
52+
ax_dims = [(get_axis_from_dim_name(xgcm_grid.axes, dim), dim) for dim in da.dims]
53+
for dim in ax_dims:
54+
if dim[0] is None:
55+
raise ValueError(
56+
f'Dimension "{dim[1]}" has no axis attribute. '
57+
f'HINT: You may want to add an {{"axis": A}} to your DataSet["{dim[1]}"], where A is one of "X", "Y", "Z" or "T"'
58+
)
59+
60+
5161
def _transpose_xfield_data_to_tzyx(da: xr.DataArray, xgcm_grid: xgcm.Grid) -> xr.DataArray:
5262
"""
5363
Transpose a DataArray of any shape into a 4D array of order TZYX. Uses xgcm to determine

src/parcels/interpolators.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,7 @@ def CGrid_Velocity(
175175
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
176176

177177
if grid._mesh == "spherical":
178-
px[0] = np.where(px[0] < lon - 225, px[0] + 360, px[0])
179-
px[0] = np.where(px[0] > lon + 225, px[0] - 360, px[0])
178+
px = ((px + 180.0) % 360.0) - 180.0
180179
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
181180
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
182181
c1 = i_u._geodetic_distance(
@@ -291,7 +290,10 @@ def CGrid_Velocity(
291290

292291
# check whether the grid conversion has been applied correctly
293292
xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3]
294-
u = np.where(np.abs((xx - lon) / lon) > 1e-4, np.nan, u)
293+
dlon = xx - lon
294+
if grid._mesh == "spherical":
295+
dlon = ((dlon + 180.0) % 360.0) - 180.0
296+
u = np.where(np.abs(dlon / lon) > 1e-4, np.nan, u)
295297

296298
if vectorfield.W:
297299
data = vectorfield.W.data

tests/test_advection.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,20 +454,17 @@ def test_nemo_curvilinear_fieldset():
454454
U = parcels.Field("U", ds["U"], grid, interp_method=XLinear)
455455
V = parcels.Field("V", ds["V"], grid, interp_method=XLinear)
456456
U.units = parcels.GeographicPolar()
457-
V.units = parcels.GeographicPolar() # U and V need GoegraphicPolar for C-Grid interpolation to work correctly
457+
V.units = parcels.GeographicPolar() # U and V need GeographicPolar for C-Grid interpolation to work correctly
458458
UV = parcels.VectorField("UV", U, V, vector_interp_method=CGrid_Velocity)
459459
fieldset = parcels.FieldSet([U, V, UV])
460460

461461
npart = 20
462462
lonp = 30 * np.ones(npart)
463463
latp = np.linspace(-70, 88, npart)
464-
runtime = np.timedelta64(12, "h") # TODO increase to 160 days
465-
466-
def periodicBC(particles, fieldset): # pragma: no cover
467-
particles.dlon = np.where(particles.lon > 180, particles.dlon - 360, particles.dlon)
464+
runtime = np.timedelta64(160, "D")
468465

469466
pset = parcels.ParticleSet(fieldset, lon=lonp, lat=latp)
470-
pset.execute([AdvectionEE, periodicBC], runtime=runtime, dt=np.timedelta64(6, "h"))
467+
pset.execute(AdvectionEE, runtime=runtime, dt=np.timedelta64(6, "h"))
471468
np.testing.assert_allclose(pset.lat, latp, atol=1e-1)
472469

473470

tests/test_xgrid.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ def test_invalid_depth():
136136
XGrid.from_dataset(ds)
137137

138138

139+
def test_dim_without_axis():
140+
ds = xr.Dataset({"z1d": (["depth"], [0])}, coords={"depth": [0]})
141+
grid = XGrid.from_dataset(ds)
142+
with pytest.raises(ValueError, match='Dimension "depth" has no axis attribute*'):
143+
Field("z1d", ds["z1d"], grid, XLinear)
144+
145+
139146
def test_vertical1D_field():
140147
nz = 11
141148
ds = xr.Dataset(

0 commit comments

Comments
 (0)