Skip to content

Commit 370151a

Browse files
Update runtime and dt in pset.execute to accept timedelta (#2178)
1 parent 3cfa92c commit 370151a

File tree

4 files changed

+68
-14
lines changed

4 files changed

+68
-14
lines changed

parcels/_core/utils/time.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from datetime import datetime
3+
from datetime import datetime, timedelta
44
from typing import TYPE_CHECKING, TypeVar
55

66
import cftime
@@ -118,3 +118,29 @@ def get_datetime_type_calendar(
118118
# datetime isn't a cftime datetime object
119119
pass
120120
return type(example_datetime), calendar
121+
122+
123+
_TD_PRECISION_GETTER_FOR_UNIT = (
124+
(lambda dt: dt.days, "D"),
125+
(lambda dt: dt.seconds, "s"),
126+
(lambda dt: dt.microseconds, "us"),
127+
)
128+
129+
130+
def maybe_convert_python_timedelta_to_numpy(dt: timedelta | np.timedelta64) -> np.timedelta64:
131+
if isinstance(dt, np.timedelta64):
132+
return dt
133+
134+
try:
135+
dts = []
136+
for get_value_for_unit, np_unit in _TD_PRECISION_GETTER_FOR_UNIT:
137+
value = get_value_for_unit(dt)
138+
if value != 0:
139+
dts.append(np.timedelta64(value, np_unit))
140+
141+
if dts:
142+
return sum(dts)
143+
else:
144+
return np.timedelta64(0, "s")
145+
except Exception as e:
146+
raise ValueError(f"Could not convert {dt!r} to np.timedelta64.") from e

parcels/particleset.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
import sys
23
import warnings
34
from collections.abc import Iterable
@@ -8,7 +9,7 @@
89
from scipy.spatial import KDTree
910
from tqdm import tqdm
1011

11-
from parcels._core.utils.time import TimeInterval
12+
from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy
1213
from parcels._reprs import particleset_repr
1314
from parcels.application_kernels.advection import AdvectionRK4
1415
from parcels.basegrid import GridType
@@ -460,8 +461,8 @@ def execute(
460461
self,
461462
pyfunc=AdvectionRK4,
462463
endtime: np.timedelta64 | np.datetime64 | None = None,
463-
runtime: np.timedelta64 | None = None,
464-
dt: np.timedelta64 | None = None,
464+
runtime: datetime.timedelta | np.timedelta64 | None = None,
465+
dt: datetime.timedelta | np.timedelta64 | None = None,
465466
output_file=None,
466467
verbose_progress=True,
467468
):
@@ -510,8 +511,21 @@ def execute(
510511
if dt is None:
511512
dt = np.timedelta64(1, "s")
512513

513-
if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]:
514-
raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}")
514+
try:
515+
dt = maybe_convert_python_timedelta_to_numpy(dt)
516+
assert not np.isnat(dt)
517+
sign_dt = np.sign(dt).astype(int)
518+
assert sign_dt in [-1, 1]
519+
except (ValueError, AssertionError) as e:
520+
raise ValueError(f"dt must be a non-zero datetime.timedelta or np.timedelta64 object, got {dt=!r}") from e
521+
522+
if runtime is not None:
523+
try:
524+
runtime = maybe_convert_python_timedelta_to_numpy(runtime)
525+
except ValueError as e:
526+
raise ValueError(
527+
f"The runtime must be a datetime.timedelta or np.timedelta64 object. Got {type(runtime)}"
528+
) from e
515529

516530
self._data["dt"][:] = dt
517531

@@ -609,9 +623,6 @@ def _get_simulation_start_and_end_times(
609623
start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime)
610624

611625
if endtime is None:
612-
if not isinstance(runtime, np.timedelta64):
613-
raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}")
614-
615626
endtime = start_time + sign_dt * runtime
616627

617628
if time_interval is not None:

tests/v4/test_particleset_execute.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
7171
for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]:
7272
with pytest.raises(
7373
ValueError,
74-
match="dt must be a positive or negative np.timedelta64 object, got .*",
74+
match="dt must be a non-zero datetime.timedelta or np.timedelta64 object, got .*",
7575
):
7676
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt)
7777

@@ -85,7 +85,7 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
8585

8686
with pytest.raises(
8787
ValueError,
88-
match="The runtime must be a np.timedelta64 object. Got .*",
88+
match="The runtime must be a datetime.timedelta or np.timedelta64 object. Got .*",
8989
):
9090
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1)
9191

@@ -121,8 +121,8 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
121121
"runtime, expectation",
122122
[
123123
(np.timedelta64(5, "s"), does_not_raise()),
124+
(timedelta(seconds=2), does_not_raise()),
124125
(5.0, pytest.raises(ValueError)),
125-
(timedelta(seconds=2), pytest.raises(ValueError)),
126126
(np.datetime64("2001-01-02T00:00:00"), pytest.raises(ValueError)),
127127
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)),
128128
],

tests/v4/utils/test_time.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
22

3-
from datetime import datetime
3+
from datetime import datetime, timedelta
44

55
import numpy as np
66
import pytest
77
from cftime import datetime as cftime_datetime
88
from hypothesis import given
99
from hypothesis import strategies as st
1010

11-
from parcels._core.utils.time import TimeInterval
11+
from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy
1212

1313
calendar_strategy = st.sampled_from(
1414
[
@@ -181,3 +181,20 @@ def test_time_interval_intersection_different_calendars():
181181
)
182182
with pytest.raises(ValueError, match="TimeIntervals are not compatible."):
183183
interval1.intersection(interval2)
184+
185+
186+
@pytest.mark.parametrize(
187+
"td,expected",
188+
[
189+
pytest.param(np.timedelta64(1, "s"), np.timedelta64(1, "s"), id="noop"),
190+
pytest.param(timedelta(days=5), np.timedelta64(5, "D"), id="single unit"),
191+
pytest.param(timedelta(days=5, seconds=30), np.timedelta64(5, "D") + np.timedelta64(30, "s"), id="mixed units"),
192+
pytest.param(timedelta(days=0), np.timedelta64(0, "s"), id="zero timedelta"),
193+
pytest.param(
194+
timedelta(seconds=-2), np.timedelta64(-2, "s"), id="negative timedelta"
195+
), # included because timedelta(seconds=-2) -> timedelta(days=-1, seconds=86398)
196+
],
197+
)
198+
def test_maybe_convert_python_timedelta_to_numpy(td, expected):
199+
result = maybe_convert_python_timedelta_to_numpy(td)
200+
assert result == expected

0 commit comments

Comments
 (0)