Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions docs/examples/tutorial_stommel_uxarray.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
"\n",
"A `UXArray.Dataset` consists of multiple `UXArray.UxDataArray`'s and a `UXArray.UxGrid`. Parcels views general circulation model data through the `Field` and `VectorField` classes. A `Field` is defined by its `name`, `data`, `grid`, and `interp_method`. A `VectorField` can be constructed by using 2 or 3 `Field`'s. The `Field.data` attribute can be either an `XArray.DataArray` or `UXArray.UxDataArray` object. The `Field.grid` attribute is of type `Parcels.XGrid` or `Parcels.UXGrid`. Last, the `interp_method` is a dynamic function that can be set at runtime to define the interpolation procedure for the `Field`. This gives you the flexibility to use one of the pre-defined interpolation methods included with Parcels v4, or to create your own interpolator. \n",
"\n",
"The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions."
"The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions. Setting the `mesh` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees."
]
},
{
Expand All @@ -99,7 +99,7 @@
"source": [
"from parcels.uxgrid import UxGrid\n",
"\n",
"grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"])\n",
"grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"], mesh=\"spherical\")\n",
"# You can view the uxgrid object with the following command:\n",
"grid.uxgrid"
]
Expand All @@ -112,7 +112,7 @@
"\n",
"In Parcels, grid searching is conducted with respect to the faces. In other words, when a grid index `ei` is provided to an interpolation method, this refers the face index `fi` at vertical layer `zi` (when unraveled). Within the interpolation method, the `field.grid.uxgrid.face_node_connectivity` attribute can be used to obtain the node indices that surround the face. Using these connectivity tables is necessary for properly indexing node registered data.\n",
"\n",
"For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method. Setting the `mesh_type` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees."
"For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method."
]
},
{
Expand All @@ -128,21 +128,18 @@
" name=\"U\",\n",
" data=ds.U,\n",
" grid=grid,\n",
" mesh_type=\"spherical\",\n",
" interp_method=UXPiecewiseConstantFace,\n",
")\n",
"V = Field(\n",
" name=\"V\",\n",
" data=ds.V,\n",
" grid=grid,\n",
" mesh_type=\"spherical\",\n",
" interp_method=UXPiecewiseConstantFace,\n",
")\n",
"P = Field(\n",
" name=\"P\",\n",
" data=ds.p,\n",
" grid=grid,\n",
" mesh_type=\"spherical\",\n",
" interp_method=UXPiecewiseConstantFace,\n",
")"
]
Expand Down
4 changes: 2 additions & 2 deletions parcels/_datasets/structured/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import xarray as xr


def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh_type="spherical"):
max_lon = 180.0 if mesh_type == "spherical" else 1e6
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"):
max_lon = 180.0 if mesh == "spherical" else 1e6

return xr.Dataset(
{"U": (["time", "depth", "YG", "XG"], np.zeros(dims)), "V": (["time", "depth", "YG", "XG"], np.zeros(dims))},
Expand Down
13 changes: 7 additions & 6 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from parcels._typing import (
GridIndexingType,
InterpMethodOption,
Mesh,
)
from parcels.tools.statuscodes import (
FieldOutOfBoundError,
Expand Down Expand Up @@ -174,7 +175,7 @@ def _search_indices_rectilinear(
_raise_field_out_of_bound_error(z, y, x)

if field.xdim > 1:
if field._mesh_type != "spherical":
if field._mesh != "spherical":
lon_index = field.lon < x
if lon_index.all():
xi = len(field.lon) - 2
Expand Down Expand Up @@ -305,7 +306,7 @@ def _search_indices_curvilinear_2d(
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))

(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh)
it += 1
if it > maxIterSearch:
print(f"Correct cell not found after {maxIterSearch} iterations")
Expand Down Expand Up @@ -408,11 +409,11 @@ def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2
return (zeta, eta, xsi, zi, yi, xi)


def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
xi = np.where(xi < 0, (xdim - 2) if sphere_mesh else 0, xi)
xi = np.where(xi > xdim - 2, 0 if sphere_mesh else (xdim - 2), xi)
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh):
xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi)
xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi)

xi = np.where(yi > ydim - 2, xdim - xi if sphere_mesh else xi, xi)
xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi)

yi = np.where(yi < 0, 0, yi)
yi = np.where(yi > ydim - 2, ydim - 2, yi)
Expand Down
19 changes: 4 additions & 15 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@

from parcels._core.utils.time import TimeInterval
from parcels._reprs import default_repr
from parcels._typing import (
Mesh,
VectorType,
assert_valid_mesh,
)
from parcels._typing import VectorType
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, ZeroInterpolator
from parcels.particle import KernelParticle
from parcels.tools.converters import (
Expand Down Expand Up @@ -86,7 +82,7 @@ class Field:
-----
The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata.
* dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon])
* attrs: (location, mesh, mesh_type)
* attrs: (location, mesh, mesh)

When using a xarray.DataArray object,
* The xarray.DataArray object must have the "location" and "mesh" attributes set.
Expand Down Expand Up @@ -114,7 +110,6 @@ def __init__(
name: str,
data: xr.DataArray | ux.UxDataArray,
grid: UxGrid | XGrid,
mesh_type: Mesh = "flat",
interp_method: Callable | None = None,
):
if not isinstance(data, (ux.UxDataArray, xr.DataArray)):
Expand All @@ -126,8 +121,6 @@ def __init__(
if not isinstance(grid, (UxGrid, XGrid)):
raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.")

assert_valid_mesh(mesh_type)

_assert_compatible_combination(data, grid)

if isinstance(grid, XGrid):
Expand Down Expand Up @@ -155,8 +148,6 @@ def __init__(
e.add_note(f"Error validating field {name!r}.")
raise e

self._mesh_type = mesh_type

# Setting the interpolation method dynamically
if interp_method is None:
self._interp_method = _DEFAULT_INTERPOLATOR_MAPPING[type(self.grid)]
Expand All @@ -166,12 +157,10 @@ def __init__(

self.igrid = -1 # Default the grid index to -1

if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()):
if self.grid._mesh == "flat" or (self.name not in unitconverters_map.keys()):
self.units = UnitConverter()
elif self._mesh_type == "spherical":
elif self.grid._mesh == "spherical":
self.units = unitconverters_map[self.name]
else:
raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'")

if self.data.shape[0] > 1:
if "time" not in self.data.coords:
Expand Down
2 changes: 1 addition & 1 deletion parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth
stacklevel=2,
)
self.fieldset.add_constant("RK45_tol", 10)
if self.fieldset.U.grid.mesh == "spherical":
if self.fieldset.U.grid._mesh == "spherical":
self.fieldset.RK45_tol /= (
1852 * 60
) # TODO does not account for zonal variation in meter -> degree conversion
Expand Down
2 changes: 1 addition & 1 deletion parcels/spatialhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _hash_index2d(self, coords):
as the source grid coordinates
"""
# Wrap longitude to [-180, 180]
if self._source_grid.mesh == "spherical":
if self._source_grid._mesh == "spherical":
lon = (coords[:, 1] + 180.0) % (360.0) - 180.0
else:
lon = coords[:, 1]
Expand Down
8 changes: 7 additions & 1 deletion parcels/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import uxarray as ux

from parcels._typing import assert_valid_mesh
from parcels.spatialhash import _barycentric_coordinates
from parcels.tools.statuscodes import FieldOutOfBoundError
from parcels.xgrid import _search_1d_array
Expand All @@ -20,7 +21,7 @@ class UxGrid(BaseGrid):
for interpolation on unstructured grids.
"""

def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid:
"""
Initializes the UxGrid with a uxarray grid and vertical coordinate array.

Expand All @@ -32,13 +33,18 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
mesh : str, optional
The type of mesh used for the grid. Either "flat" (default) or "spherical".
"""
self.uxgrid = grid
if not isinstance(z, ux.UxDataArray):
raise TypeError("z must be an instance of ux.UxDataArray")
if z.ndim != 1:
raise ValueError("z must be a 1D array of vertical coordinates")
self.z = z
self._mesh = mesh

assert_valid_mesh(mesh)

@property
def depth(self):
Expand Down
5 changes: 4 additions & 1 deletion parcels/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from parcels import xgcm
from parcels._index_search import _search_indices_curvilinear_2d
from parcels._typing import assert_valid_mesh
from parcels.basegrid import BaseGrid
from parcels.spatialhash import SpatialHash

Expand Down Expand Up @@ -97,13 +98,15 @@ class XGrid(BaseGrid):

def __init__(self, grid: xgcm.Grid, mesh="flat"):
self.xgcm_grid = grid
self.mesh = mesh
self._mesh = mesh
self._spatialhash = None
ds = grid._ds

if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes)

assert_valid_mesh(mesh)

@classmethod
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
Expand Down
12 changes: 6 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def create_fieldset_global(xdim=200, ydim=100):
return FieldSet.from_data(data, dimensions, mesh="flat")


def create_fieldset_zeros_conversion(mesh_type="spherical", xdim=200, ydim=100) -> FieldSet:
def create_fieldset_zeros_conversion(mesh="spherical", xdim=200, ydim=100) -> FieldSet:
"""Zero velocity field with lat and lon determined by a conversion factor."""
mesh_conversion = 1 / 1852.0 / 60 if mesh_type == "spherical" else 1
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type)
mesh_conversion = 1 / 1852.0 / 60 if mesh == "spherical" else 1
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh=mesh)
ds["lon"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, xdim)
ds["lat"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, ydim)
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
grid = XGrid.from_dataset(ds, mesh=mesh)
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)

UV = VectorField("UV", U, V)
return FieldSet([U, V, UV])
Expand Down
26 changes: 13 additions & 13 deletions tests/v4/test_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@
}


@pytest.mark.parametrize("mesh_type", ["spherical", "flat"])
def test_advection_zonal(mesh_type, npart=10):
@pytest.mark.parametrize("mesh", ["spherical", "flat"])
def test_advection_zonal(mesh, npart=10):
"""Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`."""
ds = simple_UV_dataset(mesh_type=mesh_type)
ds = simple_UV_dataset(mesh=mesh)
ds["U"].data[:] = 1.0
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
grid = XGrid.from_dataset(ds, mesh=mesh)
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
UV = VectorField("UV", U, V)
fieldset = FieldSet([U, V, UV])

pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart))
pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"))

if mesh_type == "spherical":
if mesh == "spherical":
assert (np.diff(pset.lon) > 1.0e-4).all()
else:
assert (np.diff(pset.lon) < 1.0e-4).all()
Expand All @@ -58,7 +58,7 @@ def periodicBC(particle, fieldset, time):


def test_advection_zonal_periodic():
ds = simple_UV_dataset(dims=(2, 2, 2, 2), mesh_type="flat")
ds = simple_UV_dataset(dims=(2, 2, 2, 2), mesh="flat")
ds["U"].data[:] = 0.1
ds["lon"].data = np.array([0, 2])
ds["lat"].data = np.array([0, 2])
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_advection_zonal_periodic():

def test_horizontal_advection_in_3D_flow(npart=10):
"""Flat 2D zonal flow that increases linearly with depth from 0 m/s to 1 m/s."""
ds = simple_UV_dataset(mesh_type="flat")
ds = simple_UV_dataset(mesh="flat")
ds["U"].data[:] = 1.0
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, interp_method=XLinear)
Expand All @@ -105,7 +105,7 @@ def test_horizontal_advection_in_3D_flow(npart=10):
@pytest.mark.parametrize("direction", ["up", "down"])
@pytest.mark.parametrize("wErrorThroughSurface", [True, False])
def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
ds = simple_UV_dataset(mesh_type="flat")
ds = simple_UV_dataset(mesh="flat")
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, interp_method=XLinear)
U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
Expand Down Expand Up @@ -208,9 +208,9 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read

def test_radialrotation(npart=10):
ds = radial_rotation_dataset()
grid = XGrid.from_dataset(ds)
U = parcels.Field("U", ds["U"], grid, mesh_type="flat", interp_method=XLinear)
V = parcels.Field("V", ds["V"], grid, mesh_type="flat", interp_method=XLinear)
grid = XGrid.from_dataset(ds, mesh="flat")
U = parcels.Field("U", ds["U"], grid, interp_method=XLinear)
V = parcels.Field("V", ds["V"], grid, interp_method=XLinear)
UV = parcels.VectorField("UV", U, V)
fieldset = parcels.FieldSet([U, V, UV])

Expand Down
Loading
Loading