Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 3 additions & 27 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, cast

import dask.array as da
import numpy as np
Expand All @@ -20,6 +20,7 @@
from parcels._typing import (
GridIndexingType,
InterpMethod,
InterpMethodOption,
Mesh,
VectorType,
assert_valid_gridindexingtype,
Expand Down Expand Up @@ -140,8 +141,6 @@ class Field:
Minimum allowed value on the field. Data below this value are set to zero
vmax : float
Maximum allowed value on the field. Data above this value are set to zero
cast_data_dtype : str
Cast Field data to dtype. Supported dtypes are "float32" (np.float32 (default)) and "float64 (np.float64).
time_origin : parcels.tools.converters.TimeConverter
Time origin of the time axis (only if grid is None)
interp_method : str
Expand All @@ -162,7 +161,6 @@ class Field:
"""

allow_time_extrapolation: bool
_cast_data_dtype: type[np.float32] | type[np.float64]

def __init__(
self,
Expand All @@ -179,7 +177,6 @@ def __init__(
transpose: bool = False,
vmin: float | None = None,
vmax: float | None = None,
cast_data_dtype: type[np.float32] | type[np.float64] | Literal["float32", "float64"] = "float32",
time_origin: TimeConverter | None = None,
interp_method: InterpMethod = "linear",
allow_time_extrapolation: bool | None = None,
Expand Down Expand Up @@ -246,19 +243,6 @@ def __init__(
self.vmin = vmin
self.vmax = vmax

match cast_data_dtype:
case "float32":
self._cast_data_dtype = np.float32
case "float64":
self._cast_data_dtype = np.float64
case _:
self._cast_data_dtype = cast_data_dtype

if self.cast_data_dtype not in [np.float32, np.float64]:
raise ValueError(
f"Unsupported cast_data_dtype {self.cast_data_dtype!r}. Choose either: 'float32' or 'float64'"
)

if not self.grid.defer_load:
self.data = self._reshape(self.data, transpose)
self._loaded_time_indices = range(self.grid.tdim)
Expand Down Expand Up @@ -332,10 +316,6 @@ def interp_method(self, value):
def gridindexingtype(self):
return self._gridindexingtype

@property
def cast_data_dtype(self):
return self._cast_data_dtype

@property
def netcdf_engine(self):
return self._netcdf_engine
Expand Down Expand Up @@ -522,6 +502,7 @@ def from_netcdf(
interp_method = interp_method[variable[0]]
else:
raise RuntimeError(f"interp_method is a dictionary but {variable[0]} is not in it")
interp_method = cast(InterpMethodOption, interp_method)

if "lon" in dimensions and "lat" in dimensions:
with NetcdfFileBuffer(
Expand Down Expand Up @@ -719,10 +700,6 @@ def _reshape(self, data, transpose=False):
# Ensure that field data is the right data type
if not isinstance(data, (np.ndarray)):
data = np.array(data)
if (self.cast_data_dtype == np.float32) and (data.dtype != np.float32):
data = data.astype(np.float32)
elif (self.cast_data_dtype == np.float64) and (data.dtype != np.float64):
data = data.astype(np.float64)
if transpose:
data = np.transpose(data)
if self.grid._lat_flipped:
Expand Down Expand Up @@ -1059,7 +1036,6 @@ def computeTimeChunk(self, data, tindex):
timestamp=timestamp,
interp_method=self.interp_method,
data_full_zdim=self.data_full_zdim,
cast_data_dtype=self.cast_data_dtype,
)
filebuffer.__enter__()
time_data = filebuffer.time
Expand Down
9 changes: 3 additions & 6 deletions parcels/fieldfilebuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
timestamp=None,
interp_method: InterpMethodOption = "linear",
data_full_zdim=None,
cast_data_dtype=np.float32,
gridindexingtype="nemo",
**kwargs,
):
Expand All @@ -28,7 +27,6 @@
self.indices = indices
self.dataset = None
self.timestamp = timestamp
self.cast_data_dtype = cast_data_dtype
self.ti = None
self.interp_method = interp_method
self.gridindexingtype = gridindexingtype
Expand Down Expand Up @@ -140,10 +138,10 @@
else:
return np.empty((0, len(self.indices["depth"]), len(self.indices["lat"]), len(self.indices["lon"])))

def _check_extend_depth(self, data, di):
def _check_extend_depth(self, data, dim):
return (
self.indices["depth"][-1] == self.data_full_zdim - 1
and data.shape[di] == self.data_full_zdim - 1
and data.shape[dim] == self.data_full_zdim - 1

Check warning on line 144 in parcels/fieldfilebuffer.py

View check run for this annotation

Codecov / codecov/patch

parcels/fieldfilebuffer.py#L144

Added line #L144 was not covered by tests
and self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]
)

Expand Down Expand Up @@ -192,8 +190,7 @@
def data_access(self):
data = self.dataset[self.name]
ti = range(data.shape[0]) if self.ti is None else self.ti
data = self._apply_indices(data, ti)
return np.array(data, dtype=self.cast_data_dtype)
return np.array(self._apply_indices(data, ti))

@property
def time(self):
Expand Down
3 changes: 0 additions & 3 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,6 @@ def check_velocityfields(U, V, W):
if V.gridindexingtype != U.gridindexingtype or (W and W.gridindexingtype != U.gridindexingtype):
raise ValueError("Not all velocity Fields have the same gridindexingtype")

if U.cast_data_dtype != V.cast_data_dtype or (W and W.cast_data_dtype != U.cast_data_dtype):
raise ValueError("Not all velocity Fields have the same dtype")

if isinstance(self.U, NestedField):
w = self.W if hasattr(self, "W") else [None] * len(self.U)
for U, V, W in zip(self.U, self.V, w, strict=True):
Expand Down
32 changes: 0 additions & 32 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,38 +192,6 @@ def test_fieldset_from_cgrid_interpmethod():
FieldSet.from_c_grid_dataset(filenames, variable, dimensions, interp_method="partialslip")


@pytest.mark.parametrize("cast_data_dtype", ["float32", "float64"])
def test_fieldset_float64(cast_data_dtype, tmpdir):
xdim, ydim = 10, 5
lon = np.linspace(0.0, 10.0, xdim, dtype=np.float64)
lat = np.linspace(0.0, 10.0, ydim, dtype=np.float64)
U, V = np.meshgrid(lon, lat)
dimensions = {"lat": lat, "lon": lon}
data = {"U": np.array(U, dtype=np.float64), "V": np.array(V, dtype=np.float64)}

fieldset = FieldSet.from_data(data, dimensions, mesh="flat", cast_data_dtype=cast_data_dtype)
if cast_data_dtype == "float32":
assert fieldset.U.data.dtype == np.float32
else:
assert fieldset.U.data.dtype == np.float64
pset = ParticleSet(fieldset, Particle, lon=1, lat=2)

failed = False
try:
pset.execute(AdvectionRK4, runtime=2)
except RuntimeError:
failed = True # noqa
assert np.isclose(pset[0].lon, 2.70833)
assert np.isclose(pset[0].lat, 5.41667)
filepath = tmpdir.join("test_fieldset_float64")
fieldset.U.write(filepath)
da = xr.open_dataset(str(filepath) + "U.nc")
if cast_data_dtype == "float32":
assert da["U"].dtype == np.float32
else:
assert da["U"].dtype == np.float64


@pytest.mark.parametrize("indslon", [range(10, 20), [1]])
@pytest.mark.parametrize("indslat", [range(30, 60), [22]])
def test_fieldset_from_file_subsets(indslon, indslat, tmpdir):
Expand Down
Loading