diff --git a/parcels/kernel.py b/parcels/kernel.py index e4c42238f7..41409c15ab 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -92,10 +92,6 @@ def funcname(self): ret += f.__name__ return ret - @property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file) - def name(self): - return f"{self._ptype.name}{self.funcname}" - @property def ptype(self): return self._ptype diff --git a/parcels/particle.py b/parcels/particle.py index 8a51e3680f..3fbbacb30b 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -165,18 +165,42 @@ def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClas return ParticleClass( variables=[ - Variable("lon", dtype=spatial_dtype), + Variable( + "lon", + dtype=spatial_dtype, + attrs={"standard_name": "longitude", "units": "degrees_east", "axis": "X"}, + ), Variable("lon_nextloop", dtype=spatial_dtype, to_write=False), - Variable("lat", dtype=spatial_dtype), + Variable( + "lat", + dtype=spatial_dtype, + attrs={"standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, + ), Variable("lat_nextloop", dtype=spatial_dtype, to_write=False), - Variable("depth", dtype=spatial_dtype), + Variable( + "depth", + dtype=spatial_dtype, + attrs={"standard_name": "depth", "units": "m", "positive": "down"}, + ), Variable("dlon", dtype=spatial_dtype, to_write=False), Variable("dlat", dtype=spatial_dtype, to_write=False), Variable("ddepth", dtype=spatial_dtype, to_write=False), Variable("depth_nextloop", dtype=spatial_dtype, to_write=False), - Variable("time", dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE), + Variable( + "time", + dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE, + attrs={"standard_name": "time", "units": "seconds", "axis": "T"}, + ), Variable("time_nextloop", dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE, to_write=False), - Variable("trajectory", dtype=np.int64, to_write="once"), + Variable( + "trajectory", + dtype=np.int64, + to_write="once", + attrs={ + "long_name": "Unique identifier for each particle", + "cf_role": "trajectory_id", + }, + ), Variable("obs_written", dtype=np.int32, initial=0, to_write=False), Variable("dt", dtype="timedelta64[s]", initial=np.timedelta64(1, "s"), to_write=False), Variable("state", dtype=np.int32, initial=StatusCode.Evaluate, to_write=False), @@ -239,7 +263,7 @@ def _create_array_for_variable(variable: Variable, nparticles: int, time_interva "This function cannot handle attrgetter initial values." ) if (dtype := variable.dtype) is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: - dtype = type(time_interval.left) + dtype = _get_time_interval_dtype(time_interval) return np.full( shape=(nparticles,), fill_value=variable.initial, @@ -250,4 +274,8 @@ def _create_array_for_variable(variable: Variable, nparticles: int, time_interva def _get_time_interval_dtype(time_interval: TimeInterval | None) -> np.dtype: if time_interval is None: return np.timedelta64(1, "ns") - return type(time_interval.left) + time = time_interval.left + if isinstance(time, (np.datetime64, np.timedelta64)): + return time.dtype + else: + return object # cftime objects needs to be stored as object dtype diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 33451caae4..8fd4ea25dd 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -3,37 +3,40 @@ from __future__ import annotations import os -import warnings -from datetime import timedelta -from typing import TYPE_CHECKING +from datetime import datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING, Literal +import cftime import numpy as np import xarray as xr import zarr from zarr.storage import DirectoryStore import parcels -from parcels._reprs import default_repr +from parcels.particle import _SAME_AS_FIELDSET_TIME_INTERVAL, ParticleClass from parcels.tools._helpers import timedelta_to_float if TYPE_CHECKING: - from pathlib import Path + from parcels._core.utils.time import TimeInterval + from parcels.particle import Variable + from parcels.particleset import ParticleSet __all__ = ["ParticleFile"] _DATATYPES_TO_FILL_VALUES = { - np.float16: np.nan, - np.float32: np.nan, - np.float64: np.nan, - np.bool_: np.iinfo(np.int8).max, - np.int8: np.iinfo(np.int8).max, - np.int16: np.iinfo(np.int16).max, - np.int32: np.iinfo(np.int32).max, - np.int64: np.iinfo(np.int64).max, - np.uint8: np.iinfo(np.uint8).max, - np.uint16: np.iinfo(np.uint16).max, - np.uint32: np.iinfo(np.uint32).max, - np.uint64: np.iinfo(np.uint64).max, + np.dtype(np.float16): np.nan, + np.dtype(np.float32): np.nan, + np.dtype(np.float64): np.nan, + np.dtype(np.bool_): np.iinfo(np.int8).max, + np.dtype(np.int8): np.iinfo(np.int8).max, + np.dtype(np.int16): np.iinfo(np.int16).max, + np.dtype(np.int32): np.iinfo(np.int32).max, + np.dtype(np.int64): np.iinfo(np.int64).max, + np.dtype(np.uint8): np.iinfo(np.uint8).max, + np.dtype(np.uint16): np.iinfo(np.uint16).max, + np.dtype(np.uint32): np.iinfo(np.uint32).max, + np.dtype(np.uint64): np.iinfo(np.uint64).max, } @@ -61,57 +64,51 @@ class ParticleFile: ParticleFile object that can be used to write particle data to file """ - def __init__(self, store, particleset, outputdt, chunks=None, create_new_zarrfile=True): - self._outputdt = timedelta_to_float(outputdt) + def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): + if isinstance(outputdt, timedelta): + outputdt = np.timedelta64(int(outputdt.total_seconds()), "s") + + if not isinstance(outputdt, np.timedelta64): + raise ValueError(f"Expected outputdt to be a np.timedelta64 or datetime.timedelta, got {type(outputdt)}") + + self._outputdt = outputdt + + _assert_valid_chunks_tuple(chunks) self._chunks = chunks - self._particleset = particleset - self._parcels_mesh = "spherical" - if self.particleset.fieldset is not None: - self._parcels_mesh = self.particleset.fieldset.gridset[0].mesh - self.lonlatdepth_dtype = self.particleset.particledata.lonlatdepth_dtype self._maxids = 0 self._pids_written = {} + self.metadata = {} self._create_new_zarrfile = create_new_zarrfile - self._vars_to_write = {} - for var in self.particleset.particledata.ptype.variables: - if var.to_write: - self.vars_to_write[var.name] = var.dtype - self.particleset.fieldset._particlefile = self - - # Reset obs_written of each particle, in case new ParticleFile created for a ParticleSet - particleset.particledata.setallvardata("obs_written", 0) - - self.metadata = { - "feature_type": "trajectory", - "Conventions": "CF-1.6/CF-1.7", - "ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0", - "parcels_version": parcels.__version__, - "parcels_mesh": self._parcels_mesh, - } - - if isinstance(store, zarr.storage.Store): - self.store = store - else: - self.store = _get_store_from_pathlike(store) - if store.read_only: - raise ValueError(f"Store {store} is read-only. Please provide a writable store.") + if not isinstance(store, zarr.storage.Store): + store = _get_store_from_pathlike(store) + + self._store = store + + # TODO v4: Enable once updating to zarr v3 + # if store.read_only: + # raise ValueError(f"Store {store} is read-only. Please provide a writable store.") # TODO v4: Add check that if create_new_zarrfile is False, the store already exists def __repr__(self) -> str: return ( f"{type(self).__name__}(" - f"name={self.fname!r}, " - f"particleset={default_repr(self.particleset)}, " f"outputdt={self.outputdt!r}, " f"chunks={self.chunks!r}, " f"create_new_zarrfile={self.create_new_zarrfile!r})" ) - @property - def create_new_zarrfile(self): - return self._create_new_zarrfile + def set_metadata(self, parcels_grid_mesh: Literal["spherical", "flat"]): + self.metadata.update( + { + "feature_type": "trajectory", + "Conventions": "CF-1.6/CF-1.7", + "ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0", + "parcels_version": parcels.__version__, + "parcels_grid_mesh": parcels_grid_mesh, + } + ) @property def outputdt(self): @@ -122,49 +119,12 @@ def chunks(self): return self._chunks @property - def particleset(self): - return self._particleset + def store(self): + return self._store @property - def fname(self): - return self._fname - - @property - def vars_to_write(self): - return self._vars_to_write - - def _create_variables_attribute_dict(self): - """Creates the dictionary with variable attributes. - - Notes - ----- - For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. - """ - attrs = { - "z": {"long_name": "", "standard_name": "depth", "units": "m", "positive": "down"}, - "trajectory": { - "long_name": "Unique identifier for each particle", - "cf_role": "trajectory_id", - "_FillValue": _DATATYPES_TO_FILL_VALUES[np.int64], - }, - "time": {"long_name": "", "standard_name": "time", "units": "seconds", "axis": "T"}, - "lon": {"long_name": "", "standard_name": "longitude", "units": "degrees_east", "axis": "X"}, - "lat": {"long_name": "", "standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, - } - - attrs["time"]["units"] = "seconds since " # TODO fix units - attrs["time"]["calendar"] = "None" # TODO fix calendar - - for vname in self.vars_to_write: - if vname not in ["time", "lat", "lon", "depth", "trajectory"]: - attrs[vname] = { - "_FillValue": _DATATYPES_TO_FILL_VALUES[self.vars_to_write[vname]], - "long_name": "", - "standard_name": vname, - "units": "unknown", - } - - return attrs + def create_new_zarrfile(self): + return self._create_new_zarrfile def _convert_varout_name(self, var): if var == "depth": @@ -172,9 +132,6 @@ def _convert_varout_name(self, var): else: return var - def _write_once(self, var): - return self.particleset.particledata.ptype[var].to_write == "once" - def _extend_zarr_dims(self, Z, store, dtype, axis): if axis == 1: a = np.full((Z.shape[0], self.chunks[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) @@ -190,7 +147,7 @@ def _extend_zarr_dims(self, Z, store, dtype, axis): Z.append(a, axis=axis) zarr.consolidate_metadata(store) - def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=None): + def write(self, pset: ParticleSet, time, indices=None): """Write all data from one time step to the zarr file, before the particle locations are updated. @@ -199,34 +156,44 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N pset : ParticleSet object to write time : - Time at which to write ParticleSet + Time at which to write ParticleSet (same time object as fieldset) """ - time = timedelta_to_float(time) if time is not None else None - - if pset.particledata._ncount == 0: - warnings.warn( - f"ParticleSet is empty on writing as array at time {time:g}", - RuntimeWarning, - stacklevel=2, - ) - return + pclass = pset._ptype + time_interval = pset.fieldset.time_interval + particle_data = pset._data + time = timedelta_to_float(time - time_interval.left) + particle_data = _convert_particle_data_time_to_float_seconds(particle_data, time_interval) + + self._write_particle_data( + particle_data=particle_data, pclass=pclass, time_interval=time_interval, time=time, indices=indices + ) + def _write_particle_data(self, *, particle_data, pclass, time_interval, time, indices=None): + # if pset._data._ncount == 0: + # warnings.warn( + # f"ParticleSet is empty on writing as array at time {time:g}", + # RuntimeWarning, + # stacklevel=2, + # ) + # return + nparticles = len(particle_data["trajectory"]) + vars_to_write = _get_vars_to_write(pclass) if indices is None: - indices_to_write = pset.particledata._to_write_particles(time) + indices_to_write = _to_write_particles(particle_data, time) else: indices_to_write = indices if len(indices_to_write) == 0: return - pids = pset.particledata.getvardata("trajectory", indices_to_write) + pids = particle_data["trajectory"][indices_to_write] to_add = sorted(set(pids) - set(self._pids_written.keys())) for i, pid in enumerate(to_add): self._pids_written[pid] = self._maxids + i ids = np.array([self._pids_written[p] for p in pids], dtype=int) self._maxids = len(self._pids_written) - once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0] + once_ids = np.where(particle_data["obs_written"][indices_to_write] == 0)[0] if len(once_ids) > 0: ids_once = ids[once_ids] indices_to_write_once = indices_to_write[once_ids] @@ -234,7 +201,7 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N store = self.store if self.create_new_zarrfile: if self.chunks is None: - self._chunks = (len(pset), 1) + self._chunks = (nparticles, 1) if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index] arrsize = (self._maxids, self.chunks[1]) # type: ignore[index] else: @@ -243,45 +210,45 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N attrs=self.metadata, coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))}, ) - attrs = self._create_variables_attribute_dict() + attrs = _create_variables_attribute_dict(pclass, time_interval) obs = np.zeros((self._maxids), dtype=np.int32) - for var in self.vars_to_write: - varout = self._convert_varout_name(var) + for var in vars_to_write: + dtype = _maybe_convert_time_dtype(var.dtype) + varout = self._convert_varout_name(var.name) if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate - if self._write_once(var): + if var.to_write == "once": data = np.full( (arrsize[0],), - _DATATYPES_TO_FILL_VALUES[self.vars_to_write[var]], - dtype=self.vars_to_write[var], + _DATATYPES_TO_FILL_VALUES[dtype], + dtype=dtype, ) - data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) + data[ids_once] = particle_data[var.name][indices_to_write_once] dims = ["trajectory"] else: - data = np.full( - arrsize, _DATATYPES_TO_FILL_VALUES[self.vars_to_write[var]], dtype=self.vars_to_write[var] - ) - data[ids, 0] = pset.particledata.getvardata(var, indices_to_write) + data = np.full(arrsize, _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + data[ids, 0] = particle_data[var.name][indices_to_write] dims = ["trajectory", "obs"] - ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) - ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index] + ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name]) + ds[varout].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks # type: ignore[index] ds.to_zarr(store, mode="w") self._create_new_zarrfile = False else: Z = zarr.group(store=store, overwrite=False) - obs = pset.particledata.getvardata("obs_written", indices_to_write) - for var in self.vars_to_write: - varout = self._convert_varout_name(var) + obs = particle_data["obs_written"][indices_to_write] + for var in vars_to_write: + dtype = _maybe_convert_time_dtype(var.dtype) + varout = self._convert_varout_name(var.name) if self._maxids > Z[varout].shape[0]: - self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0) - if self._write_once(var): + self._extend_zarr_dims(Z[varout], store, dtype=dtype, axis=0) + if var.to_write == "once": if len(once_ids) > 0: - Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) + Z[varout].vindex[ids_once] = particle_data[var.name][indices_to_write_once] else: if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var] - self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=1) - Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write) + self._extend_zarr_dims(Z[varout], store, dtype=dtype, axis=1) + Z[varout].vindex[ids, obs] = particle_data[var.name][indices_to_write] - pset.particledata.setvardata("obs_written", indices_to_write, obs + 1) + particle_data["obs_written"][indices_to_write] = obs + 1 def write_latest_locations(self, pset, time): """Write the current (latest) particle locations to zarr file. @@ -296,7 +263,7 @@ def write_latest_locations(self, pset, time): Time at which to write ParticleSet. Note that typically this would be pset.time_nextloop """ for var in ["lon", "lat", "depth", "time"]: - pset.particledata.setallvardata(f"{var}", pset.particledata.getvardata(f"{var}_nextloop")) + pset._data[f"{var}"] = pset._data[f"{var}_nextloop"] self.write(pset, time) @@ -308,3 +275,105 @@ def _get_store_from_pathlike(path: Path | str) -> DirectoryStore: raise ValueError(f"ParticleFile name must end with '.zarr' extension. Got path {path!r}.") return DirectoryStore(path) + + +def _get_vars_to_write(particle: ParticleClass) -> list[Variable]: + return [v for v in particle.variables if v.to_write is not False] + + +def _create_variables_attribute_dict(particle: ParticleClass, time_interval: TimeInterval) -> dict: + """Creates the dictionary with variable attributes. + + Notes + ----- + For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. + """ + attrs = {} + + vars = [var for var in particle.variables if var.to_write is not False] + for var in vars: + fill_value = {} + if var.dtype is not _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + fill_value = {"_FillValue": _DATATYPES_TO_FILL_VALUES[var.dtype]} + + attrs[var.name] = {**var.attrs, **fill_value} + + attrs["time"].update(_get_calendar_and_units(time_interval)) + + return attrs + + +def _to_write_particles(particle_data, time): + """Return the Particles that need to be written at time: if particle.time is between time-dt/2 and time+dt (/2)""" + return np.where( + ( + np.less_equal( + time - np.abs(particle_data["dt"] / 2), + particle_data["time_nextloop"], + where=np.isfinite(particle_data["time_nextloop"]), + ) + & np.greater_equal( + time + np.abs(particle_data["dt"] / 2), + particle_data["time_nextloop"], + where=np.isfinite(particle_data["time_nextloop"]), + ) # check time - dt/2 <= particle_data["time"] <= time + dt/2 + | ( + (np.isnan(particle_data["dt"])) + & np.equal(time, particle_data["time_nextloop"], where=np.isfinite(particle_data["time_nextloop"])) + ) # or dt is NaN and time matches particle_data["time"] + ) + & (np.isfinite(particle_data["trajectory"])) + & (np.isfinite(particle_data["time_nextloop"])) + )[0] + + +def _convert_particle_data_time_to_float_seconds(particle_data, time_interval): + #! Important that this is a shallow copy, so that updates to this propogate back to the original data + particle_data = particle_data.copy() + + particle_data["time"] = ((particle_data["time"] - time_interval.left) / np.timedelta64(1, "s")).astype(np.float64) + particle_data["time_nextloop"] = ( + (particle_data["time_nextloop"] - time_interval.left) / np.timedelta64(1, "s") + ).astype(np.float64) + particle_data["dt"] = (particle_data["dt"] / np.timedelta64(1, "s")).astype(np.float64) + return particle_data + + +def _maybe_convert_time_dtype(dtype: np.dtype | _SAME_AS_FIELDSET_TIME_INTERVAL) -> np.dtype: + """Convert the dtype of time to float64 if it is not already.""" + if dtype is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + return np.dtype( + np.uint64 + ) #! We need to have here some proper mechanism for converting particle data to the data that is to be output to zarr (namely the time needs to be converted to float seconds by subtracting the time_interval.left) + return dtype + + +def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: + calendar = None + units = "seconds" + if isinstance(time_interval.left, (np.datetime64, datetime)): + calendar = "standard" + elif isinstance(time_interval.left, cftime.datetime): + calendar = time_interval.left.calendar + + if calendar is not None: + units += f" since {time_interval.left}" + + attrs = {"units": units} + if calendar is not None: + attrs["calendar"] = calendar + + return attrs + + +def _assert_valid_chunks_tuple(chunks): + e = ValueError(f"chunks must be a tuple of integers with length 2, got {chunks=!r} instead.") + if chunks is None: + return + + if not isinstance(chunks, tuple): + raise e + if len(chunks) != 2: + raise e + if not all(isinstance(c, int) for c in chunks): + raise e diff --git a/parcels/particleset.py b/parcels/particleset.py index d8e9cbac29..b44cef0be6 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -15,7 +15,6 @@ from parcels.basegrid import GridType from parcels.kernel import Kernel from parcels.particle import KernelParticle, Particle, create_particle_data -from parcels.particlefile import ParticleFile from parcels.tools.converters import convert_to_flat_array from parcels.tools.loggers import logger from parcels.tools.statuscodes import StatusCode @@ -392,10 +391,6 @@ def InteractionKernel(self, pyfunc_inter): return None return InteractionKernel(self.fieldset, self._ptype, pyfunc=pyfunc_inter) - def ParticleFile(self, *args, **kwargs): - """Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile` object from the ParticleSet.""" - return ParticleFile(*args, particleset=self, **kwargs) - def data_indices(self, variable_name, compare_values, invert=False): """Get the indices of all particles where the value of `variable_name` equals (one of) `compare_values`. @@ -505,8 +500,9 @@ def execute( self._kernel = pyfunc - if output_file: - output_file.metadata["parcels_kernels"] = self._kernel.name + if output_file is not None: + output_file.set_metadata(self.fieldset.gridset[0]._mesh) + output_file.metadata["parcels_kernels"] = self._kernel.funcname if dt is None: dt = np.timedelta64(1, "s") @@ -542,24 +538,25 @@ def execute( # Set up pbar if output_file: - logger.info(f"Output files are stored in {output_file.fname}.") + logger.info(f"Output files are stored in {output_file.store}.") if verbose_progress: pbar = tqdm(total=(end_time - start_time) / np.timedelta64(1, "s"), file=sys.stdout) - next_output = outputdt if output_file else None + next_output = start_time + sign_dt * outputdt if output_file else None time = start_time while sign_dt * (time - end_time) < 0: - if sign_dt > 0: - next_time = end_time # TODO update to min(next_output, end_time) when ParticleFile works + if next_output is not None: + f = min if sign_dt > 0 else max + next_time = f(next_output, end_time) else: - next_time = end_time # TODO update to max(next_output, end_time) when ParticleFile works + next_time = end_time + self._kernel.execute(self, endtime=next_time, dt=dt) - # TODO: Handle IO timing based of timedelta or datetime objects if next_output: - if abs(next_time - next_output) < 1e-12: + if np.abs(next_time - next_output) < np.timedelta64(1000, "ns"): if output_file: output_file.write(self, next_output) if np.isfinite(outputdt): diff --git a/tests/conftest.py b/tests/conftest.py index 69787802c6..82020c37e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,13 @@ import pytest +from zarr.storage import MemoryStore @pytest.fixture() def tmp_zarrfile(tmp_path, request): test_name = request.node.name yield tmp_path / f"{test_name}-output.zarr" + + +@pytest.fixture +def tmp_store(): + return MemoryStore() diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py deleted file mode 100755 index 39cee4ae16..0000000000 --- a/tests/test_particlefile.py +++ /dev/null @@ -1,406 +0,0 @@ -import os -import tempfile -from datetime import timedelta - -import numpy as np -import pytest -import xarray as xr -from zarr.storage import MemoryStore - -import parcels -from parcels import ( - AdvectionRK4, - Field, - FieldSet, - Particle, - ParticleSet, - Variable, -) -from tests.common_kernels import DoNothing -from tests.utils import create_fieldset_zeros_simple - - -@pytest.fixture -def fieldset(): - return create_fieldset_zeros_simple() - - -def test_metadata(fieldset, tmp_zarrfile): - pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) - - pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=1)) - - ds = xr.open_zarr(tmp_zarrfile) - assert ds.attrs["parcels_kernels"].lower() == "ParticleDoNothing".lower() - - -def test_pfile_array_write_zarr_memorystore(fieldset): - """Check that writing to a Zarr MemoryStore works.""" - npart = 10 - zarr_store = MemoryStore() - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) - pfile = pset.ParticleFile(zarr_store, outputdt=1) - pfile.write(pset, 0) - - ds = xr.open_zarr(zarr_store) - assert ds.sizes["trajectory"] == npart - ds.close() - - -def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): - npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) - pfile.write(pset, 0) - pset.remove_indices(3) - for p in pset: - p.time = 1 - pfile.write(pset, 1) - - ds = xr.open_zarr(tmp_zarrfile) - timearr = ds["time"][:] - assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0])) - ds.close() - - -def test_pfile_set_towrite_False(fieldset, tmp_zarrfile): - npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart)) - pset.set_variable_write_status("depth", False) - pset.set_variable_write_status("lat", False) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) - - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += 0.1 - - pset.execute(Update_lon, runtime=10, output_file=pfile) - - ds = xr.open_zarr(tmp_zarrfile) - assert "time" in ds - assert "z" not in ds - assert "lat" not in ds - ds.close() - - # For pytest purposes, we need to reset to original status - pset.set_variable_write_status("depth", True) - pset.set_variable_write_status("lat", True) - - -@pytest.mark.parametrize("chunks_obs", [1, None]) -def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): - npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) - chunks = (npart, chunks_obs) if chunks_obs else None - pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=1) - pfile.write(pset, 0) - for _ in range(npart): - pset.remove_indices(-1) - pfile.write(pset, 1) - pfile.write(pset, 2) - - ds = xr.open_zarr(tmp_zarrfile).load() - assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms")) - if chunks_obs is not None: - assert ds["time"][:].shape == chunks - else: - assert ds["time"][:].shape[0] == npart - assert np.all(np.isnan(ds["time"][:, 1:])) - ds.close() - - -@pytest.mark.xfail(reason="lonlatdepth_dtype removed. Update implementation to use a different particle") -def test_variable_write_double(fieldset, tmp_zarrfile): - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += 0.1 - - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.00001) - pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile) - - ds = xr.open_zarr(tmp_zarrfile) - lons = ds["lon"][:] - assert isinstance(lons.values[0, 0], np.float64) - ds.close() - - -def test_write_dtypes_pfile(fieldset, tmp_zarrfile): - dtypes = [ - np.float32, - np.float64, - np.int32, - np.uint32, - np.int64, - np.uint64, - np.bool_, - np.int8, - np.uint8, - np.int16, - np.uint16, - ] - - extra_vars = [Variable(f"v_{d.__name__}", dtype=d, initial=0.0) for d in dtypes] - MyParticle = Particle.add_variables(extra_vars) - - pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1) - pfile.write(pset, 0) - - ds = xr.open_zarr( - tmp_zarrfile, mask_and_scale=False - ) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float - for d in dtypes: - assert ds[f"v_{d.__name__}"].dtype == d - - -@pytest.mark.parametrize("npart", [1, 2, 5]) -def test_variable_written_once(fieldset, tmp_zarrfile, npart): - def Update_v(particle, fieldset, time): # pragma: no cover - particle.v_once += 1.0 - particle.age += particle.dt - - MyParticle = Particle.add_variables( - [ - Variable("v_once", dtype=np.float64, initial=0.0, to_write="once"), - Variable("age", dtype=np.float32, initial=0.0), - ] - ) - lon = np.linspace(0, 1, npart) - lat = np.linspace(1, 0, npart) - time = np.arange(0, npart / 10.0, 0.1, dtype=np.float64) - pset = ParticleSet(fieldset, pclass=MyParticle, lon=lon, lat=lat, time=time, v_once=time) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.1) - pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile) - - assert np.allclose(pset.v_once - time - pset.age * 10, 1, atol=1e-5) - ds = xr.open_zarr(tmp_zarrfile) - vfile = np.ma.filled(ds["v_once"][:], np.nan) - assert vfile.shape == (npart,) - ds.close() - - -@pytest.mark.parametrize("type", ["repeatdt", "timearr"]) -@pytest.mark.parametrize("repeatdt", range(1, 3)) -@pytest.mark.parametrize("dt", [-1, 1]) -@pytest.mark.parametrize("maxvar", [2, 4, 10]) -def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, repeatdt, tmp_zarrfile, dt, maxvar): - runtime = 10 - fieldset.maxvar = maxvar - pset = None - - MyParticle = Particle.add_variables( - [Variable("sample_var", initial=0.0), Variable("v_once", dtype=np.float64, initial=0.0, to_write="once")] - ) - - if type == "repeatdt": - pset = ParticleSet(fieldset, lon=[0], lat=[0], pclass=MyParticle, repeatdt=repeatdt) - elif type == "timearr": - pset = ParticleSet( - fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime)) - ) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) - - def IncrLon(particle, fieldset, time): # pragma: no cover - particle.sample_var += 1.0 - if particle.sample_var > fieldset.maxvar: - particle.delete() - - for _ in range(runtime): - pset.execute(IncrLon, dt=dt, runtime=1.0, output_file=pfile) - - ds = xr.open_zarr(tmp_zarrfile) - samplevar = ds["sample_var"][:] - if type == "repeatdt": - assert samplevar.shape == (runtime // repeatdt, min(maxvar + 1, runtime)) - assert np.allclose(pset.sample_var, np.arange(maxvar, -1, -repeatdt)) - elif type == "timearr": - assert samplevar.shape == (runtime, min(maxvar + 1, runtime)) - # test whether samplevar[:, k] = k - for k in range(samplevar.shape[1]): - assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1) - filesize = os.path.getsize(str(tmp_zarrfile)) - assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB - ds.close() - - -@pytest.mark.parametrize("repeatdt", [1, 2]) -@pytest.mark.parametrize("nump", [1, 10]) -def test_pfile_chunks_repeatedrelease(fieldset, repeatdt, nump, tmp_zarrfile): - runtime = 8 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.zeros((nump, 1)), lat=np.zeros((nump, 1)), repeatdt=repeatdt) - chunks = (20, 10) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1, chunks=chunks) - - def DoNothing(particle, fieldset, time): # pragma: no cover - pass - - pset.execute(DoNothing, dt=1, runtime=runtime, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) - assert ds["time"].shape == (int(nump * runtime / repeatdt), chunks[1]) - - -def test_write_timebackward(fieldset, tmp_zarrfile): - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon -= 0.1 * particle.dt - - pset = ParticleSet(fieldset, pclass=Particle, lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3]) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1.0) - pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) - trajs = ds["trajectory"][:] - assert trajs.values.dtype == "int64" - assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release - ds.close() - - -def test_write_xiyi(fieldset, tmp_zarrfile): - fieldset.U.data[:] = 1 # set a non-zero zonal velocity - fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2])) - dt = 3600 - - XiYiParticle = Particle.add_variables( - [ - Variable("pxi0", dtype=np.int32, initial=0.0), - Variable("pxi1", dtype=np.int32, initial=0.0), - Variable("pyi", dtype=np.int32, initial=0.0), - ] - ) - - def Get_XiYi(particle, fieldset, time): # pragma: no cover - """Kernel to sample the grid indices of the particle. - Note that this sampling should be done _before_ the advection kernel - and that the first outputted value is zero. - Be careful when using multiple grids, as the index may be different for the grids. - """ - particle.pxi0 = fieldset.U.unravel_index(particle.ei)[2] - particle.pxi1 = fieldset.P.unravel_index(particle.ei)[2] - particle.pyi = fieldset.U.unravel_index(particle.ei)[1] - - def SampleP(particle, fieldset, time): # pragma: no cover - if time > 5 * 3600: - _ = fieldset.P[particle] # To trigger sampling of the P field - - pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=dt) - pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) - - ds = xr.open_zarr(tmp_zarrfile) - pxi0 = ds["pxi0"][:].values.astype(np.int32) - pxi1 = ds["pxi1"][:].values.astype(np.int32) - lons = ds["lon"][:].values - pyi = ds["pyi"][:].values.astype(np.int32) - lats = ds["lat"][:].values - - for p in range(pyi.shape[0]): - assert (pxi0[p, 0] == 0) and (pxi0[p, -1] == pset[p].pxi0) # check that particle has moved - assert np.all(pxi1[p, :6] == 0) # check that particle has not been sampled on grid 1 until time 6 - assert np.all(pxi1[p, 6:] > 0) # check that particle has not been sampled on grid 1 after time 6 - for xi, lon in zip(pxi0[p, 1:], lons[p, 1:], strict=True): - assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi + 1] - for xi, lon in zip(pxi1[p, 6:], lons[p, 6:], strict=True): - assert fieldset.P.grid.lon[xi] <= lon < fieldset.P.grid.lon[xi + 1] - for yi, lat in zip(pyi[p, 1:], lats[p, 1:], strict=True): - assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi + 1] - ds.close() - - -def test_reset_dt(fieldset, tmp_zarrfile): - # Assert that p.dt gets reset when a write_time is not a multiple of dt - # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions - - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += 0.1 - - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.05) - pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile) - - assert np.allclose(pset.lon, 0.6) - - -def test_correct_misaligned_outputdt_dt(fieldset, tmp_zarrfile): - """Testing that outputdt does not need to be a multiple of dt.""" - - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += particle.dt - - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=3) - pset.execute(pset.Kernel(Update_lon), endtime=11, dt=2, output_file=ofile) - - ds = xr.open_zarr(tmp_zarrfile) - assert np.allclose(ds.lon.values, [0, 3, 6, 9]) - assert np.allclose( - ds.time.values[0, :], [np.timedelta64(t, "s") for t in [0, 3, 6, 9]], atol=np.timedelta64(1, "ns") - ) - - -def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwargs, particle_class=Particle): - npart = 10 - - if fieldset is None: - fieldset = create_fieldset_zeros_simple() - - pset = ParticleSet( - fieldset, - pclass=particle_class, - lon=np.full(npart, fieldset.U.lon.mean()), - lat=np.full(npart, fieldset.U.lat.mean()), - ) - - with tempfile.TemporaryDirectory() as dir: - name = f"{dir}/test.zarr" - output_file = pset.ParticleFile(name=name, outputdt=outputdt) - - pset.execute(DoNothing, output_file=output_file, **execute_kwargs) - ds = xr.open_zarr(name).load() - - return ds - - -def test_pset_execute_outputdt_forwards(): - """Testing output data dt matches outputdt in forward time.""" - outputdt = timedelta(hours=1) - runtime = timedelta(hours=5) - dt = timedelta(minutes=5) - - ds = setup_pset_execute( - fieldset=create_fieldset_zeros_simple(), outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt) - ) - - assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) - - -def test_pset_execute_outputdt_backwards(): - """Testing output data dt matches outputdt in backwards time.""" - outputdt = timedelta(hours=1) - runtime = timedelta(days=2) - dt = -timedelta(minutes=5) - - ds = setup_pset_execute( - fieldset=create_fieldset_zeros_simple(), outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt) - ) - file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values - assert np.all(file_outputdt == np.timedelta64(-outputdt)) - - -def test_pset_execute_outputdt_backwards_fieldset_timevarying(): - """test_pset_execute_outputdt_backwards() still passed despite #1722 as it doesn't account for time-varying fields, - which for some reason #1722 - """ - outputdt = timedelta(hours=1) - runtime = timedelta(days=2) - dt = -timedelta(minutes=5) - - # TODO: Not ideal using the `download_example_dataset` here, but I'm struggling to recreate this error using the test suite fieldsets we have - example_dataset_folder = parcels.download_example_dataset("MovingEddies_data") - filenames = { - "U": str(example_dataset_folder / "moving_eddiesU.nc"), - "V": str(example_dataset_folder / "moving_eddiesV.nc"), - } - variables = {"U": "vozocrtx", "V": "vomecrty"} - dimensions = {"lon": "nav_lon", "lat": "nav_lat", "time": "time_counter"} - fieldset = parcels.FieldSet.from_netcdf(filenames, variables, dimensions) - - ds = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) - file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values - assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) diff --git a/tests/tools/test_warnings.py b/tests/tools/test_warnings.py index 49007b6717..1570d44bfa 100644 --- a/tests/tools/test_warnings.py +++ b/tests/tools/test_warnings.py @@ -11,6 +11,7 @@ ParticleSet, ParticleSetWarning, ) +from parcels.particlefile import ParticleFile from tests.utils import TEST_DATA @@ -30,7 +31,7 @@ def test_file_warnings(tmp_zarrfile): data={"U": np.zeros((1, 1)), "V": np.zeros((1, 1))}, dimensions={"lon": [0], "lat": [0]} ) pset = ParticleSet(fieldset=fieldset, pclass=Particle, lon=[0, 0], lat=[0, 0], time=[0, 1]) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=2) + pfile = ParticleFile(name=tmp_zarrfile, outputdt=2) with pytest.warns(ParticleSetWarning, match="Some of the particles have a start time difference.*"): pset.execute(AdvectionRK4, runtime=3, dt=1, output_file=pfile) diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index 2522a8b6f5..751c149a35 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -17,6 +17,7 @@ from parcels.field import Field, VectorField from parcels.fieldset import FieldSet from parcels.particle import Particle, Variable +from parcels.particlefile import ParticleFile from parcels.particleset import ParticleSet from parcels.tools.statuscodes import StatusCode from parcels.xgrid import XGrid @@ -52,6 +53,26 @@ def test_advection_zonal(mesh, npart=10): assert (np.diff(pset.lon) < 1.0e-4).all() +def test_advection_zonal_with_particlefile(tmp_store): + """Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`.""" + npart = 10 + ds = simple_UV_dataset(mesh="flat") + ds["U"].data[:] = 1.0 + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U"], grid, interp_method=XLinear) + V = Field("V", ds["V"], grid, interp_method=XLinear) + UV = VectorField("UV", U, V) + fieldset = FieldSet([U, V, UV]) + + pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) + pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(30, "m")) + pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"), output_file=pfile) + + assert (np.diff(pset.lon) < 1.0e-4).all() + ds = xr.open_zarr(tmp_store) + np.testing.assert_allclose(ds.isel(obs=-1).lon.values, pset.lon) + + def periodicBC(particle, fieldset, time): particle.total_dlon += particle.dlon particle.lon = np.fmod(particle.lon, 2) diff --git a/tests/v4/test_interpolation.py b/tests/v4/test_interpolation.py index d3ca4dea35..5fa92ab713 100644 --- a/tests/v4/test_interpolation.py +++ b/tests/v4/test_interpolation.py @@ -10,6 +10,7 @@ from parcels.field import Field, VectorField from parcels.fieldset import FieldSet from parcels.particle import Particle, Variable +from parcels.particlefile import ParticleFile from parcels.particleset import ParticleSet from parcels.tools.statuscodes import StatusCode from parcels.uxgrid import UxGrid @@ -169,7 +170,7 @@ def DeleteParticle(particle, fieldset, time): if particle.state >= 50: particle.state = StatusCode.Delete - outfile = pset.ParticleFile(f"test_interpolation_v4_{interp_name}", outputdt=np.timedelta64(1, "s")) + outfile = ParticleFile(f"test_interpolation_v4_{interp_name}", outputdt=np.timedelta64(1, "s")) pset.execute( [AdvectionRK4_3D, DeleteParticle], runtime=np.timedelta64(4, "s"), diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py new file mode 100755 index 0000000000..567793b4e6 --- /dev/null +++ b/tests/v4/test_particlefile.py @@ -0,0 +1,486 @@ +import os +import tempfile +from datetime import timedelta + +import numpy as np +import pytest +import xarray as xr +from zarr.storage import MemoryStore + +import parcels +from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField +from parcels._core.utils.time import TimeInterval +from parcels._datasets.structured.generic import datasets +from parcels.particle import Particle, create_particle_data, get_default_particle +from parcels.particlefile import ParticleFile +from parcels.tools.statuscodes import StatusCode +from parcels.xgrid import XGrid +from tests.common_kernels import DoNothing + + +@pytest.fixture +def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remove duplicates + """Fixture to create a FieldSet object for testing.""" + ds = datasets["ds_2d_left"] + grid = XGrid.from_dataset(ds) + U = Field("U", ds["U (A grid)"], grid) + V = Field("V", ds["V (A grid)"], grid) + UV = VectorField("UV", U, V) + + return FieldSet( + [U, V, UV], + ) + + +@pytest.mark.skip +def test_metadata(fieldset, tmp_zarrfile): + pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) + + pset.execute(DoNothing, runtime=1, output_file=ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))) + + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) # TODO v4: Fix metadata and re-enable decode_cf + assert ds.attrs["parcels_kernels"].lower() == "ParticleDoNothing".lower() + + +def test_pfile_array_write_zarr_memorystore(fieldset): + """Check that writing to a Zarr MemoryStore works.""" + npart = 10 + zarr_store = MemoryStore() + pset = ParticleSet( + fieldset, + pclass=Particle, + lon=np.linspace(0, 1, npart), + lat=0.5 * np.ones(npart), + time=fieldset.time_interval.left, + ) + pfile = ParticleFile(zarr_store, outputdt=np.timedelta64(1, "s")) + pfile.write(pset, time=fieldset.time_interval.left) + + ds = xr.open_zarr(zarr_store) + assert ds.sizes["trajectory"] == npart + + +def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): + npart = 10 + pset = ParticleSet( + fieldset, + pclass=Particle, + lon=np.linspace(0, 1, npart), + lat=0.5 * np.ones(npart), + time=fieldset.time_interval.left, + ) + pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pset._data["time"][:] = fieldset.time_interval.left + pset._data["time_nextloop"][:] = fieldset.time_interval.left + pfile.write(pset, time=fieldset.time_interval.left) + pset.remove_indices(3) + new_time = fieldset.time_interval.left + np.timedelta64(1, "D") + pset._data["time"][:] = new_time + pset._data["time_nextloop"][:] = new_time + pfile.write(pset, new_time) + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) + timearr = ds["time"][:] + pytest.skip( + "TODO v4: Set decode_cf=True, which will mean that missing values get decoded to NaT rather than fill value" + ) + assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0])) + + +@pytest.mark.parametrize("chunks_obs", [1, None]) +def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): + npart = 10 + pset = ParticleSet( + fieldset, + pclass=Particle, + lon=np.linspace(0, 1, npart), + lat=0.5 * np.ones(npart), + time=fieldset.time_interval.left, + ) + chunks = (npart, chunks_obs) if chunks_obs else None + pfile = ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=np.timedelta64(1, "s")) + pfile.write(pset, time=fieldset.time_interval.left) + for _ in range(npart): + pset.remove_indices(-1) + pfile.write(pset, fieldset.time_interval.left + np.timedelta64(1, "D")) + pfile.write(pset, fieldset.time_interval.left + np.timedelta64(2, "D")) + + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False).load() + pytest.skip( + "TODO v4: Set decode_cf=True, which will mean that missing values get decoded to NaT rather than fill value" + ) + assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms")) + if chunks_obs is not None: + assert ds["time"][:].shape == chunks + else: + assert ds["time"][:].shape[0] == npart + assert np.all(np.isnan(ds["time"][:, 1:])) + + +@pytest.mark.skip(reason="TODO v4: stuck in infinite loop") +def test_variable_write_double(fieldset, tmp_zarrfile): + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += 0.1 + + particle = get_default_particle(np.float64) + pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) + ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(10, "us")) + pset.execute( + pset.Kernel(Update_lon), + runtime=np.timedelta64(1, "ms"), + dt=np.timedelta64(10, "us"), + output_file=ofile, + ) + + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) # TODO v4: Fix metadata and re-enable decode_cf + lons = ds["lon"][:] + assert isinstance(lons.values[0, 0], np.float64) + + +def test_write_dtypes_pfile(fieldset, tmp_zarrfile): + dtypes = [ + np.float32, + np.float64, + np.int32, + np.uint32, + np.int64, + np.uint64, + np.bool_, + np.int8, + np.uint8, + np.int16, + np.uint16, + ] + + extra_vars = [Variable(f"v_{d.__name__}", dtype=d, initial=0.0) for d in dtypes] + MyParticle = Particle.add_variable(extra_vars) + + pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=fieldset.time_interval.left) + pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile.write(pset, time=fieldset.time_interval.left) + + ds = xr.open_zarr( + tmp_zarrfile, mask_and_scale=False + ) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float + for d in dtypes: + assert ds[f"v_{d.__name__}"].dtype == d + + +def test_variable_written_once(): + # Test that a vaiable is only written once. This should also work with gradual particle release (so the written once time is actually after the release of the particle) + ... + + +@pytest.mark.parametrize( + "dt", + [ + pytest.param(-np.timedelta64(1, "s"), marks=pytest.mark.xfail(reason="need to fix backwards in time")), + np.timedelta64(1, "s"), + ], +) +@pytest.mark.parametrize("maxvar", [2, 4, 10]) +def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, dt, maxvar): + """Tests that if particles are released and deleted based on age that resulting output file is correct.""" + npart = 10 + runtime = np.timedelta64(npart, "s") + fieldset.add_constant("maxvar", maxvar) + pset = None + + MyParticle = Particle.add_variable( + [Variable("sample_var", initial=0.0), Variable("v_once", dtype=np.float64, initial=0.0, to_write="once")] + ) + + pset = ParticleSet( + fieldset, + lon=np.zeros(npart), + lat=np.zeros(npart), + pclass=MyParticle, + time=fieldset.time_interval.left + np.array([np.timedelta64(i, "s") for i in range(npart)]), + ) + pfile = ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) + + def IncrLon(particle, fieldset, time): # pragma: no cover + particle.sample_var += 1.0 + particle.state = np.where( + particle.sample_var > fieldset.maxvar, + StatusCode.Delete, + particle.state, + ) + + for _ in range(npart): + pset.execute(IncrLon, dt=dt, runtime=np.timedelta64(1, "s"), output_file=pfile) + + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) + pytest.skip( + "TODO v4: Set decode_cf=True, which will mean that missing values get decoded to NaT rather than fill value" + ) + samplevar = ds["sample_var"][:] + assert samplevar.shape == (runtime, min(maxvar + 1, runtime)) + # test whether samplevar[:, k] = k + for k in range(samplevar.shape[1]): + assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1) + filesize = os.path.getsize(str(tmp_zarrfile)) + assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB + + +@pytest.mark.xfail(reason="need to fix backwards in time") +def test_write_timebackward(fieldset, tmp_zarrfile): + def Update_lon(particle, fieldset, time): # pragma: no cover + dt = particle.dt / np.timedelta64(1, "s") + particle.dlon -= 0.1 * dt + + pset = ParticleSet( + fieldset, + pclass=Particle, + lat=np.linspace(0, 1, 3), + lon=[0, 0, 0], + time=np.array([np.datetime64("2000-01-01") for _ in range(3)]), + ) + pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(1, "s"), dt=-np.timedelta64(1, "s"), output_file=pfile) + ds = xr.open_zarr(tmp_zarrfile) + trajs = ds["trajectory"][:] + assert trajs.values.dtype == "int64" + assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release + + +@pytest.mark.xfail +@pytest.mark.v4alpha +def test_write_xiyi(fieldset, tmp_zarrfile): + fieldset.U.data[:] = 1 # set a non-zero zonal velocity + fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2])) + dt = np.timedelta64(3600, "s") + + particle = get_default_particle(np.float64) + XiYiParticle = particle.add_variable( + [ + Variable("pxi0", dtype=np.int32, initial=0.0), + Variable("pxi1", dtype=np.int32, initial=0.0), + Variable("pyi", dtype=np.int32, initial=0.0), + ] + ) + + def Get_XiYi(particle, fieldset, time): # pragma: no cover + """Kernel to sample the grid indices of the particle. + Note that this sampling should be done _before_ the advection kernel + and that the first outputted value is zero. + Be careful when using multiple grids, as the index may be different for the grids. + """ + particle.pxi0 = fieldset.U.unravel_index(particle.ei)[2] + particle.pxi1 = fieldset.P.unravel_index(particle.ei)[2] + particle.pyi = fieldset.U.unravel_index(particle.ei)[1] + + def SampleP(particle, fieldset, time): # pragma: no cover + if time > 5 * 3600: + _ = fieldset.P[particle] # To trigger sampling of the P field + + pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1]) + pfile = ParticleFile(tmp_zarrfile, outputdt=dt) + pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) + + ds = xr.open_zarr(tmp_zarrfile) + pxi0 = ds["pxi0"][:].values.astype(np.int32) + pxi1 = ds["pxi1"][:].values.astype(np.int32) + lons = ds["lon"][:].values + pyi = ds["pyi"][:].values.astype(np.int32) + lats = ds["lat"][:].values + + for p in range(pyi.shape[0]): + assert (pxi0[p, 0] == 0) and (pxi0[p, -1] == pset[p].pxi0) # check that particle has moved + assert np.all(pxi1[p, :6] == 0) # check that particle has not been sampled on grid 1 until time 6 + assert np.all(pxi1[p, 6:] > 0) # check that particle has not been sampled on grid 1 after time 6 + for xi, lon in zip(pxi0[p, 1:], lons[p, 1:], strict=True): + assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi + 1] + for xi, lon in zip(pxi1[p, 6:], lons[p, 6:], strict=True): + assert fieldset.P.grid.lon[xi] <= lon < fieldset.P.grid.lon[xi + 1] + for yi, lat in zip(pyi[p, 1:], lats[p, 1:], strict=True): + assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi + 1] + + +@pytest.mark.skip +@pytest.mark.v4alpha +def test_reset_dt(fieldset, tmp_zarrfile): + # Assert that p.dt gets reset when a write_time is not a multiple of dt + # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions + + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += 0.1 + + particle = get_default_particle(np.float64) + pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) + ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "ms")) + dt = np.timedelta64(20, "ms") + pset.execute(pset.Kernel(Update_lon), runtime=6 * dt, dt=dt, output_file=ofile) + + assert np.allclose(pset.lon, 0.6) + + +@pytest.mark.v4alpha +@pytest.mark.xfail +def test_correct_misaligned_outputdt_dt(fieldset, tmp_zarrfile): + """Testing that outputdt does not need to be a multiple of dt.""" + + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += particle.dt / np.timedelta64(1, "s") + + particle = get_default_particle(np.float64) + pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) + ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s")) + pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(11, "s"), dt=np.timedelta64(2, "s"), output_file=ofile) + + ds = xr.open_zarr(tmp_zarrfile) + assert np.allclose(ds.lon.values, [0, 3, 6, 9]) + assert np.allclose( + ds.time.values[0, :], [np.timedelta64(t, "s") for t in [0, 3, 6, 9]], atol=np.timedelta64(1, "ns") + ) + + +def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwargs, particle_class=Particle): + npart = 10 + + pset = ParticleSet( + fieldset, + pclass=particle_class, + lon=np.full(npart, fieldset.U.data.lon.mean()), + lat=np.full(npart, fieldset.U.data.lat.mean()), + ) + + with tempfile.TemporaryDirectory() as dir: + name = f"{dir}/test.zarr" + output_file = ParticleFile(name, outputdt=outputdt) + + pset.execute(DoNothing, output_file=output_file, **execute_kwargs) + ds = xr.open_zarr(name).load() + + return ds + + +def test_pset_execute_outputdt_forwards(fieldset): + """Testing output data dt matches outputdt in forward time.""" + outputdt = timedelta(hours=1) + runtime = timedelta(hours=5) + dt = timedelta(minutes=5) + + ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) + + assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) + + +@pytest.mark.skip(reason="backwards in time not yet working") +def test_pset_execute_outputdt_backwards(fieldset): + """Testing output data dt matches outputdt in backwards time.""" + outputdt = timedelta(hours=1) + runtime = timedelta(days=2) + dt = -timedelta(minutes=5) + + ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) + file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values + assert np.all(file_outputdt == np.timedelta64(-outputdt)) + + +@pytest.mark.xfail(reason="TODO v4: Update dataset loading") +def test_pset_execute_outputdt_backwards_fieldset_timevarying(): + """test_pset_execute_outputdt_backwards() still passed despite #1722 as it doesn't account for time-varying fields, + which for some reason #1722 + """ + outputdt = timedelta(hours=1) + runtime = timedelta(days=2) + dt = -timedelta(minutes=5) + + # TODO: Not ideal using the `download_example_dataset` here, but I'm struggling to recreate this error using the test suite fieldsets we have + example_dataset_folder = parcels.download_example_dataset("MovingEddies_data") + filenames = { + "U": str(example_dataset_folder / "moving_eddiesU.nc"), + "V": str(example_dataset_folder / "moving_eddiesV.nc"), + } + variables = {"U": "vozocrtx", "V": "vomecrty"} + dimensions = {"lon": "nav_lon", "lat": "nav_lat", "time": "time_counter"} + fieldset = parcels.FieldSet.from_netcdf(filenames, variables, dimensions) + + ds = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) + file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values + assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) + + +def test_particlefile_init(tmp_store): + ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) + + +@pytest.mark.parametrize("name", ["store", "outputdt", "chunks", "create_new_zarrfile"]) +def test_particlefile_readonly_attrs(tmp_store, name): + pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) + with pytest.raises(AttributeError, match="property .* of 'ParticleFile' object has no setter"): + setattr(pfile, name, "something") + + +def test_particlefile_init_invalid(tmp_store): # TODO: Add test for read only store + with pytest.raises(ValueError, match="chunks must be a tuple"): + ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=1) + + +def test_particlefile_write_particle_data(tmp_store): + nparticles = 100 + + pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(nparticles, 40)) + pclass = Particle + + left, right = np.datetime64("2019-05-30T12:00:00.000000000", "ns"), np.datetime64("2020-01-02", "ns") + time_interval = TimeInterval(left=left, right=right) + + initial_lon = np.linspace(0, 1, nparticles) + data = create_particle_data( + pclass=pclass, + nparticles=nparticles, + ngrids=4, + time_interval=time_interval, + initial={ + "time": np.full(nparticles, fill_value=left), + "time_nextloop": np.full(nparticles, fill_value=left), + "lon": initial_lon, + "dt": np.full(nparticles, fill_value=1.0), + "trajectory": np.arange(nparticles), + }, + ) + np.testing.assert_array_equal(data["time"], left) + pfile._write_particle_data( + particle_data=data, + pclass=pclass, + time_interval=time_interval, + time=left, + ) + ds = xr.open_zarr(tmp_store, decode_cf=False) # TODO v4: Fix metadata and re-enable decode_cf + # assert ds.time.dtype == "datetime64[ns]" + # np.testing.assert_equal(ds["time"].isel(obs=0).values, left) + assert ds.sizes["trajectory"] == nparticles + np.testing.assert_allclose(ds["lon"].isel(obs=0).values, initial_lon) + + +def test_pfile_write_custom_particle(): + # Test the writing of a custom particle with variables that are to_write, some to_write once, and some not to_write + # ? This is more of an integration test... Should it be housed here? + ... + + +@pytest.mark.xfail( + reason="set_variable_write_status should be removed - with Particle writing defined on the particle level. GH2186" +) +def test_pfile_set_towrite_False(fieldset, tmp_zarrfile): + npart = 10 + pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart)) + pset.set_variable_write_status("depth", False) + pset.set_variable_write_status("lat", False) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += 0.1 + + pset.execute(Update_lon, runtime=10, output_file=pfile) + + ds = xr.open_zarr(tmp_zarrfile) + assert "time" in ds + assert "z" not in ds + assert "lat" not in ds + ds.close() + + # For pytest purposes, we need to reset to original status + pset.set_variable_write_status("depth", True) + pset.set_variable_write_status("lat", True) diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index 4e4ba21e12..80087d1207 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -17,6 +17,7 @@ from parcels._datasets.structured.generated import simple_UV_dataset from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.unstructured.generic import datasets as datasets_unstructured +from parcels.particlefile import ParticleFile from parcels.tools.statuscodes import FieldInterpolationError, FieldOutOfBoundError, TimeExtrapolationError from parcels.uxgrid import UxGrid from parcels.xgrid import XGrid @@ -433,7 +434,7 @@ def test_uxstommelgyre_pset_execute_output(): time=[0.0], pclass=Particle, ) - output_file = pset.ParticleFile( + output_file = ParticleFile( name="stommel_uxarray_particles.zarr", # the file name outputdt=np.timedelta64(5, "m"), # the time step of the outputs ) diff --git a/v3to4-breaking-changes.md b/v3to4-breaking-changes.md index 3afdb3e05f..f913a0d33e 100644 --- a/v3to4-breaking-changes.md +++ b/v3to4-breaking-changes.md @@ -16,3 +16,7 @@ - `repeatdt` and `lonlatdepth_dtype` have been removed from the ParticleSet. - ParticleSet.execute() expects `numpy.datetime64`/`numpy.timedelta.64` for `runtime`, `endtime` and `dt`. - `ParticleSet.from_field()`, `ParticleSet.from_line()`, `ParticleSet.from_list()` have been removed. + +## ParticleFile + +- Particlefiles should be created by `ParticleFile(...)` instead of `pset.ParticleFile(...)`