@@ -300,11 +300,11 @@ def test_decaying_moving_eddy(method, rtol):
300300 fieldset = FieldSet ([U , V , UV ])
301301
302302 start_lon , start_lat = 10000 , 10000
303- dt = np .timedelta64 (30 , "m" )
303+ dt = np .timedelta64 (60 , "m" )
304304
305305 if method == "RK45" :
306306 fieldset .add_constant ("RK45_tol" , rtol )
307- fieldset .add_constant ("RK45_min_dt" , 60 )
307+ fieldset .add_constant ("RK45_min_dt" , 10 * 60 )
308308
309309 pset = ParticleSet (fieldset , lon = start_lon , lat = start_lat , time = np .timedelta64 (0 , "s" ))
310310 pset .execute (kernel [method ], dt = dt , endtime = np .timedelta64 (1 , "D" ))
@@ -328,39 +328,70 @@ def truth_moving(x_0, y_0, t):
328328 np .testing .assert_allclose (pset .lat_nextloop , exp_lat , rtol = rtol )
329329
330330
331- # TODO decrease atol for these tests once the C-grid is implemented
332331@pytest .mark .parametrize (
333- "method, atol " ,
332+ "method, rtol " ,
334333 [
335- ("RK4" , 1 ),
336- ("RK45" , 10 ),
334+ ("RK4" , 0. 1 ),
335+ ("RK45" , 0.1 ),
337336 ],
338337)
339338@pytest .mark .parametrize ("grid_type" , ["A" ]) # TODO also implement C-grid once available
340- @pytest .mark .parametrize ("flowfield" , ["stommel_gyre" , "peninsula" ])
341- def test_gyre_flowfields (method , grid_type , atol , flowfield ):
339+ def test_stommelgyre_fieldset (method , rtol , grid_type ):
342340 npart = 2
343- if flowfield == "peninsula" :
344- ds = peninsula_dataset (grid_type = grid_type )
345- start_lat = np .linspace (3e3 , 47e3 , npart )
346- start_lon = 3e3 * np .ones_like (start_lat )
347- runtime = np .timedelta64 (1 , "D" )
348- else :
349- ds = stommel_gyre_dataset (grid_type = grid_type )
350- start_lon = np .linspace (10e3 , 100e3 , npart )
351- start_lat = np .ones_like (start_lon ) * 5000e3
352- runtime = np .timedelta64 (2 , "D" )
341+ ds = stommel_gyre_dataset (grid_type = grid_type )
353342 grid = XGrid .from_dataset (ds )
354343 U = Field ("U" , ds ["U" ], grid , interp_method = XLinear )
355344 V = Field ("V" , ds ["V" ], grid , interp_method = XLinear )
356345 P = Field ("P" , ds ["P" ], grid , interp_method = XLinear )
357346 UV = VectorField ("UV" , U , V )
358347 fieldset = FieldSet ([U , V , P , UV ])
359348
360- dt = np .timedelta64 (1 , "m" ) # TODO check these settings (and possibly increase)
349+ dt = np .timedelta64 (30 , "m" )
350+ runtime = np .timedelta64 (1 , "D" )
351+ start_lon = np .linspace (10e3 , 100e3 , npart )
352+ start_lat = np .ones_like (start_lon ) * 5000e3
361353
362354 if method == "RK45" :
363- fieldset .add_constant ("RK45_tol" , atol )
355+ fieldset .add_constant ("RK45_tol" , rtol )
356+
357+ SampleParticle = Particle .add_variable (
358+ [Variable ("p" , initial = 0.0 , dtype = np .float32 ), Variable ("p_start" , initial = 0.0 , dtype = np .float32 )]
359+ )
360+
361+ def UpdateP (particle , fieldset , time ): # pragma: no cover
362+ particle .p = fieldset .P [particle .time , particle .depth , particle .lat , particle .lon ]
363+ particle .p_start = np .where (particle .time == 0 , particle .p , particle .p_start )
364+
365+ pset = ParticleSet (fieldset , pclass = SampleParticle , lon = start_lon , lat = start_lat , time = np .timedelta64 (0 , "s" ))
366+ pset .execute ([kernel [method ], UpdateP ], dt = dt , runtime = runtime )
367+ np .testing .assert_allclose (pset .p , pset .p_start , rtol = rtol )
368+
369+
370+ @pytest .mark .parametrize (
371+ "method, rtol" ,
372+ [
373+ ("RK4" , 5e-3 ),
374+ ("RK45" , 1e-4 ),
375+ ],
376+ )
377+ @pytest .mark .parametrize ("grid_type" , ["A" ]) # TODO also implement C-grid once available
378+ def test_peninsula_fieldset (method , rtol , grid_type ):
379+ npart = 2
380+ ds = peninsula_dataset (grid_type = grid_type )
381+ grid = XGrid .from_dataset (ds )
382+ U = Field ("U" , ds ["U" ], grid , interp_method = XLinear )
383+ V = Field ("V" , ds ["V" ], grid , interp_method = XLinear )
384+ P = Field ("P" , ds ["P" ], grid , interp_method = XLinear )
385+ UV = VectorField ("UV" , U , V )
386+ fieldset = FieldSet ([U , V , P , UV ])
387+
388+ dt = np .timedelta64 (30 , "m" )
389+ runtime = np .timedelta64 (1 , "D" )
390+ start_lat = np .linspace (3e3 , 47e3 , npart )
391+ start_lon = 3e3 * np .ones_like (start_lat )
392+
393+ if method == "RK45" :
394+ fieldset .add_constant ("RK45_tol" , rtol )
364395
365396 SampleParticle = Particle .add_variable (
366397 [Variable ("p" , initial = 0.0 , dtype = np .float32 ), Variable ("p_start" , initial = 0.0 , dtype = np .float32 )]
@@ -372,4 +403,4 @@ def UpdateP(particle, fieldset, time): # pragma: no cover
372403
373404 pset = ParticleSet (fieldset , pclass = SampleParticle , lon = start_lon , lat = start_lat , time = np .timedelta64 (0 , "s" ))
374405 pset .execute ([kernel [method ], UpdateP ], dt = dt , runtime = runtime )
375- np .testing .assert_allclose (pset .p , pset .p_start , atol = atol )
406+ np .testing .assert_allclose (pset .p , pset .p_start , rtol = rtol )
0 commit comments