Skip to content

Commit f06f2e1

Browse files
Merge pull request #1897 from OceanParcels/v/remove_cast_data_dtype
Remove type casting of field data
2 parents 869137c + d55edc1 commit f06f2e1

File tree

4 files changed

+6
-68
lines changed

4 files changed

+6
-68
lines changed

parcels/field.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from collections.abc import Iterable
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Literal
6+
from typing import TYPE_CHECKING, cast
77

88
import dask.array as da
99
import numpy as np
@@ -20,6 +20,7 @@
2020
from parcels._typing import (
2121
GridIndexingType,
2222
InterpMethod,
23+
InterpMethodOption,
2324
Mesh,
2425
VectorType,
2526
assert_valid_gridindexingtype,
@@ -140,8 +141,6 @@ class Field:
140141
Minimum allowed value on the field. Data below this value are set to zero
141142
vmax : float
142143
Maximum allowed value on the field. Data above this value are set to zero
143-
cast_data_dtype : str
144-
Cast Field data to dtype. Supported dtypes are "float32" (np.float32 (default)) and "float64 (np.float64).
145144
time_origin : parcels.tools.converters.TimeConverter
146145
Time origin of the time axis (only if grid is None)
147146
interp_method : str
@@ -162,7 +161,6 @@ class Field:
162161
"""
163162

164163
allow_time_extrapolation: bool
165-
_cast_data_dtype: type[np.float32] | type[np.float64]
166164

167165
def __init__(
168166
self,
@@ -179,7 +177,6 @@ def __init__(
179177
transpose: bool = False,
180178
vmin: float | None = None,
181179
vmax: float | None = None,
182-
cast_data_dtype: type[np.float32] | type[np.float64] | Literal["float32", "float64"] = "float32",
183180
time_origin: TimeConverter | None = None,
184181
interp_method: InterpMethod = "linear",
185182
allow_time_extrapolation: bool | None = None,
@@ -246,19 +243,6 @@ def __init__(
246243
self.vmin = vmin
247244
self.vmax = vmax
248245

249-
match cast_data_dtype:
250-
case "float32":
251-
self._cast_data_dtype = np.float32
252-
case "float64":
253-
self._cast_data_dtype = np.float64
254-
case _:
255-
self._cast_data_dtype = cast_data_dtype
256-
257-
if self.cast_data_dtype not in [np.float32, np.float64]:
258-
raise ValueError(
259-
f"Unsupported cast_data_dtype {self.cast_data_dtype!r}. Choose either: 'float32' or 'float64'"
260-
)
261-
262246
if not self.grid.defer_load:
263247
self.data = self._reshape(self.data, transpose)
264248
self._loaded_time_indices = range(self.grid.tdim)
@@ -332,10 +316,6 @@ def interp_method(self, value):
332316
def gridindexingtype(self):
333317
return self._gridindexingtype
334318

335-
@property
336-
def cast_data_dtype(self):
337-
return self._cast_data_dtype
338-
339319
@property
340320
def netcdf_engine(self):
341321
return self._netcdf_engine
@@ -522,6 +502,7 @@ def from_netcdf(
522502
interp_method = interp_method[variable[0]]
523503
else:
524504
raise RuntimeError(f"interp_method is a dictionary but {variable[0]} is not in it")
505+
interp_method = cast(InterpMethodOption, interp_method)
525506

526507
if "lon" in dimensions and "lat" in dimensions:
527508
with NetcdfFileBuffer(
@@ -719,10 +700,6 @@ def _reshape(self, data, transpose=False):
719700
# Ensure that field data is the right data type
720701
if not isinstance(data, (np.ndarray)):
721702
data = np.array(data)
722-
if (self.cast_data_dtype == np.float32) and (data.dtype != np.float32):
723-
data = data.astype(np.float32)
724-
elif (self.cast_data_dtype == np.float64) and (data.dtype != np.float64):
725-
data = data.astype(np.float64)
726703
if transpose:
727704
data = np.transpose(data)
728705
if self.grid._lat_flipped:
@@ -1059,7 +1036,6 @@ def computeTimeChunk(self, data, tindex):
10591036
timestamp=timestamp,
10601037
interp_method=self.interp_method,
10611038
data_full_zdim=self.data_full_zdim,
1062-
cast_data_dtype=self.cast_data_dtype,
10631039
)
10641040
filebuffer.__enter__()
10651041
time_data = filebuffer.time

parcels/fieldfilebuffer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def __init__(
1919
timestamp=None,
2020
interp_method: InterpMethodOption = "linear",
2121
data_full_zdim=None,
22-
cast_data_dtype=np.float32,
2322
gridindexingtype="nemo",
2423
**kwargs,
2524
):
@@ -28,7 +27,6 @@ def __init__(
2827
self.indices = indices
2928
self.dataset = None
3029
self.timestamp = timestamp
31-
self.cast_data_dtype = cast_data_dtype
3230
self.ti = None
3331
self.interp_method = interp_method
3432
self.gridindexingtype = gridindexingtype
@@ -140,10 +138,10 @@ def depth_dimensions(self):
140138
else:
141139
return np.empty((0, len(self.indices["depth"]), len(self.indices["lat"]), len(self.indices["lon"])))
142140

143-
def _check_extend_depth(self, data, di):
141+
def _check_extend_depth(self, data, dim):
144142
return (
145143
self.indices["depth"][-1] == self.data_full_zdim - 1
146-
and data.shape[di] == self.data_full_zdim - 1
144+
and data.shape[dim] == self.data_full_zdim - 1
147145
and self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]
148146
)
149147

@@ -192,8 +190,7 @@ def data(self):
192190
def data_access(self):
193191
data = self.dataset[self.name]
194192
ti = range(data.shape[0]) if self.ti is None else self.ti
195-
data = self._apply_indices(data, ti)
196-
return np.array(data, dtype=self.cast_data_dtype)
193+
return np.array(self._apply_indices(data, ti))
197194

198195
@property
199196
def time(self):

parcels/fieldset.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,6 @@ def check_velocityfields(U, V, W):
279279
if V.gridindexingtype != U.gridindexingtype or (W and W.gridindexingtype != U.gridindexingtype):
280280
raise ValueError("Not all velocity Fields have the same gridindexingtype")
281281

282-
if U.cast_data_dtype != V.cast_data_dtype or (W and W.cast_data_dtype != U.cast_data_dtype):
283-
raise ValueError("Not all velocity Fields have the same dtype")
284-
285282
if isinstance(self.U, NestedField):
286283
w = self.W if hasattr(self, "W") else [None] * len(self.U)
287284
for U, V, W in zip(self.U, self.V, w, strict=True):

tests/test_fieldset.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -192,38 +192,6 @@ def test_fieldset_from_cgrid_interpmethod():
192192
FieldSet.from_c_grid_dataset(filenames, variable, dimensions, interp_method="partialslip")
193193

194194

195-
@pytest.mark.parametrize("cast_data_dtype", ["float32", "float64"])
196-
def test_fieldset_float64(cast_data_dtype, tmpdir):
197-
xdim, ydim = 10, 5
198-
lon = np.linspace(0.0, 10.0, xdim, dtype=np.float64)
199-
lat = np.linspace(0.0, 10.0, ydim, dtype=np.float64)
200-
U, V = np.meshgrid(lon, lat)
201-
dimensions = {"lat": lat, "lon": lon}
202-
data = {"U": np.array(U, dtype=np.float64), "V": np.array(V, dtype=np.float64)}
203-
204-
fieldset = FieldSet.from_data(data, dimensions, mesh="flat", cast_data_dtype=cast_data_dtype)
205-
if cast_data_dtype == "float32":
206-
assert fieldset.U.data.dtype == np.float32
207-
else:
208-
assert fieldset.U.data.dtype == np.float64
209-
pset = ParticleSet(fieldset, Particle, lon=1, lat=2)
210-
211-
failed = False
212-
try:
213-
pset.execute(AdvectionRK4, runtime=2)
214-
except RuntimeError:
215-
failed = True # noqa
216-
assert np.isclose(pset[0].lon, 2.70833)
217-
assert np.isclose(pset[0].lat, 5.41667)
218-
filepath = tmpdir.join("test_fieldset_float64")
219-
fieldset.U.write(filepath)
220-
da = xr.open_dataset(str(filepath) + "U.nc")
221-
if cast_data_dtype == "float32":
222-
assert da["U"].dtype == np.float32
223-
else:
224-
assert da["U"].dtype == np.float64
225-
226-
227195
@pytest.mark.parametrize("indslon", [range(10, 20), [1]])
228196
@pytest.mark.parametrize("indslat", [range(30, 60), [22]])
229197
def test_fieldset_from_file_subsets(indslon, indslat, tmpdir):

0 commit comments

Comments
 (0)