|
1 | 1 | import sys |
2 | 2 | import warnings |
3 | 3 | from collections.abc import Iterable |
| 4 | +from typing import Literal |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import xarray as xr |
@@ -506,59 +507,17 @@ def execute( |
506 | 507 | if output_file: |
507 | 508 | output_file.metadata["parcels_kernels"] = self._kernel.name |
508 | 509 |
|
509 | | - if (dt is not None) and (not isinstance(dt, np.timedelta64)): |
510 | | - raise TypeError("dt must be a np.timedelta64 object") |
511 | | - if dt is None or np.isnat(dt): |
| 510 | + if dt is None: |
512 | 511 | dt = np.timedelta64(1, "s") |
513 | | - self._data["dt"][:] = dt |
514 | | - sign_dt = np.sign(dt).astype(int) |
515 | | - if sign_dt not in [-1, 1]: |
516 | | - raise ValueError("dt must be a positive or negative np.timedelta64 object") |
517 | 512 |
|
518 | | - if self.fieldset.time_interval is None: |
519 | | - start_time = np.timedelta64(0, "s") # For the execution loop, we need a start time as a timedelta object |
520 | | - if runtime is None: |
521 | | - raise TypeError("The runtime must be provided when the time_interval is not defined for a fieldset.") |
| 513 | + if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]: |
| 514 | + raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}") |
522 | 515 |
|
523 | | - else: |
524 | | - if isinstance(runtime, np.timedelta64): |
525 | | - end_time = runtime |
526 | | - else: |
527 | | - raise TypeError("The runtime must be a np.timedelta64 object") |
| 516 | + self._data["dt"][:] = dt |
528 | 517 |
|
529 | | - else: |
530 | | - if not np.isnat(self.time_nextloop).any(): |
531 | | - if sign_dt > 0: |
532 | | - start_time = self.time_nextloop.min() |
533 | | - else: |
534 | | - start_time = self.time_nextloop.max() |
535 | | - else: |
536 | | - if sign_dt > 0: |
537 | | - start_time = self.fieldset.time_interval.left |
538 | | - else: |
539 | | - start_time = self.fieldset.time_interval.right |
540 | | - |
541 | | - if runtime is None: |
542 | | - if endtime is None: |
543 | | - raise ValueError( |
544 | | - "Must provide either runtime or endtime when time_interval is defined for a fieldset." |
545 | | - ) |
546 | | - # Ensure that the endtime uses the same type as the start_time |
547 | | - if isinstance(endtime, self.fieldset.time_interval.left.__class__): |
548 | | - if sign_dt > 0: |
549 | | - if endtime < self.fieldset.time_interval.left: |
550 | | - raise ValueError("The endtime must be after the start time of the fieldset.time_interval") |
551 | | - end_time = min(endtime, self.fieldset.time_interval.right) |
552 | | - else: |
553 | | - if endtime > self.fieldset.time_interval.right: |
554 | | - raise ValueError( |
555 | | - "The endtime must be before the end time of the fieldset.time_interval when dt < 0" |
556 | | - ) |
557 | | - end_time = max(endtime, self.fieldset.time_interval.left) |
558 | | - else: |
559 | | - raise TypeError("The endtime must be of the same type as the fieldset.time_interval start time.") |
560 | | - else: |
561 | | - end_time = start_time + runtime * sign_dt |
| 518 | + start_time, end_time = _get_simulation_start_and_end_times( |
| 519 | + self.fieldset.time_interval, self._data["time_nextloop"], runtime, endtime, sign_dt |
| 520 | + ) |
562 | 521 |
|
563 | 522 | # Set the time of the particles if it hadn't been set on initialisation |
564 | 523 | if np.isnat(self._data["time"]).any(): |
@@ -619,15 +578,69 @@ def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray, |
619 | 578 |
|
620 | 579 | if isinstance(time.left, np.datetime64) and isinstance(release_times[0], np.timedelta64): |
621 | 580 | release_times = np.array([t + time.left for t in release_times]) |
622 | | - if np.any(release_times < time.left): |
| 581 | + if np.any(release_times < time.left) or np.any(release_times > time.right): |
623 | 582 | warnings.warn( |
624 | 583 | "Some particles are set to be released outside the FieldSet's executable time domain.", |
625 | 584 | ParticleSetWarning, |
626 | 585 | stacklevel=2, |
627 | 586 | ) |
628 | | - if np.any(release_times > time.right): |
629 | | - warnings.warn( |
630 | | - "Some particles are set to be released after the fieldset's last time and the fields are not constant in time.", |
631 | | - ParticleSetWarning, |
632 | | - stacklevel=2, |
| 587 | + |
| 588 | + |
| 589 | +def _get_simulation_start_and_end_times( |
| 590 | + time_interval: TimeInterval, |
| 591 | + particle_release_times: np.ndarray, |
| 592 | + runtime: np.timedelta64 | None, |
| 593 | + endtime: np.datetime64 | None, |
| 594 | + sign_dt: Literal[-1, 1], |
| 595 | +) -> tuple[np.datetime64, np.datetime64]: |
| 596 | + if runtime is not None and endtime is not None: |
| 597 | + raise ValueError( |
| 598 | + f"runtime and endtime are mutually exclusive - provide one or the other. Got {runtime=!r}, {endtime=!r}" |
633 | 599 | ) |
| 600 | + |
| 601 | + if runtime is None and time_interval is None: |
| 602 | + raise ValueError("The runtime must be provided when the time_interval is not defined for a fieldset.") |
| 603 | + |
| 604 | + if sign_dt == 1: |
| 605 | + first_release_time = particle_release_times.min() |
| 606 | + else: |
| 607 | + first_release_time = particle_release_times.max() |
| 608 | + |
| 609 | + start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime) |
| 610 | + |
| 611 | + if endtime is None: |
| 612 | + if not isinstance(runtime, np.timedelta64): |
| 613 | + raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}") |
| 614 | + |
| 615 | + endtime = start_time + sign_dt * runtime |
| 616 | + |
| 617 | + if time_interval is not None: |
| 618 | + if type(endtime) != type(time_interval.left): # noqa: E721 |
| 619 | + raise ValueError( |
| 620 | + f"The endtime must be of the same type as the fieldset.time_interval start time. Got {endtime=!r} with {time_interval=!r}" |
| 621 | + ) |
| 622 | + if endtime not in time_interval: |
| 623 | + msg = ( |
| 624 | + f"Calculated/provided end time of {endtime!r} is not in fieldset time interval {time_interval!r}. Either reduce your runtime, modify your " |
| 625 | + "provided endtime, or change your release timing." |
| 626 | + "Important info:\n" |
| 627 | + f" First particle release: {first_release_time!r}\n" |
| 628 | + f" runtime: {runtime!r}\n" |
| 629 | + f" (calculated) endtime: {endtime!r}" |
| 630 | + ) |
| 631 | + raise ValueError(msg) |
| 632 | + |
| 633 | + return start_time, endtime |
| 634 | + |
| 635 | + |
| 636 | +def _get_start_time(first_release_time, time_interval, sign_dt, runtime): |
| 637 | + if time_interval is None: |
| 638 | + time_interval = TimeInterval(left=np.timedelta64(0, "s"), right=runtime) |
| 639 | + |
| 640 | + if sign_dt == 1: |
| 641 | + fieldset_start = time_interval.left |
| 642 | + else: |
| 643 | + fieldset_start = time_interval.right |
| 644 | + |
| 645 | + start_time = first_release_time if not np.isnat(first_release_time) else fieldset_start |
| 646 | + return start_time |
0 commit comments