Skip to content

Commit 9fee385

Browse files
Fixing test_drifter
1 parent c8ec729 commit 9fee385

File tree

2 files changed

+41
-32
lines changed

2 files changed

+41
-32
lines changed

src/virtualship/instruments/drifter.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,11 @@ def _sample_temperature(particles, fieldset):
3838

3939

4040
def _check_lifetime(particles, fieldset):
41-
dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds
42-
43-
def has_lifetime(p):
44-
p.age = np.where(p.has_lifetime == 1, p.age + dt, p.age)
45-
p.state = np.where(p.age >= p.lifetime, StatusCode.Delete, p.state)
46-
47-
has_lifetime(particles[particles.has_lifetime == 1])
41+
for i in range(len(particles)):
42+
if particles[i].has_lifetime == 1:
43+
particles[i].age += particles[i].dt / np.timedelta64(1, "s")
44+
if particles[i].age >= particles[i].lifetime:
45+
particles[i].state = StatusCode.Delete
4846

4947

5048
def simulate_drifters(
@@ -78,22 +76,22 @@ def simulate_drifters(
7876
pclass=_DrifterParticle,
7977
lat=[drifter.spacetime.location.lat for drifter in drifters],
8078
lon=[drifter.spacetime.location.lon for drifter in drifters],
81-
depth=[drifter.depth for drifter in drifters],
79+
z=[drifter.depth for drifter in drifters],
8280
time=[drifter.spacetime.time for drifter in drifters],
8381
has_lifetime=[1 if drifter.lifetime is not None else 0 for drifter in drifters],
8482
lifetime=[
85-
0 if drifter.lifetime is None else drifter.lifetime.total_seconds()
83+
0 if drifter.lifetime is None else drifter.lifetime / np.timedelta64(1, "s")
8684
for drifter in drifters
8785
],
8886
)
8987

9088
# define output file for the simulation
91-
out_file = drifter_particleset.ParticleFile(
92-
name=out_path, outputdt=outputdt, chunks=[len(drifter_particleset), 100]
89+
out_file = ParticleFile(
90+
store=out_path, outputdt=outputdt, chunks=(len(drifter_particleset), 100)
9391
)
9492

9593
# get earliest between fieldset end time and provide end time
96-
fieldset_endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1])
94+
fieldset_endtime = fieldset.time_interval.right - np.timedelta64(1, "s") # TODO remove hack stopping 1 second too early when v4 is fixed
9795
if endtime is None:
9896
actual_endtime = fieldset_endtime
9997
elif endtime > fieldset_endtime:
@@ -112,7 +110,7 @@ def simulate_drifters(
112110
)
113111

114112
# if there are more particles left than the number of drifters with an indefinite endtime, warn the user
115-
if len(drifter_particleset.particledata) > len(
113+
if len(drifter_particleset) > len(
116114
[d for d in drifters if d.lifetime is None]
117115
):
118116
print(

tests/instruments/test_drifter.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,59 @@
44

55
import numpy as np
66
import xarray as xr
7-
from parcels import FieldSet
7+
from parcels import FieldSet, Field, VectorField, XGrid
88

99
from virtualship.instruments.drifter import Drifter, simulate_drifters
1010
from virtualship.models import Location, Spacetime
1111

1212

1313
def test_simulate_drifters(tmpdir) -> None:
1414
# arbitrary time offset for the dummy fieldset
15-
base_time = datetime.datetime.strptime("1950-01-01", "%Y-%m-%d")
15+
base_time = np.datetime64("1950-01-01")
1616

1717
CONST_TEMPERATURE = 1.0 # constant temperature in fieldset
1818

19-
v = np.full((2, 2, 2), 1.0)
20-
u = np.full((2, 2, 2), 1.0)
21-
t = np.full((2, 2, 2), CONST_TEMPERATURE)
22-
23-
fieldset = FieldSet.from_data(
24-
{"V": v, "U": u, "T": t},
25-
{
26-
"lon": np.array([0.0, 10.0]),
27-
"lat": np.array([0.0, 10.0]),
28-
"time": [
29-
np.datetime64(base_time + datetime.timedelta(seconds=0)),
30-
np.datetime64(base_time + datetime.timedelta(days=3)),
31-
],
19+
dims = (2, 2, 2) # time, lat, lon
20+
v = np.full(dims, 1.0)
21+
u = np.full(dims, 1.0)
22+
t = np.full(dims, CONST_TEMPERATURE)
23+
24+
time = [base_time, base_time + np.timedelta64(3, "D")]
25+
ds = xr.Dataset(
26+
{"U": (["time", "YG", "XG"], u), "V": (["time", "YG", "XG"], v), "T": (["time", "YG", "XG"], t)},
27+
coords={
28+
"time": (["time"], time, {"axis": "T"}),
29+
"YC": (["YC"], np.arange(dims[1]) + 0.5, {"axis": "Y"}),
30+
"YG": (["YG"], np.arange(dims[1]), {"axis": "Y", "c_grid_axis_shift": -0.5}),
31+
"XC": (["XC"], np.arange(dims[2]) + 0.5, {"axis": "X"}),
32+
"XG": (["XG"], np.arange(dims[2]), {"axis": "X", "c_grid_axis_shift": -0.5}),
33+
"lat": (["YG"], np.linspace(-10, 10, dims[1]), {"axis": "Y", "c_grid_axis_shift": 0.5}),
34+
"lon": (["XG"], np.linspace(-10, 10, dims[2]), {"axis": "X", "c_grid_axis_shift": -0.5}),
3235
},
3336
)
3437

38+
grid = XGrid.from_dataset(ds, mesh="spherical")
39+
U = Field("U", ds["U"], grid)
40+
V = Field("V", ds["V"], grid)
41+
T = Field("T", ds["T"], grid)
42+
UV = VectorField("UV", U, V)
43+
fieldset = FieldSet([U, V, T, UV])
44+
45+
3546
# drifters to deploy
3647
drifters = [
3748
Drifter(
3849
spacetime=Spacetime(
3950
location=Location(latitude=0, longitude=0),
40-
time=base_time + datetime.timedelta(days=0),
51+
time=base_time + np.timedelta64(0, "D"),
4152
),
4253
depth=0.0,
43-
lifetime=datetime.timedelta(hours=2),
54+
lifetime=np.timedelta64(2, "h"),
4455
),
4556
Drifter(
4657
spacetime=Spacetime(
4758
location=Location(latitude=1, longitude=1),
48-
time=base_time + datetime.timedelta(hours=20),
59+
time=base_time + np.timedelta64(20, "h"),
4960
),
5061
depth=0.0,
5162
lifetime=None,
@@ -65,7 +76,7 @@ def test_simulate_drifters(tmpdir) -> None:
6576
)
6677

6778
# test if output is as expected
68-
results = xr.open_zarr(out_path)
79+
results = xr.open_zarr(out_path, decode_cf=False) # TODO fix decode_cf when parcels v4 is fixed
6980

7081
assert len(results.trajectory) == len(drifters)
7182

0 commit comments

Comments
 (0)