Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion parcels/_core/utils/time.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from datetime import datetime
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, TypeVar

import cftime
Expand Down Expand Up @@ -118,3 +118,29 @@ def get_datetime_type_calendar(
# datetime isn't a cftime datetime object
pass
return type(example_datetime), calendar


_TD_PRECISION_GETTER_FOR_UNIT = (
(lambda dt: dt.days, "D"),
(lambda dt: dt.seconds, "s"),
(lambda dt: dt.microseconds, "us"),
)


def maybe_convert_python_timedelta_to_numpy(dt: timedelta | np.timedelta64) -> np.timedelta64:
if isinstance(dt, np.timedelta64):
return dt

try:
dts = []
for get_value_for_unit, np_unit in _TD_PRECISION_GETTER_FOR_UNIT:
value = get_value_for_unit(dt)
if value != 0:
dts.append(np.timedelta64(value, np_unit))

if dts:
return sum(dts)
else:
return np.timedelta64(0, "s")
except Exception as e:
raise ValueError(f"Could not convert {dt!r} to np.timedelta64.") from e
27 changes: 19 additions & 8 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import sys
import warnings
from collections.abc import Iterable
Expand All @@ -8,7 +9,7 @@
from scipy.spatial import KDTree
from tqdm import tqdm

from parcels._core.utils.time import TimeInterval
from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy
from parcels._reprs import particleset_repr
from parcels.application_kernels.advection import AdvectionRK4
from parcels.basegrid import GridType
Expand Down Expand Up @@ -460,8 +461,8 @@ def execute(
self,
pyfunc=AdvectionRK4,
endtime: np.timedelta64 | np.datetime64 | None = None,
runtime: np.timedelta64 | None = None,
dt: np.timedelta64 | None = None,
runtime: datetime.timedelta | np.timedelta64 | None = None,
dt: datetime.timedelta | np.timedelta64 | None = None,
output_file=None,
verbose_progress=True,
):
Expand Down Expand Up @@ -510,8 +511,21 @@ def execute(
if dt is None:
dt = np.timedelta64(1, "s")

if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]:
raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}")
try:
dt = maybe_convert_python_timedelta_to_numpy(dt)
assert not np.isnat(dt)
sign_dt = np.sign(dt).astype(int)
assert sign_dt in [-1, 1]
except (ValueError, AssertionError) as e:
raise ValueError(f"dt must be a non-zero datetime.timedelta or np.timedelta64 object, got {dt=!r}") from e

if runtime is not None:
try:
runtime = maybe_convert_python_timedelta_to_numpy(runtime)
except ValueError as e:
raise ValueError(
f"The runtime must be a datetime.timedelta or np.timedelta64 object. Got {type(runtime)}"
) from e

self._data["dt"][:] = dt

Expand Down Expand Up @@ -609,9 +623,6 @@ def _get_simulation_start_and_end_times(
start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime)

if endtime is None:
if not isinstance(runtime, np.timedelta64):
raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}")

endtime = start_time + sign_dt * runtime

if time_interval is not None:
Expand Down
6 changes: 3 additions & 3 deletions tests/v4/test_particleset_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]:
with pytest.raises(
ValueError,
match="dt must be a positive or negative np.timedelta64 object, got .*",
match="dt must be a non-zero datetime.timedelta or np.timedelta64 object, got .*",
):
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt)

Expand All @@ -85,7 +85,7 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):

with pytest.raises(
ValueError,
match="The runtime must be a np.timedelta64 object. Got .*",
match="The runtime must be a datetime.timedelta or np.timedelta64 object. Got .*",
):
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1)

Expand Down Expand Up @@ -121,8 +121,8 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
"runtime, expectation",
[
(np.timedelta64(5, "s"), does_not_raise()),
(timedelta(seconds=2), does_not_raise()),
(5.0, pytest.raises(ValueError)),
(timedelta(seconds=2), pytest.raises(ValueError)),
(np.datetime64("2001-01-02T00:00:00"), pytest.raises(ValueError)),
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)),
],
Expand Down
21 changes: 19 additions & 2 deletions tests/v4/utils/test_time.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from datetime import datetime
from datetime import datetime, timedelta

import numpy as np
import pytest
from cftime import datetime as cftime_datetime
from hypothesis import given
from hypothesis import strategies as st

from parcels._core.utils.time import TimeInterval
from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy

calendar_strategy = st.sampled_from(
[
Expand Down Expand Up @@ -181,3 +181,20 @@ def test_time_interval_intersection_different_calendars():
)
with pytest.raises(ValueError, match="TimeIntervals are not compatible."):
interval1.intersection(interval2)


@pytest.mark.parametrize(
"td,expected",
[
pytest.param(np.timedelta64(1, "s"), np.timedelta64(1, "s"), id="noop"),
pytest.param(timedelta(days=5), np.timedelta64(5, "D"), id="single unit"),
pytest.param(timedelta(days=5, seconds=30), np.timedelta64(5, "D") + np.timedelta64(30, "s"), id="mixed units"),
pytest.param(timedelta(days=0), np.timedelta64(0, "s"), id="zero timedelta"),
pytest.param(
timedelta(seconds=-2), np.timedelta64(-2, "s"), id="negative timedelta"
), # included because timedelta(seconds=-2) -> timedelta(days=-1, seconds=86398)
],
)
def test_maybe_convert_python_timedelta_to_numpy(td, expected):
result = maybe_convert_python_timedelta_to_numpy(td)
assert result == expected
Loading