Skip to content

Commit 01cfef6

Browse files
Moving the mesh_type to the Grid class
1 parent 62a596c commit 01cfef6

16 files changed

+93
-98
lines changed

parcels/_index_search.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def _search_indices_curvilinear_2d(
305305
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
306306
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))
307307

308-
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
308+
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh_type)
309309
it += 1
310310
if it > maxIterSearch:
311311
print(f"Correct cell not found after {maxIterSearch} iterations")
@@ -408,11 +408,11 @@ def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2
408408
return (zeta, eta, xsi, zi, yi, xi)
409409

410410

411-
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
412-
xi = np.where(xi < 0, (xdim - 2) if sphere_mesh else 0, xi)
413-
xi = np.where(xi > xdim - 2, 0 if sphere_mesh else (xdim - 2), xi)
411+
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh_type: str):
412+
xi = np.where(xi < 0, (xdim - 2) if mesh_type == "spherical" else 0, xi)
413+
xi = np.where(xi > xdim - 2, 0 if mesh_type == "spherical" else (xdim - 2), xi)
414414

415-
xi = np.where(yi > ydim - 2, xdim - xi if sphere_mesh else xi, xi)
415+
xi = np.where(yi > ydim - 2, xdim - xi if mesh_type == "spherical" else xi, xi)
416416

417417
yi = np.where(yi < 0, 0, yi)
418418
yi = np.where(yi > ydim - 2, ydim - 2, yi)

parcels/field.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111

1212
from parcels._core.utils.time import TimeInterval
1313
from parcels._reprs import default_repr
14-
from parcels._typing import (
15-
Mesh,
16-
VectorType,
17-
assert_valid_mesh,
18-
)
14+
from parcels._typing import VectorType
1915
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, ZeroInterpolator
2016
from parcels.particle import KernelParticle
2117
from parcels.tools.converters import (
@@ -114,7 +110,6 @@ def __init__(
114110
name: str,
115111
data: xr.DataArray | ux.UxDataArray,
116112
grid: UxGrid | XGrid,
117-
mesh_type: Mesh = "flat",
118113
interp_method: Callable | None = None,
119114
):
120115
if not isinstance(data, (ux.UxDataArray, xr.DataArray)):
@@ -126,8 +121,6 @@ def __init__(
126121
if not isinstance(grid, (UxGrid, XGrid)):
127122
raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.")
128123

129-
assert_valid_mesh(mesh_type)
130-
131124
_assert_compatible_combination(data, grid)
132125

133126
if isinstance(grid, XGrid):
@@ -155,8 +148,6 @@ def __init__(
155148
e.add_note(f"Error validating field {name!r}.")
156149
raise e
157150

158-
self._mesh_type = mesh_type
159-
160151
# Setting the interpolation method dynamically
161152
if interp_method is None:
162153
self._interp_method = _DEFAULT_INTERPOLATOR_MAPPING[type(self.grid)]
@@ -166,12 +157,10 @@ def __init__(
166157

167158
self.igrid = -1 # Default the grid index to -1
168159

169-
if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()):
160+
if self.grid._mesh_type == "flat" or (self.name not in unitconverters_map.keys()):
170161
self.units = UnitConverter()
171-
elif self._mesh_type == "spherical":
162+
elif self.grid._mesh_type == "spherical":
172163
self.units = unitconverters_map[self.name]
173-
else:
174-
raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'")
175164

176165
if self.data.shape[0] > 1:
177166
if "time" not in self.data.coords:

parcels/kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth
143143
stacklevel=2,
144144
)
145145
self.fieldset.add_constant("RK45_tol", 10)
146-
if self.fieldset.U.grid.mesh == "spherical":
146+
if self.fieldset.U.grid._mesh_type == "spherical":
147147
self.fieldset.RK45_tol /= (
148148
1852 * 60
149149
) # TODO does not account for zonal variation in meter -> degree conversion

parcels/spatialhash.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _hash_index2d(self, coords):
8686
as the source grid coordinates
8787
"""
8888
# Wrap longitude to [-180, 180]
89-
if self._source_grid.mesh == "spherical":
89+
if self._source_grid._mesh_type == "spherical":
9090
lon = (coords[:, 1] + 180.0) % (360.0) - 180.0
9191
else:
9292
lon = coords[:, 1]

parcels/uxgrid.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import uxarray as ux
77

8+
from parcels._typing import assert_valid_mesh
89
from parcels.spatialhash import _barycentric_coordinates
910
from parcels.tools.statuscodes import FieldOutOfBoundError
1011
from parcels.xgrid import _search_1d_array
@@ -20,7 +21,7 @@ class UxGrid(BaseGrid):
2021
for interpolation on unstructured grids.
2122
"""
2223

23-
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
24+
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh_type="flat") -> UxGrid:
2425
"""
2526
Initializes the UxGrid with a uxarray grid and vertical coordinate array.
2627
@@ -32,13 +33,18 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
3233
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
3334
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
3435
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
36+
mesh_type : str, optional
37+
The type of mesh used for the grid. Either "flat" (default) or "spherical".
3538
"""
3639
self.uxgrid = grid
3740
if not isinstance(z, ux.UxDataArray):
3841
raise TypeError("z must be an instance of ux.UxDataArray")
3942
if z.ndim != 1:
4043
raise ValueError("z must be a 1D array of vertical coordinates")
4144
self.z = z
45+
self._mesh_type = mesh_type
46+
47+
assert_valid_mesh(mesh_type)
4248

4349
@property
4450
def depth(self):

parcels/xgrid.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from parcels import xgcm
1010
from parcels._index_search import _search_indices_curvilinear_2d
11+
from parcels._typing import assert_valid_mesh
1112
from parcels.basegrid import BaseGrid
1213
from parcels.spatialhash import SpatialHash
1314

@@ -95,17 +96,19 @@ class XGrid(BaseGrid):
9596
9697
"""
9798

98-
def __init__(self, grid: xgcm.Grid, mesh="flat"):
99+
def __init__(self, grid: xgcm.Grid, mesh_type="flat"):
99100
self.xgcm_grid = grid
100-
self.mesh = mesh
101+
self._mesh_type = mesh_type
101102
self._spatialhash = None
102103
ds = grid._ds
103104

104105
if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
105106
assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes)
106107

108+
assert_valid_mesh(mesh_type)
109+
107110
@classmethod
108-
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
111+
def from_dataset(cls, ds: xr.Dataset, mesh_type="flat", xgcm_kwargs=None):
109112
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
110113
if xgcm_kwargs is None:
111114
xgcm_kwargs = {}
@@ -114,7 +117,7 @@ def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
114117

115118
ds = _drop_field_data(ds)
116119
grid = xgcm.Grid(ds, **xgcm_kwargs)
117-
return cls(grid, mesh=mesh)
120+
return cls(grid, mesh_type=mesh_type)
118121

119122
@property
120123
def axes(self) -> list[_XGRID_AXES]:

tests/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def create_fieldset_zeros_conversion(mesh_type="spherical", xdim=200, ydim=100)
7474
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type)
7575
ds["lon"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, xdim)
7676
ds["lat"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, ydim)
77-
grid = XGrid.from_dataset(ds)
78-
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
79-
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
77+
grid = XGrid.from_dataset(ds, mesh_type=mesh_type)
78+
U = Field("U", ds["U"], grid, interp_method=XLinear)
79+
V = Field("V", ds["V"], grid, interp_method=XLinear)
8080

8181
UV = VectorField("UV", U, V)
8282
return FieldSet([U, V, UV])

tests/v4/test_advection.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def test_advection_zonal(mesh_type, npart=10):
3737
"""Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`."""
3838
ds = simple_UV_dataset(mesh_type=mesh_type)
3939
ds["U"].data[:] = 1.0
40-
grid = XGrid.from_dataset(ds)
41-
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
42-
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
40+
grid = XGrid.from_dataset(ds, mesh_type=mesh_type)
41+
U = Field("U", ds["U"], grid, interp_method=XLinear)
42+
V = Field("V", ds["V"], grid, interp_method=XLinear)
4343
UV = VectorField("UV", U, V)
4444
fieldset = FieldSet([U, V, UV])
4545

@@ -208,9 +208,9 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read
208208

209209
def test_radialrotation(npart=10):
210210
ds = radial_rotation_dataset()
211-
grid = XGrid.from_dataset(ds)
212-
U = parcels.Field("U", ds["U"], grid, mesh_type="flat", interp_method=XLinear)
213-
V = parcels.Field("V", ds["V"], grid, mesh_type="flat", interp_method=XLinear)
211+
grid = XGrid.from_dataset(ds, mesh_type="flat")
212+
U = parcels.Field("U", ds["U"], grid, interp_method=XLinear)
213+
V = parcels.Field("V", ds["V"], grid, interp_method=XLinear)
214214
UV = parcels.VectorField("UV", U, V)
215215
fieldset = parcels.FieldSet([U, V, UV])
216216

tests/v4/test_diffusion.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ def test_fieldKh_Brownian(mesh_type):
2424
ds = simple_UV_dataset(dims=(2, 1, 2, 2), mesh_type=mesh_type)
2525
ds["lon"].data = np.array([-1e6, 1e6])
2626
ds["lat"].data = np.array([-1e6, 1e6])
27-
grid = XGrid.from_dataset(ds)
28-
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
29-
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
27+
grid = XGrid.from_dataset(ds, mesh_type=mesh_type)
28+
U = Field("U", ds["U"], grid, interp_method=XLinear)
29+
V = Field("V", ds["V"], grid, interp_method=XLinear)
3030
ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_zonal))
3131
ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_meridional))
32-
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, mesh_type=mesh_type, interp_method=XLinear)
33-
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, mesh_type=mesh_type, interp_method=XLinear)
32+
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear)
33+
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear)
3434
UV = VectorField("UV", U, V)
3535
fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional])
3636

@@ -61,18 +61,18 @@ def test_fieldKh_SpatiallyVaryingDiffusion(mesh_type, kernel):
6161
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type)
6262
ds["lon"].data = np.linspace(-1e6, 1e6, xdim)
6363
ds["lat"].data = np.linspace(-1e6, 1e6, ydim)
64-
grid = XGrid.from_dataset(ds)
65-
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
66-
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
64+
grid = XGrid.from_dataset(ds, mesh_type=mesh_type)
65+
U = Field("U", ds["U"], grid, interp_method=XLinear)
66+
V = Field("V", ds["V"], grid, interp_method=XLinear)
6767

6868
Kh = np.zeros((ydim, xdim), dtype=np.float32)
6969
for x in range(xdim):
7070
Kh[:, x] = np.tanh(ds["lon"][x] / ds["lon"][-1] * 10.0) * xdim / 2.0 + xdim / 2.0 + 100.0
7171

7272
ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, ydim, xdim), Kh))
7373
ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, ydim, xdim), Kh))
74-
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, mesh_type=mesh_type, interp_method=XLinear)
75-
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, mesh_type=mesh_type, interp_method=XLinear)
74+
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear)
75+
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear)
7676
UV = VectorField("UV", U, V)
7777
fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional])
7878
fieldset.add_constant("dres", float(ds["lon"][1] - ds["lon"][0]))

tests/v4/test_field.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ def test_field_init_param_types():
2525
with pytest.raises(ValueError, match="Expected `grid` to be a parcels UxGrid, or parcels XGrid"):
2626
Field(name="test", data=data["data_g"], grid=123)
2727

28-
with pytest.raises(ValueError, match="Invalid value 'invalid'. Valid options are.*"):
29-
Field(name="test", data=data["data_g"], grid=grid, mesh_type="invalid")
30-
3128

3229
@pytest.mark.parametrize(
3330
"data,grid",
@@ -107,7 +104,7 @@ def test_field_init_fail_on_float_time_dim():
107104
)
108105
def test_field_time_interval(data, grid):
109106
"""Test creating a field."""
110-
field = Field(name="test_field", data=data, grid=grid, mesh_type="flat")
107+
field = Field(name="test_field", data=data, grid=grid)
111108
assert field.time_interval.left == np.datetime64("2000-01-01")
112109
assert field.time_interval.right == np.datetime64("2001-01-01")
113110

0 commit comments

Comments
 (0)