diff --git a/src/parcels/_core/utils/sgrid.py b/src/parcels/_core/utils/sgrid.py new file mode 100644 index 000000000..825649cf2 --- /dev/null +++ b/src/parcels/_core/utils/sgrid.py @@ -0,0 +1,380 @@ +""" +Provides helpers and utils for working with SGrid conventions, as well as data objects +useful for representing the SGRID metadata model in code. + +This code is best read alongside the SGrid conventions documentation: +https://sgrid.github.io/sgrid/ + +Note this code doesn't aim to completely cover the SGrid conventions, but aim to +cover SGrid to the extent to which Parcels is concerned. +""" + +from __future__ import annotations + +import enum +import re +from collections.abc import Hashable, Iterable +from dataclasses import dataclass +from typing import Any, Literal, Protocol, Self, overload + +import xarray as xr + +RE_DIM_DIM_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)" + +Dim = str + + +class Padding(enum.Enum): + NONE = "none" + LOW = "low" + HIGH = "high" + BOTH = "both" + + +class SGridMetadataProtocol(Protocol): + def to_attrs(self) -> dict[str, str | int]: ... + def from_attrs(cls, d: dict[str, Hashable]) -> Self: ... + + +class Grid2DMetadata(SGridMetadataProtocol): + def __init__( + self, + cf_role: Literal["grid_topology"], + topology_dimension: Literal[2], + node_dimensions: tuple[Dim, Dim], + face_dimensions: tuple[DimDimPadding, DimDimPadding], + vertical_dimensions: None | tuple[DimDimPadding] = None, + ): + if cf_role != "grid_topology": + raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}") + + if topology_dimension != 2: + raise ValueError("topology_dimension must be 2 for a 2D grid") + + if not ( + isinstance(node_dimensions, tuple) + and len(node_dimensions) == 2 + and all(isinstance(nd, str) for nd in node_dimensions) + ): + raise ValueError("node_dimensions must be a tuple of 2 dimensions for a 2D grid") + + if not ( + isinstance(face_dimensions, tuple) + and len(face_dimensions) == 2 + and all(isinstance(fd, DimDimPadding) for fd in face_dimensions) + ): + raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid") + + if vertical_dimensions is not None: + if not ( + isinstance(vertical_dimensions, tuple) + and len(vertical_dimensions) == 1 + and isinstance(vertical_dimensions[0], DimDimPadding) + ): + raise ValueError("vertical_dimensions must be a tuple of 1 DimDimPadding for a 2D grid") + + # Required attributes + self.cf_role = cf_role + self.topology_dimension = topology_dimension + self.node_dimensions = node_dimensions + self.face_dimensions = face_dimensions + + #! Optional attributes aren't really important to Parcels, can be added later if needed + # Optional attributes + # # With defaults (set in init) + # edge1_dimensions: tuple[Dim, DimDimPadding] + # edge2_dimensions: tuple[DimDimPadding, Dim] + + # # Without defaults + # node_coordinates: None | Any = None + # edge1_coordinates: None | Any = None + # edge2_coordinates: None | Any = None + # face_coordinate: None | Any = None + + #! Important optional attribute for 2D grids with vertical layering + self.vertical_dimensions = vertical_dimensions + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Grid2DMetadata): + return NotImplemented + return ( + self.cf_role == other.cf_role + and self.topology_dimension == other.topology_dimension + and self.node_dimensions == other.node_dimensions + and self.face_dimensions == other.face_dimensions + and self.vertical_dimensions == other.vertical_dimensions + ) + + @classmethod + def from_attrs(cls, attrs): + try: + return cls( + cf_role=attrs["cf_role"], + topology_dimension=attrs["topology_dimension"], + node_dimensions=load_mappings(attrs["node_dimensions"]), + face_dimensions=load_mappings(attrs["face_dimensions"]), + vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")), + ) + except Exception as e: + raise SGridParsingException(f"Failed to parse Grid2DMetadata from {attrs=!r}") from e + + def to_attrs(self) -> dict[str, str | int]: + d = dict( + cf_role=self.cf_role, + topology_dimension=self.topology_dimension, + node_dimensions=dump_mappings(self.node_dimensions), + face_dimensions=dump_mappings(self.face_dimensions), + ) + if self.vertical_dimensions is not None: + d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions) + return d + + +class Grid3DMetadata(SGridMetadataProtocol): + def __init__( + self, + cf_role: Literal["grid_topology"], + topology_dimension: Literal[3], + node_dimensions: tuple[Dim, Dim, Dim], + volume_dimensions: tuple[DimDimPadding, DimDimPadding, DimDimPadding], + ): + if cf_role != "grid_topology": + raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}") + + if topology_dimension != 3: + raise ValueError("topology_dimension must be 3 for a 3D grid") + + if not ( + isinstance(node_dimensions, tuple) + and len(node_dimensions) == 3 + and all(isinstance(nd, str) for nd in node_dimensions) + ): + raise ValueError("node_dimensions must be a tuple of 3 dimensions for a 3D grid") + + if not ( + isinstance(volume_dimensions, tuple) + and len(volume_dimensions) == 3 + and all(isinstance(fd, DimDimPadding) for fd in volume_dimensions) + ): + raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid") + + # Required attributes + self.cf_role = cf_role + self.topology_dimension = topology_dimension + self.node_dimensions = node_dimensions + self.volume_dimensions = volume_dimensions + + # ! Optional attributes aren't really important to Parcels, can be added later if needed + # Optional attributes + # # With defaults (set in init) + # edge1_dimensions: tuple[DimDimPadding, Dim, Dim] + # edge2_dimensions: tuple[Dim, DimDimPadding, Dim] + # edge3_dimensions: tuple[Dim, Dim, DimDimPadding] + # face1_dimensions: tuple[Dim, DimDimPadding, DimDimPadding] + # face2_dimensions: tuple[DimDimPadding, Dim, DimDimPadding] + # face3_dimensions: tuple[DimDimPadding, DimDimPadding, Dim] + + # # Without defaults + # node_coordinates + # edge *i_coordinates* + # face *i_coordinates* + # volume_coordinates + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Grid3DMetadata): + return NotImplemented + return ( + self.cf_role == other.cf_role + and self.topology_dimension == other.topology_dimension + and self.node_dimensions == other.node_dimensions + and self.volume_dimensions == other.volume_dimensions + ) + + @classmethod + def from_attrs(cls, attrs): + try: + return cls( + cf_role=attrs["cf_role"], + topology_dimension=attrs["topology_dimension"], + node_dimensions=load_mappings(attrs["node_dimensions"]), + volume_dimensions=load_mappings(attrs["volume_dimensions"]), + ) + except Exception as e: + raise SGridParsingException(f"Failed to parse Grid3DMetadata from {attrs=!r}") from e + + def to_attrs(self) -> dict[str, str | int]: + return dict( + cf_role=self.cf_role, + topology_dimension=self.topology_dimension, + node_dimensions=dump_mappings(self.node_dimensions), + volume_dimensions=dump_mappings(self.volume_dimensions), + ) + + +@dataclass +class DimDimPadding: + """A data class representing a dimension-dimension-padding triplet for SGrid metadata. + + This triplet can represent different relations depending on context within the standard + For example - for "face_dimensions" this can show the relation between an edge (dim1) and a node + (dim2). + """ + + dim1: str + dim2: str + padding: Padding + + def __repr__(self) -> str: + return f"DimDimPadding(dim1={self.dim1!r}, dim2={self.dim2!r}, padding={self.padding!r})" + + def __str__(self) -> str: + return f"{self.dim1}:{self.dim2} (padding:{self.padding.value})" + + @classmethod + def load(cls, s: str) -> Self: + match = re.match(RE_DIM_DIM_PADDING, s) + if not match: + raise ValueError(f"String {s!r} does not match expected format for DimDimPadding") + dim1 = match.group(1) + dim2 = match.group(2) + padding = Padding(match.group(3).lower()) + return cls(dim1, dim2, padding) + + +def dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str: + """Takes in a list of edge-node-padding tuples and serializes them into a string + according to the SGrid convention. + """ + ret = [] + for part in parts: + ret.append(str(part)) + return " ".join(ret) + + +@overload +def maybe_dump_mappings(parts: None) -> None: ... +@overload +def maybe_dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str: ... + + +def maybe_dump_mappings(parts): + if parts is None: + return None + return dump_mappings(parts) + + +def load_mappings(s: str) -> tuple[DimDimPadding | Dim, ...]: + """Takes in a string indicating the mappings of dims and dim-dim-padding + and returns a tuple with this data destructured. + + Treats `:` and `: ` equivalently (in line with the convention). + """ + if not isinstance(s, str): + raise ValueError(f"Expected string input, got {s!r} of type {type(s)}") + + s = s.replace(": ", ":") + ret = [] + while s: + # find next part + match = re.match(RE_DIM_DIM_PADDING, s) + if match and match.start() == 0: + # match found at start, take that as next part + part = match.group(0) + s_new = s[match.end() :].lstrip() + else: + # no DimDimPadding match at start, assume just a Dim until next space + part, *s_new = s.split(" ", 1) + s_new = "".join(s_new) + + assert s != s_new, f"SGrid parsing did not advance, stuck at {s!r}" + + parsed: DimDimPadding | Dim + try: + parsed = DimDimPadding.load(part) + except ValueError as e: + e.add_note(f"Failed to parse part {part!r} from {s!r} as a dimension dimension padding string") + try: + # Not a DimDimPadding, assume it's just a Dim + assert ":" not in part, f"Part {part!r} from {s!r} not a valid dim (contains ':')" + parsed = part + except AssertionError as e2: + raise e2 from e + + ret.append(parsed) + s = s_new + + return tuple(ret) + + +@overload +def maybe_load_mappings(s: None) -> None: ... +@overload +def maybe_load_mappings(s: Hashable) -> tuple[DimDimPadding | Dim, ...]: ... + + +def maybe_load_mappings(s): + if s is None: + return None + return load_mappings(s) + + +SGRID_PADDING_TO_XGCM_POSITION = { + Padding.LOW: "right", + Padding.HIGH: "left", + Padding.BOTH: "inner", + Padding.NONE: "outer", + # "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves +} + + +class SGridParsingException(Exception): + """Exception raised when parsing SGrid attributes fails.""" + + pass + + +def parse_grid_attrs(attrs: dict[str, Hashable]) -> Grid2DMetadata | Grid3DMetadata: + grid: Grid2DMetadata | Grid3DMetadata + try: + grid = Grid2DMetadata.from_attrs(attrs) + except Exception as e: + e.add_note("Failed to parse as 2D SGrid, trying 3D SGrid") + try: + grid = Grid3DMetadata.from_attrs(attrs) + except Exception as e2: + e2.add_note("Failed to parse as 3D SGrid") + raise SGridParsingException("Failed to parse SGrid metadata as either 2D or 3D grid") from e2 + return grid + + +def get_grid_topology(ds: xr.Dataset) -> xr.DataArray | None: + """Extracts grid topology DataArray from an xarray Dataset.""" + for var_name in ds.variables: + if ds[var_name].attrs.get("cf_role") == "grid_topology": + return ds[var_name] + return None + + +def parse_sgrid(ds: xr.Dataset): + # Function similar to that provided in `xgcm.metadata_parsers. + # Might at some point be upstreamed to xgcm directly + try: + grid_topology = get_grid_topology(ds) + assert grid_topology is not None, "No grid_topology variable found in dataset" + grid = parse_grid_attrs(grid_topology.attrs) + + except Exception as e: + raise SGridParsingException(f"Error parsing {grid_topology=!r}") from e + + if isinstance(grid, Grid2DMetadata): + dimensions = grid.face_dimensions + (grid.vertical_dimensions or ()) + else: + assert isinstance(grid, Grid3DMetadata) + dimensions = grid.volume_dimensions + + xgcm_coords = {} + for dim_dim_padding, axis in zip(dimensions, "XYZ", strict=False): + xgcm_position = SGRID_PADDING_TO_XGCM_POSITION[dim_dim_padding.padding] + xgcm_coords[axis] = {"center": dim_dim_padding.dim2, xgcm_position: dim_dim_padding.dim1} + + return (ds, {"coords": xgcm_coords}) diff --git a/tests/strategies/sgrid.py b/tests/strategies/sgrid.py new file mode 100644 index 000000000..eed1d7583 --- /dev/null +++ b/tests/strategies/sgrid.py @@ -0,0 +1,90 @@ +"""Provides Hypothesis strategies to help testing the parsing and serialization of datasets +According to the SGrid conventions. + +This code is best read alongside the SGrid conventions documentation: +https://sgrid.github.io/sgrid/ + +Note this code doesn't aim to completely cover the SGrid conventions, but aim to +cover SGrid to the extent to which Parcels is concerned. +""" + +import xarray.testing.strategies as xr_st +from hypothesis import strategies as st + +from parcels._core.utils import sgrid + +padding = st.sampled_from(sgrid.Padding) +dimension_name = xr_st.names().filter( + lambda s: " " not in s +) # assuming for now spaces are allowed in dimension names in SGrid convention +dim_dim_padding = ( + st.tuples(dimension_name, dimension_name, padding) + .filter(lambda t: t[0] != t[1]) + .map(lambda t: sgrid.DimDimPadding(*t)) +) + +mappings = st.lists(dim_dim_padding | dimension_name).map(tuple) + + +@st.composite +def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata: + N = 6 + names = draw(st.lists(dimension_name, min_size=N, max_size=N, unique=True)) + node_dimension1 = names[0] + node_dimension2 = names[1] + face_dimension1 = names[2] + face_dimension2 = names[3] + padding_type1 = draw(padding) + padding_type2 = draw(padding) + + vertical_dimensions_dim1 = names[4] + vertical_dimensions_dim2 = names[5] + vertical_dimensions_padding = draw(padding) + has_vertical_dimensions = draw(st.booleans()) + + if has_vertical_dimensions: + vertical_dimensions = ( + sgrid.DimDimPadding(vertical_dimensions_dim1, vertical_dimensions_dim2, vertical_dimensions_padding), + ) + else: + vertical_dimensions = None + + return sgrid.Grid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=(node_dimension1, node_dimension2), + face_dimensions=( + sgrid.DimDimPadding(face_dimension1, node_dimension1, padding_type1), + sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2), + ), + vertical_dimensions=vertical_dimensions, + ) + + +@st.composite +def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata: + N = 6 + names = draw(st.lists(dimension_name, min_size=N, max_size=N, unique=True)) + node_dimension1 = names[0] + node_dimension2 = names[1] + node_dimension3 = names[2] + face_dimension1 = names[3] + face_dimension2 = names[4] + face_dimension3 = names[5] + padding_type1 = draw(padding) + padding_type2 = draw(padding) + padding_type3 = draw(padding) + + return sgrid.Grid3DMetadata( + cf_role="grid_topology", + topology_dimension=3, + node_dimensions=(node_dimension1, node_dimension2, node_dimension3), + volume_dimensions=( + sgrid.DimDimPadding(face_dimension1, node_dimension1, padding_type1), + sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2), + sgrid.DimDimPadding(face_dimension3, node_dimension3, padding_type3), + ), + ) + + +grid_metadata = grid2Dmetadata() | grid3Dmetadata() diff --git a/tests/utils/test_sgrid.py b/tests/utils/test_sgrid.py new file mode 100644 index 000000000..c53d17512 --- /dev/null +++ b/tests/utils/test_sgrid.py @@ -0,0 +1,242 @@ +import numpy as np +import pytest +import xarray as xr +import xgcm +from hypothesis import assume, example, given + +from parcels._core.utils import sgrid +from tests.strategies import sgrid as sgrid_strategies + + +def get_unique_dim_names(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> set[str]: + dims = set() + dims.update(set(grid.node_dimensions)) + + for value in [ + grid.node_dimensions, + grid.face_dimensions if isinstance(grid, sgrid.Grid2DMetadata) else grid.volume_dimensions, + grid.vertical_dimensions if isinstance(grid, sgrid.Grid2DMetadata) else None, + ]: + if value is None: + continue + for item in value: + if isinstance(item, sgrid.DimDimPadding): + dims.add(item.dim1) + dims.add(item.dim2) + else: + assert isinstance(item, str) + dims.add(item) + return dims + + +def dummy_sgrid_ds(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> xr.Dataset: + if isinstance(grid, sgrid.Grid2DMetadata): + return dummy_sgrid_2d_ds(grid) + elif isinstance(grid, sgrid.Grid3DMetadata): + return dummy_sgrid_3d_ds(grid) + else: + raise NotImplementedError(f"Cannot create dummy SGrid dataset for grid type {type(grid)}") + + +def dummy_sgrid_2d_ds(grid: sgrid.Grid2DMetadata) -> xr.Dataset: + ds = dummy_comodo_3d_ds() + + # Can't rename dimensions that already exist in the dataset + assume(get_unique_dim_names(grid) & set(ds.dims) == set()) + + renamings = {} + if grid.vertical_dimensions is None: + ds = ds.isel(ZC=0, ZG=0) + else: + renamings.update({"ZC": grid.vertical_dimensions[0].dim2, "ZG": grid.vertical_dimensions[0].dim1}) + + for old, new in zip(["XG", "YG"], grid.node_dimensions, strict=True): + renamings[old] = new + + for old, dim_dim_padding in zip(["XC", "YC"], grid.face_dimensions, strict=True): + renamings[old] = dim_dim_padding.dim1 + + ds = ds.rename_dims(renamings) + + ds["grid"] = xr.DataArray(1, attrs=grid.to_attrs()) + ds.attrs["convention"] = "SGRID" + return ds + + +def dummy_sgrid_3d_ds(grid: sgrid.Grid3DMetadata) -> xr.Dataset: + ds = dummy_comodo_3d_ds() + + # Can't rename dimensions that already exist in the dataset + assume(get_unique_dim_names(grid) & set(ds.dims) == set()) + + renamings = {} + for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True): + renamings[old] = new + + for old, dim_dim_padding in zip(["XC", "YC", "ZC"], grid.volume_dimensions, strict=True): + renamings[old] = dim_dim_padding.dim1 + + ds = ds.rename_dims(renamings) + + ds["grid"] = xr.DataArray(1, attrs=grid.to_attrs()) + ds.attrs["convention"] = "SGRID" + return ds + + +def dummy_comodo_3d_ds() -> xr.Dataset: + T, Z, Y, X = 7, 6, 5, 4 + TIME = xr.date_range("2000", "2001", T) + return xr.Dataset( + { + "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)), + "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)), + "U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)), + "V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)), + "U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)), + "V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)), + }, + coords={ + # "XG": ( + # ["XG"], + # 2 * np.pi / X * np.arange(0, X), + # {"axis": "X", "c_grid_axis_shift": -0.5}, + # ), + # "XC": (["XC"], 2 * np.pi / X * (np.arange(0, X) + 0.5), {"axis": "X"}), + # "YG": ( + # ["YG"], + # 2 * np.pi / (Y) * np.arange(0, Y), + # {"axis": "Y", "c_grid_axis_shift": -0.5}, + # ), + # "YC": ( + # ["YC"], + # 2 * np.pi / (Y) * (np.arange(0, Y) + 0.5), + # {"axis": "Y"}, + # ), + # "ZG": ( + # ["ZG"], + # np.arange(Z), + # {"axis": "Z", "c_grid_axis_shift": -0.5}, + # ), + # "ZC": ( + # ["ZC"], + # np.arange(Z) + 0.5, + # {"axis": "Z"}, + # ), + # "lon": (["XG"], 2 * np.pi / X * np.arange(0, X)), + # "lat": (["YG"], 2 * np.pi / (Y) * np.arange(0, Y)), + # "depth": (["ZG"], np.arange(Z)), + "time": (["time"], TIME, {"axis": "T"}), + }, + ) + + +@example( + edge_node_padding=( + sgrid.DimDimPadding("edge1", "node1", sgrid.Padding.NONE), + sgrid.DimDimPadding("edge2", "node2", sgrid.Padding.LOW), + ) +) +@given(sgrid_strategies.mappings) +def test_edge_node_mapping_metadata_roundtrip(edge_node_padding): + serialized = sgrid.dump_mappings(edge_node_padding) + parsed = sgrid.load_mappings(serialized) + assert parsed == edge_node_padding + + +@pytest.mark.parametrize( + "input_, expected", + [ + ( + "edge1: node1(padding: none)", + (sgrid.DimDimPadding("edge1", "node1", sgrid.Padding.NONE),), + ), + ], +) +def test_load_dump_mappings(input_, expected): + assert sgrid.load_mappings(input_) == expected + + +@example( + grid=sgrid.Grid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + ), + vertical_dimensions=( + sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW), + ), + ) +) +@given(sgrid_strategies.grid2Dmetadata()) +def test_Grid2DMetadata_roundtrip(grid: sgrid.Grid2DMetadata): + attrs = grid.to_attrs() + parsed = sgrid.Grid2DMetadata.from_attrs(attrs) + assert parsed == grid + + +@example( + grid=sgrid.Grid3DMetadata( + cf_role="grid_topology", + topology_dimension=3, + node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"), + volume_dimensions=( + sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW), + ), + ) +) +@given(sgrid_strategies.grid3Dmetadata()) +def test_Grid3DMetadata_roundtrip(grid: sgrid.Grid3DMetadata): + attrs = grid.to_attrs() + parsed = sgrid.Grid3DMetadata.from_attrs(attrs) + assert parsed == grid + + +@given(sgrid_strategies.grid_metadata) +def test_parse_grid_attrs(grid: sgrid.SGridMetadataProtocol): + attrs = grid.to_attrs() + parsed = sgrid.parse_grid_attrs(attrs) + assert parsed == grid + + +@given(sgrid_strategies.grid2Dmetadata()) +def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata): + """Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided""" + ds = dummy_sgrid_2d_ds(grid_metadata) + + ds, xgcm_kwargs = sgrid.parse_sgrid(ds) + grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs) + + for ddp, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True): + dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding + coords = grid.axes[axis].coords + assert coords["center"] == dim_edge + assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node + + if grid_metadata.vertical_dimensions is None: + assert "Z" not in grid.axes + else: + ddp = grid_metadata.vertical_dimensions[0] + dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding + coords = grid.axes["Z"].coords + assert coords["center"] == dim_edge + assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node + + +@given(sgrid_strategies.grid3Dmetadata()) +def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata): + """Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided""" + ds = dummy_sgrid_3d_ds(grid_metadata) + + ds, xgcm_kwargs = sgrid.parse_sgrid(ds) + grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs) + + for ddp, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True): + dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding + coords = grid.axes[axis].coords + assert coords["center"] == dim_edge + assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node