Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import xarray as xr
import xgcm

from parcels._core.converters import Geographic, GeographicPolar
from parcels._core.field import Field, VectorField
from parcels._core.utils.string import _assert_str_and_python_varname
from parcels._core.utils.time import get_datetime_type_calendar
Expand Down Expand Up @@ -277,15 +276,13 @@ def from_fesom2(ds: ux.UxDataset):
raise ValueError(
f"Dataset missing one of the required dimensions 'time', 'nz', or 'nz1'. Found dimensions {ds_dims}"
)
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="spherical")
ds = _discover_fesom2_U_and_V(ds)

fields = {}
if "U" in ds.data_vars and "V" in ds.data_vars:
fields["U"] = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"]))
fields["V"] = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"]))
fields["U"].units = GeographicPolar()
fields["V"].units = Geographic()

if "W" in ds.data_vars:
fields["W"] = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"]))
Expand Down
6 changes: 3 additions & 3 deletions src/parcels/_core/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class UxGrid(BaseGrid):
for interpolation on unstructured grids.
"""

def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid:
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> UxGrid:
"""
Initializes the UxGrid with a uxarray grid and vertical coordinate array.

Expand All @@ -30,8 +30,8 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
mesh : str, optional
The type of mesh used for the grid. Either "flat" (default) or "spherical".
mesh : str
The type of mesh used for the grid. Either "flat" or "spherical".
"""
self.uxgrid = grid
if not isinstance(z, ux.UxDataArray):
Expand Down
4 changes: 2 additions & 2 deletions src/parcels/_core/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class XGrid(BaseGrid):

"""

def __init__(self, grid: xgcm.Grid, mesh="flat"):
def __init__(self, grid: xgcm.Grid, mesh):
self.xgcm_grid = grid
self._mesh = mesh
self._spatialhash = None
Expand All @@ -124,7 +124,7 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"):
self._ds = ds

@classmethod
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
def from_dataset(cls, ds: xr.Dataset, mesh, xgcm_kwargs=None):
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
if xgcm_kwargs is None:
xgcm_kwargs = {}
Expand Down
16 changes: 8 additions & 8 deletions tests/test_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_advection_zonal_periodic():
halo.XG.values = ds.XG.values[1] + 2
ds = xr.concat([ds, halo], dim="XG")

grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
UV = VectorField("UV", U, V)
Expand All @@ -103,7 +103,7 @@ def test_horizontal_advection_in_3D_flow(npart=10):
"""Flat 2D zonal flow that increases linearly with z from 0 m/s to 1 m/s."""
ds = simple_UV_dataset(mesh="flat")
ds["U"].data[:] = 1.0
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U"], grid, interp_method=XLinear)
U.data[:, 0, :, :] = 0.0 # Set U to 0 at the surface
V = Field("V", ds["V"], grid, interp_method=XLinear)
Expand All @@ -121,7 +121,7 @@ def test_horizontal_advection_in_3D_flow(npart=10):
@pytest.mark.parametrize("wErrorThroughSurface", [True, False])
def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
ds = simple_UV_dataset(mesh="flat")
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U"], grid, interp_method=XLinear)
U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
V = Field("V", ds["V"], grid, interp_method=XLinear)
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read
if w:
ds["W"] = (["time", "depth", "YG", "XG"], W)

grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
fields = [U, V, VectorField("UV", U, V)]
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_radialrotation(npart=10):
)
def test_moving_eddy(kernel, rtol):
ds = moving_eddy_dataset()
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
if kernel in [AdvectionRK2_3D, AdvectionRK4_3D]:
Expand Down Expand Up @@ -315,7 +315,7 @@ def truth_moving(x_0, y_0, t):
)
def test_decaying_moving_eddy(kernel, rtol):
ds = decaying_moving_eddy_dataset()
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
UV = VectorField("UV", U, V)
Expand Down Expand Up @@ -363,7 +363,7 @@ def truth_moving(x_0, y_0, t):
def test_stommelgyre_fieldset(kernel, rtol, grid_type):
npart = 2
ds = stommel_gyre_dataset(grid_type=grid_type)
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
vector_interp_method = None if grid_type == "A" else CGrid_Velocity
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
Expand Down Expand Up @@ -404,7 +404,7 @@ def UpdateP(particles, fieldset): # pragma: no cover
def test_peninsula_fieldset(kernel, rtol, grid_type):
npart = 2
ds = peninsula_dataset(grid_type=grid_type)
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
P = Field("P", ds["P"], grid, interp_method=XLinear)
Expand Down
19 changes: 10 additions & 9 deletions tests/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def test_field_init_param_types():
data = datasets_structured["ds_2d_left"]
grid = XGrid.from_dataset(data)
grid = XGrid.from_dataset(data, mesh="flat")

with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."):
Field(name=123, data=data["data_g"], grid=grid, interp_method=XLinear)
Expand Down Expand Up @@ -47,14 +47,15 @@ def test_field_init_param_types():
[
pytest.param(
ux.UxDataArray(),
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
id="uxdata-grid",
),
pytest.param(
xr.DataArray(),
UxGrid(
datasets_unstructured["stommel_gyre_delaunay"].uxgrid,
z=datasets_unstructured["stommel_gyre_delaunay"].coords["nz"],
mesh="flat",
),
id="xarray-uxgrid",
),
Expand All @@ -75,7 +76,7 @@ def test_field_incompatible_combination(data, grid):
[
pytest.param(
datasets_structured["ds_2d_left"]["data_g"],
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
id="ds_2d_left",
), # TODO: Perhaps this test should be expanded to cover more datasets?
],
Expand Down Expand Up @@ -106,7 +107,7 @@ def test_field_init_fail_on_float_time_dim():
)

data = ds["data_g"]
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
with pytest.raises(
ValueError,
match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?",
Expand All @@ -124,7 +125,7 @@ def test_field_init_fail_on_float_time_dim():
[
pytest.param(
datasets_structured["ds_2d_left"]["data_g"],
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
id="ds_2d_left",
),
],
Expand All @@ -143,7 +144,7 @@ def test_vectorfield_init_different_time_intervals():

def test_field_invalid_interpolator():
ds = datasets_structured["ds_2d_left"]
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")

def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid):
return 0.0
Expand All @@ -160,7 +161,7 @@ def invalid_interpolator_wrong_signature(particle_positions, grid_positions, inv

def test_vectorfield_invalid_interpolator():
ds = datasets_structured["ds_2d_left"]
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")

def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid):
return 0.0
Expand Down Expand Up @@ -194,7 +195,7 @@ def test_field_unstructured_z_linear():
for k, z in enumerate(ds.coords["nz"]):
ds["W"].values[:, k, :] = z

grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
# Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1")
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)

Expand Down Expand Up @@ -232,7 +233,7 @@ def test_field_unstructured_z_linear():
def test_field_constant_in_time():
"""Tests field evaluation for a field with no time interval (i.e., constant in time)."""
ds = datasets_unstructured["stommel_gyre_delaunay"]
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
# Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1")
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.fixture
def field_cone():
ds = datasets["2d_left_unrolled_cone"]
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
field = Field(
name="test_field",
data=ds["data_g"],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def field():
"y": (["y"], [0.5, 1.5, 2.5, 3.5], {"axis": "Y"}),
},
)
return Field("U", ds["U"], XGrid.from_dataset(ds), interp_method=XLinear)
return Field("U", ds["U"], XGrid.from_dataset(ds, mesh="flat"), interp_method=XLinear)


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remove duplicates
"""Fixture to create a FieldSet object for testing."""
ds = datasets["ds_2d_left"]
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U_A_grid"], grid, XLinear)
V = Field("V", ds["V_A_grid"], grid, XLinear)
UV = VectorField("UV", U, V)
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_pfile_array_write_zarr_memorystore(fieldset):
def test_write_fieldset_without_time(tmp_zarrfile):
ds = peninsula_dataset() # DataSet without time
assert "time" not in ds.dims
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
fieldset = FieldSet([Field("U", ds["U"], grid, XLinear)])

pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_spatialhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

def test_spatialhash_init():
ds = datasets["2d_left_rotated"]
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
spatialhash = grid.get_spatial_hash()
assert spatialhash is not None


def test_invalid_positions():
ds = datasets["2d_left_rotated"]
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")

j, i, coords = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf])
assert np.all(j == -3)
Expand All @@ -22,7 +22,7 @@ def test_invalid_positions():

def test_mixed_positions():
ds = datasets["2d_left_rotated"]
grid = XGrid.from_dataset(ds)
grid = XGrid.from_dataset(ds, mesh="flat")
lat = grid.lat.mean()
lon = grid.lon.mean()
y = [lat, np.nan]
Expand Down
19 changes: 10 additions & 9 deletions tests/test_uxarray_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def uv_fesom_channel(ds_fesom_channel) -> VectorField:
U=Field(
name="U",
data=ds_fesom_channel.U,
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
interp_method=UXPiecewiseConstantFace,
),
V=Field(
name="V",
data=ds_fesom_channel.V,
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
interp_method=UXPiecewiseConstantFace,
),
)
Expand All @@ -58,19 +58,19 @@ def uvw_fesom_channel(ds_fesom_channel) -> VectorField:
U=Field(
name="U",
data=ds_fesom_channel.U,
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
interp_method=UXPiecewiseConstantFace,
),
V=Field(
name="V",
data=ds_fesom_channel.V,
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
interp_method=UXPiecewiseConstantFace,
),
W=Field(
name="W",
data=ds_fesom_channel.W,
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"], mesh="flat"),
interp_method=UXPiecewiseLinearNode,
),
)
Expand Down Expand Up @@ -112,13 +112,14 @@ def test_fesom2_square_delaunay_uniform_z_coordinate_eval():
Since the underlying data is constant, we can check that the values are as expected.
"""
ds = datasets_unstructured["fesom2_square_delaunay_uniform_z_coordinate"]
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="flat")
UVW = VectorField(
name="UVW",
U=Field(name="U", data=ds.U, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds.V, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace),
W=Field(name="W", data=ds.W, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseLinearNode),
U=Field(name="U", data=ds.U, grid=grid, interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds.V, grid=grid, interp_method=UXPiecewiseConstantFace),
W=Field(name="W", data=ds.W, grid=grid, interp_method=UXPiecewiseLinearNode),
)
P = Field(name="p", data=ds.p, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseLinearNode)
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseLinearNode)
fieldset = FieldSet([UVW, P, UVW.U, UVW.V, UVW.W])

assert fieldset.U.eval(time=ds.time[0].values, z=[1.0], y=[30.0], x=[30.0], applyConversion=False) == 1.0
Expand Down
13 changes: 10 additions & 3 deletions tests/test_uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@

@pytest.mark.parametrize("uxds", [pytest.param(uxds, id=key) for key, uxds in uxdatasets.items()])
def test_uxgrid_init_on_generic_datasets(uxds):
UxGrid(uxds.uxgrid, z=uxds.coords["nz"])
UxGrid(uxds.uxgrid, z=uxds.coords["nz"], mesh="flat")


@pytest.mark.parametrize("uxds", [uxdatasets["stommel_gyre_delaunay"]])
def test_uxgrid_axes(uxds):
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"])
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"], mesh="flat")
assert grid.axes == ["Z", "FACE"]


@pytest.mark.parametrize("uxds", [uxdatasets["stommel_gyre_delaunay"]])
@pytest.mark.parametrize("mesh", ["flat", "spherical"])
def test_uxgrid_mesh(uxds, mesh):
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"], mesh=mesh)
assert grid._mesh == mesh


@pytest.mark.parametrize("uxds", [uxdatasets["stommel_gyre_delaunay"]])
def test_xgrid_get_axis_dim(uxds):
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"])
grid = UxGrid(uxds.uxgrid, z=uxds.coords["nz"], mesh="flat")

assert grid.get_axis_dim("FACE") == 721
assert grid.get_axis_dim("Z") == 2
Loading