Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 11 additions & 26 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear
from .fieldfilebuffer import (
DeferredNetcdfFileBuffer,
NetcdfFileBuffer,
)
from .grid import Grid, GridType
Expand Down Expand Up @@ -281,7 +280,6 @@
self._dimensions = kwargs.pop("dimensions", None)
self.indices = kwargs.pop("indices", None)
self._dataFiles = kwargs.pop("dataFiles", None)
self._field_fb_class = kwargs.pop("FieldFileBuffer", None)
self._netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4")
self._creation_log = kwargs.pop("creation_log", "")
self.grid.depth_field = kwargs.pop("depth_field", None)
Expand Down Expand Up @@ -361,9 +359,7 @@
return filenames

@staticmethod
def _collect_timeslices(
timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine, netcdf_decodewarning=None
):
def _collect_timeslices(timestamps, data_filenames, dimensions, indices, netcdf_engine, netcdf_decodewarning=None):
if netcdf_decodewarning is not None:
_deprecated_param_netcdf_decodewarning()
if timestamps is not None:
Expand All @@ -378,7 +374,7 @@
timeslices = []
dataFiles = []
for fname in data_filenames:
with _grid_fb_class(fname, dimensions, indices, netcdf_engine=netcdf_engine) as filebuffer:
with NetcdfFileBuffer(fname, dimensions, indices, netcdf_engine=netcdf_engine) as filebuffer:
ftime = filebuffer.time
timeslices.append(ftime)
dataFiles.append([fname] * len(ftime))
Expand Down Expand Up @@ -541,10 +537,8 @@
else:
raise RuntimeError(f"interp_method is a dictionary but {variable[0]} is not in it")

_grid_fb_class = NetcdfFileBuffer

if "lon" in dimensions and "lat" in dimensions:
with _grid_fb_class(
with NetcdfFileBuffer(
lonlat_filename,
dimensions,
indices,
Expand All @@ -562,15 +556,15 @@
mesh = "flat"

if "depth" in dimensions:
with _grid_fb_class(
with NetcdfFileBuffer(
depth_filename,
dimensions,
indices,
netcdf_engine,
interp_method=interp_method,
gridindexingtype=gridindexingtype,
) as filebuffer:
filebuffer.name = filebuffer.parse_name(variable[1])
filebuffer.name = variable[1]
if dimensions["depth"] == "not_yet_set":
depth = filebuffer.depth_dimensions
kwargs["depth_field"] = "not_yet_set"
Expand All @@ -592,7 +586,7 @@
# across multiple files
if "time" in dimensions or timestamps is not None:
time, time_origin, timeslices, dataFiles = cls._collect_timeslices(
timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine
timestamps, data_filenames, dimensions, indices, netcdf_engine

Check warning on line 589 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L589

Added line #L589 was not covered by tests
)
grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
grid.timeslices = timeslices
Expand All @@ -604,9 +598,7 @@
elif grid is not None and ("dataFiles" not in kwargs or kwargs["dataFiles"] is None):
# ==== means: the field has a shared grid, but may have different data files, so we need to collect the
# ==== correct file time series again.
_, _, _, dataFiles = cls._collect_timeslices(
timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine
)
_, _, _, dataFiles = cls._collect_timeslices(timestamps, data_filenames, dimensions, indices, netcdf_engine)
kwargs["dataFiles"] = dataFiles

if "time" in indices:
Expand All @@ -617,19 +609,12 @@
if grid.time.size <= 2:
deferred_load = False

_field_fb_class: type[DeferredNetcdfFileBuffer | NetcdfFileBuffer]
if deferred_load:
_field_fb_class = DeferredNetcdfFileBuffer
else:
_field_fb_class = NetcdfFileBuffer
kwargs["FieldFileBuffer"] = _field_fb_class

if not deferred_load:
# Pre-allocate data before reading files into buffer
data_list = []
ti = 0
for tslice, fname in zip(grid.timeslices, data_filenames, strict=True):
with _field_fb_class( # type: ignore[operator]
with NetcdfFileBuffer( # type: ignore[operator]
fname,
dimensions,
indices,
Expand All @@ -639,7 +624,7 @@
) as filebuffer:
# If Field.from_netcdf is called directly, it may not have a 'data' dimension
# In that case, assume that 'name' is the data dimension
filebuffer.name = filebuffer.parse_name(variable[1])
filebuffer.name = variable[1]
buffer_data = filebuffer.data
if len(buffer_data.shape) == 4:
errormessage = (
Expand Down Expand Up @@ -1080,7 +1065,7 @@
ti = g._ti + tindex
timestamp = self.timestamps[np.where(ti < summedlen)[0][0]]

filebuffer = self._field_fb_class(
filebuffer = NetcdfFileBuffer(
self._dataFiles[g._ti + tindex],
self.dimensions,
self.indices,
Expand All @@ -1095,7 +1080,7 @@
time_data = g.time_origin.reltime(time_data)
filebuffer.ti = (time_data <= g.time[tindex]).argmin() - 1
if self.netcdf_engine != "xarray":
filebuffer.name = filebuffer.parse_name(self.filebuffername)
filebuffer.name = self.filebuffername
buffer_data = filebuffer.data
if len(buffer_data.shape) == 2:
buffer_data = np.reshape(buffer_data, sum(((1, 1), buffer_data.shape), ()))
Expand Down
60 changes: 21 additions & 39 deletions parcels/fieldfilebuffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import warnings
from pathlib import Path

import numpy as np
import xarray as xr
Expand All @@ -9,7 +10,7 @@
from parcels.tools.warnings import FileWarning


class _FileBuffer:
class NetcdfFileBuffer:
def __init__(
self,
filename,
Expand All @@ -36,34 +37,10 @@ def __init__(
self.nolonlatindices = False
else:
self.nolonlatindices = True


class NetcdfFileBuffer(_FileBuffer):
def __init__(self, *args, **kwargs):
self.lib = np
self.netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4")
super().__init__(*args, **kwargs)

def __enter__(self):
try:
# Unfortunately we need to do if-else here, cause the lock-parameter is either False or a Lock-object
# (which we would rather want to have being auto-managed).
# If 'lock' is not specified, the Lock-object is auto-created and managed by xarray internally.
self.dataset = xr.open_dataset(str(self.filename), decode_cf=True, engine=self.netcdf_engine)
self.dataset["decoded"] = True
except:
warnings.warn(
f"File {self.filename} could not be decoded properly by xarray (version {xr.__version__}). "
"It will be opened with no decoding. Filling values might be wrongly parsed.",
FileWarning,
stacklevel=2,
)

self.dataset = xr.open_dataset(str(self.filename), decode_cf=False, engine=self.netcdf_engine)
self.dataset["decoded"] = False
for inds in self.indices.values():
if type(inds) not in [list, range]:
raise RuntimeError("Indices for field subsetting need to be a list")
self.dataset = open_xarray_dataset(self.filename, self.netcdf_engine)
return self

def __exit__(self, type, value, traceback):
Expand All @@ -74,16 +51,6 @@ def close(self):
self.dataset.close()
self.dataset = None

def parse_name(self, name):
if isinstance(name, list):
for nm in name:
if hasattr(self.dataset, nm):
name = nm
break
if isinstance(name, list):
raise OSError("None of variables in list found in file")
return name

@property
def latlon(self):
lon = self.dataset[self.dimensions["lon"]]
Expand Down Expand Up @@ -253,6 +220,21 @@ def time_access(self):
return time


class DeferredNetcdfFileBuffer(NetcdfFileBuffer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def open_xarray_dataset(filename: Path | str, netcdf_engine: str) -> xr.Dataset:
try:
# Unfortunately we need to do if-else here, cause the lock-parameter is either False or a Lock-object
# (which we would rather want to have being auto-managed).
# If 'lock' is not specified, the Lock-object is auto-created and managed by xarray internally.
ds = xr.open_dataset(filename, decode_cf=True, engine=netcdf_engine)
ds["decoded"] = True
except:
warnings.warn( # TODO: Is this warning necessary? What cases does this except block get triggered - is it to do with the bare except???
f"File {filename} could not be decoded properly by xarray (version {xr.__version__}). "
"It will be opened with no decoding. Filling values might be wrongly parsed.",
FileWarning,
stacklevel=2,
)

ds = xr.open_dataset(filename, decode_cf=False, engine=netcdf_engine)
ds["decoded"] = False
return ds
Loading