Skip to content

Commit 649925a

Browse files
committed
Add testing of SGRID xgcm ingestion helpers
1 parent 6a204d7 commit 649925a

File tree

3 files changed

+171
-11
lines changed

3 files changed

+171
-11
lines changed

src/parcels/_core/sgrid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class SGridParsingException(Exception):
313313
pass
314314

315315

316-
def parse_grid(attrs: dict[str, Hashable]) -> Grid2DMetadata | Grid3DMetadata:
316+
def parse_grid_attrs(attrs: dict[str, Hashable]) -> Grid2DMetadata | Grid3DMetadata:
317317
grid: Grid2DMetadata | Grid3DMetadata
318318
try:
319319
grid = Grid2DMetadata.from_attrs(attrs)
@@ -341,7 +341,7 @@ def parse_sgrid(ds: xr.Dataset):
341341
try:
342342
grid_topology = get_grid_topology(ds)
343343
assert grid_topology is not None, "No grid_topology variable found in dataset"
344-
grid = parse_grid(grid_topology.attrs)
344+
grid = parse_grid_attrs(grid_topology.attrs)
345345

346346
except Exception as e:
347347
raise SGridParsingException(f"Error parsing {grid_topology=!r}") from e

tests/strategies/sgrid.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,15 @@
77
cover SGrid to the extent to which Parcels is concerned.
88
"""
99

10-
import string
11-
10+
import xarray.testing.strategies as xr_st
1211
from hypothesis import strategies as st
1312

1413
from parcels._core import sgrid
1514

16-
ALLOWED_DIM_LETTERS = (
17-
string.ascii_letters + string.digits + "_"
18-
) # We can make this more aligned with SGrid by adjusting our regex - but this is good for now
19-
2015
padding = st.sampled_from(sgrid.Padding)
21-
dimension_name = st.text(
22-
min_size=1, alphabet=st.characters(categories=(), whitelist_characters=ALLOWED_DIM_LETTERS)
23-
).filter(lambda s: " " not in s) # assuming for now spaces are allowed in dimension names in SGrid convention
16+
dimension_name = xr_st.names().filter(
17+
lambda s: " " not in s
18+
) # assuming for now spaces are allowed in dimension names in SGrid convention
2419
dim_dim_padding = (
2520
st.tuples(dimension_name, dimension_name, padding)
2621
.filter(lambda t: t[0] != t[1])

tests/test_sgrid.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,129 @@
1+
import numpy as np
12
import pytest
3+
import xarray as xr
4+
import xgcm
25
from hypothesis import example, given
36

47
from parcels._core import sgrid
58
from tests.strategies import sgrid as sgrid_strategies
69

710

11+
def get_unique_dim_names(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> set[str]:
12+
dims = set()
13+
dims.update(set(grid.node_dimensions))
14+
15+
for value in [
16+
grid.node_dimensions,
17+
grid.face_dimensions if isinstance(grid, sgrid.Grid2DMetadata) else grid.volume_dimensions,
18+
grid.vertical_dimensions if isinstance(grid, sgrid.Grid2DMetadata) else None,
19+
]:
20+
if value is None:
21+
continue
22+
for item in value:
23+
if isinstance(item, sgrid.DimDimPadding):
24+
dims.add(item.dim1)
25+
dims.add(item.dim2)
26+
else:
27+
assert isinstance(item, str)
28+
dims.add(item)
29+
return dims
30+
31+
32+
def dummy_sgrid_ds(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> xr.Dataset:
33+
if isinstance(grid, sgrid.Grid2DMetadata):
34+
return dummy_sgrid_2d_ds(grid)
35+
elif isinstance(grid, sgrid.Grid3DMetadata):
36+
return dummy_sgrid_3d_ds(grid)
37+
else:
38+
raise NotImplementedError(f"Cannot create dummy SGrid dataset for grid type {type(grid)}")
39+
40+
41+
def dummy_sgrid_2d_ds(grid: sgrid.Grid2DMetadata) -> xr.Dataset:
42+
ds = dummy_comodo_3d_ds()
43+
44+
renamings = {}
45+
if grid.vertical_dimensions is None:
46+
ds = ds.isel(ZC=0, ZG=0)
47+
else:
48+
renamings.update({"ZC": grid.vertical_dimensions[0].dim2, "ZG": grid.vertical_dimensions[0].dim1})
49+
50+
for old, new in zip(["XG", "YG"], grid.node_dimensions, strict=True):
51+
renamings[old] = new
52+
53+
for old, dim_dim_padding in zip(["XC", "YC"], grid.face_dimensions, strict=True):
54+
renamings[old] = dim_dim_padding.dim1
55+
56+
ds = ds.rename_dims(renamings)
57+
58+
ds["grid"] = xr.DataArray(1, attrs=grid.to_attrs())
59+
ds.attrs["convention"] = "SGRID"
60+
return ds
61+
62+
63+
def dummy_sgrid_3d_ds(grid: sgrid.Grid3DMetadata) -> xr.Dataset:
64+
ds = dummy_comodo_3d_ds()
65+
66+
renamings = {}
67+
for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True):
68+
renamings[old] = new
69+
70+
for old, dim_dim_padding in zip(["XC", "YC", "ZC"], grid.volume_dimensions, strict=True):
71+
renamings[old] = dim_dim_padding.dim1
72+
73+
ds = ds.rename_dims(renamings)
74+
75+
ds["grid"] = xr.DataArray(1, attrs=grid.to_attrs())
76+
ds.attrs["convention"] = "SGRID"
77+
return ds
78+
79+
80+
def dummy_comodo_3d_ds() -> xr.Dataset:
81+
T, Z, Y, X = 7, 6, 5, 4
82+
TIME = xr.date_range("2000", "2001", T)
83+
return xr.Dataset(
84+
{
85+
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
86+
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
87+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
88+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
89+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
90+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
91+
},
92+
coords={
93+
# "XG": (
94+
# ["XG"],
95+
# 2 * np.pi / X * np.arange(0, X),
96+
# {"axis": "X", "c_grid_axis_shift": -0.5},
97+
# ),
98+
# "XC": (["XC"], 2 * np.pi / X * (np.arange(0, X) + 0.5), {"axis": "X"}),
99+
# "YG": (
100+
# ["YG"],
101+
# 2 * np.pi / (Y) * np.arange(0, Y),
102+
# {"axis": "Y", "c_grid_axis_shift": -0.5},
103+
# ),
104+
# "YC": (
105+
# ["YC"],
106+
# 2 * np.pi / (Y) * (np.arange(0, Y) + 0.5),
107+
# {"axis": "Y"},
108+
# ),
109+
# "ZG": (
110+
# ["ZG"],
111+
# np.arange(Z),
112+
# {"axis": "Z", "c_grid_axis_shift": -0.5},
113+
# ),
114+
# "ZC": (
115+
# ["ZC"],
116+
# np.arange(Z) + 0.5,
117+
# {"axis": "Z"},
118+
# ),
119+
# "lon": (["XG"], 2 * np.pi / X * np.arange(0, X)),
120+
# "lat": (["YG"], 2 * np.pi / (Y) * np.arange(0, Y)),
121+
# "depth": (["ZG"], np.arange(Z)),
122+
"time": (["time"], TIME, {"axis": "T"}),
123+
},
124+
)
125+
126+
8127
@example(
9128
edge_node_padding=(
10129
sgrid.DimDimPadding("edge1", "node1", sgrid.Padding.NONE),
@@ -69,3 +188,49 @@ def test_Grid3DMetadata_roundtrip(grid: sgrid.Grid3DMetadata):
69188
attrs = grid.to_attrs()
70189
parsed = sgrid.Grid3DMetadata.from_attrs(attrs)
71190
assert parsed == grid
191+
192+
193+
@given(sgrid_strategies.grid_metadata)
194+
def test_parse_grid_attrs(grid: sgrid.SGridMetadataProtocol):
195+
attrs = grid.to_attrs()
196+
parsed = sgrid.parse_grid_attrs(attrs)
197+
assert parsed == grid
198+
199+
200+
@given(sgrid_strategies.grid2Dmetadata())
201+
def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata):
202+
"""Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided"""
203+
ds = dummy_sgrid_2d_ds(grid_metadata)
204+
205+
ds, xgcm_kwargs = sgrid.parse_sgrid(ds)
206+
grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs)
207+
208+
for ddp, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True):
209+
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
210+
coords = grid.axes[axis].coords
211+
assert coords["center"] == dim_edge
212+
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
213+
214+
if grid_metadata.vertical_dimensions is None:
215+
assert "Z" not in grid.axes
216+
else:
217+
ddp = grid_metadata.vertical_dimensions[0]
218+
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
219+
coords = grid.axes["Z"].coords
220+
assert coords["center"] == dim_edge
221+
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
222+
223+
224+
@given(sgrid_strategies.grid3Dmetadata())
225+
def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata):
226+
"""Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided"""
227+
ds = dummy_sgrid_3d_ds(grid_metadata)
228+
229+
ds, xgcm_kwargs = sgrid.parse_sgrid(ds)
230+
grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs)
231+
232+
for ddp, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True):
233+
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
234+
coords = grid.axes[axis].coords
235+
assert coords["center"] == dim_edge
236+
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node

0 commit comments

Comments
 (0)