From 75a159e2bca604e7f204cf6f0c34e1388c54f59a Mon Sep 17 00:00:00 2001 From: Trevor James Smith <10819524+Zeitsperre@users.noreply.github.com> Date: Mon, 5 Aug 2024 16:23:08 -0400 Subject: [PATCH] add needed functions for generate_atmos - WIP --- src/xsdba/testing.py | 41 ++++++++ src/xsdba/utils.py | 217 ++++++++++++++++++++++++++++++++++++++++++- tests/conftest.py | 24 ++--- tests/test_units.py | 1 + 4 files changed, 271 insertions(+), 12 deletions(-) diff --git a/src/xsdba/testing.py b/src/xsdba/testing.py index 18fc108..2979cd2 100644 --- a/src/xsdba/testing.py +++ b/src/xsdba/testing.py @@ -19,6 +19,7 @@ from scipy.stats import gamma from xarray import open_dataset as _open_dataset +from xsdba.calendar import percentile_doy from xsdba.utils import equally_spaced_nodes __all__ = ["test_timelonlatseries", "test_timeseries"] @@ -343,3 +344,43 @@ def nancov(X): """Drop observations with NaNs from Numpy's cov.""" X_na = np.isnan(X).any(axis=0) return np.cov(X[:, ~X_na]) + + +# XC +def generate_atmos(cache_dir: Path) -> dict[str, xr.DataArray]: + """Create the `atmosds` synthetic testing dataset.""" + with open_dataset( + "ERA5/daily_surface_cancities_1990-1993.nc", + cache_dir=cache_dir, + branch=TESTDATA_BRANCH, + engine="h5netcdf", + ) as ds: + tn10 = percentile_doy(ds.tasmin, per=10) + t10 = percentile_doy(ds.tas, per=10) + t90 = percentile_doy(ds.tas, per=90) + tx90 = percentile_doy(ds.tasmax, per=90) + + # rsus = shortwave_upwelling_radiation_from_net_downwelling(ds.rss, ds.rsds) + # rlus = longwave_upwelling_radiation_from_net_downwelling(ds.rls, ds.rlds) + + ds = ds.assign( + # rsus=rsus, + # rlus=rlus, + tn10=tn10, + t10=t10, + t90=t90, + tx90=tx90, + ) + + # Create a file in session scoped temporary directory + atmos_file = cache_dir.joinpath("atmosds.nc") + ds.to_netcdf(atmos_file, engine="h5netcdf") + + # Give access to dataset variables by name in namespace + namespace = dict() + with open_dataset( + atmos_file, branch=TESTDATA_BRANCH, cache_dir=cache_dir, engine="h5netcdf" + ) as ds: + for variable in ds.data_vars: + namespace[f"{variable}_dataset"] = ds.get(variable) + return namespace diff --git a/src/xsdba/utils.py b/src/xsdba/utils.py index 58f588c..bd5303a 100644 --- a/src/xsdba/utils.py +++ b/src/xsdba/utils.py @@ -19,7 +19,7 @@ from .base import Grouper, ensure_chunk_size, parse_group, uses_dask from .calendar import ensure_longest_doy -from .nbutils import _extrapolate_on_quantiles +from .nbutils import _extrapolate_on_quantiles, _linear_interpolation MULTIPLICATIVE = "*" ADDITIVE = "+" @@ -964,3 +964,218 @@ def load_module(path: os.PathLike, name: str | None = None): mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # This executes code, effectively loading the module return mod + + +# calc_perc-needed functions needed for generate_atmos + + +# XC +def calc_perc( + arr: np.ndarray, + percentiles: Sequence[float] | None = None, + alpha: float = 1.0, + beta: float = 1.0, + copy: bool = True, +) -> np.ndarray: + """Compute percentiles using nan_calc_percentiles and move the percentiles' axis to the end.""" + if percentiles is None: + _percentiles = [50.0] + else: + _percentiles = percentiles + + return np.moveaxis( + nan_calc_percentiles( + arr=arr, + percentiles=_percentiles, + axis=-1, + alpha=alpha, + beta=beta, + copy=copy, + ), + source=0, + destination=-1, + ) + + +# XC +def nan_calc_percentiles( + arr: np.ndarray, + percentiles: Sequence[float] | None = None, + axis: int = -1, + alpha: float = 1.0, + beta: float = 1.0, + copy: bool = True, +) -> np.ndarray: + """Convert the percentiles to quantiles and compute them using _nan_quantile.""" + if percentiles is None: + _percentiles = [50.0] + else: + _percentiles = percentiles + + if copy: + # bootstrapping already works on a data's copy + # doing it again is extremely costly, especially with dask. + arr = arr.copy() + quantiles = np.array([per / 100.0 for per in _percentiles]) + return _nan_quantile(arr, quantiles, axis, alpha, beta) + + +# XC +def _nan_quantile( + arr: np.ndarray, + quantiles: np.ndarray, + axis: int = 0, + alpha: float = 1.0, + beta: float = 1.0, +) -> float | np.ndarray: + """Get the quantiles of the array for the given axis. + + A linear interpolation is performed using alpha and beta. + + Notes + ----- + By default, alpha == beta == 1 which performs the 7th method of :cite:t:`hyndman_sample_1996`. + with alpha == beta == 1/3 we get the 8th method. + """ + # --- Setup + data_axis_length = arr.shape[axis] + if data_axis_length == 0: + return np.nan + if data_axis_length == 1: + result = np.take(arr, 0, axis=axis) + return np.broadcast_to(result, (quantiles.size,) + result.shape) + # The dimensions of `q` are prepended to the output shape, so we need the + # axis being sampled from `arr` to be last. + DATA_AXIS = 0 + if axis != DATA_AXIS: # But moveaxis is slow, so only call it if axis!=0. + arr = np.moveaxis(arr, axis, destination=DATA_AXIS) + # nan_count is not a scalar + nan_count = np.isnan(arr).sum(axis=DATA_AXIS).astype(float) + valid_values_count = data_axis_length - nan_count + # We need at least two values to do an interpolation + too_few_values = valid_values_count < 2 + if too_few_values.any(): + # This will result in getting the only available value if it exists + valid_values_count[too_few_values] = np.nan + # --- Computation of indexes + # Add axis for quantiles + valid_values_count = valid_values_count[..., np.newaxis] + virtual_indexes = _compute_virtual_index(valid_values_count, quantiles, alpha, beta) + virtual_indexes = np.asanyarray(virtual_indexes) + previous_indexes, next_indexes = _get_indexes( + arr, virtual_indexes, valid_values_count + ) + # --- Sorting + arr.sort(axis=DATA_AXIS) + # --- Get values from indexes + arr = arr[..., np.newaxis] + previous = np.squeeze( + np.take_along_axis(arr, previous_indexes.astype(int)[np.newaxis, ...], axis=0), + axis=0, + ) + next_elements = np.squeeze( + np.take_along_axis(arr, next_indexes.astype(int)[np.newaxis, ...], axis=0), + axis=0, + ) + # --- Linear interpolation + gamma = _get_gamma(virtual_indexes, previous_indexes) + interpolation = _linear_interpolation(previous, next_elements, gamma) + # When an interpolation is in Nan range, (near the end of the sorted array) it means + # we can clip to the array max value. + result = np.where(np.isnan(interpolation), np.nanmax(arr, axis=0), interpolation) + # Move quantile axis in front + result = np.moveaxis(result, axis, 0) + return result + + +# XC +def _get_gamma(virtual_indexes: np.ndarray, previous_indexes: np.ndarray): + """Compute gamma (AKA 'm' or 'weight') for the linear interpolation of quantiles. + + Parameters + ---------- + virtual_indexes: array_like + The indexes where the percentile is supposed to be found in the sorted sample. + previous_indexes: array_like + The floor values of virtual_indexes. + + Notes + ----- + `gamma` is usually the fractional part of virtual_indexes but can be modified by the interpolation method. + """ + gamma = np.asanyarray(virtual_indexes - previous_indexes) + return np.asanyarray(gamma) + + +# XC +def _compute_virtual_index( + n: np.ndarray, quantiles: np.ndarray, alpha: float, beta: float +): + """Compute the floating point indexes of an array for the linear interpolation of quantiles. + + Based on the approach used by :cite:t:`hyndman_sample_1996`. + + Parameters + ---------- + n : array_like + The sample sizes. + quantiles : array_like + The quantiles values. + alpha : float + A constant used to correct the index computed. + beta : float + A constant used to correct the index computed. + + Notes + ----- + `alpha` and `beta` values depend on the chosen method (see quantile documentation). + + References + ---------- + :cite:cts:`hyndman_sample_1996` + """ + return n * quantiles + (alpha + quantiles * (1 - alpha - beta)) - 1 + + +# XC +def _get_indexes( + arr: np.ndarray, virtual_indexes: np.ndarray, valid_values_count: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Get the valid indexes of arr neighbouring virtual_indexes. + + Notes + ----- + This is a companion function to linear interpolation of quantiles. + + Parameters + ---------- + arr : array-like + virtual_indexes : array-like + valid_values_count : array-like + + Returns + ------- + array-like, array-like + A tuple of virtual_indexes neighbouring indexes (previous and next). + """ + previous_indexes = np.asanyarray(np.floor(virtual_indexes)) + next_indexes = np.asanyarray(previous_indexes + 1) + indexes_above_bounds = virtual_indexes >= valid_values_count - 1 + # When indexes is above max index, take the max value of the array + if indexes_above_bounds.any(): + previous_indexes[indexes_above_bounds] = -1 + next_indexes[indexes_above_bounds] = -1 + # When indexes is below min index, take the min value of the array + indexes_below_bounds = virtual_indexes < 0 + if indexes_below_bounds.any(): + previous_indexes[indexes_below_bounds] = 0 + next_indexes[indexes_below_bounds] = 0 + if np.issubdtype(arr.dtype, np.inexact): + # After the sort, slices having NaNs will have for last element a NaN + virtual_indexes_nans = np.isnan(virtual_indexes) + if virtual_indexes_nans.any(): + previous_indexes[virtual_indexes_nans] = -1 + next_indexes[virtual_indexes_nans] = -1 + previous_indexes = previous_indexes.astype(np.intp) + next_indexes = next_indexes.astype(np.intp) + return previous_indexes, next_indexes diff --git a/tests/conftest.py b/tests/conftest.py index e67c721..28b9f3e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ # from filelock import FileLock from packaging.version import Version -from xsdba.testing import TESTDATA_BRANCH +from xsdba.testing import TESTDATA_BRANCH, generate_atmos from xsdba.testing import open_dataset as _open_dataset from xsdba.testing import ( test_cannon_2015_dist, @@ -317,17 +317,20 @@ def atmosds(threadsafe_data_dir) -> xr.Dataset: # ) -# @pytest.fixture(scope="session", autouse=True) -# def gather_session_data(threadsafe_data_dir, worker_id, xdoctest_namespace): -# """Gather testing data on pytest run. +@pytest.fixture(scope="session", autouse=True) +def gather_session_data(threadsafe_data_dir): + """Gather testing data on pytest run. + + When running pytest with multiple workers, one worker will copy data remotely to _default_cache_dir while + other workers wait using lockfile. Once the lock is released, all workers will then copy data to their local + threadsafe_data_dir.As this fixture is scoped to the session, it will only run once per pytest run. + + Additionally, this fixture is also used to generate the `atmosds` synthetic testing dataset as well as add the + example file paths to the xdoctest_namespace, used when running doctests. + """ + generate_atmos(threadsafe_data_dir) -# When running pytest with multiple workers, one worker will copy data remotely to _default_cache_dir while -# other workers wait using lockfile. Once the lock is released, all workers will then copy data to their local -# threadsafe_data_dir.As this fixture is scoped to the session, it will only run once per pytest run. -# Additionally, this fixture is also used to generate the `atmosds` synthetic testing dataset as well as add the -# example file paths to the xdoctest_namespace, used when running doctests. -# """ # if ( # not _default_cache_dir.joinpath(helpers.TESTDATA_BRANCH).exists() # or helpers.PREFETCH_TESTING_DATA @@ -353,7 +356,6 @@ def atmosds(threadsafe_data_dir) -> xr.Dataset: # if lockfile.exists(): # lockfile.unlink() # shutil.copytree(_default_cache_dir, threadsafe_data_dir) -# helpers.generate_atmos(threadsafe_data_dir) # xdoctest_namespace.update(helpers.add_example_file_paths(threadsafe_data_dir)) diff --git a/tests/test_units.py b/tests/test_units.py index 144ff82..271c7d4 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -7,6 +7,7 @@ import pytest import xarray as xr from dask import array as dsk +from packaging.version import Version from xsdba.logging import ValidationError from xsdba.typing import Quantified