Skip to content

Commit 418b105

Browse files
Removing default mesh in XGrid and UxGrid
This addresses #2408
1 parent 6f178dc commit 418b105

File tree

7 files changed

+35
-23
lines changed

7 files changed

+35
-23
lines changed

src/parcels/_core/fieldset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import xarray as xr
1111
import xgcm
1212

13-
from parcels._core.converters import Geographic, GeographicPolar
1413
from parcels._core.field import Field, VectorField
1514
from parcels._core.utils.string import _assert_str_and_python_varname
1615
from parcels._core.utils.time import get_datetime_type_calendar
@@ -277,15 +276,13 @@ def from_fesom2(ds: ux.UxDataset):
277276
raise ValueError(
278277
f"Dataset missing one of the required dimensions 'time', 'nz', or 'nz1'. Found dimensions {ds_dims}"
279278
)
280-
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
279+
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="spherical")
281280
ds = _discover_fesom2_U_and_V(ds)
282281

283282
fields = {}
284283
if "U" in ds.data_vars and "V" in ds.data_vars:
285284
fields["U"] = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"]))
286285
fields["V"] = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"]))
287-
fields["U"].units = GeographicPolar()
288-
fields["V"].units = Geographic()
289286

290287
if "W" in ds.data_vars:
291288
fields["W"] = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"]))

src/parcels/_core/uxgrid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class UxGrid(BaseGrid):
1818
for interpolation on unstructured grids.
1919
"""
2020

21-
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid:
21+
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> UxGrid:
2222
"""
2323
Initializes the UxGrid with a uxarray grid and vertical coordinate array.
2424
@@ -30,8 +30,8 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid
3030
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
3131
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
3232
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
33-
mesh : str, optional
34-
The type of mesh used for the grid. Either "flat" (default) or "spherical".
33+
mesh : str
34+
The type of mesh used for the grid. Either "flat" or "spherical".
3535
"""
3636
self.uxgrid = grid
3737
if not isinstance(z, ux.UxDataArray):

src/parcels/_core/xgrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class XGrid(BaseGrid):
102102
103103
"""
104104

105-
def __init__(self, grid: xgcm.Grid, mesh="flat"):
105+
def __init__(self, grid: xgcm.Grid, mesh):
106106
self.xgcm_grid = grid
107107
self._mesh = mesh
108108
self._spatialhash = None

tests/test_field.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_field_init_param_types():
5555
UxGrid(
5656
datasets_unstructured["stommel_gyre_delaunay"].uxgrid,
5757
z=datasets_unstructured["stommel_gyre_delaunay"].coords["nz"],
58+
mesh="flat",
5859
),
5960
id="xarray-uxgrid",
6061
),
@@ -194,7 +195,7 @@ def test_field_unstructured_z_linear():
194195
for k, z in enumerate(ds.coords["nz"]):
195196
ds["W"].values[:, k, :] = z
196197

197-
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
198+
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
198199
# Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1")
199200
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)
200201

@@ -232,7 +233,7 @@ def test_field_unstructured_z_linear():
232233
def test_field_constant_in_time():
233234
"""Tests field evaluation for a field with no time interval (i.e., constant in time)."""
234235
ds = datasets_unstructured["stommel_gyre_delaunay"]
235-
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
236+
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
236237
# Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1")
237238
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)
238239

tests/test_uxarray_fieldset.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ def uv_fesom_channel(ds_fesom_channel) -> VectorField:
3838
U=Field(
3939
name="U",
4040
data=ds_fesom_channel.U,
41-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
41+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
4242
interp_method=UXPiecewiseConstantFace,
4343
),
4444
V=Field(
4545
name="V",
4646
data=ds_fesom_channel.V,
47-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
47+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
4848
interp_method=UXPiecewiseConstantFace,
4949
),
5050
)
@@ -58,19 +58,19 @@ def uvw_fesom_channel(ds_fesom_channel) -> VectorField:
5858
U=Field(
5959
name="U",
6060
data=ds_fesom_channel.U,
61-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
61+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
6262
interp_method=UXPiecewiseConstantFace,
6363
),
6464
V=Field(
6565
name="V",
6666
data=ds_fesom_channel.V,
67-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
67+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
6868
interp_method=UXPiecewiseConstantFace,
6969
),
7070
W=Field(
7171
name="W",
7272
data=ds_fesom_channel.W,
73-
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
73+
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
7474
interp_method=UXPiecewiseLinearNode,
7575
),
7676
)
@@ -112,13 +112,14 @@ def test_fesom2_square_delaunay_uniform_z_coordinate_eval():
112112
Since the underlying data is constant, we can check that the values are as expected.
113113
"""
114114
ds = datasets_unstructured["fesom2_square_delaunay_uniform_z_coordinate"]
115+
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
115116
UVW = VectorField(
116117
name="UVW",
117-
U=Field(name="U", data=ds.U, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace),
118-
V=Field(name="V", data=ds.V, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace),
119-
W=Field(name="W", data=ds.W, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseLinearNode),
118+
U=Field(name="U", data=ds.U, grid=grid, interp_method=UXPiecewiseConstantFace),
119+
V=Field(name="V", data=ds.V, grid=grid, interp_method=UXPiecewiseConstantFace),
120+
W=Field(name="W", data=ds.W, grid=grid, interp_method=UXPiecewiseLinearNode),
120121
)
121-
P = Field(name="p", data=ds.p, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseLinearNode)
122+
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseLinearNode)
122123
fieldset = FieldSet([UVW, P, UVW.U, UVW.V, UVW.W])
123124

124125
assert fieldset.U.eval(time=ds.time[0].values, z=[1.0], y=[30.0], x=[30.0], applyConversion=False) == 1.0

tests/test_uxgrid.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,25 @@
66

77
@pytest.mark.parametrize("uxds", [pytest.param(uxds, id=key) for key, uxds in uxdatasets.items()])
88
def test_uxgrid_init_on_generic_datasets(uxds):
9-
UxGrid(uxds.uxgrid, z=uxds.coords["nz"])
9+
UxGrid(uxds.uxgrid, z=uxds.coords["nz"], mesh="flat")
1010

1111

1212
@pytest.mark.parametrize("uxds", [uxdatasets["stommel_gyre_delaunay"]])
1313
def test_uxgrid_axes(uxds):
14-
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"])
14+
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"], mesh="flat")
1515
assert grid.axes == ["Z", "FACE"]
1616

1717

18+
@pytest.mark.parametrize("uxds", [uxdatasets["stommel_gyre_delaunay"]])
19+
@pytest.mark.parametrize("mesh", ["flat", "spherical"])
20+
def test_uxgrid_mesh(uxds, mesh):
21+
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"], mesh=mesh)
22+
assert grid._mesh == mesh
23+
24+
1825
@pytest.mark.parametrize("uxds", [uxdatasets["stommel_gyre_delaunay"]])
1926
def test_xgrid_get_axis_dim(uxds):
20-
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"])
27+
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"], mesh="flat")
2128

2229
assert grid.get_axis_dim("FACE") == 721
2330
assert grid.get_axis_dim("Z") == 2

tests/test_xgrid.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,18 @@ def test_xgrid_from_dataset_on_generic_datasets(ds):
6565
XGrid.from_dataset(ds)
6666

6767

68-
@pytest.mark.parametrize("ds", [datasets["ds_2d_left"]])
6968
def test_xgrid_axes(ds):
7069
grid = XGrid.from_dataset(ds)
7170
assert grid.axes == ["Z", "Y", "X"]
7271

7372

73+
@pytest.mark.parametrize("ds", [datasets["ds_2d_left"]])
74+
@pytest.mark.parametrize("mesh", ["flat", "spherical"])
75+
def test_uxgrid_mesh(ds, mesh):
76+
grid = XGrid.from_dataset(ds, mesh=mesh)
77+
assert grid._mesh == mesh
78+
79+
7480
@pytest.mark.parametrize("ds", [datasets["ds_2d_left"]])
7581
def test_transpose_xfield_data_to_tzyx(ds):
7682
da = ds["data_g"]

0 commit comments

Comments
 (0)