Skip to content

Commit dcdd2d1

Browse files
Adding unit test to text multiple boolean masks in Kernels
1 parent a60bc9b commit dcdd2d1

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/v4/test_particleset_execute.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,49 @@ def PythonFail(particles, fieldset): # pragma: no cover
364364
assert all(pset.time == fieldset.time_interval.left + np.timedelta64(10, "s"))
365365

366366

367+
@pytest.mark.parametrize(
368+
"kernel_names, expected", [("Lat1", [0, 1]), ("Lat2", [2, 0]), ("Lat1and2", [2, 1]), ("Lat1then2", [2, 1])]
369+
)
370+
def test_execution_update_particle_in_kernel_function(fieldset, kernel_names, expected):
371+
npart = 2
372+
373+
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart))
374+
375+
def Lat1(particles, fieldset): # pragma: no cover
376+
def SetLat1(p):
377+
p.lat = 1
378+
379+
SetLat1(particles[(particles.lat == 0) & (particles.lon > 0.5)])
380+
381+
def Lat2(particles, fieldset): # pragma: no cover
382+
def SetLat2(p):
383+
p.lat = 2
384+
385+
SetLat2(particles[(particles.lat == 0) & (particles.lon < 0.5)])
386+
387+
def Lat1and2(particles, fieldset): # pragma: no cover
388+
def SetLat1(p):
389+
p.lat = 1
390+
391+
def SetLat2(p):
392+
p.lat = 2
393+
394+
SetLat1(particles[(particles.lat == 0) & (particles.lon > 0.5)])
395+
SetLat2(particles[(particles.lat == 0) & (particles.lon < 0.5)])
396+
397+
if kernel_names == "Lat1":
398+
kernels = [Lat1]
399+
elif kernel_names == "Lat2":
400+
kernels = [Lat2]
401+
elif kernel_names == "Lat1and2":
402+
kernels = [Lat1and2]
403+
elif kernel_names == "Lat1then2":
404+
kernels = [Lat1, Lat2]
405+
406+
pset.execute(kernels, runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s"))
407+
np.testing.assert_allclose(pset.lat, expected, rtol=1e-5)
408+
409+
367410
def test_uxstommelgyre_pset_execute():
368411
ds = datasets_unstructured["stommel_gyre_delaunay"]
369412
grid = UxGrid(grid=ds.uxgrid, z=ds.coords["nz"], mesh="spherical")

0 commit comments

Comments
 (0)