Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import xgcm

from parcels._core.field import Field, VectorField
from parcels._core.utils import sgrid
from parcels._core.utils.string import _assert_str_and_python_varname
from parcels._core.utils.time import get_datetime_type_calendar
from parcels._core.utils.time import is_compatible as datetime_is_compatible
Expand Down Expand Up @@ -295,6 +296,92 @@ def from_fesom2(ds: ux.UxDataset):

return FieldSet(list(fields.values()))

def from_sgrid_conventions(
ds: xr.Dataset, mesh: Mesh
): # TODO: Update mesh to be discovered from the dataset metadata
"""Create a FieldSet from a dataset using SGRID convention metadata.

This is the primary ingestion method in Parcels for structured grid datasets.

Assumes that U, V, (and optionally W) variables are named 'U', 'V', and 'W' in the dataset.

Parameters
----------
ds : xarray.Dataset
xarray.Dataset with SGRID convention metadata.
mesh : str
String indicating the type of mesh coordinates and units used during
velocity interpolation. Options are "spherical" or "flat".

Returns
-------
FieldSet
FieldSet object containing the fields from the dataset that can be used for a Parcels simulation.

Notes
-----
This method uses the SGRID convention metadata to parse the grid structure
and create appropriate Fields for a Parcels simulation. The dataset should
contain a variable with 'cf_role' attribute set to 'grid_topology'.

See https://sgrid.github.io/ for more information on the SGRID conventions.
"""
ds = ds.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to make a copy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a shallow copy - from_fesom2 and from_copernicusmarine do them both to avoid changes to the dataset propogating to the original dataset


# Ensure time dimension has axis attribute if present
if "time" in ds.dims and "time" in ds.coords:
if "axis" not in ds["time"].attrs:
logger.debug(
"Dataset contains 'time' dimension but no 'axis' attribute. Setting 'axis' attribute to 'T'."
)
ds["time"].attrs["axis"] = "T"

# Find time dimension based on axis attribute and rename to `time`
if (time_dims := ds.cf.axes.get("T")) is not None:
if len(time_dims) > 1:
raise ValueError("Multiple time coordinates found in dataset. This is not supported by Parcels.")
(time_dim,) = time_dims
if time_dim != "time":
logger.debug(f"Renaming time axis coordinate from {time_dim} to 'time'.")
ds = ds.rename({time_dim: "time"})

# Parse SGRID metadata and get xgcm kwargs
_, xgcm_kwargs = sgrid.parse_sgrid(ds)

# Add time axis to xgcm_kwargs if present
if "time" in ds.dims:
if "T" not in xgcm_kwargs["coords"]:
xgcm_kwargs["coords"]["T"] = {"center": "time"}

# Create xgcm Grid object
xgcm_grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs, **_DEFAULT_XGCM_KWARGS)

# Wrap in XGrid
grid = XGrid(xgcm_grid, mesh=mesh)

# Create fields from data variables, skipping grid metadata variables
# Skip variables that are SGRID metadata (have cf_role='grid_topology')
skip_vars = set()
for var in ds.data_vars:
if ds[var].attrs.get("cf_role") == "grid_topology":
skip_vars.add(var)

fields = {}
if "U" in ds.data_vars and "V" in ds.data_vars:
fields["U"] = Field("U", ds["U"], grid, XLinear)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this interpolator may not be the default for nemo; so how would from_nemo overwrite the the interpolator here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be the default no? Since XLinear would be aware of the grid positioning?

Otherwise we would need to expose this as a param to the function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, XLinear is only the default for an A-grid layout. For a C-Grid layout, the velocities should be CGrid_Velocity and any other field should be CGrid_Tracer. SO it depends on the Grid topology

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in person. Outcome: I think we can find where the variables are on the grid via the sgrid conventions, and then use that to choose the interpolator. We'll merge for now, I can rework from_copernicusmarine to use this new code path, and we can explore in future how we can specify this in SGRID and adapt our interpolators

fields["V"] = Field("V", ds["V"], grid, XLinear)

if "W" in ds.data_vars:
fields["W"] = Field("W", ds["W"], grid, XLinear)
fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"])
else:
fields["UV"] = VectorField("UV", fields["U"], fields["V"])

for varname in set(ds.data_vars) - set(fields.keys()) - skip_vars:
fields[varname] = Field(varname, ds[varname], grid, XLinear)

return FieldSet(list(fields.values()))


class CalendarError(Exception): # TODO: Move to a parcels errors module
"""Exception raised when the calendar of a field is not compatible with the rest of the Fields. The user should ensure that they only add fields to a FieldSet that have compatible CFtime calendars."""
Expand Down
2 changes: 1 addition & 1 deletion src/parcels/_core/utils/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def parse_sgrid(ds: xr.Dataset):
xgcm_coords = {}
for dim_dim_padding, axis in zip(dimensions, "XYZ", strict=False):
xgcm_position = SGRID_PADDING_TO_XGCM_POSITION[dim_dim_padding.padding]
xgcm_coords[axis] = {"center": dim_dim_padding.dim2, xgcm_position: dim_dim_padding.dim1}
xgcm_coords[axis] = {"center": dim_dim_padding.dim1, xgcm_position: dim_dim_padding.dim2}

return (ds, {"coords": xgcm_coords})

Expand Down
8 changes: 4 additions & 4 deletions src/parcels/_datasets/structured/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ def _unrolled_cone_curvilinear_grid():
Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("YG", "XG"),
node_dimensions=("XG", "YG"),
face_dimensions=(
DimDimPadding("YC", "YG", Padding.HIGH),
DimDimPadding("XC", "XG", Padding.HIGH),
DimDimPadding("YC", "YG", Padding.HIGH),
),
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.HIGH),),
),
Expand All @@ -284,10 +284,10 @@ def _unrolled_cone_curvilinear_grid():
Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("YG", "XG"),
node_dimensions=("XG", "YG"),
face_dimensions=(
DimDimPadding("YC", "YG", Padding.LOW),
DimDimPadding("XC", "XG", Padding.LOW),
DimDimPadding("YC", "YG", Padding.LOW),
),
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.LOW),),
),
Expand Down
9 changes: 9 additions & 0 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from parcels._datasets.structured.circulation_models import datasets as datasets_circulation_models
from parcels._datasets.structured.generic import T as T_structured
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels._datasets.structured.generic import datasets_sgrid
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
from parcels.interpolators import XLinear
from tests import utils
Expand Down Expand Up @@ -342,3 +343,11 @@ def test_fieldset_from_fesom2_missingUV():
with pytest.raises(ValueError) as info:
_ = FieldSet.from_fesom2(localds)
assert "Dataset has only one of the two variables 'U' and 'V'" in str(info)


@pytest.mark.parametrize("ds_name", list(datasets_sgrid.keys()))
def test_fieldset_from_sgrid_conventions(ds_name):
ds = datasets_sgrid[ds_name]
fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat")
assert isinstance(fieldset, FieldSet)
assert len(fieldset.fields) > 0
73 changes: 29 additions & 44 deletions tests/utils/test_sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,29 @@
from parcels._core.utils import sgrid
from tests.strategies import sgrid as sgrid_strategies

grid2dmetadata = sgrid.Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
),
vertical_dimensions=(
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
),
)

@pytest.fixture
def grid2dmetadata():
return sgrid.Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
),
vertical_dimensions=(
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
),
)
grid3dmetadata = sgrid.Grid3DMetadata(
cf_role="grid_topology",
topology_dimension=3,
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
volume_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
),
)


def dummy_sgrid_ds(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> xr.Dataset:
Expand Down Expand Up @@ -151,39 +159,15 @@ def test_load_dump_mappings(input_, expected):
assert sgrid.load_mappings(input_) == expected


@example(
grid=sgrid.Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
),
vertical_dimensions=(
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
),
)
)
@example(grid2dmetadata)
@given(sgrid_strategies.grid2Dmetadata())
def test_Grid2DMetadata_roundtrip(grid: sgrid.Grid2DMetadata):
attrs = grid.to_attrs()
parsed = sgrid.Grid2DMetadata.from_attrs(attrs)
assert parsed == grid


@example(
grid=sgrid.Grid3DMetadata(
cf_role="grid_topology",
topology_dimension=3,
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
volume_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
),
)
)
@example(grid3dmetadata)
@given(sgrid_strategies.grid3Dmetadata())
def test_Grid3DMetadata_roundtrip(grid: sgrid.Grid3DMetadata):
attrs = grid.to_attrs()
Expand All @@ -198,6 +182,7 @@ def test_parse_grid_attrs(grid: sgrid.AttrsSerializable):
assert parsed == grid


@example(grid2dmetadata)
@given(sgrid_strategies.grid2Dmetadata())
def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata):
"""Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided"""
Expand All @@ -207,7 +192,7 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata):
grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs)

for ddp, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True):
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding
coords = grid.axes[axis].coords
assert coords["center"] == dim_edge
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
Expand All @@ -216,7 +201,7 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata):
assert "Z" not in grid.axes
else:
ddp = grid_metadata.vertical_dimensions[0]
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding
coords = grid.axes["Z"].coords
assert coords["center"] == dim_edge
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
Expand All @@ -231,7 +216,7 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata):
grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs)

for ddp, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True):
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding
coords = grid.axes[axis].coords
assert coords["center"] == dim_edge
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
Expand Down Expand Up @@ -291,7 +276,7 @@ def test_rename_dims(grid):
assert grid == grid_new.rename_dims(dims_dict_inv)


def test_rename_dims_errors(grid2dmetadata):
def test_rename_dims_errors():
# Test various error modes of rename_dims
grid = grid2dmetadata
# Non-unique target dimension names
Expand Down
Loading