Skip to content

Commit 55fb783

Browse files
Merge branch 'v4-dev' into adding_unity_unitconverter
2 parents 1881031 + 053c5b9 commit 55fb783

File tree

12 files changed

+72
-59
lines changed

12 files changed

+72
-59
lines changed

src/parcels/_core/fieldset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import xarray as xr
1111
import xgcm
1212

13-
from parcels._core.converters import Geographic, GeographicPolar
1413
from parcels._core.field import Field, VectorField
1514
from parcels._core.utils.string import _assert_str_and_python_varname
1615
from parcels._core.utils.time import get_datetime_type_calendar
@@ -277,15 +276,13 @@ def from_fesom2(ds: ux.UxDataset):
277276
raise ValueError(
278277
f"Dataset missing one of the required dimensions 'time', 'nz', or 'nz1'. Found dimensions {ds_dims}"
279278
)
280-
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
279+
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="spherical")
281280
ds = _discover_fesom2_U_and_V(ds)
282281

283282
fields = {}
284283
if "U" in ds.data_vars and "V" in ds.data_vars:
285284
fields["U"] = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"]))
286285
fields["V"] = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"]))
287-
fields["U"].units = GeographicPolar()
288-
fields["V"].units = Geographic()
289286

290287
if "W" in ds.data_vars:
291288
fields["W"] = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"]))

src/parcels/_core/uxgrid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class UxGrid(BaseGrid):
1818
for interpolation on unstructured grids.
1919
"""
2020

21-
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid:
21+
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> UxGrid:
2222
"""
2323
Initializes the UxGrid with a uxarray grid and vertical coordinate array.
2424
@@ -30,8 +30,8 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid
3030
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
3131
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
3232
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
33-
mesh : str, optional
34-
The type of mesh used for the grid. Either "flat" (default) or "spherical".
33+
mesh : str
34+
The type of mesh used for the grid. Either "flat" or "spherical".
3535
"""
3636
self.uxgrid = grid
3737
if not isinstance(z, ux.UxDataArray):

src/parcels/_core/xgrid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class XGrid(BaseGrid):
102102
103103
"""
104104

105-
def __init__(self, grid: xgcm.Grid, mesh="flat"):
105+
def __init__(self, grid: xgcm.Grid, mesh):
106106
self.xgcm_grid = grid
107107
self._mesh = mesh
108108
self._spatialhash = None
@@ -124,7 +124,7 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"):
124124
self._ds = ds
125125

126126
@classmethod
127-
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
127+
def from_dataset(cls, ds: xr.Dataset, mesh, xgcm_kwargs=None):
128128
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
129129
if xgcm_kwargs is None:
130130
xgcm_kwargs = {}

tests/test_advection.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_advection_zonal_periodic():
8484
halo.XG.values = ds.XG.values[1] + 2
8585
ds = xr.concat([ds, halo], dim="XG")
8686

87-
grid = XGrid.from_dataset(ds)
87+
grid = XGrid.from_dataset(ds, mesh="flat")
8888
U = Field("U", ds["U"], grid, interp_method=XLinear)
8989
V = Field("V", ds["V"], grid, interp_method=XLinear)
9090
UV = VectorField("UV", U, V)
@@ -103,7 +103,7 @@ def test_horizontal_advection_in_3D_flow(npart=10):
103103
"""Flat 2D zonal flow that increases linearly with z from 0 m/s to 1 m/s."""
104104
ds = simple_UV_dataset(mesh="flat")
105105
ds["U"].data[:] = 1.0
106-
grid = XGrid.from_dataset(ds)
106+
grid = XGrid.from_dataset(ds, mesh="flat")
107107
U = Field("U", ds["U"], grid, interp_method=XLinear)
108108
U.data[:, 0, :, :] = 0.0 # Set U to 0 at the surface
109109
V = Field("V", ds["V"], grid, interp_method=XLinear)
@@ -121,7 +121,7 @@ def test_horizontal_advection_in_3D_flow(npart=10):
121121
@pytest.mark.parametrize("wErrorThroughSurface", [True, False])
122122
def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
123123
ds = simple_UV_dataset(mesh="flat")
124-
grid = XGrid.from_dataset(ds)
124+
grid = XGrid.from_dataset(ds, mesh="flat")
125125
U = Field("U", ds["U"], grid, interp_method=XLinear)
126126
U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
127127
V = Field("V", ds["V"], grid, interp_method=XLinear)
@@ -202,7 +202,7 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read
202202
if w:
203203
ds["W"] = (["time", "depth", "YG", "XG"], W)
204204

205-
grid = XGrid.from_dataset(ds)
205+
grid = XGrid.from_dataset(ds, mesh="flat")
206206
U = Field("U", ds["U"], grid, interp_method=XLinear)
207207
V = Field("V", ds["V"], grid, interp_method=XLinear)
208208
fields = [U, V, VectorField("UV", U, V)]
@@ -263,7 +263,7 @@ def test_radialrotation(npart=10):
263263
)
264264
def test_moving_eddy(kernel, rtol):
265265
ds = moving_eddy_dataset()
266-
grid = XGrid.from_dataset(ds)
266+
grid = XGrid.from_dataset(ds, mesh="flat")
267267
U = Field("U", ds["U"], grid, interp_method=XLinear)
268268
V = Field("V", ds["V"], grid, interp_method=XLinear)
269269
if kernel in [AdvectionRK2_3D, AdvectionRK4_3D]:
@@ -315,7 +315,7 @@ def truth_moving(x_0, y_0, t):
315315
)
316316
def test_decaying_moving_eddy(kernel, rtol):
317317
ds = decaying_moving_eddy_dataset()
318-
grid = XGrid.from_dataset(ds)
318+
grid = XGrid.from_dataset(ds, mesh="flat")
319319
U = Field("U", ds["U"], grid, interp_method=XLinear)
320320
V = Field("V", ds["V"], grid, interp_method=XLinear)
321321
UV = VectorField("UV", U, V)
@@ -363,7 +363,7 @@ def truth_moving(x_0, y_0, t):
363363
def test_stommelgyre_fieldset(kernel, rtol, grid_type):
364364
npart = 2
365365
ds = stommel_gyre_dataset(grid_type=grid_type)
366-
grid = XGrid.from_dataset(ds)
366+
grid = XGrid.from_dataset(ds, mesh="flat")
367367
vector_interp_method = None if grid_type == "A" else CGrid_Velocity
368368
U = Field("U", ds["U"], grid, interp_method=XLinear)
369369
V = Field("V", ds["V"], grid, interp_method=XLinear)
@@ -404,7 +404,7 @@ def UpdateP(particles, fieldset): # pragma: no cover
404404
def test_peninsula_fieldset(kernel, rtol, grid_type):
405405
npart = 2
406406
ds = peninsula_dataset(grid_type=grid_type)
407-
grid = XGrid.from_dataset(ds)
407+
grid = XGrid.from_dataset(ds, mesh="flat")
408408
U = Field("U", ds["U"], grid, interp_method=XLinear)
409409
V = Field("V", ds["V"], grid, interp_method=XLinear)
410410
P = Field("P", ds["P"], grid, interp_method=XLinear)

tests/test_field.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def test_field_init_param_types():
1616
data = datasets_structured["ds_2d_left"]
17-
grid = XGrid.from_dataset(data)
17+
grid = XGrid.from_dataset(data, mesh="flat")
1818

1919
with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."):
2020
Field(name=123, data=data["data_g"], grid=grid, interp_method=XLinear)
@@ -47,14 +47,15 @@ def test_field_init_param_types():
4747
[
4848
pytest.param(
4949
ux.UxDataArray(),
50-
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
50+
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
5151
id="uxdata-grid",
5252
),
5353
pytest.param(
5454
xr.DataArray(),
5555
UxGrid(
5656
datasets_unstructured["stommel_gyre_delaunay"].uxgrid,
5757
z=datasets_unstructured["stommel_gyre_delaunay"].coords["nz"],
58+
mesh="flat",
5859
),
5960
id="xarray-uxgrid",
6061
),
@@ -75,7 +76,7 @@ def test_field_incompatible_combination(data, grid):
7576
[
7677
pytest.param(
7778
datasets_structured["ds_2d_left"]["data_g"],
78-
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
79+
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
7980
id="ds_2d_left",
8081
), # TODO: Perhaps this test should be expanded to cover more datasets?
8182
],
@@ -106,7 +107,7 @@ def test_field_init_fail_on_float_time_dim():
106107
)
107108

108109
data = ds["data_g"]
109-
grid = XGrid.from_dataset(ds)
110+
grid = XGrid.from_dataset(ds, mesh="flat")
110111
with pytest.raises(
111112
ValueError,
112113
match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?",
@@ -124,7 +125,7 @@ def test_field_init_fail_on_float_time_dim():
124125
[
125126
pytest.param(
126127
datasets_structured["ds_2d_left"]["data_g"],
127-
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
128+
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
128129
id="ds_2d_left",
129130
),
130131
],
@@ -143,7 +144,7 @@ def test_vectorfield_init_different_time_intervals():
143144

144145
def test_field_invalid_interpolator():
145146
ds = datasets_structured["ds_2d_left"]
146-
grid = XGrid.from_dataset(ds)
147+
grid = XGrid.from_dataset(ds, mesh="flat")
147148

148149
def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid):
149150
return 0.0
@@ -160,7 +161,7 @@ def invalid_interpolator_wrong_signature(particle_positions, grid_positions, inv
160161

161162
def test_vectorfield_invalid_interpolator():
162163
ds = datasets_structured["ds_2d_left"]
163-
grid = XGrid.from_dataset(ds)
164+
grid = XGrid.from_dataset(ds, mesh="flat")
164165

165166
def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid):
166167
return 0.0
@@ -194,7 +195,7 @@ def test_field_unstructured_z_linear():
194195
for k, z in enumerate(ds.coords["nz"]):
195196
ds["W"].values[:, k, :] = z
196197

197-
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
198+
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
198199
# Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1")
199200
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)
200201

@@ -232,7 +233,7 @@ def test_field_unstructured_z_linear():
232233
def test_field_constant_in_time():
233234
"""Tests field evaluation for a field with no time interval (i.e., constant in time)."""
234235
ds = datasets_unstructured["stommel_gyre_delaunay"]
235-
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
236+
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
236237
# Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1")
237238
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)
238239

tests/test_index_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
@pytest.fixture
1414
def field_cone():
1515
ds = datasets["2d_left_unrolled_cone"]
16-
grid = XGrid.from_dataset(ds)
16+
grid = XGrid.from_dataset(ds, mesh="flat")
1717
field = Field(
1818
name="test_field",
1919
data=ds["data_g"],

tests/test_interpolation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def field():
5252
"y": (["y"], [0.5, 1.5, 2.5, 3.5], {"axis": "Y"}),
5353
},
5454
)
55-
return Field("U", ds["U"], XGrid.from_dataset(ds), interp_method=XLinear)
55+
return Field("U", ds["U"], XGrid.from_dataset(ds, mesh="flat"), interp_method=XLinear)
5656

5757

5858
@pytest.mark.parametrize(

tests/test_particlefile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remove duplicates
3333
"""Fixture to create a FieldSet object for testing."""
3434
ds = datasets["ds_2d_left"]
35-
grid = XGrid.from_dataset(ds)
35+
grid = XGrid.from_dataset(ds, mesh="flat")
3636
U = Field("U", ds["U_A_grid"], grid, XLinear)
3737
V = Field("V", ds["V_A_grid"], grid, XLinear)
3838
UV = VectorField("UV", U, V)
@@ -73,7 +73,7 @@ def test_pfile_array_write_zarr_memorystore(fieldset):
7373
def test_write_fieldset_without_time(tmp_zarrfile):
7474
ds = peninsula_dataset() # DataSet without time
7575
assert "time" not in ds.dims
76-
grid = XGrid.from_dataset(ds)
76+
grid = XGrid.from_dataset(ds, mesh="flat")
7777
fieldset = FieldSet([Field("U", ds["U"], grid, XLinear)])
7878

7979
pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0)

tests/test_spatialhash.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
def test_spatialhash_init():
88
ds = datasets["2d_left_rotated"]
9-
grid = XGrid.from_dataset(ds)
9+
grid = XGrid.from_dataset(ds, mesh="flat")
1010
spatialhash = grid.get_spatial_hash()
1111
assert spatialhash is not None
1212

1313

1414
def test_invalid_positions():
1515
ds = datasets["2d_left_rotated"]
16-
grid = XGrid.from_dataset(ds)
16+
grid = XGrid.from_dataset(ds, mesh="flat")
1717

1818
j, i, coords = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf])
1919
assert np.all(j == -3)
@@ -22,7 +22,7 @@ def test_invalid_positions():
2222

2323
def test_mixed_positions():
2424
ds = datasets["2d_left_rotated"]
25-
grid = XGrid.from_dataset(ds)
25+
grid = XGrid.from_dataset(ds, mesh="flat")
2626
lat = grid.lat.mean()
2727
lon = grid.lon.mean()
2828
y = [lat, np.nan]

tests/test_uxarray_fieldset.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ def uv_fesom_channel(ds_fesom_channel) -> VectorField:
3838
U=Field(
3939
name="U",
4040
data=ds_fesom_channel.U,
41-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
41+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
4242
interp_method=UXPiecewiseConstantFace,
4343
),
4444
V=Field(
4545
name="V",
4646
data=ds_fesom_channel.V,
47-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
47+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
4848
interp_method=UXPiecewiseConstantFace,
4949
),
5050
)
@@ -58,19 +58,19 @@ def uvw_fesom_channel(ds_fesom_channel) -> VectorField:
5858
U=Field(
5959
name="U",
6060
data=ds_fesom_channel.U,
61-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
61+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
6262
interp_method=UXPiecewiseConstantFace,
6363
),
6464
V=Field(
6565
name="V",
6666
data=ds_fesom_channel.V,
67-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
67+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
6868
interp_method=UXPiecewiseConstantFace,
6969
),
7070
W=Field(
7171
name="W",
7272
data=ds_fesom_channel.W,
73-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
73+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
7474
interp_method=UXPiecewiseLinearNode,
7575
),
7676
)
@@ -112,13 +112,14 @@ def test_fesom2_square_delaunay_uniform_z_coordinate_eval():
112112
Since the underlying data is constant, we can check that the values are as expected.
113113
"""
114114
ds = datasets_unstructured["fesom2_square_delaunay_uniform_z_coordinate"]
115+
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
115116
UVW = VectorField(
116117
name="UVW",
117-
U=Field(name="U", data=ds.U, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace),
118-
V=Field(name="V", data=ds.V, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace),
119-
W=Field(name="W", data=ds.W, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseLinearNode),
118+
U=Field(name="U", data=ds.U, grid=grid, interp_method=UXPiecewiseConstantFace),
119+
V=Field(name="V", data=ds.V, grid=grid, interp_method=UXPiecewiseConstantFace),
120+
W=Field(name="W", data=ds.W, grid=grid, interp_method=UXPiecewiseLinearNode),
120121
)
121-
P = Field(name="p", data=ds.p, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseLinearNode)
122+
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseLinearNode)
122123
fieldset = FieldSet([UVW, P, UVW.U, UVW.V, UVW.W])
123124

124125
assert fieldset.U.eval(time=ds.time[0].values, z=[1.0], y=[30.0], x=[30.0], applyConversion=False) == 1.0

0 commit comments

Comments
 (0)