From 24b72af4a245067bd61dab5e96a0c35220b4820a Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 5 Nov 2024 10:21:32 -0500 Subject: [PATCH 1/3] Draft initial basegrid setup Properties and procedures that are common to unstructured grids and structured grids are defined in the BaseGrid class. This includes lon, lat, time, time_origin, and mesh properties. The Grid class is now defined as a type extension of BaseGrid. This configuration is compatible with the previous commit of Parcels (all existing tests pass). The _zonal_periodic, _zonal_halo, _meridional_halo, and _lat_flipped properties are seen as attributes specific to the structured grid. The UGrid class is defined as a type extension of the BaseGrid. This child class adds the `face_node_connectivity` property that relates vertices to the 2-D lateral grid faces. The lon and lat attributes refer to the `node_lon` and `node_lat` that are defined in the uxarray.Grid structure; these are the corner node vertices. Note that there are other possible nodes, including the edge vertices (centers of the edges), and `face_vertices` that define the element centroids; these are currently not defined. --- parcels/basegrid.py | 228 ++++++++++++++++++++++++++++++++++++++++++++ parcels/grid.py | 198 +------------------------------------- parcels/ugrid.py | 87 +++++++++++++++++ 3 files changed, 320 insertions(+), 193 deletions(-) create mode 100644 parcels/basegrid.py create mode 100644 parcels/ugrid.py diff --git a/parcels/basegrid.py b/parcels/basegrid.py new file mode 100644 index 0000000000..d02c1fdf92 --- /dev/null +++ b/parcels/basegrid.py @@ -0,0 +1,228 @@ +import functools +import warnings +from ctypes import POINTER, Structure, c_double, c_float, c_int, c_void_p, cast, pointer +from enum import IntEnum +from abc import ABC, abstractmethod + +import numpy as np +import numpy.typing as npt + +from parcels._typing import Mesh, UpdateStatus, assert_valid_mesh +from parcels.tools._helpers import deprecated_made_private +from parcels.tools.converters import TimeConverter +from parcels.tools.warnings import FieldSetWarning + +class CGrid(Structure): + _fields_ = [("gtype", c_int), ("grid", c_void_p)] + +class BaseGrid(ABC): + """Abstract Base Class for Grid types.""" + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + time: npt.NDArray | None, + time_origin: TimeConverter | None, + mesh: Mesh, + ): + self._ti = -1 + self._update_status: UpdateStatus | None = None + if not lon.flags["C_CONTIGUOUS"]: + lon = np.array(lon, order="C") + if not lat.flags["C_CONTIGUOUS"]: + lat = np.array(lat, order="C") + time = np.zeros(1, dtype=np.float64) if time is None else time + if not time.flags["C_CONTIGUOUS"]: + time = np.array(time, order="C") + if not lon.dtype == np.float32: + lon = lon.astype(np.float32) + if not lat.dtype == np.float32: + lat = lat.astype(np.float32) + if not time.dtype == np.float64: + assert isinstance( + time[0], (np.integer, np.floating, float, int) + ), "Time vector must be an array of int or floats" + time = time.astype(np.float64) + + self._lon = lon + self._lat = lat + self.time = time + self.time_full = self.time # needed for deferred_loaded Fields + self._time_origin = TimeConverter() if time_origin is None else time_origin + assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object" + assert_valid_mesh(mesh) + self._mesh = mesh + self._cstruct = None + self._cell_edge_sizes: dict[str, npt.NDArray] = {} + # self._zonal_periodic = False + # self._zonal_halo = 0 + # self._meridional_halo = 0 + # self._lat_flipped = False + self._defer_load = False + self._lonlat_minmax = np.array( + [np.nanmin(lon), np.nanmax(lon), np.nanmin(lat), np.nanmax(lat)], dtype=np.float32 + ) + self.periods = 0 + self._load_chunk: npt.NDArray = np.array([]) + self.chunk_info = None + self.chunksize = None + self._add_last_periodic_data_timestep = False + self.depth_field = None + + def __repr__(self): + with np.printoptions(threshold=5, suppress=True, linewidth=120, formatter={"float": "{: 0.2f}".format}): + return ( + f"{type(self).__name__}(" + f"lon={self.lon!r}, lat={self.lat!r}, time={self.time!r}, " + f"time_origin={self.time_origin!r}, mesh={self.mesh!r})" + ) + + @property + def lon(self): + return self._lon + + @property + def lat(self): + return self._lat + + @property + def depth(self): + return self._depth + + @property + def mesh(self): + return self._mesh + + @property + def lonlat_minmax(self): + return self._lonlat_minmax + + @property + def cell_edge_sizes(self): + return self._cell_edge_sizes + + @property + def defer_load(self): + return self._defer_load + + @property + def time_origin(self): + return self._time_origin + + @property + def ctypes_struct(self): + # This is unnecessary for the moment, but it could be useful when going will fully unstructured grids + self._cgrid = cast(pointer(self._child_ctypes_struct), c_void_p) + cstruct = CGrid(self._gtype, self._cgrid.value) + return cstruct + + @property + @abstractmethod + def _child_ctypes_struct(self): + pass + + @abstractmethod + def lon_grid_to_target(self): + pass + + @abstractmethod + def lon_grid_to_source(self): + pass + + @abstractmethod + def lon_particle_to_target(self, lon): + pass + + def _computeTimeChunk(self, f, time, signdt): + nextTime_loc = np.inf if signdt >= 0 else -np.inf + periods = self.periods.value if isinstance(self.periods, c_int) else self.periods + prev_time_indices = self.time + if self._update_status == "not_updated": + if self._ti >= 0: + if ( + time - periods * (self.time_full[-1] - self.time_full[0]) < self.time[0] + or time - periods * (self.time_full[-1] - self.time_full[0]) > self.time[1] + ): + self._ti = -1 # reset + elif signdt >= 0 and ( + time - periods * (self.time_full[-1] - self.time_full[0]) < self.time_full[0] + or time - periods * (self.time_full[-1] - self.time_full[0]) >= self.time_full[-1] + ): + self._ti = -1 # reset + elif signdt < 0 and ( + time - periods * (self.time_full[-1] - self.time_full[0]) <= self.time_full[0] + or time - periods * (self.time_full[-1] - self.time_full[0]) > self.time_full[-1] + ): + self._ti = -1 # reset + elif ( + signdt >= 0 + and time - periods * (self.time_full[-1] - self.time_full[0]) >= self.time[1] + and self._ti < len(self.time_full) - 2 + ): + self._ti += 1 + self.time = self.time_full[self._ti : self._ti + 2] + self._update_status = "updated" + elif ( + signdt < 0 + and time - periods * (self.time_full[-1] - self.time_full[0]) <= self.time[0] + and self._ti > 0 + ): + self._ti -= 1 + self.time = self.time_full[self._ti : self._ti + 2] + self._update_status = "updated" + if self._ti == -1: + self.time = self.time_full + self._ti, _ = f._time_index(time) + periods = self.periods.value if isinstance(self.periods, c_int) else self.periods + if ( + signdt == -1 + and self._ti == 0 + and (time - periods * (self.time_full[-1] - self.time_full[0])) == self.time[0] + and f.time_periodic + ): + self._ti = len(self.time) - 1 + periods -= 1 + if signdt == -1 and self._ti > 0 and self.time_full[self._ti] == time: + self._ti -= 1 + if self._ti >= len(self.time_full) - 1: + self._ti = len(self.time_full) - 2 + + self.time = self.time_full[self._ti : self._ti + 2] + self.tdim = 2 + if prev_time_indices is None or len(prev_time_indices) != 2 or len(prev_time_indices) != len(self.time): + self._update_status = "first_updated" + elif functools.reduce( + lambda i, j: i and j, map(lambda m, k: m == k, self.time, prev_time_indices), True + ) and len(prev_time_indices) == len(self.time): + self._update_status = "not_updated" + elif functools.reduce( + lambda i, j: i and j, map(lambda m, k: m == k, self.time[:1], prev_time_indices[:1]), True + ) and len(prev_time_indices) == len(self.time): + self._update_status = "updated" + else: + self._update_status = "first_updated" + if signdt >= 0 and (self._ti < len(self.time_full) - 2 or not f.allow_time_extrapolation): + nextTime_loc = self.time[1] + periods * (self.time_full[-1] - self.time_full[0]) + elif signdt < 0 and (self._ti > 0 or not f.allow_time_extrapolation): + nextTime_loc = self.time[0] + periods * (self.time_full[-1] - self.time_full[0]) + return nextTime_loc + + @property + def _chunk_not_loaded(self): + return 0 + + @property + def _chunk_loading_requested(self): + return 1 + + @property + def _chunk_loaded_touched(self): + return 2 + + @property + def _chunk_deprecated(self): + return 3 + + @property + def _chunk_loaded(self): + return [2, 3] \ No newline at end of file diff --git a/parcels/grid.py b/parcels/grid.py index 0d8745a096..4a60d9b794 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -10,6 +10,8 @@ from parcels.tools._helpers import deprecated_made_private from parcels.tools.converters import TimeConverter from parcels.tools.warnings import FieldSetWarning +from parcels.basegrid import BaseGrid +from parcels.basegrid import CGrid __all__ = [ "GridType", @@ -22,7 +24,6 @@ "Grid", ] - class GridType(IntEnum): RectilinearZGrid = 0 RectilinearSGrid = 1 @@ -35,101 +36,20 @@ class GridType(IntEnum): GridCode = GridType -class CGrid(Structure): - _fields_ = [("gtype", c_int), ("grid", c_void_p)] - - -class Grid: +class Grid(BaseGrid): """Grid class that defines a (spatial and temporal) grid on which Fields are defined.""" - def __init__( - self, - lon: npt.NDArray, - lat: npt.NDArray, - time: npt.NDArray | None, - time_origin: TimeConverter | None, - mesh: Mesh, - ): - self._ti = -1 - self._update_status: UpdateStatus | None = None - if not lon.flags["C_CONTIGUOUS"]: - lon = np.array(lon, order="C") - if not lat.flags["C_CONTIGUOUS"]: - lat = np.array(lat, order="C") - time = np.zeros(1, dtype=np.float64) if time is None else time - if not time.flags["C_CONTIGUOUS"]: - time = np.array(time, order="C") - if not lon.dtype == np.float32: - lon = lon.astype(np.float32) - if not lat.dtype == np.float32: - lat = lat.astype(np.float32) - if not time.dtype == np.float64: - assert isinstance( - time[0], (np.integer, np.floating, float, int) - ), "Time vector must be an array of int or floats" - time = time.astype(np.float64) - - self._lon = lon - self._lat = lat - self.time = time - self.time_full = self.time # needed for deferred_loaded Fields - self._time_origin = TimeConverter() if time_origin is None else time_origin - assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object" - assert_valid_mesh(mesh) - self._mesh = mesh - self._cstruct = None - self._cell_edge_sizes: dict[str, npt.NDArray] = {} + def __init__(self,*args, **kwargs): + super().__init__(*args, **kwargs) self._zonal_periodic = False self._zonal_halo = 0 self._meridional_halo = 0 self._lat_flipped = False - self._defer_load = False - self._lonlat_minmax = np.array( - [np.nanmin(lon), np.nanmax(lon), np.nanmin(lat), np.nanmax(lat)], dtype=np.float32 - ) - self.periods = 0 - self._load_chunk: npt.NDArray = np.array([]) - self.chunk_info = None - self.chunksize = None - self._add_last_periodic_data_timestep = False - self.depth_field = None - - def __repr__(self): - with np.printoptions(threshold=5, suppress=True, linewidth=120, formatter={"float": "{: 0.2f}".format}): - return ( - f"{type(self).__name__}(" - f"lon={self.lon!r}, lat={self.lat!r}, time={self.time!r}, " - f"time_origin={self.time_origin!r}, mesh={self.mesh!r})" - ) - - @property - def lon(self): - return self._lon - - @property - def lat(self): - return self._lat - - @property - def depth(self): - return self._depth - - @property - def mesh(self): - return self._mesh @property def meridional_halo(self): return self._meridional_halo - @property - def lonlat_minmax(self): - return self._lonlat_minmax - - @property - def time_origin(self): - return self._time_origin - @property def zonal_periodic(self): return self._zonal_periodic @@ -138,14 +58,6 @@ def zonal_periodic(self): def zonal_halo(self): return self._zonal_halo - @property - def defer_load(self): - return self._defer_load - - @property - def cell_edge_sizes(self): - return self._cell_edge_sizes - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def ti(self): @@ -213,12 +125,6 @@ def create_grid( else: return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) - @property - def ctypes_struct(self): - # This is unnecessary for the moment, but it could be useful when going will fully unstructured grids - self._cgrid = cast(pointer(self._child_ctypes_struct), c_void_p) - cstruct = CGrid(self._gtype, self._cgrid.value) - return cstruct @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 @@ -340,125 +246,31 @@ def _add_Sdepth_periodic_halo(self, zonal, meridional, halosize): def computeTimeChunk(self, *args, **kwargs): return self._computeTimeChunk(*args, **kwargs) - def _computeTimeChunk(self, f, time, signdt): - nextTime_loc = np.inf if signdt >= 0 else -np.inf - periods = self.periods.value if isinstance(self.periods, c_int) else self.periods - prev_time_indices = self.time - if self._update_status == "not_updated": - if self._ti >= 0: - if ( - time - periods * (self.time_full[-1] - self.time_full[0]) < self.time[0] - or time - periods * (self.time_full[-1] - self.time_full[0]) > self.time[1] - ): - self._ti = -1 # reset - elif signdt >= 0 and ( - time - periods * (self.time_full[-1] - self.time_full[0]) < self.time_full[0] - or time - periods * (self.time_full[-1] - self.time_full[0]) >= self.time_full[-1] - ): - self._ti = -1 # reset - elif signdt < 0 and ( - time - periods * (self.time_full[-1] - self.time_full[0]) <= self.time_full[0] - or time - periods * (self.time_full[-1] - self.time_full[0]) > self.time_full[-1] - ): - self._ti = -1 # reset - elif ( - signdt >= 0 - and time - periods * (self.time_full[-1] - self.time_full[0]) >= self.time[1] - and self._ti < len(self.time_full) - 2 - ): - self._ti += 1 - self.time = self.time_full[self._ti : self._ti + 2] - self._update_status = "updated" - elif ( - signdt < 0 - and time - periods * (self.time_full[-1] - self.time_full[0]) <= self.time[0] - and self._ti > 0 - ): - self._ti -= 1 - self.time = self.time_full[self._ti : self._ti + 2] - self._update_status = "updated" - if self._ti == -1: - self.time = self.time_full - self._ti, _ = f._time_index(time) - periods = self.periods.value if isinstance(self.periods, c_int) else self.periods - if ( - signdt == -1 - and self._ti == 0 - and (time - periods * (self.time_full[-1] - self.time_full[0])) == self.time[0] - and f.time_periodic - ): - self._ti = len(self.time) - 1 - periods -= 1 - if signdt == -1 and self._ti > 0 and self.time_full[self._ti] == time: - self._ti -= 1 - if self._ti >= len(self.time_full) - 1: - self._ti = len(self.time_full) - 2 - - self.time = self.time_full[self._ti : self._ti + 2] - self.tdim = 2 - if prev_time_indices is None or len(prev_time_indices) != 2 or len(prev_time_indices) != len(self.time): - self._update_status = "first_updated" - elif functools.reduce( - lambda i, j: i and j, map(lambda m, k: m == k, self.time, prev_time_indices), True - ) and len(prev_time_indices) == len(self.time): - self._update_status = "not_updated" - elif functools.reduce( - lambda i, j: i and j, map(lambda m, k: m == k, self.time[:1], prev_time_indices[:1]), True - ) and len(prev_time_indices) == len(self.time): - self._update_status = "updated" - else: - self._update_status = "first_updated" - if signdt >= 0 and (self._ti < len(self.time_full) - 2 or not f.allow_time_extrapolation): - nextTime_loc = self.time[1] + periods * (self.time_full[-1] - self.time_full[0]) - elif signdt < 0 and (self._ti > 0 or not f.allow_time_extrapolation): - nextTime_loc = self.time[0] + periods * (self.time_full[-1] - self.time_full[0]) - return nextTime_loc - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_not_loaded(self): return self._chunk_not_loaded - @property - def _chunk_not_loaded(self): - return 0 - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_loading_requested(self): return self._chunk_loading_requested - @property - def _chunk_loading_requested(self): - return 1 - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_loaded_touched(self): return self._chunk_loaded_touched - @property - def _chunk_loaded_touched(self): - return 2 - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_deprecated(self): return self._chunk_deprecated - @property - def _chunk_deprecated(self): - return 3 - @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def chunk_loaded(self): return self._chunk_loaded - @property - def _chunk_loaded(self): - return [2, 3] - class RectilinearGrid(Grid): """Rectilinear Grid class diff --git a/parcels/ugrid.py b/parcels/ugrid.py new file mode 100644 index 0000000000..896dacdc39 --- /dev/null +++ b/parcels/ugrid.py @@ -0,0 +1,87 @@ +import functools +import warnings +from ctypes import POINTER, Structure, c_double, c_float, c_int, c_void_p, cast, pointer +from enum import IntEnum + +import numpy as np +import numpy.typing as npt + +from parcels._typing import Mesh, UpdateStatus, assert_valid_mesh +from parcels.tools._helpers import deprecated_made_private +from parcels.tools.converters import TimeConverter +from parcels.tools.warnings import FieldSetWarning +from parcels.basegrid import BaseGrid +from parcels.basegrid import CGrid + +# Note : +# For variable placement in FESOM - see https://fesom2.readthedocs.io/en/latest/geometry.html +__all__ = [ + "UGridType", + "CGrid", + "UGrid" +] + +class UGrid(BaseGrid): + """Grid class that defines a (spatial and temporal) unstructured grid on which Fields are defined.""" + + def __init__(self,face_node_connectivity,*args, **kwargs): + self._face_node_connectivity = face_node_connectivity # nface_node_connectivity x n_nodes_per_element array listing the vertex ids of each face_node_connectivity. + super().__init__(*args, **kwargs) + + if not isinstance(self.lon, np.ndarray): + raise TypeError("lon must be a NumPy array.") + + if self.lon.ndim != 1: + raise ValueError("lon must be a 1-D array.") + + if not isinstance(self.lat, np.ndarray): + raise TypeError("lat must be a NumPy array.") + + if self.lat.ndim != 1: + raise ValueError("lat must be a 1-D array.") + + if self.lon.shape != self.lat.shape: + raise ValueError("lon and lat must have the same shape.") + + if not isinstance(self.face_node_connectivity, np.ndarray): + raise TypeError("face_node_connectivity must be a NumPy array.") + + if self.face_node_connectivity.ndim != 2: + raise ValueError("face_node_connectivity must be a 2-D array.") + + if self.face_node_connectivity.shape[1] > 3: # Enforce triangle face_node_connectivity + raise ValueError("face_node_connectivity must be a 2-D array with at most 3 columns.") + + self._n_face = self.face_node_connectivity.shape[0] + self._nodes_per_element = face_node_connectivity.shape[1] + self._n_vertices = self.lon.shape[0] + + # The lon and lat fields are assumed identical to the + # node_lon and node_lat fields in UXArray.Grid + # data structure. + @property + def node_lon(self): + return self.lon + + @property + def node_lat(self): + return self.lat + + @property + def face_node_connectivity(self): + return self._face_node_connectivity + + @property + def n_face(self): + return self._n_face + + @property + def nodes_per_element(self): + return self._nodes_per_element + + @property + def n_vertices(self): + return self._n_vertices + + + From 47ffe4a2096cacffd3ff57461c53b94a887c4792 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 8 Nov 2024 10:31:28 -0500 Subject: [PATCH 2/3] Add UField class to support fields on unstructured grids --- parcels/_typing.py | 2 +- parcels/basefield.py | 420 +++++++++++++++++++++++++++++++++++++++++++ parcels/field.py | 263 +-------------------------- parcels/ufield.py | 259 ++++++++++++++++++++++++++ parcels/ugrid.py | 20 ++- 5 files changed, 703 insertions(+), 261 deletions(-) create mode 100644 parcels/basefield.py create mode 100644 parcels/ufield.py diff --git a/parcels/_typing.py b/parcels/_typing.py index 92af7ba2fa..9e530c6be7 100644 --- a/parcels/_typing.py +++ b/parcels/_typing.py @@ -37,7 +37,7 @@ class ParcelsAST(ast.AST): Mesh = Literal["spherical", "flat"] # corresponds with `mesh` VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type` ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode` -GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `gridindexingtype` +GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo","fesom2","icon2"] # corresponds with `gridindexingtype` UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `_update_status` TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `time_periodic` NetcdfEngine = Literal["netcdf4", "xarray", "scipy"] diff --git a/parcels/basefield.py b/parcels/basefield.py new file mode 100644 index 0000000000..8ac4c41683 --- /dev/null +++ b/parcels/basefield.py @@ -0,0 +1,420 @@ +import collections +import datetime +import math +import warnings +from collections.abc import Iterable +from ctypes import POINTER, Structure, c_float, c_int, pointer +from pathlib import Path +from typing import TYPE_CHECKING + +import dask.array as da +import numpy as np +import xarray as xr + +import parcels.tools.interpolation_utils as i_u +from parcels._typing import ( + GridIndexingType, + InterpMethod, + Mesh, + TimePeriodic, + VectorType, + assert_valid_gridindexingtype, + assert_valid_interp_method, +) +from parcels.tools._helpers import deprecated_made_private +from parcels.tools.converters import ( + Geographic, + GeographicPolar, + TimeConverter, + UnitConverter, + unitconverters_map, +) +from parcels.tools.statuscodes import ( + AllParcelsErrorCodes, + FieldOutOfBoundError, + FieldOutOfBoundSurfaceError, + FieldSamplingError, + TimeExtrapolationError, +) +from parcels.tools.warnings import FieldSetWarning, _deprecated_param_netcdf_decodewarning + +from .fieldfilebuffer import ( + DaskFileBuffer, + DeferredDaskFileBuffer, + DeferredNetcdfFileBuffer, + NetcdfFileBuffer, +) +from .grid import CGrid, Grid, GridType +from .ugrid import UGrid + +if TYPE_CHECKING: + from ctypes import _Pointer as PointerType + + from parcels.fieldset import FieldSet + +__all__ = ["Field", "VectorField", "NestedField"] + + +def _isParticle(key): + if hasattr(key, "obs_written"): + return True + else: + return False + + +def _deal_with_errors(error, key, vector_type: VectorType): + if _isParticle(key): + key.state = AllParcelsErrorCodes[type(error)] + elif _isParticle(key[-1]): + key[-1].state = AllParcelsErrorCodes[type(error)] + else: + raise RuntimeError(f"{error}. Error could not be handled because particle was not part of the Field Sampling.") + + if vector_type == "3D": + return (0, 0, 0) + elif vector_type == "2D": + return (0, 0) + else: + return 0 + + +class BaseField: + """Class that encapsulates access to field data. + + Parameters + ---------- + name : str + Name of the field + data : np.ndarray + 2D, 3D or 4D numpy array of field data. + + 1. If data shape is [xdim, ydim], [xdim, ydim, zdim], [xdim, ydim, tdim] or [xdim, ydim, zdim, tdim], + whichever is relevant for the dataset, use the flag transpose=True + 2. If data shape is [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], + use the flag transpose=False + 3. If data has any other shape, you first need to reorder it + lon : np.ndarray or list + Longitude coordinates (numpy vector or array) of the field (only if grid is None) + lat : np.ndarray or list + Latitude coordinates (numpy vector or array) of the field (only if grid is None) + depth : np.ndarray or list + Depth coordinates (numpy vector or array) of the field (only if grid is None) + time : np.ndarray + Time coordinates (numpy vector) of the field (only if grid is None) + mesh : str + String indicating the type of mesh coordinates and + units used during velocity interpolation: (only if grid is None) + + 1. spherical: Lat and lon in degree, with a + correction for zonal velocity U near the poles. + 2. flat (default): No conversion, lat/lon are assumed to be in m. + timestamps : np.ndarray + A numpy array containing the timestamps for each of the files in filenames, for loading + from netCDF files only. Default is None if the netCDF dimensions dictionary includes time. + grid : parcels.grid.Grid + :class:`parcels.grid.Grid` object containing all the lon, lat depth, time + mesh and time_origin information. Can be constructed from any of the Grid objects + fieldtype : str + Type of Field to be used for UnitConverter (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) + transpose : bool + Transpose data to required (lon, lat) layout + vmin : float + Minimum allowed value on the field. Data below this value are set to zero + vmax : float + Maximum allowed value on the field. Data above this value are set to zero + cast_data_dtype : str + Cast Field data to dtype. Supported dtypes are "float32" (np.float32 (default)) and "float64 (np.float64). + Note that dtype can only be "float32" in JIT mode + time_origin : parcels.tools.converters.TimeConverter + Time origin of the time axis (only if grid is None) + interp_method : str + Method for interpolation. Options are 'linear' (default), 'nearest', + 'linear_invdist_land_tracer', 'cgrid_velocity', 'cgrid_tracer' and 'bgrid_velocity' + allow_time_extrapolation : bool + boolean whether to allow for extrapolation in time + (i.e. beyond the last available time snapshot) + time_periodic : bool, float or datetime.timedelta + To loop periodically over the time component of the Field. It is set to either False or the length of the period (either float in seconds or datetime.timedelta object). + The last value of the time series can be provided (which is the same as the initial one) or not (Default: False) + This flag overrides the allow_time_extrapolation and sets it to False + chunkdims_name_map : str, optional + Gives a name map to the FieldFileBuffer that declared a mapping between chunksize name, NetCDF dimension and Parcels dimension; + required only if currently incompatible OCM field is loaded and chunking is used by 'chunksize' (which is the default) + to_write : bool + Write the Field in NetCDF format at the same frequency as the ParticleFile outputdt, + using a filenaming scheme based on the ParticleFile name + + Examples + -------- + For usage examples see the following tutorials: + + * `Nested Fields <../examples/tutorial_NestedFields.ipynb>`__ + """ + + def __init__( + self, + name: str | tuple[str, str], + data, + lon=None, + lat=None, + face_node_connectivity=None, + depth=None, + time=None, + grid=None, + mesh: Mesh = "flat", + timestamps=None, + fieldtype=None, + transpose=False, + vmin=None, + vmax=None, + cast_data_dtype="float32", + time_origin=None, + interp_method: InterpMethod = "linear", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + gridindexingtype: GridIndexingType = "nemo", + to_write=False, + **kwargs, + ): + if kwargs.get("netcdf_decodewarning") is not None: + _deprecated_param_netcdf_decodewarning() + kwargs.pop("netcdf_decodewarning") + + if not isinstance(name, tuple): + self.name = name + self.filebuffername = name + else: + self.name = name[0] + self.filebuffername = name[1] + self.data = data + if grid: + if grid.defer_load and isinstance(data, np.ndarray): + raise ValueError( + "Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately" + ) + self._grid = grid + else: + if (time is not None) and isinstance(time[0], np.datetime64): + time_origin = TimeConverter(time[0]) + time = np.array([time_origin.reltime(t) for t in time]) + else: + time_origin = TimeConverter(0) + + # joe@fluidnumerics.com + # This allows for the creation of a UGrid object + if face_node_connectivity is None: + self._grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) + else: + self._grid = UGrid.create_grid(lon, lat, depth, time, face_node_connectivity, time_origin=time_origin, mesh=mesh) + + + self.igrid = -1 + self.fieldtype = self.name if fieldtype is None else fieldtype + self.to_write = to_write + if self.grid.mesh == "flat" or (self.fieldtype not in unitconverters_map.keys()): + self.units = UnitConverter() + elif self.grid.mesh == "spherical": + self.units = unitconverters_map[self.fieldtype] + else: + raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") + self.timestamps = timestamps + if isinstance(interp_method, dict): + if self.name in interp_method: + self.interp_method = interp_method[self.name] + else: + raise RuntimeError(f"interp_method is a dictionary but {name} is not in it") + else: + self.interp_method = interp_method + assert_valid_gridindexingtype(gridindexingtype) + self._gridindexingtype = gridindexingtype + if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"] and self.grid._gtype in [ + GridType.RectilinearSGrid, + GridType.CurvilinearSGrid, + ]: + warnings.warn( + "General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal.", + FieldSetWarning, + stacklevel=2, + ) + + self.fieldset: FieldSet | None = None + if allow_time_extrapolation is None: + self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False + else: + self.allow_time_extrapolation = allow_time_extrapolation + + self.time_periodic = time_periodic + if self.time_periodic is not False and self.allow_time_extrapolation: + warnings.warn( + "allow_time_extrapolation and time_periodic cannot be used together. allow_time_extrapolation is set to False", + FieldSetWarning, + stacklevel=2, + ) + self.allow_time_extrapolation = False + if self.time_periodic is True: + raise ValueError( + "Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object." + ) + if self.time_periodic is not False: + if isinstance(self.time_periodic, datetime.timedelta): + self.time_periodic = self.time_periodic.total_seconds() + if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic): + if self.grid.time[-1] - self.grid.time[0] > self.time_periodic: + raise ValueError("Time series provided is longer than the time_periodic parameter") + self.grid._add_last_periodic_data_timestep = True + self.grid.time = np.append(self.grid.time, self.grid.time[0] + self.time_periodic) + self.grid.time_full = self.grid.time + + self.vmin = vmin + self.vmax = vmax + self._cast_data_dtype = cast_data_dtype + if self.cast_data_dtype == "float32": + self._cast_data_dtype = np.float32 + elif self.cast_data_dtype == "float64": + self._cast_data_dtype = np.float64 + + if not self.grid.defer_load: + self.data = self._reshape(self.data, transpose) + + # Hack around the fact that NaN and ridiculously large values + # propagate in SciPy's interpolators + lib = np if isinstance(self.data, np.ndarray) else da + self.data[lib.isnan(self.data)] = 0.0 + if self.vmin is not None: + self.data[self.data < self.vmin] = 0.0 + if self.vmax is not None: + self.data[self.data > self.vmax] = 0.0 + + if self.grid._add_last_periodic_data_timestep: + self.data = lib.concatenate((self.data, self.data[:1, :]), axis=0) + + self._scaling_factor = None + + # Variable names in JIT code + self._dimensions = kwargs.pop("dimensions", None) + self.indices = kwargs.pop("indices", None) + self._dataFiles = kwargs.pop("dataFiles", None) + if self.grid._add_last_periodic_data_timestep and self._dataFiles is not None: + self._dataFiles = np.append(self._dataFiles, self._dataFiles[0]) + self._field_fb_class = kwargs.pop("FieldFileBuffer", None) + self._netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") + self._loaded_time_indices: Iterable[int] = [] # type: ignore + self._creation_log = kwargs.pop("creation_log", "") + self.chunksize = kwargs.pop("chunksize", None) + self.netcdf_chunkdims_name_map = kwargs.pop("chunkdims_name_map", None) + self.grid.depth_field = kwargs.pop("depth_field", None) + + if self.grid.depth_field == "not_yet_set": + assert ( + self.grid._z4d + ), "Providing the depth dimensions from another field data is only available for 4d S grids" + + # data_full_zdim is the vertical dimension of the complete field data, ignoring the indices. + # (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid, + # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). + self.data_full_zdim = kwargs.pop("data_full_zdim", None) + self._data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays + self._c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array + self.nchunks: tuple[int, ...] = () + self._chunk_set: bool = False + self.filebuffers = [None] * 2 + if len(kwargs) > 0: + raise SyntaxError(f'Field received an unexpected keyword argument "{list(kwargs.keys())[0]}"') + + @property + def dimensions(self): + return self._dimensions + + @property + def grid(self): + return self._grid + + @property + def lon(self): + """Lon defined on the Grid object""" + return self.grid.lon + + @property + def lat(self): + """Lat defined on the Grid object""" + return self.grid.lat + + @property + def depth(self): + """Depth defined on the Grid object""" + return self.grid.depth + + @property + def cell_edge_sizes(self): + return self.grid.cell_edge_sizes + + @property + def interp_method(self): + return self._interp_method + + @interp_method.setter + def interp_method(self, value): + assert_valid_interp_method(value) + self._interp_method = value + + @property + def gridindexingtype(self): + return self._gridindexingtype + + @property + def cast_data_dtype(self): + return self._cast_data_dtype + + @property + def netcdf_engine(self): + return self._netcdf_engine + + @classmethod + def _get_dim_filenames(cls, filenames, dim): + if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): + return [filenames] + elif isinstance(filenames, dict): + assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" + filename = filenames[dim] + if isinstance(filename, str): + return [filename] + else: + return filename + else: + return filenames + + @staticmethod + def _collect_timeslices( + timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine, netcdf_decodewarning=None + ): + if netcdf_decodewarning is not None: + _deprecated_param_netcdf_decodewarning() + if timestamps is not None: + dataFiles = [] + for findex in range(len(data_filenames)): + stamps_in_file = 1 if isinstance(timestamps[findex], (int, np.datetime64)) else len(timestamps[findex]) + for f in [data_filenames[findex]] * stamps_in_file: + dataFiles.append(f) + timeslices = np.array([stamp for file in timestamps for stamp in file]) + time = timeslices + else: + timeslices = [] + dataFiles = [] + for fname in data_filenames: + with _grid_fb_class(fname, dimensions, indices, netcdf_engine=netcdf_engine) as filebuffer: + ftime = filebuffer.time + timeslices.append(ftime) + dataFiles.append([fname] * len(ftime)) + time = np.concatenate(timeslices).ravel() + dataFiles = np.concatenate(dataFiles).ravel() + if time.size == 1 and time[0] is None: + time[0] = 0 + time_origin = TimeConverter(time[0]) + time = time_origin.reltime(time) + + if not np.all((time[1:] - time[:-1]) > 0): + id_not_ordered = np.where(time[1:] < time[:-1])[0][0] + raise AssertionError( + f"Please make sure your netCDF files are ordered in time. First pair of non-ordered files: {dataFiles[id_not_ordered]}, {dataFiles[id_not_ordered + 1]}" + ) + return time, time_origin, timeslices, dataFiles \ No newline at end of file diff --git a/parcels/field.py b/parcels/field.py index df3f6b991e..a1c6799ae8 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -37,6 +37,7 @@ TimeExtrapolationError, ) from parcels.tools.warnings import FieldSetWarning, _deprecated_param_netcdf_decodewarning +from parcels.basefield import BaseField from .fieldfilebuffer import ( DaskFileBuffer, @@ -77,7 +78,7 @@ def _deal_with_errors(error, key, vector_type: VectorType): return 0 -class Field: +class Field(BaseField): """Class that encapsulates access to field data. Parameters @@ -150,166 +151,8 @@ class Field: * `Nested Fields <../examples/tutorial_NestedFields.ipynb>`__ """ - def __init__( - self, - name: str | tuple[str, str], - data, - lon=None, - lat=None, - depth=None, - time=None, - grid=None, - mesh: Mesh = "flat", - timestamps=None, - fieldtype=None, - transpose=False, - vmin=None, - vmax=None, - cast_data_dtype="float32", - time_origin=None, - interp_method: InterpMethod = "linear", - allow_time_extrapolation: bool | None = None, - time_periodic: TimePeriodic = False, - gridindexingtype: GridIndexingType = "nemo", - to_write=False, - **kwargs, - ): - if kwargs.get("netcdf_decodewarning") is not None: - _deprecated_param_netcdf_decodewarning() - kwargs.pop("netcdf_decodewarning") - - if not isinstance(name, tuple): - self.name = name - self.filebuffername = name - else: - self.name = name[0] - self.filebuffername = name[1] - self.data = data - if grid: - if grid.defer_load and isinstance(data, np.ndarray): - raise ValueError( - "Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately" - ) - self._grid = grid - else: - if (time is not None) and isinstance(time[0], np.datetime64): - time_origin = TimeConverter(time[0]) - time = np.array([time_origin.reltime(t) for t in time]) - else: - time_origin = TimeConverter(0) - self._grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) - self.igrid = -1 - self.fieldtype = self.name if fieldtype is None else fieldtype - self.to_write = to_write - if self.grid.mesh == "flat" or (self.fieldtype not in unitconverters_map.keys()): - self.units = UnitConverter() - elif self.grid.mesh == "spherical": - self.units = unitconverters_map[self.fieldtype] - else: - raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") - self.timestamps = timestamps - if isinstance(interp_method, dict): - if self.name in interp_method: - self.interp_method = interp_method[self.name] - else: - raise RuntimeError(f"interp_method is a dictionary but {name} is not in it") - else: - self.interp_method = interp_method - assert_valid_gridindexingtype(gridindexingtype) - self._gridindexingtype = gridindexingtype - if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"] and self.grid._gtype in [ - GridType.RectilinearSGrid, - GridType.CurvilinearSGrid, - ]: - warnings.warn( - "General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal.", - FieldSetWarning, - stacklevel=2, - ) - - self.fieldset: FieldSet | None = None - if allow_time_extrapolation is None: - self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False - else: - self.allow_time_extrapolation = allow_time_extrapolation - - self.time_periodic = time_periodic - if self.time_periodic is not False and self.allow_time_extrapolation: - warnings.warn( - "allow_time_extrapolation and time_periodic cannot be used together. allow_time_extrapolation is set to False", - FieldSetWarning, - stacklevel=2, - ) - self.allow_time_extrapolation = False - if self.time_periodic is True: - raise ValueError( - "Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object." - ) - if self.time_periodic is not False: - if isinstance(self.time_periodic, datetime.timedelta): - self.time_periodic = self.time_periodic.total_seconds() - if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic): - if self.grid.time[-1] - self.grid.time[0] > self.time_periodic: - raise ValueError("Time series provided is longer than the time_periodic parameter") - self.grid._add_last_periodic_data_timestep = True - self.grid.time = np.append(self.grid.time, self.grid.time[0] + self.time_periodic) - self.grid.time_full = self.grid.time - - self.vmin = vmin - self.vmax = vmax - self._cast_data_dtype = cast_data_dtype - if self.cast_data_dtype == "float32": - self._cast_data_dtype = np.float32 - elif self.cast_data_dtype == "float64": - self._cast_data_dtype = np.float64 - - if not self.grid.defer_load: - self.data = self._reshape(self.data, transpose) - - # Hack around the fact that NaN and ridiculously large values - # propagate in SciPy's interpolators - lib = np if isinstance(self.data, np.ndarray) else da - self.data[lib.isnan(self.data)] = 0.0 - if self.vmin is not None: - self.data[self.data < self.vmin] = 0.0 - if self.vmax is not None: - self.data[self.data > self.vmax] = 0.0 - - if self.grid._add_last_periodic_data_timestep: - self.data = lib.concatenate((self.data, self.data[:1, :]), axis=0) - - self._scaling_factor = None - - # Variable names in JIT code - self._dimensions = kwargs.pop("dimensions", None) - self.indices = kwargs.pop("indices", None) - self._dataFiles = kwargs.pop("dataFiles", None) - if self.grid._add_last_periodic_data_timestep and self._dataFiles is not None: - self._dataFiles = np.append(self._dataFiles, self._dataFiles[0]) - self._field_fb_class = kwargs.pop("FieldFileBuffer", None) - self._netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") - self._loaded_time_indices: Iterable[int] = [] # type: ignore - self._creation_log = kwargs.pop("creation_log", "") - self.chunksize = kwargs.pop("chunksize", None) - self.netcdf_chunkdims_name_map = kwargs.pop("chunkdims_name_map", None) - self.grid.depth_field = kwargs.pop("depth_field", None) - - if self.grid.depth_field == "not_yet_set": - assert ( - self.grid._z4d - ), "Providing the depth dimensions from another field data is only available for 4d S grids" - - # data_full_zdim is the vertical dimension of the complete field data, ignoring the indices. - # (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid, - # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). - self.data_full_zdim = kwargs.pop("data_full_zdim", None) - self._data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays - self._c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array - self.nchunks: tuple[int, ...] = () - self._chunk_set: bool = False - self.filebuffers = [None] * 2 - if len(kwargs) > 0: - raise SyntaxError(f'Field received an unexpected keyword argument "{list(kwargs.keys())[0]}"') + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) @property @deprecated_made_private # TODO: Remove 6 months after v3.1.0 @@ -341,114 +184,16 @@ def creation_log(self): def loaded_time_indices(self): return self._loaded_time_indices - @property - def dimensions(self): - return self._dimensions - - @property - def grid(self): - return self._grid - - @property - def lon(self): - """Lon defined on the Grid object""" - return self.grid.lon - - @property - def lat(self): - """Lat defined on the Grid object""" - return self.grid.lat - - @property - def depth(self): - """Depth defined on the Grid object""" - return self.grid.depth - - @property - def cell_edge_sizes(self): - return self.grid.cell_edge_sizes - - @property - def interp_method(self): - return self._interp_method - - @interp_method.setter - def interp_method(self, value): - assert_valid_interp_method(value) - self._interp_method = value - - @property - def gridindexingtype(self): - return self._gridindexingtype - - @property - def cast_data_dtype(self): - return self._cast_data_dtype - - @property - def netcdf_engine(self): - return self._netcdf_engine - @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def get_dim_filenames(cls, *args, **kwargs): return cls._get_dim_filenames(*args, **kwargs) - @classmethod - def _get_dim_filenames(cls, filenames, dim): - if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): - return [filenames] - elif isinstance(filenames, dict): - assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" - filename = filenames[dim] - if isinstance(filename, str): - return [filename] - else: - return filename - else: - return filenames - @staticmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def collect_timeslices(*args, **kwargs): return Field._collect_timeslices(*args, **kwargs) - @staticmethod - def _collect_timeslices( - timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine, netcdf_decodewarning=None - ): - if netcdf_decodewarning is not None: - _deprecated_param_netcdf_decodewarning() - if timestamps is not None: - dataFiles = [] - for findex in range(len(data_filenames)): - stamps_in_file = 1 if isinstance(timestamps[findex], (int, np.datetime64)) else len(timestamps[findex]) - for f in [data_filenames[findex]] * stamps_in_file: - dataFiles.append(f) - timeslices = np.array([stamp for file in timestamps for stamp in file]) - time = timeslices - else: - timeslices = [] - dataFiles = [] - for fname in data_filenames: - with _grid_fb_class(fname, dimensions, indices, netcdf_engine=netcdf_engine) as filebuffer: - ftime = filebuffer.time - timeslices.append(ftime) - dataFiles.append([fname] * len(ftime)) - time = np.concatenate(timeslices).ravel() - dataFiles = np.concatenate(dataFiles).ravel() - if time.size == 1 and time[0] is None: - time[0] = 0 - time_origin = TimeConverter(time[0]) - time = time_origin.reltime(time) - - if not np.all((time[1:] - time[:-1]) > 0): - id_not_ordered = np.where(time[1:] < time[:-1])[0][0] - raise AssertionError( - f"Please make sure your netCDF files are ordered in time. First pair of non-ordered files: {dataFiles[id_not_ordered]}, {dataFiles[id_not_ordered + 1]}" - ) - return time, time_origin, timeslices, dataFiles - @classmethod def from_netcdf( cls, diff --git a/parcels/ufield.py b/parcels/ufield.py new file mode 100644 index 0000000000..8d433755bb --- /dev/null +++ b/parcels/ufield.py @@ -0,0 +1,259 @@ +import collections +import datetime +import math +import warnings +from collections.abc import Iterable +from ctypes import POINTER, Structure, c_float, c_int, pointer +from pathlib import Path +from typing import TYPE_CHECKING + +import dask.array as da +import numpy as np +import xarray as xr + +import parcels.tools.interpolation_utils as i_u +from parcels._typing import ( + GridIndexingType, + InterpMethod, + Mesh, + TimePeriodic, + VectorType, + assert_valid_gridindexingtype, + assert_valid_interp_method, +) +from parcels.tools._helpers import deprecated_made_private +from parcels.tools.converters import ( + Geographic, + GeographicPolar, + TimeConverter, + UnitConverter, + unitconverters_map, +) +from parcels.tools.statuscodes import ( + AllParcelsErrorCodes, + FieldOutOfBoundError, + FieldOutOfBoundSurfaceError, + FieldSamplingError, + TimeExtrapolationError, +) +from parcels.tools.warnings import FieldSetWarning, _deprecated_param_netcdf_decodewarning + +from .fieldfilebuffer import ( + DaskFileBuffer, + DeferredDaskFileBuffer, + DeferredNetcdfFileBuffer, + NetcdfFileBuffer, +) +from .grid import CGrid, Grid, GridType + +if TYPE_CHECKING: + from ctypes import _Pointer as PointerType + + from parcels.fieldset import FieldSet + + + +__all__ = ["UField", "UVectorField", "UNestedField"] + + +def _isParticle(key): + if hasattr(key, "obs_written"): + return True + else: + return False + + +def _deal_with_errors(error, key, vector_type: VectorType): + if _isParticle(key): + key.state = AllParcelsErrorCodes[type(error)] + elif _isParticle(key[-1]): + key[-1].state = AllParcelsErrorCodes[type(error)] + else: + raise RuntimeError(f"{error}. Error could not be handled because particle was not part of the Field Sampling.") + + if vector_type == "3D": + return (0, 0, 0) + elif vector_type == "2D": + return (0, 0) + else: + return 0 + +class UField: + """Class that encapsulates access to field data. + + Parameters + ---------- + name : str + Name of the field + data : np.ndarray + 2D, 3D or 4D numpy array of field data. + + 1. If data shape is [xdim, ydim], [xdim, ydim, zdim], [xdim, ydim, tdim] or [xdim, ydim, zdim, tdim], + whichever is relevant for the dataset, use the flag transpose=True + 2. If data shape is [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], + use the flag transpose=False + 3. If data has any other shape, you first need to reorder it + timestamps : np.ndarray + A numpy array containing the timestamps for each of the files in filenames, for loading + from netCDF files only. Default is None if the netCDF dimensions dictionary includes time. + grid : parcels.ugrid.UGrid + :class:`parcels.ugrid.UGrid` object containing all the lon, lat depth, time + mesh and time_origin information. Can be constructed from any of the Grid objects + fieldtype : str + Type of Field to be used for UnitConverter (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) + transpose : bool + Transpose data to required (lon, lat) layout + vmin : float + Minimum allowed value on the field. Data below this value are set to zero + vmax : float + Maximum allowed value on the field. Data above this value are set to zero + cast_data_dtype : str + Cast Field data to dtype. Supported dtypes are "float32" (np.float32 (default)) and "float64 (np.float64). + Note that dtype can only be "float32" in JIT mode + time_origin : parcels.tools.converters.TimeConverter + Time origin of the time axis (only if grid is None) + interp_method : str + Method for interpolation. Options are 'linear' (default), 'nearest', + 'linear_invdist_land_tracer', 'cgrid_velocity', 'cgrid_tracer' and 'bgrid_velocity' + allow_time_extrapolation : bool + boolean whether to allow for extrapolation in time + (i.e. beyond the last available time snapshot) + time_periodic : bool, float or datetime.timedelta + To loop periodically over the time component of the Field. It is set to either False or the length of the period (either float in seconds or datetime.timedelta object). + The last value of the time series can be provided (which is the same as the initial one) or not (Default: False) + This flag overrides the allow_time_extrapolation and sets it to False + chunkdims_name_map : str, optional + Gives a name map to the FieldFileBuffer that declared a mapping between chunksize name, NetCDF dimension and Parcels dimension; + required only if currently incompatible OCM field is loaded and chunking is used by 'chunksize' (which is the default) + to_write : bool + Write the Field in NetCDF format at the same frequency as the ParticleFile outputdt, + using a filenaming scheme based on the ParticleFile name + + Examples + -------- + + """ + + def __init__( + self, + name: str | tuple[str, str], + data, + grid, + timestamps=None, + fieldtype=None, + transpose=False, + vmin=None, + vmax=None, + cast_data_dtype="float32", + time_origin=None, + interp_method: InterpMethod = "linear", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + gridindexingtype: GridIndexingType = "fesom2", + to_write=False, + **kwargs, + ): + + if kwargs.get("netcdf_decodewarning") is not None: + _deprecated_param_netcdf_decodewarning() + kwargs.pop("netcdf_decodewarning") + + if not isinstance(name, tuple): + self.name = name + self.filebuffername = name + else: + self.name = name[0] + self.filebuffername = name[1] + self.data = data + + # 11/8/2024 - joe@fluidnumerics.com + # This conditional block is commented out since the UGrid class currently does + # not have a create_grid method. + #if grid: + if grid.defer_load and isinstance(data, np.ndarray): + raise ValueError( + "Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately" + ) + self._grid = grid + #else: + # if (time is not None) and isinstance(time[0], np.datetime64): + # time_origin = TimeConverter(time[0]) + # time = np.array([time_origin.reltime(t) for t in time]) + # else: + # time_origin = TimeConverter(0) + # self._grid = Grid.create_grid(lon, lat, face_node_connectivity, depth, time, time_origin=time_origin, mesh=mesh) + + self.igrid = -1 + self.fieldtype = self.name if fieldtype is None else fieldtype + self.to_write = to_write + if self.grid.mesh == "flat" or (self.fieldtype not in unitconverters_map.keys()): + self.units = UnitConverter() + elif self.grid.mesh == "spherical": + self.units = unitconverters_map[self.fieldtype] + else: + raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") + self.timestamps = timestamps + + if isinstance(interp_method, dict): + if self.name in interp_method: + self.interp_method = interp_method[self.name] + else: + raise RuntimeError(f"interp_method is a dictionary but {name} is not in it") + else: + self.interp_method = interp_method + assert_valid_gridindexingtype(gridindexingtype) + self._gridindexingtype = gridindexingtype + + + self.fieldset: FieldSet | None = None + if allow_time_extrapolation is None: + self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False + else: + self.allow_time_extrapolation = allow_time_extrapolation\ + + self.time_periodic = time_periodic + if self.time_periodic is not False and self.allow_time_extrapolation: + warnings.warn( + "allow_time_extrapolation and time_periodic cannot be used together. allow_time_extrapolation is set to False", + FieldSetWarning, + stacklevel=2, + ) + self.allow_time_extrapolation = False + if self.time_periodic is True: + raise ValueError( + "Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object." + ) + if self.time_periodic is not False: + if isinstance(self.time_periodic, datetime.timedelta): + self.time_periodic = self.time_periodic.total_seconds() + if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic): + if self.grid.time[-1] - self.grid.time[0] > self.time_periodic: + raise ValueError("Time series provided is longer than the time_periodic parameter") + self.grid._add_last_periodic_data_timestep = True + self.grid.time = np.append(self.grid.time, self.grid.time[0] + self.time_periodic) + self.grid.time_full = self.grid.time + + self.vmin = vmin + self.vmax = vmax + self._cast_data_dtype = cast_data_dtype + if self.cast_data_dtype == "float32": + self._cast_data_dtype = np.float32 + elif self.cast_data_dtype == "float64": + self._cast_data_dtype = np.float64 + + if not self.grid.defer_load: + self.data = self._reshape(self.data, transpose) + + # Hack around the fact that NaN and ridiculously large values + # propagate in SciPy's interpolators + lib = np if isinstance(self.data, np.ndarray) else da + self.data[lib.isnan(self.data)] = 0.0 + if self.vmin is not None: + self.data[self.data < self.vmin] = 0.0 + if self.vmax is not None: + self.data[self.data > self.vmax] = 0.0 + + if self.grid._add_last_periodic_data_timestep: + self.data = lib.concatenate((self.data, self.data[:1, :]), axis=0) + + self._scaling_factor = None \ No newline at end of file diff --git a/parcels/ugrid.py b/parcels/ugrid.py index 896dacdc39..59ac7991e7 100644 --- a/parcels/ugrid.py +++ b/parcels/ugrid.py @@ -16,7 +16,6 @@ # Note : # For variable placement in FESOM - see https://fesom2.readthedocs.io/en/latest/geometry.html __all__ = [ - "UGridType", "CGrid", "UGrid" ] @@ -85,3 +84,22 @@ def n_vertices(self): + @staticmethod + def create_grid( + lon: npt.ArrayLike, + lat: npt.ArrayLike, + face_node_connectivity: npt.ArrayLike, + depth, + time, + time_origin, + mesh: Mesh, + **kwargs, + ): + lon = np.array(lon) + lat = np.array(lat) + face_node_connectivity = np.array(face_node_connectivity) + + if depth is not None: + depth = np.array(depth) + + return UGrid(lon, lat, face_node_connectivity, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) From 429a591e872885db6a6c173f220bad94e19cb9db Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 25 Nov 2024 10:41:44 -0500 Subject: [PATCH 3/3] Add hashgrid class with setup methods. Add barycentric coordinate calc for ugrid In the ufield class, I've sketched out the methods that we need to define in order to do 2D and 3D interpolation. Once these are in place, I'll kick off some examples that "regrid" onto structured grid and interpolate onto random particles. --- parcels/hashgrid.py | 77 ++++++++++++++ parcels/ufield.py | 253 ++++++++++++++++++++++---------------------- parcels/ugrid.py | 78 +++++++++++++- 3 files changed, 281 insertions(+), 127 deletions(-) create mode 100644 parcels/hashgrid.py diff --git a/parcels/hashgrid.py b/parcels/hashgrid.py new file mode 100644 index 0000000000..eb402f1601 --- /dev/null +++ b/parcels/hashgrid.py @@ -0,0 +1,77 @@ + + +class HashGrid: + """Class for creating a hash grid for fast particle lookup. + + Parameters + ---------- + x0 : float + x-coordinate of the lower left corner of the grid. + y0 : float + y-coordinate of the lower left corner of the grid. + dx : float + Grid spacing in x-direction. + dy : float + Grid spacing in y-direction. + nx : int + Number of grid points in x-direction. + ny : int + Number of grid points in y-direction + ugrid_elements : list + List of lists of unstructured grid indices in each hash cell. + """ + + def __init__(self,x0,y0,dx,dy,nx,ny): + self.x0 = x0 + self.y0 = y0 + self.dx = dx + self.dy = dy + self.nx = nx + self.ny = ny + ugrid_elements = [[] for i in range(self.nx*self.ny)] + + + + def get_hashindex_for_xy(self,x,y): + """Get the grid indices for a given x and y coordinate.""" + i = int((x-self.x0)/self.dx) + j = int((y-self.y0)/self.dy) + return i+self.nx*j + + + def populate_ugrid_elements(self,vertices, elements): + """ + Efficiently find the list of triangles whose bounding box overlaps with the specified hash cells. + + Parameters: + - vertices (np.ndarray): Array of vertex coordinates of shape (n_vertices, 2). + - elements (np.ndarray): Array of element-to-vertex connectivity, where each row contains 3 indices into the vertices array. + + Returns: + - overlapping_triangles (dict): A dictionary where keys are the hash cell index and values are lists of triangle indices. + """ + import numpy as np + + overlapping_triangles = [[] for i in range(self.nx*self.ny)] + + # Loop over each triangle element + for triangle_idx, triangle in enumerate(elements): + # Get the coordinates of the triangle's vertices + triangle_vertices = vertices[triangle] + + # Calculate the bounding box of the triangle + x_min, y_min = np.min(triangle_vertices, axis=0) + x_max, y_max = np.max(triangle_vertices, axis=0) + + # Find the hash cell range that overlaps with the triangle's bounding box + i_min = int(np.floor((x_min-self.x0) / self.dx)) + i_max = int(np.floor((x_max-self.x0) / self.dx)) + j_min = int(np.floor((y_min-self.y0) / self.dy)) + j_max = int(np.floor((y_max-self.y0) / self.dy)) + + # Iterate over all hash cells that intersect the bounding box + for j in range(j_min, j_max + 1): + for i in range(i_min, i_max + 1): + overlapping_triangles[i+self.nx*j].append(triangle_idx) + + self.ugrid_elements = overlapping_triangles \ No newline at end of file diff --git a/parcels/ufield.py b/parcels/ufield.py index 8d433755bb..c983a4ac3a 100644 --- a/parcels/ufield.py +++ b/parcels/ufield.py @@ -78,7 +78,7 @@ def _deal_with_errors(error, key, vector_type: VectorType): else: return 0 -class UField: +class UField(BaseField): """Class that encapsulates access to field data. Parameters @@ -93,11 +93,28 @@ class UField: 2. If data shape is [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], use the flag transpose=False 3. If data has any other shape, you first need to reorder it + lon : np.ndarray or list + Longitude coordinates (numpy vector or array) of the field (only if grid is None) + lat : np.ndarray or list + Latitude coordinates (numpy vector or array) of the field (only if grid is None) + face_node_connectivity: np.ndarray or list dimensioned [nfaces, max_nodes_per_face] + Connectivity array between faces and nodes (only if grid is None) + depth : np.ndarray or list + Depth coordinates (numpy vector or array) of the field (only if grid is None) + time : np.ndarray + Time coordinates (numpy vector) of the field (only if grid is None) + mesh : str + String indicating the type of mesh coordinates and + units used during velocity interpolation: (only if grid is None) + + 1. spherical: Lat and lon in degree, with a + correction for zonal velocity U near the poles. + 2. flat (default): No conversion, lat/lon are assumed to be in m. timestamps : np.ndarray A numpy array containing the timestamps for each of the files in filenames, for loading from netCDF files only. Default is None if the netCDF dimensions dictionary includes time. - grid : parcels.ugrid.UGrid - :class:`parcels.ugrid.UGrid` object containing all the lon, lat depth, time + grid : parcels.grid.Grid + :class:`parcels.grid.Grid` object containing all the lon, lat depth, time mesh and time_origin information. Can be constructed from any of the Grid objects fieldtype : str Type of Field to be used for UnitConverter (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) @@ -134,126 +151,112 @@ class UField: """ - def __init__( - self, - name: str | tuple[str, str], - data, - grid, - timestamps=None, - fieldtype=None, - transpose=False, - vmin=None, - vmax=None, - cast_data_dtype="float32", - time_origin=None, - interp_method: InterpMethod = "linear", - allow_time_extrapolation: bool | None = None, - time_periodic: TimePeriodic = False, - gridindexingtype: GridIndexingType = "fesom2", - to_write=False, - **kwargs, - ): + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + @property + def face_node_connectivity(self): + return self.grid.face_node_connectivity - if kwargs.get("netcdf_decodewarning") is not None: - _deprecated_param_netcdf_decodewarning() - kwargs.pop("netcdf_decodewarning") - - if not isinstance(name, tuple): - self.name = name - self.filebuffername = name - else: - self.name = name[0] - self.filebuffername = name[1] - self.data = data - - # 11/8/2024 - joe@fluidnumerics.com - # This conditional block is commented out since the UGrid class currently does - # not have a create_grid method. - #if grid: - if grid.defer_load and isinstance(data, np.ndarray): - raise ValueError( - "Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately" - ) - self._grid = grid - #else: - # if (time is not None) and isinstance(time[0], np.datetime64): - # time_origin = TimeConverter(time[0]) - # time = np.array([time_origin.reltime(t) for t in time]) - # else: - # time_origin = TimeConverter(0) - # self._grid = Grid.create_grid(lon, lat, face_node_connectivity, depth, time, time_origin=time_origin, mesh=mesh) - - self.igrid = -1 - self.fieldtype = self.name if fieldtype is None else fieldtype - self.to_write = to_write - if self.grid.mesh == "flat" or (self.fieldtype not in unitconverters_map.keys()): - self.units = UnitConverter() - elif self.grid.mesh == "spherical": - self.units = unitconverters_map[self.fieldtype] - else: - raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") - self.timestamps = timestamps - - if isinstance(interp_method, dict): - if self.name in interp_method: - self.interp_method = interp_method[self.name] - else: - raise RuntimeError(f"interp_method is a dictionary but {name} is not in it") - else: - self.interp_method = interp_method - assert_valid_gridindexingtype(gridindexingtype) - self._gridindexingtype = gridindexingtype - - - self.fieldset: FieldSet | None = None - if allow_time_extrapolation is None: - self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False - else: - self.allow_time_extrapolation = allow_time_extrapolation\ - - self.time_periodic = time_periodic - if self.time_periodic is not False and self.allow_time_extrapolation: - warnings.warn( - "allow_time_extrapolation and time_periodic cannot be used together. allow_time_extrapolation is set to False", - FieldSetWarning, - stacklevel=2, - ) - self.allow_time_extrapolation = False - if self.time_periodic is True: - raise ValueError( - "Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object." - ) - if self.time_periodic is not False: - if isinstance(self.time_periodic, datetime.timedelta): - self.time_periodic = self.time_periodic.total_seconds() - if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic): - if self.grid.time[-1] - self.grid.time[0] > self.time_periodic: - raise ValueError("Time series provided is longer than the time_periodic parameter") - self.grid._add_last_periodic_data_timestep = True - self.grid.time = np.append(self.grid.time, self.grid.time[0] + self.time_periodic) - self.grid.time_full = self.grid.time - - self.vmin = vmin - self.vmax = vmax - self._cast_data_dtype = cast_data_dtype - if self.cast_data_dtype == "float32": - self._cast_data_dtype = np.float32 - elif self.cast_data_dtype == "float64": - self._cast_data_dtype = np.float64 - - if not self.grid.defer_load: - self.data = self._reshape(self.data, transpose) - - # Hack around the fact that NaN and ridiculously large values - # propagate in SciPy's interpolators - lib = np if isinstance(self.data, np.ndarray) else da - self.data[lib.isnan(self.data)] = 0.0 - if self.vmin is not None: - self.data[self.data < self.vmin] = 0.0 - if self.vmax is not None: - self.data[self.data > self.vmax] = 0.0 - - if self.grid._add_last_periodic_data_timestep: - self.data = lib.concatenate((self.data, self.data[:1, :]), axis=0) - - self._scaling_factor = None \ No newline at end of file + # To do + #@classmethod + #def from_netcdf() + # Likely want to use uxarray for this + + # To do + #@classmethod + #def from_uxarray() + + # def _reshape(self, data, transpose=False): + + # def set_scaling_factor(self, factor): + # """Scales the field data by some constant factor. + + # Parameters + # ---------- + # factor : + # scaling factor + + + # Examples + # -------- + # For usage examples see the following tutorial: + + # * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ + # """ + + + # def set_depth_from_field(self, field): + # """Define the depth dimensions from another (time-varying) field. + + # Notes + # ----- + # See `this tutorial <../examples/tutorial_timevaryingdepthdimensions.ipynb>`__ + # for a detailed explanation on how to set up time-evolving depth dimensions. + + # """ + # self.grid.depth_field = field + # if self.grid != field.grid: + # field.grid.depth_field = field + + # def _calc_cell_edge_sizes(self): + # """Method to calculate cell sizes based on numpy.gradient method. + + # Currently only works for Rectilinear Grids + # """ + + # def cell_areas(self): + # """Method to calculate cell sizes based on cell_edge_sizes. + + # Currently only works for Rectilinear Grids + # """ + + # def _search_indices_vertical_s( + # self, x: float, y: float, z: float, xi: int, yi: int, xsi: float, eta: float, ti: int, time: float + # ): + + # def _reconnect_bnd_indices(self, xi, yi, xdim, ydim, sphere_mesh): + + # def _search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): + + # def _search_indices(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): + + # Do hash table search based on x,y location + # Get list of elements to check + # Loop over elements, check if particle is in element (barycentric coordinate calc) + # + # (l1,l2,l3, inside_element) = self.barycentric_coordinates(x, y) + # + # return barycentric coordinates (2d) + # vertical interpolation weight + # element index + # nearest vertical layer index *above* particle (lower k bound); particle is between layer k and k+1 + + #def _interpolator2D(self, ti, z, y, x, particle=None): + # """Interpolation method for 2D UField. The UField.data is assumed to + # be provided at the ugrid vertices. The method uses either nearest + # neighbor or linear (barycentric) interpolation """ + + # (bc, _, ei, _) = self._search_indices(x, y, z, ti, particle=particle) + + # if self.interp_method == "nearest" + # idxs = self.face_node_connectivity[ei] + # vi = idxs[np.argmax(bc)] + # return self.data[ti,vi] + # + # Do nearest neighbour interpolation using vertex with largest barycentric coordinate + # elif self.interp_method == "linear" + # Do barycentric interpolation + + # def _interpolator3D(self, ti, z, y, x, time, particle=None): + # (bc, zeta, ei, zi) = self._search_indices(x, y, z, ti, particle=particle) + + # if self.interp_method == "nearest" + # idxs = self.face_node_connectivity[ei] + # vi = idxs[np.argmax(bc)] + # zii = zi if zeta <= 0.5 else zi + 1 + # return self.data[ti,zi,vi] + # + # Do nearest neighbour interpolation using vertex with largest barycentric coordinate + # elif self.interp_method == "linear" + # Do barycentric interpolation diff --git a/parcels/ugrid.py b/parcels/ugrid.py index 59ac7991e7..197f406724 100644 --- a/parcels/ugrid.py +++ b/parcels/ugrid.py @@ -12,6 +12,7 @@ from parcels.tools.warnings import FieldSetWarning from parcels.basegrid import BaseGrid from parcels.basegrid import CGrid +from parcels.hashgrid import HashGrid # Note : # For variable placement in FESOM - see https://fesom2.readthedocs.io/en/latest/geometry.html @@ -25,6 +26,7 @@ class UGrid(BaseGrid): def __init__(self,face_node_connectivity,*args, **kwargs): self._face_node_connectivity = face_node_connectivity # nface_node_connectivity x n_nodes_per_element array listing the vertex ids of each face_node_connectivity. + self._hashgrid = None super().__init__(*args, **kwargs) if not isinstance(self.lon, np.ndarray): @@ -82,8 +84,6 @@ def nodes_per_element(self): def n_vertices(self): return self._n_vertices - - @staticmethod def create_grid( lon: npt.ArrayLike, @@ -103,3 +103,77 @@ def create_grid( depth = np.array(depth) return UGrid(lon, lat, face_node_connectivity, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + + @staticmethod + def create_hashgrid(self, hash_cell_size_scalefac=1.0): + """Create the hashgrid attribute and populates the ugrid_element lookup table that + relates the hash cell indices to the associated ugrid elements. The hashgrid + spacing is based on the median bounding box diagonal length of all triangles in the mesh. + The `hash_cell_size_scalefac` parameter can be used to scale the hash cell size. Values + greater than 1 will result in larger hash cells, likely having more unstructured elements + per hash cell. + """ + import numpy as np + + # Initialize a list to store bounding box diagonals + diagonals = np.zeros(self.face_node_connectivity.shape[0]) + vertices = np.column_stack((self.lon,self.lat)) + + # Loop over each triangle element + k = 0 + for triangle in self.face_node_connectivity: + # Get the coordinates of the triangle's vertices + triangle_vertices = vertices[triangle] + + # Calculate the bounding box of the triangle + x_min, y_min = np.min(triangle_vertices, axis=0) + x_max, y_max = np.max(triangle_vertices, axis=0) + + # Calculate the diagonal length of the bounding box + diagonal = np.sqrt((x_max - x_min) ** 2 + (y_max - y_min) ** 2) + + # Store the diagonal length + diagonals[k] = diagonal + + k+=1 + + # Use the median diagonal as a basis for the cell size + dh = np.median(diagonals)*hash_cell_size_scalefac + + Nx = int((self.lon.max() - self.lon.min()) / dh) + 1 + Ny = int((self.lat.max() - self.lat.min()) / dh) + 1 + self.hashgrid = HashGrid(self.lon.min(), self.lat.min(), dh, dh, Nx, Ny) + self.hashgrid.populate_ugrid_elements(self.lon, self.lat, self.face_node_connectivity) + + @staticmethod + def barycentric_coordinates(self,xP, yP): + """ + Compute the barycentric coordinates of a particle in a triangular element + + Parameters: + - xP, yP: The coordinates of the particle + - triangle_vertices (np.ndarray) : The vertices of the triangle as a (3,2) array. + + Returns: + - The barycentric coordinates (l1,l2,l3) + - True if the point is inside the triangle, False otherwise. + """ + + xv = np.squeeze(self.lon) + yv = np.squeeze(self.lat[:,1]) + + A_ABC = xv[0]*(yv[1]-yv[2]) + xv[1]*(yv[2]-yv[0]) + xv[2]*(yv[0]-yv[1]) + A_BCP = xv[1]*(yv[2]-yP ) + xv[2]*(yP -yv[1]) + xP *(yv[1]-yv[2]) + A_CAP = xv[2]*(yv[0]-yP ) + xv[0]*(yP -yv[2]) + xP *(yv[2]-yv[0]) + A_ABP = xv[0]*(yv[1]-yP ) + xv[1]*(yP -yv[0]) + xP *(yv[0]-yv[1]) + + # Compute the vectors + l1 = A_BCP/A_ABC + l2 = A_CAP/A_ABC + l3 = A_ABP/A_ABC + + inside_triangle = all( [l1 >= 0.0, l1 <= 1.0, + l2 >= 0.0, l2 <= 1.0, + l3 >= 0.0, l3 <= 1.0] ) + + return l1,l2,l3,inside_triangle \ No newline at end of file