diff --git a/docs/examples/tutorial_stommel_uxarray.ipynb b/docs/examples/tutorial_stommel_uxarray.ipynb index 5b703550b..bb0b3fd6a 100644 --- a/docs/examples/tutorial_stommel_uxarray.ipynb +++ b/docs/examples/tutorial_stommel_uxarray.ipynb @@ -88,7 +88,7 @@ "\n", "A `UXArray.Dataset` consists of multiple `UXArray.UxDataArray`'s and a `UXArray.UxGrid`. Parcels views general circulation model data through the `Field` and `VectorField` classes. A `Field` is defined by its `name`, `data`, `grid`, and `interp_method`. A `VectorField` can be constructed by using 2 or 3 `Field`'s. The `Field.data` attribute can be either an `XArray.DataArray` or `UXArray.UxDataArray` object. The `Field.grid` attribute is of type `Parcels.XGrid` or `Parcels.UXGrid`. Last, the `interp_method` is a dynamic function that can be set at runtime to define the interpolation procedure for the `Field`. This gives you the flexibility to use one of the pre-defined interpolation methods included with Parcels v4, or to create your own interpolator. \n", "\n", - "The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions." + "The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions. Setting the `mesh` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees." ] }, { @@ -99,7 +99,7 @@ "source": [ "from parcels.uxgrid import UxGrid\n", "\n", - "grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"])\n", + "grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"], mesh=\"spherical\")\n", "# You can view the uxgrid object with the following command:\n", "grid.uxgrid" ] @@ -112,7 +112,7 @@ "\n", "In Parcels, grid searching is conducted with respect to the faces. In other words, when a grid index `ei` is provided to an interpolation method, this refers the face index `fi` at vertical layer `zi` (when unraveled). Within the interpolation method, the `field.grid.uxgrid.face_node_connectivity` attribute can be used to obtain the node indices that surround the face. Using these connectivity tables is necessary for properly indexing node registered data.\n", "\n", - "For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method. Setting the `mesh_type` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees." + "For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method." ] }, { @@ -128,21 +128,18 @@ " name=\"U\",\n", " data=ds.U,\n", " grid=grid,\n", - " mesh_type=\"spherical\",\n", " interp_method=UXPiecewiseConstantFace,\n", ")\n", "V = Field(\n", " name=\"V\",\n", " data=ds.V,\n", " grid=grid,\n", - " mesh_type=\"spherical\",\n", " interp_method=UXPiecewiseConstantFace,\n", ")\n", "P = Field(\n", " name=\"P\",\n", " data=ds.p,\n", " grid=grid,\n", - " mesh_type=\"spherical\",\n", " interp_method=UXPiecewiseConstantFace,\n", ")" ] diff --git a/parcels/_datasets/structured/generated.py b/parcels/_datasets/structured/generated.py index 0a91f6864..4454cc58e 100644 --- a/parcels/_datasets/structured/generated.py +++ b/parcels/_datasets/structured/generated.py @@ -4,8 +4,8 @@ import xarray as xr -def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh_type="spherical"): - max_lon = 180.0 if mesh_type == "spherical" else 1e6 +def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"): + max_lon = 180.0 if mesh == "spherical" else 1e6 return xr.Dataset( {"U": (["time", "depth", "YG", "XG"], np.zeros(dims)), "V": (["time", "depth", "YG", "XG"], np.zeros(dims))}, diff --git a/parcels/_index_search.py b/parcels/_index_search.py index f95c215b0..316e72bc8 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -8,6 +8,7 @@ from parcels._typing import ( GridIndexingType, InterpMethodOption, + Mesh, ) from parcels.tools.statuscodes import ( FieldOutOfBoundError, @@ -174,7 +175,7 @@ def _search_indices_rectilinear( _raise_field_out_of_bound_error(z, y, x) if field.xdim > 1: - if field._mesh_type != "spherical": + if field._mesh != "spherical": lon_index = field.lon < x if lon_index.all(): xi = len(field.lon) - 2 @@ -305,7 +306,7 @@ def _search_indices_curvilinear_2d( xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi)) yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi)) - (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh) + (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh) it += 1 if it > maxIterSearch: print(f"Correct cell not found after {maxIterSearch} iterations") @@ -408,11 +409,11 @@ def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2 return (zeta, eta, xsi, zi, yi, xi) -def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool): - xi = np.where(xi < 0, (xdim - 2) if sphere_mesh else 0, xi) - xi = np.where(xi > xdim - 2, 0 if sphere_mesh else (xdim - 2), xi) +def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh): + xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi) + xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi) - xi = np.where(yi > ydim - 2, xdim - xi if sphere_mesh else xi, xi) + xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi) yi = np.where(yi < 0, 0, yi) yi = np.where(yi > ydim - 2, ydim - 2, yi) diff --git a/parcels/field.py b/parcels/field.py index 7c2170ac9..7d14eff60 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -11,11 +11,7 @@ from parcels._core.utils.time import TimeInterval from parcels._reprs import default_repr -from parcels._typing import ( - Mesh, - VectorType, - assert_valid_mesh, -) +from parcels._typing import VectorType from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, ZeroInterpolator from parcels.particle import KernelParticle from parcels.tools.converters import ( @@ -86,7 +82,7 @@ class Field: ----- The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata. * dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon]) - * attrs: (location, mesh, mesh_type) + * attrs: (location, mesh, mesh) When using a xarray.DataArray object, * The xarray.DataArray object must have the "location" and "mesh" attributes set. @@ -114,7 +110,6 @@ def __init__( name: str, data: xr.DataArray | ux.UxDataArray, grid: UxGrid | XGrid, - mesh_type: Mesh = "flat", interp_method: Callable | None = None, ): if not isinstance(data, (ux.UxDataArray, xr.DataArray)): @@ -126,8 +121,6 @@ def __init__( if not isinstance(grid, (UxGrid, XGrid)): raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.") - assert_valid_mesh(mesh_type) - _assert_compatible_combination(data, grid) if isinstance(grid, XGrid): @@ -155,8 +148,6 @@ def __init__( e.add_note(f"Error validating field {name!r}.") raise e - self._mesh_type = mesh_type - # Setting the interpolation method dynamically if interp_method is None: self._interp_method = _DEFAULT_INTERPOLATOR_MAPPING[type(self.grid)] @@ -166,12 +157,10 @@ def __init__( self.igrid = -1 # Default the grid index to -1 - if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()): + if self.grid._mesh == "flat" or (self.name not in unitconverters_map.keys()): self.units = UnitConverter() - elif self._mesh_type == "spherical": + elif self.grid._mesh == "spherical": self.units = unitconverters_map[self.name] - else: - raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") if self.data.shape[0] > 1: if "time" not in self.data.coords: diff --git a/parcels/kernel.py b/parcels/kernel.py index 4cd8f3777..467724ee8 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -143,7 +143,7 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth stacklevel=2, ) self.fieldset.add_constant("RK45_tol", 10) - if self.fieldset.U.grid.mesh == "spherical": + if self.fieldset.U.grid._mesh == "spherical": self.fieldset.RK45_tol /= ( 1852 * 60 ) # TODO does not account for zonal variation in meter -> degree conversion diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 57942610e..4a76a1113 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -86,7 +86,7 @@ def _hash_index2d(self, coords): as the source grid coordinates """ # Wrap longitude to [-180, 180] - if self._source_grid.mesh == "spherical": + if self._source_grid._mesh == "spherical": lon = (coords[:, 1] + 180.0) % (360.0) - 180.0 else: lon = coords[:, 1] diff --git a/parcels/uxgrid.py b/parcels/uxgrid.py index 7f0048843..bc8014f76 100644 --- a/parcels/uxgrid.py +++ b/parcels/uxgrid.py @@ -5,6 +5,7 @@ import numpy as np import uxarray as ux +from parcels._typing import assert_valid_mesh from parcels.spatialhash import _barycentric_coordinates from parcels.tools.statuscodes import FieldOutOfBoundError from parcels.xgrid import _search_1d_array @@ -20,7 +21,7 @@ class UxGrid(BaseGrid): for interpolation on unstructured grids. """ - def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid: + def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid: """ Initializes the UxGrid with a uxarray grid and vertical coordinate array. @@ -32,6 +33,8 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid: A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths). While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid. + mesh : str, optional + The type of mesh used for the grid. Either "flat" (default) or "spherical". """ self.uxgrid = grid if not isinstance(z, ux.UxDataArray): @@ -39,6 +42,9 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid: if z.ndim != 1: raise ValueError("z must be a 1D array of vertical coordinates") self.z = z + self._mesh = mesh + + assert_valid_mesh(mesh) @property def depth(self): diff --git a/parcels/xgrid.py b/parcels/xgrid.py index bac507d8e..75140d0a9 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -8,6 +8,7 @@ from parcels import xgcm from parcels._index_search import _search_indices_curvilinear_2d +from parcels._typing import assert_valid_mesh from parcels.basegrid import BaseGrid from parcels.spatialhash import SpatialHash @@ -97,13 +98,15 @@ class XGrid(BaseGrid): def __init__(self, grid: xgcm.Grid, mesh="flat"): self.xgcm_grid = grid - self.mesh = mesh + self._mesh = mesh self._spatialhash = None ds = grid._ds if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development) assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes) + assert_valid_mesh(mesh) + @classmethod def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None): """WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release diff --git a/tests/utils.py b/tests/utils.py index 551f82edd..c89cbea45 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -68,15 +68,15 @@ def create_fieldset_global(xdim=200, ydim=100): return FieldSet.from_data(data, dimensions, mesh="flat") -def create_fieldset_zeros_conversion(mesh_type="spherical", xdim=200, ydim=100) -> FieldSet: +def create_fieldset_zeros_conversion(mesh="spherical", xdim=200, ydim=100) -> FieldSet: """Zero velocity field with lat and lon determined by a conversion factor.""" - mesh_conversion = 1 / 1852.0 / 60 if mesh_type == "spherical" else 1 - ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type) + mesh_conversion = 1 / 1852.0 / 60 if mesh == "spherical" else 1 + ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh=mesh) ds["lon"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, xdim) ds["lat"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, ydim) - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear) - V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear) + grid = XGrid.from_dataset(ds, mesh=mesh) + U = Field("U", ds["U"], grid, interp_method=XLinear) + V = Field("V", ds["V"], grid, interp_method=XLinear) UV = VectorField("UV", U, V) return FieldSet([U, V, UV]) diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index 7b209ab30..2522a8b6f 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -32,21 +32,21 @@ } -@pytest.mark.parametrize("mesh_type", ["spherical", "flat"]) -def test_advection_zonal(mesh_type, npart=10): +@pytest.mark.parametrize("mesh", ["spherical", "flat"]) +def test_advection_zonal(mesh, npart=10): """Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`.""" - ds = simple_UV_dataset(mesh_type=mesh_type) + ds = simple_UV_dataset(mesh=mesh) ds["U"].data[:] = 1.0 - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear) - V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear) + grid = XGrid.from_dataset(ds, mesh=mesh) + U = Field("U", ds["U"], grid, interp_method=XLinear) + V = Field("V", ds["V"], grid, interp_method=XLinear) UV = VectorField("UV", U, V) fieldset = FieldSet([U, V, UV]) pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m")) - if mesh_type == "spherical": + if mesh == "spherical": assert (np.diff(pset.lon) > 1.0e-4).all() else: assert (np.diff(pset.lon) < 1.0e-4).all() @@ -58,7 +58,7 @@ def periodicBC(particle, fieldset, time): def test_advection_zonal_periodic(): - ds = simple_UV_dataset(dims=(2, 2, 2, 2), mesh_type="flat") + ds = simple_UV_dataset(dims=(2, 2, 2, 2), mesh="flat") ds["U"].data[:] = 0.1 ds["lon"].data = np.array([0, 2]) ds["lat"].data = np.array([0, 2]) @@ -86,7 +86,7 @@ def test_advection_zonal_periodic(): def test_horizontal_advection_in_3D_flow(npart=10): """Flat 2D zonal flow that increases linearly with depth from 0 m/s to 1 m/s.""" - ds = simple_UV_dataset(mesh_type="flat") + ds = simple_UV_dataset(mesh="flat") ds["U"].data[:] = 1.0 grid = XGrid.from_dataset(ds) U = Field("U", ds["U"], grid, interp_method=XLinear) @@ -105,7 +105,7 @@ def test_horizontal_advection_in_3D_flow(npart=10): @pytest.mark.parametrize("direction", ["up", "down"]) @pytest.mark.parametrize("wErrorThroughSurface", [True, False]) def test_advection_3D_outofbounds(direction, wErrorThroughSurface): - ds = simple_UV_dataset(mesh_type="flat") + ds = simple_UV_dataset(mesh="flat") grid = XGrid.from_dataset(ds) U = Field("U", ds["U"], grid, interp_method=XLinear) U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds) @@ -208,9 +208,9 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read def test_radialrotation(npart=10): ds = radial_rotation_dataset() - grid = XGrid.from_dataset(ds) - U = parcels.Field("U", ds["U"], grid, mesh_type="flat", interp_method=XLinear) - V = parcels.Field("V", ds["V"], grid, mesh_type="flat", interp_method=XLinear) + grid = XGrid.from_dataset(ds, mesh="flat") + U = parcels.Field("U", ds["U"], grid, interp_method=XLinear) + V = parcels.Field("V", ds["V"], grid, interp_method=XLinear) UV = parcels.VectorField("UV", U, V) fieldset = parcels.FieldSet([U, V, UV]) diff --git a/tests/v4/test_diffusion.py b/tests/v4/test_diffusion.py index 74cfa0429..1a4769302 100644 --- a/tests/v4/test_diffusion.py +++ b/tests/v4/test_diffusion.py @@ -15,22 +15,22 @@ from tests.utils import create_fieldset_zeros_conversion -@pytest.mark.parametrize("mesh_type", ["spherical", "flat"]) -def test_fieldKh_Brownian(mesh_type): +@pytest.mark.parametrize("mesh", ["spherical", "flat"]) +def test_fieldKh_Brownian(mesh): kh_zonal = 100 kh_meridional = 50 - mesh_conversion = 1 / 1852.0 / 60 if mesh_type == "spherical" else 1 + mesh_conversion = 1 / 1852.0 / 60 if mesh == "spherical" else 1 - ds = simple_UV_dataset(dims=(2, 1, 2, 2), mesh_type=mesh_type) + ds = simple_UV_dataset(dims=(2, 1, 2, 2), mesh=mesh) ds["lon"].data = np.array([-1e6, 1e6]) ds["lat"].data = np.array([-1e6, 1e6]) - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear) - V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear) + grid = XGrid.from_dataset(ds, mesh=mesh) + U = Field("U", ds["U"], grid, interp_method=XLinear) + V = Field("V", ds["V"], grid, interp_method=XLinear) ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_zonal)) ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_meridional)) - Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, mesh_type=mesh_type, interp_method=XLinear) - Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, mesh_type=mesh_type, interp_method=XLinear) + Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear) + Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear) UV = VectorField("UV", U, V) fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional]) @@ -51,19 +51,19 @@ def test_fieldKh_Brownian(mesh_type): assert np.allclose(np.mean(pset.lat), 0, atol=tol) -@pytest.mark.parametrize("mesh_type", ["spherical", "flat"]) +@pytest.mark.parametrize("mesh", ["spherical", "flat"]) @pytest.mark.parametrize("kernel", [AdvectionDiffusionM1, AdvectionDiffusionEM]) -def test_fieldKh_SpatiallyVaryingDiffusion(mesh_type, kernel): +def test_fieldKh_SpatiallyVaryingDiffusion(mesh, kernel): """Test advection-diffusion kernels on a non-uniform diffusivity field with a linear gradient in one direction.""" ydim, xdim = 100, 200 - mesh_conversion = 1 / 1852.0 / 60 if mesh_type == "spherical" else 1 - ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type) + mesh_conversion = 1 / 1852.0 / 60 if mesh == "spherical" else 1 + ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh=mesh) ds["lon"].data = np.linspace(-1e6, 1e6, xdim) ds["lat"].data = np.linspace(-1e6, 1e6, ydim) - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear) - V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear) + grid = XGrid.from_dataset(ds, mesh=mesh) + U = Field("U", ds["U"], grid, interp_method=XLinear) + V = Field("V", ds["V"], grid, interp_method=XLinear) Kh = np.zeros((ydim, xdim), dtype=np.float32) for x in range(xdim): @@ -71,8 +71,8 @@ def test_fieldKh_SpatiallyVaryingDiffusion(mesh_type, kernel): ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, ydim, xdim), Kh)) ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, ydim, xdim), Kh)) - Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, mesh_type=mesh_type, interp_method=XLinear) - Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, mesh_type=mesh_type, interp_method=XLinear) + Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear) + Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear) UV = VectorField("UV", U, V) fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional]) fieldset.add_constant("dres", float(ds["lon"][1] - ds["lon"][0])) diff --git a/tests/v4/test_field.py b/tests/v4/test_field.py index 6f1068896..cd71337d7 100644 --- a/tests/v4/test_field.py +++ b/tests/v4/test_field.py @@ -25,9 +25,6 @@ def test_field_init_param_types(): with pytest.raises(ValueError, match="Expected `grid` to be a parcels UxGrid, or parcels XGrid"): Field(name="test", data=data["data_g"], grid=123) - with pytest.raises(ValueError, match="Invalid value 'invalid'. Valid options are.*"): - Field(name="test", data=data["data_g"], grid=grid, mesh_type="invalid") - @pytest.mark.parametrize( "data,grid", @@ -107,7 +104,7 @@ def test_field_init_fail_on_float_time_dim(): ) def test_field_time_interval(data, grid): """Test creating a field.""" - field = Field(name="test_field", data=data, grid=grid, mesh_type="flat") + field = Field(name="test_field", data=data, grid=grid) assert field.time_interval.left == np.datetime64("2000-01-01") assert field.time_interval.right == np.datetime64("2001-01-01") diff --git a/tests/v4/test_fieldset.py b/tests/v4/test_fieldset.py index c3871018b..075256794 100644 --- a/tests/v4/test_fieldset.py +++ b/tests/v4/test_fieldset.py @@ -21,9 +21,9 @@ @pytest.fixture def fieldset() -> FieldSet: """Fixture to create a FieldSet object for testing.""" - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U (A grid)"], grid, mesh_type="flat") - V = Field("V", ds["V (A grid)"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U (A grid)"], grid) + V = Field("V", ds["V (A grid)"], grid) UV = VectorField("UV", U, V) return FieldSet( @@ -55,8 +55,8 @@ def test_fieldset_add_constant_field(fieldset): def test_fieldset_add_field(fieldset): - grid = XGrid.from_dataset(ds) - field = Field("test_field", ds["U (A grid)"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds, mesh="flat") + field = Field("test_field", ds["U (A grid)"], grid) fieldset.add_field(field) assert fieldset.test_field == field @@ -68,8 +68,8 @@ def test_fieldset_add_field_wrong_type(fieldset): def test_fieldset_add_field_already_exists(fieldset): - grid = XGrid.from_dataset(ds) - field = Field("test_field", ds["U (A grid)"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds, mesh="flat") + field = Field("test_field", ds["U (A grid)"], grid) fieldset.add_field(field, "test_field") with pytest.raises(ValueError, match="FieldSet already has a Field with name 'test_field'"): fieldset.add_field(field, "test_field") @@ -87,10 +87,10 @@ def test_fieldset_gridset(fieldset): @pytest.mark.parametrize("ds", [pytest.param(ds, id=k) for k, ds in datasets_structured.items()]) def test_fieldset_from_structured_generic_datasets(ds): - grid = XGrid.from_dataset(ds) + grid = XGrid.from_dataset(ds, mesh="flat") fields = [] for var in ds.data_vars: - fields.append(Field(var, ds[var], grid, mesh_type="flat")) + fields.append(Field(var, ds[var], grid)) fieldset = FieldSet(fields) @@ -105,13 +105,13 @@ def test_fieldset_gridset_multiple_grids(): ... def test_fieldset_time_interval(): - grid1 = XGrid.from_dataset(ds) - field1 = Field("field1", ds["U (A grid)"], grid1, mesh_type="flat") + grid1 = XGrid.from_dataset(ds, mesh="flat") + field1 = Field("field1", ds["U (A grid)"], grid1) ds2 = ds.copy() ds2["time"] = (ds2["time"].dims, ds2["time"].data + np.timedelta64(timedelta(days=1)), ds2["time"].attrs) - grid2 = XGrid.from_dataset(ds2) - field2 = Field("field2", ds2["U (A grid)"], grid2, mesh_type="flat") + grid2 = XGrid.from_dataset(ds2, mesh="flat") + field2 = Field("field2", ds2["U (A grid)"], grid2) fieldset = FieldSet([field1, field2]) fieldset.add_constant_field("constant_field", 1.0) @@ -136,9 +136,9 @@ def test_fieldset_init_incompatible_calendars(): ds1["time"].attrs, ) - grid = XGrid.from_dataset(ds1) - U = Field("U", ds1["U (A grid)"], grid, mesh_type="flat") - V = Field("V", ds1["V (A grid)"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds1, mesh="flat") + U = Field("U", ds1["U (A grid)"], grid) + V = Field("V", ds1["V (A grid)"], grid) UV = VectorField("UV", U, V) ds2 = ds.copy() @@ -147,8 +147,8 @@ def test_fieldset_init_incompatible_calendars(): xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True), ds2["time"].attrs, ) - grid2 = XGrid.from_dataset(ds2) - incompatible_calendar = Field("test", ds2["data_g"], grid2, mesh_type="flat") + grid2 = XGrid.from_dataset(ds2, mesh="flat") + incompatible_calendar = Field("test", ds2["data_g"], grid2) with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"): FieldSet([U, V, UV, incompatible_calendar]) @@ -161,8 +161,8 @@ def test_fieldset_add_field_incompatible_calendars(fieldset): xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True), ds_test["time"].attrs, ) - grid = XGrid.from_dataset(ds_test) - field = Field("test_field", ds_test["data_g"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds_test, mesh="flat") + field = Field("test_field", ds_test["data_g"], grid) with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"): fieldset.add_field(field, "test_field") @@ -173,8 +173,8 @@ def test_fieldset_add_field_incompatible_calendars(fieldset): np.linspace(0, 100, T_structured, dtype="timedelta64[s]"), ds_test["time"].attrs, ) - grid = XGrid.from_dataset(ds_test) - field = Field("test_field", ds_test["data_g"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds_test, mesh="flat") + field = Field("test_field", ds_test["data_g"], grid) with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"): fieldset.add_field(field, "test_field") diff --git a/tests/v4/test_interpolation.py b/tests/v4/test_interpolation.py index 959066992..a8c52b7cf 100644 --- a/tests/v4/test_interpolation.py +++ b/tests/v4/test_interpolation.py @@ -16,18 +16,18 @@ from tests.utils import TEST_DATA -@pytest.mark.parametrize("mesh_type", ["spherical", "flat"]) -def test_interpolation_mesh_type(mesh_type, npart=10): - ds = simple_UV_dataset(mesh_type=mesh_type) +@pytest.mark.parametrize("mesh", ["spherical", "flat"]) +def test_interpolation_mesh(mesh, npart=10): + ds = simple_UV_dataset(mesh=mesh) ds["U"].data[:] = 1.0 - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear) - V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear) + grid = XGrid.from_dataset(ds, mesh=mesh) + U = Field("U", ds["U"], grid, interp_method=XLinear) + V = Field("V", ds["V"], grid, interp_method=XLinear) UV = VectorField("UV", U, V) lat = 30.0 time = U.time_interval.left - u_expected = 1.0 if mesh_type == "flat" else 1.0 / (1852 * 60 * np.cos(np.radians(lat))) + u_expected = 1.0 if mesh == "flat" else 1.0 / (1852 * 60 * np.cos(np.radians(lat))) assert np.isclose(U[time, 0, lat, 0], u_expected, atol=1e-7) assert V[time, 0, lat, 0] == 0.0 @@ -91,10 +91,10 @@ def test_interp_regression_v3(interp_name): }, ) - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U"], grid, mesh_type="flat", interp_method=interp_methods[interp_name]) - V = Field("V", ds["V"], grid, mesh_type="flat", interp_method=interp_methods[interp_name]) - W = Field("W", ds["W"], grid, mesh_type="flat", interp_method=interp_methods[interp_name]) + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U"], grid, interp_method=interp_methods[interp_name]) + V = Field("V", ds["V"], grid, interp_method=interp_methods[interp_name]) + W = Field("W", ds["W"], grid, interp_method=interp_methods[interp_name]) fieldset = FieldSet([U, V, W, VectorField("UVW", U, V, W)]) x, y, z = np.meshgrid(np.linspace(0, 1, 7), np.linspace(0, 1, 13), np.linspace(0, 1, 5)) diff --git a/tests/v4/test_kernel.py b/tests/v4/test_kernel.py index 098b1a4c3..acceb8dff 100644 --- a/tests/v4/test_kernel.py +++ b/tests/v4/test_kernel.py @@ -17,9 +17,9 @@ @pytest.fixture def fieldset() -> FieldSet: ds = datasets_structured["ds_2d_left"] - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U (A grid)"], grid, mesh_type="flat") - V = Field("V", ds["V (A grid)"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U (A grid)"], grid) + V = Field("V", ds["V (A grid)"], grid) return FieldSet([U, V]) diff --git a/tests/v4/test_particleset.py b/tests/v4/test_particleset.py index e1aedb13c..7c3f2058c 100644 --- a/tests/v4/test_particleset.py +++ b/tests/v4/test_particleset.py @@ -22,9 +22,9 @@ @pytest.fixture def fieldset() -> FieldSet: ds = datasets_structured["ds_2d_left"] - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U (A grid)"], grid, mesh_type="flat") - V = Field("V", ds["V (A grid)"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U (A grid)"], grid) + V = Field("V", ds["V (A grid)"], grid) return FieldSet([U, V]) diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index 7cf6aa21a..68a699626 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -24,19 +24,19 @@ @pytest.fixture def fieldset() -> FieldSet: ds = datasets_structured["ds_2d_left"] - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U (A grid)"], grid, mesh_type="flat") - V = Field("V", ds["V (A grid)"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U (A grid)"], grid) + V = Field("V", ds["V (A grid)"], grid) return FieldSet([U, V]) @pytest.fixture def zonal_flow_fieldset() -> FieldSet: - ds = simple_UV_dataset(mesh_type="flat") + ds = simple_UV_dataset(mesh="flat") ds["U"].data[:] = 1.0 - grid = XGrid.from_dataset(ds) - U = Field("U", ds["U"], grid, mesh_type="flat") - V = Field("V", ds["V"], grid, mesh_type="flat") + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U"], grid) + V = Field("V", ds["V"], grid) UV = VectorField("UV", U, V) return FieldSet([U, V, UV]) @@ -234,26 +234,23 @@ def PythonFail(particle, fieldset, time): # pragma: no cover def test_uxstommelgyre_pset_execute(): ds = datasets_unstructured["stommel_gyre_delaunay"] - grid = UxGrid(grid=ds.uxgrid, z=ds.coords["nz"]) + grid = UxGrid(grid=ds.uxgrid, z=ds.coords["nz"], mesh="spherical") U = Field( name="U", data=ds.U, grid=grid, - mesh_type="spherical", interp_method=UXPiecewiseConstantFace, ) V = Field( name="V", data=ds.V, grid=grid, - mesh_type="spherical", interp_method=UXPiecewiseConstantFace, ) P = Field( name="P", data=ds.p, grid=grid, - mesh_type="spherical", interp_method=UXPiecewiseConstantFace, ) UV = VectorField(name="UV", U=U, V=V) @@ -278,26 +275,23 @@ def test_uxstommelgyre_pset_execute(): @pytest.mark.xfail(reason="Output file not implemented yet") def test_uxstommelgyre_pset_execute_output(): ds = datasets_unstructured["stommel_gyre_delaunay"] - grid = UxGrid(grid=ds.uxgrid, z=ds.coords["nz"]) + grid = UxGrid(grid=ds.uxgrid, z=ds.coords["nz"], mesh="spherical") U = Field( name="U", data=ds.U, grid=grid, - mesh_type="spherical", interp_method=UXPiecewiseConstantFace, ) V = Field( name="V", data=ds.V, grid=grid, - mesh_type="spherical", interp_method=UXPiecewiseConstantFace, ) P = Field( name="P", data=ds.p, grid=grid, - mesh_type="spherical", interp_method=UXPiecewiseConstantFace, ) UV = VectorField(name="UV", U=U, V=V) diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index f987cf798..eb9e8524e 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -39,6 +39,12 @@ def assert_equal(actual, expected): assert_allclose(actual, expected) +@pytest.mark.parametrize("ds", [datasets["ds_2d_left"]]) +def test_grid_init_param_types(ds): + with pytest.raises(ValueError, match="Invalid value 'invalid'. Valid options are.*"): + XGrid.from_dataset(ds, mesh="invalid") + + @pytest.mark.parametrize("ds, attr, expected", test_cases) def test_xgrid_properties_ground_truth(ds, attr, expected): grid = XGrid.from_dataset(ds) diff --git a/v3to4-breaking-changes.md b/v3to4-breaking-changes.md index 46004e0a1..14fe961f6 100644 --- a/v3to4-breaking-changes.md +++ b/v3to4-breaking-changes.md @@ -6,7 +6,6 @@ Kernels: FieldSet -- `mesh` is now called `mesh_type`? - `interp_method` has to be an Interpolation function, instead of a string ParticleSet