Skip to content
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies: #! Keep in sync with [tool.pixi.dependencies] in pyproject.toml
- uxarray>=2025.3.0
- xgcm>=0.9.0
- pooch
- cf_xarray

# Notebooks
- trajan
Expand Down
30 changes: 18 additions & 12 deletions parcels/_datasets/structured/circulation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,34 @@ def _copernicusmarine():
)


def _copernicusmarine_globcurrent():
def _copernicusmarine_waves():
"""Copernicus Marine Service GlobCurrent dataset (MULTIOBS_GLO_PHY_MYNRT_015_003)"""
return xr.Dataset(
{
"ue": (
"VSDX": (
["time", "depth", "latitude", "longitude"],
np.random.rand(T, Z, Y, X),
{
"units": "m/s",
"standard_name": "eastward_sea_water_velocity_due_to_ekman_drift",
"long_name": "Depth Ekman driven velocity : zonal component",
"grid_mapping": "crs",
"units": "m s-1",
"standard_name": "sea_surface_wave_stokes_drift_x_velocity",
"long_name": "Stokes drift U",
"WMO_code": 215,
"cell_methods": "time:point area:mean",
"missing_value": -32767,
"type_of_analysis": "spectral analysis",
},
),
"ve": (
"VSDY": (
["time", "depth", "latitude", "longitude"],
np.random.rand(T, Z, Y, X),
{
"units": "m/s",
"standard_name": "northward_sea_water_velocity_due_to_ekman_drift",
"long_name": "Depth Ekman driven velocity : meridional component",
"grid_mapping": "crs",
"units": "m s-1",
"standard_name": "sea_surface_wave_stokes_drift_y_velocity",
"long_name": "Stokes drift V",
"WMO_code": 216,
"cell_methods": "time:point area:mean",
"missing_value": -32767,
"type_of_analysis": "spectral analysis",
},
),
},
Expand Down Expand Up @@ -1244,7 +1250,7 @@ def _CROCO_idealized():

datasets = {
"ds_copernicusmarine": _copernicusmarine(),
"ds_copernicusmarine_globcurrent": _copernicusmarine_globcurrent(),
"ds_copernicusmarine_waves": _copernicusmarine_waves(),
"ds_NEMO_MOI_U": _NEMO_MOI_U(),
"ds_NEMO_MOI_V": _NEMO_MOI_V(),
"ds_CESM": _CESM(),
Expand Down
155 changes: 155 additions & 0 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING

import cf_xarray # noqa: F401
import numpy as np
import xarray as xr
import xgcm
Expand All @@ -12,6 +13,8 @@
from parcels._core.utils.time import is_compatible as datetime_is_compatible
from parcels._typing import Mesh
from parcels.field import Field, VectorField
from parcels.tools.converters import Geographic, GeographicPolar
from parcels.tools.loggers import logger
from parcels.xgrid import XGrid

if TYPE_CHECKING:
Expand Down Expand Up @@ -174,6 +177,75 @@ def gridset(self) -> list[BaseGrid]:
grids.append(field.grid)
return grids

def from_copernicusmarine(ds: xr.Dataset):
"""Create a FieldSet from a Copernicus Marine Service xarray.Dataset.

Parameters
----------
ds : xarray.Dataset
xarray.Dataset as obtained from the copernicusmarine toolbox.

Returns
-------
FieldSet
FieldSet object containing the fields from the dataset that can be used for a Parcels simulation.

Notes
-----
See https://help.marine.copernicus.eu/en/collections/9080063-copernicus-marine-toolbox for more information on the copernicusmarine toolbox.
The toolbox to ingest data from most of the products on the Copernicus Marine Service (https://data.marine.copernicus.eu/products) into an xarray.Dataset.
You can use indexing and slicing to select a subset of the data before passing it to this function.
Note that most Parcels uses will require both U and V fields to be present in the dataset. This function will try to find out which variables in the dataset correspond to U and V.
To override the automatic detection, rename the appropriate variables in your dataset to 'U' and 'V' before passing it to this function.

"""
ds = ds.copy()
ds = _discover_copernicusmarine_U_and_V(ds)
expected_axes = set("XYZT") # TODO: Update after we have support for 2D spatial fields
if missing_axes := (expected_axes - set(ds.cf.axes)):
raise ValueError(
f"Dataset missing axes {missing_axes} to have coordinates for all {expected_axes} axes according to CF conventions."
)

ds = _rename_coords_copernicusmarine(ds)
grid = XGrid(
xgcm.Grid(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get a lot of warnings

  /Users/erik/anaconda3/envs/parcels-v4/lib/python3.12/site-packages/xgcm/grid.py:196: DeprecationWarning: The `periodic` argument will be deprecated. To preserve previous behavior supply `boundary = 'periodic'.

Should we fix these here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get a lot of warnings

  /Users/erik/anaconda3/envs/parcels-v4/lib/python3.12/site-packages/xgcm/grid.py:196: DeprecationWarning: The `periodic` argument will be deprecated. To preserve previous behavior supply `boundary = 'periodic'.

Should we fix these here?

Investigating, but not sure how simple this is (might be upstream) - but merging for now so others can easily test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be upstream

Indeed upstream - we can't really get around the warnings xgcm/xgcm#678 .

ds,
coords={
"X": {
"left": "lon",
},
"Y": {
"left": "lat",
},
"Z": {
"left": "depth",
},
"T": {
"center": "time",
},
},
autoparse_metadata=False,
)
)

U = Field("U", ds["U"], grid)
V = Field("V", ds["V"], grid)

U.units = GeographicPolar()
V.units = Geographic()

fields = {"U": U, "V": V}
for varname in set(ds.data_vars) - set(fields.keys()):
fields[varname] = Field(varname, ds[varname], grid)

if "U" in fields and "V" in fields:
if "W" in fields:
fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"])
else:
fields["UV"] = VectorField("UV", fields["U"], fields["V"])
return FieldSet(list(fields.values()))


class CalendarError(Exception): # TODO: Move to a parcels errors module
"""Exception raised when the calendar of a field is not compatible with the rest of the Fields. The user should ensure that they only add fields to a FieldSet that have compatible CFtime calendars."""
Expand Down Expand Up @@ -206,3 +278,86 @@ def _datetime_to_msg(example_datetime: TimeLike) -> str:

def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) -> str:
return f"Expected field {field.name!r} to have calendar compatible with datetime object {_datetime_to_msg(reference_datetime)}. Got field with calendar {_datetime_to_msg(field.time_interval.left)}. Have you considered using xarray to update the time dimension of the dataset to have a compatible calendar?"


_COPERNICUS_MARINE_AXIS_VARNAMES = {
"X": "lon",
"Y": "lat",
"Z": "depth",
"T": "time",
}


def _rename_coords_copernicusmarine(ds):
try:
for axis, [coord] in ds.cf.axes.items():
ds = ds.rename({coord: _COPERNICUS_MARINE_AXIS_VARNAMES[axis]})
except ValueError as e:
raise ValueError(f"Multiple coordinates found for Copernicus dataset on axis '{axis}'. Check your data.") from e
return ds


def _discover_copernicusmarine_U_and_V(ds: xr.Dataset) -> xr.Dataset:
# Assumes that the dataset has U and V data

cf_UV_standard_name_fallbacks = [
(
"eastward_sea_water_velocity",
"northward_sea_water_velocity",
), # GLOBAL_ANALYSISFORECAST_PHY_001_024, MEDSEA_ANALYSISFORECAST_PHY_006_013, BALTICSEA_ANALYSISFORECAST_PHY_003_006, BLKSEA_ANALYSISFORECAST_PHY_007_001, IBI_ANALYSISFORECAST_PHY_005_001, NWSHELF_ANALYSISFORECAST_PHY_004_013, MULTIOBS_GLO_PHY_MYNRT_015_003, MULTIOBS_GLO_PHY_W_3D_REP_015_007
(
"surface_geostrophic_eastward_sea_water_velocity",
"surface_geostrophic_northward_sea_water_velocity",
), # SEALEVEL_GLO_PHY_L4_MY_008_047, SEALEVEL_EUR_PHY_L4_NRT_008_060
(
"geostrophic_eastward_sea_water_velocity",
"geostrophic_northward_sea_water_velocity",
), # MULTIOBS_GLO_PHY_TSUV_3D_MYNRT_015_012
(
"sea_surface_wave_stokes_drift_x_velocity",
"sea_surface_wave_stokes_drift_y_velocity",
), # GLOBAL_ANALYSISFORECAST_WAV_001_027, MEDSEA_MULTIYEAR_WAV_006_012, ARCTIC_ANALYSIS_FORECAST_WAV_002_014, BLKSEA_ANALYSISFORECAST_WAV_007_003, IBI_ANALYSISFORECAST_WAV_005_005, NWSHELF_ANALYSISFORECAST_WAV_004_014
("sea_water_x_velocity", "sea_water_y_velocity"), # ARCTIC_ANALYSISFORECAST_PHY_002_001
(
"eastward_sea_water_velocity_vertical_mean_over_pelagic_layer",
"northward_sea_water_velocity_vertical_mean_over_pelagic_layer",
), # GLOBAL_MULTIYEAR_BGC_001_033
]

if "U" in ds and "V" in ds:
return ds # U and V already present
elif "U" in ds or "V" in ds:
raise ValueError(
"Dataset has only one of the two variables 'U' and 'V'. Please rename the appropriate variable in your dataset to have both 'U' and 'V' for Parcels simulation."
)

for cf_standard_name_U, cf_standard_name_V in cf_UV_standard_name_fallbacks:
if cf_standard_name_U in ds.cf.standard_names:
if cf_standard_name_V not in ds.cf.standard_names:
raise ValueError(
f"Dataset has variable with CF standard name {cf_standard_name_U!r}, "
f"but not the matching variable with CF standard name {cf_standard_name_V!r}. "
"Please rename the appropriate variables in your dataset to have both 'U' and 'V' for Parcels simulation."
)
else:
continue

ds = _ds_rename_using_standard_names(ds, {cf_standard_name_U: "U", cf_standard_name_V: "V"})
break
else:
raise ValueError(
f"Could not find variables 'U' and 'V' in dataset, nor any of the fallback CF standard names "
f"{cf_UV_standard_name_fallbacks}. Please rename the appropriate variables to 'U' and 'V' in "
"your dataset for the Parcels simulation."
)
return ds


def _ds_rename_using_standard_names(ds: xr.Dataset, name_dict: dict[str, str]) -> xr.Dataset:
for standard_name, rename_to in name_dict.items():
name = ds.cf[standard_name].name
ds = ds.rename({name: rename_to})
logger.info(
f"cf_xarray found variable {name!r} with CF standard name {standard_name!r} in dataset, renamed it to {rename_to!r} for Parcels simulation."
)
return ds
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"uxarray",
"pooch",
"xgcm >=0.9.0",
"cf_xarray",
]

[project.urls]
Expand Down Expand Up @@ -106,6 +107,7 @@ pydata-sphinx-theme = "*"
sphinx-autobuild = "*"
myst-parser = "*"
sphinxcontrib-mermaid = "*"
cf_xarray = "*"

[tool.pixi.pypi-dependencies]
parcels = { path = ".", editable = true }
Expand Down
10 changes: 5 additions & 5 deletions tests/v4/test_basegrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from parcels.basegrid import BaseGrid


class TestGrid(BaseGrid):
class MockGrid(BaseGrid):
def __init__(self, axis_dim: dict[str, int]):
self.axis_dim = axis_dim

Expand All @@ -26,10 +26,10 @@ def get_axis_dim(self, axis: str) -> int:
@pytest.mark.parametrize(
"grid",
[
TestGrid({"Z": 10, "Y": 20, "X": 30}),
TestGrid({"Z": 5, "Y": 15}),
TestGrid({"Z": 8}),
TestGrid({"Z": 12, "FACE": 25}),
MockGrid({"Z": 10, "Y": 20, "X": 30}),
MockGrid({"Z": 5, "Y": 15}),
MockGrid({"Z": 8}),
MockGrid({"Z": 12, "FACE": 25}),
],
)
def test_basegrid_ravel_unravel_index(grid):
Expand Down
34 changes: 31 additions & 3 deletions tests/v4/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import pytest
import xarray as xr

from parcels._datasets.structured.circulation_models import (
datasets as datasets_circulation_models, # noqa: F401
) # just making sure the import works. Will eventually be used in tests
from parcels._datasets.structured.circulation_models import datasets as datasets_circulation_models
from parcels._datasets.structured.generic import T as T_structured
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels.field import Field, VectorField
Expand Down Expand Up @@ -216,3 +214,33 @@ def test_fieldset_grid_deduplication():
def test_fieldset_add_field_after_pset():
# ? Should it be allowed to add fields (normal or vector) after a ParticleSet has been initialized?
...


_COPERNICUS_DATASETS = [
datasets_circulation_models["ds_copernicusmarine"],
datasets_circulation_models["ds_copernicusmarine_waves"],
]


@pytest.mark.parametrize("ds", _COPERNICUS_DATASETS)
def test_fieldset_from_copernicusmarine(ds, caplog):
fieldset = FieldSet.from_copernicusmarine(ds)
assert "U" in fieldset.fields
assert "V" in fieldset.fields
assert "UV" in fieldset.fields
assert "renamed it to 'U'" in caplog.text
assert "renamed it to 'V'" in caplog.text


@pytest.mark.parametrize("ds", _COPERNICUS_DATASETS)
def test_fieldset_from_copernicusmarine_no_logs(ds, caplog):
ds = ds.copy()
zeros = xr.zeros_like(list(ds.data_vars.values())[0])
ds["U"] = zeros
ds["V"] = zeros

fieldset = FieldSet.from_copernicusmarine(ds)
assert "U" in fieldset.fields
assert "V" in fieldset.fields
assert "UV" in fieldset.fields
assert caplog.text == ""
Loading