diff --git a/parcels/field.py b/parcels/field.py index 516da65b61..cc2f064645 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -43,7 +43,6 @@ from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear from .fieldfilebuffer import ( - DeferredNetcdfFileBuffer, NetcdfFileBuffer, ) from .grid import Grid, GridType @@ -281,7 +280,6 @@ def __init__( self._dimensions = kwargs.pop("dimensions", None) self.indices = kwargs.pop("indices", None) self._dataFiles = kwargs.pop("dataFiles", None) - self._field_fb_class = kwargs.pop("FieldFileBuffer", None) self._netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") self._creation_log = kwargs.pop("creation_log", "") self.grid.depth_field = kwargs.pop("depth_field", None) @@ -361,9 +359,7 @@ def _get_dim_filenames(cls, filenames, dim): return filenames @staticmethod - def _collect_timeslices( - timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine, netcdf_decodewarning=None - ): + def _collect_timeslices(timestamps, data_filenames, dimensions, indices, netcdf_engine, netcdf_decodewarning=None): if netcdf_decodewarning is not None: _deprecated_param_netcdf_decodewarning() if timestamps is not None: @@ -378,7 +374,7 @@ def _collect_timeslices( timeslices = [] dataFiles = [] for fname in data_filenames: - with _grid_fb_class(fname, dimensions, indices, netcdf_engine=netcdf_engine) as filebuffer: + with NetcdfFileBuffer(fname, dimensions, indices, netcdf_engine=netcdf_engine) as filebuffer: ftime = filebuffer.time timeslices.append(ftime) dataFiles.append([fname] * len(ftime)) @@ -541,10 +537,8 @@ def from_netcdf( else: raise RuntimeError(f"interp_method is a dictionary but {variable[0]} is not in it") - _grid_fb_class = NetcdfFileBuffer - if "lon" in dimensions and "lat" in dimensions: - with _grid_fb_class( + with NetcdfFileBuffer( lonlat_filename, dimensions, indices, @@ -562,7 +556,7 @@ def from_netcdf( mesh = "flat" if "depth" in dimensions: - with _grid_fb_class( + with NetcdfFileBuffer( depth_filename, dimensions, indices, @@ -570,7 +564,7 @@ def from_netcdf( interp_method=interp_method, gridindexingtype=gridindexingtype, ) as filebuffer: - filebuffer.name = filebuffer.parse_name(variable[1]) + filebuffer.name = variable[1] if dimensions["depth"] == "not_yet_set": depth = filebuffer.depth_dimensions kwargs["depth_field"] = "not_yet_set" @@ -592,7 +586,7 @@ def from_netcdf( # across multiple files if "time" in dimensions or timestamps is not None: time, time_origin, timeslices, dataFiles = cls._collect_timeslices( - timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine + timestamps, data_filenames, dimensions, indices, netcdf_engine ) grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) grid.timeslices = timeslices @@ -604,9 +598,7 @@ def from_netcdf( elif grid is not None and ("dataFiles" not in kwargs or kwargs["dataFiles"] is None): # ==== means: the field has a shared grid, but may have different data files, so we need to collect the # ==== correct file time series again. - _, _, _, dataFiles = cls._collect_timeslices( - timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine - ) + _, _, _, dataFiles = cls._collect_timeslices(timestamps, data_filenames, dimensions, indices, netcdf_engine) kwargs["dataFiles"] = dataFiles if "time" in indices: @@ -617,19 +609,12 @@ def from_netcdf( if grid.time.size <= 2: deferred_load = False - _field_fb_class: type[DeferredNetcdfFileBuffer | NetcdfFileBuffer] - if deferred_load: - _field_fb_class = DeferredNetcdfFileBuffer - else: - _field_fb_class = NetcdfFileBuffer - kwargs["FieldFileBuffer"] = _field_fb_class - if not deferred_load: # Pre-allocate data before reading files into buffer data_list = [] ti = 0 for tslice, fname in zip(grid.timeslices, data_filenames, strict=True): - with _field_fb_class( # type: ignore[operator] + with NetcdfFileBuffer( # type: ignore[operator] fname, dimensions, indices, @@ -639,7 +624,7 @@ def from_netcdf( ) as filebuffer: # If Field.from_netcdf is called directly, it may not have a 'data' dimension # In that case, assume that 'name' is the data dimension - filebuffer.name = filebuffer.parse_name(variable[1]) + filebuffer.name = variable[1] buffer_data = filebuffer.data if len(buffer_data.shape) == 4: errormessage = ( @@ -1080,7 +1065,7 @@ def computeTimeChunk(self, data, tindex): ti = g._ti + tindex timestamp = self.timestamps[np.where(ti < summedlen)[0][0]] - filebuffer = self._field_fb_class( + filebuffer = NetcdfFileBuffer( self._dataFiles[g._ti + tindex], self.dimensions, self.indices, @@ -1095,7 +1080,7 @@ def computeTimeChunk(self, data, tindex): time_data = g.time_origin.reltime(time_data) filebuffer.ti = (time_data <= g.time[tindex]).argmin() - 1 if self.netcdf_engine != "xarray": - filebuffer.name = filebuffer.parse_name(self.filebuffername) + filebuffer.name = self.filebuffername buffer_data = filebuffer.data if len(buffer_data.shape) == 2: buffer_data = np.reshape(buffer_data, sum(((1, 1), buffer_data.shape), ())) diff --git a/parcels/fieldfilebuffer.py b/parcels/fieldfilebuffer.py index 9c47890294..2bfa4b5a31 100644 --- a/parcels/fieldfilebuffer.py +++ b/parcels/fieldfilebuffer.py @@ -1,5 +1,6 @@ import datetime import warnings +from pathlib import Path import numpy as np import xarray as xr @@ -9,7 +10,7 @@ from parcels.tools.warnings import FileWarning -class _FileBuffer: +class NetcdfFileBuffer: def __init__( self, filename, @@ -36,34 +37,10 @@ def __init__( self.nolonlatindices = False else: self.nolonlatindices = True - - -class NetcdfFileBuffer(_FileBuffer): - def __init__(self, *args, **kwargs): - self.lib = np self.netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") - super().__init__(*args, **kwargs) def __enter__(self): - try: - # Unfortunately we need to do if-else here, cause the lock-parameter is either False or a Lock-object - # (which we would rather want to have being auto-managed). - # If 'lock' is not specified, the Lock-object is auto-created and managed by xarray internally. - self.dataset = xr.open_dataset(str(self.filename), decode_cf=True, engine=self.netcdf_engine) - self.dataset["decoded"] = True - except: - warnings.warn( - f"File {self.filename} could not be decoded properly by xarray (version {xr.__version__}). " - "It will be opened with no decoding. Filling values might be wrongly parsed.", - FileWarning, - stacklevel=2, - ) - - self.dataset = xr.open_dataset(str(self.filename), decode_cf=False, engine=self.netcdf_engine) - self.dataset["decoded"] = False - for inds in self.indices.values(): - if type(inds) not in [list, range]: - raise RuntimeError("Indices for field subsetting need to be a list") + self.dataset = open_xarray_dataset(self.filename, self.netcdf_engine) return self def __exit__(self, type, value, traceback): @@ -74,16 +51,6 @@ def close(self): self.dataset.close() self.dataset = None - def parse_name(self, name): - if isinstance(name, list): - for nm in name: - if hasattr(self.dataset, nm): - name = nm - break - if isinstance(name, list): - raise OSError("None of variables in list found in file") - return name - @property def latlon(self): lon = self.dataset[self.dimensions["lon"]] @@ -253,6 +220,21 @@ def time_access(self): return time -class DeferredNetcdfFileBuffer(NetcdfFileBuffer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +def open_xarray_dataset(filename: Path | str, netcdf_engine: str) -> xr.Dataset: + try: + # Unfortunately we need to do if-else here, cause the lock-parameter is either False or a Lock-object + # (which we would rather want to have being auto-managed). + # If 'lock' is not specified, the Lock-object is auto-created and managed by xarray internally. + ds = xr.open_dataset(filename, decode_cf=True, engine=netcdf_engine) + ds["decoded"] = True + except: + warnings.warn( # TODO: Is this warning necessary? What cases does this except block get triggered - is it to do with the bare except??? + f"File {filename} could not be decoded properly by xarray (version {xr.__version__}). " + "It will be opened with no decoding. Filling values might be wrongly parsed.", + FileWarning, + stacklevel=2, + ) + + ds = xr.open_dataset(filename, decode_cf=False, engine=netcdf_engine) + ds["decoded"] = False + return ds