Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 68 additions & 55 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import warnings
from collections.abc import Iterable
from typing import Literal

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -506,59 +507,17 @@ def execute(
if output_file:
output_file.metadata["parcels_kernels"] = self._kernel.name

if (dt is not None) and (not isinstance(dt, np.timedelta64)):
raise TypeError("dt must be a np.timedelta64 object")
if dt is None or np.isnat(dt):
if dt is None:
dt = np.timedelta64(1, "s")
self._data["dt"][:] = dt
sign_dt = np.sign(dt).astype(int)
if sign_dt not in [-1, 1]:
raise ValueError("dt must be a positive or negative np.timedelta64 object")

if self.fieldset.time_interval is None:
start_time = np.timedelta64(0, "s") # For the execution loop, we need a start time as a timedelta object
if runtime is None:
raise TypeError("The runtime must be provided when the time_interval is not defined for a fieldset.")
if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]:
raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}")

else:
if isinstance(runtime, np.timedelta64):
end_time = runtime
else:
raise TypeError("The runtime must be a np.timedelta64 object")
self._data["dt"][:] = dt

else:
if not np.isnat(self.time_nextloop).any():
if sign_dt > 0:
start_time = self.time_nextloop.min()
else:
start_time = self.time_nextloop.max()
else:
if sign_dt > 0:
start_time = self.fieldset.time_interval.left
else:
start_time = self.fieldset.time_interval.right

if runtime is None:
if endtime is None:
raise ValueError(
"Must provide either runtime or endtime when time_interval is defined for a fieldset."
)
# Ensure that the endtime uses the same type as the start_time
if isinstance(endtime, self.fieldset.time_interval.left.__class__):
if sign_dt > 0:
if endtime < self.fieldset.time_interval.left:
raise ValueError("The endtime must be after the start time of the fieldset.time_interval")
end_time = min(endtime, self.fieldset.time_interval.right)
else:
if endtime > self.fieldset.time_interval.right:
raise ValueError(
"The endtime must be before the end time of the fieldset.time_interval when dt < 0"
)
end_time = max(endtime, self.fieldset.time_interval.left)
else:
raise TypeError("The endtime must be of the same type as the fieldset.time_interval start time.")
else:
end_time = start_time + runtime * sign_dt
start_time, end_time = _get_simulation_start_and_end_times(
self.fieldset.time_interval, self._data["time_nextloop"], runtime, endtime, sign_dt
)

# Set the time of the particles if it hadn't been set on initialisation
if np.isnat(self._data["time"]).any():
Expand Down Expand Up @@ -619,15 +578,69 @@ def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray,

if isinstance(time.left, np.datetime64) and isinstance(release_times[0], np.timedelta64):
release_times = np.array([t + time.left for t in release_times])
if np.any(release_times < time.left):
if np.any(release_times < time.left) or np.any(release_times > time.right):
warnings.warn(
"Some particles are set to be released outside the FieldSet's executable time domain.",
ParticleSetWarning,
stacklevel=2,
)
if np.any(release_times > time.right):
warnings.warn(
"Some particles are set to be released after the fieldset's last time and the fields are not constant in time.",
ParticleSetWarning,
stacklevel=2,


def _get_simulation_start_and_end_times(
time_interval: TimeInterval,
particle_release_times: np.ndarray,
runtime: np.timedelta64 | None,
endtime: np.datetime64 | None,
sign_dt: Literal[-1, 1],
) -> tuple[np.datetime64, np.datetime64]:
if runtime is not None and endtime is not None:
raise ValueError(
f"runtime and endtime are mutually exclusive - provide one or the other. Got {runtime=!r}, {endtime=!r}"
)

if runtime is None and time_interval is None:
raise ValueError("The runtime must be provided when the time_interval is not defined for a fieldset.")

if sign_dt == 1:
first_release_time = particle_release_times.min()
else:
first_release_time = particle_release_times.max()

start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime)

if endtime is None:
if not isinstance(runtime, np.timedelta64):
raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}")

endtime = start_time + sign_dt * runtime

if time_interval is not None:
if type(endtime) != type(time_interval.left): # noqa: E721
raise ValueError(
f"The endtime must be of the same type as the fieldset.time_interval start time. Got {endtime=!r} with {time_interval=!r}"
)
if endtime not in time_interval:
msg = (
f"Calculated/provided end time of {endtime!r} is not in fieldset time interval {time_interval!r}. Either reduce your runtime, modify your "
"provided endtime, or change your release timing."
"Important info:\n"
f" First particle release: {first_release_time!r}\n"
f" runtime: {runtime!r}\n"
f" (calculated) endtime: {endtime!r}"
)
raise ValueError(msg)

return start_time, endtime


def _get_start_time(first_release_time, time_interval, sign_dt, runtime):
if time_interval is None:
time_interval = TimeInterval(left=np.timedelta64(0, "s"), right=runtime)

if sign_dt == 1:
fieldset_start = time_interval.left
else:
fieldset_start = time_interval.right

start_time = first_release_time if not np.isnat(first_release_time) else fieldset_start
return start_time
47 changes: 0 additions & 47 deletions tests/v4/test_particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,6 @@ def test_pset_create_outside_time(fieldset):
ParticleSet(fieldset, pclass=Particle, lon=[0] * len(time), lat=[0] * len(time), time=time)


@pytest.mark.parametrize(
"dt, expectation",
[
(np.timedelta64(5, "s"), does_not_raise()),
(5.0, pytest.raises(TypeError)),
(np.datetime64("2000-01-02T00:00:00"), pytest.raises(TypeError)),
(timedelta(seconds=2), pytest.raises(TypeError)),
],
)
def test_particleset_dt_type(fieldset, dt, expectation):
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
with expectation:
pset.execute(runtime=np.timedelta64(10, "s"), dt=dt, pyfunc=DoNothing)


def test_pset_starttime_not_multiple_dt(fieldset):
times = [0, 1, 2]
datetimes = [fieldset.time_interval.left + np.timedelta64(t, "s") for t in times]
Expand All @@ -141,38 +126,6 @@ def Addlon(particle, fieldset, time): # pragma: no cover
assert np.allclose([p.lon_nextloop for p in pset], [8 - t for t in times])


@pytest.mark.parametrize(
"runtime, expectation",
[
(np.timedelta64(5, "s"), does_not_raise()),
(5.0, pytest.raises(TypeError)),
(timedelta(seconds=2), pytest.raises(TypeError)),
(np.datetime64("2001-01-02T00:00:00"), pytest.raises(TypeError)),
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(TypeError)),
],
)
def test_particleset_runtime_type(fieldset, runtime, expectation):
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
with expectation:
pset.execute(runtime=runtime, dt=np.timedelta64(10, "s"), pyfunc=DoNothing)


@pytest.mark.parametrize(
"endtime, expectation",
[
(np.datetime64("2000-01-02T00:00:00"), does_not_raise()),
(5.0, pytest.raises(TypeError)),
(np.timedelta64(5, "s"), pytest.raises(TypeError)),
(timedelta(seconds=2), pytest.raises(TypeError)),
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(TypeError)),
],
)
def test_particleset_endtime_type(fieldset, endtime, expectation):
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
with expectation:
pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing)


def test_pset_add_explicit(fieldset):
npart = 11
lon = np.linspace(0, 1, npart)
Expand Down
120 changes: 115 additions & 5 deletions tests/v4/test_particleset_execute.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from contextlib import nullcontext as does_not_raise
from datetime import datetime, timedelta

import numpy as np
import pytest

Expand Down Expand Up @@ -27,7 +30,20 @@ def fieldset() -> FieldSet:
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U (A grid)"], grid)
V = Field("V", ds["V (A grid)"], grid)
return FieldSet([U, V])
UV = VectorField("UV", U, V)
return FieldSet([U, V, UV])


@pytest.fixture
def fieldset_no_time_interval() -> FieldSet:
# i.e., no time variation
ds = datasets_structured["ds_2d_left"].isel(time=0).drop("time")

grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U (A grid)"], grid)
V = Field("V", ds["V (A grid)"], grid)
UV = VectorField("UV", U, V)
return FieldSet([U, V, UV])


@pytest.fixture
Expand All @@ -41,6 +57,98 @@ def zonal_flow_fieldset() -> FieldSet:
return FieldSet([U, V, UV])


def test_pset_execute_implicit_dt_one_second(fieldset):
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle)
pset.execute(DoNothing, runtime=np.timedelta64(1, "s"))

time = pset.time.copy()

pset.execute(DoNothing, runtime=np.timedelta64(1, "s"))
np.testing.assert_array_equal(pset.time, time + np.timedelta64(1, "s"))


def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]:
with pytest.raises(
ValueError,
match="dt must be a positive or negative np.timedelta64 object, got .*",
):
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt)

with pytest.raises(
ValueError,
match="runtime and endtime are mutually exclusive - provide one or the other. Got .*",
):
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01")
)

with pytest.raises(
ValueError,
match="The runtime must be a np.timedelta64 object. Got .*",
):
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1)

msg = """Calculated/provided end time of .* is not in fieldset time interval .* Either reduce your runtime, modify your provided endtime, or change your release timing.*"""
with pytest.raises(
ValueError,
match=msg,
):
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=np.datetime64("1990-01-01"))

with pytest.raises(
ValueError,
match=msg,
):
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s")
)

with pytest.raises(
ValueError,
match="The endtime must be of the same type as the fieldset.time_interval start time. Got .*",
):
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=12345)

with pytest.raises(
ValueError,
match="The runtime must be provided when the time_interval is not defined for a fieldset.",
):
ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute()


@pytest.mark.parametrize(
"runtime, expectation",
[
(np.timedelta64(5, "s"), does_not_raise()),
(5.0, pytest.raises(ValueError)),
(timedelta(seconds=2), pytest.raises(ValueError)),
(np.datetime64("2001-01-02T00:00:00"), pytest.raises(ValueError)),
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)),
],
)
def test_particleset_runtime_type(fieldset, runtime, expectation):
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
with expectation:
pset.execute(runtime=runtime, dt=np.timedelta64(10, "s"), pyfunc=DoNothing)


@pytest.mark.parametrize(
"endtime, expectation",
[
(np.datetime64("2000-01-02T00:00:00"), does_not_raise()),
(5.0, pytest.raises(ValueError)),
(np.timedelta64(5, "s"), pytest.raises(ValueError)),
(timedelta(seconds=2), pytest.raises(ValueError)),
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)),
],
)
def test_particleset_endtime_type(fieldset, endtime, expectation):
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
with expectation:
pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing)


def test_pset_remove_particle_in_kernel(fieldset):
npart = 100
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))
Expand Down Expand Up @@ -92,7 +200,8 @@ def AddLat(particle, fieldset, time): # pragma: no cover
def test_execution_endtime(fieldset, starttime, endtime, dt):
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
endtime = fieldset.time_interval.left + np.timedelta64(endtime, "s")
dt = np.timedelta64(dt, "s")
if dt is not None:
dt = np.timedelta64(dt, "s")
pset = ParticleSet(fieldset, time=starttime, lon=0, lat=0)
pset.execute(DoNothing, endtime=endtime, dt=dt)
assert abs(pset.time_nextloop - endtime) < np.timedelta64(1, "ms")
Expand Down Expand Up @@ -152,10 +261,10 @@ def test_some_particles_throw_outoftime(fieldset):
pset = ParticleSet(fieldset, lon=np.zeros_like(time), lat=np.zeros_like(time), time=time)

def FieldAccessOutsideTime(particle, fieldset, time): # pragma: no cover
fieldset.U[particle.time + np.timedelta64(1, "D"), particle.depth, particle.lat, particle.lon, particle]
fieldset.U[particle.time + np.timedelta64(400, "D"), particle.depth, particle.lat, particle.lon, particle]

with pytest.raises(TimeExtrapolationError):
pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(400, "D"), dt=np.timedelta64(10, "D"))
pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(1, "D"), dt=np.timedelta64(10, "D"))


def test_execution_check_stopallexecution(fieldset):
Expand Down Expand Up @@ -200,7 +309,8 @@ def test_execution_runtime(fieldset, starttime, runtime, dt, npart):
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
runtime = np.timedelta64(runtime, "s")
sign_dt = 1 if dt is None else np.sign(dt)
dt = np.timedelta64(dt, "s")
if dt is not None:
dt = np.timedelta64(dt, "s")
pset = ParticleSet(fieldset, time=starttime, lon=np.zeros(npart), lat=np.zeros(npart))
pset.execute(DoNothing, runtime=runtime, dt=dt)
assert all([abs(p.time_nextloop - starttime - runtime * sign_dt) < np.timedelta64(1, "ms") for p in pset])
Expand Down
Loading