Skip to content

Commit f77f5f4

Browse files
merge
2 parents 07a8238 + 807d2ee commit f77f5f4

20 files changed

+136
-143
lines changed

docs/examples/tutorial_stommel_uxarray.ipynb

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
"\n",
8989
"A `UXArray.Dataset` consists of multiple `UXArray.UxDataArray`'s and a `UXArray.UxGrid`. Parcels views general circulation model data through the `Field` and `VectorField` classes. A `Field` is defined by its `name`, `data`, `grid`, and `interp_method`. A `VectorField` can be constructed by using 2 or 3 `Field`'s. The `Field.data` attribute can be either an `XArray.DataArray` or `UXArray.UxDataArray` object. The `Field.grid` attribute is of type `Parcels.XGrid` or `Parcels.UXGrid`. Last, the `interp_method` is a dynamic function that can be set at runtime to define the interpolation procedure for the `Field`. This gives you the flexibility to use one of the pre-defined interpolation methods included with Parcels v4, or to create your own interpolator. \n",
9090
"\n",
91-
"The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions."
91+
"The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions. Setting the `mesh` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees."
9292
]
9393
},
9494
{
@@ -99,7 +99,7 @@
9999
"source": [
100100
"from parcels.uxgrid import UxGrid\n",
101101
"\n",
102-
"grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"])\n",
102+
"grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"], mesh=\"spherical\")\n",
103103
"# You can view the uxgrid object with the following command:\n",
104104
"grid.uxgrid"
105105
]
@@ -112,7 +112,7 @@
112112
"\n",
113113
"In Parcels, grid searching is conducted with respect to the faces. In other words, when a grid index `ei` is provided to an interpolation method, this refers the face index `fi` at vertical layer `zi` (when unraveled). Within the interpolation method, the `field.grid.uxgrid.face_node_connectivity` attribute can be used to obtain the node indices that surround the face. Using these connectivity tables is necessary for properly indexing node registered data.\n",
114114
"\n",
115-
"For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method. Setting the `mesh_type` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees."
115+
"For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method."
116116
]
117117
},
118118
{
@@ -128,21 +128,18 @@
128128
" name=\"U\",\n",
129129
" data=ds.U,\n",
130130
" grid=grid,\n",
131-
" mesh_type=\"spherical\",\n",
132131
" interp_method=UXPiecewiseConstantFace,\n",
133132
")\n",
134133
"V = Field(\n",
135134
" name=\"V\",\n",
136135
" data=ds.V,\n",
137136
" grid=grid,\n",
138-
" mesh_type=\"spherical\",\n",
139137
" interp_method=UXPiecewiseConstantFace,\n",
140138
")\n",
141139
"P = Field(\n",
142140
" name=\"P\",\n",
143141
" data=ds.p,\n",
144142
" grid=grid,\n",
145-
" mesh_type=\"spherical\",\n",
146143
" interp_method=UXPiecewiseConstantFace,\n",
147144
")"
148145
]

parcels/_datasets/structured/generated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import xarray as xr
55

66

7-
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh_type="spherical"):
8-
max_lon = 180.0 if mesh_type == "spherical" else 1e6
7+
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"):
8+
max_lon = 180.0 if mesh == "spherical" else 1e6
99

1010
return xr.Dataset(
1111
{"U": (["time", "depth", "YG", "XG"], np.zeros(dims)), "V": (["time", "depth", "YG", "XG"], np.zeros(dims))},

parcels/_index_search.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parcels._typing import (
99
GridIndexingType,
1010
InterpMethodOption,
11+
Mesh,
1112
)
1213
from parcels.tools.statuscodes import (
1314
FieldOutOfBoundError,
@@ -174,7 +175,7 @@ def _search_indices_rectilinear(
174175
_raise_field_out_of_bound_error(z, y, x)
175176

176177
if field.xdim > 1:
177-
if field._mesh_type != "spherical":
178+
if field._mesh != "spherical":
178179
lon_index = field.lon < x
179180
if lon_index.all():
180181
xi = len(field.lon) - 2
@@ -307,7 +308,7 @@ def _search_indices_curvilinear_2d(
307308
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
308309
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))
309310

310-
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
311+
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh)
311312
it += 1
312313
if it > maxIterSearch:
313314
print(f"Correct cell not found after {maxIterSearch} iterations")
@@ -410,11 +411,11 @@ def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2
410411
return (zeta, eta, xsi, zi, yi, xi)
411412

412413

413-
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
414-
xi = np.where(xi < 0, (xdim - 2) if sphere_mesh else 0, xi)
415-
xi = np.where(xi > xdim - 2, 0 if sphere_mesh else (xdim - 2), xi)
414+
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh):
415+
xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi)
416+
xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi)
416417

417-
xi = np.where(yi > ydim - 2, xdim - xi if sphere_mesh else xi, xi)
418+
xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi)
418419

419420
yi = np.where(yi < 0, 0, yi)
420421
yi = np.where(yi > ydim - 2, ydim - 2, yi)

parcels/application_kernels/interpolation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,22 @@ def CGrid_Velocity(
148148
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
149149
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
150150

151-
if vectorfield._mesh_type == "spherical":
151+
if vectorfield._mesh == "spherical":
152152
px[0] = np.where(px[0] < x - 225, px[0] + 360, px[0])
153153
px[0] = np.where(px[0] > x + 225, px[0] - 360, px[0])
154154
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
155155
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
156156
c1 = i_u._geodetic_distance(
157-
py[0], py[1], px[0], px[1], vectorfield._mesh_type, np.einsum("ij,ji->i", i_u.phi2D_lin(0.0, xsi), py)
157+
py[0], py[1], px[0], px[1], vectorfield._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(0.0, xsi), py)
158158
)
159159
c2 = i_u._geodetic_distance(
160-
py[1], py[2], px[1], px[2], vectorfield._mesh_type, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 1.0), py)
160+
py[1], py[2], px[1], px[2], vectorfield._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 1.0), py)
161161
)
162162
c3 = i_u._geodetic_distance(
163-
py[2], py[3], px[2], px[3], vectorfield._mesh_type, np.einsum("ij,ji->i", i_u.phi2D_lin(1.0, xsi), py)
163+
py[2], py[3], px[2], px[3], vectorfield._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(1.0, xsi), py)
164164
)
165165
c4 = i_u._geodetic_distance(
166-
py[3], py[0], px[3], px[0], vectorfield._mesh_type, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 0.0), py)
166+
py[3], py[0], px[3], px[0], vectorfield._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 0.0), py)
167167
)
168168

169169
lenT = 2 if np.any(tau > 0) else 1
@@ -245,9 +245,9 @@ def CGrid_Velocity(
245245

246246
deg2m = 1852 * 60.0
247247
if applyConversion:
248-
meshJac = (deg2m * deg2m * np.cos(np.deg2rad(y))) if vectorfield._mesh_type == "spherical" else 1
248+
meshJac = (deg2m * deg2m * np.cos(np.deg2rad(y))) if vectorfield._mesh == "spherical" else 1
249249
else:
250-
meshJac = deg2m if vectorfield._mesh_type == "spherical" else 1
250+
meshJac = deg2m if vectorfield._mesh == "spherical" else 1
251251

252252
jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) * meshJac
253253

parcels/field.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from parcels._core.utils.time import TimeInterval
1313
from parcels._reprs import default_repr
1414
from parcels._typing import (
15-
Mesh,
1615
VectorType,
17-
assert_valid_mesh,
1816
)
1917
from parcels.application_kernels.interpolation import CGrid_Velocity, UXPiecewiseLinearNode, XLinear, ZeroInterpolator
2018
from parcels.particle import KernelParticle
@@ -86,7 +84,7 @@ class Field:
8684
-----
8785
The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata.
8886
* dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon])
89-
* attrs: (location, mesh, mesh_type)
87+
* attrs: (location, mesh, mesh)
9088
9189
When using a xarray.DataArray object,
9290
* The xarray.DataArray object must have the "location" and "mesh" attributes set.
@@ -114,7 +112,6 @@ def __init__(
114112
name: str,
115113
data: xr.DataArray | ux.UxDataArray,
116114
grid: UxGrid | XGrid,
117-
mesh_type: Mesh = "flat",
118115
interp_method: Callable | None = None,
119116
):
120117
if not isinstance(data, (ux.UxDataArray, xr.DataArray)):
@@ -126,8 +123,6 @@ def __init__(
126123
if not isinstance(grid, (UxGrid, XGrid)):
127124
raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.")
128125

129-
assert_valid_mesh(mesh_type)
130-
131126
_assert_compatible_combination(data, grid)
132127

133128
if isinstance(grid, XGrid):
@@ -155,8 +150,6 @@ def __init__(
155150
e.add_note(f"Error validating field {name!r}.")
156151
raise e
157152

158-
self._mesh_type = mesh_type
159-
160153
# Setting the interpolation method dynamically
161154
if interp_method is None:
162155
self._interp_method = _DEFAULT_INTERPOLATOR_MAPPING[type(self.grid)]
@@ -166,12 +159,10 @@ def __init__(
166159

167160
self.igrid = -1 # Default the grid index to -1
168161

169-
if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()):
162+
if self.grid._mesh == "flat" or (self.name not in unitconverters_map.keys()):
170163
self.units = UnitConverter()
171-
elif self._mesh_type == "spherical":
164+
elif self.grid._mesh == "spherical":
172165
self.units = unitconverters_map[self.name]
173-
else:
174-
raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'")
175166

176167
if self.data.shape[0] > 1:
177168
if "time" not in self.data.coords:
@@ -294,9 +285,8 @@ def __init__(
294285
_assert_same_function_signature(vector_interp_method, ref=CGrid_Velocity)
295286
self._vector_interp_method = vector_interp_method
296287

297-
if U._mesh_type != V._mesh_type or (W and U._mesh_type != W._mesh_type):
298-
raise ValueError(f"Inconsistent mesh types: {U._mesh_type}, {V._mesh_type}, {W._mesh_type}")
299-
self._mesh_type = U._mesh_type
288+
if U.grid._mesh != V.grid._mesh or (W and U.grid._mesh != W.grid._mesh):
289+
raise ValueError(f"Inconsistent mesh types: {U.grid._mesh}, {V.grid._mesh}, {W.grid._mesh}")
300290

301291
def __repr__(self):
302292
return f"""<{type(self).__name__}>

parcels/kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth
143143
stacklevel=2,
144144
)
145145
self.fieldset.add_constant("RK45_tol", 10)
146-
if self.fieldset.U.grid.mesh == "spherical":
146+
if self.fieldset.U.grid._mesh == "spherical":
147147
self.fieldset.RK45_tol /= (
148148
1852 * 60
149149
) # TODO does not account for zonal variation in meter -> degree conversion

parcels/spatialhash.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _hash_index2d(self, coords):
8686
as the source grid coordinates
8787
"""
8888
# Wrap longitude to [-180, 180]
89-
if self._source_grid.mesh == "spherical":
89+
if self._source_grid._mesh == "spherical":
9090
lon = (coords[:, 1] + 180.0) % (360.0) - 180.0
9191
else:
9292
lon = coords[:, 1]

parcels/uxgrid.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import uxarray as ux
77

8+
from parcels._typing import assert_valid_mesh
89
from parcels.spatialhash import _barycentric_coordinates
910
from parcels.tools.statuscodes import FieldOutOfBoundError
1011
from parcels.xgrid import _search_1d_array
@@ -20,7 +21,7 @@ class UxGrid(BaseGrid):
2021
for interpolation on unstructured grids.
2122
"""
2223

23-
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
24+
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid:
2425
"""
2526
Initializes the UxGrid with a uxarray grid and vertical coordinate array.
2627
@@ -32,13 +33,18 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
3233
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
3334
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
3435
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
36+
mesh : str, optional
37+
The type of mesh used for the grid. Either "flat" (default) or "spherical".
3538
"""
3639
self.uxgrid = grid
3740
if not isinstance(z, ux.UxDataArray):
3841
raise TypeError("z must be an instance of ux.UxDataArray")
3942
if z.ndim != 1:
4043
raise ValueError("z must be a 1D array of vertical coordinates")
4144
self.z = z
45+
self._mesh = mesh
46+
47+
assert_valid_mesh(mesh)
4248

4349
@property
4450
def depth(self):

parcels/xgrid.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from parcels import xgcm
1010
from parcels._index_search import _search_indices_curvilinear_2d
11+
from parcels._typing import assert_valid_mesh
1112
from parcels.basegrid import BaseGrid
1213
from parcels.spatialhash import SpatialHash
1314

@@ -97,7 +98,7 @@ class XGrid(BaseGrid):
9798

9899
def __init__(self, grid: xgcm.Grid, mesh="flat"):
99100
self.xgcm_grid = grid
100-
self.mesh = mesh
101+
self._mesh = mesh
101102
self._spatialhash = None
102103
ds = grid._ds
103104
if "lon" in ds and hasattr(ds["lon"], "load"):
@@ -108,6 +109,8 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"):
108109
if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
109110
assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes)
110111

112+
assert_valid_mesh(mesh)
113+
111114
@classmethod
112115
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
113116
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release

tests/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ def create_fieldset_global(xdim=200, ydim=100):
6868
return FieldSet.from_data(data, dimensions, mesh="flat")
6969

7070

71-
def create_fieldset_zeros_conversion(mesh_type="spherical", xdim=200, ydim=100) -> FieldSet:
71+
def create_fieldset_zeros_conversion(mesh="spherical", xdim=200, ydim=100) -> FieldSet:
7272
"""Zero velocity field with lat and lon determined by a conversion factor."""
73-
mesh_conversion = 1 / 1852.0 / 60 if mesh_type == "spherical" else 1
74-
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type)
73+
mesh_conversion = 1 / 1852.0 / 60 if mesh == "spherical" else 1
74+
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh=mesh)
7575
ds["lon"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, xdim)
7676
ds["lat"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, ydim)
77-
grid = XGrid.from_dataset(ds)
78-
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
79-
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
77+
grid = XGrid.from_dataset(ds, mesh=mesh)
78+
U = Field("U", ds["U"], grid, interp_method=XLinear)
79+
V = Field("V", ds["V"], grid, interp_method=XLinear)
8080

8181
UV = VectorField("UV", U, V)
8282
return FieldSet([U, V, UV])

0 commit comments

Comments
 (0)