Skip to content

Commit 43fbf59

Browse files
Merge pull request #2274 from andrew-s28/dt-required-arg
Require `dt` argument in ParticleSet.execute()
2 parents ee0de5f + 7c4e05f commit 43fbf59

File tree

3 files changed

+29
-37
lines changed

3 files changed

+29
-37
lines changed

parcels/_core/particleset.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from parcels._core.warnings import ParticleSetWarning
1818
from parcels._logger import logger
1919
from parcels._reprs import particleset_repr
20-
from parcels.kernels import AdvectionRK4
2120

2221
__all__ = ["ParticleSet"]
2322

@@ -452,10 +451,10 @@ def set_variable_write_status(self, var, write_status):
452451

453452
def execute(
454453
self,
455-
pyfunc=AdvectionRK4,
454+
pyfunc,
455+
dt: datetime.timedelta | np.timedelta64,
456456
endtime: np.timedelta64 | np.datetime64 | None = None,
457457
runtime: datetime.timedelta | np.timedelta64 | None = None,
458-
dt: datetime.timedelta | np.timedelta64 | None = None,
459458
output_file=None,
460459
verbose_progress=True,
461460
):
@@ -469,17 +468,17 @@ def execute(
469468
pyfunc :
470469
Kernel function to execute. This can be the name of a
471470
defined Python function or a :class:`parcels.kernel.Kernel` object.
472-
Kernels can be concatenated using the + operator (Default value = AdvectionRK4)
471+
Kernels can be concatenated using the + operator.
472+
dt (np.timedelta64):
473+
Timestep interval (as a np.timedelta64 object) to be passed to the kernel.
474+
Use a negative value for a backward-in-time simulation.
473475
endtime (np.datetime64 or np.timedelta64): :
474476
End time for the timestepping loop. If a np.timedelta64 is provided, it is interpreted as the total simulation time. In this case,
475477
the absolute end time is the start of the fieldset's time interval plus the np.timedelta64.
476478
If a datetime is provided, it is interpreted as the absolute end time of the simulation.
477479
runtime (np.timedelta64):
478480
The duration of the simuulation execution. Must be a np.timedelta64 object and is required to be set when the `fieldset.time_interval` is not defined.
479481
If the `fieldset.time_interval` is defined and the runtime is provided, the end time will be the start of the fieldset's time interval plus the runtime.
480-
dt (np.timedelta64):
481-
Timestep interval (as a np.timedelta64 object) to be passed to the kernel.
482-
Use a negative value for a backward-in-time simulation. (Default value = 1 second)
483482
output_file :
484483
mod:`parcels.particlefile.ParticleFile` object for particle output (Default value = None)
485484
verbose_progress : bool
@@ -502,9 +501,6 @@ def execute(
502501
output_file.set_metadata(self.fieldset.gridset[0]._mesh)
503502
output_file.metadata["parcels_kernels"] = self._kernel.funcname
504503

505-
if dt is None:
506-
dt = np.timedelta64(1, "s")
507-
508504
try:
509505
dt = maybe_convert_python_timedelta_to_numpy(dt)
510506
assert not np.isnat(dt)

tests/test_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def ErrorKernel(particles, fieldset): # pragma: no cover
3030
particles.unknown_varname += 0.2
3131

3232
with pytest.raises(KeyError, match="'unknown_varname'"):
33-
pset.execute(ErrorKernel, runtime=np.timedelta64(2, "s"))
33+
pset.execute(ErrorKernel, runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s"))
3434

3535

3636
def test_kernel_init(fieldset):

tests/test_particleset_execute.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from parcels._datasets.structured.generic import datasets as datasets_structured
2424
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
2525
from parcels.interpolators import UXPiecewiseConstantFace
26-
from parcels.kernels import AdvectionEE
26+
from parcels.kernels import AdvectionEE, AdvectionRK4
2727
from tests import utils
2828
from tests.common_kernels import DoNothing
2929

@@ -61,64 +61,62 @@ def zonal_flow_fieldset() -> FieldSet:
6161
return FieldSet([U, V, UV])
6262

6363

64-
def test_pset_execute_implicit_dt_one_second(fieldset):
65-
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle)
66-
pset.execute(DoNothing, runtime=np.timedelta64(1, "s"))
67-
68-
time = pset.time.copy()
69-
70-
pset.execute(DoNothing, runtime=np.timedelta64(1, "s"))
71-
np.testing.assert_array_equal(pset.time, time + np.timedelta64(1, "s"))
72-
73-
7464
def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
7565
for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]:
7666
with pytest.raises(
7767
ValueError,
7868
match="dt must be a non-zero datetime.timedelta or np.timedelta64 object, got .*",
7969
):
80-
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt)
70+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(AdvectionRK4, dt=dt)
8171

8272
with pytest.raises(
8373
ValueError,
8474
match="runtime and endtime are mutually exclusive - provide one or the other. Got .*",
8575
):
8676
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
87-
runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01")
77+
AdvectionRK4, runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(1, "s")
8878
)
8979

9080
with pytest.raises(
9181
ValueError,
9282
match="The runtime must be a datetime.timedelta or np.timedelta64 object. Got .*",
9383
):
94-
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1)
84+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
85+
AdvectionRK4, runtime=1, dt=np.timedelta64(1, "s")
86+
)
9587

9688
msg = """Calculated/provided end time of .* is not in fieldset time interval .* Either reduce your runtime, modify your provided endtime, or change your release timing.*"""
9789
with pytest.raises(
9890
ValueError,
9991
match=msg,
10092
):
101-
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=np.datetime64("1990-01-01"))
93+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
94+
AdvectionRK4, endtime=np.datetime64("1990-01-01"), dt=np.timedelta64(1, "s")
95+
)
10296

10397
with pytest.raises(
10498
ValueError,
10599
match=msg,
106100
):
107101
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
108-
endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s")
102+
AdvectionRK4, endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s")
109103
)
110104

111105
with pytest.raises(
112106
ValueError,
113107
match="The endtime must be of the same type as the fieldset.time_interval start time. Got .*",
114108
):
115-
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=12345)
109+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
110+
AdvectionRK4, endtime=12345, dt=np.timedelta64(1, "s")
111+
)
116112

117113
with pytest.raises(
118114
ValueError,
119115
match="The runtime must be provided when the time_interval is not defined for a fieldset.",
120116
):
121-
ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute()
117+
ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute(
118+
AdvectionRK4, dt=np.timedelta64(1, "s")
119+
)
122120

123121

124122
@pytest.mark.parametrize(
@@ -222,13 +220,12 @@ def AddLat(particles, fieldset): # pragma: no cover
222220

223221
@pytest.mark.parametrize(
224222
"starttime, endtime, dt",
225-
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, None)],
223+
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, 1)],
226224
)
227225
def test_execution_endtime(fieldset, starttime, endtime, dt):
228226
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
229227
endtime = fieldset.time_interval.left + np.timedelta64(endtime, "s")
230-
if dt is not None:
231-
dt = np.timedelta64(dt, "s")
228+
dt = np.timedelta64(dt, "s")
232229
pset = ParticleSet(fieldset, time=starttime, lon=0, lat=0)
233230
pset.execute(DoNothing, endtime=endtime, dt=dt)
234231
assert abs(pset.time_nextloop - endtime) < np.timedelta64(1, "ms")
@@ -348,15 +345,14 @@ def MoveLeft(particles, fieldset): # pragma: no cover
348345

349346
@pytest.mark.parametrize(
350347
"starttime, runtime, dt",
351-
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, None)],
348+
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, 1)],
352349
)
353350
@pytest.mark.parametrize("npart", [1, 10])
354351
def test_execution_runtime(fieldset, starttime, runtime, dt, npart):
355352
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
356353
runtime = np.timedelta64(runtime, "s")
357-
sign_dt = 1 if dt is None else np.sign(dt)
358-
if dt is not None:
359-
dt = np.timedelta64(dt, "s")
354+
sign_dt = np.sign(dt)
355+
dt = np.timedelta64(dt, "s")
360356
pset = ParticleSet(fieldset, time=starttime, lon=np.zeros(npart), lat=np.zeros(npart))
361357
pset.execute(DoNothing, runtime=runtime, dt=dt)
362358
assert all([abs(p.time_nextloop - starttime - runtime * sign_dt) < np.timedelta64(1, "ms") for p in pset])
@@ -475,9 +471,9 @@ def test_uxstommelgyre_pset_execute():
475471
pclass=Particle,
476472
)
477473
pset.execute(
474+
AdvectionEE,
478475
runtime=np.timedelta64(10, "m"),
479476
dt=np.timedelta64(60, "s"),
480-
pyfunc=AdvectionEE,
481477
)
482478
assert utils.round_and_hash_float_array([p.lon for p in pset]) == 1165396086
483479
assert utils.round_and_hash_float_array([p.lat for p in pset]) == 1142124776

0 commit comments

Comments
 (0)