Skip to content

Commit a1815d8

Browse files
Merge branch 'v4-dev' into support_1d_fields
2 parents fb6fcb1 + feb1a46 commit a1815d8

File tree

4 files changed

+24
-17
lines changed

4 files changed

+24
-17
lines changed

src/parcels/_core/fieldset.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid
1818
from parcels._logger import logger
1919
from parcels._typing import Mesh
20+
from parcels.interpolators import XConstantField
2021

2122
if TYPE_CHECKING:
2223
from parcels._core.basegrid import BaseGrid
@@ -116,7 +117,7 @@ def add_field(self, field: Field, name: str | None = None):
116117

117118
self.fields[name] = field
118119

119-
def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
120+
def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"):
120121
"""Wrapper function to add a Field that is constant in space,
121122
useful e.g. when using constant horizontal diffusivity
122123
@@ -134,16 +135,15 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
134135
correction for zonal velocity U near the poles.
135136
2. flat: No conversion, lat/lon are assumed to be in m.
136137
"""
137-
ds = xr.Dataset({name: (["time", "lat", "lon", "depth"], np.full((1, 1, 1, 1), value))})
138-
grid = XGrid(xgcm.Grid(ds, **_DEFAULT_XGCM_KWARGS))
139-
self.add_field(
140-
Field(
141-
name,
142-
ds[name],
143-
grid,
144-
interp_method=None, # TODO : Need to define an interpolation method for constants
145-
)
138+
ds = xr.Dataset(
139+
{name: (["lat", "lon"], np.full((1, 1), value))},
140+
coords={"lat": (["lat"], [0], {"axis": "Y"}), "lon": (["lon"], [0], {"axis": "X"})},
141+
)
142+
xgrid = xgcm.Grid(
143+
ds, coords={"X": {"left": "lon"}, "Y": {"left": "lat"}}, autoparse_metadata=False, **_DEFAULT_XGCM_KWARGS
146144
)
145+
grid = XGrid(xgrid, mesh=mesh)
146+
self.add_field(Field(name, ds[name], grid, interp_method=XConstantField))
147147

148148
def add_constant(self, name, value):
149149
"""Add a constant to the FieldSet. Note that all constants are

src/parcels/interpolators.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"CGrid_Velocity",
2121
"UXPiecewiseConstantFace",
2222
"UXPiecewiseLinearNode",
23+
"XConstantField",
2324
"XFreeslip",
2425
"XLinear",
2526
"XLinearInvdistLandTracer",
@@ -136,6 +137,15 @@ def XLinear(
136137
return value.compute() if is_dask_collection(value) else value
137138

138139

140+
def XConstantField(
141+
particle_positions: dict[str, float | np.ndarray],
142+
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
143+
field: Field,
144+
):
145+
"""Returning the single value of a Constant Field (with a size=(1,1,1,1) array)"""
146+
return field.data[0, 0, 0, 0].values
147+
148+
139149
def CGrid_Velocity(
140150
particle_positions: dict[str, float | np.ndarray],
141151
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
@@ -599,7 +609,7 @@ def XLinearInvdistLandTracer(
599609
all_land_mask = nb_land == 4 * lenZ * lenT
600610
values[all_land_mask] = 0.0
601611

602-
not_all_land = ~all_land_mask
612+
not_all_land = np.asarray(~all_land_mask, dtype=bool)
603613
if np.any(not_all_land):
604614
i_grid = np.arange(2)[None, None, None, :, None]
605615
j_grid = np.arange(2)[None, None, :, None, None]

tests/test_diffusion.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@ def test_fieldKh_Brownian(mesh):
2323
grid = XGrid.from_dataset(ds, mesh=mesh)
2424
U = Field("U", ds["U"], grid, interp_method=XLinear)
2525
V = Field("V", ds["V"], grid, interp_method=XLinear)
26-
ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_zonal))
27-
ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_meridional))
28-
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear)
29-
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear)
3026
UV = VectorField("UV", U, V)
31-
fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional])
27+
fieldset = FieldSet([U, V, UV])
28+
fieldset.add_constant_field("Kh_zonal", kh_zonal, mesh=mesh)
29+
fieldset.add_constant_field("Kh_meridional", kh_meridional, mesh=mesh)
3230

3331
npart = 100
3432
runtime = np.timedelta64(2, "h")

tests/test_fieldset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_fieldset_add_constant_field(fieldset):
5959
lat = ds["lat"].mean()
6060
lon = ds["lon"].mean()
6161

62-
pytest.xfail(reason="Not yet implemented interpolation.")
6362
assert fieldset.test_constant_field[time, z, lat, lon] == 1.0
6463

6564

0 commit comments

Comments
 (0)