@@ -364,6 +364,49 @@ def PythonFail(particles, fieldset): # pragma: no cover
364364 assert all (pset .time == fieldset .time_interval .left + np .timedelta64 (10 , "s" ))
365365
366366
367+ @pytest .mark .parametrize (
368+ "kernel_names, expected" , [("Lat1" , [0 , 1 ]), ("Lat2" , [2 , 0 ]), ("Lat1and2" , [2 , 1 ]), ("Lat1then2" , [2 , 1 ])]
369+ )
370+ def test_execution_update_particle_in_kernel_function (fieldset , kernel_names , expected ):
371+ npart = 2
372+
373+ pset = ParticleSet (fieldset , lon = np .linspace (0 , 1 , npart ), lat = np .zeros (npart ))
374+
375+ def Lat1 (particles , fieldset ): # pragma: no cover
376+ def SetLat1 (p ):
377+ p .lat = 1
378+
379+ SetLat1 (particles [(particles .lat == 0 ) & (particles .lon > 0.5 )])
380+
381+ def Lat2 (particles , fieldset ): # pragma: no cover
382+ def SetLat2 (p ):
383+ p .lat = 2
384+
385+ SetLat2 (particles [(particles .lat == 0 ) & (particles .lon < 0.5 )])
386+
387+ def Lat1and2 (particles , fieldset ): # pragma: no cover
388+ def SetLat1 (p ):
389+ p .lat = 1
390+
391+ def SetLat2 (p ):
392+ p .lat = 2
393+
394+ SetLat1 (particles [(particles .lat == 0 ) & (particles .lon > 0.5 )])
395+ SetLat2 (particles [(particles .lat == 0 ) & (particles .lon < 0.5 )])
396+
397+ if kernel_names == "Lat1" :
398+ kernels = [Lat1 ]
399+ elif kernel_names == "Lat2" :
400+ kernels = [Lat2 ]
401+ elif kernel_names == "Lat1and2" :
402+ kernels = [Lat1and2 ]
403+ elif kernel_names == "Lat1then2" :
404+ kernels = [Lat1 , Lat2 ]
405+
406+ pset .execute (kernels , runtime = np .timedelta64 (2 , "s" ), dt = np .timedelta64 (1 , "s" ))
407+ np .testing .assert_allclose (pset .lat , expected , rtol = 1e-5 )
408+
409+
367410def test_uxstommelgyre_pset_execute ():
368411 ds = datasets_unstructured ["stommel_gyre_delaunay" ]
369412 grid = UxGrid (grid = ds .uxgrid , z = ds .coords ["nz" ], mesh = "spherical" )
0 commit comments