Skip to content

Commit 0ca39bd

Browse files
Merge branch 'v4-dev' into cleaning_index_search
2 parents 2d20319 + 62a13d8 commit 0ca39bd

30 files changed

+1120
-968
lines changed

docs/examples/tutorial_stommel_uxarray.ipynb

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
"\n",
8989
"A `UXArray.Dataset` consists of multiple `UXArray.UxDataArray`'s and a `UXArray.UxGrid`. Parcels views general circulation model data through the `Field` and `VectorField` classes. A `Field` is defined by its `name`, `data`, `grid`, and `interp_method`. A `VectorField` can be constructed by using 2 or 3 `Field`'s. The `Field.data` attribute can be either an `XArray.DataArray` or `UXArray.UxDataArray` object. The `Field.grid` attribute is of type `Parcels.XGrid` or `Parcels.UXGrid`. Last, the `interp_method` is a dynamic function that can be set at runtime to define the interpolation procedure for the `Field`. This gives you the flexibility to use one of the pre-defined interpolation methods included with Parcels v4, or to create your own interpolator. \n",
9090
"\n",
91-
"The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions."
91+
"The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions. Setting the `mesh` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees."
9292
]
9393
},
9494
{
@@ -99,7 +99,7 @@
9999
"source": [
100100
"from parcels.uxgrid import UxGrid\n",
101101
"\n",
102-
"grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"])\n",
102+
"grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"], mesh=\"spherical\")\n",
103103
"# You can view the uxgrid object with the following command:\n",
104104
"grid.uxgrid"
105105
]
@@ -112,7 +112,7 @@
112112
"\n",
113113
"In Parcels, grid searching is conducted with respect to the faces. In other words, when a grid index `ei` is provided to an interpolation method, this refers the face index `fi` at vertical layer `zi` (when unraveled). Within the interpolation method, the `field.grid.uxgrid.face_node_connectivity` attribute can be used to obtain the node indices that surround the face. Using these connectivity tables is necessary for properly indexing node registered data.\n",
114114
"\n",
115-
"For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method. Setting the `mesh_type` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees."
115+
"For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method."
116116
]
117117
},
118118
{
@@ -128,21 +128,18 @@
128128
" name=\"U\",\n",
129129
" data=ds.U,\n",
130130
" grid=grid,\n",
131-
" mesh_type=\"spherical\",\n",
132131
" interp_method=UXPiecewiseConstantFace,\n",
133132
")\n",
134133
"V = Field(\n",
135134
" name=\"V\",\n",
136135
" data=ds.V,\n",
137136
" grid=grid,\n",
138-
" mesh_type=\"spherical\",\n",
139137
" interp_method=UXPiecewiseConstantFace,\n",
140138
")\n",
141139
"P = Field(\n",
142140
" name=\"P\",\n",
143141
" data=ds.p,\n",
144142
" grid=grid,\n",
145-
" mesh_type=\"spherical\",\n",
146143
" interp_method=UXPiecewiseConstantFace,\n",
147144
")"
148145
]

parcels/_core/utils/time.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def __init__(self, left: T, right: T) -> None:
4848
def __contains__(self, item: T) -> bool:
4949
return self.left <= item <= self.right
5050

51+
def is_all_time_in_interval(self, time):
52+
item = np.atleast_1d(time)
53+
return (self.left <= item).all() and (item <= self.right).all()
54+
5155
def __repr__(self) -> str:
5256
return f"TimeInterval(left={self.left!r}, right={self.right!r})"
5357

parcels/_datasets/structured/generated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import xarray as xr
55

66

7-
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh_type="spherical"):
8-
max_lon = 180.0 if mesh_type == "spherical" else 1e6
7+
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"):
8+
max_lon = 180.0 if mesh == "spherical" else 1e6
99

1010
return xr.Dataset(
1111
{"U": (["time", "depth", "YG", "XG"], np.zeros(dims)), "V": (["time", "depth", "YG", "XG"], np.zeros(dims))},

parcels/_index_search.py

Lines changed: 27 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -30,47 +30,24 @@ def _search_time_index(field: Field, time: datetime):
3030
if the sampled value is outside the time value range.
3131
"""
3232
if field.time_interval is None:
33-
return 0, 0
33+
return np.zeros(shape=time.shape, dtype=np.float32), np.zeros(shape=time.shape, dtype=np.int32)
3434

35-
if time not in field.time_interval:
35+
if not field.time_interval.is_all_time_in_interval(time):
3636
_raise_time_extrapolation_error(time, field=None)
3737

38-
time_index = field.data.time <= time
39-
40-
if time_index.all():
41-
# If given time > last known field time, use
42-
# the last field frame without interpolation
43-
ti = len(field.data.time) - 1
44-
45-
elif np.logical_not(time_index).all():
46-
# If given time < any time in the field, use
47-
# the first field frame without interpolation
48-
ti = 0
49-
else:
50-
ti = int(time_index.argmin() - 1) if time_index.any() else 0
51-
if len(field.data.time) == 1:
52-
tau = 0
53-
elif ti == len(field.data.time) - 1:
54-
tau = 1
55-
else:
56-
tau = (
57-
(time - field.data.time[ti]).dt.total_seconds().values
58-
/ (field.data.time[ti + 1] - field.data.time[ti]).dt.total_seconds().values
59-
if field.data.time[ti] != field.data.time[ti + 1]
60-
else 0
61-
)
62-
return tau, ti
38+
ti = np.searchsorted(field.data.time.data, time, side="right") - 1
39+
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
40+
return np.atleast_1d(tau), np.atleast_1d(ti)
6341

6442

6543
def _search_indices_curvilinear_2d(
6644
grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None
67-
):
45+
): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays
6846
yi, xi = yi_guess, xi_guess
6947
if yi is None or xi is None:
70-
faces = grid.get_spatial_hash().query(np.column_stack((y, x)))
71-
yi, xi = faces[0]
48+
yi, xi = grid.get_spatial_hash().query(y, x)
7249

73-
xsi = eta = -1.0
50+
xsi = eta = -1.0 * np.ones(len(x), dtype=float)
7451
invA = np.array(
7552
[
7653
[1, 0, 0, 0],
@@ -94,7 +71,7 @@ def _search_indices_curvilinear_2d(
9471
# if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
9572
# _raise_field_out_of_bound_error(z, y, x)
9673

97-
while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol:
74+
while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol):
9875
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
9976

10077
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
@@ -104,40 +81,29 @@ def _search_indices_curvilinear_2d(
10481
aa = a[3] * b[2] - a[2] * b[3]
10582
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
10683
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
107-
if abs(aa) < 1e-12: # Rectilinear cell, or quasi
108-
eta = -cc / bb
109-
else:
110-
det2 = bb * bb - 4 * aa * cc
111-
if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter
112-
det = np.sqrt(det2)
113-
eta = (-bb + det) / (2 * aa)
114-
if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg
115-
xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5
116-
else:
117-
xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta)
118-
if xsi < 0 and eta < 0 and xi == 0 and yi == 0:
119-
_raise_field_out_of_bound_error(0, y, x)
120-
if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1:
121-
_raise_field_out_of_bound_error(0, y, x)
122-
if xsi < -tol:
123-
xi -= 1
124-
elif xsi > 1 + tol:
125-
xi += 1
126-
if eta < -tol:
127-
yi -= 1
128-
elif eta > 1 + tol:
129-
yi += 1
130-
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
84+
85+
det2 = bb * bb - 4 * aa * cc
86+
det = np.where(det2 > 0, np.sqrt(det2), eta)
87+
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))
88+
89+
xsi = np.where(
90+
abs(a[1] + a[3] * eta) < 1e-12,
91+
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
92+
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
93+
)
94+
95+
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
96+
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))
97+
98+
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh)
13199
it += 1
132100
if it > maxIterSearch:
133101
print(f"Correct cell not found after {maxIterSearch} iterations")
134102
_raise_field_out_of_bound_error(0, y, x)
135-
xsi = max(0.0, xsi)
136-
eta = max(0.0, eta)
137-
xsi = min(1.0, xsi)
138-
eta = min(1.0, eta)
103+
xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi))
104+
eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta))
139105

140-
if not ((0 <= xsi <= 1) and (0 <= eta <= 1)):
106+
if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)):
141107
_raise_field_sampling_error(y, x)
142108
return (yi, eta, xi, xsi)
143109

parcels/application_kernels/advection.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ def AdvectionRK4(particle, fieldset, time): # pragma: no cover
2121
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
2222
(u1, v1) = fieldset.UV[particle]
2323
lon1, lat1 = (particle.lon + u1 * 0.5 * dt, particle.lat + v1 * 0.5 * dt)
24-
(u2, v2) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat1, lon1, particle]
24+
(u2, v2) = fieldset.UV[particle.time + 0.5 * particle.dt, particle.depth, lat1, lon1, particle]
2525
lon2, lat2 = (particle.lon + u2 * 0.5 * dt, particle.lat + v2 * 0.5 * dt)
26-
(u3, v3) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat2, lon2, particle]
26+
(u3, v3) = fieldset.UV[particle.time + 0.5 * particle.dt, particle.depth, lat2, lon2, particle]
2727
lon3, lat3 = (particle.lon + u3 * dt, particle.lat + v3 * dt)
28-
(u4, v4) = fieldset.UV[time + particle.dt, particle.depth, lat3, lon3, particle]
28+
(u4, v4) = fieldset.UV[particle.time + particle.dt, particle.depth, lat3, lon3, particle]
2929
particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6.0 * dt
3030
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6.0 * dt
3131

@@ -37,15 +37,15 @@ def AdvectionRK4_3D(particle, fieldset, time): # pragma: no cover
3737
lon1 = particle.lon + u1 * 0.5 * dt
3838
lat1 = particle.lat + v1 * 0.5 * dt
3939
dep1 = particle.depth + w1 * 0.5 * dt
40-
(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
40+
(u2, v2, w2) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
4141
lon2 = particle.lon + u2 * 0.5 * dt
4242
lat2 = particle.lat + v2 * 0.5 * dt
4343
dep2 = particle.depth + w2 * 0.5 * dt
44-
(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
44+
(u3, v3, w3) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
4545
lon3 = particle.lon + u3 * dt
4646
lat3 = particle.lat + v3 * dt
4747
dep3 = particle.depth + w3 * dt
48-
(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
48+
(u4, v4, w4) = fieldset.UVW[particle.time + particle.dt, dep3, lat3, lon3, particle]
4949
particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * dt
5050
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * dt
5151
particle.ddepth += (w1 + 2 * w2 + 2 * w3 + w4) / 6 * dt
@@ -56,35 +56,35 @@ def AdvectionRK4_3D_CROCO(particle, fieldset, time): # pragma: no cover
5656
This kernel assumes the vertical velocity is the 'w' field from CROCO output and works on sigma-layers.
5757
"""
5858
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
59-
sig_dep = particle.depth / fieldset.H[time, 0, particle.lat, particle.lon]
59+
sig_dep = particle.depth / fieldset.H[particle.time, 0, particle.lat, particle.lon]
6060

61-
(u1, v1, w1) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon, particle]
62-
w1 *= sig_dep / fieldset.H[time, 0, particle.lat, particle.lon]
61+
(u1, v1, w1) = fieldset.UVW[particle.time, particle.depth, particle.lat, particle.lon, particle]
62+
w1 *= sig_dep / fieldset.H[particle.time, 0, particle.lat, particle.lon]
6363
lon1 = particle.lon + u1 * 0.5 * dt
6464
lat1 = particle.lat + v1 * 0.5 * dt
6565
sig_dep1 = sig_dep + w1 * 0.5 * dt
66-
dep1 = sig_dep1 * fieldset.H[time, 0, lat1, lon1]
66+
dep1 = sig_dep1 * fieldset.H[particle.time, 0, lat1, lon1]
6767

68-
(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
69-
w2 *= sig_dep1 / fieldset.H[time, 0, lat1, lon1]
68+
(u2, v2, w2) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
69+
w2 *= sig_dep1 / fieldset.H[particle.time, 0, lat1, lon1]
7070
lon2 = particle.lon + u2 * 0.5 * dt
7171
lat2 = particle.lat + v2 * 0.5 * dt
7272
sig_dep2 = sig_dep + w2 * 0.5 * dt
73-
dep2 = sig_dep2 * fieldset.H[time, 0, lat2, lon2]
73+
dep2 = sig_dep2 * fieldset.H[particle.time, 0, lat2, lon2]
7474

75-
(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
76-
w3 *= sig_dep2 / fieldset.H[time, 0, lat2, lon2]
75+
(u3, v3, w3) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
76+
w3 *= sig_dep2 / fieldset.H[particle.time, 0, lat2, lon2]
7777
lon3 = particle.lon + u3 * dt
7878
lat3 = particle.lat + v3 * dt
7979
sig_dep3 = sig_dep + w3 * dt
80-
dep3 = sig_dep3 * fieldset.H[time, 0, lat3, lon3]
80+
dep3 = sig_dep3 * fieldset.H[particle.time, 0, lat3, lon3]
8181

82-
(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
83-
w4 *= sig_dep3 / fieldset.H[time, 0, lat3, lon3]
82+
(u4, v4, w4) = fieldset.UVW[particle.time + particle.dt, dep3, lat3, lon3, particle]
83+
w4 *= sig_dep3 / fieldset.H[particle.time, 0, lat3, lon3]
8484
lon4 = particle.lon + u4 * dt
8585
lat4 = particle.lat + v4 * dt
8686
sig_dep4 = sig_dep + w4 * dt
87-
dep4 = sig_dep4 * fieldset.H[time, 0, lat4, lon4]
87+
dep4 = sig_dep4 * fieldset.H[particle.time, 0, lat4, lon4]
8888

8989
particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * dt
9090
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * dt
@@ -115,14 +115,7 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
115115
Time-step dt is halved if error is larger than fieldset.RK45_tol,
116116
and doubled if error is smaller than 1/10th of tolerance.
117117
"""
118-
dt = particle.next_dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
119-
if dt > fieldset.RK45_max_dt:
120-
dt = fieldset.RK45_max_dt
121-
particle.next_dt = fieldset.RK45_max_dt * np.timedelta64(1, "s")
122-
if dt < fieldset.RK45_min_dt:
123-
particle.next_dt = fieldset.RK45_min_dt * np.timedelta64(1, "s")
124-
return StatusCode.Repeat
125-
particle.dt = particle.next_dt
118+
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
126119

127120
c = [1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0]
128121
A = [
@@ -137,42 +130,58 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
137130

138131
(u1, v1) = fieldset.UV[particle]
139132
lon1, lat1 = (particle.lon + u1 * A[0][0] * dt, particle.lat + v1 * A[0][0] * dt)
140-
(u2, v2) = fieldset.UV[time + c[0] * particle.dt, particle.depth, lat1, lon1, particle]
133+
(u2, v2) = fieldset.UV[particle.time + c[0] * particle.dt, particle.depth, lat1, lon1, particle]
141134
lon2, lat2 = (
142135
particle.lon + (u1 * A[1][0] + u2 * A[1][1]) * dt,
143136
particle.lat + (v1 * A[1][0] + v2 * A[1][1]) * dt,
144137
)
145-
(u3, v3) = fieldset.UV[time + c[1] * particle.dt, particle.depth, lat2, lon2, particle]
138+
(u3, v3) = fieldset.UV[particle.time + c[1] * particle.dt, particle.depth, lat2, lon2, particle]
146139
lon3, lat3 = (
147140
particle.lon + (u1 * A[2][0] + u2 * A[2][1] + u3 * A[2][2]) * dt,
148141
particle.lat + (v1 * A[2][0] + v2 * A[2][1] + v3 * A[2][2]) * dt,
149142
)
150-
(u4, v4) = fieldset.UV[time + c[2] * particle.dt, particle.depth, lat3, lon3, particle]
143+
(u4, v4) = fieldset.UV[particle.time + c[2] * particle.dt, particle.depth, lat3, lon3, particle]
151144
lon4, lat4 = (
152145
particle.lon + (u1 * A[3][0] + u2 * A[3][1] + u3 * A[3][2] + u4 * A[3][3]) * dt,
153146
particle.lat + (v1 * A[3][0] + v2 * A[3][1] + v3 * A[3][2] + v4 * A[3][3]) * dt,
154147
)
155-
(u5, v5) = fieldset.UV[time + c[3] * particle.dt, particle.depth, lat4, lon4, particle]
148+
(u5, v5) = fieldset.UV[particle.time + c[3] * particle.dt, particle.depth, lat4, lon4, particle]
156149
lon5, lat5 = (
157150
particle.lon + (u1 * A[4][0] + u2 * A[4][1] + u3 * A[4][2] + u4 * A[4][3] + u5 * A[4][4]) * dt,
158151
particle.lat + (v1 * A[4][0] + v2 * A[4][1] + v3 * A[4][2] + v4 * A[4][3] + v5 * A[4][4]) * dt,
159152
)
160-
(u6, v6) = fieldset.UV[time + c[4] * particle.dt, particle.depth, lat5, lon5, particle]
153+
(u6, v6) = fieldset.UV[particle.time + c[4] * particle.dt, particle.depth, lat5, lon5, particle]
161154

162155
lon_4th = (u1 * b4[0] + u2 * b4[1] + u3 * b4[2] + u4 * b4[3] + u5 * b4[4]) * dt
163156
lat_4th = (v1 * b4[0] + v2 * b4[1] + v3 * b4[2] + v4 * b4[3] + v5 * b4[4]) * dt
164157
lon_5th = (u1 * b5[0] + u2 * b5[1] + u3 * b5[2] + u4 * b5[3] + u5 * b5[4] + u6 * b5[5]) * dt
165158
lat_5th = (v1 * b5[0] + v2 * b5[1] + v3 * b5[2] + v4 * b5[3] + v5 * b5[4] + v6 * b5[5]) * dt
166159

167-
kappa = math.sqrt(math.pow(lon_5th - lon_4th, 2) + math.pow(lat_5th - lat_4th, 2))
168-
if (kappa <= fieldset.RK45_tol) or (math.fabs(dt) < math.fabs(fieldset.RK45_min_dt)):
169-
particle.dlon += lon_4th
170-
particle.dlat += lat_4th
171-
if (kappa <= fieldset.RK45_tol / 10) and (math.fabs(dt * 2) <= math.fabs(fieldset.RK45_max_dt)):
172-
particle.next_dt *= 2
173-
else:
174-
particle.next_dt /= 2
175-
return StatusCode.Repeat
160+
kappa = np.sqrt(np.pow(lon_5th - lon_4th, 2) + np.pow(lat_5th - lat_4th, 2))
161+
162+
good_particles = (kappa <= fieldset.RK45_tol) | (np.fabs(dt) <= np.fabs(fieldset.RK45_min_dt))
163+
particle.dlon += np.where(good_particles, lon_5th, 0)
164+
particle.dlat += np.where(good_particles, lat_5th, 0)
165+
166+
increase_dt_particles = (
167+
good_particles & (kappa <= fieldset.RK45_tol / 10) & (np.fabs(dt * 2) <= np.fabs(fieldset.RK45_max_dt))
168+
)
169+
particle.dt = np.where(increase_dt_particles, particle.dt * 2, particle.dt)
170+
particle.dt = np.where(
171+
particle.dt > fieldset.RK45_max_dt * np.timedelta64(1, "s"),
172+
fieldset.RK45_max_dt * np.timedelta64(1, "s"),
173+
particle.dt,
174+
)
175+
particle.state = np.where(good_particles, StatusCode.Success, particle.state)
176+
177+
repeat_particles = np.invert(good_particles)
178+
particle.dt = np.where(repeat_particles, particle.dt / 2, particle.dt)
179+
particle.dt = np.where(
180+
particle.dt < fieldset.RK45_min_dt * np.timedelta64(1, "s"),
181+
fieldset.RK45_min_dt * np.timedelta64(1, "s"),
182+
particle.dt,
183+
)
184+
particle.state = np.where(repeat_particles, StatusCode.Repeat, particle.state)
176185

177186

178187
def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover

0 commit comments

Comments
 (0)