Skip to content

Commit 6442f06

Browse files
Add submodule parcels._core.sgrid and tests (#2418)
Co-authored-by: Erik van Sebille <[email protected]>
1 parent 6210166 commit 6442f06

File tree

3 files changed

+712
-0
lines changed

3 files changed

+712
-0
lines changed

src/parcels/_core/utils/sgrid.py

Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
"""
2+
Provides helpers and utils for working with SGrid conventions, as well as data objects
3+
useful for representing the SGRID metadata model in code.
4+
5+
This code is best read alongside the SGrid conventions documentation:
6+
https://sgrid.github.io/sgrid/
7+
8+
Note this code doesn't aim to completely cover the SGrid conventions, but aim to
9+
cover SGrid to the extent to which Parcels is concerned.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import enum
15+
import re
16+
from collections.abc import Hashable, Iterable
17+
from dataclasses import dataclass
18+
from typing import Any, Literal, Protocol, Self, overload
19+
20+
import xarray as xr
21+
22+
RE_DIM_DIM_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)"
23+
24+
Dim = str
25+
26+
27+
class Padding(enum.Enum):
28+
NONE = "none"
29+
LOW = "low"
30+
HIGH = "high"
31+
BOTH = "both"
32+
33+
34+
class SGridMetadataProtocol(Protocol):
35+
def to_attrs(self) -> dict[str, str | int]: ...
36+
def from_attrs(cls, d: dict[str, Hashable]) -> Self: ...
37+
38+
39+
class Grid2DMetadata(SGridMetadataProtocol):
40+
def __init__(
41+
self,
42+
cf_role: Literal["grid_topology"],
43+
topology_dimension: Literal[2],
44+
node_dimensions: tuple[Dim, Dim],
45+
face_dimensions: tuple[DimDimPadding, DimDimPadding],
46+
vertical_dimensions: None | tuple[DimDimPadding] = None,
47+
):
48+
if cf_role != "grid_topology":
49+
raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}")
50+
51+
if topology_dimension != 2:
52+
raise ValueError("topology_dimension must be 2 for a 2D grid")
53+
54+
if not (
55+
isinstance(node_dimensions, tuple)
56+
and len(node_dimensions) == 2
57+
and all(isinstance(nd, str) for nd in node_dimensions)
58+
):
59+
raise ValueError("node_dimensions must be a tuple of 2 dimensions for a 2D grid")
60+
61+
if not (
62+
isinstance(face_dimensions, tuple)
63+
and len(face_dimensions) == 2
64+
and all(isinstance(fd, DimDimPadding) for fd in face_dimensions)
65+
):
66+
raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid")
67+
68+
if vertical_dimensions is not None:
69+
if not (
70+
isinstance(vertical_dimensions, tuple)
71+
and len(vertical_dimensions) == 1
72+
and isinstance(vertical_dimensions[0], DimDimPadding)
73+
):
74+
raise ValueError("vertical_dimensions must be a tuple of 1 DimDimPadding for a 2D grid")
75+
76+
# Required attributes
77+
self.cf_role = cf_role
78+
self.topology_dimension = topology_dimension
79+
self.node_dimensions = node_dimensions
80+
self.face_dimensions = face_dimensions
81+
82+
#! Optional attributes aren't really important to Parcels, can be added later if needed
83+
# Optional attributes
84+
# # With defaults (set in init)
85+
# edge1_dimensions: tuple[Dim, DimDimPadding]
86+
# edge2_dimensions: tuple[DimDimPadding, Dim]
87+
88+
# # Without defaults
89+
# node_coordinates: None | Any = None
90+
# edge1_coordinates: None | Any = None
91+
# edge2_coordinates: None | Any = None
92+
# face_coordinate: None | Any = None
93+
94+
#! Important optional attribute for 2D grids with vertical layering
95+
self.vertical_dimensions = vertical_dimensions
96+
97+
def __eq__(self, other: Any) -> bool:
98+
if not isinstance(other, Grid2DMetadata):
99+
return NotImplemented
100+
return (
101+
self.cf_role == other.cf_role
102+
and self.topology_dimension == other.topology_dimension
103+
and self.node_dimensions == other.node_dimensions
104+
and self.face_dimensions == other.face_dimensions
105+
and self.vertical_dimensions == other.vertical_dimensions
106+
)
107+
108+
@classmethod
109+
def from_attrs(cls, attrs):
110+
try:
111+
return cls(
112+
cf_role=attrs["cf_role"],
113+
topology_dimension=attrs["topology_dimension"],
114+
node_dimensions=load_mappings(attrs["node_dimensions"]),
115+
face_dimensions=load_mappings(attrs["face_dimensions"]),
116+
vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")),
117+
)
118+
except Exception as e:
119+
raise SGridParsingException(f"Failed to parse Grid2DMetadata from {attrs=!r}") from e
120+
121+
def to_attrs(self) -> dict[str, str | int]:
122+
d = dict(
123+
cf_role=self.cf_role,
124+
topology_dimension=self.topology_dimension,
125+
node_dimensions=dump_mappings(self.node_dimensions),
126+
face_dimensions=dump_mappings(self.face_dimensions),
127+
)
128+
if self.vertical_dimensions is not None:
129+
d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions)
130+
return d
131+
132+
133+
class Grid3DMetadata(SGridMetadataProtocol):
134+
def __init__(
135+
self,
136+
cf_role: Literal["grid_topology"],
137+
topology_dimension: Literal[3],
138+
node_dimensions: tuple[Dim, Dim, Dim],
139+
volume_dimensions: tuple[DimDimPadding, DimDimPadding, DimDimPadding],
140+
):
141+
if cf_role != "grid_topology":
142+
raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}")
143+
144+
if topology_dimension != 3:
145+
raise ValueError("topology_dimension must be 3 for a 3D grid")
146+
147+
if not (
148+
isinstance(node_dimensions, tuple)
149+
and len(node_dimensions) == 3
150+
and all(isinstance(nd, str) for nd in node_dimensions)
151+
):
152+
raise ValueError("node_dimensions must be a tuple of 3 dimensions for a 3D grid")
153+
154+
if not (
155+
isinstance(volume_dimensions, tuple)
156+
and len(volume_dimensions) == 3
157+
and all(isinstance(fd, DimDimPadding) for fd in volume_dimensions)
158+
):
159+
raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid")
160+
161+
# Required attributes
162+
self.cf_role = cf_role
163+
self.topology_dimension = topology_dimension
164+
self.node_dimensions = node_dimensions
165+
self.volume_dimensions = volume_dimensions
166+
167+
# ! Optional attributes aren't really important to Parcels, can be added later if needed
168+
# Optional attributes
169+
# # With defaults (set in init)
170+
# edge1_dimensions: tuple[DimDimPadding, Dim, Dim]
171+
# edge2_dimensions: tuple[Dim, DimDimPadding, Dim]
172+
# edge3_dimensions: tuple[Dim, Dim, DimDimPadding]
173+
# face1_dimensions: tuple[Dim, DimDimPadding, DimDimPadding]
174+
# face2_dimensions: tuple[DimDimPadding, Dim, DimDimPadding]
175+
# face3_dimensions: tuple[DimDimPadding, DimDimPadding, Dim]
176+
177+
# # Without defaults
178+
# node_coordinates
179+
# edge *i_coordinates*
180+
# face *i_coordinates*
181+
# volume_coordinates
182+
183+
def __eq__(self, other: Any) -> bool:
184+
if not isinstance(other, Grid3DMetadata):
185+
return NotImplemented
186+
return (
187+
self.cf_role == other.cf_role
188+
and self.topology_dimension == other.topology_dimension
189+
and self.node_dimensions == other.node_dimensions
190+
and self.volume_dimensions == other.volume_dimensions
191+
)
192+
193+
@classmethod
194+
def from_attrs(cls, attrs):
195+
try:
196+
return cls(
197+
cf_role=attrs["cf_role"],
198+
topology_dimension=attrs["topology_dimension"],
199+
node_dimensions=load_mappings(attrs["node_dimensions"]),
200+
volume_dimensions=load_mappings(attrs["volume_dimensions"]),
201+
)
202+
except Exception as e:
203+
raise SGridParsingException(f"Failed to parse Grid3DMetadata from {attrs=!r}") from e
204+
205+
def to_attrs(self) -> dict[str, str | int]:
206+
return dict(
207+
cf_role=self.cf_role,
208+
topology_dimension=self.topology_dimension,
209+
node_dimensions=dump_mappings(self.node_dimensions),
210+
volume_dimensions=dump_mappings(self.volume_dimensions),
211+
)
212+
213+
214+
@dataclass
215+
class DimDimPadding:
216+
"""A data class representing a dimension-dimension-padding triplet for SGrid metadata.
217+
218+
This triplet can represent different relations depending on context within the standard
219+
For example - for "face_dimensions" this can show the relation between an edge (dim1) and a node
220+
(dim2).
221+
"""
222+
223+
dim1: str
224+
dim2: str
225+
padding: Padding
226+
227+
def __repr__(self) -> str:
228+
return f"DimDimPadding(dim1={self.dim1!r}, dim2={self.dim2!r}, padding={self.padding!r})"
229+
230+
def __str__(self) -> str:
231+
return f"{self.dim1}:{self.dim2} (padding:{self.padding.value})"
232+
233+
@classmethod
234+
def load(cls, s: str) -> Self:
235+
match = re.match(RE_DIM_DIM_PADDING, s)
236+
if not match:
237+
raise ValueError(f"String {s!r} does not match expected format for DimDimPadding")
238+
dim1 = match.group(1)
239+
dim2 = match.group(2)
240+
padding = Padding(match.group(3).lower())
241+
return cls(dim1, dim2, padding)
242+
243+
244+
def dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str:
245+
"""Takes in a list of edge-node-padding tuples and serializes them into a string
246+
according to the SGrid convention.
247+
"""
248+
ret = []
249+
for part in parts:
250+
ret.append(str(part))
251+
return " ".join(ret)
252+
253+
254+
@overload
255+
def maybe_dump_mappings(parts: None) -> None: ...
256+
@overload
257+
def maybe_dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str: ...
258+
259+
260+
def maybe_dump_mappings(parts):
261+
if parts is None:
262+
return None
263+
return dump_mappings(parts)
264+
265+
266+
def load_mappings(s: str) -> tuple[DimDimPadding | Dim, ...]:
267+
"""Takes in a string indicating the mappings of dims and dim-dim-padding
268+
and returns a tuple with this data destructured.
269+
270+
Treats `:` and `: ` equivalently (in line with the convention).
271+
"""
272+
if not isinstance(s, str):
273+
raise ValueError(f"Expected string input, got {s!r} of type {type(s)}")
274+
275+
s = s.replace(": ", ":")
276+
ret = []
277+
while s:
278+
# find next part
279+
match = re.match(RE_DIM_DIM_PADDING, s)
280+
if match and match.start() == 0:
281+
# match found at start, take that as next part
282+
part = match.group(0)
283+
s_new = s[match.end() :].lstrip()
284+
else:
285+
# no DimDimPadding match at start, assume just a Dim until next space
286+
part, *s_new = s.split(" ", 1)
287+
s_new = "".join(s_new)
288+
289+
assert s != s_new, f"SGrid parsing did not advance, stuck at {s!r}"
290+
291+
parsed: DimDimPadding | Dim
292+
try:
293+
parsed = DimDimPadding.load(part)
294+
except ValueError as e:
295+
e.add_note(f"Failed to parse part {part!r} from {s!r} as a dimension dimension padding string")
296+
try:
297+
# Not a DimDimPadding, assume it's just a Dim
298+
assert ":" not in part, f"Part {part!r} from {s!r} not a valid dim (contains ':')"
299+
parsed = part
300+
except AssertionError as e2:
301+
raise e2 from e
302+
303+
ret.append(parsed)
304+
s = s_new
305+
306+
return tuple(ret)
307+
308+
309+
@overload
310+
def maybe_load_mappings(s: None) -> None: ...
311+
@overload
312+
def maybe_load_mappings(s: Hashable) -> tuple[DimDimPadding | Dim, ...]: ...
313+
314+
315+
def maybe_load_mappings(s):
316+
if s is None:
317+
return None
318+
return load_mappings(s)
319+
320+
321+
SGRID_PADDING_TO_XGCM_POSITION = {
322+
Padding.LOW: "right",
323+
Padding.HIGH: "left",
324+
Padding.BOTH: "inner",
325+
Padding.NONE: "outer",
326+
# "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves
327+
}
328+
329+
330+
class SGridParsingException(Exception):
331+
"""Exception raised when parsing SGrid attributes fails."""
332+
333+
pass
334+
335+
336+
def parse_grid_attrs(attrs: dict[str, Hashable]) -> Grid2DMetadata | Grid3DMetadata:
337+
grid: Grid2DMetadata | Grid3DMetadata
338+
try:
339+
grid = Grid2DMetadata.from_attrs(attrs)
340+
except Exception as e:
341+
e.add_note("Failed to parse as 2D SGrid, trying 3D SGrid")
342+
try:
343+
grid = Grid3DMetadata.from_attrs(attrs)
344+
except Exception as e2:
345+
e2.add_note("Failed to parse as 3D SGrid")
346+
raise SGridParsingException("Failed to parse SGrid metadata as either 2D or 3D grid") from e2
347+
return grid
348+
349+
350+
def get_grid_topology(ds: xr.Dataset) -> xr.DataArray | None:
351+
"""Extracts grid topology DataArray from an xarray Dataset."""
352+
for var_name in ds.variables:
353+
if ds[var_name].attrs.get("cf_role") == "grid_topology":
354+
return ds[var_name]
355+
return None
356+
357+
358+
def parse_sgrid(ds: xr.Dataset):
359+
# Function similar to that provided in `xgcm.metadata_parsers.
360+
# Might at some point be upstreamed to xgcm directly
361+
try:
362+
grid_topology = get_grid_topology(ds)
363+
assert grid_topology is not None, "No grid_topology variable found in dataset"
364+
grid = parse_grid_attrs(grid_topology.attrs)
365+
366+
except Exception as e:
367+
raise SGridParsingException(f"Error parsing {grid_topology=!r}") from e
368+
369+
if isinstance(grid, Grid2DMetadata):
370+
dimensions = grid.face_dimensions + (grid.vertical_dimensions or ())
371+
else:
372+
assert isinstance(grid, Grid3DMetadata)
373+
dimensions = grid.volume_dimensions
374+
375+
xgcm_coords = {}
376+
for dim_dim_padding, axis in zip(dimensions, "XYZ", strict=False):
377+
xgcm_position = SGRID_PADDING_TO_XGCM_POSITION[dim_dim_padding.padding]
378+
xgcm_coords[axis] = {"center": dim_dim_padding.dim2, xgcm_position: dim_dim_padding.dim1}
379+
380+
return (ds, {"coords": xgcm_coords})

0 commit comments

Comments
 (0)