Skip to content

Commit 161ad9c

Browse files
updating advection tests for faster runtime
1 parent 4c90bbc commit 161ad9c

File tree

1 file changed

+52
-21
lines changed

1 file changed

+52
-21
lines changed

tests/v4/test_advection.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,11 @@ def test_decaying_moving_eddy(method, rtol):
300300
fieldset = FieldSet([U, V, UV])
301301

302302
start_lon, start_lat = 10000, 10000
303-
dt = np.timedelta64(30, "m")
303+
dt = np.timedelta64(60, "m")
304304

305305
if method == "RK45":
306306
fieldset.add_constant("RK45_tol", rtol)
307-
fieldset.add_constant("RK45_min_dt", 60)
307+
fieldset.add_constant("RK45_min_dt", 10 * 60)
308308

309309
pset = ParticleSet(fieldset, lon=start_lon, lat=start_lat, time=np.timedelta64(0, "s"))
310310
pset.execute(kernel[method], dt=dt, endtime=np.timedelta64(1, "D"))
@@ -328,39 +328,70 @@ def truth_moving(x_0, y_0, t):
328328
np.testing.assert_allclose(pset.lat_nextloop, exp_lat, rtol=rtol)
329329

330330

331-
# TODO decrease atol for these tests once the C-grid is implemented
332331
@pytest.mark.parametrize(
333-
"method, atol",
332+
"method, rtol",
334333
[
335-
("RK4", 1),
336-
("RK45", 10),
334+
("RK4", 0.1),
335+
("RK45", 0.1),
337336
],
338337
)
339338
@pytest.mark.parametrize("grid_type", ["A"]) # TODO also implement C-grid once available
340-
@pytest.mark.parametrize("flowfield", ["stommel_gyre", "peninsula"])
341-
def test_gyre_flowfields(method, grid_type, atol, flowfield):
339+
def test_stommelgyre_fieldset(method, rtol, grid_type):
342340
npart = 2
343-
if flowfield == "peninsula":
344-
ds = peninsula_dataset(grid_type=grid_type)
345-
start_lat = np.linspace(3e3, 47e3, npart)
346-
start_lon = 3e3 * np.ones_like(start_lat)
347-
runtime = np.timedelta64(1, "D")
348-
else:
349-
ds = stommel_gyre_dataset(grid_type=grid_type)
350-
start_lon = np.linspace(10e3, 100e3, npart)
351-
start_lat = np.ones_like(start_lon) * 5000e3
352-
runtime = np.timedelta64(2, "D")
341+
ds = stommel_gyre_dataset(grid_type=grid_type)
353342
grid = XGrid.from_dataset(ds)
354343
U = Field("U", ds["U"], grid, interp_method=XLinear)
355344
V = Field("V", ds["V"], grid, interp_method=XLinear)
356345
P = Field("P", ds["P"], grid, interp_method=XLinear)
357346
UV = VectorField("UV", U, V)
358347
fieldset = FieldSet([U, V, P, UV])
359348

360-
dt = np.timedelta64(1, "m") # TODO check these settings (and possibly increase)
349+
dt = np.timedelta64(30, "m")
350+
runtime = np.timedelta64(1, "D")
351+
start_lon = np.linspace(10e3, 100e3, npart)
352+
start_lat = np.ones_like(start_lon) * 5000e3
361353

362354
if method == "RK45":
363-
fieldset.add_constant("RK45_tol", atol)
355+
fieldset.add_constant("RK45_tol", rtol)
356+
357+
SampleParticle = Particle.add_variable(
358+
[Variable("p", initial=0.0, dtype=np.float32), Variable("p_start", initial=0.0, dtype=np.float32)]
359+
)
360+
361+
def UpdateP(particle, fieldset, time): # pragma: no cover
362+
particle.p = fieldset.P[particle.time, particle.depth, particle.lat, particle.lon]
363+
particle.p_start = np.where(particle.time == 0, particle.p, particle.p_start)
364+
365+
pset = ParticleSet(fieldset, pclass=SampleParticle, lon=start_lon, lat=start_lat, time=np.timedelta64(0, "s"))
366+
pset.execute([kernel[method], UpdateP], dt=dt, runtime=runtime)
367+
np.testing.assert_allclose(pset.p, pset.p_start, rtol=rtol)
368+
369+
370+
@pytest.mark.parametrize(
371+
"method, rtol",
372+
[
373+
("RK4", 5e-3),
374+
("RK45", 1e-4),
375+
],
376+
)
377+
@pytest.mark.parametrize("grid_type", ["A"]) # TODO also implement C-grid once available
378+
def test_peninsula_fieldset(method, rtol, grid_type):
379+
npart = 2
380+
ds = peninsula_dataset(grid_type=grid_type)
381+
grid = XGrid.from_dataset(ds)
382+
U = Field("U", ds["U"], grid, interp_method=XLinear)
383+
V = Field("V", ds["V"], grid, interp_method=XLinear)
384+
P = Field("P", ds["P"], grid, interp_method=XLinear)
385+
UV = VectorField("UV", U, V)
386+
fieldset = FieldSet([U, V, P, UV])
387+
388+
dt = np.timedelta64(30, "m")
389+
runtime = np.timedelta64(1, "D")
390+
start_lat = np.linspace(3e3, 47e3, npart)
391+
start_lon = 3e3 * np.ones_like(start_lat)
392+
393+
if method == "RK45":
394+
fieldset.add_constant("RK45_tol", rtol)
364395

365396
SampleParticle = Particle.add_variable(
366397
[Variable("p", initial=0.0, dtype=np.float32), Variable("p_start", initial=0.0, dtype=np.float32)]
@@ -372,4 +403,4 @@ def UpdateP(particle, fieldset, time): # pragma: no cover
372403

373404
pset = ParticleSet(fieldset, pclass=SampleParticle, lon=start_lon, lat=start_lat, time=np.timedelta64(0, "s"))
374405
pset.execute([kernel[method], UpdateP], dt=dt, runtime=runtime)
375-
np.testing.assert_allclose(pset.p, pset.p_start, atol=atol)
406+
np.testing.assert_allclose(pset.p, pset.p_start, rtol=rtol)

0 commit comments

Comments
 (0)