diff --git a/parcels/_typing.py b/parcels/_typing.py index 92af7ba2fa..9e530c6be7 100644 --- a/parcels/_typing.py +++ b/parcels/_typing.py @@ -37,7 +37,7 @@ class ParcelsAST(ast.AST): Mesh = Literal["spherical", "flat"] # corresponds with `mesh` VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type` ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode` -GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `gridindexingtype` +GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo","fesom2","icon2"] # corresponds with `gridindexingtype` UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `_update_status` TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `time_periodic` NetcdfEngine = Literal["netcdf4", "xarray", "scipy"] diff --git a/parcels/basefield.py b/parcels/basefield.py new file mode 100644 index 0000000000..8ac4c41683 --- /dev/null +++ b/parcels/basefield.py @@ -0,0 +1,420 @@ +import collections +import datetime +import math +import warnings +from collections.abc import Iterable +from ctypes import POINTER, Structure, c_float, c_int, pointer +from pathlib import Path +from typing import TYPE_CHECKING + +import dask.array as da +import numpy as np +import xarray as xr + +import parcels.tools.interpolation_utils as i_u +from parcels._typing import ( + GridIndexingType, + InterpMethod, + Mesh, + TimePeriodic, + VectorType, + assert_valid_gridindexingtype, + assert_valid_interp_method, +) +from parcels.tools._helpers import deprecated_made_private +from parcels.tools.converters import ( + Geographic, + GeographicPolar, + TimeConverter, + UnitConverter, + unitconverters_map, +) +from parcels.tools.statuscodes import ( + AllParcelsErrorCodes, + FieldOutOfBoundError, + FieldOutOfBoundSurfaceError, + FieldSamplingError, + TimeExtrapolationError, +) +from parcels.tools.warnings import FieldSetWarning, _deprecated_param_netcdf_decodewarning + +from .fieldfilebuffer import ( + DaskFileBuffer, + DeferredDaskFileBuffer, + DeferredNetcdfFileBuffer, + NetcdfFileBuffer, +) +from .grid import CGrid, Grid, GridType +from .ugrid import UGrid + +if TYPE_CHECKING: + from ctypes import _Pointer as PointerType + + from parcels.fieldset import FieldSet + +__all__ = ["Field", "VectorField", "NestedField"] + + +def _isParticle(key): + if hasattr(key, "obs_written"): + return True + else: + return False + + +def _deal_with_errors(error, key, vector_type: VectorType): + if _isParticle(key): + key.state = AllParcelsErrorCodes[type(error)] + elif _isParticle(key[-1]): + key[-1].state = AllParcelsErrorCodes[type(error)] + else: + raise RuntimeError(f"{error}. Error could not be handled because particle was not part of the Field Sampling.") + + if vector_type == "3D": + return (0, 0, 0) + elif vector_type == "2D": + return (0, 0) + else: + return 0 + + +class BaseField: + """Class that encapsulates access to field data. + + Parameters + ---------- + name : str + Name of the field + data : np.ndarray + 2D, 3D or 4D numpy array of field data. + + 1. If data shape is [xdim, ydim], [xdim, ydim, zdim], [xdim, ydim, tdim] or [xdim, ydim, zdim, tdim], + whichever is relevant for the dataset, use the flag transpose=True + 2. If data shape is [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], + use the flag transpose=False + 3. If data has any other shape, you first need to reorder it + lon : np.ndarray or list + Longitude coordinates (numpy vector or array) of the field (only if grid is None) + lat : np.ndarray or list + Latitude coordinates (numpy vector or array) of the field (only if grid is None) + depth : np.ndarray or list + Depth coordinates (numpy vector or array) of the field (only if grid is None) + time : np.ndarray + Time coordinates (numpy vector) of the field (only if grid is None) + mesh : str + String indicating the type of mesh coordinates and + units used during velocity interpolation: (only if grid is None) + + 1. spherical: Lat and lon in degree, with a + correction for zonal velocity U near the poles. + 2. flat (default): No conversion, lat/lon are assumed to be in m. + timestamps : np.ndarray + A numpy array containing the timestamps for each of the files in filenames, for loading + from netCDF files only. Default is None if the netCDF dimensions dictionary includes time. + grid : parcels.grid.Grid + :class:`parcels.grid.Grid` object containing all the lon, lat depth, time + mesh and time_origin information. Can be constructed from any of the Grid objects + fieldtype : str + Type of Field to be used for UnitConverter (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) + transpose : bool + Transpose data to required (lon, lat) layout + vmin : float + Minimum allowed value on the field. Data below this value are set to zero + vmax : float + Maximum allowed value on the field. Data above this value are set to zero + cast_data_dtype : str + Cast Field data to dtype. Supported dtypes are "float32" (np.float32 (default)) and "float64 (np.float64). + Note that dtype can only be "float32" in JIT mode + time_origin : parcels.tools.converters.TimeConverter + Time origin of the time axis (only if grid is None) + interp_method : str + Method for interpolation. Options are 'linear' (default), 'nearest', + 'linear_invdist_land_tracer', 'cgrid_velocity', 'cgrid_tracer' and 'bgrid_velocity' + allow_time_extrapolation : bool + boolean whether to allow for extrapolation in time + (i.e. beyond the last available time snapshot) + time_periodic : bool, float or datetime.timedelta + To loop periodically over the time component of the Field. It is set to either False or the length of the period (either float in seconds or datetime.timedelta object). + The last value of the time series can be provided (which is the same as the initial one) or not (Default: False) + This flag overrides the allow_time_extrapolation and sets it to False + chunkdims_name_map : str, optional + Gives a name map to the FieldFileBuffer that declared a mapping between chunksize name, NetCDF dimension and Parcels dimension; + required only if currently incompatible OCM field is loaded and chunking is used by 'chunksize' (which is the default) + to_write : bool + Write the Field in NetCDF format at the same frequency as the ParticleFile outputdt, + using a filenaming scheme based on the ParticleFile name + + Examples + -------- + For usage examples see the following tutorials: + + * `Nested Fields <../examples/tutorial_NestedFields.ipynb>`__ + """ + + def __init__( + self, + name: str | tuple[str, str], + data, + lon=None, + lat=None, + face_node_connectivity=None, + depth=None, + time=None, + grid=None, + mesh: Mesh = "flat", + timestamps=None, + fieldtype=None, + transpose=False, + vmin=None, + vmax=None, + cast_data_dtype="float32", + time_origin=None, + interp_method: InterpMethod = "linear", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + gridindexingtype: GridIndexingType = "nemo", + to_write=False, + **kwargs, + ): + if kwargs.get("netcdf_decodewarning") is not None: + _deprecated_param_netcdf_decodewarning() + kwargs.pop("netcdf_decodewarning") + + if not isinstance(name, tuple): + self.name = name + self.filebuffername = name + else: + self.name = name[0] + self.filebuffername = name[1] + self.data = data + if grid: + if grid.defer_load and isinstance(data, np.ndarray): + raise ValueError( + "Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately" + ) + self._grid = grid + else: + if (time is not None) and isinstance(time[0], np.datetime64): + time_origin = TimeConverter(time[0]) + time = np.array([time_origin.reltime(t) for t in time]) + else: + time_origin = TimeConverter(0) + + # joe@fluidnumerics.com + # This allows for the creation of a UGrid object + if face_node_connectivity is None: + self._grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) + else: + self._grid = UGrid.create_grid(lon, lat, depth, time, face_node_connectivity, time_origin=time_origin, mesh=mesh) + + + self.igrid = -1 + self.fieldtype = self.name if fieldtype is None else fieldtype + self.to_write = to_write + if self.grid.mesh == "flat" or (self.fieldtype not in unitconverters_map.keys()): + self.units = UnitConverter() + elif self.grid.mesh == "spherical": + self.units = unitconverters_map[self.fieldtype] + else: + raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") + self.timestamps = timestamps + if isinstance(interp_method, dict): + if self.name in interp_method: + self.interp_method = interp_method[self.name] + else: + raise RuntimeError(f"interp_method is a dictionary but {name} is not in it") + else: + self.interp_method = interp_method + assert_valid_gridindexingtype(gridindexingtype) + self._gridindexingtype = gridindexingtype + if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"] and self.grid._gtype in [ + GridType.RectilinearSGrid, + GridType.CurvilinearSGrid, + ]: + warnings.warn( + "General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal.", + FieldSetWarning, + stacklevel=2, + ) + + self.fieldset: FieldSet | None = None + if allow_time_extrapolation is None: + self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False + else: + self.allow_time_extrapolation = allow_time_extrapolation + + self.time_periodic = time_periodic + if self.time_periodic is not False and self.allow_time_extrapolation: + warnings.warn( + "allow_time_extrapolation and time_periodic cannot be used together. allow_time_extrapolation is set to False", + FieldSetWarning, + stacklevel=2, + ) + self.allow_time_extrapolation = False + if self.time_periodic is True: + raise ValueError( + "Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object." + ) + if self.time_periodic is not False: + if isinstance(self.time_periodic, datetime.timedelta): + self.time_periodic = self.time_periodic.total_seconds() + if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic): + if self.grid.time[-1] - self.grid.time[0] > self.time_periodic: + raise ValueError("Time series provided is longer than the time_periodic parameter") + self.grid._add_last_periodic_data_timestep = True + self.grid.time = np.append(self.grid.time, self.grid.time[0] + self.time_periodic) + self.grid.time_full = self.grid.time + + self.vmin = vmin + self.vmax = vmax + self._cast_data_dtype = cast_data_dtype + if self.cast_data_dtype == "float32": + self._cast_data_dtype = np.float32 + elif self.cast_data_dtype == "float64": + self._cast_data_dtype = np.float64 + + if not self.grid.defer_load: + self.data = self._reshape(self.data, transpose) + + # Hack around the fact that NaN and ridiculously large values + # propagate in SciPy's interpolators + lib = np if isinstance(self.data, np.ndarray) else da + self.data[lib.isnan(self.data)] = 0.0 + if self.vmin is not None: + self.data[self.data < self.vmin] = 0.0 + if self.vmax is not None: + self.data[self.data > self.vmax] = 0.0 + + if self.grid._add_last_periodic_data_timestep: + self.data = lib.concatenate((self.data, self.data[:1, :]), axis=0) + + self._scaling_factor = None + + # Variable names in JIT code + self._dimensions = kwargs.pop("dimensions", None) + self.indices = kwargs.pop("indices", None) + self._dataFiles = kwargs.pop("dataFiles", None) + if self.grid._add_last_periodic_data_timestep and self._dataFiles is not None: + self._dataFiles = np.append(self._dataFiles, self._dataFiles[0]) + self._field_fb_class = kwargs.pop("FieldFileBuffer", None) + self._netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") + self._loaded_time_indices: Iterable[int] = [] # type: ignore + self._creation_log = kwargs.pop("creation_log", "") + self.chunksize = kwargs.pop("chunksize", None) + self.netcdf_chunkdims_name_map = kwargs.pop("chunkdims_name_map", None) + self.grid.depth_field = kwargs.pop("depth_field", None) + + if self.grid.depth_field == "not_yet_set": + assert ( + self.grid._z4d + ), "Providing the depth dimensions from another field data is only available for 4d S grids" + + # data_full_zdim is the vertical dimension of the complete field data, ignoring the indices. + # (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid, + # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). + self.data_full_zdim = kwargs.pop("data_full_zdim", None) + self._data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays + self._c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array + self.nchunks: tuple[int, ...] = () + self._chunk_set: bool = False + self.filebuffers = [None] * 2 + if len(kwargs) > 0: + raise SyntaxError(f'Field received an unexpected keyword argument "{list(kwargs.keys())[0]}"') + + @property + def dimensions(self): + return self._dimensions + + @property + def grid(self): + return self._grid + + @property + def lon(self): + """Lon defined on the Grid object""" + return self.grid.lon + + @property + def lat(self): + """Lat defined on the Grid object""" + return self.grid.lat + + @property + def depth(self): + """Depth defined on the Grid object""" + return self.grid.depth + + @property + def cell_edge_sizes(self): + return self.grid.cell_edge_sizes + + @property + def interp_method(self): + return self._interp_method + + @interp_method.setter + def interp_method(self, value): + assert_valid_interp_method(value) + self._interp_method = value + + @property + def gridindexingtype(self): + return self._gridindexingtype + + @property + def cast_data_dtype(self): + return self._cast_data_dtype + + @property + def netcdf_engine(self): + return self._netcdf_engine + + @classmethod + def _get_dim_filenames(cls, filenames, dim): + if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): + return [filenames] + elif isinstance(filenames, dict): + assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" + filename = filenames[dim] + if isinstance(filename, str): + return [filename] + else: + return filename + else: + return filenames + + @staticmethod + def _collect_timeslices( + timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine, netcdf_decodewarning=None + ): + if netcdf_decodewarning is not None: + _deprecated_param_netcdf_decodewarning() + if timestamps is not None: + dataFiles = [] + for findex in range(len(data_filenames)): + stamps_in_file = 1 if isinstance(timestamps[findex], (int, np.datetime64)) else len(timestamps[findex]) + for f in [data_filenames[findex]] * stamps_in_file: + dataFiles.append(f) + timeslices = np.array([stamp for file in timestamps for stamp in file]) + time = timeslices + else: + timeslices = [] + dataFiles = [] + for fname in data_filenames: + with _grid_fb_class(fname, dimensions, indices, netcdf_engine=netcdf_engine) as filebuffer: + ftime = filebuffer.time + timeslices.append(ftime) + dataFiles.append([fname] * len(ftime)) + time = np.concatenate(timeslices).ravel() + dataFiles = np.concatenate(dataFiles).ravel() + if time.size == 1 and time[0] is None: + time[0] = 0 + time_origin = TimeConverter(time[0]) + time = time_origin.reltime(time) + + if not np.all((time[1:] - time[:-1]) > 0): + id_not_ordered = np.where(time[1:] < time[:-1])[0][0] + raise AssertionError( + f"Please make sure your netCDF files are ordered in time. First pair of non-ordered files: {dataFiles[id_not_ordered]}, {dataFiles[id_not_ordered + 1]}" + ) + return time, time_origin, timeslices, dataFiles \ No newline at end of file diff --git a/parcels/basegrid.py b/parcels/basegrid.py new file mode 100644 index 0000000000..d02c1fdf92 --- /dev/null +++ b/parcels/basegrid.py @@ -0,0 +1,228 @@ +import functools +import warnings +from ctypes import POINTER, Structure, c_double, c_float, c_int, c_void_p, cast, pointer +from enum import IntEnum +from abc import ABC, abstractmethod + +import numpy as np +import numpy.typing as npt + +from parcels._typing import Mesh, UpdateStatus, assert_valid_mesh +from parcels.tools._helpers import deprecated_made_private +from parcels.tools.converters import TimeConverter +from parcels.tools.warnings import FieldSetWarning + +class CGrid(Structure): + _fields_ = [("gtype", c_int), ("grid", c_void_p)] + +class BaseGrid(ABC): + """Abstract Base Class for Grid types.""" + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + time: npt.NDArray | None, + time_origin: TimeConverter | None, + mesh: Mesh, + ): + self._ti = -1 + self._update_status: UpdateStatus | None = None + if not lon.flags["C_CONTIGUOUS"]: + lon = np.array(lon, order="C") + if not lat.flags["C_CONTIGUOUS"]: + lat = np.array(lat, order="C") + time = np.zeros(1, dtype=np.float64) if time is None else time + if not time.flags["C_CONTIGUOUS"]: + time = np.array(time, order="C") + if not lon.dtype == np.float32: + lon = lon.astype(np.float32) + if not lat.dtype == np.float32: + lat = lat.astype(np.float32) + if not time.dtype == np.float64: + assert isinstance( + time[0], (np.integer, np.floating, float, int) + ), "Time vector must be an array of int or floats" + time = time.astype(np.float64) + + self._lon = lon + self._lat = lat + self.time = time + self.time_full = self.time # needed for deferred_loaded Fields + self._time_origin = TimeConverter() if time_origin is None else time_origin + assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object" + assert_valid_mesh(mesh) + self._mesh = mesh + self._cstruct = None + self._cell_edge_sizes: dict[str, npt.NDArray] = {} + # self._zonal_periodic = False + # self._zonal_halo = 0 + # self._meridional_halo = 0 + # self._lat_flipped = False + self._defer_load = False + self._lonlat_minmax = np.array( + [np.nanmin(lon), np.nanmax(lon), np.nanmin(lat), np.nanmax(lat)], dtype=np.float32 + ) + self.periods = 0 + self._load_chunk: npt.NDArray = np.array([]) + self.chunk_info = None + self.chunksize = None + self._add_last_periodic_data_timestep = False + self.depth_field = None + + def __repr__(self): + with np.printoptions(threshold=5, suppress=True, linewidth=120, formatter={"float": "{: 0.2f}".format}): + return ( + f"{type(self).__name__}(" + f"lon={self.lon!r}, lat={self.lat!r}, time={self.time!r}, " + f"time_origin={self.time_origin!r}, mesh={self.mesh!r})" + ) + + @property + def lon(self): + return self._lon + + @property + def lat(self): + return self._lat + + @property + def depth(self): + return self._depth + + @property + def mesh(self): + return self._mesh + + @property + def lonlat_minmax(self): + return self._lonlat_minmax + + @property + def cell_edge_sizes(self): + return self._cell_edge_sizes + + @property + def defer_load(self): + return self._defer_load + + @property + def time_origin(self): + return self._time_origin + + @property + def ctypes_struct(self): + # This is unnecessary for the moment, but it could be useful when going will fully unstructured grids + self._cgrid = cast(pointer(self._child_ctypes_struct), c_void_p) + cstruct = CGrid(self._gtype, self._cgrid.value) + return cstruct + + @property + @abstractmethod + def _child_ctypes_struct(self): + pass + + @abstractmethod + def lon_grid_to_target(self): + pass + + @abstractmethod + def lon_grid_to_source(self): + pass + + @abstractmethod + def lon_particle_to_target(self, lon): + pass + + def _computeTimeChunk(self, f, time, signdt): + nextTime_loc = np.inf if signdt >= 0 else -np.inf + periods = self.periods.value if isinstance(self.periods, c_int) else self.periods + prev_time_indices = self.time + if self._update_status == "not_updated": + if self._ti >= 0: + if ( + time - periods * (self.time_full[-1] - self.time_full[0]) < self.time[0] + or time - periods * (self.time_full[-1] - self.time_full[0]) > self.time[1] + ): + self._ti = -1 # reset + elif signdt >= 0 and ( + time - periods * (self.time_full[-1] - self.time_full[0]) < self.time_full[0] + or time - periods * (self.time_full[-1] - self.time_full[0]) >= self.time_full[-1] + ): + self._ti = -1 # reset + elif signdt < 0 and ( + time - periods * (self.time_full[-1] - self.time_full[0]) <= self.time_full[0] + or time - periods * (self.time_full[-1] - self.time_full[0]) > self.time_full[-1] + ): + self._ti = -1 # reset + elif ( + signdt >= 0 + and time - periods * (self.time_full[-1] - self.time_full[0]) >= self.time[1] + and self._ti < len(self.time_full) - 2 + ): + self._ti += 1 + self.time = self.time_full[self._ti : self._ti + 2] + self._update_status = "updated" + elif ( + signdt < 0 + and time - periods * (self.time_full[-1] - self.time_full[0]) <= self.time[0] + and self._ti > 0 + ): + self._ti -= 1 + self.time = self.time_full[self._ti : self._ti + 2] + self._update_status = "updated" + if self._ti == -1: + self.time = self.time_full + self._ti, _ = f._time_index(time) + periods = self.periods.value if isinstance(self.periods, c_int) else self.periods + if ( + signdt == -1 + and self._ti == 0 + and (time - periods * (self.time_full[-1] - self.time_full[0])) == self.time[0] + and f.time_periodic + ): + self._ti = len(self.time) - 1 + periods -= 1 + if signdt == -1 and self._ti > 0 and self.time_full[self._ti] == time: + self._ti -= 1 + if self._ti >= len(self.time_full) - 1: + self._ti = len(self.time_full) - 2 + + self.time = self.time_full[self._ti : self._ti + 2] + self.tdim = 2 + if prev_time_indices is None or len(prev_time_indices) != 2 or len(prev_time_indices) != len(self.time): + self._update_status = "first_updated" + elif functools.reduce( + lambda i, j: i and j, map(lambda m, k: m == k, self.time, prev_time_indices), True + ) and len(prev_time_indices) == len(self.time): + self._update_status = "not_updated" + elif functools.reduce( + lambda i, j: i and j, map(lambda m, k: m == k, self.time[:1], prev_time_indices[:1]), True + ) and len(prev_time_indices) == len(self.time): + self._update_status = "updated" + else: + self._update_status = "first_updated" + if signdt >= 0 and (self._ti < len(self.time_full) - 2 or not f.allow_time_extrapolation): + nextTime_loc = self.time[1] + periods * (self.time_full[-1] - self.time_full[0]) + elif signdt < 0 and (self._ti > 0 or not f.allow_time_extrapolation): + nextTime_loc = self.time[0] + periods * (self.time_full[-1] - self.time_full[0]) + return nextTime_loc + + @property + def _chunk_not_loaded(self): + return 0 + + @property + def _chunk_loading_requested(self): + return 1 + + @property + def _chunk_loaded_touched(self): + return 2 + + @property + def _chunk_deprecated(self): + return 3 + + @property + def _chunk_loaded(self): + return [2, 3] \ No newline at end of file diff --git a/parcels/field.py b/parcels/field.py index df3f6b991e..a1c6799ae8 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -37,6 +37,7 @@ TimeExtrapolationError, ) from parcels.tools.warnings import FieldSetWarning, _deprecated_param_netcdf_decodewarning +from parcels.basefield import BaseField from .fieldfilebuffer import ( DaskFileBuffer, @@ -77,7 +78,7 @@ def _deal_with_errors(error, key, vector_type: VectorType): return 0 -class Field: +class Field(BaseField): """Class that encapsulates access to field data. Parameters @@ -150,166 +151,8 @@ class Field: * `Nested Fields <../examples/tutorial_NestedFields.ipynb>`__ """ - def __init__( - self, - name: str | tuple[str, str], - data, - lon=None, - lat=None, - depth=None, - time=None, - grid=None, - mesh: Mesh = "flat", - timestamps=None, - fieldtype=None, - transpose=False, - vmin=None, - vmax=None, - cast_data_dtype="float32", - time_origin=None, - interp_method: InterpMethod = "linear", - allow_time_extrapolation: bool | None = None, - time_periodic: TimePeriodic = False, - gridindexingtype: GridIndexingType = "nemo", - to_write=False, - **kwargs, - ): - if kwargs.get("netcdf_decodewarning") is not None: - _deprecated_param_netcdf_decodewarning() - kwargs.pop("netcdf_decodewarning") - - if not isinstance(name, tuple): - self.name = name - self.filebuffername = name - else: - self.name = name[0] - self.filebuffername = name[1] - self.data = data - if grid: - if grid.defer_load and isinstance(data, np.ndarray): - raise ValueError( - "Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately" - ) - self._grid = grid - else: - if (time is not None) and isinstance(time[0], np.datetime64): - time_origin = TimeConverter(time[0]) - time = np.array([time_origin.reltime(t) for t in time]) - else: - time_origin = TimeConverter(0) - self._grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) - self.igrid = -1 - self.fieldtype = self.name if fieldtype is None else fieldtype - self.to_write = to_write - if self.grid.mesh == "flat" or (self.fieldtype not in unitconverters_map.keys()): - self.units = UnitConverter() - elif self.grid.mesh == "spherical": - self.units = unitconverters_map[self.fieldtype] - else: - raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") - self.timestamps = timestamps - if isinstance(interp_method, dict): - if self.name in interp_method: - self.interp_method = interp_method[self.name] - else: - raise RuntimeError(f"interp_method is a dictionary but {name} is not in it") - else: - self.interp_method = interp_method - assert_valid_gridindexingtype(gridindexingtype) - self._gridindexingtype = gridindexingtype - if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"] and self.grid._gtype in [ - GridType.RectilinearSGrid, - GridType.CurvilinearSGrid, - ]: - warnings.warn( - "General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal.", - FieldSetWarning, - stacklevel=2, - ) - - self.fieldset: FieldSet | None = None - if allow_time_extrapolation is None: - self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False - else: - self.allow_time_extrapolation = allow_time_extrapolation - - self.time_periodic = time_periodic - if self.time_periodic is not False and self.allow_time_extrapolation: - warnings.warn( - "allow_time_extrapolation and time_periodic cannot be used together. allow_time_extrapolation is set to False", - FieldSetWarning, - stacklevel=2, - ) - self.allow_time_extrapolation = False - if self.time_periodic is True: - raise ValueError( - "Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object." - ) - if self.time_periodic is not False: - if isinstance(self.time_periodic, datetime.timedelta): - self.time_periodic = self.time_periodic.total_seconds() - if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic): - if self.grid.time[-1] - self.grid.time[0] > self.time_periodic: - raise ValueError("Time series provided is longer than the time_periodic parameter") - self.grid._add_last_periodic_data_timestep = True - self.grid.time = np.append(self.grid.time, self.grid.time[0] + self.time_periodic) - self.grid.time_full = self.grid.time - - self.vmin = vmin - self.vmax = vmax - self._cast_data_dtype = cast_data_dtype - if self.cast_data_dtype == "float32": - self._cast_data_dtype = np.float32 - elif self.cast_data_dtype == "float64": - self._cast_data_dtype = np.float64 - - if not self.grid.defer_load: - self.data = self._reshape(self.data, transpose) - - # Hack around the fact that NaN and ridiculously large values - # propagate in SciPy's interpolators - lib = np if isinstance(self.data, np.ndarray) else da - self.data[lib.isnan(self.data)] = 0.0 - if self.vmin is not None: - self.data[self.data < self.vmin] = 0.0 - if self.vmax is not None: - self.data[self.data > self.vmax] = 0.0 - - if self.grid._add_last_periodic_data_timestep: - self.data = lib.concatenate((self.data, self.data[:1, :]), axis=0) - - self._scaling_factor = None - - # Variable names in JIT code - self._dimensions = kwargs.pop("dimensions", None) - self.indices = kwargs.pop("indices", None) - self._dataFiles = kwargs.pop("dataFiles", None) - if self.grid._add_last_periodic_data_timestep and self._dataFiles is not None: - self._dataFiles = np.append(self._dataFiles, self._dataFiles[0]) - self._field_fb_class = kwargs.pop("FieldFileBuffer", None) - self._netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") - self._loaded_time_indices: Iterable[int] = [] # type: ignore - self._creation_log = kwargs.pop("creation_log", "") - self.chunksize = kwargs.pop("chunksize", None) - self.netcdf_chunkdims_name_map = kwargs.pop("chunkdims_name_map", None) - self.grid.depth_field = kwargs.pop("depth_field", None) - - if self.grid.depth_field == "not_yet_set": - assert ( - self.grid._z4d - ), "Providing the depth dimensions from another field data is only available for 4d S grids" - - # data_full_zdim is the vertical dimension of the complete field data, ignoring the indices. - # (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid, - # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). - self.data_full_zdim = kwargs.pop("data_full_zdim", None) - self._data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays - self._c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array - self.nchunks: tuple[int, ...] = () - self._chunk_set: bool = False - self.filebuffers = [None] * 2 - if len(kwargs) > 0: - raise SyntaxError(f'Field received an unexpected keyword argument "{list(kwargs.keys())[0]}"') + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 @@ -341,114 +184,16 @@ def creation_log(self): def loaded_time_indices(self): return self._loaded_time_indices - @property - def dimensions(self): - return self._dimensions - - @property - def grid(self): - return self._grid - - @property - def lon(self): - """Lon defined on the Grid object""" - return self.grid.lon - - @property - def lat(self): - """Lat defined on the Grid object""" - return self.grid.lat - - @property - def depth(self): - """Depth defined on the Grid object""" - return self.grid.depth - - @property - def cell_edge_sizes(self): - return self.grid.cell_edge_sizes - - @property - def interp_method(self): - return self._interp_method - - @interp_method.setter - def interp_method(self, value): - assert_valid_interp_method(value) - self._interp_method = value - - @property - def gridindexingtype(self): - return self._gridindexingtype - - @property - def cast_data_dtype(self): - return self._cast_data_dtype - - @property - def netcdf_engine(self): - return self._netcdf_engine - @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def get_dim_filenames(cls, *args, **kwargs): return cls._get_dim_filenames(*args, **kwargs) - @classmethod - def _get_dim_filenames(cls, filenames, dim): - if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): - return [filenames] - elif isinstance(filenames, dict): - assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" - filename = filenames[dim] - if isinstance(filename, str): - return [filename] - else: - return filename - else: - return filenames - @staticmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def collect_timeslices(*args, **kwargs): return Field._collect_timeslices(*args, **kwargs) - @staticmethod - def _collect_timeslices( - timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine, netcdf_decodewarning=None - ): - if netcdf_decodewarning is not None: - _deprecated_param_netcdf_decodewarning() - if timestamps is not None: - dataFiles = [] - for findex in range(len(data_filenames)): - stamps_in_file = 1 if isinstance(timestamps[findex], (int, np.datetime64)) else len(timestamps[findex]) - for f in [data_filenames[findex]] * stamps_in_file: - dataFiles.append(f) - timeslices = np.array([stamp for file in timestamps for stamp in file]) - time = timeslices - else: - timeslices = [] - dataFiles = [] - for fname in data_filenames: - with _grid_fb_class(fname, dimensions, indices, netcdf_engine=netcdf_engine) as filebuffer: - ftime = filebuffer.time - timeslices.append(ftime) - dataFiles.append([fname] * len(ftime)) - time = np.concatenate(timeslices).ravel() - dataFiles = np.concatenate(dataFiles).ravel() - if time.size == 1 and time[0] is None: - time[0] = 0 - time_origin = TimeConverter(time[0]) - time = time_origin.reltime(time) - - if not np.all((time[1:] - time[:-1]) > 0): - id_not_ordered = np.where(time[1:] < time[:-1])[0][0] - raise AssertionError( - f"Please make sure your netCDF files are ordered in time. First pair of non-ordered files: {dataFiles[id_not_ordered]}, {dataFiles[id_not_ordered + 1]}" - ) - return time, time_origin, timeslices, dataFiles - @classmethod def from_netcdf( cls, diff --git a/parcels/grid.py b/parcels/grid.py index 0d8745a096..4a60d9b794 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -10,6 +10,8 @@ from parcels.tools._helpers import deprecated_made_private from parcels.tools.converters import TimeConverter from parcels.tools.warnings import FieldSetWarning +from parcels.basegrid import BaseGrid +from parcels.basegrid import CGrid __all__ = [ "GridType", @@ -22,7 +24,6 @@ "Grid", ] - class GridType(IntEnum): RectilinearZGrid = 0 RectilinearSGrid = 1 @@ -35,101 +36,20 @@ class GridType(IntEnum): GridCode = GridType -class CGrid(Structure): - _fields_ = [("gtype", c_int), ("grid", c_void_p)] - - -class Grid: +class Grid(BaseGrid): """Grid class that defines a (spatial and temporal) grid on which Fields are defined.""" - def __init__( - self, - lon: npt.NDArray, - lat: npt.NDArray, - time: npt.NDArray | None, - time_origin: TimeConverter | None, - mesh: Mesh, - ): - self._ti = -1 - self._update_status: UpdateStatus | None = None - if not lon.flags["C_CONTIGUOUS"]: - lon = np.array(lon, order="C") - if not lat.flags["C_CONTIGUOUS"]: - lat = np.array(lat, order="C") - time = np.zeros(1, dtype=np.float64) if time is None else time - if not time.flags["C_CONTIGUOUS"]: - time = np.array(time, order="C") - if not lon.dtype == np.float32: - lon = lon.astype(np.float32) - if not lat.dtype == np.float32: - lat = lat.astype(np.float32) - if not time.dtype == np.float64: - assert isinstance( - time[0], (np.integer, np.floating, float, int) - ), "Time vector must be an array of int or floats" - time = time.astype(np.float64) - - self._lon = lon - self._lat = lat - self.time = time - self.time_full = self.time # needed for deferred_loaded Fields - self._time_origin = TimeConverter() if time_origin is None else time_origin - assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object" - assert_valid_mesh(mesh) - self._mesh = mesh - self._cstruct = None - self._cell_edge_sizes: dict[str, npt.NDArray] = {} + def __init__(self,*args, **kwargs): + super().__init__(*args, **kwargs) self._zonal_periodic = False self._zonal_halo = 0 self._meridional_halo = 0 self._lat_flipped = False - self._defer_load = False - self._lonlat_minmax = np.array( - [np.nanmin(lon), np.nanmax(lon), np.nanmin(lat), np.nanmax(lat)], dtype=np.float32 - ) - self.periods = 0 - self._load_chunk: npt.NDArray = np.array([]) - self.chunk_info = None - self.chunksize = None - self._add_last_periodic_data_timestep = False - self.depth_field = None - - def __repr__(self): - with np.printoptions(threshold=5, suppress=True, linewidth=120, formatter={"float": "{: 0.2f}".format}): - return ( - f"{type(self).__name__}(" - f"lon={self.lon!r}, lat={self.lat!r}, time={self.time!r}, " - f"time_origin={self.time_origin!r}, mesh={self.mesh!r})" - ) - - @property - def lon(self): - return self._lon - - @property - def lat(self): - return self._lat - - @property - def depth(self): - return self._depth - - @property - def mesh(self): - return self._mesh @property def meridional_halo(self): return self._meridional_halo - @property - def lonlat_minmax(self): - return self._lonlat_minmax - - @property - def time_origin(self): - return self._time_origin - @property def zonal_periodic(self): return self._zonal_periodic @@ -138,14 +58,6 @@ def zonal_periodic(self): def zonal_halo(self): return self._zonal_halo - @property - def defer_load(self): - return self._defer_load - - @property - def cell_edge_sizes(self): - return self._cell_edge_sizes - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def ti(self): @@ -213,12 +125,6 @@ def create_grid( else: return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) - @property - def ctypes_struct(self): - # This is unnecessary for the moment, but it could be useful when going will fully unstructured grids - self._cgrid = cast(pointer(self._child_ctypes_struct), c_void_p) - cstruct = CGrid(self._gtype, self._cgrid.value) - return cstruct @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 @@ -340,125 +246,31 @@ def _add_Sdepth_periodic_halo(self, zonal, meridional, halosize): def computeTimeChunk(self, *args, **kwargs): return self._computeTimeChunk(*args, **kwargs) - def _computeTimeChunk(self, f, time, signdt): - nextTime_loc = np.inf if signdt >= 0 else -np.inf - periods = self.periods.value if isinstance(self.periods, c_int) else self.periods - prev_time_indices = self.time - if self._update_status == "not_updated": - if self._ti >= 0: - if ( - time - periods * (self.time_full[-1] - self.time_full[0]) < self.time[0] - or time - periods * (self.time_full[-1] - self.time_full[0]) > self.time[1] - ): - self._ti = -1 # reset - elif signdt >= 0 and ( - time - periods * (self.time_full[-1] - self.time_full[0]) < self.time_full[0] - or time - periods * (self.time_full[-1] - self.time_full[0]) >= self.time_full[-1] - ): - self._ti = -1 # reset - elif signdt < 0 and ( - time - periods * (self.time_full[-1] - self.time_full[0]) <= self.time_full[0] - or time - periods * (self.time_full[-1] - self.time_full[0]) > self.time_full[-1] - ): - self._ti = -1 # reset - elif ( - signdt >= 0 - and time - periods * (self.time_full[-1] - self.time_full[0]) >= self.time[1] - and self._ti < len(self.time_full) - 2 - ): - self._ti += 1 - self.time = self.time_full[self._ti : self._ti + 2] - self._update_status = "updated" - elif ( - signdt < 0 - and time - periods * (self.time_full[-1] - self.time_full[0]) <= self.time[0] - and self._ti > 0 - ): - self._ti -= 1 - self.time = self.time_full[self._ti : self._ti + 2] - self._update_status = "updated" - if self._ti == -1: - self.time = self.time_full - self._ti, _ = f._time_index(time) - periods = self.periods.value if isinstance(self.periods, c_int) else self.periods - if ( - signdt == -1 - and self._ti == 0 - and (time - periods * (self.time_full[-1] - self.time_full[0])) == self.time[0] - and f.time_periodic - ): - self._ti = len(self.time) - 1 - periods -= 1 - if signdt == -1 and self._ti > 0 and self.time_full[self._ti] == time: - self._ti -= 1 - if self._ti >= len(self.time_full) - 1: - self._ti = len(self.time_full) - 2 - - self.time = self.time_full[self._ti : self._ti + 2] - self.tdim = 2 - if prev_time_indices is None or len(prev_time_indices) != 2 or len(prev_time_indices) != len(self.time): - self._update_status = "first_updated" - elif functools.reduce( - lambda i, j: i and j, map(lambda m, k: m == k, self.time, prev_time_indices), True - ) and len(prev_time_indices) == len(self.time): - self._update_status = "not_updated" - elif functools.reduce( - lambda i, j: i and j, map(lambda m, k: m == k, self.time[:1], prev_time_indices[:1]), True - ) and len(prev_time_indices) == len(self.time): - self._update_status = "updated" - else: - self._update_status = "first_updated" - if signdt >= 0 and (self._ti < len(self.time_full) - 2 or not f.allow_time_extrapolation): - nextTime_loc = self.time[1] + periods * (self.time_full[-1] - self.time_full[0]) - elif signdt < 0 and (self._ti > 0 or not f.allow_time_extrapolation): - nextTime_loc = self.time[0] + periods * (self.time_full[-1] - self.time_full[0]) - return nextTime_loc - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_not_loaded(self): return self._chunk_not_loaded - @property - def _chunk_not_loaded(self): - return 0 - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_loading_requested(self): return self._chunk_loading_requested - @property - def _chunk_loading_requested(self): - return 1 - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_loaded_touched(self): return self._chunk_loaded_touched - @property - def _chunk_loaded_touched(self): - return 2 - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_deprecated(self): return self._chunk_deprecated - @property - def _chunk_deprecated(self): - return 3 - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_loaded(self): return self._chunk_loaded - @property - def _chunk_loaded(self): - return [2, 3] - class RectilinearGrid(Grid): """Rectilinear Grid class diff --git a/parcels/hashgrid.py b/parcels/hashgrid.py new file mode 100644 index 0000000000..eb402f1601 --- /dev/null +++ b/parcels/hashgrid.py @@ -0,0 +1,77 @@ + + +class HashGrid: + """Class for creating a hash grid for fast particle lookup. + + Parameters + ---------- + x0 : float + x-coordinate of the lower left corner of the grid. + y0 : float + y-coordinate of the lower left corner of the grid. + dx : float + Grid spacing in x-direction. + dy : float + Grid spacing in y-direction. + nx : int + Number of grid points in x-direction. + ny : int + Number of grid points in y-direction + ugrid_elements : list + List of lists of unstructured grid indices in each hash cell. + """ + + def __init__(self,x0,y0,dx,dy,nx,ny): + self.x0 = x0 + self.y0 = y0 + self.dx = dx + self.dy = dy + self.nx = nx + self.ny = ny + ugrid_elements = [[] for i in range(self.nx*self.ny)] + + + + def get_hashindex_for_xy(self,x,y): + """Get the grid indices for a given x and y coordinate.""" + i = int((x-self.x0)/self.dx) + j = int((y-self.y0)/self.dy) + return i+self.nx*j + + + def populate_ugrid_elements(self,vertices, elements): + """ + Efficiently find the list of triangles whose bounding box overlaps with the specified hash cells. + + Parameters: + - vertices (np.ndarray): Array of vertex coordinates of shape (n_vertices, 2). + - elements (np.ndarray): Array of element-to-vertex connectivity, where each row contains 3 indices into the vertices array. + + Returns: + - overlapping_triangles (dict): A dictionary where keys are the hash cell index and values are lists of triangle indices. + """ + import numpy as np + + overlapping_triangles = [[] for i in range(self.nx*self.ny)] + + # Loop over each triangle element + for triangle_idx, triangle in enumerate(elements): + # Get the coordinates of the triangle's vertices + triangle_vertices = vertices[triangle] + + # Calculate the bounding box of the triangle + x_min, y_min = np.min(triangle_vertices, axis=0) + x_max, y_max = np.max(triangle_vertices, axis=0) + + # Find the hash cell range that overlaps with the triangle's bounding box + i_min = int(np.floor((x_min-self.x0) / self.dx)) + i_max = int(np.floor((x_max-self.x0) / self.dx)) + j_min = int(np.floor((y_min-self.y0) / self.dy)) + j_max = int(np.floor((y_max-self.y0) / self.dy)) + + # Iterate over all hash cells that intersect the bounding box + for j in range(j_min, j_max + 1): + for i in range(i_min, i_max + 1): + overlapping_triangles[i+self.nx*j].append(triangle_idx) + + self.ugrid_elements = overlapping_triangles \ No newline at end of file diff --git a/parcels/ufield.py b/parcels/ufield.py new file mode 100644 index 0000000000..c983a4ac3a --- /dev/null +++ b/parcels/ufield.py @@ -0,0 +1,262 @@ +import collections +import datetime +import math +import warnings +from collections.abc import Iterable +from ctypes import POINTER, Structure, c_float, c_int, pointer +from pathlib import Path +from typing import TYPE_CHECKING + +import dask.array as da +import numpy as np +import xarray as xr + +import parcels.tools.interpolation_utils as i_u +from parcels._typing import ( + GridIndexingType, + InterpMethod, + Mesh, + TimePeriodic, + VectorType, + assert_valid_gridindexingtype, + assert_valid_interp_method, +) +from parcels.tools._helpers import deprecated_made_private +from parcels.tools.converters import ( + Geographic, + GeographicPolar, + TimeConverter, + UnitConverter, + unitconverters_map, +) +from parcels.tools.statuscodes import ( + AllParcelsErrorCodes, + FieldOutOfBoundError, + FieldOutOfBoundSurfaceError, + FieldSamplingError, + TimeExtrapolationError, +) +from parcels.tools.warnings import FieldSetWarning, _deprecated_param_netcdf_decodewarning + +from .fieldfilebuffer import ( + DaskFileBuffer, + DeferredDaskFileBuffer, + DeferredNetcdfFileBuffer, + NetcdfFileBuffer, +) +from .grid import CGrid, Grid, GridType + +if TYPE_CHECKING: + from ctypes import _Pointer as PointerType + + from parcels.fieldset import FieldSet + + + +__all__ = ["UField", "UVectorField", "UNestedField"] + + +def _isParticle(key): + if hasattr(key, "obs_written"): + return True + else: + return False + + +def _deal_with_errors(error, key, vector_type: VectorType): + if _isParticle(key): + key.state = AllParcelsErrorCodes[type(error)] + elif _isParticle(key[-1]): + key[-1].state = AllParcelsErrorCodes[type(error)] + else: + raise RuntimeError(f"{error}. Error could not be handled because particle was not part of the Field Sampling.") + + if vector_type == "3D": + return (0, 0, 0) + elif vector_type == "2D": + return (0, 0) + else: + return 0 + +class UField(BaseField): + """Class that encapsulates access to field data. + + Parameters + ---------- + name : str + Name of the field + data : np.ndarray + 2D, 3D or 4D numpy array of field data. + + 1. If data shape is [xdim, ydim], [xdim, ydim, zdim], [xdim, ydim, tdim] or [xdim, ydim, zdim, tdim], + whichever is relevant for the dataset, use the flag transpose=True + 2. If data shape is [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], + use the flag transpose=False + 3. If data has any other shape, you first need to reorder it + lon : np.ndarray or list + Longitude coordinates (numpy vector or array) of the field (only if grid is None) + lat : np.ndarray or list + Latitude coordinates (numpy vector or array) of the field (only if grid is None) + face_node_connectivity: np.ndarray or list dimensioned [nfaces, max_nodes_per_face] + Connectivity array between faces and nodes (only if grid is None) + depth : np.ndarray or list + Depth coordinates (numpy vector or array) of the field (only if grid is None) + time : np.ndarray + Time coordinates (numpy vector) of the field (only if grid is None) + mesh : str + String indicating the type of mesh coordinates and + units used during velocity interpolation: (only if grid is None) + + 1. spherical: Lat and lon in degree, with a + correction for zonal velocity U near the poles. + 2. flat (default): No conversion, lat/lon are assumed to be in m. + timestamps : np.ndarray + A numpy array containing the timestamps for each of the files in filenames, for loading + from netCDF files only. Default is None if the netCDF dimensions dictionary includes time. + grid : parcels.grid.Grid + :class:`parcels.grid.Grid` object containing all the lon, lat depth, time + mesh and time_origin information. Can be constructed from any of the Grid objects + fieldtype : str + Type of Field to be used for UnitConverter (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) + transpose : bool + Transpose data to required (lon, lat) layout + vmin : float + Minimum allowed value on the field. Data below this value are set to zero + vmax : float + Maximum allowed value on the field. Data above this value are set to zero + cast_data_dtype : str + Cast Field data to dtype. Supported dtypes are "float32" (np.float32 (default)) and "float64 (np.float64). + Note that dtype can only be "float32" in JIT mode + time_origin : parcels.tools.converters.TimeConverter + Time origin of the time axis (only if grid is None) + interp_method : str + Method for interpolation. Options are 'linear' (default), 'nearest', + 'linear_invdist_land_tracer', 'cgrid_velocity', 'cgrid_tracer' and 'bgrid_velocity' + allow_time_extrapolation : bool + boolean whether to allow for extrapolation in time + (i.e. beyond the last available time snapshot) + time_periodic : bool, float or datetime.timedelta + To loop periodically over the time component of the Field. It is set to either False or the length of the period (either float in seconds or datetime.timedelta object). + The last value of the time series can be provided (which is the same as the initial one) or not (Default: False) + This flag overrides the allow_time_extrapolation and sets it to False + chunkdims_name_map : str, optional + Gives a name map to the FieldFileBuffer that declared a mapping between chunksize name, NetCDF dimension and Parcels dimension; + required only if currently incompatible OCM field is loaded and chunking is used by 'chunksize' (which is the default) + to_write : bool + Write the Field in NetCDF format at the same frequency as the ParticleFile outputdt, + using a filenaming scheme based on the ParticleFile name + + Examples + -------- + + """ + + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + @property + def face_node_connectivity(self): + return self.grid.face_node_connectivity + + # To do + #@classmethod + #def from_netcdf() + # Likely want to use uxarray for this + + # To do + #@classmethod + #def from_uxarray() + + # def _reshape(self, data, transpose=False): + + # def set_scaling_factor(self, factor): + # """Scales the field data by some constant factor. + + # Parameters + # ---------- + # factor : + # scaling factor + + + # Examples + # -------- + # For usage examples see the following tutorial: + + # * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ + # """ + + + # def set_depth_from_field(self, field): + # """Define the depth dimensions from another (time-varying) field. + + # Notes + # ----- + # See `this tutorial <../examples/tutorial_timevaryingdepthdimensions.ipynb>`__ + # for a detailed explanation on how to set up time-evolving depth dimensions. + + # """ + # self.grid.depth_field = field + # if self.grid != field.grid: + # field.grid.depth_field = field + + # def _calc_cell_edge_sizes(self): + # """Method to calculate cell sizes based on numpy.gradient method. + + # Currently only works for Rectilinear Grids + # """ + + # def cell_areas(self): + # """Method to calculate cell sizes based on cell_edge_sizes. + + # Currently only works for Rectilinear Grids + # """ + + # def _search_indices_vertical_s( + # self, x: float, y: float, z: float, xi: int, yi: int, xsi: float, eta: float, ti: int, time: float + # ): + + # def _reconnect_bnd_indices(self, xi, yi, xdim, ydim, sphere_mesh): + + # def _search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): + + # def _search_indices(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): + + # Do hash table search based on x,y location + # Get list of elements to check + # Loop over elements, check if particle is in element (barycentric coordinate calc) + # + # (l1,l2,l3, inside_element) = self.barycentric_coordinates(x, y) + # + # return barycentric coordinates (2d) + # vertical interpolation weight + # element index + # nearest vertical layer index *above* particle (lower k bound); particle is between layer k and k+1 + + #def _interpolator2D(self, ti, z, y, x, particle=None): + # """Interpolation method for 2D UField. The UField.data is assumed to + # be provided at the ugrid vertices. The method uses either nearest + # neighbor or linear (barycentric) interpolation """ + + # (bc, _, ei, _) = self._search_indices(x, y, z, ti, particle=particle) + + # if self.interp_method == "nearest" + # idxs = self.face_node_connectivity[ei] + # vi = idxs[np.argmax(bc)] + # return self.data[ti,vi] + # + # Do nearest neighbour interpolation using vertex with largest barycentric coordinate + # elif self.interp_method == "linear" + # Do barycentric interpolation + + # def _interpolator3D(self, ti, z, y, x, time, particle=None): + # (bc, zeta, ei, zi) = self._search_indices(x, y, z, ti, particle=particle) + + # if self.interp_method == "nearest" + # idxs = self.face_node_connectivity[ei] + # vi = idxs[np.argmax(bc)] + # zii = zi if zeta <= 0.5 else zi + 1 + # return self.data[ti,zi,vi] + # + # Do nearest neighbour interpolation using vertex with largest barycentric coordinate + # elif self.interp_method == "linear" + # Do barycentric interpolation diff --git a/parcels/ugrid.py b/parcels/ugrid.py new file mode 100644 index 0000000000..197f406724 --- /dev/null +++ b/parcels/ugrid.py @@ -0,0 +1,179 @@ +import functools +import warnings +from ctypes import POINTER, Structure, c_double, c_float, c_int, c_void_p, cast, pointer +from enum import IntEnum + +import numpy as np +import numpy.typing as npt + +from parcels._typing import Mesh, UpdateStatus, assert_valid_mesh +from parcels.tools._helpers import deprecated_made_private +from parcels.tools.converters import TimeConverter +from parcels.tools.warnings import FieldSetWarning +from parcels.basegrid import BaseGrid +from parcels.basegrid import CGrid +from parcels.hashgrid import HashGrid + +# Note : +# For variable placement in FESOM - see https://fesom2.readthedocs.io/en/latest/geometry.html +__all__ = [ + "CGrid", + "UGrid" +] + +class UGrid(BaseGrid): + """Grid class that defines a (spatial and temporal) unstructured grid on which Fields are defined.""" + + def __init__(self,face_node_connectivity,*args, **kwargs): + self._face_node_connectivity = face_node_connectivity # nface_node_connectivity x n_nodes_per_element array listing the vertex ids of each face_node_connectivity. + self._hashgrid = None + super().__init__(*args, **kwargs) + + if not isinstance(self.lon, np.ndarray): + raise TypeError("lon must be a NumPy array.") + + if self.lon.ndim != 1: + raise ValueError("lon must be a 1-D array.") + + if not isinstance(self.lat, np.ndarray): + raise TypeError("lat must be a NumPy array.") + + if self.lat.ndim != 1: + raise ValueError("lat must be a 1-D array.") + + if self.lon.shape != self.lat.shape: + raise ValueError("lon and lat must have the same shape.") + + if not isinstance(self.face_node_connectivity, np.ndarray): + raise TypeError("face_node_connectivity must be a NumPy array.") + + if self.face_node_connectivity.ndim != 2: + raise ValueError("face_node_connectivity must be a 2-D array.") + + if self.face_node_connectivity.shape[1] > 3: # Enforce triangle face_node_connectivity + raise ValueError("face_node_connectivity must be a 2-D array with at most 3 columns.") + + self._n_face = self.face_node_connectivity.shape[0] + self._nodes_per_element = face_node_connectivity.shape[1] + self._n_vertices = self.lon.shape[0] + + # The lon and lat fields are assumed identical to the + # node_lon and node_lat fields in UXArray.Grid + # data structure. + @property + def node_lon(self): + return self.lon + + @property + def node_lat(self): + return self.lat + + @property + def face_node_connectivity(self): + return self._face_node_connectivity + + @property + def n_face(self): + return self._n_face + + @property + def nodes_per_element(self): + return self._nodes_per_element + + @property + def n_vertices(self): + return self._n_vertices + + @staticmethod + def create_grid( + lon: npt.ArrayLike, + lat: npt.ArrayLike, + face_node_connectivity: npt.ArrayLike, + depth, + time, + time_origin, + mesh: Mesh, + **kwargs, + ): + lon = np.array(lon) + lat = np.array(lat) + face_node_connectivity = np.array(face_node_connectivity) + + if depth is not None: + depth = np.array(depth) + + return UGrid(lon, lat, face_node_connectivity, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + + @staticmethod + def create_hashgrid(self, hash_cell_size_scalefac=1.0): + """Create the hashgrid attribute and populates the ugrid_element lookup table that + relates the hash cell indices to the associated ugrid elements. The hashgrid + spacing is based on the median bounding box diagonal length of all triangles in the mesh. + The `hash_cell_size_scalefac` parameter can be used to scale the hash cell size. Values + greater than 1 will result in larger hash cells, likely having more unstructured elements + per hash cell. + """ + import numpy as np + + # Initialize a list to store bounding box diagonals + diagonals = np.zeros(self.face_node_connectivity.shape[0]) + vertices = np.column_stack((self.lon,self.lat)) + + # Loop over each triangle element + k = 0 + for triangle in self.face_node_connectivity: + # Get the coordinates of the triangle's vertices + triangle_vertices = vertices[triangle] + + # Calculate the bounding box of the triangle + x_min, y_min = np.min(triangle_vertices, axis=0) + x_max, y_max = np.max(triangle_vertices, axis=0) + + # Calculate the diagonal length of the bounding box + diagonal = np.sqrt((x_max - x_min) ** 2 + (y_max - y_min) ** 2) + + # Store the diagonal length + diagonals[k] = diagonal + + k+=1 + + # Use the median diagonal as a basis for the cell size + dh = np.median(diagonals)*hash_cell_size_scalefac + + Nx = int((self.lon.max() - self.lon.min()) / dh) + 1 + Ny = int((self.lat.max() - self.lat.min()) / dh) + 1 + self.hashgrid = HashGrid(self.lon.min(), self.lat.min(), dh, dh, Nx, Ny) + self.hashgrid.populate_ugrid_elements(self.lon, self.lat, self.face_node_connectivity) + + @staticmethod + def barycentric_coordinates(self,xP, yP): + """ + Compute the barycentric coordinates of a particle in a triangular element + + Parameters: + - xP, yP: The coordinates of the particle + - triangle_vertices (np.ndarray) : The vertices of the triangle as a (3,2) array. + + Returns: + - The barycentric coordinates (l1,l2,l3) + - True if the point is inside the triangle, False otherwise. + """ + + xv = np.squeeze(self.lon) + yv = np.squeeze(self.lat[:,1]) + + A_ABC = xv[0]*(yv[1]-yv[2]) + xv[1]*(yv[2]-yv[0]) + xv[2]*(yv[0]-yv[1]) + A_BCP = xv[1]*(yv[2]-yP ) + xv[2]*(yP -yv[1]) + xP *(yv[1]-yv[2]) + A_CAP = xv[2]*(yv[0]-yP ) + xv[0]*(yP -yv[2]) + xP *(yv[2]-yv[0]) + A_ABP = xv[0]*(yv[1]-yP ) + xv[1]*(yP -yv[0]) + xP *(yv[0]-yv[1]) + + # Compute the vectors + l1 = A_BCP/A_ABC + l2 = A_CAP/A_ABC + l3 = A_ABP/A_ABC + + inside_triangle = all( [l1 >= 0.0, l1 <= 1.0, + l2 >= 0.0, l2 <= 1.0, + l3 >= 0.0, l3 <= 1.0] ) + + return l1,l2,l3,inside_triangle \ No newline at end of file