Skip to content

Commit 7696988

Browse files
Implementing new Interpolation API
1 parent 948caa2 commit 7696988

File tree

6 files changed

+52
-108
lines changed

6 files changed

+52
-108
lines changed

src/parcels/_core/field.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,13 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
216216
else:
217217
_ei = particles.ei[:, self.igrid]
218218

219-
tau, ti = _search_time_index(self, time)
220-
position = self.grid.search(z, y, x, ei=_ei)
219+
position = {"time": time, "z": z, "lat": y, "lon": x}
220+
position["T"] = _search_time_index(self, time)
221+
position.update(self.grid.search(z, y, x, ei=_ei))
221222
_update_particles_ei(particles, position, self)
222223
_update_particle_states_position(particles, position)
223224

224-
value = self._interp_method(self, ti, position, tau, time, z, y, x)
225+
value = self._interp_method(position, self)
225226

226227
_update_particle_states_interp_value(particles, value)
227228

@@ -300,20 +301,21 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
300301
else:
301302
_ei = particles.ei[:, self.igrid]
302303

303-
tau, ti = _search_time_index(self.U, time)
304-
position = self.grid.search(z, y, x, ei=_ei)
304+
position = {"time": time, "z": z, "lat": y, "lon": x}
305+
position["T"] = _search_time_index(self.U, time)
306+
position.update(self.grid.search(z, y, x, ei=_ei))
305307
_update_particles_ei(particles, position, self)
306308
_update_particle_states_position(particles, position)
307309

308310
if self._vector_interp_method is None:
309-
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
310-
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)
311+
u = self.U._interp_method(position, self.U)
312+
v = self.V._interp_method(position, self.V)
311313
if "3D" in self.vector_type:
312-
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
314+
w = self.W._interp_method(position, self.W)
313315
else:
314316
w = 0.0
315317
else:
316-
(u, v, w) = self._vector_interp_method(self, ti, position, tau, time, z, y, x)
318+
(u, v, w) = self._vector_interp_method(position, self)
317319

318320
if applyConversion:
319321
u = self.U.units.to_target(u, z, y, x)

src/parcels/_core/index_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@ def _search_time_index(field: Field, time: datetime):
7575
if the sampled value is outside the time value range.
7676
"""
7777
if field.time_interval is None:
78-
return np.zeros(shape=time.shape, dtype=np.float32), np.zeros(shape=time.shape, dtype=np.int32)
78+
return np.zeros(shape=time.shape, dtype=np.int32), np.zeros(shape=time.shape, dtype=np.float32)
7979

8080
if not field.time_interval.is_all_time_in_interval(time):
8181
_raise_time_extrapolation_error(time, field=None)
8282

8383
ti = np.searchsorted(field.data.time.data, time, side="right") - 1
8484
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
85-
return np.atleast_1d(tau), np.atleast_1d(ti)
85+
return np.atleast_1d(ti), np.atleast_1d(tau)
8686

8787

8888
def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray):

src/parcels/interpolators.py

Lines changed: 33 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,16 @@
3030

3131

3232
def ZeroInterpolator(
33-
field: Field,
34-
ti: int,
3533
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
36-
tau: np.float32 | np.float64,
37-
t: np.float32 | np.float64,
38-
z: np.float32 | np.float64,
39-
y: np.float32 | np.float64,
40-
x: np.float32 | np.float64,
34+
field: Field,
4135
) -> np.float32 | np.float64:
4236
"""Template function used for the signature check of the lateral interpolation methods."""
4337
return 0.0
4438

4539

4640
def ZeroInterpolator_Vector(
47-
vectorfield: VectorField,
48-
ti: int,
4941
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
50-
tau: np.float32 | np.float64,
51-
t: np.float32 | np.float64,
52-
z: np.float32 | np.float64,
53-
y: np.float32 | np.float64,
54-
x: np.float32 | np.float64,
42+
vectorfield: VectorField,
5543
) -> np.float32 | np.float64:
5644
"""Template function used for the signature check of the interpolation methods for velocity fields."""
5745
return 0.0
@@ -105,19 +93,14 @@ def _get_corner_data_Agrid(
10593

10694

10795
def XLinear(
108-
field: Field,
109-
ti: int,
11096
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
111-
tau: np.float32 | np.float64,
112-
t: np.float32 | np.float64,
113-
z: np.float32 | np.float64,
114-
y: np.float32 | np.float64,
115-
x: np.float32 | np.float64,
97+
field: Field,
11698
):
11799
"""Trilinear interpolation on a regular grid."""
118100
xi, xsi = position["X"]
119101
yi, eta = position["Y"]
120102
zi, zeta = position["Z"]
103+
ti, tau = position["T"]
121104

122105
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
123106
data = field.data
@@ -149,14 +132,8 @@ def XLinear(
149132

150133

151134
def CGrid_Velocity(
152-
vectorfield: VectorField,
153-
ti: int,
154135
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
155-
tau: np.float32 | np.float64,
156-
t: np.float32 | np.float64,
157-
z: np.float32 | np.float64,
158-
y: np.float32 | np.float64,
159-
x: np.float32 | np.float64,
136+
vectorfield: VectorField,
160137
):
161138
"""
162139
Interpolation kernel for velocity fields on a C-Grid.
@@ -166,6 +143,8 @@ def CGrid_Velocity(
166143
xi, xsi = position["X"]
167144
yi, eta = position["Y"]
168145
zi, zeta = position["Z"]
146+
ti, tau = position["T"]
147+
lon = position["lon"]
169148

170149
U = vectorfield.U.data
171150
V = vectorfield.V.data
@@ -180,8 +159,8 @@ def CGrid_Velocity(
180159
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
181160

182161
if grid._mesh == "spherical":
183-
px[0] = np.where(px[0] < x - 225, px[0] + 360, px[0])
184-
px[0] = np.where(px[0] > x + 225, px[0] - 360, px[0])
162+
px[0] = np.where(px[0] < lon - 225, px[0] + 360, px[0])
163+
px[0] = np.where(px[0] > lon + 225, px[0] - 360, px[0])
185164
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
186165
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
187166
c1 = i_u._geodetic_distance(
@@ -296,7 +275,7 @@ def CGrid_Velocity(
296275

297276
# check whether the grid conversion has been applied correctly
298277
xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3]
299-
u = np.where(np.abs((xx - x) / x) > 1e-4, np.nan, u)
278+
u = np.where(np.abs((xx - lon) / lon) > 1e-4, np.nan, u)
300279

301280
if vectorfield.W:
302281
data = vectorfield.W.data
@@ -348,14 +327,8 @@ def CGrid_Velocity(
348327

349328

350329
def CGrid_Tracer(
351-
field: Field,
352-
ti: int,
353330
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
354-
tau: np.float32 | np.float64,
355-
t: np.float32 | np.float64,
356-
z: np.float32 | np.float64,
357-
y: np.float32 | np.float64,
358-
x: np.float32 | np.float64,
331+
field: Field,
359332
):
360333
"""Interpolation kernel for tracer fields on a C-Grid.
361334
@@ -365,6 +338,7 @@ def CGrid_Tracer(
365338
xi, _ = position["X"]
366339
yi, _ = position["Y"]
367340
zi, _ = position["Z"]
341+
ti, tau = position["T"]
368342

369343
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
370344
data = field.data
@@ -403,31 +377,26 @@ def CGrid_Tracer(
403377

404378

405379
def _Spatialslip(
406-
vectorfield: VectorField,
407-
ti: int,
408380
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
409-
tau: np.float32 | np.float64,
410-
t: np.float32 | np.float64,
411-
z: np.float32 | np.float64,
412-
y: np.float32 | np.float64,
413-
x: np.float32 | np.float64,
381+
vectorfield: VectorField,
414382
a: np.float32,
415383
b: np.float32,
416384
):
417385
"""Helper function for spatial boundary condition interpolation for velocity fields."""
418386
xi, xsi = position["X"]
419387
yi, eta = position["Y"]
420388
zi, zeta = position["Z"]
389+
ti, tau = position["T"]
421390

422391
axis_dim = vectorfield.U.grid.get_axis_dim_mapping(vectorfield.U.data.dims)
423392
lenT = 2 if np.any(tau > 0) else 1
424393
lenZ = 2 if np.any(zeta > 0) else 1
425394
npart = len(xsi)
426395

427-
u = XLinear(vectorfield.U, ti, position, tau, t, z, y, x)
428-
v = XLinear(vectorfield.V, ti, position, tau, t, z, y, x)
396+
u = XLinear(position, vectorfield.U)
397+
v = XLinear(position, vectorfield.V)
429398
if vectorfield.W:
430-
w = XLinear(vectorfield.W, ti, position, tau, t, z, y, x)
399+
w = XLinear(position, vectorfield.W)
431400

432401
corner_dataU = _get_corner_data_Agrid(vectorfield.U.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim)
433402
corner_dataV = _get_corner_data_Agrid(vectorfield.V.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim)
@@ -519,42 +488,24 @@ def is_land(ti: int, zi: int, yi: int, xi: int):
519488

520489

521490
def XFreeslip(
522-
vectorfield: VectorField,
523-
ti: int,
524491
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
525-
tau: np.float32 | np.float64,
526-
t: np.float32 | np.float64,
527-
z: np.float32 | np.float64,
528-
y: np.float32 | np.float64,
529-
x: np.float32 | np.float64,
492+
vectorfield: VectorField,
530493
):
531494
"""Free-slip boundary condition interpolation for velocity fields."""
532-
return _Spatialslip(vectorfield, ti, position, tau, t, z, y, x, a=1.0, b=0.0)
495+
return _Spatialslip(position, vectorfield, a=1.0, b=0.0)
533496

534497

535498
def XPartialslip(
536-
vectorfield: VectorField,
537-
ti: int,
538499
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
539-
tau: np.float32 | np.float64,
540-
t: np.float32 | np.float64,
541-
z: np.float32 | np.float64,
542-
y: np.float32 | np.float64,
543-
x: np.float32 | np.float64,
500+
vectorfield: VectorField,
544501
):
545502
"""Partial-slip boundary condition interpolation for velocity fields."""
546-
return _Spatialslip(vectorfield, ti, position, tau, t, z, y, x, a=0.5, b=0.5)
503+
return _Spatialslip(position, vectorfield, a=0.5, b=0.5)
547504

548505

549506
def XNearest(
550-
field: Field,
551-
ti: int,
552507
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
553-
tau: np.float32 | np.float64,
554-
t: np.float32 | np.float64,
555-
z: np.float32 | np.float64,
556-
y: np.float32 | np.float64,
557-
x: np.float32 | np.float64,
508+
field: Field,
558509
):
559510
"""
560511
Nearest-Neighbour spatial interpolation on a regular grid.
@@ -563,6 +514,7 @@ def XNearest(
563514
xi, xsi = position["X"]
564515
yi, eta = position["Y"]
565516
zi, zeta = position["Z"]
517+
ti, tau = position["T"]
566518

567519
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
568520
data = field.data
@@ -610,49 +562,39 @@ def XNearest(
610562

611563

612564
def UXPiecewiseConstantFace(
613-
field: Field,
614-
ti: int,
615565
position: dict[_UXGRID_AXES, tuple[int, float | np.ndarray]],
616-
tau: np.float32 | np.float64,
617-
t: np.float32 | np.float64,
618-
z: np.float32 | np.float64,
619-
y: np.float32 | np.float64,
620-
x: np.float32 | np.float64,
566+
field: Field,
621567
):
622568
"""
623569
Piecewise constant interpolation kernel for face registered data.
624570
This interpolation method is appropriate for fields that are
625571
face registered, such as u,v in FESOM.
626572
"""
627-
return field.data.values[ti, position["Z"][0], position["FACE"][0]]
573+
return field.data.values[position["T"][0], position["Z"][0], position["FACE"][0]]
628574

629575

630576
def UXPiecewiseLinearNode(
631-
field: Field,
632-
ti: int,
633577
position: dict[_UXGRID_AXES, tuple[int, float | np.ndarray]],
634-
tau: np.float32 | np.float64,
635-
t: np.float32 | np.float64,
636-
z: np.float32 | np.float64,
637-
y: np.float32 | np.float64,
638-
x: np.float32 | np.float64,
578+
field: Field,
639579
):
640580
"""
641581
Piecewise linear interpolation kernel for node registered data located at vertical interface levels.
642582
This interpolation method is appropriate for fields that are node registered such as the vertical
643583
velocity W in FESOM2. Effectively, it applies barycentric interpolation in the lateral direction
644584
and piecewise linear interpolation in the vertical direction.
645585
"""
646-
k, fi = position["Z"][0], position["FACE"][0]
586+
ti, _ = position["T"]
587+
zi, fi = position["Z"][0], position["FACE"][0]
588+
z = position["z"]
647589
bcoords = position["FACE"][1]
648590
node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values
649591
# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
650592
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
651593
# First, do barycentric interpolation in the lateral direction for each interface level
652-
fzk = np.sum(field.data.values[ti[:, None], k[:, None], node_ids] * bcoords, axis=-1)
653-
fzkp1 = np.sum(field.data.values[ti[:, None], k[:, None] + 1, node_ids] * bcoords, axis=-1)
594+
fzk = np.sum(field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1)
595+
fzkp1 = np.sum(field.data.values[ti[:, None], zi[:, None] + 1, node_ids] * bcoords, axis=-1)
654596

655597
# Then, do piecewise linear interpolation in the vertical direction
656-
zk = field.grid.z.values[k]
657-
zkp1 = field.grid.z.values[k + 1]
598+
zk = field.grid.z.values[zi]
599+
zkp1 = field.grid.z.values[zi + 1]
658600
return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction

tests/test_field.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_field_invalid_interpolator():
117117
ds = datasets_structured["ds_2d_left"]
118118
grid = XGrid.from_dataset(ds)
119119

120-
def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, invalid):
120+
def invalid_interpolator_wrong_signature(position, invalid):
121121
return 0.0
122122

123123
# Test invalid interpolator with wrong signature
@@ -129,7 +129,7 @@ def test_vectorfield_invalid_interpolator():
129129
ds = datasets_structured["ds_2d_left"]
130130
grid = XGrid.from_dataset(ds)
131131

132-
def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, invalid):
132+
def invalid_interpolator_wrong_signature(position, invalid):
133133
return 0.0
134134

135135
# Create component fields

tests/test_interpolation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def field():
8484
)
8585
def test_raw_2d_interpolation(field, func, t, z, y, x, expected):
8686
"""Test the interpolation functions on the Field."""
87-
tau, ti = _search_time_index(field, t)
8887
position = field.grid.search(z, y, x)
88+
position["T"] = _search_time_index(field, t)
8989

90-
value = func(field, ti, position, tau, 0, 0, y, x)
90+
value = func(position, field)
9191
np.testing.assert_equal(value, expected)
9292

9393

tests/test_particleset_execute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ def test_raise_general_error(): ...
298298

299299

300300
def test_errorinterpolation(fieldset):
301-
def NaNInterpolator(field, ti, position, tau, t, z, y, x): # pragma: no cover
302-
return np.nan * np.zeros_like(x)
301+
def NaNInterpolator(position, field): # pragma: no cover
302+
return np.nan * np.zeros_like(position["lon"])
303303

304304
def SampleU(particles, fieldset): # pragma: no cover
305305
fieldset.U[particles.time, particles.z, particles.lat, particles.lon, particles]

0 commit comments

Comments
 (0)