Skip to content

Commit bec46b8

Browse files
committed
adds dt and kernel arguments where required in tests
1 parent 53ae8cd commit bec46b8

File tree

2 files changed

+23
-27
lines changed

2 files changed

+23
-27
lines changed

tests/test_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def ErrorKernel(particles, fieldset): # pragma: no cover
3030
particles.unknown_varname += 0.2
3131

3232
with pytest.raises(KeyError, match="'unknown_varname'"):
33-
pset.execute(ErrorKernel, runtime=np.timedelta64(2, "s"))
33+
pset.execute(ErrorKernel, runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s"))
3434

3535

3636
def test_kernel_init(fieldset):

tests/test_particleset_execute.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from parcels._datasets.structured.generic import datasets as datasets_structured
2424
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
2525
from parcels.interpolators import UXPiecewiseConstantFace
26-
from parcels.kernels import AdvectionEE
26+
from parcels.kernels import AdvectionEE, AdvectionRK4
2727
from tests import utils
2828
from tests.common_kernels import DoNothing
2929

@@ -61,64 +61,62 @@ def zonal_flow_fieldset() -> FieldSet:
6161
return FieldSet([U, V, UV])
6262

6363

64-
def test_pset_execute_implicit_dt_one_second(fieldset):
65-
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle)
66-
pset.execute(DoNothing, runtime=np.timedelta64(1, "s"))
67-
68-
time = pset.time.copy()
69-
70-
pset.execute(DoNothing, runtime=np.timedelta64(1, "s"))
71-
np.testing.assert_array_equal(pset.time, time + np.timedelta64(1, "s"))
72-
73-
7464
def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
7565
for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]:
7666
with pytest.raises(
7767
ValueError,
7868
match="dt must be a non-zero datetime.timedelta or np.timedelta64 object, got .*",
7969
):
80-
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt)
70+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(AdvectionRK4, dt=dt)
8171

8272
with pytest.raises(
8373
ValueError,
8474
match="runtime and endtime are mutually exclusive - provide one or the other. Got .*",
8575
):
8676
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
87-
runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01")
77+
AdvectionRK4, runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(1, "s")
8878
)
8979

9080
with pytest.raises(
9181
ValueError,
9282
match="The runtime must be a datetime.timedelta or np.timedelta64 object. Got .*",
9383
):
94-
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1)
84+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
85+
AdvectionRK4, runtime=1, dt=np.timedelta64(1, "s")
86+
)
9587

9688
msg = """Calculated/provided end time of .* is not in fieldset time interval .* Either reduce your runtime, modify your provided endtime, or change your release timing.*"""
9789
with pytest.raises(
9890
ValueError,
9991
match=msg,
10092
):
101-
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=np.datetime64("1990-01-01"))
93+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
94+
AdvectionRK4, endtime=np.datetime64("1990-01-01"), dt=np.timedelta64(1, "s")
95+
)
10296

10397
with pytest.raises(
10498
ValueError,
10599
match=msg,
106100
):
107101
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
108-
endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s")
102+
AdvectionRK4, endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s")
109103
)
110104

111105
with pytest.raises(
112106
ValueError,
113107
match="The endtime must be of the same type as the fieldset.time_interval start time. Got .*",
114108
):
115-
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=12345)
109+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
110+
AdvectionRK4, endtime=12345, dt=np.timedelta64(1, "s")
111+
)
116112

117113
with pytest.raises(
118114
ValueError,
119115
match="The runtime must be provided when the time_interval is not defined for a fieldset.",
120116
):
121-
ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute()
117+
ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute(
118+
AdvectionRK4, dt=np.timedelta64(1, "s")
119+
)
122120

123121

124122
@pytest.mark.parametrize(
@@ -222,13 +220,12 @@ def AddLat(particles, fieldset): # pragma: no cover
222220

223221
@pytest.mark.parametrize(
224222
"starttime, endtime, dt",
225-
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, None)],
223+
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, 1)],
226224
)
227225
def test_execution_endtime(fieldset, starttime, endtime, dt):
228226
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
229227
endtime = fieldset.time_interval.left + np.timedelta64(endtime, "s")
230-
if dt is not None:
231-
dt = np.timedelta64(dt, "s")
228+
dt = np.timedelta64(dt, "s")
232229
pset = ParticleSet(fieldset, time=starttime, lon=0, lat=0)
233230
pset.execute(DoNothing, endtime=endtime, dt=dt)
234231
assert abs(pset.time_nextloop - endtime) < np.timedelta64(1, "ms")
@@ -348,15 +345,14 @@ def MoveLeft(particles, fieldset): # pragma: no cover
348345

349346
@pytest.mark.parametrize(
350347
"starttime, runtime, dt",
351-
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, None)],
348+
[(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, 1)],
352349
)
353350
@pytest.mark.parametrize("npart", [1, 10])
354351
def test_execution_runtime(fieldset, starttime, runtime, dt, npart):
355352
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
356353
runtime = np.timedelta64(runtime, "s")
357-
sign_dt = 1 if dt is None else np.sign(dt)
358-
if dt is not None:
359-
dt = np.timedelta64(dt, "s")
354+
sign_dt = np.sign(dt)
355+
dt = np.timedelta64(dt, "s")
360356
pset = ParticleSet(fieldset, time=starttime, lon=np.zeros(npart), lat=np.zeros(npart))
361357
pset.execute(DoNothing, runtime=runtime, dt=dt)
362358
assert all([abs(p.time_nextloop - starttime - runtime * sign_dt) < np.timedelta64(1, "ms") for p in pset])
@@ -475,9 +471,9 @@ def test_uxstommelgyre_pset_execute():
475471
pclass=Particle,
476472
)
477473
pset.execute(
474+
AdvectionEE,
478475
runtime=np.timedelta64(10, "m"),
479476
dt=np.timedelta64(60, "s"),
480-
pyfunc=AdvectionEE,
481477
)
482478
assert utils.round_and_hash_float_array([p.lon for p in pset]) == 1165396086
483479
assert utils.round_and_hash_float_array([p.lat for p in pset]) == 1142124776

0 commit comments

Comments
 (0)