diff --git a/docs/examples/example_globcurrent.py b/docs/examples/example_globcurrent.py index c0d6b9103..543973568 100755 --- a/docs/examples/example_globcurrent.py +++ b/docs/examples/example_globcurrent.py @@ -196,7 +196,7 @@ def test_globcurrent_particle_independence(rundays=5): time0 = fieldset.U.grid.time[0] def DeleteP0(particle, fieldset, time): # pragma: no cover - if particle.id == 0: + if particle.trajectory == 0: particle.delete() pset0 = parcels.ParticleSet( @@ -244,7 +244,7 @@ def test_globcurrent_pset_fromfile(dt, pid_offset, tmpdir): pset.execute(parcels.AdvectionRK4, runtime=timedelta(days=1), dt=dt) pset_new.execute(parcels.AdvectionRK4, runtime=timedelta(days=1), dt=dt) - for var in ["lon", "lat", "depth", "time", "id"]: + for var in ["lon", "lat", "depth", "time", "trajectory"]: assert np.allclose( [getattr(p, var) for p in pset], [getattr(p, var) for p in pset_new] ) diff --git a/parcels/_compat.py b/parcels/_compat.py index 6efab15a7..0abe9753d 100644 --- a/parcels/_compat.py +++ b/parcels/_compat.py @@ -17,3 +17,24 @@ from sklearn.cluster import KMeans # type: ignore[no-redef] except ModuleNotFoundError: pass + + +# for compat with v3 of parcels when users provide `initial=attrgetter("lon")` to a Variable +# so that particle initial state matches another variable +class _AttrgetterHelper: + """ + Example usage + + >>> _attrgetter_helper = _AttrgetterHelper() + >>> _attrgetter_helper.some_attribute + 'some_attribute' + >>> from operator import attrgetter + >>> attrgetter('some_attribute')(_attrgetter_helper) + 'some_attribute' + """ + + def __getattr__(self, name): + return name + + +_attrgetter_helper = _AttrgetterHelper() diff --git a/parcels/application_kernels/interaction.py b/parcels/application_kernels/interaction.py index db3c4e04e..41f4b18dd 100644 --- a/parcels/application_kernels/interaction.py +++ b/parcels/application_kernels/interaction.py @@ -25,12 +25,12 @@ def NearestNeighborWithinRange(particle, fieldset, time, neighbors, mutator): # undesirable results. if dist < min_dist or min_dist < 0: min_dist = dist - neighbor_id = n.id + neighbor_id = n.trajectory def f(p, neighbor): p.nearest_neighbor = neighbor - mutator[particle.id].append((f, [neighbor_id])) + mutator[particle.trajectory].append((f, [neighbor_id])) return StatusCode.Success @@ -54,14 +54,14 @@ def merge_with_neighbor(p, nlat, nlon, ndepth, nmass): p.mass = p.mass + nmass for n in neighbors: - if n.id == particle.nearest_neighbor: - if n.nearest_neighbor == particle.id and particle.id < n.id: + if n.trajectory == particle.nearest_neighbor: + if n.nearest_neighbor == particle.trajectory and particle.trajectory < n.trajectory: # Merge particles: # Delete neighbor - mutator[n.id].append((delete_particle, ())) + mutator[n.trajectory].append((delete_particle, ())) # Take position at the mid point and sum of masses args = np.array([n.lat, n.lon, n.depth, n.mass]) - mutator[particle.id].append((merge_with_neighbor, args)) + mutator[particle.trajectory].append((merge_with_neighbor, args)) return StatusCode.Success else: @@ -101,6 +101,6 @@ def f(n, dlat, dlon, ddepth): n.lon_nextloop += dlon n.depth_nextloop += ddepth - mutator[n.id].append((f, d_vec)) + mutator[n.trajectory].append((f, d_vec)) return StatusCode.Success diff --git a/parcels/field.py b/parcels/field.py index 46b47d69a..1228b0a1f 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -16,7 +16,7 @@ VectorType, assert_valid_mesh, ) -from parcels.particle import Particle +from parcels.particle import KernelParticle from parcels.tools.converters import ( UnitConverter, unitconverters_map, @@ -37,9 +37,9 @@ def _deal_with_errors(error, key, vector_type: VectorType): - if isinstance(key, Particle): + if isinstance(key, KernelParticle): key.state = AllParcelsErrorCodes[type(error)] - elif isinstance(key[-1], Particle): + elif isinstance(key[-1], KernelParticle): key[-1].state = AllParcelsErrorCodes[type(error)] else: raise RuntimeError(f"{error}. Error could not be handled because particle was not part of the Field Sampling.") @@ -278,7 +278,7 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): def __getitem__(self, key): self._check_velocitysampling() try: - if isinstance(key, Particle): + if isinstance(key, KernelParticle): return self.eval(key.time, key.depth, key.lat, key.lon, key) else: return self.eval(*key) @@ -373,7 +373,7 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): def __getitem__(self, key): try: - if isinstance(key, Particle): + if isinstance(key, KernelParticle): return self.eval(key.time, key.depth, key.lat, key.lon, key) else: return self.eval(*key) diff --git a/parcels/interaction/interactionkernel.py b/parcels/interaction/interactionkernel.py index a4646a33f..039297f5d 100644 --- a/parcels/interaction/interactionkernel.py +++ b/parcels/interaction/interactionkernel.py @@ -157,7 +157,7 @@ def execute_python(self, pset, endtime, dt): for particle_idx in active_idx: p = pset[particle_idx] try: - for mutator_func, args in mutator[p.id]: + for mutator_func, args in mutator[p.trajectory]: mutator_func(p, *args) except KeyError: pass @@ -201,7 +201,7 @@ def execute(self, pset, endtime, dt, output_file=None): pass else: warnings.warn( - f"Deleting particle {p.id} because of non-recoverable error", + f"Deleting particle {p.trajectory} because of non-recoverable error", RuntimeWarning, stacklevel=2, ) diff --git a/parcels/kernel.py b/parcels/kernel.py index 73a91adb8..7bd4a59d3 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -259,7 +259,7 @@ def execute(self, pset, endtime, dt): pass else: warnings.warn( - f"Deleting particle {p.id} because of non-recoverable error", + f"Deleting particle {p.trajectory} because of non-recoverable error", RuntimeWarning, stacklevel=2, ) diff --git a/parcels/particle.py b/parcels/particle.py index 218614d0e..546f53cf6 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -1,9 +1,21 @@ +from __future__ import annotations + +import enum +import operator +from keyword import iskeyword from typing import Literal import numpy as np -import xarray as xr -__all__ = ["InteractionParticle", "Particle", "Variable"] +from parcels._compat import _attrgetter_helper +from parcels._core.utils.time import TimeInterval +from parcels._reprs import _format_list_items_multiline +from parcels.tools.statuscodes import StatusCode + +__all__ = ["KernelParticle", "Particle", "ParticleClass", "Variable"] +_TO_WRITE_OPTIONS = [True, False, "once"] + +_SAME_AS_FIELDSET_TIME_INTERVAL = enum.Enum("_SAME_AS_FIELDSET_TIME_INTERVAL", "VALUE") class Variable: @@ -21,94 +33,99 @@ class Variable: to_write : bool, 'once', optional Boolean or 'once'. Controls whether Variable is written to NetCDF file. If to_write = 'once', the variable will be written as a time-independent 1D array + attrs : dict, optional + Attributes to be stored with the variable when written to file. This can include metadata such as units, long_name, etc. """ - def __init__(self, name, dtype=np.float32, initial=0, to_write: bool | Literal["once"] = True): + def __init__( + self, + name, + dtype: np.dtype | _SAME_AS_FIELDSET_TIME_INTERVAL = np.float32, + initial=0, + to_write: bool | Literal["once"] = True, + attrs: dict | None = None, + ): + if not isinstance(name, str): + raise TypeError(f"Variable name must be a string. Got {name=!r}") + _assert_valid_python_varname(name) + + try: + dtype = np.dtype(dtype) + except (TypeError, ValueError): + if dtype is not _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + raise TypeError(f"Variable dtype must be a valid numpy dtype. Got {dtype=!r}") + + if to_write not in _TO_WRITE_OPTIONS: + raise ValueError(f"to_write must be one of {_TO_WRITE_OPTIONS!r}. Got {to_write=!r}") + + if attrs is None: + attrs = {} + + if not to_write: + if attrs != {}: + raise ValueError(f"Attributes cannot be set if {to_write=!r}.") + self._name = name self.dtype = dtype self.initial = initial self.to_write = to_write + self.attrs = attrs @property def name(self): return self._name - def __get__(self, instance, cls): - if instance is None: - return self - return getattr(instance, f"_{self.name}", self.initial) - - def __set__(self, instance, value): - setattr(instance, f"_{self.name}", value) - def __repr__(self): - return f"Variable(name={self._name}, dtype={self.dtype}, initial={self.initial}, to_write={self.to_write})" + return f"Variable(name={self._name!r}, dtype={self.dtype!r}, initial={self.initial!r}, to_write={self.to_write!r}, attrs={self.attrs!r})" -class ParticleType: - """Class encapsulating the type information for custom particles. +class ParticleClass: + """Define a class of particles. This is used to generate the particle data which is then used in the simulation. Parameters ---------- - user_vars : - Optional list of (name, dtype) tuples for custom variables + variables : list[Variable] + List of Variable objects that define the particle's attributes. + """ - def __init__(self, pclass): - if not isinstance(pclass, type): - raise TypeError("Class object required to derive ParticleType") - if not issubclass(pclass, Particle): - raise TypeError("Class object does not inherit from parcels.Particle") - self.name = pclass.__name__ - # Pick Variable objects out of __dict__. - self.variables = [v for v in pclass.__dict__.values() if isinstance(v, Variable)] - for cls in pclass.__bases__: - if issubclass(cls, Particle): - # Add inherited particle variables - ptype = cls.getPType() - for v in self.variables: - if v.name in [v.name for v in ptype.variables]: - raise AttributeError( - f"Custom Variable name '{v.name}' is not allowed, as it is also a built-in variable" - ) - if v.name == "z": - raise AttributeError( - "Custom Variable name 'z' is not allowed, as it is used for depth in ParticleFile" - ) - self.variables = ptype.variables + self.variables + def __init__(self, variables: list[Variable]): + if not isinstance(variables, list): + raise TypeError(f"Expected list of Variable objects, got {type(variables)}") + if not all(isinstance(var, Variable) for var in variables): + raise ValueError(f"All items in variables must be instances of Variable. Got {variables=!r}") + + self.variables = variables def __repr__(self): - return f"{type(self).__name__}(pclass={self.name})" + vars = [repr(v) for v in self.variables] + return f"ParticleClass(variables={_format_list_items_multiline(vars)})" - def __getitem__(self, item): - for v in self.variables: - if v.name == item: - return v + def add_variable(self, variable: Variable | list[Variable]): + """Add a new variable to the Particle class. This returns a new Particle class with the added variable(s). + Parameters + ---------- + variable : Variable or list[Variable] + Variable or list of Variables to be added to the Particle class. + If a list is provided, all variables will be added to the class. + """ + if isinstance(variable, Variable): + variable = [variable] -class Particle: - """Class encapsulating the basic attributes of a particle, to be executed in SciPy mode. + for var in variable: + if not isinstance(var, Variable): + raise TypeError(f"Expected Variable, got {type(var)}") + + _assert_no_duplicate_variable_names(existing_vars=self.variables, new_vars=variable) + + return ParticleClass(variables=self.variables + variable) - Parameters - ---------- - lon : float - Initial longitude of particle - lat : float - Initial latitude of particle - depth : float - Initial depth of particle - fieldset : parcels.fieldset.FieldSet - mod:`parcels.fieldset.FieldSet` object to track this particle on - time : float - Current time of the particle - - - Notes - ----- - Additional Variables can be added via the :Class Variable: objects - """ - def __init__(self, data: xr.Dataset, index: int): +class KernelParticle: + """Simple class to be used in a kernel that links a particle (on the kernel level) to a particle dataset.""" + + def __init__(self, data, index): self._data = data self._index = index @@ -121,41 +138,109 @@ def __setattr__(self, name, value): else: self._data[name][self._index] = value - @classmethod - def add_variable(cls, variable: Variable | list[Variable]): - """Add a new variable to the Particle class - Parameters - ---------- - variable : Variable or list[Variable] - Variable or list of Variables to be added to the Particle class. - If a list is provided, all variables will be added to the class. - """ - - class NewParticle(cls): - pass +def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_vars: list[Variable]): + existing_names = {var.name for var in existing_vars} + for var in new_vars: + if var.name in existing_names: + raise ValueError(f"Variable name already exists: {var.name}") + + +def _assert_valid_python_varname(name): + if name.isidentifier() and not iskeyword(name): + return + raise ValueError(f"Particle variable has to be a valid Python variable name. Got {name=!r}") + + +def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass: + if spatial_dtype not in [np.float32, np.float64]: + raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}") + + return ParticleClass( + variables=[ + Variable("lon", dtype=spatial_dtype), + Variable("lon_nextloop", dtype=spatial_dtype, to_write=False), + Variable("lat", dtype=spatial_dtype), + Variable("lat_nextloop", dtype=spatial_dtype, to_write=False), + Variable("depth", dtype=spatial_dtype), + Variable("dlon", dtype=spatial_dtype, to_write=False), + Variable("dlat", dtype=spatial_dtype, to_write=False), + Variable("ddepth", dtype=spatial_dtype, to_write=False), + Variable("depth_nextloop", dtype=spatial_dtype, to_write=False), + Variable("time", dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE), + Variable("time_nextloop", dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE, to_write=False), + Variable("trajectory", dtype=np.int64, to_write="once"), + Variable("obs_written", dtype=np.int32, initial=0, to_write=False), + Variable("dt", dtype="timedelta64[s]", initial=np.timedelta64(1, "s"), to_write=False), + Variable("state", dtype=np.int32, initial=StatusCode.Evaluate, to_write=False), + ] + ) + + +Particle = get_default_particle(np.float32) + + +def create_particle_data( + *, + pclass: ParticleClass, + nparticles: int, + ngrids: int, + time_interval: TimeInterval, + initial: dict[str, np.array] | None = None, +): + if initial is None: + initial = {} + + variables = {var.name: var for var in pclass.variables} + + assert "ei" not in initial, "'ei' is for internal use, and is unique since is only non 1D array" + + time_interval_dtype = _get_time_interval_dtype(time_interval) + + dtypes = {} + for var in variables.values(): + if var.dtype is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + dtypes[var.name] = time_interval_dtype + else: + dtypes[var.name] = var.dtype - if isinstance(variable, Variable): - setattr(NewParticle, variable.name, variable) - elif isinstance(variable, list): - for var in variable: - if not isinstance(var, Variable): - raise TypeError(f"Expected Variable, got {type(var)}") - setattr(NewParticle, var.name, var) - return NewParticle + for var_name in initial: + if var_name not in variables: + raise ValueError(f"Variable {var_name} is not defined in the ParticleClass.") - @classmethod - def getPType(cls): - return ParticleType(cls) + values = initial[var_name] + if values.shape != (nparticles,): + raise ValueError(f"Initial value for {var_name} must have shape ({nparticles},). Got {values.shape=}") + initial[var_name] = values.astype(dtypes[var_name]) -InteractionParticle = Particle.add_variable( - [Variable("vert_dist", dtype=np.float32), Variable("horiz_dist", dtype=np.float32)] -) + data = {"ei": np.zeros((nparticles, ngrids), dtype=np.int32), **initial} + vars_to_create = {k: v for k, v in variables.items() if k not in data} -class JITParticle(Particle): - def __init__(self, *args, **kwargs): - raise NotImplementedError( - "JITParticle has been deprecated in Parcels v4. Use Particle instead." - ) # TODO v4: link to migration guide + for var in vars_to_create.values(): + if isinstance(var.initial, operator.attrgetter): + name_to_copy = var.initial(_attrgetter_helper) + data[var.name] = data[name_to_copy].copy() + else: + data[var.name] = _create_array_for_variable(var, nparticles, time_interval) + return data + + +def _create_array_for_variable(variable: Variable, nparticles: int, time_interval: TimeInterval): + assert not isinstance(variable.initial, operator.attrgetter), ( + "This function cannot handle attrgetter initial values." + ) + if (dtype := variable.dtype) is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + dtype = type(time_interval.left) + return np.full( + shape=(nparticles,), + fill_value=variable.initial, + dtype=dtype, + ) + + +def _get_time_interval_dtype(time_interval: TimeInterval | None) -> np.dtype: + if time_interval is None: + return np.timedelta64(1, "ns") + return type(time_interval.left) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index e27a56243..33451caae 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -1,27 +1,40 @@ """Module controlling the writing of ParticleSets to Zarr file.""" +from __future__ import annotations + import os import warnings from datetime import timedelta +from typing import TYPE_CHECKING import numpy as np import xarray as xr import zarr +from zarr.storage import DirectoryStore import parcels -from parcels._compat import MPI from parcels._reprs import default_repr from parcels.tools._helpers import timedelta_to_float -from parcels.tools.warnings import FileWarning -__all__ = ["ParticleFile"] +if TYPE_CHECKING: + from pathlib import Path +__all__ = ["ParticleFile"] -def _set_calendar(origin_calendar): - if origin_calendar == "np_datetime64": - return "standard" - else: - return origin_calendar +_DATATYPES_TO_FILL_VALUES = { + np.float16: np.nan, + np.float32: np.nan, + np.float64: np.nan, + np.bool_: np.iinfo(np.int8).max, + np.int8: np.iinfo(np.int8).max, + np.int16: np.iinfo(np.int16).max, + np.int32: np.iinfo(np.int32).max, + np.int64: np.iinfo(np.int64).max, + np.uint8: np.iinfo(np.uint8).max, + np.uint16: np.iinfo(np.uint16).max, + np.uint32: np.iinfo(np.uint32).max, + np.uint64: np.iinfo(np.uint64).max, +} class ParticleFile: @@ -48,7 +61,7 @@ class ParticleFile: ParticleFile object that can be used to write particle data to file """ - def __init__(self, name, particleset, outputdt, chunks=None, create_new_zarrfile=True): + def __init__(self, store, particleset, outputdt, chunks=None, create_new_zarrfile=True): self._outputdt = timedelta_to_float(outputdt) self._chunks = chunks self._particleset = particleset @@ -63,7 +76,6 @@ def __init__(self, name, particleset, outputdt, chunks=None, create_new_zarrfile for var in self.particleset.particledata.ptype.variables: if var.to_write: self.vars_to_write[var.name] = var.dtype - self._mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 self.particleset.fieldset._particlefile = self # Reset obs_written of each particle, in case new ParticleFile created for a ParticleSet @@ -77,44 +89,15 @@ def __init__(self, name, particleset, outputdt, chunks=None, create_new_zarrfile "parcels_mesh": self._parcels_mesh, } - # Create dictionary to translate datatypes and fill_values - self._fill_value_map = { - np.float16: np.nan, - np.float32: np.nan, - np.float64: np.nan, - np.bool_: np.iinfo(np.int8).max, - np.int8: np.iinfo(np.int8).max, - np.int16: np.iinfo(np.int16).max, - np.int32: np.iinfo(np.int32).max, - np.int64: np.iinfo(np.int64).max, - np.uint8: np.iinfo(np.uint8).max, - np.uint16: np.iinfo(np.uint16).max, - np.uint32: np.iinfo(np.uint32).max, - np.uint64: np.iinfo(np.uint64).max, - } - if issubclass(type(name), zarr.storage.Store): - # If we already got a Zarr store, we won't need any of the naming logic below. - # But we need to handle incompatibility with MPI mode for now: - if MPI and MPI.COMM_WORLD.Get_size() > 1: - raise ValueError("Currently, MPI mode is not compatible with directly passing a Zarr store.") - fname = name + if isinstance(store, zarr.storage.Store): + self.store = store else: - extension = os.path.splitext(str(name))[1] - if extension in [".nc", ".nc4"]: - raise RuntimeError( - "Output in NetCDF is not supported anymore. Use .zarr extension for ParticleFile name." - ) - if MPI and MPI.COMM_WORLD.Get_size() > 1: - fname = os.path.join(name, f"proc{self._mpi_rank:02d}.zarr") - if extension in [".zarr"]: - warnings.warn( - f"The ParticleFile name contains .zarr extension, but zarr files will be written per processor in MPI mode at {fname}", - FileWarning, - stacklevel=2, - ) - else: - fname = name if extension in [".zarr"] else f"{name}.zarr" - self._fname = fname + self.store = _get_store_from_pathlike(store) + + if store.read_only: + raise ValueError(f"Store {store} is read-only. Please provide a writable store.") + + # TODO v4: Add check that if create_new_zarrfile is False, the store already exists def __repr__(self) -> str: return ( @@ -162,7 +145,7 @@ def _create_variables_attribute_dict(self): "trajectory": { "long_name": "Unique identifier for each particle", "cf_role": "trajectory_id", - "_FillValue": self._fill_value_map[np.int64], + "_FillValue": _DATATYPES_TO_FILL_VALUES[np.int64], }, "time": {"long_name": "", "standard_name": "time", "units": "seconds", "axis": "T"}, "lon": {"long_name": "", "standard_name": "longitude", "units": "degrees_east", "axis": "X"}, @@ -173,9 +156,9 @@ def _create_variables_attribute_dict(self): attrs["time"]["calendar"] = "None" # TODO fix calendar for vname in self.vars_to_write: - if vname not in ["time", "lat", "lon", "depth", "id"]: + if vname not in ["time", "lat", "lon", "depth", "trajectory"]: attrs[vname] = { - "_FillValue": self._fill_value_map[self.vars_to_write[vname]], + "_FillValue": _DATATYPES_TO_FILL_VALUES[self.vars_to_write[vname]], "long_name": "", "standard_name": vname, "units": "unknown", @@ -186,8 +169,6 @@ def _create_variables_attribute_dict(self): def _convert_varout_name(self, var): if var == "depth": return "z" - elif var == "id": - return "trajectory" else: return var @@ -196,16 +177,16 @@ def _write_once(self, var): def _extend_zarr_dims(self, Z, store, dtype, axis): if axis == 1: - a = np.full((Z.shape[0], self.chunks[1]), self._fill_value_map[dtype], dtype=dtype) + a = np.full((Z.shape[0], self.chunks[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) obs = zarr.group(store=store, overwrite=False)["obs"] if len(obs) == Z.shape[1]: obs.append(np.arange(self.chunks[1]) + obs[-1] + 1) else: extra_trajs = self._maxids - Z.shape[0] if len(Z.shape) == 2: - a = np.full((extra_trajs, Z.shape[1]), self._fill_value_map[dtype], dtype=dtype) + a = np.full((extra_trajs, Z.shape[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) else: - a = np.full((extra_trajs,), self._fill_value_map[dtype], dtype=dtype) + a = np.full((extra_trajs,), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) Z.append(a, axis=axis) zarr.consolidate_metadata(store) @@ -238,7 +219,7 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N if len(indices_to_write) == 0: return - pids = pset.particledata.getvardata("id", indices_to_write) + pids = pset.particledata.getvardata("trajectory", indices_to_write) to_add = sorted(set(pids) - set(self._pids_written.keys())) for i, pid in enumerate(to_add): self._pids_written[pid] = self._maxids + i @@ -250,6 +231,7 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N ids_once = ids[once_ids] indices_to_write_once = indices_to_write[once_ids] + store = self.store if self.create_new_zarrfile: if self.chunks is None: self._chunks = (len(pset), 1) @@ -269,27 +251,22 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N if self._write_once(var): data = np.full( (arrsize[0],), - self._fill_value_map[self.vars_to_write[var]], + _DATATYPES_TO_FILL_VALUES[self.vars_to_write[var]], dtype=self.vars_to_write[var], ) data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) dims = ["trajectory"] else: data = np.full( - arrsize, self._fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var] + arrsize, _DATATYPES_TO_FILL_VALUES[self.vars_to_write[var]], dtype=self.vars_to_write[var] ) data[ids, 0] = pset.particledata.getvardata(var, indices_to_write) dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index] - ds.to_zarr(self.fname, mode="w") + ds.to_zarr(store, mode="w") self._create_new_zarrfile = False else: - # Either use the store that was provided directly or create a DirectoryStore: - if isinstance(self.fname, zarr.storage.Store): - store = self.fname - else: - store = zarr.DirectoryStore(self.fname) Z = zarr.group(store=store, overwrite=False) obs = pset.particledata.getvardata("obs_written", indices_to_write) for var in self.vars_to_write: @@ -322,3 +299,12 @@ def write_latest_locations(self, pset, time): pset.particledata.setallvardata(f"{var}", pset.particledata.getvardata(f"{var}_nextloop")) self.write(pset, time) + + +def _get_store_from_pathlike(path: Path | str) -> DirectoryStore: + path = str(Path(path)) # Ensure valid path, and convert to string + extension = os.path.splitext(path)[1] + if extension != ".zarr": + raise ValueError(f"ParticleFile name must end with '.zarr' extension. Got path {path!r}.") + + return DirectoryStore(path) diff --git a/parcels/particleset.py b/parcels/particleset.py index c9b16cfde..2647800e1 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1,7 +1,6 @@ import sys import warnings from collections.abc import Iterable -from operator import attrgetter import numpy as np import xarray as xr @@ -13,7 +12,7 @@ from parcels.application_kernels.advection import AdvectionRK4 from parcels.basegrid import GridType from parcels.kernel import Kernel -from parcels.particle import Particle, Variable +from parcels.particle import KernelParticle, Particle, create_particle_data from parcels.particlefile import ParticleFile from parcels.tools.converters import convert_to_flat_array from parcels.tools.loggers import logger @@ -70,7 +69,6 @@ def __init__( lat=None, depth=None, time=None, - lonlatdepth_dtype=None, trajectory_ids=None, **kwargs, ): @@ -112,13 +110,6 @@ def __init__( if fieldset.time_interval: _warn_particle_times_outside_fieldset_time_bounds(time, fieldset.time_interval) - if lonlatdepth_dtype is None: - lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U) - assert lonlatdepth_dtype in [ - np.float32, - np.float64, - ], "lon lat depth precision should be set to either np.float32 or np.float64" - for kwvar in kwargs: if kwvar not in ["partition_function"]: kwargs[kwvar] = convert_to_flat_array(kwargs[kwvar]) @@ -126,36 +117,29 @@ def __init__( f"{kwvar} and positions (lon, lat, depth) don't have the same lengths." ) - self._data = { - "lon": lon.astype(lonlatdepth_dtype), - "lat": lat.astype(lonlatdepth_dtype), - "depth": depth.astype(lonlatdepth_dtype), - "dlon": np.zeros(lon.size, dtype=lonlatdepth_dtype), - "dlat": np.zeros(lon.size, dtype=lonlatdepth_dtype), - "ddepth": np.zeros(lon.size, dtype=lonlatdepth_dtype), - "time": time, - "dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)), - # "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)), - "state": np.zeros((len(trajectory_ids)), dtype=np.int32), - "lon_nextloop": lon.astype(lonlatdepth_dtype), - "lat_nextloop": lat.astype(lonlatdepth_dtype), - "depth_nextloop": depth.astype(lonlatdepth_dtype), - "time_nextloop": time, - "trajectory": trajectory_ids, - } - self._ptype = pclass.getPType() - # add extra fields from the custom Particle class - for v in pclass.__dict__.values(): - if isinstance(v, Variable): - if isinstance(v.initial, attrgetter): - initial = v.initial(self) - else: - initial = [np.array(v.initial, dtype=v.dtype)] * len(trajectory_ids) - self._data[v.name] = initial + self._data = create_particle_data( + pclass=pclass, + nparticles=lon.size, + ngrids=len(fieldset.gridset), + time_interval=fieldset.time_interval, + initial=dict( + lon=lon, + lat=lat, + depth=depth, + time=time, + lon_nextloop=lon, + lat_nextloop=lat, + depth_nextloop=depth, + time_nextloop=time, + trajectory=trajectory_ids, + ), + ) + self._ptype = pclass - # update initial values provided on ParticleSet creation + # update initial values provided on ParticleSet creation # TODO: Wrap this into create_particle_data + particle_variables = [v.name for v in pclass.variables] for kwvar, kwval in kwargs.items(): - if not hasattr(pclass, kwvar): + if kwvar not in particle_variables: raise RuntimeError(f"Particle class does not have Variable {kwvar}") self._data[kwvar][:] = kwval @@ -190,7 +174,7 @@ def __getattr__(self, name): def __getitem__(self, index): """Get a single particle by index.""" - return Particle(self._data, index=index) + return KernelParticle(self._data, index=index) @staticmethod def lonlatdepth_dtype_from_field_interp_method(field): @@ -310,7 +294,7 @@ def _neighbors_by_index(self, particle_idx): def _neighbors_by_coor(self, coor): neighbor_idx = self._neighbor_tree.find_neighbors_by_coor(coor) - neighbor_ids = self._data["id"][neighbor_idx] + neighbor_ids = self._data["trajectory"][neighbor_idx] return neighbor_ids # TODO: This method is only tested in tutorial notebook. Add unit test? diff --git a/tests/test_interaction.py b/tests/test_interaction.py index 7c0e4fc08..c20fdda88 100644 --- a/tests/test_interaction.py +++ b/tests/test_interaction.py @@ -32,7 +32,7 @@ def DummyMoveNeighbor(particle, fieldset, time, neighbors, mutator): def f(p): p.lat_nextloop += 0.1 - neighbor_id = neighbors[i_min_dist].id + neighbor_id = neighbors[i_min_dist].trajectory mutator[neighbor_id].append((f, ())) pass @@ -172,7 +172,7 @@ def ConstantMoveInteraction(particle, fieldset, time, neighbors, mutator): def f(p): p.lat_nextloop += p.dt - mutator[particle.id].append((f, ())) + mutator[particle.trajectory].append((f, ())) @pytest.mark.parametrize("runtime, dt", [(1, 1e-2), (1, -2.123e-3), (1, -3.12452 - 3)]) diff --git a/tests/test_kernel_language.py b/tests/test_kernel_language.py index 70b28c5e9..76b3ad1ba 100644 --- a/tests/test_kernel_language.py +++ b/tests/test_kernel_language.py @@ -230,14 +230,16 @@ def test_print(fieldset_unit_mesh, capfd): def kernel(particle, fieldset, time): # pragma: no cover particle.p = 1e-3 tmp = 5 - print(f"{particle.id} {particle.p:f} {tmp:f}") + print(f"{particle.trajectory} {particle.p:f} {tmp:f}") pset.execute(kernel, endtime=1.0, dt=1.0, verbose_progress=False) out, err = capfd.readouterr() lst = out.split(" ") tol = 1e-8 assert ( - abs(float(lst[0]) - pset.id[0]) < tol and abs(float(lst[1]) - pset.p[0]) < tol and abs(float(lst[2]) - 5) < tol + abs(float(lst[0]) - pset.trajectory[0]) < tol + and abs(float(lst[1]) - pset.p[0]) < tol + and abs(float(lst[2]) - 5) < tol ) def kernel2(particle, fieldset, time): # pragma: no cover diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index 6c0fc74b8..39cee4ae1 100755 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -108,6 +108,7 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): ds.close() +@pytest.mark.xfail(reason="lonlatdepth_dtype removed. Update implementation to use a different particle") def test_variable_write_double(fieldset, tmp_zarrfile): def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += 0.1 diff --git a/tests/test_particlesets.py b/tests/test_particlesets.py index fc186519d..ed884f595 100644 --- a/tests/test_particlesets.py +++ b/tests/test_particlesets.py @@ -65,7 +65,7 @@ def Kernel(particle, fieldset, time): # pragma: no cover assert np.allclose([getattr(p, var) for p in pset], [getattr(p, var) for p in pset_new]) if restart: - assert np.allclose([p.id for p in pset], [p.id for p in pset_new]) + assert np.allclose([p.trajectory for p in pset], [p.trajectory for p in pset_new]) pset_new.execute(Kernel, runtime=2, dt=1) assert len(pset_new) == 3 * len(pset) assert pset[0].p3.dtype == np.float64 diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index b53946c3a..524450147 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -92,7 +92,7 @@ def test_horizontal_advection_in_3D_flow(npart=10): pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m")) expected_lon = pset.depth * (pset.time - fieldset.time_interval.left) / np.timedelta64(1, "s") - np.testing.assert_allclose(expected_lon, pset.lon_nextloop, atol=1.0e-1) + np.testing.assert_allclose(pset.lon, expected_lon, atol=1.0e-1) @pytest.mark.parametrize("direction", ["up", "down"]) @@ -255,8 +255,8 @@ def test_moving_eddy(method, rtol): if method == "RK45": # Use RK45Particles to set next_dt - RK45Particles = Particle.add_variable(Variable("next_dt", initial=dt, dtype=np.timedelta64)) - fieldset.add_constant("RK45_tol", 1e-6) + RK45Particles = Particle.add_variable(Variable("next_dt", initial=dt, dtype="timedelta64[s]")) + fieldset.add_constant("RK45_tol", 1e-3) pclass = RK45Particles if method == "RK45" else Particle pset = ParticleSet( @@ -282,7 +282,7 @@ def truth_moving(x_0, y_0, t): [ ("EE", 1e-2), ("RK4", 1e-5), - ("RK45", 1e-5), + ("RK45", 1e-3), ], ) def test_decaying_moving_eddy(method, rtol): @@ -298,8 +298,8 @@ def test_decaying_moving_eddy(method, rtol): if method == "RK45": # Use RK45Particles to set next_dt - RK45Particles = Particle.add_variable(Variable("next_dt", initial=dt, dtype=np.timedelta64)) - fieldset.add_constant("RK45_tol", 1e-6) + RK45Particles = Particle.add_variable(Variable("next_dt", initial=dt, dtype="timedelta64[s]")) + fieldset.add_constant("RK45_tol", 1e-3) pclass = RK45Particles if method == "RK45" else Particle @@ -362,10 +362,10 @@ def test_gyre_flowfields(method, grid_type, atol, flowfield): [ Variable("p", initial=0.0, dtype=np.float32), Variable("p_start", initial=0.0, dtype=np.float32), - Variable("next_dt", initial=dt, dtype=np.timedelta64), + Variable("next_dt", initial=dt, dtype="timedelta64[s]"), ] ) - fieldset.add_constant("RK45_tol", 1e-6) + fieldset.add_constant("RK45_tol", 1e-3) else: SampleParticle = Particle.add_variable( [Variable("p", initial=0.0, dtype=np.float32), Variable("p_start", initial=0.0, dtype=np.float32)] diff --git a/tests/v4/test_particle.py b/tests/v4/test_particle.py new file mode 100644 index 000000000..3ca0a6c13 --- /dev/null +++ b/tests/v4/test_particle.py @@ -0,0 +1,159 @@ +import numpy as np +import pytest + +from parcels._core.utils.time import TimeInterval +from parcels._datasets.structured.generic import TIME +from parcels.particle import _SAME_AS_FIELDSET_TIME_INTERVAL, Particle, ParticleClass, Variable, create_particle_data + + +def test_variable_init(): + var = Variable("test") + assert var.name == "test" + assert var.dtype == np.float32 + assert var.to_write + assert var.attrs == {} + + +def test_variable_invalid_init(): + with pytest.raises(ValueError, match="to_write must be one of .*\. Got to_write="): + Variable("name", to_write="test") + + with pytest.raises(ValueError, match="to_write must be one of .*\. Got to_write="): + Variable("name", to_write="test") + + for name in ["a b", "123", "while"]: + with pytest.raises(ValueError, match="Particle variable has to be a valid Python variable name. Got "): + Variable(name) + + with pytest.raises(ValueError, match="Attributes cannot be set if to_write=False"): + Variable("name", to_write=False, attrs={"description": "metadata to write"}) + + +@pytest.mark.parametrize( + "variable, expected", + [ + ( + Variable("test", np.float32, 0.0, True, {"some": "metadata"}), + "Variable(name='test', dtype=dtype('float32'), initial=0.0, to_write=True, attrs={'some': 'metadata'})", + ), + ( + Variable("test", np.float32, 0.0, True), + "Variable(name='test', dtype=dtype('float32'), initial=0.0, to_write=True, attrs={})", + ), + ], +) +def test_variable_repr(variable, expected): + assert repr(variable) == expected + + +def test_particleclass_init(): + ParticleClass( + variables=[ + Variable("vara", dtype=np.float32), + Variable("varb", dtype=np.float32, to_write=False), + Variable("varc", dtype=np.float32), + ] + ) + + +def test_particleclass_invalid_vars(): + with pytest.raises(ValueError, match="All items in variables must be instances of Variable. Got"): + ParticleClass(variables=[Variable("vara", dtype=np.float32), "not a variable class"]) + + with pytest.raises(TypeError, match="Expected list of Variable objects, got "): + ParticleClass(variables="not a list") + + +@pytest.mark.parametrize( + "obj, expected", + [ + ( + ParticleClass( + variables=[ + Variable("vara", dtype=np.float32, to_write=True), + Variable("varb", dtype=np.float32, to_write=False), + Variable("varc", dtype=np.float32, to_write=True), + ] + ), + """ParticleClass(variables=[ + Variable(name='vara', dtype=dtype('float32'), initial=0, to_write=True, attrs={}), + Variable(name='varb', dtype=dtype('float32'), initial=0, to_write=False, attrs={}), + Variable(name='varc', dtype=dtype('float32'), initial=0, to_write=True, attrs={}) +])""", + ), + ], +) +def test_particleclass_repr(obj, expected): + assert repr(obj) == expected + + +def test_particleclass_add_variable(): + p_initial = ParticleClass(variables=[Variable("vara", dtype=np.float32)]) + variables = [ + Variable("varb", dtype=np.float32, to_write=True), + Variable("varc", dtype=np.float32, to_write=False), + ] + p_final = p_initial.add_variable(variables) + + assert len(p_final.variables) == 3 + assert p_final.variables[0].name == "vara" + assert p_final.variables[1].name == "varb" + assert p_final.variables[2].name == "varc" + + +def test_particleclass_add_variable_in_loop(): + p = ParticleClass(variables=[Variable("vara", dtype=np.float32)]) + vars = [Variable("sample_var"), Variable("sample_var2")] + p_loop = p + for var in vars: + p_loop = p_loop.add_variable(var) + + p_list = p.add_variable(vars) + + for var1, var2 in zip(p_loop.variables, p_list.variables, strict=True): + assert var1.name == var2.name + assert var1.dtype == var2.dtype + assert var1.to_write == var2.to_write + + +def test_particleclass_add_variable_collision(): + p_initial = ParticleClass(variables=[Variable("vara", dtype=np.float32)]) + + with pytest.raises(ValueError, match="Variable name already exists: "): + p_initial.add_variable([Variable("vara", dtype=np.float32, to_write=True)]) + + +@pytest.mark.parametrize( + "particle", + [ + ParticleClass( + variables=[ + Variable("vara", dtype=np.float32, initial=1.0), + Variable("varb", dtype=np.float32, initial=2.0), + ] + ), + Particle, + ], +) +@pytest.mark.parametrize("nparticles", [5, 10]) +def test_create_particle_data(particle, nparticles): + time_interval = TimeInterval(TIME[0], TIME[-1]) + ngrids = 4 + data = create_particle_data(pclass=particle, nparticles=nparticles, ngrids=ngrids, time_interval=time_interval) + + assert isinstance(data, dict) + assert len(data) == len(particle.variables) + 1 # ei variable is separate + + variables = {var.name: var for var in particle.variables} + + for variable_name in variables.keys(): + variable = variables[variable_name] + variable_array = data[variable_name] + + assert variable_array.shape[0] == nparticles + + dtype = variable.dtype + if dtype is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + dtype = type(time_interval.left) + + assert variable_array.dtype == dtype diff --git a/tests/v4/test_particleset.py b/tests/v4/test_particleset.py index d21322790..e1aedb13c 100644 --- a/tests/v4/test_particleset.py +++ b/tests/v4/test_particleset.py @@ -213,7 +213,7 @@ def test_pset_merge_inplace(fieldset, npart=100): def test_pset_remove_index(fieldset, npart=100): lon = np.linspace(0, 1, npart) lat = np.linspace(1, 0, npart) - pset = ParticleSet(fieldset, lon=lon, lat=lat, lonlatdepth_dtype=np.float64) + pset = ParticleSet(fieldset, lon=lon, lat=lat) indices_to_remove = [0, 10, 20] pset.remove_indices(indices_to_remove) assert pset.size == 97 diff --git a/v3to4-breaking-changes.md b/v3to4-breaking-changes.md index 83c68d7f5..46004e0a1 100644 --- a/v3to4-breaking-changes.md +++ b/v3to4-breaking-changes.md @@ -11,6 +11,6 @@ FieldSet ParticleSet -- ParticleSet init had repeatdt removed +- ParticleSet init had `repeatdt` and `lonlatdepth_dtype` removed - ParticleSet.execute() expects `numpy.datetime64`/`numpy.timedelta.64` for `runtime`, `endtime` and `dt` - `ParticleSet.from_field()`, `ParticleSet.from_line()`, `ParticleSet.from_list()` has been removed