Skip to content

Commit 0d82a2b

Browse files
Merge pull request #2201 from OceanParcels/v4-new-kernel-api
Update Kernel signature to (particles, fieldset)
2 parents 129540e + fe00581 commit 0d82a2b

File tree

13 files changed

+340
-280
lines changed

13 files changed

+340
-280
lines changed

docs/community/v4-migration.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Version 4 of Parcels is unreleased at the moment. The information in this migrat
1111
- Sharing state between kernels must be done via the particle data (as the kernels are not combined under the hood anymore).
1212
- `particl_dlon`, `particle_dlat` etc have been renamed to `particle.dlon` and `particle.dlat`.
1313
- `particle.dt` is a np.timedelta64 object; be careful when multiplying `particle.dt` with a velocity, as its value may be cast to nanoseconds.
14-
- The `time` argument in the Kernel signature is now standard `None` (and may be removed in the Kernel API before release of v4), so can't be used. Use `particle.time` instead.
14+
- The `time` argument in the Kernel signature has been removed in the Kernel API, so can't be used. Use `particle.time` instead.
15+
- The `particle` argument in the Kernel signature has been renamed to `particles`.
1516
- `math` functions should be replaced with array compatible equivalents (e.g., `math.sin` -> `np.sin`). Instead of `ParcelsRandom` you should use numpy's random functions.
1617

1718
## FieldSet

parcels/application_kernels/advection.py

Lines changed: 114 additions & 114 deletions
Large diffs are not rendered by default.

parcels/application_kernels/advectiondiffusion.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
__all__ = ["AdvectionDiffusionEM", "AdvectionDiffusionM1", "DiffusionUniformKh"]
99

1010

11-
def AdvectionDiffusionM1(particle, fieldset, time): # pragma: no cover
11+
def AdvectionDiffusionM1(particles, fieldset): # pragma: no cover
1212
"""Kernel for 2D advection-diffusion, solved using the Milstein scheme at first order (M1).
1313
1414
Assumes that fieldset has fields `Kh_zonal` and `Kh_meridional`
@@ -23,30 +23,34 @@ def AdvectionDiffusionM1(particle, fieldset, time): # pragma: no cover
2323
The Wiener increment `dW` is normally distributed with zero
2424
mean and a standard deviation of sqrt(dt).
2525
"""
26-
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
26+
dt = particles.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
2727
# Wiener increment with zero mean and std of sqrt(dt)
2828
dWx = np.random.normal(0, np.sqrt(np.fabs(dt)))
2929
dWy = np.random.normal(0, np.sqrt(np.fabs(dt)))
3030

31-
Kxp1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon + fieldset.dres, particle]
32-
Kxm1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon - fieldset.dres, particle]
31+
Kxp1 = fieldset.Kh_zonal[particles.time, particles.depth, particles.lat, particles.lon + fieldset.dres, particles]
32+
Kxm1 = fieldset.Kh_zonal[particles.time, particles.depth, particles.lat, particles.lon - fieldset.dres, particles]
3333
dKdx = (Kxp1 - Kxm1) / (2 * fieldset.dres)
3434

35-
u, v = fieldset.UV[particle.time, particle.depth, particle.lat, particle.lon, particle]
36-
bx = np.sqrt(2 * fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon, particle])
35+
u, v = fieldset.UV[particles.time, particles.depth, particles.lat, particles.lon, particles]
36+
bx = np.sqrt(2 * fieldset.Kh_zonal[particles.time, particles.depth, particles.lat, particles.lon, particles])
3737

38-
Kyp1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat + fieldset.dres, particle.lon, particle]
39-
Kym1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat - fieldset.dres, particle.lon, particle]
38+
Kyp1 = fieldset.Kh_meridional[
39+
particles.time, particles.depth, particles.lat + fieldset.dres, particles.lon, particles
40+
]
41+
Kym1 = fieldset.Kh_meridional[
42+
particles.time, particles.depth, particles.lat - fieldset.dres, particles.lon, particles
43+
]
4044
dKdy = (Kyp1 - Kym1) / (2 * fieldset.dres)
4145

42-
by = np.sqrt(2 * fieldset.Kh_meridional[particle.time, particle.depth, particle.lat, particle.lon, particle])
46+
by = np.sqrt(2 * fieldset.Kh_meridional[particles.time, particles.depth, particles.lat, particles.lon, particles])
4347

4448
# Particle positions are updated only after evaluating all terms.
45-
particle.dlon += u * dt + 0.5 * dKdx * (dWx**2 + dt) + bx * dWx
46-
particle.dlat += v * dt + 0.5 * dKdy * (dWy**2 + dt) + by * dWy
49+
particles.dlon += u * dt + 0.5 * dKdx * (dWx**2 + dt) + bx * dWx
50+
particles.dlat += v * dt + 0.5 * dKdy * (dWy**2 + dt) + by * dWy
4751

4852

49-
def AdvectionDiffusionEM(particle, fieldset, time): # pragma: no cover
53+
def AdvectionDiffusionEM(particles, fieldset): # pragma: no cover
5054
"""Kernel for 2D advection-diffusion, solved using the Euler-Maruyama scheme (EM).
5155
5256
Assumes that fieldset has fields `Kh_zonal` and `Kh_meridional`
@@ -59,31 +63,35 @@ def AdvectionDiffusionEM(particle, fieldset, time): # pragma: no cover
5963
The Wiener increment `dW` is normally distributed with zero
6064
mean and a standard deviation of sqrt(dt).
6165
"""
62-
dt = particle.dt / np.timedelta64(1, "s")
66+
dt = particles.dt / np.timedelta64(1, "s")
6367
# Wiener increment with zero mean and std of sqrt(dt)
6468
dWx = np.random.normal(0, np.sqrt(np.fabs(dt)))
6569
dWy = np.random.normal(0, np.sqrt(np.fabs(dt)))
6670

67-
u, v = fieldset.UV[particle.time, particle.depth, particle.lat, particle.lon, particle]
71+
u, v = fieldset.UV[particles.time, particles.depth, particles.lat, particles.lon, particles]
6872

69-
Kxp1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon + fieldset.dres, particle]
70-
Kxm1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon - fieldset.dres, particle]
73+
Kxp1 = fieldset.Kh_zonal[particles.time, particles.depth, particles.lat, particles.lon + fieldset.dres, particles]
74+
Kxm1 = fieldset.Kh_zonal[particles.time, particles.depth, particles.lat, particles.lon - fieldset.dres, particles]
7175
dKdx = (Kxp1 - Kxm1) / (2 * fieldset.dres)
7276
ax = u + dKdx
73-
bx = np.sqrt(2 * fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon, particle])
74-
75-
Kyp1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat + fieldset.dres, particle.lon, particle]
76-
Kym1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat - fieldset.dres, particle.lon, particle]
77+
bx = np.sqrt(2 * fieldset.Kh_zonal[particles.time, particles.depth, particles.lat, particles.lon, particles])
78+
79+
Kyp1 = fieldset.Kh_meridional[
80+
particles.time, particles.depth, particles.lat + fieldset.dres, particles.lon, particles
81+
]
82+
Kym1 = fieldset.Kh_meridional[
83+
particles.time, particles.depth, particles.lat - fieldset.dres, particles.lon, particles
84+
]
7785
dKdy = (Kyp1 - Kym1) / (2 * fieldset.dres)
7886
ay = v + dKdy
79-
by = np.sqrt(2 * fieldset.Kh_meridional[particle.time, particle.depth, particle.lat, particle.lon, particle])
87+
by = np.sqrt(2 * fieldset.Kh_meridional[particles.time, particles.depth, particles.lat, particles.lon, particles])
8088

8189
# Particle positions are updated only after evaluating all terms.
82-
particle.dlon += ax * dt + bx * dWx
83-
particle.dlat += ay * dt + by * dWy
90+
particles.dlon += ax * dt + bx * dWx
91+
particles.dlat += ay * dt + by * dWy
8492

8593

86-
def DiffusionUniformKh(particle, fieldset, time): # pragma: no cover
94+
def DiffusionUniformKh(particles, fieldset): # pragma: no cover
8795
"""Kernel for simple 2D diffusion where diffusivity (Kh) is assumed uniform.
8896
8997
Assumes that fieldset has constant fields `Kh_zonal` and `Kh_meridional`.
@@ -101,15 +109,13 @@ def DiffusionUniformKh(particle, fieldset, time): # pragma: no cover
101109
The Wiener increment `dW` is normally distributed with zero
102110
mean and a standard deviation of sqrt(dt).
103111
"""
104-
dt = particle.dt / np.timedelta64(1, "s")
112+
dt = particles.dt / np.timedelta64(1, "s")
105113
# Wiener increment with zero mean and std of sqrt(dt)
106114
dWx = np.random.normal(0, np.sqrt(np.fabs(dt)))
107115
dWy = np.random.normal(0, np.sqrt(np.fabs(dt)))
108116

109-
print(particle)
110-
111-
bx = np.sqrt(2 * fieldset.Kh_zonal[particle])
112-
by = np.sqrt(2 * fieldset.Kh_meridional[particle])
117+
bx = np.sqrt(2 * fieldset.Kh_zonal[particles])
118+
by = np.sqrt(2 * fieldset.Kh_meridional[particles])
113119

114-
particle.dlon += bx * dWx
115-
particle.dlat += by * dWy
120+
particles.dlon += bx * dWx
121+
particles.dlat += by * dWy

parcels/field.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import inspect
43
import warnings
54
from collections.abc import Callable
65
from datetime import datetime
@@ -19,6 +18,7 @@
1918
ZeroInterpolator_Vector,
2019
)
2120
from parcels.particle import KernelParticle
21+
from parcels.tools._helpers import _assert_same_function_signature
2222
from parcels.tools.converters import (
2323
UnitConverter,
2424
unitconverters_map,
@@ -57,25 +57,6 @@ def _deal_with_errors(error, key, vector_type: VectorType):
5757
}
5858

5959

60-
def _assert_same_function_signature(f: Callable, *, ref: Callable) -> None:
61-
"""Ensures a function `f` has the same signature as the reference function `ref`."""
62-
sig_ref = inspect.signature(ref)
63-
sig = inspect.signature(f)
64-
65-
if len(sig_ref.parameters) != len(sig.parameters):
66-
raise ValueError(
67-
f"Interpolation function must have {len(sig_ref.parameters)} parameters, got {len(sig.parameters)}"
68-
)
69-
70-
for (_name1, param1), (_name2, param2) in zip(sig_ref.parameters.items(), sig.parameters.items(), strict=False):
71-
if param1.kind != param2.kind:
72-
raise ValueError(
73-
f"Parameter '{_name2}' has incorrect parameter kind. Expected {param1.kind}, got {param2.kind}"
74-
)
75-
if param1.name != param2.name:
76-
raise ValueError(f"Parameter '{_name2}' has incorrect name. Expected '{param1.name}', got '{param2.name}'")
77-
78-
7960
class Field:
8061
"""The Field class that holds scalar field data.
8162
The `Field` object is a wrapper around a xarray.DataArray or uxarray.UxDataArray object.
@@ -157,7 +138,7 @@ def __init__(
157138
if interp_method is None:
158139
self._interp_method = _DEFAULT_INTERPOLATOR_MAPPING[type(self.grid)]
159140
else:
160-
_assert_same_function_signature(interp_method, ref=ZeroInterpolator)
141+
_assert_same_function_signature(interp_method, ref=ZeroInterpolator, context="Interpolation")
161142
self._interp_method = interp_method
162143

163144
self.igrid = -1 # Default the grid index to -1
@@ -213,7 +194,7 @@ def interp_method(self):
213194

214195
@interp_method.setter
215196
def interp_method(self, method: Callable):
216-
_assert_same_function_signature(method, ref=ZeroInterpolator)
197+
_assert_same_function_signature(method, ref=ZeroInterpolator, context="Interpolation")
217198
self._interp_method = method
218199

219200
def _check_velocitysampling(self):
@@ -287,7 +268,7 @@ def __init__(
287268
if vector_interp_method is None:
288269
self._vector_interp_method = None
289270
else:
290-
_assert_same_function_signature(vector_interp_method, ref=ZeroInterpolator_Vector)
271+
_assert_same_function_signature(vector_interp_method, ref=ZeroInterpolator_Vector, context="Interpolation")
291272
self._vector_interp_method = vector_interp_method
292273

293274
def __repr__(self):
@@ -303,7 +284,7 @@ def vector_interp_method(self):
303284

304285
@vector_interp_method.setter
305286
def vector_interp_method(self, method: Callable):
306-
_assert_same_function_signature(method, ref=ZeroInterpolator_Vector)
287+
_assert_same_function_signature(method, ref=ZeroInterpolator_Vector, context="Interpolation")
307288
self._vector_interp_method = method
308289

309290
def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):

parcels/kernel.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AdvectionRK45,
1414
)
1515
from parcels.basegrid import GridType
16+
from parcels.tools._helpers import _assert_same_function_signature
1617
from parcels.tools.statuscodes import (
1718
StatusCode,
1819
_raise_field_interpolation_error,
@@ -23,6 +24,7 @@
2324
_raise_time_extrapolation_error,
2425
)
2526
from parcels.tools.warnings import KernelWarning
27+
from tests.common_kernels import DoNothing
2628

2729
if TYPE_CHECKING:
2830
from collections.abc import Callable
@@ -67,6 +69,7 @@ def __init__(
6769
for f in pyfuncs:
6870
if not isinstance(f, types.FunctionType):
6971
raise TypeError(f"Argument pyfunc should be a function or list of functions. Got {type(f)}")
72+
_assert_same_function_signature(f, ref=DoNothing, context="Kernel")
7073

7174
if len(pyfuncs) == 0:
7275
raise ValueError("List of `pyfuncs` should have at least one function.")
@@ -112,22 +115,22 @@ def remove_deleted(self, pset):
112115

113116
def add_positionupdate_kernels(self):
114117
# Adding kernels that set and update the coordinate changes
115-
def Setcoords(particle, fieldset, time): # pragma: no cover
118+
def Setcoords(particles, fieldset): # pragma: no cover
116119
import numpy as np # noqa
117120

118-
particle.dlon = 0
119-
particle.dlat = 0
120-
particle.ddepth = 0
121-
particle.lon = particle.lon_nextloop
122-
particle.lat = particle.lat_nextloop
123-
particle.depth = particle.depth_nextloop
124-
particle.time = particle.time_nextloop
121+
particles.dlon = 0
122+
particles.dlat = 0
123+
particles.ddepth = 0
124+
particles.lon = particles.lon_nextloop
125+
particles.lat = particles.lat_nextloop
126+
particles.depth = particles.depth_nextloop
127+
particles.time = particles.time_nextloop
125128

126-
def Updatecoords(particle, fieldset, time): # pragma: no cover
127-
particle.lon_nextloop = particle.lon + particle.dlon
128-
particle.lat_nextloop = particle.lat + particle.dlat
129-
particle.depth_nextloop = particle.depth + particle.ddepth
130-
particle.time_nextloop = particle.time + particle.dt
129+
def Updatecoords(particles, fieldset): # pragma: no cover
130+
particles.lon_nextloop = particles.lon + particles.dlon
131+
particles.lat_nextloop = particles.lat + particles.dlat
132+
particles.depth_nextloop = particles.depth + particles.ddepth
133+
particles.time_nextloop = particles.time + particles.dt
131134

132135
self._pyfuncs = (Setcoords + self + Updatecoords)._pyfuncs
133136

@@ -255,12 +258,12 @@ def execute(self, pset, endtime, dt):
255258
# run kernels for all particles that need to be evaluated
256259
evaluate_particles = (pset.state == StatusCode.Evaluate) & (pset.dt != 0)
257260
for f in self._pyfuncs:
258-
f(pset[evaluate_particles], self._fieldset, None)
261+
f(pset[evaluate_particles], self._fieldset)
259262

260263
# check for particles that have to be repeated
261264
repeat_particles = pset.state == StatusCode.Repeat
262265
while np.any(repeat_particles):
263-
f(pset[repeat_particles], self._fieldset, None)
266+
f(pset[repeat_particles], self._fieldset)
264267
repeat_particles = pset.state == StatusCode.Repeat
265268

266269
# revert to original dt (unless in RK45 mode)

parcels/tools/_helpers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import functools
6+
import inspect
67
import warnings
78
from collections.abc import Callable
89
from datetime import timedelta
@@ -75,3 +76,24 @@ def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float:
7576
def should_calculate_next_ti(ti: int, tau: float, tdim: int):
7677
"""Check if the time is beyond the last time in the field"""
7778
return np.greater(tau, 0) and ti < tdim - 1
79+
80+
81+
def _assert_same_function_signature(f: Callable, *, ref: Callable, context: str) -> None:
82+
"""Ensures a function `f` has the same signature as the reference function `ref`."""
83+
sig_ref = inspect.signature(ref)
84+
sig = inspect.signature(f)
85+
86+
if len(sig_ref.parameters) != len(sig.parameters):
87+
raise ValueError(
88+
f"{context} function must have {len(sig_ref.parameters)} parameters, got {len(sig.parameters)}"
89+
)
90+
91+
for param1, param2 in zip(sig_ref.parameters.values(), sig.parameters.values(), strict=False):
92+
if param1.kind != param2.kind:
93+
raise ValueError(
94+
f"Parameter '{param2.name}' has incorrect parameter kind. Expected {param1.kind}, got {param2.kind}"
95+
)
96+
if param1.name != param2.name:
97+
raise ValueError(
98+
f"Parameter '{param2.name}' has incorrect name. Expected '{param1.name}', got '{param2.name}'"
99+
)

tests/common_kernels.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
"""Shared kernels between tests."""
22

3+
import numpy as np
34

4-
def DoNothing(particle, fieldset, time): # pragma: no cover
5+
from parcels.tools.statuscodes import StatusCode
6+
7+
8+
def DoNothing(particles, fieldset): # pragma: no cover
59
pass
610

711

8-
def DeleteParticle(particle, fieldset, time): # pragma: no cover
9-
if particle.state >= 50: # This captures all Errors
10-
particle.delete()
12+
def DeleteParticle(particles, fieldset): # pragma: no cover
13+
particles.state = np.where(particles.state >= 50, StatusCode.Delete, particles.state)
1114

1215

13-
def MoveEast(particle, fieldset, time): # pragma: no cover
14-
particle.dlon += 0.1
16+
def MoveEast(particles, fieldset): # pragma: no cover
17+
particles.dlon += 0.1
1518

1619

17-
def MoveNorth(particle, fieldset, time): # pragma: no cover
18-
particle.dlat += 0.1
20+
def MoveNorth(particles, fieldset): # pragma: no cover
21+
particles.dlat += 0.1

0 commit comments

Comments
 (0)