Skip to content
145 changes: 120 additions & 25 deletions src/parcels/_core/utils/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import xarray as xr

from parcels._python import repr_from_dunder_dict

RE_DIM_DIM_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)"

Dim = str
Expand All @@ -31,12 +33,21 @@ class Padding(enum.Enum):
BOTH = "both"


class SGridMetadataProtocol(Protocol):
SGRID_PADDING_TO_XGCM_POSITION = {
Padding.LOW: "right",
Padding.HIGH: "left",
Padding.BOTH: "inner",
Padding.NONE: "outer",
# "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves
}


class AttrsSerializable(Protocol):
def to_attrs(self) -> dict[str, str | int]: ...
def from_attrs(cls, d: dict[str, Hashable]) -> Self: ...


class Grid2DMetadata(SGridMetadataProtocol):
class Grid2DMetadata(AttrsSerializable):
def __init__(
self,
cf_role: Literal["grid_topology"],
Expand Down Expand Up @@ -94,16 +105,13 @@ def __init__(
#! Important optional attribute for 2D grids with vertical layering
self.vertical_dimensions = vertical_dimensions

def __repr__(self) -> str:
return repr_from_dunder_dict(self)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Grid2DMetadata):
return NotImplemented
return (
self.cf_role == other.cf_role
and self.topology_dimension == other.topology_dimension
and self.node_dimensions == other.node_dimensions
and self.face_dimensions == other.face_dimensions
and self.vertical_dimensions == other.vertical_dimensions
)
return self.to_attrs() == other.to_attrs()

@classmethod
def from_attrs(cls, attrs):
Expand All @@ -129,8 +137,11 @@ def to_attrs(self) -> dict[str, str | int]:
d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions)
return d

def rename_dims(self, dims_dict: dict[str, str]) -> Self:
return _metadata_rename_dims(self, dims_dict)


class Grid3DMetadata(SGridMetadataProtocol):
class Grid3DMetadata(AttrsSerializable):
def __init__(
self,
cf_role: Literal["grid_topology"],
Expand Down Expand Up @@ -180,15 +191,13 @@ def __init__(
# face *i_coordinates*
# volume_coordinates

def __repr__(self) -> str:
return repr_from_dunder_dict(self)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Grid3DMetadata):
return NotImplemented
return (
self.cf_role == other.cf_role
and self.topology_dimension == other.topology_dimension
and self.node_dimensions == other.node_dimensions
and self.volume_dimensions == other.volume_dimensions
)
return self.to_attrs() == other.to_attrs()

@classmethod
def from_attrs(cls, attrs):
Expand All @@ -210,6 +219,9 @@ def to_attrs(self) -> dict[str, str | int]:
volume_dimensions=dump_mappings(self.volume_dimensions),
)

def rename_dims(self, dims_dict: dict[str, str]) -> Self:
return _metadata_rename_dims(self, dims_dict)


@dataclass
class DimDimPadding:
Expand Down Expand Up @@ -318,15 +330,6 @@ def maybe_load_mappings(s):
return load_mappings(s)


SGRID_PADDING_TO_XGCM_POSITION = {
Padding.LOW: "right",
Padding.HIGH: "left",
Padding.BOTH: "inner",
Padding.NONE: "outer",
# "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves
}


class SGridParsingException(Exception):
"""Exception raised when parsing SGrid attributes fails."""

Expand Down Expand Up @@ -378,3 +381,95 @@ def parse_sgrid(ds: xr.Dataset):
xgcm_coords[axis] = {"center": dim_dim_padding.dim2, xgcm_position: dim_dim_padding.dim1}

return (ds, {"coords": xgcm_coords})


def rename_dims(ds: xr.Dataset, dims_dict: dict[str, str]) -> xr.Dataset:
grid_da = get_grid_topology(ds)
if grid_da is None:
raise ValueError(
"No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions."
)

ds = ds.rename_dims(dims_dict)

# Update the metadata
grid = parse_grid_attrs(grid_da.attrs)
ds[grid_da.name].attrs = grid.rename_dims(dims_dict).to_attrs()
return ds


def get_unique_dim_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
dims = set()
dims.update(set(grid.node_dimensions))

for key, value in grid.__dict__.items():
if key in ("cf_role", "topology_dimension") or value is None:
continue
assert isinstance(value, tuple), (
f"Expected sgrid metadata attribute to be represented as a tuple, got {value!r}. This is an internal error to Parcels - please post an issue if you encounter this."
)
for item in value:
if isinstance(item, DimDimPadding):
dims.add(item.dim1)
dims.add(item.dim2)
else:
assert isinstance(item, str)
dims.add(item)
return dims


@overload
def _metadata_rename_dims(grid: Grid2DMetadata, dims_dict: dict[str, str]) -> Grid2DMetadata: ...


@overload
def _metadata_rename_dims(grid: Grid3DMetadata, dims_dict: dict[str, str]) -> Grid3DMetadata: ...


def _metadata_rename_dims(grid, dims_dict):
"""
Renames dimensions in SGrid metadata.

Similar in API to xr.Dataset.rename_dims. Renames dimensions according to dims_dict mapping
of old dimension names to new dimension names.
"""
dims_dict = dims_dict.copy()
assert len(dims_dict) == len(set(dims_dict.values())), "dims_dict contains duplicate target dimension names"

existing_dims = get_unique_dim_names(grid)
for dim in dims_dict.keys():
if dim not in existing_dims:
raise ValueError(f"Dimension {dim!r} not found in SGrid metadata dimensions {existing_dims!r}")

for dim in existing_dims:
if dim not in dims_dict:
dims_dict[dim] = dim # identity mapping for dimensions not being renamed

kwargs = {}
for key, value in grid.__dict__.items():
if isinstance(value, tuple):
new_value = []
for item in value:
if isinstance(item, DimDimPadding):
new_item = DimDimPadding(
dim1=dims_dict[item.dim1],
dim2=dims_dict[item.dim2],
padding=item.padding,
)
new_value.append(new_item)
else:
assert isinstance(item, str)
new_value.append(dims_dict[item])
kwargs[key] = tuple(new_value)
continue

if key in ("cf_role", "topology_dimension") or value is None:
kwargs[key] = value
continue

if isinstance(value, str):
kwargs[key] = dims_dict[value]
continue

raise ValueError(f"Unexpected attribute {key!r} on {grid!r}")
return type(grid)(**kwargs)
73 changes: 73 additions & 0 deletions src/parcels/_datasets/structured/generic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
import numpy as np
import xarray as xr

from parcels._core.utils.sgrid import (
DimDimPadding,
Grid2DMetadata,
Grid3DMetadata,
Padding,
)
from parcels._core.utils.sgrid import (
rename_dims as sgrid_rename_dims,
)

from . import T, X, Y, Z

__all__ = ["T", "X", "Y", "Z", "datasets"]

TIME = xr.date_range("2000", "2001", T)


def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata):
"""Copies the dataset and attaches the SGRID metadata in 'grid' variable. Modifies 'conventions' attribute."""
ds = ds.copy()
ds["grid"] = (
[],
0,
grid.to_attrs(),
)
ds.attrs["Conventions"] = "SGRID"
return ds


def _rotated_curvilinear_grid():
XG = np.arange(X)
YG = np.arange(Y)
Expand Down Expand Up @@ -225,3 +247,54 @@ def _unrolled_cone_curvilinear_grid():
),
"2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(),
}

_COMODO_TO_2D_SGRID = { # Note "2D SGRID" here is meant in the context of SGRID convention (i.e., 1D depth)
"XG": "node_dimension1",
"YG": "node_dimension2",
"XC": "face_dimension1",
"YC": "face_dimension2",
"ZG": "vertical_dimensions_dim1",
"ZC": "vertical_dimensions_dim2",
}
datasets_sgrid = {
"ds_2d_padded_high": (
datasets["ds_2d_left"]
.pipe(
_attach_sgrid_metadata,
Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("YG", "XG"),
face_dimensions=(
DimDimPadding("YC", "YG", Padding.HIGH),
DimDimPadding("XC", "XG", Padding.HIGH),
),
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.HIGH),),
),
)
.pipe(
sgrid_rename_dims,
_COMODO_TO_2D_SGRID,
)
),
"ds_2d_padded_low": (
datasets["ds_2d_right"]
.pipe(
_attach_sgrid_metadata,
Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("YG", "XG"),
face_dimensions=(
DimDimPadding("YC", "YG", Padding.LOW),
DimDimPadding("XC", "XG", Padding.LOW),
),
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.LOW),),
),
)
.pipe(
sgrid_rename_dims,
_COMODO_TO_2D_SGRID,
)
),
}
10 changes: 8 additions & 2 deletions src/parcels/_python.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Generic Python helpers
import inspect
from collections.abc import Callable
from types import FunctionType


def isinstance_noimport(obj, class_or_tuple):
Expand All @@ -14,7 +14,13 @@ def isinstance_noimport(obj, class_or_tuple):
)


def assert_same_function_signature(f: Callable, *, ref: Callable, context: str) -> None:
def repr_from_dunder_dict(obj: object) -> str:
"""Dataclass-like __repr__ implementation based on __dict__."""
parts = [f"{k}={v!r}" for k, v in obj.__dict__.items()]
return f"{obj.__class__.__qualname__}(" + ", ".join(parts) + ")"


def assert_same_function_signature(f: FunctionType, *, ref: FunctionType, context: str) -> None:
"""Ensures a function `f` has the same signature as the reference function `ref`."""
sig_ref = inspect.signature(ref)
sig = inspect.signature(f)
Expand Down
Loading