66
77import cf_xarray # noqa: F401
88import numpy as np
9+ import uxarray as ux
910import xarray as xr
1011import xgcm
1112
1415from parcels ._core .utils .string import _assert_str_and_python_varname
1516from parcels ._core .utils .time import get_datetime_type_calendar
1617from parcels ._core .utils .time import is_compatible as datetime_is_compatible
18+ from parcels ._core .uxgrid import UxGrid
1719from parcels ._core .xgrid import _DEFAULT_XGCM_KWARGS , XGrid
1820from parcels ._logger import logger
1921from parcels ._typing import Mesh
22+ from parcels .interpolators import UXPiecewiseConstantFace , UXPiecewiseLinearNode , XConstantField , XLinear
2023
2124if TYPE_CHECKING :
2225 from parcels ._core .basegrid import BaseGrid
@@ -116,7 +119,7 @@ def add_field(self, field: Field, name: str | None = None):
116119
117120 self .fields [name ] = field
118121
119- def add_constant_field (self , name : str , value , mesh : Mesh = "flat " ):
122+ def add_constant_field (self , name : str , value , mesh : Mesh = "spherical " ):
120123 """Wrapper function to add a Field that is constant in space,
121124 useful e.g. when using constant horizontal diffusivity
122125
@@ -134,16 +137,15 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
134137 correction for zonal velocity U near the poles.
135138 2. flat: No conversion, lat/lon are assumed to be in m.
136139 """
137- ds = xr .Dataset ({name : (["time" , "lat" , "lon" , "depth" ], np .full ((1 , 1 , 1 , 1 ), value ))})
138- grid = XGrid (xgcm .Grid (ds , ** _DEFAULT_XGCM_KWARGS ))
139- self .add_field (
140- Field (
141- name ,
142- ds [name ],
143- grid ,
144- interp_method = None , # TODO : Need to define an interpolation method for constants
145- )
140+ ds = xr .Dataset (
141+ {name : (["lat" , "lon" ], np .full ((1 , 1 ), value ))},
142+ coords = {"lat" : (["lat" ], [0 ], {"axis" : "Y" }), "lon" : (["lon" ], [0 ], {"axis" : "X" })},
143+ )
144+ xgrid = xgcm .Grid (
145+ ds , coords = {"X" : {"left" : "lon" }, "Y" : {"left" : "lat" }}, autoparse_metadata = False , ** _DEFAULT_XGCM_KWARGS
146146 )
147+ grid = XGrid (xgrid , mesh = mesh )
148+ self .add_field (Field (name , ds [name ], grid , interp_method = XConstantField ))
147149
148150 def add_constant (self , name , value ):
149151 """Add a constant to the FieldSet. Note that all constants are
@@ -238,22 +240,62 @@ def from_copernicusmarine(ds: xr.Dataset):
238240
239241 fields = {}
240242 if "U" in ds .data_vars and "V" in ds .data_vars :
241- fields ["U" ] = Field ("U" , ds ["U" ], grid )
242- fields ["V" ] = Field ("V" , ds ["V" ], grid )
243+ fields ["U" ] = Field ("U" , ds ["U" ], grid , XLinear )
244+ fields ["V" ] = Field ("V" , ds ["V" ], grid , XLinear )
243245 fields ["U" ].units = GeographicPolar ()
244246 fields ["V" ].units = Geographic ()
245247
246248 if "W" in ds .data_vars :
247249 ds ["W" ] -= ds [
248250 "W"
249251 ] # Negate W to convert from up positive to down positive (as that's the direction of positive z)
250- fields ["W" ] = Field ("W" , ds ["W" ], grid )
252+ fields ["W" ] = Field ("W" , ds ["W" ], grid , XLinear )
253+ fields ["UVW" ] = VectorField ("UVW" , fields ["U" ], fields ["V" ], fields ["W" ])
254+ else :
255+ fields ["UV" ] = VectorField ("UV" , fields ["U" ], fields ["V" ])
256+
257+ for varname in set (ds .data_vars ) - set (fields .keys ()):
258+ fields [varname ] = Field (varname , ds [varname ], grid , XLinear )
259+
260+ return FieldSet (list (fields .values ()))
261+
262+ def from_fesom2 (ds : ux .UxDataset ):
263+ """Create a FieldSet from a FESOM2 uxarray.UxDataset.
264+
265+ Parameters
266+ ----------
267+ ds : uxarray.UxDataset
268+ uxarray.UxDataset as obtained from the uxarray package.
269+
270+ Returns
271+ -------
272+ FieldSet
273+ FieldSet object containing the fields from the dataset that can be used for a Parcels simulation.
274+ """
275+ ds = ds .copy ()
276+ ds_dims = list (ds .dims )
277+ if not all (dim in ds_dims for dim in ["time" , "nz" , "nz1" ]):
278+ raise ValueError (
279+ f"Dataset missing one of the required dimensions 'time', 'nz', or 'nz1'. Found dimensions { ds_dims } "
280+ )
281+ grid = UxGrid (ds .uxgrid , z = ds .coords ["nz" ])
282+ ds = _discover_fesom2_U_and_V (ds )
283+
284+ fields = {}
285+ if "U" in ds .data_vars and "V" in ds .data_vars :
286+ fields ["U" ] = Field ("U" , ds ["U" ], grid , _select_uxinterpolator (ds ["U" ]))
287+ fields ["V" ] = Field ("V" , ds ["V" ], grid , _select_uxinterpolator (ds ["U" ]))
288+ fields ["U" ].units = GeographicPolar ()
289+ fields ["V" ].units = Geographic ()
290+
291+ if "W" in ds .data_vars :
292+ fields ["W" ] = Field ("W" , ds ["W" ], grid , _select_uxinterpolator (ds ["U" ]))
251293 fields ["UVW" ] = VectorField ("UVW" , fields ["U" ], fields ["V" ], fields ["W" ])
252294 else :
253295 fields ["UV" ] = VectorField ("UV" , fields ["U" ], fields ["V" ])
254296
255297 for varname in set (ds .data_vars ) - set (fields .keys ()):
256- fields [varname ] = Field (varname , ds [varname ], grid )
298+ fields [varname ] = Field (varname , ds [varname ], grid , _select_uxinterpolator ( ds [ varname ]) )
257299
258300 return FieldSet (list (fields .values ()))
259301
@@ -365,11 +407,86 @@ def _discover_copernicusmarine_U_and_V(ds: xr.Dataset) -> xr.Dataset:
365407 return ds
366408
367409
368- def _ds_rename_using_standard_names (ds : xr .Dataset , name_dict : dict [str , str ]) -> xr .Dataset :
410+ def _discover_fesom2_U_and_V (ds : ux .UxDataset ) -> ux .UxDataset :
411+ # Common variable names for U and V found in UxDatasets
412+ common_fesom_UV = [("unod" , "vnod" ), ("u" , "v" )]
413+ common_fesom_W = ["w" ]
414+
415+ if "W" not in ds :
416+ for common_W in common_fesom_W :
417+ if common_W in ds :
418+ ds = _ds_rename_using_standard_names (ds , {common_W : "W" })
419+ break
420+
421+ if "U" in ds and "V" in ds :
422+ return ds # U and V already present
423+ elif "U" in ds or "V" in ds :
424+ raise ValueError (
425+ "Dataset has only one of the two variables 'U' and 'V'. Please rename the appropriate variable in your dataset to have both 'U' and 'V' for Parcels simulation."
426+ )
427+
428+ for common_U , common_V in common_fesom_UV :
429+ if common_U in ds :
430+ if common_V not in ds :
431+ raise ValueError (
432+ f"Dataset has variable with standard name { common_U !r} , "
433+ f"but not the matching variable with standard name { common_V !r} . "
434+ "Please rename the appropriate variables in your dataset to have both 'U' and 'V' for Parcels simulation."
435+ )
436+ else :
437+ ds = _ds_rename_using_standard_names (ds , {common_U : "U" , common_V : "V" })
438+ break
439+
440+ else :
441+ if common_V in ds :
442+ raise ValueError (
443+ f"Dataset has variable with standard name { common_V !r} , "
444+ f"but not the matching variable with standard name { common_U !r} . "
445+ "Please rename the appropriate variables in your dataset to have both 'U' and 'V' for Parcels simulation."
446+ )
447+ continue
448+
449+ return ds
450+
451+
452+ def _ds_rename_using_standard_names (ds : xr .Dataset | ux .UxDataset , name_dict : dict [str , str ]) -> xr .Dataset :
369453 for standard_name , rename_to in name_dict .items ():
370454 name = ds .cf [standard_name ].name
371455 ds = ds .rename ({name : rename_to })
372456 logger .info (
373457 f"cf_xarray found variable { name !r} with CF standard name { standard_name !r} in dataset, renamed it to { rename_to !r} for Parcels simulation."
374458 )
375459 return ds
460+
461+
462+ def _select_uxinterpolator (da : ux .UxDataArray ):
463+ """Selects the appropriate uxarray interpolator for a given uxarray UxDataArray"""
464+ supported_uxinterp_mapping = {
465+ # (nz1,n_face): face-center laterally, layer centers vertically — piecewise constant
466+ "nz1,n_face" : UXPiecewiseConstantFace ,
467+ # (nz,n_node): node/corner laterally, layer interfaces vertically — barycentric lateral & linear vertical
468+ "nz,n_node" : UXPiecewiseLinearNode ,
469+ }
470+ # Extract only spatial dimensions, neglecting time
471+ da_spatial_dims = tuple (d for d in da .dims if d not in ("time" ,))
472+ if len (da_spatial_dims ) != 2 :
473+ raise ValueError (
474+ "Fields on unstructured grids must have two spatial dimensions, one vertical (nz or nz1) and one lateral (n_face, n_edge, or n_node)"
475+ )
476+
477+ # Construct key (string) for mapping to interpolator
478+ # Find vertical and lateral tokens
479+ vdim = None
480+ ldim = None
481+ for d in da_spatial_dims :
482+ if d in ("nz" , "nz1" ):
483+ vdim = d
484+ if d in ("n_face" , "n_node" ):
485+ ldim = d
486+ # Map to supported interpolators
487+ if vdim and ldim :
488+ key = f"{ vdim } ,{ ldim } "
489+ if key in supported_uxinterp_mapping .keys ():
490+ return supported_uxinterp_mapping [key ]
491+
492+ return None
0 commit comments