Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _search_indices_curvilinear_2d(
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))

(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh_type)
it += 1
if it > maxIterSearch:
print(f"Correct cell not found after {maxIterSearch} iterations")
Expand Down Expand Up @@ -408,11 +408,11 @@ def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2
return (zeta, eta, xsi, zi, yi, xi)


def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
xi = np.where(xi < 0, (xdim - 2) if sphere_mesh else 0, xi)
xi = np.where(xi > xdim - 2, 0 if sphere_mesh else (xdim - 2), xi)
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh_type: str):
xi = np.where(xi < 0, (xdim - 2) if mesh_type == "spherical" else 0, xi)
xi = np.where(xi > xdim - 2, 0 if mesh_type == "spherical" else (xdim - 2), xi)

xi = np.where(yi > ydim - 2, xdim - xi if sphere_mesh else xi, xi)
xi = np.where(yi > ydim - 2, xdim - xi if mesh_type == "spherical" else xi, xi)

yi = np.where(yi < 0, 0, yi)
yi = np.where(yi > ydim - 2, ydim - 2, yi)
Expand Down
17 changes: 3 additions & 14 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@

from parcels._core.utils.time import TimeInterval
from parcels._reprs import default_repr
from parcels._typing import (
Mesh,
VectorType,
assert_valid_mesh,
)
from parcels._typing import VectorType
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, ZeroInterpolator
from parcels.particle import KernelParticle
from parcels.tools.converters import (
Expand Down Expand Up @@ -114,7 +110,6 @@ def __init__(
name: str,
data: xr.DataArray | ux.UxDataArray,
grid: UxGrid | XGrid,
mesh_type: Mesh = "flat",
interp_method: Callable | None = None,
):
if not isinstance(data, (ux.UxDataArray, xr.DataArray)):
Expand All @@ -126,8 +121,6 @@ def __init__(
if not isinstance(grid, (UxGrid, XGrid)):
raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.")

assert_valid_mesh(mesh_type)

_assert_compatible_combination(data, grid)

if isinstance(grid, XGrid):
Expand Down Expand Up @@ -155,8 +148,6 @@ def __init__(
e.add_note(f"Error validating field {name!r}.")
raise e

self._mesh_type = mesh_type

# Setting the interpolation method dynamically
if interp_method is None:
self._interp_method = _DEFAULT_INTERPOLATOR_MAPPING[type(self.grid)]
Expand All @@ -166,12 +157,10 @@ def __init__(

self.igrid = -1 # Default the grid index to -1

if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()):
if self.grid._mesh_type == "flat" or (self.name not in unitconverters_map.keys()):
self.units = UnitConverter()
elif self._mesh_type == "spherical":
elif self.grid._mesh_type == "spherical":
self.units = unitconverters_map[self.name]
else:
raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'")

if self.data.shape[0] > 1:
if "time" not in self.data.coords:
Expand Down
2 changes: 1 addition & 1 deletion parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth
stacklevel=2,
)
self.fieldset.add_constant("RK45_tol", 10)
if self.fieldset.U.grid.mesh == "spherical":
if self.fieldset.U.grid._mesh_type == "spherical":
self.fieldset.RK45_tol /= (
1852 * 60
) # TODO does not account for zonal variation in meter -> degree conversion
Expand Down
2 changes: 1 addition & 1 deletion parcels/spatialhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _hash_index2d(self, coords):
as the source grid coordinates
"""
# Wrap longitude to [-180, 180]
if self._source_grid.mesh == "spherical":
if self._source_grid._mesh_type == "spherical":
lon = (coords[:, 1] + 180.0) % (360.0) - 180.0
else:
lon = coords[:, 1]
Expand Down
8 changes: 7 additions & 1 deletion parcels/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import uxarray as ux

from parcels._typing import assert_valid_mesh
from parcels.spatialhash import _barycentric_coordinates
from parcels.tools.statuscodes import FieldOutOfBoundError
from parcels.xgrid import _search_1d_array
Expand All @@ -20,7 +21,7 @@ class UxGrid(BaseGrid):
for interpolation on unstructured grids.
"""

def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh_type="flat") -> UxGrid:
"""
Initializes the UxGrid with a uxarray grid and vertical coordinate array.

Expand All @@ -32,13 +33,18 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
mesh_type : str, optional
The type of mesh used for the grid. Either "flat" (default) or "spherical".
"""
self.uxgrid = grid
if not isinstance(z, ux.UxDataArray):
raise TypeError("z must be an instance of ux.UxDataArray")
if z.ndim != 1:
raise ValueError("z must be a 1D array of vertical coordinates")
self.z = z
self._mesh_type = mesh_type

assert_valid_mesh(mesh_type)

@property
def depth(self):
Expand Down
11 changes: 7 additions & 4 deletions parcels/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from parcels import xgcm
from parcels._index_search import _search_indices_curvilinear_2d
from parcels._typing import assert_valid_mesh
from parcels.basegrid import BaseGrid
from parcels.spatialhash import SpatialHash

Expand Down Expand Up @@ -95,17 +96,19 @@ class XGrid(BaseGrid):

"""

def __init__(self, grid: xgcm.Grid, mesh="flat"):
def __init__(self, grid: xgcm.Grid, mesh_type="flat"):
self.xgcm_grid = grid
self.mesh = mesh
self._mesh_type = mesh_type
self._spatialhash = None
ds = grid._ds

if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes)

assert_valid_mesh(mesh_type)

@classmethod
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
def from_dataset(cls, ds: xr.Dataset, mesh_type="flat", xgcm_kwargs=None):
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
if xgcm_kwargs is None:
xgcm_kwargs = {}
Expand All @@ -114,7 +117,7 @@ def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):

ds = _drop_field_data(ds)
grid = xgcm.Grid(ds, **xgcm_kwargs)
return cls(grid, mesh=mesh)
return cls(grid, mesh_type=mesh_type)

@property
def axes(self) -> list[_XGRID_AXES]:
Expand Down
6 changes: 3 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def create_fieldset_zeros_conversion(mesh_type="spherical", xdim=200, ydim=100)
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type)
ds["lon"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, xdim)
ds["lat"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, ydim)
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
grid = XGrid.from_dataset(ds, mesh_type=mesh_type)
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)

UV = VectorField("UV", U, V)
return FieldSet([U, V, UV])
Expand Down
12 changes: 6 additions & 6 deletions tests/v4/test_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def test_advection_zonal(mesh_type, npart=10):
"""Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`."""
ds = simple_UV_dataset(mesh_type=mesh_type)
ds["U"].data[:] = 1.0
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
grid = XGrid.from_dataset(ds, mesh_type=mesh_type)
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
UV = VectorField("UV", U, V)
fieldset = FieldSet([U, V, UV])

Expand Down Expand Up @@ -208,9 +208,9 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read

def test_radialrotation(npart=10):
ds = radial_rotation_dataset()
grid = XGrid.from_dataset(ds)
U = parcels.Field("U", ds["U"], grid, mesh_type="flat", interp_method=XLinear)
V = parcels.Field("V", ds["V"], grid, mesh_type="flat", interp_method=XLinear)
grid = XGrid.from_dataset(ds, mesh_type="flat")
U = parcels.Field("U", ds["U"], grid, interp_method=XLinear)
V = parcels.Field("V", ds["V"], grid, interp_method=XLinear)
UV = parcels.VectorField("UV", U, V)
fieldset = parcels.FieldSet([U, V, UV])

Expand Down
20 changes: 10 additions & 10 deletions tests/v4/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def test_fieldKh_Brownian(mesh_type):
ds = simple_UV_dataset(dims=(2, 1, 2, 2), mesh_type=mesh_type)
ds["lon"].data = np.array([-1e6, 1e6])
ds["lat"].data = np.array([-1e6, 1e6])
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
grid = XGrid.from_dataset(ds, mesh_type=mesh_type)
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_zonal))
ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_meridional))
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, mesh_type=mesh_type, interp_method=XLinear)
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, mesh_type=mesh_type, interp_method=XLinear)
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear)
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear)
UV = VectorField("UV", U, V)
fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional])

Expand Down Expand Up @@ -61,18 +61,18 @@ def test_fieldKh_SpatiallyVaryingDiffusion(mesh_type, kernel):
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type)
ds["lon"].data = np.linspace(-1e6, 1e6, xdim)
ds["lat"].data = np.linspace(-1e6, 1e6, ydim)
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
grid = XGrid.from_dataset(ds, mesh_type=mesh_type)
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)

Kh = np.zeros((ydim, xdim), dtype=np.float32)
for x in range(xdim):
Kh[:, x] = np.tanh(ds["lon"][x] / ds["lon"][-1] * 10.0) * xdim / 2.0 + xdim / 2.0 + 100.0

ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, ydim, xdim), Kh))
ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, ydim, xdim), Kh))
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, mesh_type=mesh_type, interp_method=XLinear)
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, mesh_type=mesh_type, interp_method=XLinear)
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear)
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear)
UV = VectorField("UV", U, V)
fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional])
fieldset.add_constant("dres", float(ds["lon"][1] - ds["lon"][0]))
Expand Down
5 changes: 1 addition & 4 deletions tests/v4/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def test_field_init_param_types():
with pytest.raises(ValueError, match="Expected `grid` to be a parcels UxGrid, or parcels XGrid"):
Field(name="test", data=data["data_g"], grid=123)

with pytest.raises(ValueError, match="Invalid value 'invalid'. Valid options are.*"):
Field(name="test", data=data["data_g"], grid=grid, mesh_type="invalid")


@pytest.mark.parametrize(
"data,grid",
Expand Down Expand Up @@ -107,7 +104,7 @@ def test_field_init_fail_on_float_time_dim():
)
def test_field_time_interval(data, grid):
"""Test creating a field."""
field = Field(name="test_field", data=data, grid=grid, mesh_type="flat")
field = Field(name="test_field", data=data, grid=grid)
assert field.time_interval.left == np.datetime64("2000-01-01")
assert field.time_interval.right == np.datetime64("2001-01-01")

Expand Down
44 changes: 22 additions & 22 deletions tests/v4/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
@pytest.fixture
def fieldset() -> FieldSet:
"""Fixture to create a FieldSet object for testing."""
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U (A grid)"], grid, mesh_type="flat")
V = Field("V", ds["V (A grid)"], grid, mesh_type="flat")
grid = XGrid.from_dataset(ds, mesh_type="flat")
U = Field("U", ds["U (A grid)"], grid)
V = Field("V", ds["V (A grid)"], grid)
UV = VectorField("UV", U, V)

return FieldSet(
Expand Down Expand Up @@ -55,8 +55,8 @@ def test_fieldset_add_constant_field(fieldset):


def test_fieldset_add_field(fieldset):
grid = XGrid.from_dataset(ds)
field = Field("test_field", ds["U (A grid)"], grid, mesh_type="flat")
grid = XGrid.from_dataset(ds, mesh_type="flat")
field = Field("test_field", ds["U (A grid)"], grid)
fieldset.add_field(field)
assert fieldset.test_field == field

Expand All @@ -68,8 +68,8 @@ def test_fieldset_add_field_wrong_type(fieldset):


def test_fieldset_add_field_already_exists(fieldset):
grid = XGrid.from_dataset(ds)
field = Field("test_field", ds["U (A grid)"], grid, mesh_type="flat")
grid = XGrid.from_dataset(ds, mesh_type="flat")
field = Field("test_field", ds["U (A grid)"], grid)
fieldset.add_field(field, "test_field")
with pytest.raises(ValueError, match="FieldSet already has a Field with name 'test_field'"):
fieldset.add_field(field, "test_field")
Expand All @@ -87,10 +87,10 @@ def test_fieldset_gridset(fieldset):

@pytest.mark.parametrize("ds", [pytest.param(ds, id=k) for k, ds in datasets_structured.items()])
def test_fieldset_from_structured_generic_datasets(ds):
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh_type="flat")
fields = []
for var in ds.data_vars:
fields.append(Field(var, ds[var], grid, mesh_type="flat"))
fields.append(Field(var, ds[var], grid))

fieldset = FieldSet(fields)

Expand All @@ -105,13 +105,13 @@ def test_fieldset_gridset_multiple_grids(): ...


def test_fieldset_time_interval():
grid1 = XGrid.from_dataset(ds)
field1 = Field("field1", ds["U (A grid)"], grid1, mesh_type="flat")
grid1 = XGrid.from_dataset(ds, mesh_type="flat")
field1 = Field("field1", ds["U (A grid)"], grid1)

ds2 = ds.copy()
ds2["time"] = (ds2["time"].dims, ds2["time"].data + np.timedelta64(timedelta(days=1)), ds2["time"].attrs)
grid2 = XGrid.from_dataset(ds2)
field2 = Field("field2", ds2["U (A grid)"], grid2, mesh_type="flat")
grid2 = XGrid.from_dataset(ds2, mesh_type="flat")
field2 = Field("field2", ds2["U (A grid)"], grid2)

fieldset = FieldSet([field1, field2])
fieldset.add_constant_field("constant_field", 1.0)
Expand All @@ -136,9 +136,9 @@ def test_fieldset_init_incompatible_calendars():
ds1["time"].attrs,
)

grid = XGrid.from_dataset(ds1)
U = Field("U", ds1["U (A grid)"], grid, mesh_type="flat")
V = Field("V", ds1["V (A grid)"], grid, mesh_type="flat")
grid = XGrid.from_dataset(ds1, mesh_type="flat")
U = Field("U", ds1["U (A grid)"], grid)
V = Field("V", ds1["V (A grid)"], grid)
UV = VectorField("UV", U, V)

ds2 = ds.copy()
Expand All @@ -147,8 +147,8 @@ def test_fieldset_init_incompatible_calendars():
xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True),
ds2["time"].attrs,
)
grid2 = XGrid.from_dataset(ds2)
incompatible_calendar = Field("test", ds2["data_g"], grid2, mesh_type="flat")
grid2 = XGrid.from_dataset(ds2, mesh_type="flat")
incompatible_calendar = Field("test", ds2["data_g"], grid2)

with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"):
FieldSet([U, V, UV, incompatible_calendar])
Expand All @@ -161,8 +161,8 @@ def test_fieldset_add_field_incompatible_calendars(fieldset):
xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True),
ds_test["time"].attrs,
)
grid = XGrid.from_dataset(ds_test)
field = Field("test_field", ds_test["data_g"], grid, mesh_type="flat")
grid = XGrid.from_dataset(ds_test, mesh_type="flat")
field = Field("test_field", ds_test["data_g"], grid)

with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"):
fieldset.add_field(field, "test_field")
Expand All @@ -173,8 +173,8 @@ def test_fieldset_add_field_incompatible_calendars(fieldset):
np.linspace(0, 100, T_structured, dtype="timedelta64[s]"),
ds_test["time"].attrs,
)
grid = XGrid.from_dataset(ds_test)
field = Field("test_field", ds_test["data_g"], grid, mesh_type="flat")
grid = XGrid.from_dataset(ds_test, mesh_type="flat")
field = Field("test_field", ds_test["data_g"], grid)

with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"):
fieldset.add_field(field, "test_field")
Expand Down
Loading
Loading