diff --git a/HISTORY.rst b/HISTORY.rst index adff47a2..7016b896 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -9,6 +9,7 @@ Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Juliette Announcements ^^^^^^^^^^^^^ * `xscen` is now offered as a conda package available through Anaconda.org. Refer to the installation documentation for more information. (:issue:`149`, :pull:`171`). +* Deprecation: Release 0.6.0 of `xscen` will be the last version to support ``xscen.extract.clisops_subset``. (:pull:`182`). New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -20,6 +21,8 @@ New features and enhancements * Allow passing ``GeoDataFrame`` instances in ``spatial_mean``'s ``region`` argument, not only geospatial file paths. (:pull:`174`). * Allow searching for periods in `catalog.search`. (:issue:`123`, :pull:`170`). * Allow searching and extracting multiple frequencies for a given variable. (:issue:`168`, :pull:`170`). +* New masking feature in ``extract_dataset``. (:issue:`180`, :pull:`182`). +* New function ``xs.spatial.subset`` to replace ``xs.extract.clisops_subset`` and add method "sel". (:issue:`180`, :pull:`182`). Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/xscen/extract.py b/xscen/extract.py index 6db32e77..f0dd1f8d 100644 --- a/xscen/extract.py +++ b/xscen/extract.py @@ -5,18 +5,14 @@ import re import warnings from collections import defaultdict -from copy import deepcopy from pathlib import Path from typing import Callable, List, Optional, Union -import clisops.core.subset -import dask import numpy as np import pandas as pd import xarray as xr import xclim as xc from intake_esm.derived import DerivedVariableRegistry -from xclim.core.utils import uses_dask from .catalog import DataCatalog # noqa from .catalog import ( @@ -28,6 +24,7 @@ ) from .config import parse_config from .indicators import load_xclim_module, registry_from_module +from .spatial import subset from .utils import CV from .utils import ensure_correct_time as _ensure_correct_time from .utils import natural_sort @@ -73,73 +70,13 @@ def clisops_subset(ds: xr.Dataset, region: dict) -> xr.Dataset: -------- clisops.core.subset.subset_gridpoint, clisops.core.subset.subset_bbox, clisops.core.subset.subset_shape """ - if uses_dask(ds.lon) or uses_dask(ds.lat): - warnings.warn("Loading longitude and latitude for more efficient subsetting.") - ds["lon"], ds["lat"] = dask.compute(ds.lon, ds.lat) - if "buffer" in region.keys(): - # estimate the model resolution - if len(ds.lon.dims) == 1: # 1D lat-lon - lon_res = np.abs(ds.lon.diff("lon")[0].values) - lat_res = np.abs(ds.lat.diff("lat")[0].values) - else: - lon_res = np.abs(ds.lon[0, 0].values - ds.lon[0, 1].values) - lat_res = np.abs(ds.lat[0, 0].values - ds.lat[1, 0].values) - - kwargs = deepcopy(region[region["method"]]) - - if region["method"] in ["gridpoint"]: - ds_subset = clisops.core.subset_gridpoint(ds, **kwargs) - new_history = ( - f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"{region['method']} spatial subsetting on {len(region['gridpoint']['lon'])} coordinates - clisops v{clisops.__version__}" - ) - - elif region["method"] in ["bbox"]: - if "buffer" in region.keys(): - # adjust the boundaries - kwargs["lon_bnds"] = ( - kwargs["lon_bnds"][0] - lon_res * region["buffer"], - kwargs["lon_bnds"][1] + lon_res * region["buffer"], - ) - kwargs["lat_bnds"] = ( - kwargs["lat_bnds"][0] - lat_res * region["buffer"], - kwargs["lat_bnds"][1] + lat_res * region["buffer"], - ) - - if xc.core.utils.uses_dask(ds.cf["longitude"]): - ds[ds.cf["longitude"].name].load() - if xc.core.utils.uses_dask(ds.cf["latitude"]): - ds[ds.cf["latitude"].name].load() - - ds_subset = clisops.core.subset_bbox(ds, **kwargs) - new_history = ( - f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"{region['method']} spatial subsetting with {'buffer=' + str(region['buffer']) if 'buffer' in region else 'no buffer'}" - f", lon_bnds={np.array(region['bbox']['lon_bnds'])}, lat_bnds={np.array(region['bbox']['lat_bnds'])}" - f" - clisops v{clisops.__version__}" - ) - - elif region["method"] in ["shape"]: - if "buffer" in region.keys(): - kwargs["buffer"] = np.max([lon_res, lat_res]) * region["buffer"] - - ds_subset = clisops.core.subset_shape(ds, **kwargs) - new_history = ( - f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"{region['method']} spatial subsetting with {'buffer=' + str(region['buffer']) if 'buffer' in region else 'no buffer'}" - f", shape={Path(region['shape']['shape']).name if isinstance(region['shape']['shape'], (str, Path)) else 'gpd.GeoDataFrame'}" - f" - clisops v{clisops.__version__}" - ) - - else: - raise ValueError("Subsetting type not recognized") - - history = ( - new_history + " \n " + ds_subset.attrs["history"] - if "history" in ds_subset.attrs - else new_history + warnings.warn( + "clisops_subset is deprecated and will not be available in future versions. " + "Use xscen.spatial.subset instead.", + category=FutureWarning, ) - ds_subset.attrs["history"] = history + + ds_subset = subset(ds, region) return ds_subset @@ -157,6 +94,7 @@ def extract_dataset( xr_combine_kwargs: dict = None, preprocess: Callable = None, resample_methods: Optional[dict] = None, + mask: Union[bool, xr.Dataset, xr.DataArray] = False, ) -> Union[dict, xr.Dataset]: """Take one element of the output of `search_data_catalogs` and returns a dataset, performing conversions and resampling as needed. @@ -174,7 +112,7 @@ def extract_dataset( [start, end] of the period to be evaluated (or a list of lists) Will be read from catalog._requested_periods if None. Leave both None to extract everything. region : dict, optional - Description of the region and the subsetting method (required fields listed in the Notes). + Description of the region and the subsetting method (required fields listed in the Notes) used in `xscen.spatial.subset`. to_level : str The processing level to assign to the output. Defaults to 'extracted' @@ -197,6 +135,12 @@ def extract_dataset( If the method is not given for a variable, it is guessed from the variable name and frequency, using the mapping in CVs/resampling_methods.json. If the variable is not found there, "mean" is used by default. + mask: xr.Dataset, bool + A mask that is applied to all variables and only keeps data where it is True. + Where the mask is False, variable values are replaced by NaNs. + The mask should have the same dimensions as the variables extracted. + If `mask` is a dataset, the dataset should have a variable named 'mask'. + If `mask` is True, it will expect a `mask` variable at xrfreq `fx` to have been extracted. Returns ------- @@ -211,7 +155,7 @@ def extract_dataset( name: str Region name used to overwrite domain in the catalog. method: str - ['gridpoint', 'bbox', shape'] + ['gridpoint', 'bbox', shape', 'sel'] : dict Arguments specific to the method used. buffer: float, optional @@ -375,10 +319,9 @@ def extract_dataset( slices.extend([ds.sel({"time": slice(str(period[0]), str(period[1]))})]) ds = xr.concat(slices, dim="time", **xr_combine_kwargs) - # Custom call to clisops + # subset to the region if region is not None: - ds = clisops_subset(ds, region) - ds.attrs["cat:domain"] = region["name"] + ds = subset(ds, region) # add relevant attrs ds.attrs["cat:processing_level"] = to_level @@ -387,6 +330,26 @@ def extract_dataset( out_dict[xrfreq] = ds + if mask: + if isinstance(mask, xr.Dataset): + ds_mask = mask["mask"] + elif isinstance(mask, xr.DataArray): + ds_mask = mask + elif ( + "fx" in out_dict and "mask" in out_dict["fx"] + ): # get mask that was extracted above + ds_mask = out_dict["fx"]["mask"].copy() + else: + raise ValueError( + "No mask found. Either pass a xr.Dataset/xr.DataArray to the `mask` argument or pass a `dc` that includes a dataset with a variable named `mask`." + ) + + # iter over all xrfreq to apply the mask + for xrfreq, ds in out_dict.items(): + out_dict[xrfreq] = ds.where(ds_mask) + if xrfreq == "fx": # put back the mask + out_dict[xrfreq]["mask"] = ds_mask + return out_dict diff --git a/xscen/spatial.py b/xscen/spatial.py index 78d57d98..f6626eec 100644 --- a/xscen/spatial.py +++ b/xscen/spatial.py @@ -1,9 +1,23 @@ """Spatial tools.""" +import datetime import itertools +import warnings +from copy import deepcopy +from pathlib import Path +import clisops.core.subset +import dask import numpy as np import sparse as sp import xarray as xr +import xclim as xc +from xclim.core.utils import uses_dask + +__all__ = [ + "creep_weights", + "creep_fill", + "subset", +] def creep_weights(mask, n=1, mode="clip"): @@ -103,3 +117,122 @@ def _dot(arr, wei): dask="parallelized", output_dtypes=["float64"], ) + + +def subset(ds: xr.Dataset, region: dict) -> xr.Dataset: + """ + Subset the data to a region. + + Either creates a slice and uses the .sel() method or customize a call to + clisops.subset() that allows for an automatic buffer around the region. + + Parameters + ---------- + ds : xr.Dataset + Dataset to be subsetted + region : dict + Description of the region and the subsetting method (required fields listed in the Notes) + + Notes + ----- + 'region' fields: + name: str + Region name used to overwrite domain in the catalog. + method: str + ['gridpoint', 'bbox', shape','sel'] + If the method is `sel`, this is not a call to clisops but only a subsetting with the xarray .sel() fonction. + The keys are the dimensions to subset and the values are turned into a slice. + : dict + Arguments specific to the method used. + buffer: float, optional + Multiplier to apply to the model resolution. + + Returns + ------- + xr.Dataset + Subsetted Dataset. + + See Also + -------- + clisops.core.subset.subset_gridpoint, clisops.core.subset.subset_bbox, clisops.core.subset.subset_shape + """ + if uses_dask(ds.lon) or uses_dask(ds.lat): + warnings.warn("Loading longitude and latitude for more efficient subsetting.") + ds["lon"], ds["lat"] = dask.compute(ds.lon, ds.lat) + if "buffer" in region.keys(): + # estimate the model resolution + if len(ds.lon.dims) == 1: # 1D lat-lon + lon_res = np.abs(ds.lon.diff("lon")[0].values) + lat_res = np.abs(ds.lat.diff("lat")[0].values) + else: + lon_res = np.abs(ds.lon[0, 0].values - ds.lon[0, 1].values) + lat_res = np.abs(ds.lat[0, 0].values - ds.lat[1, 0].values) + + kwargs = deepcopy(region[region["method"]]) + + if region["method"] in ["gridpoint"]: + ds_subset = clisops.core.subset_gridpoint(ds, **kwargs) + new_history = ( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"{region['method']} spatial subsetting on {len(region['gridpoint']['lon'])} coordinates - clisops v{clisops.__version__}" + ) + + elif region["method"] in ["bbox"]: + if "buffer" in region.keys(): + # adjust the boundaries + kwargs["lon_bnds"] = ( + kwargs["lon_bnds"][0] - lon_res * region["buffer"], + kwargs["lon_bnds"][1] + lon_res * region["buffer"], + ) + kwargs["lat_bnds"] = ( + kwargs["lat_bnds"][0] - lat_res * region["buffer"], + kwargs["lat_bnds"][1] + lat_res * region["buffer"], + ) + + if xc.core.utils.uses_dask(ds.cf["longitude"]): + ds[ds.cf["longitude"].name].load() + if xc.core.utils.uses_dask(ds.cf["latitude"]): + ds[ds.cf["latitude"].name].load() + + ds_subset = clisops.core.subset_bbox(ds, **kwargs) + new_history = ( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"{region['method']} spatial subsetting with {'buffer=' + str(region['buffer']) if 'buffer' in region else 'no buffer'}" + f", lon_bnds={np.array(region['bbox']['lon_bnds'])}, lat_bnds={np.array(region['bbox']['lat_bnds'])}" + f" - clisops v{clisops.__version__}" + ) + + elif region["method"] in ["shape"]: + if "buffer" in region.keys(): + kwargs["buffer"] = np.max([lon_res, lat_res]) * region["buffer"] + + ds_subset = clisops.core.subset_shape(ds, **kwargs) + new_history = ( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"{region['method']} spatial subsetting with {'buffer=' + str(region['buffer']) if 'buffer' in region else 'no buffer'}" + f", shape={Path(region['shape']['shape']).name if isinstance(region['shape']['shape'], (str, Path)) else 'gpd.GeoDataFrame'}" + f" - clisops v{clisops.__version__}" + ) + + elif region["method"] in ["sel"]: + arg_sel = { + dim: slice(*map(float, bounds)) for dim, bounds in region["sel"].items() + } + ds_subset = ds.sel(**arg_sel) + new_history = ( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"{region['method']} subsetting with arguments {arg_sel}" + ) + + else: + raise ValueError("Subsetting type not recognized") + + history = ( + new_history + " \n " + ds_subset.attrs["history"] + if "history" in ds_subset.attrs + else new_history + ) + ds_subset.attrs["history"] = history + ds_subset.attrs["cat:domain"] = region["name"] + + return ds_subset