Skip to content

Commit fab164d

Browse files
committed
update all instrument kernels to parcels v4
1 parent fa1a04c commit fab164d

File tree

7 files changed

+228
-152
lines changed

7 files changed

+228
-152
lines changed

src/virtualship/instruments/adcp.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,26 @@
33
from pathlib import Path
44

55
import numpy as np
6-
from parcels import FieldSet, ParticleSet, ScipyParticle, Variable
6+
from parcels import FieldSet, ParticleSet, Variable
7+
from parcels.particle import Particle
78

89
from virtualship.models import Spacetime
910

10-
# we specifically use ScipyParticle because we have many small calls to execute
11-
# there is some overhead with JITParticle and this ends up being significantly faster
12-
_ADCPParticle = ScipyParticle.add_variables(
11+
_ADCPParticle = Particle.add_variable(
1312
[
1413
Variable("U", dtype=np.float32, initial=np.nan),
1514
Variable("V", dtype=np.float32, initial=np.nan),
1615
]
1716
)
1817

1918

20-
def _sample_velocity(particle, fieldset, time):
21-
particle.U, particle.V = fieldset.UV.eval(
22-
time, particle.depth, particle.lat, particle.lon, applyConversion=False
19+
def _sample_velocity(particles, fieldset):
20+
particles.U, particles.V = fieldset.UV.eval(
21+
particles.time,
22+
particles.z,
23+
particles.lat,
24+
particles.lon,
25+
applyConversion=False,
2326
)
2427

2528

src/virtualship/instruments/argo_float.py

Lines changed: 85 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Argo float instrument."""
22

3-
import math
43
from dataclasses import dataclass
54
from datetime import datetime, timedelta
65
from pathlib import Path
@@ -9,11 +8,11 @@
98
from parcels import (
109
AdvectionRK4,
1110
FieldSet,
12-
JITParticle,
1311
ParticleSet,
1412
StatusCode,
1513
Variable,
1614
)
15+
from parcels.particle import Particle
1716

1817
from virtualship.models import Spacetime
1918

@@ -31,7 +30,7 @@ class ArgoFloat:
3130
drift_days: float
3231

3332

34-
_ArgoParticle = JITParticle.add_variables(
33+
_ArgoParticle = Particle.add_variable(
3534
[
3635
Variable("cycle_phase", dtype=np.int32, initial=0.0),
3736
Variable("cycle_age", dtype=np.float32, initial=0.0),
@@ -48,71 +47,86 @@ class ArgoFloat:
4847
)
4948

5049

51-
def _argo_float_vertical_movement(particle, fieldset, time):
52-
if particle.cycle_phase == 0:
53-
# Phase 0: Sinking with vertical_speed until depth is drift_depth
54-
particle_ddepth += ( # noqa Parcels defines particle_* variables, which code checkers cannot know.
55-
particle.vertical_speed * particle.dt
56-
)
57-
if particle.depth + particle_ddepth <= particle.drift_depth:
58-
particle_ddepth = particle.drift_depth - particle.depth
59-
particle.cycle_phase = 1
60-
61-
elif particle.cycle_phase == 1:
62-
# Phase 1: Drifting at depth for drifttime seconds
63-
particle.drift_age += particle.dt
64-
if particle.drift_age >= particle.drift_days * 86400:
65-
particle.drift_age = 0 # reset drift_age for next cycle
66-
particle.cycle_phase = 2
67-
68-
elif particle.cycle_phase == 2:
69-
# Phase 2: Sinking further to max_depth
70-
particle_ddepth += particle.vertical_speed * particle.dt
71-
if particle.depth + particle_ddepth <= particle.max_depth:
72-
particle_ddepth = particle.max_depth - particle.depth
73-
particle.cycle_phase = 3
74-
75-
elif particle.cycle_phase == 3:
76-
# Phase 3: Rising with vertical_speed until at surface
77-
particle_ddepth -= particle.vertical_speed * particle.dt
78-
particle.cycle_age += (
79-
particle.dt
80-
) # solve issue of not updating cycle_age during ascent
81-
if particle.depth + particle_ddepth >= particle.min_depth:
82-
particle_ddepth = particle.min_depth - particle.depth
83-
particle.temperature = (
84-
math.nan
85-
) # reset temperature to NaN at end of sampling cycle
86-
particle.salinity = math.nan # idem
87-
particle.cycle_phase = 4
88-
else:
89-
particle.temperature = fieldset.T[
90-
time, particle.depth, particle.lat, particle.lon
91-
]
92-
particle.salinity = fieldset.S[
93-
time, particle.depth, particle.lat, particle.lon
94-
]
95-
96-
elif particle.cycle_phase == 4:
97-
# Phase 4: Transmitting at surface until cycletime is reached
98-
if particle.cycle_age > particle.cycle_days * 86400:
99-
particle.cycle_phase = 0
100-
particle.cycle_age = 0
101-
102-
if particle.state == StatusCode.Evaluate:
103-
particle.cycle_age += particle.dt # update cycle_age
104-
105-
106-
def _keep_at_surface(particle, fieldset, time):
50+
def ArgoPhase1(particles, fieldset):
51+
dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds
52+
53+
def SinkingPhase(p):
54+
"""Phase 0: Sinking with p.vertical_speed until depth is driftdepth."""
55+
p.dz += p.verticle_speed * dt
56+
p.cycle_phase = np.where(p.z + p.dz >= p.drift_depth, 1, p.cycle_phase)
57+
p.dz = np.where(p.z + p.dz >= p.drift_depth, p.drift_depth - p.z, p.dz)
58+
59+
SinkingPhase(particles[particles.cycle_phase == 0])
60+
61+
62+
def ArgoPhase2(particles, fieldset):
63+
dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds
64+
65+
def DriftingPhase(p):
66+
"""Phase 1: Drifting at depth for drift_time seconds."""
67+
p.drift_age += dt
68+
p.cycle_phase = np.where(p.drift_age >= p.drift_time, 2, p.cycle_phase)
69+
p.drift_age = np.where(p.drift_age >= p.drift_time, 0, p.drift_age)
70+
71+
DriftingPhase(particles[particles.cycle_phase == 1])
72+
73+
74+
def ArgoPhase3(particles, fieldset):
75+
dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds
76+
77+
def SecondSinkingPhase(p):
78+
"""Phase 2: Sinking further to max_depth."""
79+
p.dz += p.vertical_speed * dt
80+
p.cycle_phase = np.where(p.z + p.dz >= p.max_depth, 3, p.cycle_phase)
81+
p.dz = np.where(p.z + p.dz >= p.max_depth, p.max_depth - p.z, p.dz)
82+
83+
SecondSinkingPhase(particles[particles.cycle_phase == 2])
84+
85+
86+
def ArgoPhase4(particles, fieldset):
87+
dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds
88+
89+
def RisingPhase(p):
90+
"""Phase 3: Rising with p.vertical_speed until at surface."""
91+
p.dz -= p.vertical_speed * dt
92+
p.temp = fieldset.temp[p.time, p.z, p.lat, p.lon]
93+
p.cycle_phase = np.where(p.z + p.dz <= fieldset.mindepth, 4, p.cycle_phase)
94+
95+
RisingPhase(particles[particles.cycle_phase == 3])
96+
97+
98+
def ArgoPhase5(particles, fieldset):
99+
def TransmittingPhase(p):
100+
"""Phase 4: Transmitting at surface until cycletime (cycle_days * 86400 [seconds]) is reached."""
101+
p.cycle_phase = np.where(p.cycle_age >= p.cycle_days * 86400, 0, p.cycle_phase)
102+
p.cycle_age = np.where(p.cycle_age >= p.cycle_days * 86400, 0, p.cycle_age)
103+
104+
TransmittingPhase(particles[particles.cycle_phase == 4])
105+
106+
107+
def ArgoPhase6(particles, fieldset):
108+
dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds
109+
particles.cycle_age += dt # update cycle_age
110+
111+
112+
def _keep_at_surface(particles, fieldset):
107113
# Prevent error when float reaches surface
108-
if particle.state == StatusCode.ErrorThroughSurface:
109-
particle.depth = particle.min_depth
110-
particle.state = StatusCode.Success
114+
particles.z = np.where(
115+
particles.state == StatusCode.ErrorThroughSurface,
116+
particles.min_depth,
117+
particles.z,
118+
)
119+
particles.state = np.where(
120+
particles.state == StatusCode.ErrorThroughSurface,
121+
StatusCode.Success,
122+
particles.state,
123+
)
111124

112125

113-
def _check_error(particle, fieldset, time):
114-
if particle.state >= 50: # This captures all Errors
115-
particle.delete()
126+
def _check_error(particles, fieldset):
127+
particles.state = np.where(
128+
particles.state >= 50, StatusCode.Delete, particles.state
129+
) # captures all errors
116130

117131

118132
def simulate_argo_floats(
@@ -174,7 +188,12 @@ def simulate_argo_floats(
174188
# execute simulation
175189
argo_float_particleset.execute(
176190
[
177-
_argo_float_vertical_movement,
191+
ArgoPhase1,
192+
ArgoPhase2,
193+
ArgoPhase3,
194+
ArgoPhase4,
195+
ArgoPhase5,
196+
ArgoPhase6,
178197
AdvectionRK4,
179198
_keep_at_surface,
180199
_check_error,

src/virtualship/instruments/ctd.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from pathlib import Path
66

77
import numpy as np
8-
from parcels import FieldSet, JITParticle, ParticleSet, Variable
8+
from parcels import FieldSet, ParticleSet, Variable
9+
from parcels.particle import Particle
10+
from parcels.tools import StatusCode
911

1012
from virtualship.models import Spacetime
1113

@@ -19,7 +21,7 @@ class CTD:
1921
max_depth: float
2022

2123

22-
_CTDParticle = JITParticle.add_variables(
24+
_CTDParticle = Particle.add_variable(
2325
[
2426
Variable("salinity", dtype=np.float32, initial=np.nan),
2527
Variable("temperature", dtype=np.float32, initial=np.nan),
@@ -31,26 +33,37 @@ class CTD:
3133
)
3234

3335

34-
def _sample_temperature(particle, fieldset, time):
35-
particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon]
36+
def _sample_temperature(particles, fieldset):
37+
particles.temperature = fieldset.T[
38+
particles.time, particles.z, particles.lat, particles.lon
39+
]
40+
41+
42+
def _sample_salinity(particles, fieldset):
43+
particles.salinity = fieldset.S[
44+
particles.time, particles.z, particles.lat, particles.lon
45+
]
46+
47+
48+
def _ctd_sinking(particles, fieldset):
49+
dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds
50+
51+
def ctd_lowering(p):
52+
p.dz = -particles.winch_speed * dt
53+
p.raising = np.where(p.z + p.dz < p.max_depth, 1, p.raising)
54+
p.dz = np.where(p.z + p.dz < p.max_depth, -p.ddpeth, p.dz)
55+
56+
ctd_lowering(particles[particles.raising == 0])
3657

3758

38-
def _sample_salinity(particle, fieldset, time):
39-
particle.salinity = fieldset.S[time, particle.depth, particle.lat, particle.lon]
59+
def _ctd_rising(particles, fieldset):
60+
dt = particles.dt / np.timedelta64(1, "s") # convert dt to seconds
4061

62+
def ctd_rising(p):
63+
p.dz = p.winch_speed * dt
64+
p.state = np.where(p.z + p.dz > p.min_depth, StatusCode.Delete, p.state)
4165

42-
def _ctd_cast(particle, fieldset, time):
43-
# lowering
44-
if particle.raising == 0:
45-
particle_ddepth = -particle.winch_speed * particle.dt
46-
if particle.depth + particle_ddepth < particle.max_depth:
47-
particle.raising = 1
48-
particle_ddepth = -particle_ddepth
49-
# raising
50-
else:
51-
particle_ddepth = particle.winch_speed * particle.dt
52-
if particle.depth + particle_ddepth > particle.min_depth:
53-
particle.delete()
66+
ctd_rising(particles[particles.raising == 1])
5467

5568

5669
def simulate_ctd(
@@ -123,7 +136,7 @@ def simulate_ctd(
123136

124137
# execute simulation
125138
ctd_particleset.execute(
126-
[_sample_salinity, _sample_temperature, _ctd_cast],
139+
[_sample_salinity, _sample_temperature, _ctd_sinking, _ctd_rising],
127140
endtime=fieldset_endtime,
128141
dt=DT,
129142
verbose_progress=False,

0 commit comments

Comments
 (0)