|
| 1 | +import datetime |
1 | 2 | import sys |
2 | 3 | import warnings |
3 | 4 | from collections.abc import Iterable |
|
8 | 9 | from scipy.spatial import KDTree |
9 | 10 | from tqdm import tqdm |
10 | 11 |
|
11 | | -from parcels._core.utils.time import TimeInterval |
| 12 | +from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy |
12 | 13 | from parcels._reprs import particleset_repr |
13 | 14 | from parcels.application_kernels.advection import AdvectionRK4 |
14 | 15 | from parcels.basegrid import GridType |
@@ -460,8 +461,8 @@ def execute( |
460 | 461 | self, |
461 | 462 | pyfunc=AdvectionRK4, |
462 | 463 | endtime: np.timedelta64 | np.datetime64 | None = None, |
463 | | - runtime: np.timedelta64 | None = None, |
464 | | - dt: np.timedelta64 | None = None, |
| 464 | + runtime: datetime.timedelta | np.timedelta64 | None = None, |
| 465 | + dt: datetime.timedelta | np.timedelta64 | None = None, |
465 | 466 | output_file=None, |
466 | 467 | verbose_progress=True, |
467 | 468 | ): |
@@ -510,8 +511,21 @@ def execute( |
510 | 511 | if dt is None: |
511 | 512 | dt = np.timedelta64(1, "s") |
512 | 513 |
|
513 | | - if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]: |
514 | | - raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}") |
| 514 | + try: |
| 515 | + dt = maybe_convert_python_timedelta_to_numpy(dt) |
| 516 | + assert not np.isnat(dt) |
| 517 | + sign_dt = np.sign(dt).astype(int) |
| 518 | + assert sign_dt in [-1, 1] |
| 519 | + except (ValueError, AssertionError): |
| 520 | + raise ValueError(f"dt must be a non-zero datetime.timedelta or np.timedelta64 object, got {dt=!r}") |
| 521 | + |
| 522 | + if runtime is not None: |
| 523 | + try: |
| 524 | + runtime = maybe_convert_python_timedelta_to_numpy(runtime) |
| 525 | + except ValueError: |
| 526 | + raise ValueError( |
| 527 | + f"The runtime must be a datetime.timedelta or np.timedelta64 object. Got {type(runtime)}" |
| 528 | + ) |
515 | 529 |
|
516 | 530 | self._data["dt"][:] = dt |
517 | 531 |
|
@@ -609,9 +623,6 @@ def _get_simulation_start_and_end_times( |
609 | 623 | start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime) |
610 | 624 |
|
611 | 625 | if endtime is None: |
612 | | - if not isinstance(runtime, np.timedelta64): |
613 | | - raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}") |
614 | | - |
615 | 626 | endtime = start_time + sign_dt * runtime |
616 | 627 |
|
617 | 628 | if time_interval is not None: |
|
0 commit comments