diff --git a/parcels/_core/utils/time.py b/parcels/_core/utils/time.py index 4473e87e4..8cc34acf4 100644 --- a/parcels/_core/utils/time.py +++ b/parcels/_core/utils/time.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta from typing import TYPE_CHECKING, TypeVar import cftime @@ -118,3 +118,29 @@ def get_datetime_type_calendar( # datetime isn't a cftime datetime object pass return type(example_datetime), calendar + + +_TD_PRECISION_GETTER_FOR_UNIT = ( + (lambda dt: dt.days, "D"), + (lambda dt: dt.seconds, "s"), + (lambda dt: dt.microseconds, "us"), +) + + +def maybe_convert_python_timedelta_to_numpy(dt: timedelta | np.timedelta64) -> np.timedelta64: + if isinstance(dt, np.timedelta64): + return dt + + try: + dts = [] + for get_value_for_unit, np_unit in _TD_PRECISION_GETTER_FOR_UNIT: + value = get_value_for_unit(dt) + if value != 0: + dts.append(np.timedelta64(value, np_unit)) + + if dts: + return sum(dts) + else: + return np.timedelta64(0, "s") + except Exception as e: + raise ValueError(f"Could not convert {dt!r} to np.timedelta64.") from e diff --git a/parcels/particleset.py b/parcels/particleset.py index f67b6b34e..d8e9cbac2 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1,3 +1,4 @@ +import datetime import sys import warnings from collections.abc import Iterable @@ -8,7 +9,7 @@ from scipy.spatial import KDTree from tqdm import tqdm -from parcels._core.utils.time import TimeInterval +from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy from parcels._reprs import particleset_repr from parcels.application_kernels.advection import AdvectionRK4 from parcels.basegrid import GridType @@ -460,8 +461,8 @@ def execute( self, pyfunc=AdvectionRK4, endtime: np.timedelta64 | np.datetime64 | None = None, - runtime: np.timedelta64 | None = None, - dt: np.timedelta64 | None = None, + runtime: datetime.timedelta | np.timedelta64 | None = None, + dt: datetime.timedelta | np.timedelta64 | None = None, output_file=None, verbose_progress=True, ): @@ -510,8 +511,21 @@ def execute( if dt is None: dt = np.timedelta64(1, "s") - if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]: - raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}") + try: + dt = maybe_convert_python_timedelta_to_numpy(dt) + assert not np.isnat(dt) + sign_dt = np.sign(dt).astype(int) + assert sign_dt in [-1, 1] + except (ValueError, AssertionError) as e: + raise ValueError(f"dt must be a non-zero datetime.timedelta or np.timedelta64 object, got {dt=!r}") from e + + if runtime is not None: + try: + runtime = maybe_convert_python_timedelta_to_numpy(runtime) + except ValueError as e: + raise ValueError( + f"The runtime must be a datetime.timedelta or np.timedelta64 object. Got {type(runtime)}" + ) from e self._data["dt"][:] = dt @@ -609,9 +623,6 @@ def _get_simulation_start_and_end_times( start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime) if endtime is None: - if not isinstance(runtime, np.timedelta64): - raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}") - endtime = start_time + sign_dt * runtime if time_interval is not None: diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index 0c6a4baf6..4e4ba21e1 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -71,7 +71,7 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval): for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]: with pytest.raises( ValueError, - match="dt must be a positive or negative np.timedelta64 object, got .*", + match="dt must be a non-zero datetime.timedelta or np.timedelta64 object, got .*", ): ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt) @@ -85,7 +85,7 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval): with pytest.raises( ValueError, - match="The runtime must be a np.timedelta64 object. Got .*", + match="The runtime must be a datetime.timedelta or np.timedelta64 object. Got .*", ): ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1) @@ -121,8 +121,8 @@ def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval): "runtime, expectation", [ (np.timedelta64(5, "s"), does_not_raise()), + (timedelta(seconds=2), does_not_raise()), (5.0, pytest.raises(ValueError)), - (timedelta(seconds=2), pytest.raises(ValueError)), (np.datetime64("2001-01-02T00:00:00"), pytest.raises(ValueError)), (datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)), ], diff --git a/tests/v4/utils/test_time.py b/tests/v4/utils/test_time.py index 497a4eb29..5bab70b3b 100644 --- a/tests/v4/utils/test_time.py +++ b/tests/v4/utils/test_time.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta import numpy as np import pytest @@ -8,7 +8,7 @@ from hypothesis import given from hypothesis import strategies as st -from parcels._core.utils.time import TimeInterval +from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy calendar_strategy = st.sampled_from( [ @@ -181,3 +181,20 @@ def test_time_interval_intersection_different_calendars(): ) with pytest.raises(ValueError, match="TimeIntervals are not compatible."): interval1.intersection(interval2) + + +@pytest.mark.parametrize( + "td,expected", + [ + pytest.param(np.timedelta64(1, "s"), np.timedelta64(1, "s"), id="noop"), + pytest.param(timedelta(days=5), np.timedelta64(5, "D"), id="single unit"), + pytest.param(timedelta(days=5, seconds=30), np.timedelta64(5, "D") + np.timedelta64(30, "s"), id="mixed units"), + pytest.param(timedelta(days=0), np.timedelta64(0, "s"), id="zero timedelta"), + pytest.param( + timedelta(seconds=-2), np.timedelta64(-2, "s"), id="negative timedelta" + ), # included because timedelta(seconds=-2) -> timedelta(days=-1, seconds=86398) + ], +) +def test_maybe_convert_python_timedelta_to_numpy(td, expected): + result = maybe_convert_python_timedelta_to_numpy(td) + assert result == expected