|
23 | 23 | from parcels._datasets.structured.generic import datasets as datasets_structured |
24 | 24 | from parcels._datasets.unstructured.generic import datasets as datasets_unstructured |
25 | 25 | from parcels.interpolators import UXPiecewiseConstantFace |
26 | | -from parcels.kernels import AdvectionEE |
| 26 | +from parcels.kernels import AdvectionEE, AdvectionRK4 |
27 | 27 | from tests import utils |
28 | 28 | from tests.common_kernels import DoNothing |
29 | 29 |
|
@@ -61,64 +61,62 @@ def zonal_flow_fieldset() -> FieldSet: |
61 | 61 | return FieldSet([U, V, UV]) |
62 | 62 |
|
63 | 63 |
|
64 | | -def test_pset_execute_implicit_dt_one_second(fieldset): |
65 | | - pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle) |
66 | | - pset.execute(DoNothing, runtime=np.timedelta64(1, "s")) |
67 | | - |
68 | | - time = pset.time.copy() |
69 | | - |
70 | | - pset.execute(DoNothing, runtime=np.timedelta64(1, "s")) |
71 | | - np.testing.assert_array_equal(pset.time, time + np.timedelta64(1, "s")) |
72 | | - |
73 | | - |
74 | 64 | def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval): |
75 | 65 | for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]: |
76 | 66 | with pytest.raises( |
77 | 67 | ValueError, |
78 | 68 | match="dt must be a non-zero datetime.timedelta or np.timedelta64 object, got .*", |
79 | 69 | ): |
80 | | - ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt) |
| 70 | + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(AdvectionRK4, dt=dt) |
81 | 71 |
|
82 | 72 | with pytest.raises( |
83 | 73 | ValueError, |
84 | 74 | match="runtime and endtime are mutually exclusive - provide one or the other. Got .*", |
85 | 75 | ): |
86 | 76 | ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute( |
87 | | - runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01") |
| 77 | + AdvectionRK4, runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(1, "s") |
88 | 78 | ) |
89 | 79 |
|
90 | 80 | with pytest.raises( |
91 | 81 | ValueError, |
92 | 82 | match="The runtime must be a datetime.timedelta or np.timedelta64 object. Got .*", |
93 | 83 | ): |
94 | | - ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1) |
| 84 | + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute( |
| 85 | + AdvectionRK4, runtime=1, dt=np.timedelta64(1, "s") |
| 86 | + ) |
95 | 87 |
|
96 | 88 | msg = """Calculated/provided end time of .* is not in fieldset time interval .* Either reduce your runtime, modify your provided endtime, or change your release timing.*""" |
97 | 89 | with pytest.raises( |
98 | 90 | ValueError, |
99 | 91 | match=msg, |
100 | 92 | ): |
101 | | - ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=np.datetime64("1990-01-01")) |
| 93 | + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute( |
| 94 | + AdvectionRK4, endtime=np.datetime64("1990-01-01"), dt=np.timedelta64(1, "s") |
| 95 | + ) |
102 | 96 |
|
103 | 97 | with pytest.raises( |
104 | 98 | ValueError, |
105 | 99 | match=msg, |
106 | 100 | ): |
107 | 101 | ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute( |
108 | | - endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s") |
| 102 | + AdvectionRK4, endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s") |
109 | 103 | ) |
110 | 104 |
|
111 | 105 | with pytest.raises( |
112 | 106 | ValueError, |
113 | 107 | match="The endtime must be of the same type as the fieldset.time_interval start time. Got .*", |
114 | 108 | ): |
115 | | - ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=12345) |
| 109 | + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute( |
| 110 | + AdvectionRK4, endtime=12345, dt=np.timedelta64(1, "s") |
| 111 | + ) |
116 | 112 |
|
117 | 113 | with pytest.raises( |
118 | 114 | ValueError, |
119 | 115 | match="The runtime must be provided when the time_interval is not defined for a fieldset.", |
120 | 116 | ): |
121 | | - ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute() |
| 117 | + ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute( |
| 118 | + AdvectionRK4, dt=np.timedelta64(1, "s") |
| 119 | + ) |
122 | 120 |
|
123 | 121 |
|
124 | 122 | @pytest.mark.parametrize( |
@@ -222,13 +220,12 @@ def AddLat(particles, fieldset): # pragma: no cover |
222 | 220 |
|
223 | 221 | @pytest.mark.parametrize( |
224 | 222 | "starttime, endtime, dt", |
225 | | - [(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, None)], |
| 223 | + [(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, 1)], |
226 | 224 | ) |
227 | 225 | def test_execution_endtime(fieldset, starttime, endtime, dt): |
228 | 226 | starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s") |
229 | 227 | endtime = fieldset.time_interval.left + np.timedelta64(endtime, "s") |
230 | | - if dt is not None: |
231 | | - dt = np.timedelta64(dt, "s") |
| 228 | + dt = np.timedelta64(dt, "s") |
232 | 229 | pset = ParticleSet(fieldset, time=starttime, lon=0, lat=0) |
233 | 230 | pset.execute(DoNothing, endtime=endtime, dt=dt) |
234 | 231 | assert abs(pset.time_nextloop - endtime) < np.timedelta64(1, "ms") |
@@ -348,15 +345,14 @@ def MoveLeft(particles, fieldset): # pragma: no cover |
348 | 345 |
|
349 | 346 | @pytest.mark.parametrize( |
350 | 347 | "starttime, runtime, dt", |
351 | | - [(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, None)], |
| 348 | + [(0, 10, 1), (0, 10, 3), (2, 16, 3), (20, 10, -1), (20, 0, -2), (5, 15, 1)], |
352 | 349 | ) |
353 | 350 | @pytest.mark.parametrize("npart", [1, 10]) |
354 | 351 | def test_execution_runtime(fieldset, starttime, runtime, dt, npart): |
355 | 352 | starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s") |
356 | 353 | runtime = np.timedelta64(runtime, "s") |
357 | | - sign_dt = 1 if dt is None else np.sign(dt) |
358 | | - if dt is not None: |
359 | | - dt = np.timedelta64(dt, "s") |
| 354 | + sign_dt = np.sign(dt) |
| 355 | + dt = np.timedelta64(dt, "s") |
360 | 356 | pset = ParticleSet(fieldset, time=starttime, lon=np.zeros(npart), lat=np.zeros(npart)) |
361 | 357 | pset.execute(DoNothing, runtime=runtime, dt=dt) |
362 | 358 | assert all([abs(p.time_nextloop - starttime - runtime * sign_dt) < np.timedelta64(1, "ms") for p in pset]) |
@@ -475,9 +471,9 @@ def test_uxstommelgyre_pset_execute(): |
475 | 471 | pclass=Particle, |
476 | 472 | ) |
477 | 473 | pset.execute( |
| 474 | + AdvectionEE, |
478 | 475 | runtime=np.timedelta64(10, "m"), |
479 | 476 | dt=np.timedelta64(60, "s"), |
480 | | - pyfunc=AdvectionEE, |
481 | 477 | ) |
482 | 478 | assert utils.round_and_hash_float_array([p.lon for p in pset]) == 1165396086 |
483 | 479 | assert utils.round_and_hash_float_array([p.lat for p in pset]) == 1142124776 |
|
0 commit comments