Skip to content

Commit ef12d26

Browse files
Using DEFAULT_PARTICLES for RK45/next_dt in test_advection
1 parent 2b166b8 commit ef12d26

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

tests/test_advection.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
AdvectionRK4_3D,
2626
AdvectionRK45,
2727
)
28-
from tests.utils import round_and_hash_float_array
28+
from tests.utils import DEFAULT_PARTICLES, round_and_hash_float_array
2929

3030

3131
@pytest.mark.parametrize("mesh", ["spherical", "flat"])
@@ -287,12 +287,9 @@ def test_moving_eddy(kernel, rtol):
287287

288288
if kernel == AdvectionRK45:
289289
fieldset.add_constant("RK45_tol", rtol)
290-
MyParticle = Particle.add_variable(Variable("next_dt"))
291-
else:
292-
MyParticle = Particle
293290

294291
pset = ParticleSet(
295-
fieldset, pclass=MyParticle, lon=start_lon, lat=start_lat, z=start_z, time=np.timedelta64(0, "s")
292+
fieldset, pclass=DEFAULT_PARTICLES[kernel], lon=start_lon, lat=start_lat, z=start_z, time=np.timedelta64(0, "s")
296293
)
297294
pset.execute(kernel, dt=dt, endtime=endtime)
298295

@@ -333,11 +330,10 @@ def test_decaying_moving_eddy(kernel, rtol):
333330
if kernel == AdvectionRK45:
334331
fieldset.add_constant("RK45_tol", rtol)
335332
fieldset.add_constant("RK45_min_dt", 10 * 60)
336-
MyParticle = Particle.add_variable(Variable("next_dt"))
337-
else:
338-
MyParticle = Particle
339333

340-
pset = ParticleSet(fieldset, pclass=MyParticle, lon=start_lon, lat=start_lat, time=np.timedelta64(0, "s"))
334+
pset = ParticleSet(
335+
fieldset, pclass=DEFAULT_PARTICLES[kernel], lon=start_lon, lat=start_lat, time=np.timedelta64(0, "s")
336+
)
341337
pset.execute(kernel, dt=dt, endtime=endtime)
342338

343339
def truth_moving(x_0, y_0, t):
@@ -384,13 +380,12 @@ def test_stommelgyre_fieldset(kernel, rtol, grid_type):
384380
start_lon = np.linspace(10e3, 100e3, npart)
385381
start_lat = np.ones_like(start_lon) * 5000e3
386382

387-
SampleParticle = Particle.add_variable(
383+
SampleParticle = DEFAULT_PARTICLES[kernel].add_variable(
388384
[Variable("p", initial=0.0, dtype=np.float32), Variable("p_start", initial=0.0, dtype=np.float32)]
389385
)
390386

391387
if kernel == AdvectionRK45:
392388
fieldset.add_constant("RK45_tol", rtol)
393-
SampleParticle = SampleParticle.add_variable(Variable("next_dt"))
394389

395390
def UpdateP(particles, fieldset): # pragma: no cover
396391
particles.p = fieldset.P[particles.time, particles.z, particles.lat, particles.lon]
@@ -425,13 +420,12 @@ def test_peninsula_fieldset(kernel, rtol, grid_type):
425420
start_lat = np.linspace(3e3, 47e3, npart)
426421
start_lon = 3e3 * np.ones_like(start_lat)
427422

428-
SampleParticle = Particle.add_variable(
423+
SampleParticle = DEFAULT_PARTICLES[kernel].add_variable(
429424
[Variable("p", initial=0.0, dtype=np.float32), Variable("p_start", initial=0.0, dtype=np.float32)]
430425
)
431426

432427
if kernel == AdvectionRK45:
433428
fieldset.add_constant("RK45_tol", rtol)
434-
SampleParticle = SampleParticle.add_variable(Variable("next_dt"))
435429

436430
def UpdateP(particles, fieldset): # pragma: no cover
437431
particles.p = fieldset.P[particles.time, particles.z, particles.lat, particles.lon]

0 commit comments

Comments
 (0)