diff --git a/src/parcels/_core/utils/sgrid.py b/src/parcels/_core/utils/sgrid.py index 825649cf2..bd98962cc 100644 --- a/src/parcels/_core/utils/sgrid.py +++ b/src/parcels/_core/utils/sgrid.py @@ -19,6 +19,8 @@ import xarray as xr +from parcels._python import repr_from_dunder_dict + RE_DIM_DIM_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)" Dim = str @@ -31,12 +33,21 @@ class Padding(enum.Enum): BOTH = "both" -class SGridMetadataProtocol(Protocol): +SGRID_PADDING_TO_XGCM_POSITION = { + Padding.LOW: "right", + Padding.HIGH: "left", + Padding.BOTH: "inner", + Padding.NONE: "outer", + # "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves +} + + +class AttrsSerializable(Protocol): def to_attrs(self) -> dict[str, str | int]: ... def from_attrs(cls, d: dict[str, Hashable]) -> Self: ... -class Grid2DMetadata(SGridMetadataProtocol): +class Grid2DMetadata(AttrsSerializable): def __init__( self, cf_role: Literal["grid_topology"], @@ -94,16 +105,13 @@ def __init__( #! Important optional attribute for 2D grids with vertical layering self.vertical_dimensions = vertical_dimensions + def __repr__(self) -> str: + return repr_from_dunder_dict(self) + def __eq__(self, other: Any) -> bool: if not isinstance(other, Grid2DMetadata): return NotImplemented - return ( - self.cf_role == other.cf_role - and self.topology_dimension == other.topology_dimension - and self.node_dimensions == other.node_dimensions - and self.face_dimensions == other.face_dimensions - and self.vertical_dimensions == other.vertical_dimensions - ) + return self.to_attrs() == other.to_attrs() @classmethod def from_attrs(cls, attrs): @@ -129,8 +137,11 @@ def to_attrs(self) -> dict[str, str | int]: d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions) return d + def rename_dims(self, dims_dict: dict[str, str]) -> Self: + return _metadata_rename_dims(self, dims_dict) + -class Grid3DMetadata(SGridMetadataProtocol): +class Grid3DMetadata(AttrsSerializable): def __init__( self, cf_role: Literal["grid_topology"], @@ -180,15 +191,13 @@ def __init__( # face *i_coordinates* # volume_coordinates + def __repr__(self) -> str: + return repr_from_dunder_dict(self) + def __eq__(self, other: Any) -> bool: if not isinstance(other, Grid3DMetadata): return NotImplemented - return ( - self.cf_role == other.cf_role - and self.topology_dimension == other.topology_dimension - and self.node_dimensions == other.node_dimensions - and self.volume_dimensions == other.volume_dimensions - ) + return self.to_attrs() == other.to_attrs() @classmethod def from_attrs(cls, attrs): @@ -210,6 +219,9 @@ def to_attrs(self) -> dict[str, str | int]: volume_dimensions=dump_mappings(self.volume_dimensions), ) + def rename_dims(self, dims_dict: dict[str, str]) -> Self: + return _metadata_rename_dims(self, dims_dict) + @dataclass class DimDimPadding: @@ -318,15 +330,6 @@ def maybe_load_mappings(s): return load_mappings(s) -SGRID_PADDING_TO_XGCM_POSITION = { - Padding.LOW: "right", - Padding.HIGH: "left", - Padding.BOTH: "inner", - Padding.NONE: "outer", - # "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves -} - - class SGridParsingException(Exception): """Exception raised when parsing SGrid attributes fails.""" @@ -378,3 +381,95 @@ def parse_sgrid(ds: xr.Dataset): xgcm_coords[axis] = {"center": dim_dim_padding.dim2, xgcm_position: dim_dim_padding.dim1} return (ds, {"coords": xgcm_coords}) + + +def rename_dims(ds: xr.Dataset, dims_dict: dict[str, str]) -> xr.Dataset: + grid_da = get_grid_topology(ds) + if grid_da is None: + raise ValueError( + "No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions." + ) + + ds = ds.rename_dims(dims_dict) + + # Update the metadata + grid = parse_grid_attrs(grid_da.attrs) + ds[grid_da.name].attrs = grid.rename_dims(dims_dict).to_attrs() + return ds + + +def get_unique_dim_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]: + dims = set() + dims.update(set(grid.node_dimensions)) + + for key, value in grid.__dict__.items(): + if key in ("cf_role", "topology_dimension") or value is None: + continue + assert isinstance(value, tuple), ( + f"Expected sgrid metadata attribute to be represented as a tuple, got {value!r}. This is an internal error to Parcels - please post an issue if you encounter this." + ) + for item in value: + if isinstance(item, DimDimPadding): + dims.add(item.dim1) + dims.add(item.dim2) + else: + assert isinstance(item, str) + dims.add(item) + return dims + + +@overload +def _metadata_rename_dims(grid: Grid2DMetadata, dims_dict: dict[str, str]) -> Grid2DMetadata: ... + + +@overload +def _metadata_rename_dims(grid: Grid3DMetadata, dims_dict: dict[str, str]) -> Grid3DMetadata: ... + + +def _metadata_rename_dims(grid, dims_dict): + """ + Renames dimensions in SGrid metadata. + + Similar in API to xr.Dataset.rename_dims. Renames dimensions according to dims_dict mapping + of old dimension names to new dimension names. + """ + dims_dict = dims_dict.copy() + assert len(dims_dict) == len(set(dims_dict.values())), "dims_dict contains duplicate target dimension names" + + existing_dims = get_unique_dim_names(grid) + for dim in dims_dict.keys(): + if dim not in existing_dims: + raise ValueError(f"Dimension {dim!r} not found in SGrid metadata dimensions {existing_dims!r}") + + for dim in existing_dims: + if dim not in dims_dict: + dims_dict[dim] = dim # identity mapping for dimensions not being renamed + + kwargs = {} + for key, value in grid.__dict__.items(): + if isinstance(value, tuple): + new_value = [] + for item in value: + if isinstance(item, DimDimPadding): + new_item = DimDimPadding( + dim1=dims_dict[item.dim1], + dim2=dims_dict[item.dim2], + padding=item.padding, + ) + new_value.append(new_item) + else: + assert isinstance(item, str) + new_value.append(dims_dict[item]) + kwargs[key] = tuple(new_value) + continue + + if key in ("cf_role", "topology_dimension") or value is None: + kwargs[key] = value + continue + + if isinstance(value, str): + kwargs[key] = dims_dict[value] + continue + + raise ValueError(f"Unexpected attribute {key!r} on {grid!r}") + return type(grid)(**kwargs) diff --git a/src/parcels/_datasets/structured/generic.py b/src/parcels/_datasets/structured/generic.py index 7758cfe18..df97b0cf3 100644 --- a/src/parcels/_datasets/structured/generic.py +++ b/src/parcels/_datasets/structured/generic.py @@ -1,6 +1,16 @@ import numpy as np import xarray as xr +from parcels._core.utils.sgrid import ( + DimDimPadding, + Grid2DMetadata, + Grid3DMetadata, + Padding, +) +from parcels._core.utils.sgrid import ( + rename_dims as sgrid_rename_dims, +) + from . import T, X, Y, Z __all__ = ["T", "X", "Y", "Z", "datasets"] @@ -8,6 +18,18 @@ TIME = xr.date_range("2000", "2001", T) +def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata): + """Copies the dataset and attaches the SGRID metadata in 'grid' variable. Modifies 'conventions' attribute.""" + ds = ds.copy() + ds["grid"] = ( + [], + 0, + grid.to_attrs(), + ) + ds.attrs["Conventions"] = "SGRID" + return ds + + def _rotated_curvilinear_grid(): XG = np.arange(X) YG = np.arange(Y) @@ -225,3 +247,54 @@ def _unrolled_cone_curvilinear_grid(): ), "2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(), } + +_COMODO_TO_2D_SGRID = { # Note "2D SGRID" here is meant in the context of SGRID convention (i.e., 1D depth) + "XG": "node_dimension1", + "YG": "node_dimension2", + "XC": "face_dimension1", + "YC": "face_dimension2", + "ZG": "vertical_dimensions_dim1", + "ZC": "vertical_dimensions_dim2", +} +datasets_sgrid = { + "ds_2d_padded_high": ( + datasets["ds_2d_left"] + .pipe( + _attach_sgrid_metadata, + Grid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("YG", "XG"), + face_dimensions=( + DimDimPadding("YC", "YG", Padding.HIGH), + DimDimPadding("XC", "XG", Padding.HIGH), + ), + vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.HIGH),), + ), + ) + .pipe( + sgrid_rename_dims, + _COMODO_TO_2D_SGRID, + ) + ), + "ds_2d_padded_low": ( + datasets["ds_2d_right"] + .pipe( + _attach_sgrid_metadata, + Grid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("YG", "XG"), + face_dimensions=( + DimDimPadding("YC", "YG", Padding.LOW), + DimDimPadding("XC", "XG", Padding.LOW), + ), + vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.LOW),), + ), + ) + .pipe( + sgrid_rename_dims, + _COMODO_TO_2D_SGRID, + ) + ), +} diff --git a/src/parcels/_python.py b/src/parcels/_python.py index a78e9bf0d..1cc7b4fd5 100644 --- a/src/parcels/_python.py +++ b/src/parcels/_python.py @@ -1,6 +1,6 @@ # Generic Python helpers import inspect -from collections.abc import Callable +from types import FunctionType def isinstance_noimport(obj, class_or_tuple): @@ -14,7 +14,13 @@ def isinstance_noimport(obj, class_or_tuple): ) -def assert_same_function_signature(f: Callable, *, ref: Callable, context: str) -> None: +def repr_from_dunder_dict(obj: object) -> str: + """Dataclass-like __repr__ implementation based on __dict__.""" + parts = [f"{k}={v!r}" for k, v in obj.__dict__.items()] + return f"{obj.__class__.__qualname__}(" + ", ".join(parts) + ")" + + +def assert_same_function_signature(f: FunctionType, *, ref: FunctionType, context: str) -> None: """Ensures a function `f` has the same signature as the reference function `ref`.""" sig_ref = inspect.signature(ref) sig = inspect.signature(f) diff --git a/tests/utils/test_sgrid.py b/tests/utils/test_sgrid.py index c53d17512..094c1c6ca 100644 --- a/tests/utils/test_sgrid.py +++ b/tests/utils/test_sgrid.py @@ -8,25 +8,20 @@ from tests.strategies import sgrid as sgrid_strategies -def get_unique_dim_names(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> set[str]: - dims = set() - dims.update(set(grid.node_dimensions)) - - for value in [ - grid.node_dimensions, - grid.face_dimensions if isinstance(grid, sgrid.Grid2DMetadata) else grid.volume_dimensions, - grid.vertical_dimensions if isinstance(grid, sgrid.Grid2DMetadata) else None, - ]: - if value is None: - continue - for item in value: - if isinstance(item, sgrid.DimDimPadding): - dims.add(item.dim1) - dims.add(item.dim2) - else: - assert isinstance(item, str) - dims.add(item) - return dims +@pytest.fixture +def grid2dmetadata(): + return sgrid.Grid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + ), + vertical_dimensions=( + sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW), + ), + ) def dummy_sgrid_ds(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> xr.Dataset: @@ -42,7 +37,7 @@ def dummy_sgrid_2d_ds(grid: sgrid.Grid2DMetadata) -> xr.Dataset: ds = dummy_comodo_3d_ds() # Can't rename dimensions that already exist in the dataset - assume(get_unique_dim_names(grid) & set(ds.dims) == set()) + assume(sgrid.get_unique_dim_names(grid) & set(ds.dims) == set()) renamings = {} if grid.vertical_dimensions is None: @@ -67,7 +62,7 @@ def dummy_sgrid_3d_ds(grid: sgrid.Grid3DMetadata) -> xr.Dataset: ds = dummy_comodo_3d_ds() # Can't rename dimensions that already exist in the dataset - assume(get_unique_dim_names(grid) & set(ds.dims) == set()) + assume(sgrid.get_unique_dim_names(grid) & set(ds.dims) == set()) renamings = {} for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True): @@ -197,7 +192,7 @@ def test_Grid3DMetadata_roundtrip(grid: sgrid.Grid3DMetadata): @given(sgrid_strategies.grid_metadata) -def test_parse_grid_attrs(grid: sgrid.SGridMetadataProtocol): +def test_parse_grid_attrs(grid: sgrid.AttrsSerializable): attrs = grid.to_attrs() parsed = sgrid.parse_grid_attrs(attrs) assert parsed == grid @@ -240,3 +235,75 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata): coords = grid.axes[axis].coords assert coords["center"] == dim_edge assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node + + +@pytest.mark.parametrize( + "grid", + [ + ( + sgrid.Grid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + ), + vertical_dimensions=( + sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW), + ), + ) + ), + ( + sgrid.Grid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + ), + vertical_dimensions=None, + ) + ), + ( + sgrid.Grid3DMetadata( + cf_role="grid_topology", + topology_dimension=3, + node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"), + volume_dimensions=( + sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW), + ), + ) + ), + ], +) +def test_rename_dims(grid): + dims = sgrid.get_unique_dim_names(grid) + dims_dict = {dim: f"new_{dim}" for dim in dims} + dims_dict_inv = {v: k for k, v in dims_dict.items()} + + grid_new = grid.rename_dims(dims_dict) + assert dims & set(sgrid.get_unique_dim_names(grid_new)) == set() + + assert grid == grid_new.rename_dims(dims_dict_inv) + + +def test_rename_dims_errors(grid2dmetadata): + # Test various error modes of rename_dims + grid = grid2dmetadata + # Non-unique target dimension names + dims_dict = { + "node_dimension1": "new_node_dimension", + "node_dimension2": "new_node_dimension", + } + with pytest.raises(AssertionError, match="dims_dict contains duplicate target dimension names"): + grid.rename_dims(dims_dict) + # Unexpected attribute in dims_dict + dims_dict = { + "unexpected_dimension": "new_unexpected_dimension", + } + with pytest.raises(ValueError, match="Dimension 'unexpected_dimension' not found in SGrid metadata dimensions"): + grid.rename_dims(dims_dict)