Skip to content

Commit e479387

Browse files
committed
Refactor AdvectionRK45 kernel selection into dedicated util
1 parent 5bcf378 commit e479387

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

tests/test_particleset_execute.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from parcels.interpolators import UXPiecewiseConstantFace, UXPiecewiseLinearNode, XLinear
2727
from parcels.kernels import AdvectionEE, AdvectionRK2, AdvectionRK4, AdvectionRK4_3D, AdvectionRK45
2828
from tests.common_kernels import DoNothing
29+
from tests.utils import DEFAULT_PARTICLES
2930

3031

3132
@pytest.fixture
@@ -166,12 +167,7 @@ def test_particleset_run_RK_to_endtime_fwd_bwd(fieldset, kernel, dt):
166167
fieldset.U.data[:] = 0.0
167168
fieldset.V.data[:] = 0.0
168169

169-
if kernel == AdvectionRK45:
170-
MyParticle = Particle.add_variable(Variable("next_dt"))
171-
else:
172-
MyParticle = Particle
173-
174-
pset = ParticleSet(fieldset, pclass=MyParticle, lon=[0.2], lat=[5.0], time=[starttime])
170+
pset = ParticleSet(fieldset, pclass=DEFAULT_PARTICLES[kernel], lon=[0.2], lat=[5.0], time=[starttime])
175171
pset.execute(kernel, endtime=endtime, dt=dt)
176172
assert pset[0].time == fieldset.time_interval.time_length_as_flt
177173

tests/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from __future__ import annotations
44

55
import struct
6+
from collections import defaultdict
67
from pathlib import Path
78

89
import numpy as np
910
import xarray as xr
1011

1112
import parcels
12-
from parcels import Field, FieldSet, VectorField
13+
from parcels import Field, FieldSet, Particle, Variable, VectorField
1314
from parcels._core.xgrid import _FIELD_DATA_ORDERING, XGrid, get_axis_from_dim_name
1415
from parcels._datasets.structured.generated import simple_UV_dataset
1516
from parcels.interpolators import XLinear
@@ -18,6 +19,10 @@
1819
TEST_ROOT = PROJECT_ROOT / "tests"
1920
TEST_DATA = TEST_ROOT / "test_data"
2021

22+
# Define default particle classes for different built-in kernels
23+
DEFAULT_PARTICLES = defaultdict(lambda: Particle)
24+
DEFAULT_PARTICLES[parcels.kernels.AdvectionRK45] = Particle.add_variable(Variable("next_dt"))
25+
2126

2227
def create_fieldset_unit_mesh(xdim=20, ydim=20, mesh="flat") -> FieldSet:
2328
"""Standard unit mesh fieldset with U and V equivalent to longitude and latitude."""

0 commit comments

Comments
 (0)